[PYTHON] J'ai essayé d'implémenter CVAE avec PyTorch

Pour ma propre pratique, j'ai implémenté et formé CVAE, qui est un type d'apprentissage profond. Cet article est une description de niveau mémo et est rédigé en partant du principe que vous connaissez la VAE. Notez s'il vous plaît.

<détails>

Environnement </ summary>

  • OS: Windows10
  • Python: 3.7.5
  • CUDA: 9.2
  • numpy: 1.18.1
  • torch: 1.4.0+cu92
  • torchvision: 0.5.0+cu92
  • matplotlib: 3.1.3

Il est également implémenté à l'aide de Jupyter Notebook.

Article de référence

Voici quelques pages auxquelles j'ai fait référence lors de la mise en œuvre.

En outre, je me réfère également à l'exemple de mise en œuvre de Pytorch.

Qu'est-ce que CVAE

** CVAE (Conditional Variational Auto Encoder) ** est une méthode avancée de VAE. Dans VAE normal, les données sont entrées dans l'encodeur et les variables latentes sont entrées dans le décodeur, mais dans CVAE, l'état des données est ajouté à celles-ci. Cela vous donne les avantages suivants:

  • Lors de la suppression de dimensions avec Encoder, des fonctionnalités autres que les étiquettes de données peuvent être reflétées.
  • Lors de la génération de données avec Decoder, vous pouvez spécifier l'état des données souhaitées.

Mise en œuvre et apprentissage

Cette fois, nous allons implémenter CVAE avec Pytorch et former MNIST (ensemble de données de caractères manuscrits).

python


import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline

DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50

# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)   
torch.cuda.manual_seed(SEED)


class CVAE(nn.Module):
    def __init__(self, zdim):
        super().__init__()
        self._zdim = zdim
        self._in_units = 28 * 28
        hidden_units = 512
        self._encoder = nn.Sequential(
            nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
        )
        self._to_mean = nn.Linear(hidden_units, zdim)
        self._to_lnvar = nn.Linear(hidden_units, zdim)
        self._decoder = nn.Sequential(
            nn.Linear(zdim + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, self._in_units),
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
        in_[:, :self._in_units] = x
        in_[:, self._in_units:] = labels
        h = self._encoder(in_)
        mean = self._to_mean(h)
        lnvar = self._to_lnvar(h)
        return mean, lnvar

    def decode(self, z, labels):
        in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
        in_[:, :self._zdim] = z
        in_[:, self._zdim:] = labels
        return self._decoder(in_)


def to_onehot(label):
    return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]


# Train
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
for e in range(NUM_EPOCHS):
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        labels = to_onehot(labels)
        # Reconstruction images
        # Encode images
        x = images.view(-1, 28*28*1).to(DEVICE)
        mean, lnvar = model.encode(x, labels)
        std = lnvar.exp().sqrt()
        epsilon = torch.randn(ZDIM, device=DEVICE)
        
        # Decode latent variables
        z = mean + std * epsilon
        y = model.decode(z, labels)
        
        # Compute loss
        kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
        bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
        loss = (-1 * kld + bce).mean()

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.shape[0]

    print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')

résultat

epoch: 1 epoch_loss: 200.2185905436198
epoch: 2 epoch_loss: 160.22688263346353
epoch: 3 epoch_loss: 148.69330817057292

#Omission

epoch: 48 epoch_loss: 98.95304524739583
epoch: 49 epoch_loss: 98.6720672281901
epoch: 50 epoch_loss: 98.65486107177735

Voici une liste de points de mise en œuvre et d'apprentissage.

