[PYTHON] Ich habe versucht, Realness GAN zu implementieren

Es scheint ein neues GAN namens RealnessGAN zu geben, aber es scheint wenig Informationen auf Japanisch zu geben. Ich habe es implementiert und etwas über CIFAR-10 gelernt.

Das Papier ist real oder nicht real, das ist die Frage. Ich habe auch auf [Implementierung durch den Autor des Papiers] verwiesen (https://github.com/kam1107/RealnessGAN). Implementierung hier ist ebenfalls hilfreich.

Es scheint, dass Sie mit RealnessGAN auch mit DCGAN wunderbar lernen können.

CelebA Dataset Training Results https://github.com/kam1107/RealnessGAN/blob/master/images/CelebA_snapshot.png

Überblick

In einem normalen GAN ist die Ausgabe des Diskriminators ein Skalarwert, der "Realität" darstellt. In diesem Artikel wird vorgeschlagen, einen Diskriminator zu verwenden, der die Wahrscheinlichkeitsverteilung der Realität ausgibt.

Je mehr Informationen der Diskriminator ausgibt, desto besser lernt der Generator. Dem Papier zufolge gelang es sogar mit einer normalen DCGAN-Struktur, ein 1024 x 1024-Gesichtsbild (FFHQ-Datensatz) zu lernen. Ergebnisse des FFHQ-Datensatzes ermitteln

Bedeutung von Symbolen

Methode

Ein normaler GAN-Diskriminator gibt einen kontinuierlichen Skalarwert "Realität" aus. Andererseits scheint der Diskriminator der Realität GAN die diskrete Wahrscheinlichkeitsverteilung der Realität auszugeben. Zum Beispiel

D(\mbox{Bild}) = 
\begin{bmatrix}
\mbox{Die Realität des Bildes}1.0\mbox{Wahrscheinlichkeit von} \\
\mbox{Die Realität des Bildes}0.9\mbox{Wahrscheinlichkeit von} \\
\vdots \\
\mbox{Die Realität des Bildes}-0.9\mbox{Wahrscheinlichkeit von} \\
\mbox{Die Realität des Bildes}-1.0\mbox{Wahrscheinlichkeit von} \\
\end{bmatrix}

Es scheint so zu sein. Es scheint, dass dieser diskrete Realitätswert im Papier als Ergebnis bezeichnet wird. Die Wahrscheinlichkeitsverteilung scheint erhalten zu werden, indem ein weiches Maximum in Kanalrichtung für die Rohleistung des Diskriminators genommen wird.

Es scheint auch, dass die richtigen Antwortdaten über die Wahrscheinlichkeitsverteilung der Realität in der Arbeit als Anker bezeichnet werden. Zum Beispiel

\mathcal{A}_0 = 
\begin{bmatrix}
\mbox{Die Realität des gefälschten Bildes}1.0\mbox{Wahrscheinlichkeit von} \\
\mbox{Die Realität des gefälschten Bildes}0.9\mbox{Wahrscheinlichkeit von} \\
\vdots \\
\mbox{Die Realität des gefälschten Bildes}-0.9\mbox{Wahrscheinlichkeit von} \\
\mbox{Die Realität des gefälschten Bildes}-1.0\mbox{Wahrscheinlichkeit von} \\
\end{bmatrix}
\mathcal{A}_1 = 
\begin{bmatrix}
\mbox{Die Realität des realen Bildes}1.0\mbox{Wahrscheinlichkeit von} \\
\mbox{Die Realität des realen Bildes}0.9\mbox{Wahrscheinlichkeit von} \\
\vdots \\
\mbox{Die Realität des realen Bildes}-0.9\mbox{Wahrscheinlichkeit von} \\
\mbox{Die Realität des realen Bildes}-1.0\mbox{Wahrscheinlichkeit von} \\
\end{bmatrix}

Es scheint, dass der Realitätswertbereich, die Ankerverteilung usw. frei angepasst werden können.

Zielfunktion

Dem Papier zufolge ist die Zielfunktion

\max_{G} \min_{D} V(G, D) = \mathbb{E}{\boldsymbol{x} \sim p{\mathrm{data}}}[\mathcal{D}{\mathrm{KL}}( \mathcal{A}{1} || D(\boldsymbol{x}) )] + \mathbb{E}{\boldsymbol{x} \sim p{g}}[\mathcal{D}{\mathrm{KL}}( \mathcal{A}{0} || D(\boldsymbol{x}) )]. \tag{3}


 Es scheint.
 Wenn Sie die Zielfunktion des Generators $ G $ daraus extrahieren,

> ```math
(G_{\mathrm{objective1}}) \quad
\min_{G} 
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{18}

Es scheint, dass das Lernen damit nicht gut gehen wird. Daher schlägt das Papier zwei Zielfunktionen für $ G $ vor.

(G_{\mathrm{objective2}}) \quad \min_{G} \quad \mathbb{E}{\boldsymbol{x} \sim p{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z}))]


> ```math
(G_{\mathrm{objective3}}) \quad
\min_{G} \quad
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(G(\boldsymbol{z}))]
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{20}

