[PYTHON] compréhension approfondie de col2im

Personne cible

Pour ceux qui veulent en savoir plus sur la fonction im2col qui apparaît dans la reconnaissance d'image à l'aide de CNN Nous expliquerons en détail de la mise en œuvre initiale à la version améliorée, la version compatible avec le canal de traitement par lots, la version compatible avec le rembourrage de foulée à l'aide de gifs et d'images.

table des matières

Qu'est-ce que «col2im»?

La fonction col2im est une fonction importante qui est indispensable dans des domaines tels que la reconnaissance d'image, qui est associée à la fonction ʻim2im. Son rôle est l'opposé de la fonction ʻim2col, qui a été convertie en ** tenseur $ \ rightarrow $ matrix ** par la fonction ʻim2collors de la ** propagation en avant **, alors qu'elle a été convertie en ** rétropropagation **. Convertit en ** matrice $ \ rightarrow $ tensor ** avec la fonctioncol2im`. En faisant cela, il sera transformé en une forme adaptée à l'apprentissage comme un filtre.

Comportement et implémentation initiale de col2im

Commençons par l'implémentation initiale. En d'autres termes

stride = 1 \\
pad = 0

Supposer que L'opération est l'inverse de la fonction ʻim2col, donc c'est comme suit. ![col2im_image.gif](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/640911/7757c2b4-43a7-ef74-8e18-649b6caf3766.gif) À ce stade, veuillez noter que ** les parties qui se chevauchent sont ajoutées **. La raison peut être comprise en considérant l'opération de filtrage. Lors de la mise au point sur un élément, la couche d'éléments suivante affectée par le filtrage est illustrée dans la figure ci-dessous. ![col2im_NN.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/640911/d663724c-caca-4b68-72af-736dd7cacb7c.png) En d'autres termes, il ** se ramifie dans chaque élément **. Cela signifie que ** les dégradés qui ont coulé dans la propagation arrière doivent être additionnés **. Par conséquent, lors de la transformation avec la fonction col2im`, il est nécessaire" d'ajouter les parties qui se chevauchent ".

Maintenant, construisons simplement un programme selon cette logique.

<détails>

Début col2im </ summary>

col2im.py


def col2im(cols, I_shape, O_shape):
    def get_f_shape(i, o):
        return i - o + 1
    
    I_h, I_w = I_shape
    O_h, O_w = O_shape
    F_h = get_f_shape(I_h, O_h)
    F_w = get_f_shape(I_w, O_w)
    images = np.zeros((I_h, I_w))
    
    for h in range(O_h):
        h_lim = h + F_h
        for w in range(O_w):
            w_lim = w + F_w
            images[h:h_lim, w:w_lim] += cols[:, h*O_h+w].reshape(F_h, F_w)
    
    return images


x = np.ones((4, 4))
f = np.arange(-2*2, 0).reshape(2, 2)
im2col_x, O_shape = im2col(x, f, pad=0, get_out_size=True)
im2col_f, Of_shape = im2col(f, f, get_out_size=True)
print(im2col_x)
print(im2col_f)
print(col2im(im2col_x, x.shape, O_shape))
print(col2im(im2col_f, f.shape, Of_shape))

Je me sens comme cela. Tout d'abord, nous préparons une boîte qui a la forme après déformation, puis nous la transformons et la jetons dans chaque rangée. Ici, la forme du filtre est l'expression relationnelle entre l'entrée et la sortie d'im2col et le filtre.

O_h = I_h - F_h + 1 \\
O_w = I_w - F_w + 1

Il est calculé en utilisant.

Amélioration de col2im

Après tout, l'implémentation initiale nécessite un accès $ O_h O_w $ comme ʻim2col, donc elle a l'inconvénient d'une vitesse de traitement lente et d'une impraticabilité. Alors, imaginez la même chose que pour ʻim2col. La méthode est juste l'ordre inverse.

Version améliorée `col2im`

col2im.py


def col2im(cols, I_shape, O_shape):
    def get_f_shape(i, o):
        return i - o + 1
    
    I_h, I_w = I_shape
    O_h, O_w = O_shape
    F_h = get_f_shape(I_h, O_h)
    F_w = get_f_shape(I_w, O_w)
    cols = cols.reshape(F_h, F_w, O_h, O_w)
    images = np.zeros((I_h, I_w))
    
    for h in range(F_h):
        h_lim = h + O_h
        for w in range(F_w):
            w_lim = w + O_w
            images[h:h_lim, w:w_lim] += cols[h, w, :, :]
    
    return images


x = np.ones((4, 4))
f = np.arange(-2*2, 0).reshape(2, 2)
im2col_x, O_shape = im2col(x, f, pad=0, get_out_size=True)
im2col_f, Of_shape = im2col(f, f, get_out_size=True)
print(im2col_x)
print(im2col_f)
print(col2im(im2col_x, x.shape, O_shape))
print(col2im(im2col_f, f.shape, Of_shape))

Tout d'abord, la matrice entrée dans col2im improved_im2col_reshape.png De improved_col.png Il se transforme en une telle forme. Lors de l'allocation de mémoire pour la matrice de sortie dans la version améliorée ʻim2col` Il a la même forme. plus tard improved_col2im.gif Je vais y accéder comme ça. C'est technique ~