--Utilisez 6000 données d'apprentissage de torchvision.datasets.MNIST pour l'apprentissage et définissez le nombre d'époques sur 50.

  • Concevoir une classe CVAE avec Encoder et Decoder et implémenter les méthodes ʻencode et decodesans implémenterforward` --Convertissez l'étiquette de l'ensemble de données (numéro écrit) en un vecteur unique et ajoutez-la aux entrées de l'encodeur et du décodeur --La taille du mini-lot au moment de l'apprentissage est de 256 [^ 1] --Comprend un MLP simple pour l'encodeur et le décodeur --Réglez la dimension de la variable latente à 16.

Génération d'images par CVAE

VAE a deux applications, la suppression de dimension et la génération de données, mais cette fois nous nous concentrerons sur la génération de données. Pensez à créer une nouvelle image manuscrite à l'aide du décodeur CVAE que vous avez appris précédemment.

Génération d'images "5"

Les informations d'étiquette fournies au décodeur sont fixées à "5", 100 nombres aléatoires qui suivent la distribution normale standard sont générés, et les images correspondant à chacun sont générées.

python


# Generation data with label '5'
NUM_GENERATION = 100

os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
    z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
    label = torch.tensor([5], device=DEVICE)
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()

    # Save image
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
    plt.close(fig) 

résultat

img_cvae.png

Certains d'entre eux sont déformés, mais nous sommes capables de générer diverses images "5".

Génération d'une image numérique épaisse

J'ai recherché les nombres en gras dans l'image de test de torchvision.datasets.MNIST. L'image suivante est la 49e image de l'ensemble de données.

fat_digit.png

Il est écrit très épais comme "4". Utilisez Encoder pour trouver la variable latente correspondant à ces données.

python


test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transforms.ToTensor(),
    download=True,
)
target_image, label = list(test_dataset)[48]

x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
    mean, _ = model.encode(x, to_onehot(label))
z = mean

print(f'z = {z.cpu().detach().numpy().squeeze()}')

résultat

z = [ 0.7933388   2.4768877   0.49229255 -0.09540698 -1.7999544   0.03376897
  0.01600834  1.3863252   0.14656337 -0.14543885  0.04157912  0.13938689
 -0.2016176   0.5204378  -0.08096244  1.0930295 ]

Ce vecteur à 16 dimensions contient les informations de l'image de ** autre que l'étiquette ** donnée au moment de la formation. En d'autres termes, vous devriez avoir l'information "très épaisse", pas l'information "c'est sous la forme de 4".

Par conséquent, en utilisant cette variable latente, essayez de générer une image tout en modifiant les informations d'étiquette fournies au décodeur.

python


os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/fat/img{label}')
    plt.close(fig) 

résultat

fat_generated.png

"2" est un peu suspect, mais je suis capable de générer une image avec des nombres épais.

en conclusion

Je connaissais CVAE depuis longtemps, mais c'était la première fois que je l'implémentais. Je suis content que cela ait fonctionné. Il est important non seulement de le connaître mais aussi de le mettre en œuvre. Certaines des images générées n'étaient pas jolies, mais elles peuvent être résolues en utilisant la convolution ou la convolution de translocation dans le réseau VAE. Bien que cette fois omis, le système VAE reconnaît qu'il est important d'analyser quelles entités sont mappées et où dans l'espace de faible dimension. J'aimerais faire cette analyse cette fois.

[^ 1]: C'est pour que les données avec toutes les étiquettes existent dans le mini-lot afin que l'image du mini-lot par Encoder suive la distribution normale standard dans l'espace variable latent.

Recommended Posts

J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai essayé d'implémenter Autoencoder avec TensorFlow
J'ai essayé d'implémenter SSD avec PyTorch maintenant (édition du modèle)
J'ai essayé d'implémenter StarGAN (1)
J'ai essayé de déplacer Faster R-CNN rapidement avec pytorch
J'ai essayé d'implémenter Mine Sweeper sur un terminal avec python
J'ai essayé d'implémenter le perceptron artificiel avec python
[Introduction à Pytorch] J'ai essayé de catégoriser Cifar10 avec VGG16 ♬
J'ai essayé d'implémenter Grad-CAM avec keras et tensorflow
J'ai essayé d'implémenter Deep VQE
J'ai essayé de mettre en place une validation contradictoire
J'ai essayé d'expliquer l'ensemble de données de Pytorch
J'ai essayé d'implémenter DeepPose avec PyTorch
J'ai essayé d'implémenter Realness GAN
J'ai essayé d'implémenter une ligne moyenne mobile de volume avec Quantx
J'ai essayé de mettre en œuvre une évasion (type d'évitement de tromperie) avec Quantx
J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai essayé d'implémenter ListNet d'apprentissage de rang avec Chainer
J'ai essayé de mettre en œuvre le chapeau de regroupement de Harry Potter avec CNN
J'ai essayé d'implémenter PLSA en Python
J'ai essayé d'implémenter la permutation en Python
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de commencer avec Hy
J'ai essayé d'implémenter PLSA dans Python 2
[Introduction à Pytorch] J'ai joué avec sinGAN ♬
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé d'implémenter PPO en Python
J'ai essayé de résoudre TSP avec QAOA
J'ai essayé de déboguer.
J'ai essayé de prédire l'année prochaine avec l'IA
J'ai essayé d'utiliser lightGBM, xg boost avec Boruta
J'ai essayé d'apprendre le fonctionnement logique avec TF Learn
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé de détecter rapidement un mouvement avec OpenCV
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé de sortir LLVM IR avec Python
J'ai essayé d'implémenter TOPIC MODEL en Python
J'ai essayé de détecter un objet avec M2Det!
J'ai essayé d'automatiser la fabrication des sushis avec python
J'ai essayé d'utiliser Linux avec Discord Bot
J'ai essayé d'implémenter le tri sélectif en python
J'ai essayé d'étudier DP avec séquence de Fibonacci
J'ai essayé de démarrer Jupyter avec toutes les lumières d'Amazon
J'ai essayé de juger Tundele avec Naive Bays
J'ai essayé de mettre en œuvre le problème du voyageur de commerce
J'ai essayé de mettre en œuvre un apprentissage en profondeur qui n'est pas profond avec uniquement NumPy
J'ai essayé de mettre en œuvre une blockchain qui fonctionne réellement avec environ 170 lignes
J'ai essayé d'entraîner la fonction péché avec chainer
J'ai essayé d'implémenter le tri par fusion en Python avec le moins de lignes possible
J'ai essayé fp-growth avec python
J'ai essayé de gratter avec Python
J'ai essayé de mettre en œuvre la gestion des processus statistiques multivariés (MSPC)