[PYTHON] Erzeugung des Junk-Zeichens MNIST (KMNIST) mit cGAN (bedingtes GAN)

Einführung

Ich habe versucht, ein Junk-Zeichen MNIST mit einer Art GAN, cGAN (Conditional GAN), zu generieren. Detaillierte theoretische Aspekte finden Sie unter den Links, die gegebenenfalls hilfreich sind.

Ich hoffe, es wird für diejenigen hilfreich sein, die mögen.

Was ist Kuzuji MNIST (KMNIST)?

Über den diesmal verwendeten Datensatz. KMNIST ist ein Datensatz, der für maschinelles Lernen als Ableitung des vom Humanities Open Data Sharing Center erstellten "Japanese Classics Kuzuji Data Set" erstellt wurde. Sie können es von GitHub Link herunterladen. kmnist.png

"KMNIST-Datensatz" (erstellt von CODH) "Japanischer Klassiker-Kuzuji-Datensatz" (Kokubunken et al.) Adapted doi: 10.20676 / 00000341

Wie diese MNIST (handschriftliche Nummer), die jeder kennt, der maschinelles Lernen durchgeführt hat, ist ein Bild 1 x 28 x 28 Pixel groß.

Die folgenden drei Arten von Datasets können im komprimierten Format von numpy.array aus dem Repository heruntergeladen werden.

--kuzushiji-MNIST (10 Zeichen von Hiragana) --kuzushiji-49 (49 Zeichen von Hiragana)

Von diesen wird diesmal "kuzushiji-49" verwendet. Es gibt keinen besonders tiefen Grund, aber wenn 49 Hiragana-Zeichen gezielt und generiert werden können, ist es dann möglich, handgeschriebene Sätze zu generieren? Ich dachte, es wäre eine leichte Motivation.

Was ist GAN?

Lassen Sie uns kurz auf GAN vor cGAN eingehen. GAN ist eine Abkürzung für "Generative Adversarial Network" (= feindliches Generationsnetzwerk) und eine Art Generationsmodell für tiefes Lernen. Es ist besonders effektiv auf dem Gebiet der Bilderzeugung, und ich denke, dass das Ergebnis der Erzeugung von Gesichtsbildern von Menschen, die es auf der Welt nicht gibt, berühmt ist.

GAN-Modellstruktur

Das Folgende ist ein grobes Modelldiagramm von GAN. "G" steht für Generator und "D" steht für Diskriminator.

Der Generator erzeugt aus dem Rauschen ein gefälschtes Bild, das so real wie möglich ist. Der Diskriminator unterscheidet zwischen dem aus dem Datensatz entnommenen realen Bild (real_img) und dem vom Generator erstellten gefälschten Bild (fake_img) (True oder False).

Durch Wiederholen dieses Lernens versucht der Generator, ein Bild zu erstellen, das der realen Sache so nahe wie möglich kommt, das der Diskriminator nicht erkennen kann, und der Diskriminator versucht, die vom Generator erstellte Fälschung und die aus dem Datensatz abgeleitete reale Sache zu erkennen, sodass der Generator generiert wird. Die Genauigkeit wird erhöht.

GAN.jpg


Referenzartikel
GAN (1) Verständnis der Grundstruktur, die ich nicht mehr hören kann

GAN-bezogene Artikel sind in This GitHub Repository organisiert.

Was ist cGAN (bedingte GAN)?

Als nächstes möchte ich über das diesmal verwendete bedingte GAN sprechen. Einfach ausgedrückt ist es ** "GAN, das das gewünschte Bild erzeugen kann" **. Die Idee ist einfach: Es ist, als würde man das zu erzeugende Bild durch Hinzufügen von Etiketteninformationen zur Eingabe von Diskriminator und Generator entscheiden.

Das Originalpapier ist hier

Modellstruktur von cGAN

Es ist dasselbe wie ein normales GAN, außer dass ** "Geben Sie das Etikett während des Trainings ein" **. Obwohl Etiketteninformationen verwendet werden, bestimmt Discriminator nur, "ob das Bild echt ist". cGAN.jpg


Referenzartikel
Implementierung von GAN (6) Bedingte GAN, die jetzt nicht gehört werden kann

Implementierung

Kommen wir nun zur Implementierung.

Umgebung

Ich habe jupyterlab unter Ubuntu 18.04 installiert und ausgeführt.

