[PYTHON] Median filter using xarray (median filter)

Median filter using xarray (median filter)

I used xarray for the purpose of applying a filter such as a moving average, so I will keep a record.

All we had to do was apply a moving average (or moving median filter) in only one direction of the multidimensional array. Here, image data is treated as an example. Originally, the image data should use the library for image data, but please note that there was no suitable data.

import numpy as np
import xarray as xr
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image
image = np.array(Image.open('lena.jpg'))
plt.imshow(image)
<matplotlib.image.AxesImage at 0x7f6b807a8898>

output_3_1.png

Here, reduce the size appropriately and add salt and pepper noise.

small_image = image[::3, ::3]
noisy_image = small_image.copy()
noisy_image[np.random.randint(0, noisy_image.shape[0], 1000), 
            np.random.randint(0, noisy_image.shape[1], 1000),
            np.random.randint(0, noisy_image.shape[2], 1000)] = 0
noisy_image[np.random.randint(0, noisy_image.shape[0], 1000), 
            np.random.randint(0, noisy_image.shape[1], 1000), 
            np.random.randint(0, noisy_image.shape[2], 1000)] = 256
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('original')
plt.imshow(small_image)
plt.subplot(1, 2, 2)
plt.title('noisy')
plt.imshow(noisy_image)
<matplotlib.image.AxesImage at 0x7f6b7eefb6d8>

output_6_1.png

Store in xr.DataArray. Let the vertical, horizontal, and color axes be x, y, and c, respectively.

data = xr.DataArray(noisy_image, dims=['x', 'y', 'c'])

To take the moving average in the x direction, use the rolling method. For the method keywords, specify the corresponding axis name and window size.

rolling = data.rolling(x=3)  #Consider a 3-pixel window in the x direction.
rolling
DataArrayRolling [window->3,center->False,dim->x]

rolling corresponds to mean, median, max, min, count, and so on.

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title('mean')
plt.imshow(rolling.mean().astype('ubyte'))  # `imshow`Converted to ubyte for display in.
plt.subplot(1, 2, 2)
plt.title('median')
plt.imshow(rolling.median().astype('ubyte'))
<matplotlib.image.AxesImage at 0x7f6b7edb7f60>

output_12_1.png

If you look closely, you can see that the top two pixels are black. This is because moving average is not possible for the first 2 pixels. It contains np.nan.

rolling.mean()
<xarray.DataArray (x: 171, y: 171, c: 3)>
array([[[        nan,         nan,         nan],
        [        nan,         nan,         nan],
        ..., 
        [        nan,         nan,         nan],
        [        nan,         nan,         nan]],

       [[        nan,         nan,         nan],
        [        nan,         nan,         nan],
        ..., 
        [        nan,         nan,         nan],
        [        nan,         nan,         nan]],

       ..., 
       [[  91.      ,   27.333333,   63.333333],
        [  93.666667,   26.666667,   61.      ],
        ..., 
        [ 125.666667,   25.333333,   70.      ],
        [ 143.      ,   48.333333,   69.333333]],

       [[  87.333333,   26.333333,   62.666667],
        [  91.333333,   25.333333,   58.666667],
        ..., 
        [ 149.      ,   39.333333,   79.666667],
        [ 162.333333,   60.333333,   73.666667]]])
Dimensions without coordinates: x, y, c

If you don't like it, specify the minimum window width, such as rolling (x = 3, min_periods = 1). With min_periods = 1, the end points will be the moving average value of 1 pixel (effectively the same value).

plt.imshow(data.rolling(x=3, min_periods=1).median().astype('ubyte'))
<matplotlib.image.AxesImage at 0x7f6b7ed676a0>

output_16_1.png

If you look closely, this picture shifts downwards on average. If you enlarge the upper part

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(data.rolling(x=4, min_periods=1).median().astype('ubyte')[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
plt.subplot(1, 2, 2)
plt.imshow(small_image[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
[<matplotlib.lines.Line2D at 0x7f6b7ec430f0>]

output_18_1.png

If you do not want such a shift, specify center = True.

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(data.rolling(x=4, min_periods=1, center=True).median().astype('ubyte')[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
plt.subplot(1, 2, 2)
plt.imshow(small_image[70:100, 70:100])
plt.plot([0, 30], [15, 15], '--k')
[<matplotlib.lines.Line2D at 0x7f6b7ebdc7b8>]

output_20_1.png


Recommended Posts

Median filter using xarray (median filter)
One-dimensional median filter (median filter)
Low-pass filter using closure
Write FizzBuzz using map (), reduce (), filter (), recursion