[PYTHON] Informationen zur Entfaltungsfunktion

Conv2D-Arithmetik

Unter Berücksichtigung der Operation der zweidimensionalen Faltung geben Sie $ (Batch, H, W, C_ {in}) $, $ (Batch, H, W, C_ {out}) $, Kernelgröße $ (3,3) ein. $, Faltungsgewicht $ W = (3,3, C_ {in}, C_ {out}) $

Effektiv Conv2D-Betrieb matmul(x,W)=matmul((Batch,HW,9C_{in}), (9C_{in}, C_{out}))=(Batch,HW,C_{out}) Dies entspricht der Betrachtung der Matrixoperation von matmul. Hier ist in der Matrixoperation von $ c = matmul (a, b) $, wenn $ a = (i, j, k, m), b = (m, n) $, dann $ c = (i, j, k, n) ) $.

Andererseits für die Eingabe $ (Batch, H, W, C_ {in}) $

python


x[:,0]=input[:,0:H-2,0:W-2,:] \\
x[:,1]=input[:,0:H-2,1:W-1,:] \\
x[:,2]=input[:,0:H-2,2:W-0,:] \\ 
x[:,3]=input[:,1:H-1,0:W-2,:] \\
x[:,4]=input[:,1:H-1,1:W-1,:] \\
x[:,5]=input[:,1:H-1,2:W-0,:] \\ 
x[:,6]=input[:,2:H-0,0:W-2,:] \\
x[:,7]=input[:,2:H-0,1:W-1,:] \\
x[:,8]=input[:,2:H-0,2:W-0,:]

Extrahieren Sie $ (H-2, W-2) $ aus $ (H, W) $ like und konvertieren Sie es vor der Matrixoperation in eine Matrix wie $ (Batch, HW, 9C_ {in}) $. müssen es tun. Eine solche Matrixtransformation heißt $ im2col $. Es kann davon ausgegangen werden, dass dieser Prozess die Anzahl der Eingangskanäle um die Gesamtzahl der Kernelgrößen verdoppelt. Auch der $ im2col $ -Prozess selbst hat kein Gewicht.

python


def im2col(input_data, filter_h, filter_w, stride=1, pad=0):

    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col

In Pytorch wird die Funktion im2col als Unfold-Funktion bezeichnet. Daher sollte es ** Conv2D = (im2col + matmul) = (Unfold + matmul) ** sein. Ich versuchte herauszufinden, ob das Hauptthema wirklich so war.

Vergleich in PyTorch

PyTorch ist Kanal zuerst mit Eingang $ (Batch, C_ {in}, H, W) = (25,3,32,32) $, Ausgang $ (Batch, C_ {out}, H, W) = (25,16) , 30,30) $, Kernelgröße $ (3,3) $, Gewicht $ W = (C_ {out}, 3 × 3 × C_ {in}) = (16,27) $.

(Entfalten + Matmul) Betrieb

python


import numpy as np
import torch

input = torch.tensor(np.random.rand(25,3,32,32)).float()
weight = torch.tensor(np.random.rand(16,3,3,3)).float()
weight2 = weight.reshape((16,27))

print('input.shape=  ', input.shape)
print('weight.shape= ', weight.shape)
print('weight2.shape=', weight2.shape)

x = torch.nn.Unfold(kernel_size=(3,3), stride=(1,1), padding=(0,0), dilation=(1,1))(input)
output1 = torch.matmul(weight2, x).reshape((25,16,30,30))

print('x.shape=      ', x.shape)
print('output1.shape=', output1.shape)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 32, 32])
weight.shape=  torch.Size([16, 3, 3, 3])
weight2.shape= torch.Size([16, 27])
x.shape=       torch.Size([25, 27, 900])
output1.shape= torch.Size([25, 16, 30, 30])

Wenn Sie hier die Entfaltungsfunktion auf die Eingabe anwenden, ist $ x = (25, 3 × 3 × 3, 30 × 30) = (25,27,900) $, und wenn $ W = (16,27) $, $ matmul (W) , x) = (25,16,30 × 30) $.

Conv2D-Arithmetik

Wenn andererseits die Eingabe $ (Batch, C_ {in}, H, W) = (25,3,32,32) $ und das Gewicht der Conv2D-Funktion $ W = (16,3,3,3) $ ist Der Code für die Ausgabe ist unten.

python


conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, bias=False)
conv1.weight.data = weight

output2 = conv1(input)