Version terminée col2im

Alors enfin, pensez à la foulée et au rembourrage.

Version terminée `col2im`

col2im.py


import numpy as np


def col2im(cols, I_shape, O_shape, stride=1, pad=0):
    def get_f_shape(i, o, s, p):
        return int(i + 2*p - (o - 1)*s)
    
    if len(I_shape) == 2:
        B = C = 1
        I_h, I_w = I_shape
    elif len(img_shape) == 3:
        C = 1
        B, I_h, I_w = I_shape
    else:
        B, C, I_h, I_w = I_shape
    O_h, O_w = O_shape
    
    if isinstance(stride, tuple):
        stride_ud, stride_lr = stride
    else:
        stride_ud = stride
        stride_lr = stride
    if isinstance(pad, tuple):
        pad_ud, pad_lr = pad
    elif isinstance(pad, int):
        pad_ud = pad
        pad_lr = pad
    
    F_h = get_f_shape(I_h, O_h, stride_ud, pad_ud)
    F_w = get_f_shape(I_w, O_w, stride_lr, pad_lr)
    pad_ud = int(np.ceil(pad_ud))
    pad_lr = int(np.ceil(pad_lr))
    cols = cols.reshape(C, F_h, F_w, B, O_h, O_w).transpose(3, 0, 1, 2, 4, 5)
    images = np.zeros((B, C, I_h+2*pad_ud+stride-1, I_w+2*pad_lr+stride-1))
    
    for h in range(F_h):
        h_lim = h + stride*O_h
        for w in range(F_w):
            w_lim = w + stride*O_w
            images[:, :, h:h_lim:stride, w:w_lim:stride] += cols[:, :, h, w, :, :]
    
    return images[:, :, pad_ud : I_h+pad_ud, pad_lr : I_w+pad_lr]

x = np.ones((4, 4))
f = np.arange(-2*2, 0).reshape(2, 2)
im2col_x, O_shape, x_pad = im2col(x, f, pad="same")
im2col_f, Of_shape, f_pad = im2col(f, f)
print(im2col_x)
print(im2col_f)
#print((im2col_f.T@im2col_x).reshape(*O_shape))
print(col2im(im2col_x, x.shape, O_shape, pad=x_pad))
print(col2im(im2col_f, f.shape, Of_shape, pad=f_pad))

Calcul de la forme lorsque la foulée et le rembourrage sont pris en compte

O_h = \left\lceil \cfrac{I_h - F_h + 2\textrm{pad}_{ud}}{\textrm{stride}_{ud}} \right\rceil + 1 \\
O_w = \left\lceil \cfrac{I_w - F_w + 2\textrm{pad}_{lr}}{\textrm{stride}_{lr}} \right\rceil + 1 \\

Alors, calculez la forme du filtre à partir d'ici.

F_h = I_h + 2\textrm{pad}_{ud} - (O_h - 1) \textrm{stride}_{ud} \\
F_w = I_w + 2\textrm{pad}_{lr} - (O_w - 1) \textrm{stride}_{lr}

Je pensais à diverses choses, mais pour le restaurer correctement, la valeur exacte de $ \ textrm {pad} \ _ {ud}, \ textrm {pad} \ _ {lr} $ (la valeur avant d'arrondir par la fonction de plafond) est Comme cela semble nécessaire, j'ai changé l'implémentation de la fonction ʻim2col` en conséquence.

Une petite question

Pendant que j'expérimentais, j'ai remarqué que l'ajout de la matrice d'entrée de matrice $ 4 \ fois 4 $ haut, bas, gauche et droite $ \ textrm {pad} = 1 $ donne 6 $ \ fois 6 $, ce qui correspond à 2 $ \ fois 2 $ matrice. Si vous appliquez le filtre avec $ \ textrm {stride} = 1 $, la matrice de sortie devrait être $ 5 \ times 5 $, mais ce n'est pas le cas. pad_im2col.png Je me suis demandé pourquoi, mais au fait, si vous entrez $ \ textrm {pad} = \ textrm {same} $ dans la fonction ʻim2col sous cette condition, le remplissage du résultat du calcul sera $ \ textrm {pad} = 0,5 $. Ce sera. Et, bien sûr, la largeur de remplissage est un entier, donc elle est arrondie à $ \ textrm {pad} = 1 $, donc elle devient une matrice $ 6 \ times 6 $. Par conséquent, elle doit être traitée comme une matrice $ 5 \ times 5 $, et vous pouvez voir que la fonction ʻim2col renvoie en fait un produit qui utilise seulement la matrice $ 5 \ times 5 $ en haut à gauche. La preuve en est que la partie superposée de la fonction col2im col2im_result.png Comme, la partie supérieure gauche est ajoutée 4 fois. col2im_q.gif

en conclusion

L'explication est considérablement simplifiée car c'est juste l'ordre inverse de la fonction ʻim2col`. Des explications plus détaillées peuvent être ajoutées lorsque le temps est disponible.

Série d'apprentissage en profondeur

Recommended Posts

compréhension approfondie de col2im
Compréhension approfondie d'Im2col
Comprendre Concaténer