Vorbereitung auf das Lernen

Importieren Sie das gewünschte Modul

python


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import random

Erstellen Sie einen Datensatz

Daten herunterladen

Laden Sie die Daten im Numpy-Format von KMNISTs Github herunter. Mit jupyterlab nach dem Öffnen des Terminals und dem Verschieben des Repositorys wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-imgs.npz wget http://codh.rois.ac.jp/kmnist/dataset/k49/k49-train-labels.npz Dann können Sie das Bild und das Etikett herunterladen. Übrigens, wenn es sich um KMNIST mit 10 Hiragana-Zeichen handelt, ist es standardmäßig in der Fackelvision enthalten. Wenn es Ihnen nichts ausmacht, genau wie bei normalem MNIST

python


transform = transforms.Compose(
    [transforms.ToTensor(),
     ])
train_data_10 = torchvision.datasets.KMNIST(root='./data', train=True,download=True,transform=transform)

Sie können es verwenden, wenn Sie dies tun.

Datenvorverarbeitung

Wenn Sie mit PyTorch ein eigenes benutzerdefiniertes Dataset erstellen möchten, müssen Sie die Vorverarbeitung selbst definieren. Die bildbasierte Vorverarbeitung ist meistens in "torchvision.transforms" enthalten, daher verwende ich diese häufig, aber Sie können auch Ihre eigenen erstellen.

python


class Transform(object):
    def __init__(self):
        pass
    
    def __call__(self, sample):
        sample = np.array(sample, dtype = np.float32)
        sample = torch.tensor(sample)
        return (sample/127.5)-1
    
transform = Transform()

Die meisten von numpy behandelten Brüche sind "np.float64" (Gleitkommazahl 64 Bit), aber PyTorch verarbeitet den Bruchwert standardmäßig mit der Gleitkommazahl 32 Bit, sodass ein Fehler auftritt, wenn sie nicht ausgerichtet sind.

Zusätzlich wird hier die Verarbeitung durchgeführt, um den Helligkeitswert des Bildes auf den Bereich von [-1,1] zu normalisieren. Dies liegt daran, dass "Tanh" in der letzten Ebene der Generatorausgabe verwendet wird, die später ausgegeben wird, sodass der Helligkeitswert des realen Bildes entsprechend angepasst wird.

Datensatzklasse

Als nächstes definieren wir die Dataset-Klasse. Dies ist ein Modul, das einen Satz von Daten und Beschriftungen zurückgibt und die Daten zurückgibt, die von der zuvor beim Abrufen der Daten definierten "Transformation" vorverarbeitet wurden.

python


from tqdm import tqdm

class dataset_full(torch.utils.data.Dataset):
    
    def __init__(self, img, label, transform=None):
        self.transform = transform
        self.data_num = len(img)
        self.data = []
        self.label = []
        for i in tqdm(range(self.data_num)):
            self.data.append([img[i]])
            self.label.append(label[i])
        self.data_num = len(self.data)
            
    def __len__(self):
        return self.data_num
    
    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = np.identity(49)[self.label[idx]]
        out_label = np.array(out_label, dtype = np.float32)
        
        if self.transform:
            out_data = self.transform(out_data)
            
        return out_data, out_label

Wenn Sie das erste "tqdm" eingeben, wird der Fortschritt wie ein Balkendiagramm angezeigt, wenn Sie die for-Anweisung drehen. Dies hat jedoch nichts mit dem cGAN selbst zu tun.

Ich verwende np.identity, um einen One-Hot-Vektor mit einer Länge von 49 zu erstellen.

Bilden Sie einen Datensatz aus DL-Daten

Erstellen Sie einen Datensatz mit den Klassen "Transformieren" und "Datensatz", die aus den zuvor heruntergeladenen Daten implementiert wurden.

python


path = %pwd
train_img = np.load('{}/k49-train-imgs.npz'.format(path))
train_img = train_img['arr_0']
train_label = np.load('{}/k49-train-labels.npz'.format(path))
train_label = train_label['arr_0']

train_data = dataset_full(train_img, train_label, transform=transform)

Wenn Sie das tqdm früher eingeben, wird der Fortschritt angezeigt, wenn Sie dies ausführen. Die meisten Daten sind 232.625, aber ich denke nicht, dass es lange dauern wird.

Erstellen Sie einen DataLoader