Nach dem Experimentieren scheint $ G_ {\ mathrm {object2}} $ in Gleichung (19) die beste der drei Zielfunktionen für $ G $ zu sein.

Zusammenfassung,

\begin{align}
\min_{D} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(\boldsymbol{x}))] +  
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}) ))] \\
\min_{G} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z})))] -
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z})))]
\end{align}

Es wird sein. $ \ mathbb {E} _ {\ boldsymbol {x} \ sim p_ {\ mathrm {data}}} [\ cdots] $, $ \ mathbb {E} _ {\ boldsymbol {z} \ sim p_ {\ boldsymbol { z}}} [\ cdots] $, $ \ mathbb {E} _ {\ boldsymbol {x} \ sim p_ {\ mathrm {data}}, \ boldsymbol {z} \ sim p_ {\ boldsymbol {z}}} Sollte der [\ cdots] $ -Teil der Durchschnitt des Mini-Batch sein?

Laut dem Artikel hat die Zielfunktion dieselbe Form wie ein normales GAN, also Realness GAN, wenn Anchor $ \ mathcal {A} _0 = [1, 0] $, $ \ mathcal {A} _1 = [0, 1] $ ist Scheint als Verallgemeinerung der gewöhnlichen GAN zu gelten.

Verschiedene Informationen

Das Papier enthält einige Diskussionen und Lernideen, daher habe ich sie zusammengefasst.

Anzahl der Ergebnisse

Es scheint besser zu sein, das Ergebnis (die Output-Dimension von Discriminator) zu steigern. Wenn Sie das Ergebnis erhöhen, sollten Sie die Häufigkeit erhöhen, mit der Sie Generator $ G $ aktualisieren?

Ankerauswahl

Es scheint, dass je größer die KL-Divergenz zwischen Anchor $ \ mathcal {A} _0 $ im gefälschten Bild und Anchor $ \ mathcal {A} _1 $ im realen Bild ist, desto besser.

Feature-Resampling

Es scheint, dass sich die Leistung verbessern wird, wenn die Ausgabedimension von Discriminator verdoppelt und aus der Normalverteilung als Mittelwert und Standardabweichung abgetastet wird. In Github-Quelle scheint die Standardabweichung nicht so zu sein, wie sie ist, aber der Index wird nach Division durch $ 2 $ genommen. (Das heißt, die ursprüngliche Ausgabe ist der Logarithmus der Verteilung). Das Lernen scheint besonders in der zweiten Hälfte des Lernens stabil zu sein. Ich habe es im folgenden Code nicht getan.

Code

Erfahren Sie mehr über CIFAR-10.

realness_gan.py


import numpy
import torch
import torchvision

#Funktion zur Berechnung der KL-Divergenz
#epsilon wird protokolliert, damit NaN nicht herauskommt
def kl_divergence(p, q, epsilon=1e-16):
    return torch.mean(torch.sum(p * torch.log((p + epsilon) / (q + epsilon)), dim=1))

