[PYTHON] Comprendre VQ-VAE

introduction

VQ-VAE est une VAE qui utilise une technique appelée Vector Quantised. Dans la VAE conventionnelle, la variable latente z est formée pour être un vecteur de distribution normale (distribution gaussienne), mais dans VQ-VAE, la variable latente est formée pour être une valeur numérique discrète. Le modèle se compose de (Encodeur) - (partie de quantification) - (Décodeur), mais l'encodeur et le décodeur ne sont pas très différents de VAE qui effectue la convolution. Quand j'ai jeté un coup d'œil sur le papier et la mise en œuvre de VQ-VAE, ma compréhension de la façon de faire la partie en charge de la quantification a changé, je vais donc résumer ma compréhension dans un mémorandum.

Qu'est-ce que l'intégration

L'intégration est probablement inévitable lorsqu'on parle de VQ-VAE. Si vous ne le comprenez pas comme vous-même, il est difficile de comprendre à quoi cela ressemble.

Il était plus facile pour moi de voir un exemple. Par exemple, considérons le cas où la matrice d'entrée $ (2,4) $, la valeur numérique est la valeur d'index et la matrice d'incorporation est $ (10,3) $ comme indiqué ci-dessous. Dans ce cas, si vous convertissez la matrice d'entrée en onehot pour la rendre $ (2,4,10) $ et la multipliez par la matrice d'incorporation $ (10,3) $, la matrice $ (2,4,3) $ sera créée après l'incorporation. En bref, ** l'incorporation est juste une entrée chaude et multipliée par une matrice d'incorporation. ** **

python


>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

Première compréhension de VQ-VAE (faux)