Wir haben einen Datensatz, aber wir rufen beim Training des Modells keine Daten direkt aus diesem Datensatz ab. Da wir Batch für Batch trainieren, definieren wir einen DataLoader, der Batch-Größendaten zurückgibt.

python



batch_size = 256

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = True, num_workers=2)

Wenn Sie "shuffle = True" setzen, sind die von DataLoader abgerufenen Daten zufällig. num_workers ist ein Argument, das die Anzahl der von DataLoader verwendeten CPU-Kerne angibt und für das cGAN selbst nicht besonders relevant ist.

Der Transform-Dataset-DataLoader bis zu diesem Punkt ist in den folgenden Artikeln zusammengefasst.
Referenzartikel
Überprüfen Sie die grundlegende Funktionsweise von PyTorch-Transformationen / Dataset / DataLoader

Generator definieren

Ich werde den Modellkörper machen. Der Generator erstellt aus Rauschen und Beschriftungen ein falsches Bild (fake_img).

Die Implementierungsmethode ist je nach Person sehr unterschiedlich, aber die Struktur des diesmal erstellten Generators ist wie folgt. (Es ist handgeschrieben, aber es tut mir leid ...) cGAN_G.png In der Eingabe ist "z_dim" (Rauschdimension) 30 und "num_class" (Anzahl der Klassen) 49 Hiragana-Zeichen, daher wird es auf 49 gesetzt. Das gefälschte Bild der Ausgabe hat die Form 1 (Kanal) x 28 (px) x 28 (px).

python



class Generator(nn.Module):
    def __init__(self, z_dim, num_class):
        super(Generator, self).__init__()
        
        self.fc1 = nn.Linear(z_dim, 300)
        self.bn1 = nn.BatchNorm1d(300)
        self.LReLU1 = nn.LeakyReLU(0.2)
        
        self.fc2 = nn.Linear(num_class, 1500)
        self.bn2 = nn.BatchNorm1d(1500)
        self.LReLU2 = nn.LeakyReLU(0.2)
        
        self.fc3 = nn.Linear(1800, 128 * 7 * 7)
        self.bn3 = nn.BatchNorm1d(128 * 7 * 7)
        self.bo1 = nn.Dropout(p=0.5)
        self.LReLU3 = nn.LeakyReLU(0.2)
        
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), #Ändern Sie die Anzahl der Kanäle von 128 auf 64.
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1), #Die Anzahl der Kanäle wurde von 64 auf 1 geändert
            nn.Tanh(),
        )
        
        self.init_weights()
        
    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.ConvTranspose2d):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.Linear):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm1d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
        
    def forward(self, noise, labels):
        y_1 = self.fc1(noise)
        y_1 = self.bn1(y_1)
        y_1 = self.LReLU1(y_1)
        
        y_2 = self.fc2(labels)
        y_2 = self.bn2(y_2)
        y_2 = self.LReLU2(y_2)
        
        x = torch.cat([y_1, y_2], 1)
        x = self.fc3(x)
        x = self.bo1(x)
        x = self.LReLU3(x)
        x = x.view(-1, 128, 7, 7)
        x = self.deconv(x)
        return x

Definition des Diskriminators

Als nächstes kommt der Diskriminator. Der Diskriminator gibt das Original- / Fälschungsbild und seine Etiketteninformationen ein und bestimmt, ob es echt oder gefälscht ist.

Die Struktur des diesmal erstellten Diskriminators ist wie folgt. cGAN_D.png "img" (Eingabebild) ist 1 (Kanal) x 28 (px) x 28 (px) sowohl für echte als auch für gefälschte Bilder, und "label" (Eingabeetikett) ist ein 49-dimensionaler eindimensionaler Vektor. Die Ausgabe bestimmt, ob sie echt ist oder nicht, mit einem Wert von 0 bis 1.

Konzentrieren Sie die Bild- und Beschriftungsinformationen in Kanalrichtung mit "cat" in der Mitte. Ich denke, der zuvor erwähnte cGAN-Artikel ist in diesem Bereich leicht zu verstehen.

python



