[PYTHON] VQ-VAE verstehen

Einführung

VQ-VAE ist eine VAE, die eine Technik namens Vector Quantised verwendet. In der herkömmlichen VAE wird die latente Variable z als Vektor der Normalverteilung (Gaußsche Verteilung) trainiert, in der VQ-VAE wird die latente Variable als diskreter numerischer Wert trainiert. Das Modell besteht aus (Encoder) - (Quantifizierungsteil) - (Decoder), aber Encoder und Decoder unterscheiden sich nicht wesentlich von VAE, die Faltung ausführt. Als ich einen Blick auf das Papier und die Implementierung von VQ-VAE warf, änderte sich mein Verständnis, wie der für die Quantisierung zuständige Teil hergestellt werden kann, und ich werde mein Verständnis als Memorandum zusammenfassen.

Was ist Einbetten?

Das Einbetten ist wahrscheinlich unvermeidlich, wenn es um VQ-VAE geht. Wenn Sie es nicht so verstehen wie Sie selbst, ist es schwierig zu verstehen, wie das ist.

Für mich war es am einfachsten, ein Beispiel zu sehen. Betrachten Sie beispielsweise den Fall, in dem die Eingabematrix $ (2,4) $, der numerische Wert der Indexwert und die Einbettungsmatrix $ (10,3) $ ist (siehe unten). Wenn Sie in diesem Fall die Eingabematrix in onehot konvertieren, machen Sie sie zu $ (2,4,10) $ und multiplizieren Sie sie mit der Einbettungsmatrix $ (10,3) $. Nach dem Einbetten wird eine $ (2,4,3) $ -Matrix erstellt. Kurz gesagt, ** Einbetten ist nur eine heiße Eingabe und wird mit einer Einbettungsmatrix multipliziert. ** ** **

python


>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

Erstes Verständnis von VQ-VAE (falsch)

image.png Zuerst habe ich verstanden, wie in der obigen Abbildung gezeigt. Es ist ein Fehler, zuerst zu sagen. Die Eingabe $ z_e $ ist $ (10,10,32) $, die Einbettungsmatrix ist $ (32,128) $, und ich habe über eine Vektorquantisierung von $ 128 $ im latenten Raum nachgedacht (Unterscheidung des latenten Raums). Multiplizieren Sie die Eingabe $ z_e $ mit der Einbettungsmatrix, um den Index der Position zu erhalten, die 1 am nächsten liegt. (Der Index des Onehot-Vektors, der einem Onehot-Vektor am nächsten kommt). Es wird $ q (z | x) $, eine Matrix von $ (10,10) $, und der Wert ist der Indexwert. Konvertieren Sie dies in onehot und multiplizieren Sie es mit der Umkehrung der Einbettungsmatrix, um $ z_q $ zu erhalten. Hier ist der Prozess des Konvertierens von $ z_q $ in onehot und des Multiplizierens mit der Einbettungsmatrix nichts anderes als der eingangs erläuterte Einbettungsprozess selbst. Die Verlustfunktion ist $ (z_e-z_q) ^ 2 $, da sich die Ausgabe $ z_q $ der Eingabe $ z_e $ nähern sollte, wenn der latente Raum gut dezentralisiert ist.

Wenn Sie die Änderung der Figur mit numpy schreiben, ist dies wie folgt.

python


import numpy as np

input = np.random.rand(10,10,32)
embed = np.random.rand(32,128)
embed_inv = np.linalg.pinv(embed)
dist = (np.dot(input, embed) - np.ones((10,10,128)))**2
embed_ind = np.argmin(dist, axis=2)
embed_onehot = np.identity(128)[embed_ind]
output = np.dot(embed_onehot, embed_inv)

print("input.shape=", input.shape)
print("embed.shape=", embed.shape)
print("embed_inv.shape=", embed_inv.shape)
print("dist.shape=", dist.shape)
print("embed_ind.shape=", embed_ind.shape)
print("embed_onehot.shape=", embed_onehot.shape)
print("output.shape=", output.shape)
----------------------------------------------
input.shape= (10, 10, 32)
embed.shape= (32, 128)
embed_inv.shape= (128, 32)
dist.shape= (10, 10, 128)
embed_ind.shape= (10, 10)
embed_onehot.shape= (10, 10, 128)
output.shape= (10, 10, 32)

Was ist los

Die obige Interpretation ist im Vergleich zur tatsächlichen Implementierung falsch. Einer der Gründe ist, dass der Prozess des Findens der Inversen der Einbettungsmatrix wahrscheinlich nicht wirklich möglich ist. Daher müssen wir $ q (z | x) $ und $ z_q $ so finden, dass die Umkehrung der Einbettungsmatrix nicht verwendet wird.

Das andere ist, dass es sich von der Definition des $ q (z | x) $ -Papiers unterscheidet. Das Papier ist wie folgt: image.png Die folgende Formel in der obigen Interpretation ist falsch. q(z|x)=argmin((z_e \cdot e_{mbed} - I)^2)

Unter Berücksichtigung der Multiplikation von argmin mit der inversen Matrix $ e_ {mbed \ inv} $ der Einbettungsmatrix kann dies wie folgt organisiert werden. ((z_e \cdot e_{mbed} - I)^2 \cdot e_{mbed\ inv}^2)=(z_e \cdot e_{mbed} \cdot e_{mbed\ inv}- I \cdot e_{mbed\ inv})^2=(z_e - e_{mbed\ inv})^2 Dies entspricht der Formel im Papier.

