Für meine eigene Praxis habe ich CVAE implementiert und geschult, eine Art Deep Learning. Dieser Artikel ist eine Beschreibung auf Memoebene und basiert auf der Annahme, dass Sie Kenntnisse über VAE haben. Bitte beachten Sie.
Es wird auch mit Jupyter Notebook implementiert.
Hier sind einige Seiten, auf die ich bei der Implementierung verwiesen habe.
Darüber hinaus verweise ich auch auf die Beispielimplementierung von Pytorch.
** CVAE (Conditional Variational Auto Encoder) ** ist eine fortschrittliche VAE-Methode. In der normalen VAE werden Daten in den Encoder eingegeben und latente Variablen in den Decoder, aber in der CVAE wird der Datenstatus zu diesen hinzugefügt. Dies bietet Ihnen folgende Vorteile:
Dieses Mal werden wir CVAE mit Pytorch implementieren und MNIST (Datensatz handgeschriebener Zeichen) trainieren.
python
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline
DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50
# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
class CVAE(nn.Module):
def __init__(self, zdim):
super().__init__()
self._zdim = zdim
self._in_units = 28 * 28
hidden_units = 512
self._encoder = nn.Sequential(
nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
)
self._to_mean = nn.Linear(hidden_units, zdim)
self._to_lnvar = nn.Linear(hidden_units, zdim)
self._decoder = nn.Sequential(
nn.Linear(zdim + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, self._in_units),
nn.Sigmoid()
)
def encode(self, x, labels):
in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
in_[:, :self._in_units] = x
in_[:, self._in_units:] = labels
h = self._encoder(in_)
mean = self._to_mean(h)
lnvar = self._to_lnvar(h)
return mean, lnvar
def decode(self, z, labels):
in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
in_[:, :self._zdim] = z
in_[:, self._zdim:] = labels
return self._decoder(in_)
def to_onehot(label):
return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]
# Train
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for e in range(NUM_EPOCHS):
train_loss = 0
for i, (images, labels) in enumerate(train_loader):
labels = to_onehot(labels)
# Reconstruction images
# Encode images
x = images.view(-1, 28*28*1).to(DEVICE)
mean, lnvar = model.encode(x, labels)
std = lnvar.exp().sqrt()
epsilon = torch.randn(ZDIM, device=DEVICE)
# Decode latent variables
z = mean + std * epsilon
y = model.decode(z, labels)
# Compute loss
kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
loss = (-1 * kld + bce).mean()
# Update model
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * x.shape[0]
print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')
Ergebnis
epoch: 1 epoch_loss: 200.2185905436198
epoch: 2 epoch_loss: 160.22688263346353
epoch: 3 epoch_loss: 148.69330817057292
#Unterlassung
epoch: 48 epoch_loss: 98.95304524739583
epoch: 49 epoch_loss: 98.6720672281901
epoch: 50 epoch_loss: 98.65486107177735
Unten finden Sie eine Liste der Implementierungs- und Lernpunkte.
--Verwenden Sie 6000 Trainingsdaten von torchvision.datasets.MNIST
zum Lernen und setzen Sie die Anzahl der Epochen auf 50.
VAE hat zwei Anwendungen, das Löschen von Dimensionen und die Datengenerierung. Dieses Mal konzentrieren wir uns jedoch auf die Datengenerierung. Erstellen Sie ein neues handgeschriebenes Bild mit dem zuvor erlernten CVAE-Decoder.
Die dem Decoder gegebenen Etiketteninformationen sind auf "5" festgelegt, 100 Zufallszahlen, die der Standardnormalverteilung folgen, werden erzeugt und das entsprechende Bild wird erzeugt.
python
# Generation data with label '5'
NUM_GENERATION = 100
os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
label = torch.tensor([5], device=DEVICE)
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
# Save image
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
plt.close(fig)
Ergebnis
Einige von ihnen sind nicht in Form, aber wir können verschiedene "5" -Bilder erzeugen.
Ich habe nach den fetten Zahlen im Testbild von torchvision.datasets.MNIST
gesucht.
Das folgende Bild ist das 49. Bild im Datensatz.
Es ist sehr dick als "4" geschrieben. Verwenden Sie Encoder, um die latente Variable zu finden, die diesen Daten entspricht.
python
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
transform=transforms.ToTensor(),
download=True,
)
target_image, label = list(test_dataset)[48]
x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
mean, _ = model.encode(x, to_onehot(label))
z = mean
print(f'z = {z.cpu().detach().numpy().squeeze()}')
Ergebnis
z = [ 0.7933388 2.4768877 0.49229255 -0.09540698 -1.7999544 0.03376897
0.01600834 1.3863252 0.14656337 -0.14543885 0.04157912 0.13938689
-0.2016176 0.5204378 -0.08096244 1.0930295 ]
Dieser 16-dimensionale Vektor enthält die Informationen des Bildes von ** außer dem Etikett **, das zum Zeitpunkt des Trainings angegeben wurde. Mit anderen Worten, Sie sollten die Information "sehr dick" haben, nicht die Information "es ist in Form von 4".
Versuchen Sie daher, mit dieser latenten Variablen ein Bild zu generieren, während Sie die dem Decoder übergebenen Beschriftungsinformationen ändern.
python
os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/fat/img{label}')
plt.close(fig)
Ergebnis
"2" ist etwas verdächtig, aber ich kann ein Bild mit dicken Zahlen erzeugen.
Ich wusste lange über CVAE Bescheid, aber dies war das erste Mal, dass ich es implementierte. Ich bin froh, dass es funktioniert hat. Es ist wichtig, es nicht nur zu wissen, sondern auch umzusetzen. Einige der generierten Bilder sahen nicht besonders hübsch aus, können jedoch mithilfe der Faltung oder Translokationsfaltung im VAE-Netzwerk aufgelöst werden. Obwohl diesmal weggelassen, erkennt das VAE-System, dass es wichtig ist zu analysieren, welche Merkmale wo im niedrigdimensionalen Raum abgebildet werden. Diesmal möchte ich diese Analyse durchführen.
[^ 1]: Damit sind die Daten mit allen Beschriftungen im Mini-Batch vorhanden, sodass das Bild des Mini-Batches von Encoder der Standardnormalverteilung im latenten Variablenraum folgt.
Recommended Posts