class Discriminator(nn.Module):
    def __init__(self, num_class):
        super(Discriminator, self).__init__()
        self.num_class = num_class
        
        self.conv = nn.Sequential(
            nn.Conv2d(num_class + 1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.BatchNorm2d(128),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )
        
        self.init_weights()
        
    def init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.Linear):
                module.weight.data.normal_(0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm1d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.normal_(1.0, 0.02)
                module.bias.data.zero_()
        
    def forward(self, img, labels):
        y_2 = labels.view(-1, self.num_class, 1, 1)
        y_2 = y_2.expand(-1, -1, 28, 28)
        
        x = torch.cat([img, y_2], 1)
        
        x = self.conv(x)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)
        return x

Berechnung pro Epoche

1 Erstellen Sie eine Funktion zur Berechnung der Epoche.

python



def train_func(D_model, G_model, batch_size, z_dim, num_class, criterion, 
               D_optimizer, G_optimizer, data_loader, device):
    #Trainingsmodus
    D_model.train()
    G_model.train()

    #Das echte Label ist 1
    y_real = torch.ones((batch_size, 1)).to(device)
    D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Geräuschetikett zum Einfügen von D.

    #Gefälschtes Etikett ist 0
    y_fake = torch.zeros((batch_size, 1)).to(device)
    D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Geräuschetikett zum Einfügen von D.
    
    #Initialisierung des Verlustes
    D_running_loss = 0
    G_running_loss = 0
    
    #Chargenweise Berechnung
    for batch_idx, (data, labels) in enumerate(data_loader):
        #Ignorieren, wenn weniger als die Stapelgröße
        if data.size()[0] != batch_size:
            break
        
        #Geräuschentwicklung
        z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Durchschnitt 0.Generieren Sie Zufallszahlen nach einer Normalverteilung von 5
        
        real_img, label, z = data.to(device), labels.to(device), z.to(device)
        
        #Diskriminator-Update
        D_optimizer.zero_grad()
        
        #Setzen Sie ein reales Bild in Discriminator und verbreiten Sie vorwärts ⇒ Verlustberechnung
        D_real = D_model(real_img, label)
        D_real_loss = criterion(D_real, D_y_real)
        
        #Fügen Sie das durch Einfügen von Rauschen in Generator in Discriminator erzeugte Bild ein und verbreiten Sie es vorwärts ⇒ Verlustberechnung
        fake_img = G_model(z, label)
        D_fake = D_model(fake_img.detach(), label) #fake_Stop Loss wird in Bildern berechnet, damit es nicht zurück zum Generator übertragen wird
        D_fake_loss = criterion(D_fake, D_y_fake)
        
        #Minimieren Sie die Summe von zwei Verlusten
        D_loss = D_real_loss + D_fake_loss
        
        D_loss.backward()
        D_optimizer.step()
                
        D_running_loss += D_loss.item()
        
        #Generator-Update
        G_optimizer.zero_grad()
        
        #Das Bild, das durch Einfügen von Rauschen in den Generator erzeugt wird, wird in den Diskriminator eingefügt und vorwärts weitergegeben. ⇒ Der erkannte Teil wird zu Verlust
        fake_img_2 = G_model(z, label)
        D_fake_2 = D_model(fake_img_2, label)
        
        #G Verlust(max(log D)Optimiert mit)
        G_loss = -criterion(D_fake_2, y_fake)
        
        G_loss.backward()
        G_optimizer.step()
        G_running_loss += G_loss.item()
        
    D_running_loss /= len(data_loader)
    G_running_loss /= len(data_loader)
    
    return D_running_loss, G_running_loss

Das "Kriterium", das im Argument erscheint, ist die Verlustklasse (in diesem Fall Binary Cross Entropy). Was wir mit dieser Funktion machen, ist in Ordnung

--Backpropagation von Fehlern, indem das reale Image des Datensatzes in Discriminator gestellt wird

ist.

Einfallsreichtum der Implementierung

Es ist ein wenig alt, aber diese Implementierung beinhaltet den Einfallsreichtum, der in "Wie man ein GAN trainiert" auf NIPS2016 erscheint, um das GAN-Lernen erfolgreich zu machen. GitHub-Link


Referenzartikel
14 Techniken zum Lernen von GAN (Generative Adversarial Networks)

1. Eingabe normalisieren

Als ich die Dataset-Klasse erstellt habe

python


return (sample/127.5)-1

Ist das. Die letzte Ebene des Generators ist "nn.Tanh ()".

2. Feste Verlustfunktion von G.

python


#G Verlust(max(log D)Optimiert mit)
        G_loss = -criterion(D_fake_2, y_fake)