# torch.nn.Jetzt können Sie die Umformung in Sequential setzen
class Reshape(torch.nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(*self.shape)

class GAN:
    def __init__(self):
        self.noise_dimension = 100
        self.n_outcomes      = 20
        self.device          = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.discriminator = torch.nn.Sequential(
            torch.nn.Conv2d( 3, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            Reshape(-1, 32 * 4 * 4),
            torch.nn.Linear(32 * 4 * 4, self.n_outcomes),
        ).to(self.device)
        self.generator = torch.nn.Sequential(
            torch.nn.Linear(self.noise_dimension, 32 * 4 * 4),
            Reshape(-1, 32, 4, 4),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32,  3, 3, padding=1),
            torch.nn.Sigmoid(),
        ).to(self.device)

        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])
        self.generator_optimizer     = torch.optim.Adam(self.generator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])

        #Berechnen Sie hier den Anker
        #Machen Sie ein Zufallszahlenhistogramm nach der Implementierung des Autors auf Github
        normal = numpy.random.normal(1, 1, 1000) #durchschnittlich+1, Normalverteilung mit Standardabweichung 1
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2)) # -Ab 2+Machen Sie ein Histogramm bis zu 2
        self.real_anchor = count / sum(count) #Normalisiert auf Summe 1

        normal = numpy.random.normal(-1, 1, 1000) #durchschnittlich-1, Normalverteilung mit Standardabweichung 1
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2))
        self.fake_anchor = count / sum(count)

        self.real_anchor = torch.Tensor(self.real_anchor).to(self.device)
        self.fake_anchor = torch.Tensor(self.fake_anchor).to(self.device)

    def generate_fakes(self, num):
        mean = torch.zeros(num, self.noise_dimension, device=self.device)
        std  = torch.ones(num, self.noise_dimension, device=self.device)
        noise = torch.normal(mean, std)
        return self.generator(noise)

    def train_discriminator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size).detach()

        #Softmax die Ausgabe von Discriminator, um es zu einer Wahrscheinlichkeit zu machen
        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        loss = kl_divergence(self.real_anchor, real_feature) + kl_divergence(self.fake_anchor, fake_feature) #Papierformel(3)

        self.discriminator_optimizer.zero_grad()
        loss.backward()
        self.discriminator_optimizer.step()
        
        return float(loss)

    def train_generator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size)

        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        # loss = -kl_divergence(self.fake_anchor, fake_feature) #Papierformel(18)
        loss = kl_divergence(real_feature, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) #Papierformel(19)
        # loss = kl_divergence(self.real_anchor, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) #Papierformel(20)
        
        self.generator_optimizer.zero_grad()
        loss.backward()
        self.generator_optimizer.step()
        
        return float(loss)

    def step(self, real):
        real = real.to(self.device)

        discriminator_loss = self.train_discriminator(real)
        generator_loss     = self.train_generator(real)

        return discriminator_loss, generator_loss

if __name__ == '__main__':
    transformer = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
    ])

    dataset = torchvision.datasets.CIFAR10(root='C:/datasets',
                                           transform=transformer,
                                           download=True)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=128,
                                           drop_last=True)

    gan = GAN()
    n_steps = 0

    for epoch in range(1000):
        for iteration, data in enumerate(iterator):
            real = data[0].float()
            discriminator_loss, generator_loss = gan.step(real)
            
            print('epoch : {}, iteration : {}, discriminator_loss : {}, generator_loss : {}'.format(
                epoch, iteration, discriminator_loss, generator_loss
            ))

            n_steps += 1

            if iteration == 0:
                fakes = gan.generate_fakes(64)
                torchvision.utils.save_image(fakes, 'out/{}.png'.format(n_steps))

Ergebnis

0 Epoche (1. Schritt) 1.png

  1. Epoche (3901. Schritt) 3901.png

  2. Epoche (39001 Schritt) 39001.png

  3. Epoche (195001 Schritt) 195001.png