print('conv1.weight.shape=', conv1.weight.shape)
print('output2.shape= ', output2.shape)
-----------------------------------------------------------
conv1.weight.shape= torch.Size([16, 3, 3, 3])
output2.shape=  torch.Size([25, 16, 30, 30])

Beim Vergleich von ** output1 **, erhalten von (Unfold + matmul), und ** output2 **, erhalten von Conv2D, waren die Werte völlig gleich. Daher wurde bestätigt, dass es rechnerisch äquivalent zu ** Conv2D = (Unfold + matmul) ** ist.

python


output1:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4978, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0875,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......
output2:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4979, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0874,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......

Andere Verwendungen der Entfaltungsfunktion

Wenn kernel_size und stride gleich sind, entspricht dies der Patch-Aufteilung von Vision Transformer. Nun, Patch-Splitting kann durch Umformen und Transponieren ersetzt werden, ohne Unfold zu verwenden ...

python


input = torch.tensor(np.random.rand(25,3,224,224)).float()
x = torch.nn.Unfold(kernel_size=(14,14), stride=(14,14), padding=(0,0), dilation=(1,1))(input)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 224, 224])
x.shape=       torch.Size([25, 588, 256]) #(25,3*14*14,16*16)

In der Geschichte, dass Vision Transformer Conv2D überhaupt nicht verwendet, hatte ich eine unbegründete Täuschung, dass Unfold + matmul selbst in ViT Conv2D entspricht, da matmul in die Berechnung von Aufmerksamkeitsgewicht und -wert einbezogen wird.

Zusammenfassung

Die Unfold-Funktion ist die im2col-Funktion in Pytorch und ** Conv2D = (Unfold + matmul) **. Im Tensorflow ist dies die Funktion extract_image_patches.

Recommended Posts

Informationen zur Entfaltungsfunktion
Über die Aufzählungsfunktion (Python)
Denken Sie grob über die Verlustfunktion nach
Über den Test
Über die Warteschlange
Über die Argumente der Setup-Funktion von PyCaret
Über Funktionsargumente (Python)
Die erste GOLD "Funktion"
Python: Über Funktionsargumente
Über den Servicebefehl
Über die Verwirrungsmatrix
Über das Besuchermuster
In Bezug auf die Aktivierungsfunktion Gelu
Was ist die Aktivierungsfunktion?
Über das Python-Modul venv
Python-Anfänger-Memorandum-Funktion
Informationen zur Funktion fork () und zur Funktion execve ()
Über das Problem der reisenden Verkäufer
Über das Verständnis des 3-Punkt-Lesers [...]
Über die Komponenten von Luigi
Über die Funktionen von Python
Was ist die Rückruffunktion?
Verwendung der Zip-Funktion
Sortierwarnung in der Funktion pd.concat
Denken Sie an das Problem der minimalen Änderung
[Python] Was ist @? (Über Dekorateure)
Über den Rückgabewert von pthread_mutex_init ()
Vorsichtsmaßnahmen bei Verwendung der Funktion urllib.parse.quote
Über den Rückgabewert des Histogramms.
[Python] Machen Sie die Funktion zu einer Lambda-Funktion
Über den Grundtyp von Go
Über die Obergrenze von Threads-max
Über die durchschnittliche Option von sklearn.metrics.f1_score
Über das Verhalten von Yield_per von SqlAlchemy
Über die Größe der Punkte in Matplotlib
Informationen zur Grundlagenliste der Python-Grundlagen
[Python Kivy] Über das Ändern des Designthemas
Informationen zum Verhalten von enable_backprop von Chainer v2
Informationen zur virtuellen Umgebung von Python Version 3.7
Verschiedene Hinweise zum Django REST-Framework
Nehmen Sie die logische Summe von List in Python (Zip-Funktion)
[OpenCV] Über das von imread zurückgegebene Array
Über NumFOCUS, eine Open Source-Support-Organisation
[Python3] Schreiben Sie das Codeobjekt der Funktion neu
Denken Sie grob über die Gradientenabstiegsmethode nach
[Python] Fassen Sie die rudimentären Dinge über Multithreading zusammen
Informationen zu der von Ihnen verwendeten Entwicklungsumgebung
Einführung der Funktion addModuleCleanup / doModuleCleanups von unittest
Was ist mit 2017 rund um die Crystal-Sprache? (Täuschung)
Über die Beziehung zwischen Git und GitHub
Über die Normalgleichung der linearen Regression
Ein Memo, dass ich das Pyramid Tutorial ausprobiert habe