Ist das. "D_fake_2" ist das Urteil des Diskriminators, und "y_fake" ist ein 128 × 1 0-Vektor.

3.z stammt aus der Gaußschen Verteilung

Probieren Sie das Rauschen, das in den Generator eingegeben werden soll, anhand einer Normalverteilung anstelle einer gleichmäßigen Verteilung aus.

python


#Geräuschentwicklung
z = torch.normal(mean = 0.5, std = 0.2, size = (batch_size, z_dim)) #Durchschnitt 0.Generieren Sie Zufallszahlen nach einer Normalverteilung von 5

Der Mittelwert und die Standardabweichung sind angemessen, aber wenn Sie mit einer gleichmäßigen Verteilung von [0,1] abtasten, erhalten Sie keinen negativen Wert, daher habe ich den abgetasteten Rauschwert fast positiv gemacht.

4.Batch Norm Alle Daten, die aus dem oben erstellten Data Loader stammen, sind ein echtes Bild. und umgekehrt

python



fake_img = G_model(z, label)

Aus den Etiketteninformationen und dem Rauschen von DataLoader erstellen wir dann gefälschte Bilder mit Stapelgröße.

5. Vermeiden Sie Dinge wie ReLU und Max Pooling, bei denen der Gradient gering ist

LeakyReLU scheint sowohl für den Generator als auch für den Diskriminator wirksam zu sein, daher sind alle Aktivierungsfunktionen auf LeakyReLU eingestellt. Dem Argument 0.2 wurde gefolgt, da viele Implementierungen diesen Wert angenommen haben.

6. Verwenden Sie ein lautes Etikett für das richtige Etikett von D.

Das Discriminator-Label ist normalerweise 0 oder 1, aber wir fügen hier Rauschen hinzu. Probieren Sie nach dem Zufallsprinzip echte Etiketten von 0,7 bis 1,2 und gefälschte Etiketten von 0,0 bis 0,3 aus.

python



#Das echte Label ist 1
y_real = torch.ones((batch_size, 1)).to(device)
D_y_real = (torch.rand((batch_size, 1))/2 + 0.7).to(device) #Geräuschetikett zum Einfügen von D.

#Gefälschtes Etikett ist 0
y_fake = torch.zeros((batch_size, 1)).to(device)
D_y_fake = (torch.rand((batch_size, 1)) * 0.3).to(device) #Geräuschetikett zum Einfügen von D.

Das ist der Teil. Normalerweise benutze ich "y_real" / "y_fake" und dieses Mal habe ich tatsächlich "D_y_real" / "D_y_fake" verwendet.

9. Verwenden Sie Adam als Optimierungsmethode

Dies ist ein alter Artikel, daher ist ein anderer Optimierer wie RAdam jetzt möglicherweise besser.

14. Setzen Sie Dropout in G ein

Dieses Mal habe ich Dropout nur einmal in die lineare Ebene des Generators eingefügt. Es gibt jedoch eine Theorie, dass BatchNorm und Dropout nicht miteinander kompatibel sind, daher denke ich nicht, dass es definitiv besser ist, sie alle zusammenzufügen.

Zeigen Sie das vom Generator erstellte Bild an

Definieren Sie vor dem Training des Modells eine Funktion zum Anzeigen des vom Generator erstellten Bildes. Machen Sie dies und überprüfen Sie den Lerngrad des Generators für jede Epoche.

python



import os
from IPython.display import Image
from torchvision.utils import save_image
%matplotlib inline

def Generate_img(epoch, G_model, device, z_dim, noise, var_mode, labels, log_dir = 'logs_cGAN'):
    G_model.eval()
    
    with torch.no_grad():
        if var_mode == True:
            #Für die Generierung erforderliche Zufallszahlen
            noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
        else:
            noise = noise

        #Probengenerierung mit Generator
        samples = G_model(noise, labels).data.cpu()
        samples = (samples/2)+0.5
        save_image(samples,os.path.join(log_dir, 'epoch_%05d.png' % (epoch)), nrow = 7)
        img = Image('logs_cGAN/epoch_%05d.png' % (epoch))
        display(img)

Alles, was Sie tun müssen, ist, das mit Rauschen im Generator erstellte Bild in einen Ordner namens "logs_cGAN" zu legen und anzuzeigen. Es wird angenommen, dass jedes Mal dieselbe Zufallszahl verwendet wird, wenn var_mode False ist.

