[PYTHON] Ich habe versucht, CVAE mit PyTorch zu implementieren

Für meine eigene Praxis habe ich CVAE implementiert und geschult, eine Art Deep Learning. Dieser Artikel ist eine Beschreibung auf Memoebene und basiert auf der Annahme, dass Sie Kenntnisse über VAE haben. Bitte beachten Sie.

Umgebung
  • OS: Windows10
  • Python: 3.7.5
  • CUDA: 9.2
  • numpy: 1.18.1
  • torch: 1.4.0+cu92
  • torchvision: 0.5.0+cu92
  • matplotlib: 3.1.3

Es wird auch mit Jupyter Notebook implementiert.

Referenzartikel

Hier sind einige Seiten, auf die ich bei der Implementierung verwiesen habe.

Darüber hinaus verweise ich auch auf die Beispielimplementierung von Pytorch.

Was ist CVAE?

** CVAE (Conditional Variational Auto Encoder) ** ist eine fortschrittliche VAE-Methode. In der normalen VAE werden Daten in den Encoder eingegeben und latente Variablen in den Decoder, aber in der CVAE wird der Datenstatus zu diesen hinzugefügt. Dies bietet Ihnen folgende Vorteile:

Implementierung und Lernen

Dieses Mal werden wir CVAE mit Pytorch implementieren und MNIST (Datensatz handgeschriebener Zeichen) trainieren.

python


import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline

DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50

# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)   
torch.cuda.manual_seed(SEED)


class CVAE(nn.Module):
    def __init__(self, zdim):
        super().__init__()
        self._zdim = zdim
        self._in_units = 28 * 28
        hidden_units = 512
        self._encoder = nn.Sequential(
            nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
        )
        self._to_mean = nn.Linear(hidden_units, zdim)
        self._to_lnvar = nn.Linear(hidden_units, zdim)
        self._decoder = nn.Sequential(
            nn.Linear(zdim + CLASS_SIZE, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, hidden_units),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_units, self._in_units),
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
        in_[:, :self._in_units] = x
        in_[:, self._in_units:] = labels
        h = self._encoder(in_)
        mean = self._to_mean(h)
        lnvar = self._to_lnvar(h)
        return mean, lnvar

    def decode(self, z, labels):
        in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
        in_[:, :self._zdim] = z
        in_[:, self._zdim:] = labels
        return self._decoder(in_)


def to_onehot(label):
    return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]


# Train
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
for e in range(NUM_EPOCHS):
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        labels = to_onehot(labels)
        # Reconstruction images
        # Encode images
        x = images.view(-1, 28*28*1).to(DEVICE)
        mean, lnvar = model.encode(x, labels)
        std = lnvar.exp().sqrt()
        epsilon = torch.randn(ZDIM, device=DEVICE)
        
        # Decode latent variables
        z = mean + std * epsilon
        y = model.decode(z, labels)
        
        # Compute loss
        kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
        bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
        loss = (-1 * kld + bce).mean()

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.shape[0]

    print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')

Ergebnis

epoch: 1 epoch_loss: 200.2185905436198
epoch: 2 epoch_loss: 160.22688263346353
epoch: 3 epoch_loss: 148.69330817057292

#Unterlassung

epoch: 48 epoch_loss: 98.95304524739583
epoch: 49 epoch_loss: 98.6720672281901
epoch: 50 epoch_loss: 98.65486107177735

Unten finden Sie eine Liste der Implementierungs- und Lernpunkte.

--Verwenden Sie 6000 Trainingsdaten von torchvision.datasets.MNIST zum Lernen und setzen Sie die Anzahl der Epochen auf 50.

Bilderzeugung durch CVAE

VAE hat zwei Anwendungen, das Löschen von Dimensionen und die Datengenerierung. Dieses Mal konzentrieren wir uns jedoch auf die Datengenerierung. Erstellen Sie ein neues handgeschriebenes Bild mit dem zuvor erlernten CVAE-Decoder.

