Ich habe versucht, ein Junk-Zeichen MNIST mit einer Art GAN, cGAN (Conditional GAN), zu generieren. Detaillierte theoretische Aspekte finden Sie unter den Links, die gegebenenfalls hilfreich sind.
Ich hoffe, es wird für diejenigen hilfreich sein, die mögen.
Über den diesmal verwendeten Datensatz. KMNIST ist ein Datensatz, der für maschinelles Lernen als Ableitung des vom Humanities Open Data Sharing Center erstellten "Japanese Classics Kuzuji Data Set" erstellt wurde. Sie können es von GitHub Link herunterladen.
"KMNIST-Datensatz" (erstellt von CODH) "Japanischer Klassiker-Kuzuji-Datensatz" (Kokubunken et al.) Adapted doi: 10.20676 / 00000341
Wie diese MNIST (handschriftliche Nummer), die jeder kennt, der maschinelles Lernen durchgeführt hat, ist ein Bild 1 x 28 x 28 Pixel groß.
Die folgenden drei Arten von Datasets können im komprimierten Format von numpy.array aus dem Repository heruntergeladen werden.
--kuzushiji-MNIST (10 Zeichen von Hiragana) --kuzushiji-49 (49 Zeichen von Hiragana)
Von diesen wird diesmal "kuzushiji-49" verwendet. Es gibt keinen besonders tiefen Grund, aber wenn 49 Hiragana-Zeichen gezielt und generiert werden können, ist es dann möglich, handgeschriebene Sätze zu generieren? Ich dachte, es wäre eine leichte Motivation.
Lassen Sie uns kurz auf GAN vor cGAN eingehen. GAN ist eine Abkürzung für "Generative Adversarial Network" (= feindliches Generationsnetzwerk) und eine Art Generationsmodell für tiefes Lernen. Es ist besonders effektiv auf dem Gebiet der Bilderzeugung, und ich denke, dass das Ergebnis der Erzeugung von Gesichtsbildern von Menschen, die es auf der Welt nicht gibt, berühmt ist.
Das Folgende ist ein grobes Modelldiagramm von GAN. "G" steht für Generator und "D" steht für Diskriminator.
Der Generator erzeugt aus dem Rauschen ein gefälschtes Bild, das so real wie möglich ist. Der Diskriminator unterscheidet zwischen dem aus dem Datensatz entnommenen realen Bild (real_img) und dem vom Generator erstellten gefälschten Bild (fake_img) (True oder False).
Durch Wiederholen dieses Lernens versucht der Generator, ein Bild zu erstellen, das der realen Sache so nahe wie möglich kommt, das der Diskriminator nicht erkennen kann, und der Diskriminator versucht, die vom Generator erstellte Fälschung und die aus dem Datensatz abgeleitete reale Sache zu erkennen, sodass der Generator generiert wird. Die Genauigkeit wird erhöht.
Referenzartikel
GAN (1) Verständnis der Grundstruktur, die ich nicht mehr hören kann
GAN-bezogene Artikel sind in This GitHub Repository organisiert.
Als nächstes möchte ich über das diesmal verwendete bedingte GAN sprechen. Einfach ausgedrückt ist es ** "GAN, das das gewünschte Bild erzeugen kann" **. Die Idee ist einfach: Es ist, als würde man das zu erzeugende Bild durch Hinzufügen von Etiketteninformationen zur Eingabe von Diskriminator und Generator entscheiden.
Das Originalpapier ist hier
Es ist dasselbe wie ein normales GAN, außer dass ** "Geben Sie das Etikett während des Trainings ein" **. Obwohl Etiketteninformationen verwendet werden, bestimmt Discriminator nur, "ob das Bild echt ist".
Referenzartikel
Implementierung von GAN (6) Bedingte GAN, die jetzt nicht gehört werden kann
Kommen wir nun zur Implementierung.
Ich habe jupyterlab unter Ubuntu 18.04 installiert und ausgeführt.
Importieren Sie das gewünschte Modul
python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import random
Laden Sie die Daten im Numpy-Format von KMNISTs Github herunter.
Mit jupyterlab nach dem Öffnen des Terminals und dem Verschieben des Repositorys
wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-imgs.npz
wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-labels.npz
Dann können Sie das Bild und das Etikett herunterladen.
Übrigens, wenn es sich um KMNIST mit 10 Hiragana-Zeichen handelt, ist es standardmäßig in der Fackelvision enthalten. Wenn es Ihnen nichts ausmacht, genau wie bei normalem MNIST
python
transform = transforms.Compose(
[transforms.ToTensor(),
])
train_data_10 = torchvision.datasets.KMNIST(root='./data', train=True,download=True,transform=transform)
Sie können es verwenden, wenn Sie dies tun.
Wenn Sie mit PyTorch ein eigenes benutzerdefiniertes Dataset erstellen möchten, müssen Sie die Vorverarbeitung selbst definieren. Die bildbasierte Vorverarbeitung ist meistens in "torchvision.transforms" enthalten, daher verwende ich diese häufig, aber Sie können auch Ihre eigenen erstellen.
python
class Transform(object):
def __init__(self):
pass
def __call__(self, sample):
sample = np.array(sample, dtype = np.float32)
sample = torch.tensor(sample)
return (sample/127.5)-1
transform = Transform()
Die meisten von numpy behandelten Brüche sind "np.float64" (Gleitkommazahl 64 Bit), aber PyTorch verarbeitet den Bruchwert standardmäßig mit der Gleitkommazahl 32 Bit, sodass ein Fehler auftritt, wenn sie nicht ausgerichtet sind.
Zusätzlich wird hier die Verarbeitung durchgeführt, um den Helligkeitswert des Bildes auf den Bereich von [-1,1] zu normalisieren. Dies liegt daran, dass "Tanh" in der letzten Ebene der Generatorausgabe verwendet wird, die später ausgegeben wird, sodass der Helligkeitswert des realen Bildes entsprechend angepasst wird.
Als nächstes definieren wir die Dataset-Klasse. Dies ist ein Modul, das einen Satz von Daten und Beschriftungen zurückgibt und die Daten zurückgibt, die von der zuvor beim Abrufen der Daten definierten "Transformation" vorverarbeitet wurden.
python
from tqdm import tqdm
class dataset_full(torch.utils.data.Dataset):
def __init__(self, img, label, transform=None):
self.transform = transform
self.data_num = len(img)
self.data = []
self.label = []
for i in tqdm(range(self.data_num)):
self.data.append([img[i]])
self.label.append(label[i])
self.data_num = len(self.data)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.data[idx]
out_label = np.identity(49)[self.label[idx]]
out_label = np.array(out_label, dtype = np.float32)
if self.transform:
out_data = self.transform(out_data)
return out_data, out_label
Wenn Sie das erste "tqdm" eingeben, wird der Fortschritt wie ein Balkendiagramm angezeigt, wenn Sie die for-Anweisung drehen. Dies hat jedoch nichts mit dem cGAN selbst zu tun.
Ich verwende np.identity
, um einen One-Hot-Vektor mit einer Länge von 49 zu erstellen.
Erstellen Sie einen Datensatz mit den Klassen "Transformieren" und "Datensatz", die aus den zuvor heruntergeladenen Daten implementiert wurden.
python
path = %pwd
train_img = np.load('{}/k49-train-imgs.npz'.format(path))
train_img = train_img['arr_0']
train_label = np.load('{}/k49-train-labels.npz'.format(path))
train_label = train_label['arr_0']
train_data = dataset_full(train_img, train_label, transform=transform)
Wenn Sie das tqdm früher eingeben, wird der Fortschritt angezeigt, wenn Sie dies ausführen. Die meisten Daten sind 232.625, aber ich denke nicht, dass es lange dauern wird.
Wir haben einen Datensatz, aber wir rufen beim Training des Modells keine Daten direkt aus diesem Datensatz ab. Da wir Batch für Batch trainieren, definieren wir einen DataLoader, der Batch-Größendaten zurückgibt.
python
batch_size = 256
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers=2)
Wenn Sie "shuffle = True" setzen, sind die von DataLoader abgerufenen Daten zufällig. num_workers
ist ein Argument, das die Anzahl der von DataLoader verwendeten CPU-Kerne angibt und für das cGAN selbst nicht besonders relevant ist.
Der Transform-Dataset-DataLoader bis zu diesem Punkt ist in den folgenden Artikeln zusammengefasst.
Referenzartikel
Überprüfen Sie die grundlegende Funktionsweise von PyTorch-Transformationen / Dataset / DataLoader
Ich werde den Modellkörper machen. Der Generator erstellt aus Rauschen und Beschriftungen ein falsches Bild (fake_img).
Die Implementierungsmethode ist je nach Person sehr unterschiedlich, aber die Struktur des diesmal erstellten Generators ist wie folgt. (Es ist handgeschrieben, aber es tut mir leid ...) In der Eingabe ist "z_dim" (Rauschdimension) 30 und "num_class" (Anzahl der Klassen) 49 Hiragana-Zeichen, daher wird es auf 49 gesetzt. Das gefälschte Bild der Ausgabe hat die Form 1 (Kanal) x 28 (px) x 28 (px).
python
class Generator(nn.Module):
def __init__(self, z_dim, num_class):
super(Generator, self).__init__()
self.fc1 = nn.Linear(z_dim, 300)
self.bn1 = nn.BatchNorm1d(300)
self.LReLU1 = nn.LeakyReLU(0.2)
self.fc2 = nn.Linear(num_class, 1500)
self.bn2 = nn.BatchNorm1d(1500)
self.LReLU2 = nn.LeakyReLU(0.2)
self.fc3 = nn.Linear(1800, 128 * 7 * 7)
self.bn3 = nn.BatchNorm1d(128 * 7 * 7)
self.bo1 = nn.Dropout(p=0.5)
self.LReLU3 = nn.LeakyReLU(0.2)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), #Ändern Sie die Anzahl der Kanäle von 128 auf 64.
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), #Die Anzahl der Kanäle wurde von 64 auf 1 geändert
nn.Tanh(),
)
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.ConvTranspose2d):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
def forward(self, noise, labels):
y_1 = self.fc1(noise)
y_1 = self.bn1(y_1)
y_1 = self.LReLU1(y_1)
y_2 = self.fc2(labels)
y_2 = self.bn2(y_2)
y_2 = self.LReLU2(y_2)
x = torch.cat([y_1, y_2], 1)
x = self.fc3(x)
x = self.bo1(x)
x = self.LReLU3(x)
x = x.view(-1, 128, 7, 7)
x = self.deconv(x)
return x
Als nächstes kommt der Diskriminator. Der Diskriminator gibt das Original- / Fälschungsbild und seine Etiketteninformationen ein und bestimmt, ob es echt oder gefälscht ist.
Die Struktur des diesmal erstellten Diskriminators ist wie folgt. "img" (Eingabebild) ist 1 (Kanal) x 28 (px) x 28 (px) sowohl für echte als auch für gefälschte Bilder, und "label" (Eingabeetikett) ist ein 49-dimensionaler eindimensionaler Vektor. Die Ausgabe bestimmt, ob sie echt ist oder nicht, mit einem Wert von 0 bis 1.
Konzentrieren Sie die Bild- und Beschriftungsinformationen in Kanalrichtung mit "cat" in der Mitte. Ich denke, der zuvor erwähnte cGAN-Artikel ist in diesem Bereich leicht zu verstehen.
python
class Discriminator(nn.Module):
def __init__(self, num_class):
super(Discriminator, self).__init__()
self.num_class = num_class
self.conv = nn.Sequential(
nn.Conv2d(num_class + 1, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.BatchNorm2d(128),
)
self.fc = nn.Sequential(
nn.Linear(128 * 7 * 7, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 1),
nn.Sigmoid(),
)
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm1d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.normal_(1.0, 0.02)
module.bias.data.zero_()
def forward(self, img, labels):
y_2 = labels.view(-1, self.num_class, 1, 1)
y_2 = y_2.expand(-1, -1, 28, 28)
x = torch.cat([img, y_2], 1)
x = self.conv(x)
x = x.view(-1, 128 * 7 * 7)
x = self.fc(x)
return x
1 Erstellen Sie eine Funktion zur Berechnung der Epoche.
python
def train_func(D_model, G_model, batch_size, z_dim, num_class, criterion,
D_optimizer, G_optimizer, data_loader, device):
#Trainingsmodus
D_model.train()
G_model.train()
#Das echte Label ist 1
y_real = torch.ones((batch_size, 1)).to(device)
D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Geräuschetikett zum Einfügen von D.
#Gefälschtes Etikett ist 0
y_fake = torch.zeros((batch_size, 1)).to(device)
D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Geräuschetikett zum Einfügen von D.
#Initialisierung des Verlustes
D_running_loss = 0
G_running_loss = 0
#Chargenweise Berechnung
for batch_idx, (data, labels) in enumerate(data_loader):
#Ignorieren, wenn weniger als die Stapelgröße
if data.size()[0] != batch_size:
break
#Geräuschentwicklung
z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Durchschnitt 0.Generieren Sie Zufallszahlen nach einer Normalverteilung von 5
real_img, label, z = data.to(device), labels.to(device), z.to(device)
#Diskriminator-Update
D_optimizer.zero_grad()
#Setzen Sie ein reales Bild in Discriminator und verbreiten Sie vorwärts ⇒ Verlustberechnung
D_real = D_model(real_img, label)
D_real_loss = criterion(D_real, D_y_real)
#Fügen Sie das durch Einfügen von Rauschen in Generator in Discriminator erzeugte Bild ein und verbreiten Sie es vorwärts ⇒ Verlustberechnung
fake_img = G_model(z, label)
D_fake = D_model(fake_img.detach(), label) #fake_Stop Loss wird in Bildern berechnet, damit es nicht zurück zum Generator übertragen wird
D_fake_loss = criterion(D_fake, D_y_fake)
#Minimieren Sie die Summe von zwei Verlusten
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
D_optimizer.step()
D_running_loss += D_loss.item()
#Generator-Update
G_optimizer.zero_grad()
#Das Bild, das durch Einfügen von Rauschen in den Generator erzeugt wird, wird in den Diskriminator eingefügt und vorwärts weitergegeben. ⇒ Der erkannte Teil wird zu Verlust
fake_img_2 = G_model(z, label)
D_fake_2 = D_model(fake_img_2, label)
#G Verlust(max(log D)Optimiert mit)
G_loss = -criterion(D_fake_2, y_fake)
G_loss.backward()
G_optimizer.step()
G_running_loss += G_loss.item()
D_running_loss /= len(data_loader)
G_running_loss /= len(data_loader)
return D_running_loss, G_running_loss
Das "Kriterium", das im Argument erscheint, ist die Verlustklasse (in diesem Fall Binary Cross Entropy). Was wir mit dieser Funktion machen, ist in Ordnung
--Backpropagation von Fehlern, indem das reale Image des Datensatzes in Discriminator gestellt wird
ist.
Es ist ein wenig alt, aber diese Implementierung beinhaltet den Einfallsreichtum, der in "Wie man ein GAN trainiert" auf NIPS2016 erscheint, um das GAN-Lernen erfolgreich zu machen. GitHub-Link
Referenzartikel
14 Techniken zum Lernen von GAN (Generative Adversarial Networks)
Als ich die Dataset-Klasse erstellt habe
python
return (sample/127.5)-1
Ist das. Die letzte Ebene des Generators ist "nn.Tanh ()".
python
#G Verlust(max(log D)Optimiert mit)
G_loss = -criterion(D_fake_2, y_fake)
Ist das. "D_fake_2" ist das Urteil des Diskriminators, und "y_fake" ist ein 128 × 1 0-Vektor.
Probieren Sie das Rauschen, das in den Generator eingegeben werden soll, anhand einer Normalverteilung anstelle einer gleichmäßigen Verteilung aus.
python
#Geräuschentwicklung
z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Durchschnitt 0.Generieren Sie Zufallszahlen nach einer Normalverteilung von 5
Der Mittelwert und die Standardabweichung sind angemessen, aber wenn Sie mit einer gleichmäßigen Verteilung von [0,1] abtasten, erhalten Sie keinen negativen Wert, daher habe ich den abgetasteten Rauschwert fast positiv gemacht.
4.Batch Norm Alle Daten, die aus dem oben erstellten Data Loader stammen, sind ein echtes Bild. und umgekehrt
python
fake_img = G_model(z, label)
Aus den Etiketteninformationen und dem Rauschen von DataLoader erstellen wir dann gefälschte Bilder mit Stapelgröße.
LeakyReLU scheint sowohl für den Generator als auch für den Diskriminator wirksam zu sein, daher sind alle Aktivierungsfunktionen auf LeakyReLU eingestellt. Dem Argument 0.2 wurde gefolgt, da viele Implementierungen diesen Wert angenommen haben.
Das Discriminator-Label ist normalerweise 0 oder 1, aber wir fügen hier Rauschen hinzu. Probieren Sie nach dem Zufallsprinzip echte Etiketten von 0,7 bis 1,2 und gefälschte Etiketten von 0,0 bis 0,3 aus.
python
#Das echte Label ist 1
y_real = torch.ones((batch_size, 1)).to(device)
D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Geräuschetikett zum Einfügen von D.
#Gefälschtes Etikett ist 0
y_fake = torch.zeros((batch_size, 1)).to(device)
D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Geräuschetikett zum Einfügen von D.
Das ist der Teil. Normalerweise benutze ich "y_real" / "y_fake" und dieses Mal habe ich tatsächlich "D_y_real" / "D_y_fake" verwendet.
Dies ist ein alter Artikel, daher ist ein anderer Optimierer wie RAdam jetzt möglicherweise besser.
Dieses Mal habe ich Dropout nur einmal in die lineare Ebene des Generators eingefügt. Es gibt jedoch eine Theorie, dass BatchNorm und Dropout nicht miteinander kompatibel sind, daher denke ich nicht, dass es definitiv besser ist, sie alle zusammenzufügen.
Definieren Sie vor dem Training des Modells eine Funktion zum Anzeigen des vom Generator erstellten Bildes. Machen Sie dies und überprüfen Sie den Lerngrad des Generators für jede Epoche.
python
import os
from IPython.display import Image
from torchvision.utils import save_image
%matplotlib inline
def Generate_img(epoch, G_model, device, z_dim, noise, var_mode, labels, log_dir = 'logs_cGAN'):
G_model.eval()
with torch.no_grad():
if var_mode == True:
#Für die Generierung erforderliche Zufallszahlen
noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
else:
noise = noise
#Probengenerierung mit Generator
samples = G_model(noise, labels).data.cpu()
samples = (samples/2)+0.5
save_image(samples,os.path.join(log_dir, 'epoch_%05d.png' % (epoch)), nrow = 7)
img = Image('logs_cGAN/epoch_%05d.png' % (epoch))
display(img)
Alles, was Sie tun müssen, ist, das mit Rauschen im Generator erstellte Bild in einen Ordner namens "logs_cGAN" zu legen und anzuzeigen. Es wird angenommen, dass jedes Mal dieselbe Zufallszahl verwendet wird, wenn var_mode False ist.
Trainiere das Modell.
python
#Fester Startwert zur Gewährleistung der Reproduzierbarkeit
SEED = 1111
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
#device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def model_run(num_epochs, batch_size = batch_size, dataloader = train_loader, device = device):
#Dimension des Rauschens, das in den Generator eingefügt werden soll
z_dim = 30
var_mode = False #Gibt an, ob bei jeder Anzeige des Anzeigeergebnisses eine andere Zufallszahl verwendet werden soll
#Für die Generierung erforderliche Zufallszahlen
noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
#Anzahl der Klassen
num_class = 49
#Erstellen Sie ein Etikett, das Sie beim Ausprobieren von Generator verwenden können
labels = []
for i in range(num_class):
tmp = np.identity(num_class)[i]
tmp = np.array(tmp, dtype = np.float32)
labels.append(tmp)
label = torch.Tensor(labels).to(device)
#Modelldefinition
D_model = Discriminator(num_class).to(device)
G_model = Generator(z_dim, num_class).to(device)
#Definition von Verlust(Das Argument ist Zug_Spezifiziert in func)
criterion = nn.BCELoss().to(device)
#Definition des Optimierers
D_optimizer = torch.optim.Adam(D_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
G_optimizer = torch.optim.Adam(G_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
D_loss_list = []
G_loss_list = []
all_time = time.time()
for epoch in range(num_epochs):
start_time = time.time()
D_loss, G_loss = train_func(D_model, G_model, batch_size, z_dim, num_class, criterion,
D_optimizer, G_optimizer, dataloader, device)
D_loss_list.append(D_loss)
G_loss_list.append(G_loss)
secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60
#Ergebnisse nach Epoche anzeigen
print('Epoch: %d' %(epoch + 1), " |Benötigte Zeit%d Minuten%d Sekunden" %(mins, secs))
print(f'\tLoss: {D_loss:.4f}(Discriminator)')
print(f'\tLoss: {G_loss:.4f}(Generator)')
if (epoch + 1) % 1 == 0:
Generate_img(epoch, G_model, device, z_dim, noise, var_mode, label)
#Erstellen Sie eine Prüfpunktdatei zum Speichern des Modells
if (epoch + 1) % 5 == 0:
torch.save({
'epoch':epoch,
'model_state_dict':G_model.state_dict(),
'optimizer_state_dict':G_optimizer.state_dict(),
'loss':G_loss,
}, './checkpoint_cGAN/G_model_{}'.format(epoch + 1))
return D_loss_list, G_loss_list
#Drehen Sie das Modell
D_loss_list, G_loss_list = model_run(num_epochs = 100)
Es ist ziemlich lang, aber ich zeige die erforderliche Zeit und den Verlust für jede Epoche an und speichere das Modell.
Lassen Sie uns den Übergang des Verlusts von Generator und Diskriminator sehen.
python
import matplotlib.pyplot as plt
%matplotlib inline
fig = plt.figure(figsize=(10,7))
loss = fig.add_subplot(1,1,1)
loss.plot(range(len(D_loss_list)),D_loss_list,label='Discriminator_loss')
loss.plot(range(len(G_loss_list)),G_loss_list,label='Generator_loss')
loss.set_xlabel('epoch')
loss.set_ylabel('loss')
loss.legend()
loss.grid()
fig.show()
Ab etwa 20 Epochen haben sich beide Verluste nicht geändert. Sowohl der Diskriminator- als auch der Generatorverlust sind weit von 0 entfernt, daher scheint es ziemlich gut zu funktionieren. Übrigens, wenn Sie versuchen, die generierten Zeichen in der Reihenfolge von 1 bis 100 Epochen in Gifs umzuwandeln, sieht es so aus.
Oben links ist "A" und unten rechts ist "ゝ". Je nach Charakter gibt es einige Unterschiede, und es scheint, dass "u", "ku", "sa", "so" und "hi" stabil und gut erzeugt werden, aber "na" und "yu" Übergänge haben. Es ist heftig.
Nachfolgend sind die Ergebnisse der Generierung von 5 Bildern für jeden Typ aufgeführt. Epoch:5
Epoch:50
Epoch:100
Wenn man das alleine betrachtet, scheint es nicht besser zu sein, Epoche zu stapeln. "Mu" scheint in 5 Epochen das Beste zu sein, während "ゑ" in 100 Epochen das Beste zu sein scheint.
Übrigens sieht es so aus, wenn Sie 5 Trainingsdaten auf die gleiche Weise abrufen.
Es gibt einige Dinge, die selbst moderne Menschen nicht lesen können. "Su" und "mi" unterscheiden sich stark von ihren aktuellen Formen. Wenn ich das betrachte, denke ich, dass die Leistung des Modells ziemlich gut ist.
Ich habe versucht, Junk-Zeichen mit cGAN zu generieren. Ich denke, es gibt noch viel Raum für Verbesserungen bei der Implementierung, aber ich denke, das Ergebnis selbst ist vernünftig. Es ist lange her, aber ich hoffe, es hilft sogar einem Teil davon.
Einige Leute haben auch allgemeine MNIST (handschriftliche Zahlen) mit PyTorch in cGAN implementiert. Es gibt viele verschiedene Teile wie die Modellstruktur, daher denke ich, dass dies auch hilfreich ist.
Referenzartikel
Ich habe versucht, handschriftliche Zeichen durch tiefes Lernen zu generieren [Pytorch x MNIST x CGAN]
Ursprünglich war ich leicht motiviert zu denken: "Ist es möglich, handgeschriebene Sätze zu generieren?", Also werde ich es am Ende versuchen.
Laden Sie das Modellgewicht aus der gespeicherten Prüfpunktdatei und versuchen Sie es einmal mit pkl.
python
import cloudpickle
%matplotlib inline
#Geben Sie die abzurufende Epoche an
point = 50
#Definieren Sie die Struktur des Modells
z_dim = 30
num_class = 49
G = Generator(z_dim = z_dim, num_class = num_class)
#Checkpoint extrahieren
checkpoint = torch.load('./checkpoint_cGAN/G_model_{}'.format(point))
#Parameter in Generator einfügen
G.load_state_dict(checkpoint['model_state_dict'])
#Bleiben Sie im Überprüfungsmodus
G.eval()
#Mit Gurke speichern
with open ('KMNIST_cGAN.pkl','wb')as f:
cloudpickle.dump(G,f)
Es scheint, dass Sie es pkl machen können, indem Sie ein Modul namens "Cloudpickle" anstelle der üblichen "Pickle" verwenden.
Öffnen wir diese pkl-Datei und generieren einen Satz.
python
letter = 'Aiue Okakikuke Kosashi Suseso Tachi Nune no Hahifuhe Homami Mumemoya Yuyorari Rurerowa'
strs = input()
with open('KMNIST_cGAN.pkl','rb')as f:
Generator = cloudpickle.load(f)
for i in range(len(str(strs))):
noise = torch.normal(mean = 0.5, std = 0.2, size = (1, 30))
str_index = letter.index(strs[i])
tmp = np.identity(49)[str_index]
tmp = np.array(tmp, dtype = np.float32)
label = [tmp]
img = Generator(noise, torch.Tensor(label))
img = img.reshape((28,28))
img = img.detach().numpy().tolist()
if i == 0:
comp_img = img
else:
comp_img.extend(img)
save_image(torch.tensor(comp_img), './sentence.png', nrow=len(str(strs)))
img = Image('./sentence.png')
display(img)
Das Ergebnis sieht so aus.
"Ich weiß nichts mehr" ...