image.png Au début, j'ai compris comme le montre la figure ci-dessus. C'est une erreur de dire d'abord. L'entrée $ z_e $ est $ (10,10,32) $, la matrice d'intégration est $ (32,128) $, et j'ai pensé à la quantification vectorielle de 128 $ $ dans l'espace latent (discrimination de l'espace latent). Multipliez l'entrée $ z_e $ par la matrice d'incorporation pour obtenir l'index de l'emplacement le plus proche de 1. (L'indice du vecteur onehot le plus proche de tout vecteur onehot). Il devient $ q (z | x) $, qui est une matrice de $ (10,10) $ et la valeur est la valeur d'index. Convertissez-le en onehot et multipliez-le par l'inverse de la matrice d'intégration pour obtenir $ z_q $. Ici, le processus de conversion de $ z_q $ en onehot et de le multiplier par la matrice d'incorporation n'est rien d'autre que le processus d'incorporation lui-même expliqué au début. La fonction de perte est $ (z_e-z_q) ^ 2 $ car la sortie $ z_q $ devrait se rapprocher de l'entrée $ z_e $ si l'espace latent est bien décentralisé.

Si vous écrivez le changement de la figure avec numpy, ce sera comme suit.

python


import numpy as np

input = np.random.rand(10,10,32)
embed = np.random.rand(32,128)
embed_inv = np.linalg.pinv(embed)
dist = (np.dot(input, embed) - np.ones((10,10,128)))**2
embed_ind = np.argmin(dist, axis=2)
embed_onehot = np.identity(128)[embed_ind]
output = np.dot(embed_onehot, embed_inv)

print("input.shape=", input.shape)
print("embed.shape=", embed.shape)
print("embed_inv.shape=", embed_inv.shape)
print("dist.shape=", dist.shape)
print("embed_ind.shape=", embed_ind.shape)
print("embed_onehot.shape=", embed_onehot.shape)
print("output.shape=", output.shape)
----------------------------------------------
input.shape= (10, 10, 32)
embed.shape= (32, 128)
embed_inv.shape= (128, 32)
dist.shape= (10, 10, 128)
embed_ind.shape= (10, 10)
embed_onehot.shape= (10, 10, 128)
output.shape= (10, 10, 32)

Qu'est-ce qui ne va pas

L'interprétation ci-dessus est incorrecte par rapport à la Mise en œuvre réelle. L'une des raisons est que le processus de recherche de l'inverse de la matrice d'inclusion n'est probablement pas possible en pratique. Par conséquent, nous devons trouver $ q (z | x) $ et $ z_q $ d'une manière qui n'utilise pas l'inverse de la matrice d'intégration.

L'autre est qu'il diffère de la définition du papier $ q (z | x) $. Le papier est le suivant, image.png La formule suivante dans l'interprétation ci-dessus est incorrecte. q(z|x)=argmin((z_e \cdot e_{mbed} - I)^2)

En multipliant argmin par la matrice inverse $ e_ {mbed \ inv} $ de la matrice d'intégration, elle peut être organisée comme suit. ((z_e \cdot e_{mbed} - I)^2 \cdot e_{mbed\ inv}^2)=(z_e \cdot e_{mbed} \cdot e_{mbed\ inv}- I \cdot e_{mbed\ inv})^2=(z_e - e_{mbed\ inv})^2 Cela équivaut à la formule du papier.

Ensuite, il est plus pratique de remplacer les noms de la matrice d'incorporation et de sa matrice inverse. En d'autres termes, $ e_ {mbed \ inv} $ s'appellera $ e_ {mbed} $ et $ e_ {mbed} $ s'appellera $ e_ {mbed \ inv} $.

Deuxième compréhension de VQ-VAE

Avec les corrections ci-dessus, nous avons la compréhension suivante. Notez que $ e_ {mbed \ inv} $ n'est pas inclus dans les deux expressions pour $ q (z | x) $ et $ z_q $. Les deux $ q (z | x) $ et $ z_q $ peuvent être calculés avec les entrées $ z_e $ et $ e_ {mbed} , éliminant ainsi le besoin de calculer leur inverse. En particulierq(z|x)Estz_eQuande_{mbed}Dez_qEstq(z|x)Quande_{mbed}$Il est calculé à partir de.

image.png

Propagation par gradient de la fonction de perte

Au fait, si vous pensez que la fonction de perte liée à la quantification vectorielle est $ (z_e-z_q) ^ 2 $, qui est la différence avant et après la quantification, c'est en fait différent. Il est exprimé comme $ (sg (z_e) -z_q) ^ 2 + (z_e-sg (z_q)) ^ 2 $ en utilisant la fonction qui arrête le gradient appelé $ sg () $. Ceci est différent de $ (z_e-z_q) ^ 2 $, probablement parce qu'il est difficile de calculer la rétropropagation d'erreur de $ z_e $ et $ e_ {mbed} $ à partir de $ q (z | x) $. , On pense que la transmission du gradient de cette pièce est coupée.

De plus, le contenu à mettre à jour diffère entre les deuxième et troisième éléments de la fonction de perte. Le deuxième élément met à jour la matrice d'incorporation, mais le dégradé ne se propage pas à l'entrée (encodeur). Le troisième élément propage le dégradé vers l'entrée (encodeur) mais ne met pas à jour la matrice d'incorporation. En ce qui concerne le premier élément de la fonction de perte, il semble qu'il commence à partir de Decorder, saute la partie de quantification de $ z_q à z_e $, et est transmis à Encoder. Cependant, ce n'est pas différent de la perte d'un AutoEncoder normal. image.png image.png

fonction argmin

Écrivons une fonction argmin qui prend la valeur d'index de la plus petite valeur du tableau à l'aide de la fonction d'étape du côté lourd.

argmin(a,b) = H(b-a) \cdot 0 + H(a-b) \cdot 1 \\
argmin(a,b,c) = H(b-a) \cdot H(c-a) \cdot 0 + H(a-b) \cdot H(c-b) \cdot 1 + H(a-c) \cdot H(b-c) \cdot 2\\
H(x) =\left\{
\begin{array}{ll}
1 & (x \geq 0) \\
0 & (x \lt 0)
\end{array}
\right.

Ici, le terme pour lequel la valeur minimale est soustraite est la valeur de tous les produits qui reste à 1, et lorsque la valeur qui n'est pas la valeur minimale est soustraite, l'un d'eux devient zéro et ne reste pas. Par conséquent, $ argmin (a_ {1}, \ cdots, a_ {128}) $ est le produit de fonctions d'escalier d'ordre supérieur, donc même si la fonction d'escalier est remplacée par une fonction différentielle continue pendant le calcul du gradient (par exemple, sigmoïde). Fonction) Cela semble difficile à différencier. Cependant, vous n'avez pas à y penser en utilisant une fonction qui arrête le dégradé appelée $ sg () $ comme expliqué précédemment.

Résumé

Je pensais avoir un aperçu de l'implémentation et l'avoir comprise, mais j'ai réalisé que la matrice d'intégration à laquelle je pensais au début était l'inverse de la matrice d'intégration réelle. Il peut être facile de se méprendre sur le fait que la matrice d'enrobage n'est pas la matrice à multiplier lors de la quantification vectorielle de l'entrée. La matrice d'inclusion est celle qui est multipliée lors de la conversion de la variable latente quantifiée en espace latent. De plus, j'ai senti que la quantification vectorielle était similaire à la segmentation sémantique, qui est la reconnaissance d'objets en unités de pixels. La segmentation sémantique utilise softmax pour générer un vecteur onehot pour chaque pixel, tandis que VQ utilise un carré de distance et argmin pour générer un vecteur de quantification.

Référence: pytorch VQ-VAE

De Exemple d'implémentation réelle

class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        embed = torch.randn(dim, n_embed)
        ...

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)
        ...
        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

Recommended Posts

Comprendre VQ-VAE
Comprendre Concaténer
Comprendre Python Coroutine
Compréhension approfondie d'Im2col
[Discord.py] Comprendre Cog
Comprendre Tensor (1): Dimension
compréhension approfondie de col2im
Comprendre l'auto python
Comprendre TensorFlow avec l'arithmétique