[PYTHON] I tried to implement CVAE with PyTorch

For my own practice, I implemented and learned CVAE, which is a type of deep learning. This article is a memo-level description and is written on the assumption that you have knowledge of VAE. Please note.

Environment
  • OS: Windows10
  • Python: 3.7.5
  • CUDA: 9.2
  • numpy: 1.18.1
  • torch: 1.4.0+cu92
  • torchvision: 0.5.0+cu92
  • matplotlib: 3.1.3

It is also implemented using Jupyter Notebook.

Reference article

Here are the pages that I referred to when implementing.

-[Qiita] Variational Autoencoder Thorough Explanation -[Qiita] Journey around the deep generative model (2): VAE

In addition, I also refer to the example implementation of Pytorch.

What is CVAE

** CVAE (Conditional Variational Auto Encoder) ** is an advanced method of VAE. In normal VAE, data is input to Encoder and latent variables are input to Decoder, but in CVAE, the state of data is added to these. This gives you the following benefits:

--When deleting dimensions with Encoder, features other than data labels can be reflected. --When generating data with Decoder, you can specify the state of the desired data.

Implementation and learning

This time, we will implement CVAE with Pytorch and train MNIST (handwritten character data set).

python


import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline

DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50

# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)   
torch.cuda.manual_seed(SEED)


class CVAE(nn.Module):
    def __init__(self, zdim):
        super().__init__()
        self._zdim = zdim
        self._in_units = 28 * 28
        hidden_units = 512
        self._encoder = nn.Sequential(
            nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
        )
        self._to_mean = nn.Linear(hidden_units, zdim)
        self._to_lnvar = nn.Linear(hidden_units, zdim)
        self._decoder = nn.Sequential(
            nn.Linear(zdim + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, self._in_units),
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
        in_[:, :self._in_units] = x
        in_[:, self._in_units:] = labels
        h = self._encoder(in_)
        mean = self._to_mean(h)
        lnvar = self._to_lnvar(h)
        return mean, lnvar

    def decode(self, z, labels):
        in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
        in_[:, :self._zdim] = z
        in_[:, self._zdim:] = labels
        return self._decoder(in_)


def to_onehot(label):
    return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]


# Train
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
for e in range(NUM_EPOCHS):
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        labels = to_onehot(labels)
        # Reconstruction images
        # Encode images
        x = images.view(-1, 28*28*1).to(DEVICE)
        mean, lnvar = model.encode(x, labels)
        std = lnvar.exp().sqrt()
        epsilon = torch.randn(ZDIM, device=DEVICE)
        
        # Decode latent variables
        z = mean + std * epsilon
        y = model.decode(z, labels)
        
        # Compute loss
        kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
        bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
        loss = (-1 * kld + bce).mean()

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.shape[0]

    print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')

result

epoch: 1 epoch_loss: 200.2185905436198
epoch: 2 epoch_loss: 160.22688263346353
epoch: 3 epoch_loss: 148.69330817057292

#Omission

epoch: 48 epoch_loss: 98.95304524739583
epoch: 49 epoch_loss: 98.6720672281901
epoch: 50 epoch_loss: 98.65486107177735

The following is a list of implementation and learning points.