Erzeugung von "5" Bildern

Die dem Decoder gegebenen Etiketteninformationen sind auf "5" festgelegt, 100 Zufallszahlen, die der Standardnormalverteilung folgen, werden erzeugt und das entsprechende Bild wird erzeugt.

python


# Generation data with label '5'
NUM_GENERATION = 100

os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
    z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
    label = torch.tensor([5], device=DEVICE)
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()

    # Save image
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
    plt.close(fig) 

Ergebnis

img_cvae.png

Einige von ihnen sind nicht in Form, aber wir können verschiedene "5" -Bilder erzeugen.

Erzeugung eines dicken Zahlenbildes

Ich habe nach den fetten Zahlen im Testbild von torchvision.datasets.MNIST gesucht. Das folgende Bild ist das 49. Bild im Datensatz.

fat_digit.png

Es ist sehr dick als "4" geschrieben. Verwenden Sie Encoder, um die latente Variable zu finden, die diesen Daten entspricht.

python


test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transforms.ToTensor(),
    download=True,
)
target_image, label = list(test_dataset)[48]

x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
    mean, _ = model.encode(x, to_onehot(label))
z = mean

print(f'z = {z.cpu().detach().numpy().squeeze()}')

Ergebnis

z = [ 0.7933388   2.4768877   0.49229255 -0.09540698 -1.7999544   0.03376897
  0.01600834  1.3863252   0.14656337 -0.14543885  0.04157912  0.13938689
 -0.2016176   0.5204378  -0.08096244  1.0930295 ]

Dieser 16-dimensionale Vektor enthält die Informationen des Bildes von ** außer dem Etikett **, das zum Zeitpunkt des Trainings angegeben wurde. Mit anderen Worten, Sie sollten die Information "sehr dick" haben, nicht die Information "es ist in Form von 4".

Versuchen Sie daher, mit dieser latenten Variablen ein Bild zu generieren, während Sie die dem Decoder übergebenen Beschriftungsinformationen ändern.

python


os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
    with torch.no_grad():
        y = model.decode(z, to_onehot(label))
    y = y.reshape(28, 28).cpu().detach().numpy()
    fig, ax = plt.subplots()
    ax.imshow(y)
    ax.set_title(f'Generation(label={label})')
    ax.tick_params(
        labelbottom=False,
        labelleft=False,
        bottom=False,
        left=False,
    )
    plt.savefig(f'img/cvae/generation/fat/img{label}')
    plt.close(fig) 

Ergebnis

fat_generated.png

"2" ist etwas verdächtig, aber ich kann ein Bild mit dicken Zahlen erzeugen.

abschließend

Ich wusste lange über CVAE Bescheid, aber dies war das erste Mal, dass ich es implementierte. Ich bin froh, dass es funktioniert hat. Es ist wichtig, es nicht nur zu wissen, sondern auch umzusetzen. Einige der generierten Bilder sahen nicht besonders hübsch aus, können jedoch mithilfe der Faltung oder Translokationsfaltung im VAE-Netzwerk aufgelöst werden. Obwohl diesmal weggelassen, erkennt das VAE-System, dass es wichtig ist zu analysieren, welche Merkmale wo im niedrigdimensionalen Raum abgebildet werden. Diesmal möchte ich diese Analyse durchführen.

[^ 1]: Damit sind die Daten mit allen Beschriftungen im Mini-Batch vorhanden, sodass das Bild des Mini-Batches von Encoder der Standardnormalverteilung im latenten Variablenraum folgt.

Recommended Posts