Bei dieser Implementierung sind sowohl Batch Normalization als auch Spectral Normalization [Feature Resampling](#Feature Resampling). ) Wird auch nicht verwendet, aber es scheint, dass es ziemlich gut erzeugt werden kann.

Recommended Posts

Ich habe versucht, Realness GAN zu implementieren
Ich habe versucht, PCANet zu implementieren
Ich habe versucht, StarGAN (1) zu implementieren.
Ich habe versucht, Deep VQE zu implementieren
Ich habe versucht, eine kontroverse Validierung zu implementieren
Ich habe versucht, PLSA in Python zu implementieren
Ich habe versucht, PLSA in Python 2 zu implementieren
Ich habe versucht, ADALINE in Python zu implementieren
Ich habe versucht, PPO in Python zu implementieren
Ich habe versucht, CVAE mit PyTorch zu implementieren
Ich habe versucht zu debuggen.
Ich habe versucht, GAN (mnist) mit Keras zu bewegen
Ich habe versucht, TOPIC MODEL in Python zu implementieren
Ich habe versucht, eine selektive Sortierung in Python zu implementieren
Ich habe versucht, das Problem des Handlungsreisenden umzusetzen
Ich habe versucht, PredNet zu lernen
Ich habe versucht, SVM zu organisieren.
Ich habe versucht, Linux wieder einzuführen
Ich habe versucht, Pylint vorzustellen
Ich habe versucht, SparseMatrix zusammenzufassen
jupyter ich habe es berührt
Ich habe versucht, DCGAN mit PyTorch zu implementieren und zu lernen
Ich habe versucht, Mine Sweeper auf dem Terminal mit Python zu implementieren
Ich habe versucht, einen Pseudo-Pachislot in Python zu implementieren
Ich habe versucht, Drakues Poker in Python zu implementieren
Ich habe versucht, künstliches Perzeptron mit Python zu implementieren
Ich habe versucht, GA (genetischer Algorithmus) in Python zu implementieren
Ich habe versucht, Grad-CAM mit Keras und Tensorflow zu implementieren
Ich habe versucht, SSD jetzt mit PyTorch zu implementieren (Dataset)
Ich habe versucht, einen automatischen Nachweis der Sequenzberechnung zu implementieren
Ich habe versucht, das grundlegende Modell des wiederkehrenden neuronalen Netzwerks zu implementieren
Ich habe versucht, eine Quip-API zu erstellen
Ich habe versucht, die Erkennung von Anomalien durch spärliches Strukturlernen zu implementieren
Ich habe versucht, Python zu berühren (Installation)
Ich habe versucht, einen eindimensionalen Zellautomaten in Python zu implementieren
Ich habe versucht, mit Quantx einen Ausbruch (Typ der Täuschungsvermeidung) zu implementieren
[Django] Ich habe versucht, Zugriffsbeschränkungen durch Klassenvererbung zu implementieren.
Ich habe versucht, Pytorchs Datensatz zu erklären
Ich habe Watson Voice Authentication (Speech to Text) ausprobiert.
Ich habe versucht, GAN in Colaboratory auszuführen
Ich habe Teslas API berührt
Ich habe versucht, ListNet of Rank Learning mit Chainer zu implementieren
Ich habe versucht, die Mail-Sendefunktion in Python zu implementieren
Ich habe versucht, Harry Potters Gruppierungshut mit CNN umzusetzen
Ich habe versucht, mich über MCMC zu organisieren.
Ich habe versucht, Perceptron Teil 1 [Deep Learning von Grund auf neu] zu implementieren.
Ich habe versucht, das Blackjack of Trump-Spiel mit Python zu implementieren
Ich habe versucht, den Ball zu bewegen
Ich habe versucht, den Abschnitt zu schätzen.
Ich habe versucht, SSD jetzt mit PyTorch zu implementieren (Modellversion)
[Python] Ich habe versucht, eine stabile Sortierung zu implementieren
Ich habe versucht, ein missverstandenes Gefangenendilemma in Python zu implementieren
Ich habe versucht, die Satzklassifizierung durch Self Attention mit PyTorch zu implementieren
Ich habe versucht, einen Linebot zu erstellen (Implementierung)
Ich habe versucht, die Behandlung von Python-Ausnahmen zusammenzufassen
Ich habe versucht, Azure Speech to Text zu verwenden.