Modelltraining

Trainiere das Modell.

python



#Fester Startwert zur Gewährleistung der Reproduzierbarkeit
SEED = 1111
random.seed(SEED)
np.random.seed(SEED) 
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

#device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def model_run(num_epochs, batch_size = batch_size, dataloader = train_loader, device = device):
    
    #Dimension des Rauschens, das in den Generator eingefügt werden soll
    z_dim = 30
    var_mode = False #Gibt an, ob bei jeder Anzeige des Anzeigeergebnisses eine andere Zufallszahl verwendet werden soll
    #Für die Generierung erforderliche Zufallszahlen
    noise = torch.normal(mean = 0.5, std = 0.2, size = (49, z_dim)).to(device)
    
    #Anzahl der Klassen
    num_class = 49
    
    #Erstellen Sie ein Etikett, das Sie beim Ausprobieren von Generator verwenden können
    labels = []
    for i in range(num_class):
        tmp = np.identity(num_class)[i]
        tmp = np.array(tmp, dtype = np.float32)
        labels.append(tmp)
    label = torch.Tensor(labels).to(device)
    
    #Modelldefinition
    D_model = Discriminator(num_class).to(device)
    G_model = Generator(z_dim, num_class).to(device)
    
    #Definition von Verlust(Das Argument ist Zug_Spezifiziert in func)
    criterion = nn.BCELoss().to(device)
    
    #Definition des Optimierers
    D_optimizer = torch.optim.Adam(D_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
    G_optimizer = torch.optim.Adam(G_model.parameters(), lr=0.0002, betas=(0.5, 0.999), eps=1e-08, weight_decay=1e-5, amsgrad=False)
    
    D_loss_list = []
    G_loss_list = []
    
    all_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        
        D_loss, G_loss = train_func(D_model, G_model, batch_size, z_dim, num_class, criterion, 
                                    D_optimizer, G_optimizer, dataloader, device)

        D_loss_list.append(D_loss)
        G_loss_list.append(G_loss)
        
        secs = int(time.time() - start_time)
        mins = secs / 60
        secs = secs % 60
        
        #Ergebnisse nach Epoche anzeigen
        print('Epoch: %d' %(epoch + 1), " |Benötigte Zeit%d Minuten%d Sekunden" %(mins, secs))
        print(f'\tLoss: {D_loss:.4f}(Discriminator)')
        print(f'\tLoss: {G_loss:.4f}(Generator)')
        
        if (epoch + 1) % 1 == 0:
            Generate_img(epoch, G_model, device, z_dim, noise, var_mode, label)
        
        #Erstellen Sie eine Prüfpunktdatei zum Speichern des Modells
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch':epoch,
                'model_state_dict':G_model.state_dict(),
                'optimizer_state_dict':G_optimizer.state_dict(),
                'loss':G_loss,
            }, './checkpoint_cGAN/G_model_{}'.format(epoch + 1))
            
    return D_loss_list, G_loss_list

#Drehen Sie das Modell
D_loss_list, G_loss_list = model_run(num_epochs = 100)

Es ist ziemlich lang, aber ich zeige die erforderliche Zeit und den Verlust für jede Epoche an und speichere das Modell.

Ergebnis

Lassen Sie uns den Übergang des Verlusts von Generator und Diskriminator sehen.

python


import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure(figsize=(10,7))

loss = fig.add_subplot(1,1,1)

loss.plot(range(len(D_loss_list)),D_loss_list,label='Discriminator_loss')
loss.plot(range(len(G_loss_list)),G_loss_list,label='Generator_loss')

loss.set_xlabel('epoch')
loss.set_ylabel('loss')

loss.legend()
loss.grid()

fig.show()

cGAN-result.png

Ab etwa 20 Epochen haben sich beide Verluste nicht geändert. Sowohl der Diskriminator- als auch der Generatorverlust sind weit von 0 entfernt, daher scheint es ziemlich gut zu funktionieren. Übrigens, wenn Sie versuchen, die generierten Zeichen in der Reihenfolge von 1 bis 100 Epochen in Gifs umzuwandeln, sieht es so aus. result_cGAN.gif

