[PYTHON] col2im Thorough understanding

Target person

For those who want to know more about the im2col function that appears in image recognition using CNN We will thoroughly explain from the initial implementation to the improved version, batch channel compatible version, stride padding compatible version using gifs and images.

table of contents

-[What is col2im](What is # col2im) -[Behavior and initial implementation of col2im](Behavior and initial implementation of # col2im) -[Improvement of col2im](Improvement of # col2im) -[Completed version col2im](#Completed version col2im) -[Small question](#Small question)

What is col2im?

The col2im function is an important function that is indispensable in fields such as image recognition, which is paired with the ʻim2im function. Its role is the opposite of the ʻim2col function, which was converted to a ** tensor $ \ rightarrow $ matrix ** by the ʻim2colfunction during ** forward propagation **, whereas it was converted to ** backpropagation ** Convert to ** matrix $ \ rightarrow $ tensor ** with thecol2im` function. By doing this, it will be transformed into a shape suitable for learning such as a filter.

Behavior and initial implementation of col2im

Let's start with the initial implementation. In other words

stride = 1 \\
pad = 0

Suppose that The operation is the reverse of the ʻim2colfunction, so it looks like the following. ![col2im_image.gif](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/640911/7757c2b4-43a7-ef74-8e18-649b6caf3766.gif) At this time, please note that ** the overlapping parts are added **. The reason can be understood by considering the operation of filtering. When focusing on one element, the next layer of elements affected by filtering is shown in the figure below. ![col2im_NN.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/640911/d663724c-caca-4b68-72af-736dd7cacb7c.png) In other words, it is ** branching into each element **. This means that ** the gradients flowing in the back propagation must be added together **. Therefore, when transforming with thecol2im` function, it is necessary to" add the overlapping parts ".

Now, let's simply program according to this logic.

Early `col2im`

col2im.py


def col2im(cols, I_shape, O_shape):
    def get_f_shape(i, o):
        return i - o + 1
    
    I_h, I_w = I_shape
    O_h, O_w = O_shape
    F_h = get_f_shape(I_h, O_h)
    F_w = get_f_shape(I_w, O_w)
    images = np.zeros((I_h, I_w))
    
    for h in range(O_h):
        h_lim = h + F_h
        for w in range(O_w):
            w_lim = w + F_w
            images[h:h_lim, w:w_lim] += cols[:, h*O_h+w].reshape(F_h, F_w)
    
    return images


x = np.ones((4, 4))
f = np.arange(-2*2, 0).reshape(2, 2)
im2col_x, O_shape = im2col(x, f, pad=0, get_out_size=True)
im2col_f, Of_shape = im2col(f, f, get_out_size=True)
print(im2col_x)
print(im2col_f)
print(col2im(im2col_x, x.shape, O_shape))
print(col2im(im2col_f, f.shape, Of_shape))

I feel like this. First, we prepare a box that has the shape after deformation, and then deform and throw it into each row. Here, the shape of the filter is the relational expression between the input and output of im2col and the filter.

O_h = I_h - F_h + 1 \\
O_w = I_w - F_w + 1

It is calculated using.

Improvement of col2im

After all, the initial implementation requires $ O_h O_w $ access like ʻim2col, so it has the disadvantage of slow processing speed and impracticality. So, devise the same as for ʻim2col. The method is just the reverse order.

Improved version `col2im`

col2im.py


def col2im(cols, I_shape, O_shape):
    def get_f_shape(i, o):
        return i - o + 1
    
    I_h, I_w = I_shape
    O_h, O_w = O_shape
    F_h = get_f_shape(I_h, O_h)
    F_w = get_f_shape(I_w, O_w)
    cols = cols.reshape(F_h, F_w, O_h, O_w)
    images = np.zeros((I_h, I_w))
    
    for h in range(F_h):
        h_lim = h + O_h
        for w in range(F_w):
            w_lim = w + O_w
            images[h:h_lim, w:w_lim] += cols[h, w, :, :]
    
    return images


x = np.ones((4, 4))
f = np.arange(-2*2, 0).reshape(2, 2)
im2col_x, O_shape = im2col(x, f, pad=0, get_out_size=True)
im2col_f, Of_shape = im2col(f, f, get_out_size=True)
print(im2col_x)
print(im2col_f)
print(col2im(im2col_x, x.shape, O_shape))
print(col2im(im2col_f, f.shape, Of_shape))

First, the matrix input to col2im improved_im2col_reshape.png From improved_col.png It transforms into such a shape. When allocating memory for the output matrix in the improved version ʻim2col` It has the same shape. later improved_col2im.gif I will access it like this. It's technical ~

Completed version col2im

So finally, consider stride and padding.

Completed version `col2im`

col2im.py


import numpy as np


def col2im(cols, I_shape, O_shape, stride=1, pad=0):
    def get_f_shape(i, o, s, p):
        return int(i + 2*p - (o - 1)*s)
    
    if len(I_shape) == 2:
        B = C = 1
        I_h, I_w = I_shape
    elif len(img_shape) == 3:
        C = 1
        B, I_h, I_w = I_shape
    else:
        B, C, I_h, I_w = I_shape
    O_h, O_w = O_shape
    
    if isinstance(stride, tuple):
        stride_ud, stride_lr = stride
    else:
        stride_ud = stride
        stride_lr = stride
    if isinstance(pad, tuple):
        pad_ud, pad_lr = pad
    elif isinstance(pad, int):
        pad_ud = pad
        pad_lr = pad
    
    F_h = get_f_shape(I_h, O_h, stride_ud, pad_ud)
    F_w = get_f_shape(I_w, O_w, stride_lr, pad_lr)
    pad_ud = int(np.ceil(pad_ud))
    pad_lr = int(np.ceil(pad_lr))
    cols = cols.reshape(C, F_h, F_w, B, O_h, O_w).transpose(3, 0, 1, 2, 4, 5)
    images = np.zeros((B, C, I_h+2*pad_ud+stride-1, I_w+2*pad_lr+stride-1))
    
    for h in range(F_h):
        h_lim = h + stride*O_h
        for w in range(F_w):
            w_lim = w + stride*O_w
            images[:, :, h:h_lim:stride, w:w_lim:stride] += cols[:, :, h, w, :, :]
    
    return images[:, :, pad_ud : I_h+pad_ud, pad_lr : I_w+pad_lr]

x = np.ones((4, 4))
f = np.arange(-2*2, 0).reshape(2, 2)
im2col_x, O_shape, x_pad = im2col(x, f, pad="same")
im2col_f, Of_shape, f_pad = im2col(f, f)
print(im2col_x)
print(im2col_f)
#print((im2col_f.T@im2col_x).reshape(*O_shape))
print(col2im(im2col_x, x.shape, O_shape, pad=x_pad))
print(col2im(im2col_f, f.shape, Of_shape, pad=f_pad))

Shape calculation when stride and padding are taken into consideration

O_h = \left\lceil \cfrac{I_h - F_h + 2\textrm{pad}_{ud}}{\textrm{stride}_{ud}} \right\rceil + 1 \\
O_w = \left\lceil \cfrac{I_w - F_w + 2\textrm{pad}_{lr}}{\textrm{stride}_{lr}} \right\rceil + 1 \\

So, calculate the shape of the filter from here.

F_h = I_h + 2\textrm{pad}_{ud} - (O_h - 1) \textrm{stride}_{ud} \\
F_w = I_w + 2\textrm{pad}_{lr} - (O_w - 1) \textrm{stride}_{lr}

I've been thinking about it, but to restore it properly, the exact value of $ \ textrm {pad} \ _ {ud}, \ textrm {pad} \ _ {lr} $ (the value before rounding up by the ceiling function) is required. It seems necessary, so I changed the implementation of the ʻim2col` function accordingly.

A little question

As I was experimenting, I noticed that adding $ 4 \ times 4 $ matrix input matrix up / down / left / right $ \ textrm {pad} = 1 $ yields $ 6 \ times 6 $, which is $ 2 \ times 2 $ matrix. If you apply the filter with $ \ textrm {stride} = 1 $, the output matrix should be $ 5 \ times 5 $, but that's not the case. pad_im2col.png I wondered why, but by the way, if you enter $ \ textrm {pad} = \ textrm {same} $ in the ʻim2col function under this condition, the padding of the calculation result will be $ \ textrm {pad} = 0.5 $. It will be. And, of course, the padding width is an integer, so it is rounded up to $ \ textrm {pad} = 1 $, so it becomes a $ 6 \ times 6 $ matrix. Therefore, it should be treated as a $ 5 \ times 5 $ matrix, and you can see that the ʻim2col function actually returns the one that uses only the $ 5 \ times 5 $ matrix in the upper left. The proof is that the overlapping part of the col2im function col2im_result.png Like, the upper left part is added 4 times. col2im_q.gif

in conclusion

The explanation is considerably simplified because it is just the reverse order of the ʻim2col` function. More detailed explanations may be added when time is available.

Deep learning series

-Introduction to Deep Learning ~ Basics ~ -Introduction to Deep Learning ~ Coding Preparation ~ -Introduction to Deep Learning ~ Forward Propagation ~ -Introduction to Deep Learning ~ Backpropagation ~ -Introduction to Deep Learning ~ Learning Rules ~ -Introduction to Deep Learning ~ Localization and Loss Functions ~ -Introduction to Deep Learning ~ Function Approximation ~ -List of activation functions (2020) -Gradient descent method list (2020) -See and understand! Comparison of optimization methods (2020) -Thorough understanding of im2col -Col2im thorough understanding -Complete understanding of numpy.pad function

Recommended Posts

col2im Thorough understanding
Im2col thorough understanding
Understanding Concatenate