--Use 6000 training data of torchvision.datasets.MNIST for learning and set the number of epochs to 50. --Design a CVAE class with Encoder and Decoder and implement ʻencode and decodemethods without implementingforward` --Convert the dataset label (written number) to a one-hot vector and add it to the Encoder and Decoder inputs --The mini-batch size for learning is a large 256 [^ 1] --Consists of simple MLP for both Encoder and Decoder --Set the dimension of the latent variable to 16.

Image generation by CVAE

VAE has two applications, dimension deletion and data generation, but this time we will focus on data generation. Consider creating a new handwritten image using the CVAE Decoder that you learned earlier.

Generation of "5" images

The label information given to the Decoder is fixed to "5", 100 random numbers that follow the standard normal distribution are generated, and the corresponding image is generated.

python


# Generation data with label '5'
NUM_GENERATION = 100

os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
    z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
    label = torch.tensor([5], device=DEVICE)
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()

    # Save image
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
    plt.close(fig) 

result

img_cvae.png

Some of them are out of shape, but we are able to generate various "5" images.

Generation of thick numeric images

I searched for the bold numbers in the test image of torchvision.datasets.MNIST. The following image is the 49th image in the dataset.

fat_digit.png

It is written very thickly as "4". Use Encoder to find the latent variable corresponding to this data.

python


test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transforms.ToTensor(),
    download=True,
)
target_image, label = list(test_dataset)[48]

x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
    mean, _ = model.encode(x, to_onehot(label))
z = mean

print(f'z = {z.cpu().detach().numpy().squeeze()}')

result

z = [ 0.7933388   2.4768877   0.49229255 -0.09540698 -1.7999544   0.03376897
  0.01600834  1.3863252   0.14656337 -0.14543885  0.04157912  0.13938689
 -0.2016176   0.5204378  -0.08096244  1.0930295 ]

This 16-dimensional vector has the image information of the label ** other than ** given at the time of learning. In other words, it should have the information "very thick", not the information "it is in the form of 4".

Therefore, using this latent variable, let's generate an image while changing the label information given to the Decoder.

python


os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/fat/img{label}')
    plt.close(fig) 

result

fat_generated.png

"2" is a little suspicious, but I was able to generate an image with thick numbers.

in conclusion

I knew about CVAE for a long time, but this was the first time I implemented it. I'm glad it worked. It is important to implement it as well as knowledge. Some of the generated images didn't look pretty, but they may be resolved by using convolution or transpose convolution in the VAE network. Although omitted this time, the VAE system recognizes that it is important to analyze which features are mapped where in the low-dimensional space. I would like to do that analysis this time.

[^ 1]: This is to ensure that the data with all labels exists in the mini-batch so that the image of the mini-batch by Encoder follows a standard normal distribution in the latent variable space.

Recommended Posts

I tried to implement CVAE with PyTorch
I tried to implement reading Dataset with PyTorch
I tried to implement and learn DCGAN with PyTorch
I tried to implement SSD with PyTorch now (Dataset)
I tried to implement Autoencoder with TensorFlow
I tried to implement SSD with PyTorch now (model edition)
I tried to detect Mario with pytorch + yolov3
I tried to implement StarGAN (1)
I tried to move Faster R-CNN quickly with pytorch
I tried to implement Minesweeper on terminal with python
I tried to implement an artificial perceptron with python
I tried to implement time series prediction with GBDT
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I tried to implement Grad-CAM with keras and tensorflow
I tried to implement Deep VQE
I tried to implement adversarial validation
I tried to explain Pytorch dataset
I tried implementing DeepPose with PyTorch
I tried to implement hierarchical clustering
I tried to implement Realness GAN
I tried to implement a volume moving average with Quantx
I tried to implement breakout (deception avoidance type) with Quantx
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement ListNet of rank learning with Chainer
I tried to implement Harry Potter sort hat with CNN
I tried to implement PLSA in Python
I tried to implement permutation in Python
I tried to visualize AutoEncoder with TensorFlow
I tried to get started with Hy
I tried to implement PLSA in Python 2
[Introduction to Pytorch] I played with sinGAN ♬
I tried implementing DeepPose with PyTorch PartⅡ
I tried to implement PPO in Python
I tried to solve TSP with QAOA
I tried to debug.
I tried to paste
I tried to predict next year with AI
I tried to use lightGBM, xgboost with Boruta
I tried to learn logical operations with TF Learn
I tried to move GAN (mnist) with keras
I tried to save the data with discord
I tried to detect motion quickly with OpenCV
I tried to integrate with Keras in TFv1.1
I tried to output LLVM IR with Python
I tried to implement TOPIC MODEL in Python
I tried to detect an object with M2Det!
I tried to automate sushi making with python
I tried to operate Linux with Discord Bot
I tried to implement selection sort in python
I tried to study DP with Fibonacci sequence
I tried to start Jupyter with Amazon lightsail
I tried to judge Tsundere with Naive Bayes
I tried to implement the traveling salesman problem
I tried to implement deep learning that is not deep with only NumPy
I tried to implement a blockchain that actually works with about 170 lines
I tried to learn the sin function with chainer
I tried to implement merge sort in Python with as few lines as possible
I tried fp-growth with python
I tried scraping with Python
I tried to create a table only with Django
I tried to implement multivariate statistical process management (MSPC)