Oben links ist "A" und unten rechts ist "ゝ". Je nach Charakter gibt es einige Unterschiede, und es scheint, dass "u", "ku", "sa", "so" und "hi" stabil und gut erzeugt werden, aber "na" und "yu" Übergänge haben. Es ist heftig.

Nachfolgend sind die Ergebnisse der Generierung von 5 Bildern für jeden Typ aufgeführt. Epoch:5 epoch_5.png

Epoch:50 epoch_50.png

Epoch:100 epoch_100.png

Wenn man das alleine betrachtet, scheint es nicht besser zu sein, Epoche zu stapeln. "Mu" scheint in 5 Epochen das Beste zu sein, während "ゑ" in 100 Epochen das Beste zu sein scheint.

Übrigens sieht es so aus, wenn Sie 5 Trainingsdaten auf die gleiche Weise abrufen. train_data.png

Es gibt einige Dinge, die selbst moderne Menschen nicht lesen können. "Su" und "mi" unterscheiden sich stark von ihren aktuellen Formen. Wenn ich das betrachte, denke ich, dass die Leistung des Modells ziemlich gut ist.

Zusammenfassung

Ich habe versucht, Junk-Zeichen mit cGAN zu generieren. Ich denke, es gibt noch viel Raum für Verbesserungen bei der Implementierung, aber ich denke, das Ergebnis selbst ist vernünftig. Es ist lange her, aber ich hoffe, es hilft sogar einem Teil davon.

Einige Leute haben auch allgemeine MNIST (handschriftliche Zahlen) mit PyTorch in cGAN implementiert. Es gibt viele verschiedene Teile wie die Modellstruktur, daher denke ich, dass dies auch hilfreich ist.


Referenzartikel
Ich habe versucht, handschriftliche Zeichen durch tiefes Lernen zu generieren [Pytorch x MNIST x CGAN]

Schließlich

Ursprünglich war ich leicht motiviert zu denken: "Ist es möglich, handgeschriebene Sätze zu generieren?", Also werde ich es am Ende versuchen.

Laden Sie das Modellgewicht aus der gespeicherten Prüfpunktdatei und versuchen Sie es einmal mit pkl.

python



import cloudpickle
%matplotlib inline
#Geben Sie die abzurufende Epoche an
point = 50

#Definieren Sie die Struktur des Modells
z_dim = 30
num_class = 49
G = Generator(z_dim = z_dim, num_class = num_class)

#Checkpoint extrahieren
checkpoint = torch.load('./checkpoint_cGAN/G_model_{}'.format(point))

#Parameter in Generator einfügen
G.load_state_dict(checkpoint['model_state_dict'])

#Bleiben Sie im Überprüfungsmodus
G.eval()

#Mit Gurke speichern
with open ('KMNIST_cGAN.pkl','wb')as f:
    cloudpickle.dump(G,f)

Es scheint, dass Sie es pkl machen können, indem Sie ein Modul namens "Cloudpickle" anstelle der üblichen "Pickle" verwenden.

Öffnen wir diese pkl-Datei und generieren einen Satz.

python



letter = 'Aiue Okakikuke Kosashi Suseso Tachi Nune no Hahifuhe Homami Mumemoya Yuyorari Rurerowa'

strs = input()
with open('KMNIST_cGAN.pkl','rb')as f:
    Generator = cloudpickle.load(f)
    
for i in range(len(str(strs))):
    noise = torch.normal(mean = 0.5, std = 0.2, size = (1, 30))
    str_index = letter.index(strs[i])
    tmp = np.identity(49)[str_index]
    tmp = np.array(tmp, dtype = np.float32)
    label = [tmp]
    
    img = Generator(noise, torch.Tensor(label))
    img = img.reshape((28,28))
    img = img.detach().numpy().tolist()
    
    if i == 0:
        comp_img = img
    else:
        comp_img.extend(img)
        
save_image(torch.tensor(comp_img), './sentence.png', nrow=len(str(strs)))
img = Image('./sentence.png')
display(img)

Das Ergebnis sieht so aus. sentence.png

"Ich weiß nichts mehr" ...

Recommended Posts

Erzeugung des Junk-Zeichens MNIST (KMNIST) mit cGAN (bedingtes GAN)
Speichern Sie die Ausgabe der bedingten GAN für jede Klasse ~ Mit der cGAN-Implementierung von PyTorch ~
Bedingte GAN mit Chainer implementiert