Dann ist es bequemer, die Namen der Einbettungsmatrix und ihrer inversen Matrix zu ersetzen. Mit anderen Worten, $ e_ {mbed \ inv} $ heißt $ e_ {mbed} $ und $ e_ {mbed} $ heißt $ e_ {mbed \ inv} $.

Zweites Verständnis von VQ-VAE

Mit den obigen Korrekturen haben wir das folgende Verständnis. Beachten Sie, dass $ e_ {mbed \ inv} $ nicht in beiden Ausdrücken für $ q (z | x) $ und $ z_q $ enthalten ist. Sowohl $ q (z | x) $ als auch $ z_q $ können mit den Eingaben $ z_e $ und $ e_ {mbed} $ berechnet werden, sodass die Umkehrung nicht mehr erforderlich ist. Speziellq(z|x)Istz_eWanne_{mbed}Vonz_qIstq(z|x)Wanne_{mbed}Es wird berechnet aus.

image.png

Gradientenausbreitung der Verlustfunktion

Übrigens, wenn Sie denken, dass die mit der Vektorquantisierung verbundene Verlustfunktion $ (z_e-z_q) ^ 2 $ ist, was der Unterschied vor und nach der Quantisierung ist, ist sie tatsächlich unterschiedlich. Es wird ausgedrückt als $ (sg (z_e) -z_q) ^ 2 + (z_e-sg (z_q)) ^ 2 $ unter Verwendung der Funktion, die den Gradienten namens $ sg () $ stoppt. Dies unterscheidet sich von $ (z_e-z_q) ^ 2 $, wahrscheinlich weil es schwierig ist, die Fehlerrückausbreitung von $ z_e $ und $ e_ {mbed} $ aus $ q (z | x) $ zu berechnen. Es wird angenommen, dass die Gradientenübertragung dieses Teils unterbrochen ist.

Außerdem unterscheidet sich der zu aktualisierende Inhalt zwischen dem zweiten und dritten Element der Verlustfunktion. Das zweite Element aktualisiert die Einbettungsmatrix, aber der Gradient breitet sich nicht auf den Eingang (Encoder) aus. Das dritte Element überträgt den Gradienten an den Eingang (Encoder), aktualisiert jedoch nicht die Einbettungsmatrix. In Bezug auf das erste Element der Verlustfunktion scheint es, dass es von Decorder ausgeht, den Quantisierungsteil von $ z_q bis z_e $ überspringt und an Encoder übertragen wird. Dies unterscheidet sich jedoch nicht vom Verlust eines normalen AutoEncoders. image.png image.png

Argmin-Funktion

Schreiben wir eine Argmin-Funktion, die den Indexwert des kleinsten Werts im Array mithilfe der Heavyside-Step-Funktion verwendet.

argmin(a,b) = H(b-a) \cdot 0 + H(a-b) \cdot 1 \\
argmin(a,b,c) = H(b-a) \cdot H(c-a) \cdot 0 + H(a-b) \cdot H(c-b) \cdot 1 + H(a-c) \cdot H(b-c) \cdot 2\\
H(x) =\left\{
\begin{array}{ll}
1 & (x \geq 0) \\
0 & (x \lt 0)
\end{array}
\right.

Hier ist der Term, für den der Mindestwert abgezogen wird, der Wert aller Produkte, der als 1 verbleibt, und wenn der Wert, der nicht der Mindestwert ist, abgezogen wird, wird einer von ihnen Null und bleibt nicht erhalten. Daher ist $ argmin (a_ {1}, \ cdots, a_ {128}) $ das Produkt von Treppenfunktionen höherer Ordnung. Selbst wenn die Treppenfunktion während der Gradientenberechnung durch eine stetige Differentialfunktion ersetzt wird (z. B. Sigmoid). Funktion) Dies scheint schwer zu unterscheiden zu sein. Sie müssen jedoch nicht darüber nachdenken, indem Sie eine Funktion verwenden, die den Gradienten namens $ sg () $ stoppt, wie zuvor erläutert.

Zusammenfassung

Ich dachte, ich hätte einen Blick auf die Implementierung geworfen und sie verstanden, aber mir wurde klar, dass die Einbettungsmatrix, an die ich zu Beginn dachte, die Umkehrung der tatsächlichen Einbettungsmatrix war. Es ist leicht zu missverstehen, dass die Einbettungsmatrix nicht die Matrix ist, die beim Vektorquantisieren der Eingabe multipliziert werden muss. Die Einbettungsmatrix ist diejenige, die multipliziert wird, wenn die quantisierte latente Variable in den latenten Raum konvertiert wird. Ich hatte auch das Gefühl, dass die Vektorquantisierung der semantischen Segmentierung ähnlich ist, bei der es sich um die Objekterkennung in Pixeleinheiten handelt. Die semantische Segmentierung verwendet Softmax, um einen Onehot-Vektor für jedes Pixel zu erzeugen, während VQ Distanzquadrat und Argmin verwendet, um einen Quantisierungsvektor zu erzeugen.

Referenz: Pytorch VQ-VAE

Aus Beispiel für die tatsächliche Implementierung

class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        embed = torch.randn(dim, n_embed)
        ...

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)
        ...
        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

Recommended Posts

VQ-VAE verstehen
Verketten verstehen
Im2col gründliches Verständnis
[Discord.py] Cog verstehen
Tensor verstehen (1): Dimension
col2im gründliches Verständnis
Python selbst verstehen
TensorFlow mit Arithmetik verstehen