Ich habe versucht, CVAE mit PyTorch zu implementieren
Ich habe versucht, das Lesen von Dataset mit PyTorch zu implementieren
Ich habe versucht, DCGAN mit PyTorch zu implementieren und zu lernen
Ich habe versucht, SSD jetzt mit PyTorch zu implementieren (Dataset)
Ich habe versucht, Autoencoder mit TensorFlow zu implementieren
Ich habe versucht, SSD jetzt mit PyTorch zu implementieren (Modellversion)
Ich habe versucht, StarGAN (1) zu implementieren.
Ich habe versucht, Faster R-CNN mit Pytorch auszuführen
Ich habe versucht, Mine Sweeper auf dem Terminal mit Python zu implementieren
Ich habe versucht, künstliches Perzeptron mit Python zu implementieren
[Einführung in Pytorch] Ich habe versucht, Cifar10 mit VGG16 ♬ zu kategorisieren
Ich habe versucht, Grad-CAM mit Keras und Tensorflow zu implementieren
Ich habe versucht, Deep VQE zu implementieren
Ich habe versucht, eine kontroverse Validierung zu implementieren
Ich habe versucht, Pytorchs Datensatz zu erklären
Ich habe versucht, DeepPose mit PyTorch zu implementieren
Ich habe versucht, Realness GAN zu implementieren
Ich habe versucht, mit Quantx eine Linie mit gleitendem Durchschnitt des Volumens zu implementieren
Ich habe versucht, mit Quantx einen Ausbruch (Typ der Täuschungsvermeidung) zu implementieren
Ich habe versucht, MNIST nach GNN zu klassifizieren (mit PyTorch-Geometrie).
Ich habe versucht, ListNet of Rank Learning mit Chainer zu implementieren
Ich habe versucht, Harry Potters Gruppierungshut mit CNN umzusetzen
Ich habe versucht, PLSA in Python zu implementieren
Ich habe versucht, Permutation in Python zu implementieren
Ich habe versucht, AutoEncoder mit TensorFlow zu visualisieren
Ich habe versucht, mit Hy anzufangen
Ich habe versucht, PLSA in Python 2 zu implementieren
[Einführung in Pytorch] Ich habe mit sinGAN ♬ gespielt
Ich habe versucht, DeepPose mit PyTorch PartⅡ zu implementieren
Ich habe versucht, PPO in Python zu implementieren
Ich habe versucht, TSP mit QAOA zu lösen
Ich habe versucht zu debuggen.
Ich habe versucht, nächstes Jahr mit AI vorherzusagen
Ich habe versucht, lightGBM, xg Boost mit Boruta zu verwenden
Ich habe versucht, mit TF Learn die logische Operation zu lernen
Ich habe versucht, GAN (mnist) mit Keras zu bewegen
Ich habe versucht, die Daten mit Zwietracht zu speichern
Ich habe versucht, mit OpenCV Bewegungen schnell zu erkennen
Ich habe versucht, Keras in TFv1.1 zu integrieren
Ich habe versucht, LLVM IR mit Python auszugeben
Ich habe versucht, TOPIC MODEL in Python zu implementieren
Ich habe versucht, ein Objekt mit M2Det zu erkennen!
Ich habe versucht, die Herstellung von Sushi mit Python zu automatisieren
Ich habe versucht, Linux mit Discord Bot zu betreiben
Ich habe versucht, eine selektive Sortierung in Python zu implementieren
Ich habe versucht, DP mit Fibonacci-Sequenz zu studieren
Ich habe versucht, Jupyter mit allen Amazon-Lichtern zu starten
Ich habe versucht, Tundele mit Naive Bays zu beurteilen
Ich habe versucht, das Problem des Handlungsreisenden umzusetzen
Ich habe versucht, Deep Learning zu implementieren, das nicht nur mit NumPy tiefgreifend ist
Ich habe versucht, eine Blockchain zu implementieren, die tatsächlich mit ungefähr 170 Zeilen funktioniert
Ich habe versucht, die Sündenfunktion mit Chainer zu trainieren
Ich habe versucht, die Zusammenführungssortierung in Python mit möglichst wenigen Zeilen zu implementieren
Ich habe fp-Wachstum mit Python versucht
Ich habe versucht, mit Python zu kratzen
Ich habe versucht, ein multivariates statistisches Prozessmanagement (MSPC) zu implementieren.