[PYTHON] J'ai essayé d'implémenter Realness GAN

Il semble y avoir un nouveau GAN appelé RealnessGAN, mais il semble y avoir peu d'informations en japonais. Je l'ai mis en œuvre et j'ai découvert CIFAR-10.

Le papier est réel ou pas réel, telle est la question. J'ai également fait référence à Mise en œuvre par l'auteur de l'article. Implémentation ici est également utile.

Il semble que vous puissiez étudier magnifiquement même avec DCGAN en utilisant RealnessGAN.

Résultats de la formation CelebA Dataset https://github.com/kam1107/RealnessGAN/blob/master/images/CelebA_snapshot.png

Aperçu

Dans un GAN normal, la sortie du discriminateur est une valeur scalaire qui représente la "réalité". Dans cet article, il est proposé d'utiliser Discriminator qui produit la distribution de probabilité de la réalité.

Il semble que plus le Discriminator fournit d'informations, meilleur sera l'apprentissage du Générateur. Selon l'article, même avec une structure DCGAN normale, il a réussi à apprendre une image de visage de 1024 x 1024 (ensemble de données FFHQ). Recherche des résultats de l'ensemble de données FFHQ

Signification des symboles

Méthode

Un discriminateur de GAN normal produit une valeur scalaire continue "Réalité". D'un autre côté, il semble que le Discriminateur de la Réalité GAN produit la distribution de probabilité discrète de la Réalité. Par exemple

D(\mbox{image}) = 
\begin{bmatrix}
\mbox{La réalité de l'image}1.0\mbox{Probabilité de} \\
\mbox{La réalité de l'image}0.9\mbox{Probabilité de} \\
\vdots \\
\mbox{La réalité de l'image}-0.9\mbox{Probabilité de} \\
\mbox{La réalité de l'image}-1.0\mbox{Probabilité de} \\
\end{bmatrix}

Cela semble être comme. Il semble que cette valeur de Réalité discrète soit appelée Outcome dans l'article. La distribution de probabilité semble être obtenue en prenant un soft max dans la direction du canal pour la sortie brute du discriminateur.

En outre, il semble que les données de réponse correctes sur la distribution de probabilité de la réalité soient appelées Ancre dans l'article. Par exemple

\mathcal{A}_0 = 
\begin{bmatrix}
\mbox{La réalité de la fausse image}1.0\mbox{Probabilité de} \\
\mbox{La réalité de la fausse image}0.9\mbox{Probabilité de} \\
\vdots \\
\mbox{La réalité de la fausse image}-0.9\mbox{Probabilité de} \\
\mbox{La réalité de la fausse image}-1.0\mbox{Probabilité de} \\
\end{bmatrix}
\mathcal{A}_1 = 
\begin{bmatrix}
\mbox{La réalité de l'image réelle}1.0\mbox{Probabilité de} \\
\mbox{La réalité de l'image réelle}0.9\mbox{Probabilité de} \\
\vdots \\
\mbox{La réalité de l'image réelle}-0.9\mbox{Probabilité de} \\
\mbox{La réalité de l'image réelle}-1.0\mbox{Probabilité de} \\
\end{bmatrix}

Il semble que la plage de valeurs de réalité, la distribution d'ancrage, etc. puissent être librement personnalisées.

Fonction objective

Selon l'article, la fonction objective est

\max_{G} \min_{D} V(G, D) = \mathbb{E}{\boldsymbol{x} \sim p{\mathrm{data}}}[\mathcal{D}{\mathrm{KL}}( \mathcal{A}{1} || D(\boldsymbol{x}) )] + \mathbb{E}{\boldsymbol{x} \sim p{g}}[\mathcal{D}{\mathrm{KL}}( \mathcal{A}{0} || D(\boldsymbol{x}) )]. \tag{3}


 Il semble.
 Si vous en extrayez la fonction objectif du Generator $ G $,

> ```math
(G_{\mathrm{objective1}}) \quad
\min_{G} 
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{18}

Il semble que l'apprentissage ne va pas bien avec cela. Par conséquent, l'article propose deux fonctions objectives pour $ G $.

(G_{\mathrm{objective2}}) \quad \min_{G} \quad \mathbb{E}{\boldsymbol{x} \sim p{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z}))]


> ```math
(G_{\mathrm{objective3}}) \quad
\min_{G} \quad
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(G(\boldsymbol{z}))]
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{20}

Après avoir expérimenté, il semble que $ G_ {\ mathrm {objective2}} $ dans l'équation (19) était la meilleure des trois fonctions objectives pour $ G $.

Résumé,

\begin{align}
\min_{D} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(\boldsymbol{x}))] +  
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}) ))] \\
\min_{G} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z})))] -
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z})))]
\end{align}

Ce sera. $ \ mathbb {E} _ {\ boldsymbol {x} \ sim p_ {\ mathrm {data}}} [\ cdots] $, $ \ mathbb {E} _ {\ boldsymbol {z} \ sim p_ {\ boldsymbol { z}}} [\ cdots] $, $ \ mathbb {E} _ {\ boldsymbol {x} \ sim p_ {\ mathrm {data}}, \ boldsymbol {z} \ sim p_ {\ boldsymbol {z}}} La partie [\ cdots] $ doit-elle être la moyenne du mini-lot?

Selon l'article, si Anchor est $ \ mathcal {A} _0 = [1, 0] $, $ \ mathcal {A} _1 = [0, 1] $, la fonction objectif aura la même forme qu'un GAN normal, donc Realness GAN Semble être considéré comme une généralisation du GAN ordinaire.

Informations diverses

Le document contient quelques discussions et idées d'apprentissage, je les ai donc résumées.

Nombre de résultats

Il semble qu'il vaut mieux augmenter le résultat (la dimension de sortie du discriminateur). Si vous augmentez le résultat, devriez-vous augmenter le nombre de mises à jour de Generator $ G $?

Sélection d'ancre

Il semble que plus la divergence KL entre Anchor $ \ mathcal {A} _0 $ dans la fausse image et Anchor $ \ mathcal {A} _1 $ dans l'image réelle, mieux c'est.

Rééchantillonnage des fonctionnalités

Il semble que les performances s'amélioreront si la dimension de sortie de Discriminator est doublée et échantillonnée à partir de la distribution normale en tant que moyenne et écart type. Dans source Github, il semble que l'écart type ne soit pas utilisé tel quel, mais l'indice est pris après avoir divisé par 2 $. (Autrement dit, la sortie d'origine est la logarithmique de la distribution). L'apprentissage semble stable, en particulier dans la seconde moitié de l'apprentissage. Je ne l'ai pas fait dans le code ci-dessous.

code

En savoir plus sur CIFAR-10.

realness_gan.py


import numpy
import torch
import torchvision

#Fonction pour calculer la divergence KL
#epsilon est mis dans le journal pour que NaN ne sorte pas
def kl_divergence(p, q, epsilon=1e-16):
    return torch.mean(torch.sum(p * torch.log((p + epsilon) / (q + epsilon)), dim=1))

# torch.nn.Vous pouvez maintenant mettre le remodelage en séquence
class Reshape(torch.nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(*self.shape)

class GAN:
    def __init__(self):
        self.noise_dimension = 100
        self.n_outcomes      = 20
        self.device          = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.discriminator = torch.nn.Sequential(
            torch.nn.Conv2d( 3, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            Reshape(-1, 32 * 4 * 4),
            torch.nn.Linear(32 * 4 * 4, self.n_outcomes),
        ).to(self.device)
        self.generator = torch.nn.Sequential(
            torch.nn.Linear(self.noise_dimension, 32 * 4 * 4),
            Reshape(-1, 32, 4, 4),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32,  3, 3, padding=1),
            torch.nn.Sigmoid(),
        ).to(self.device)

        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])
        self.generator_optimizer     = torch.optim.Adam(self.generator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])

        #Calculez l'ancre ici
        #Prenez un histogramme de nombres aléatoires suivant l'implémentation de l'auteur sur Github
        normal = numpy.random.normal(1, 1, 1000) #moyenne+1, distribution normale avec écart type 1
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2)) # -À partir de 2+Prenez un histogramme jusqu'à 2
        self.real_anchor = count / sum(count) #Normalisé à la somme 1

        normal = numpy.random.normal(-1, 1, 1000) #moyenne-1, distribution normale avec écart type 1
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2))
        self.fake_anchor = count / sum(count)

        self.real_anchor = torch.Tensor(self.real_anchor).to(self.device)
        self.fake_anchor = torch.Tensor(self.fake_anchor).to(self.device)

    def generate_fakes(self, num):
        mean = torch.zeros(num, self.noise_dimension, device=self.device)
        std  = torch.ones(num, self.noise_dimension, device=self.device)
        noise = torch.normal(mean, std)
        return self.generator(noise)

    def train_discriminator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size).detach()

        #Softmax la sortie de Discriminator pour en faire une probabilité
        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        loss = kl_divergence(self.real_anchor, real_feature) + kl_divergence(self.fake_anchor, fake_feature) #Formule papier(3)

        self.discriminator_optimizer.zero_grad()
        loss.backward()
        self.discriminator_optimizer.step()
        
        return float(loss)

    def train_generator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size)

        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        # loss = -kl_divergence(self.fake_anchor, fake_feature) #Formule papier(18)
        loss = kl_divergence(real_feature, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) #Formule papier(19)
        # loss = kl_divergence(self.real_anchor, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) #Formule papier(20)
        
        self.generator_optimizer.zero_grad()
        loss.backward()
        self.generator_optimizer.step()
        
        return float(loss)

    def step(self, real):
        real = real.to(self.device)

        discriminator_loss = self.train_discriminator(real)
        generator_loss     = self.train_generator(real)

        return discriminator_loss, generator_loss

if __name__ == '__main__':
    transformer = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
    ])

    dataset = torchvision.datasets.CIFAR10(root='C:/datasets',
                                           transform=transformer,
                                           download=True)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=128,
                                           drop_last=True)

    gan = GAN()
    n_steps = 0

    for epoch in range(1000):
        for iteration, data in enumerate(iterator):
            real = data[0].float()
            discriminator_loss, generator_loss = gan.step(real)
            
            print('epoch : {}, iteration : {}, discriminator_loss : {}, generator_loss : {}'.format(
                epoch, iteration, discriminator_loss, generator_loss
            ))

            n_steps += 1

            if iteration == 0:
                fakes = gan.generate_fakes(64)
                torchvision.utils.save_image(fakes, 'out/{}.png'.format(n_steps))

résultat

0 époque (1ère étape) 1.png

10e époque (3901e étape) 3901.png

100e époque (étape 39001) 39001.png

500e époque (étape 195001) 195001.png

Avec cette implémentation, Batch Normalization et Spectral Normalization sont [Feature Resampling](#Feature Resampling). ) N'est pas utilisé non plus, mais il semble qu'il puisse être généré raisonnablement bien.

Recommended Posts

J'ai essayé d'implémenter Realness GAN
J'ai essayé d'implémenter PCANet
J'ai essayé d'implémenter StarGAN (1)
J'ai essayé d'implémenter Deep VQE
J'ai essayé de mettre en place une validation contradictoire
J'ai essayé d'implémenter PLSA en Python
J'ai essayé d'implémenter PLSA dans Python 2
J'ai essayé d'implémenter ADALINE en Python
J'ai essayé d'implémenter PPO en Python
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé de déboguer.
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé d'implémenter TOPIC MODEL en Python
J'ai essayé d'implémenter le tri sélectif en python
J'ai essayé de mettre en œuvre le problème du voyageur de commerce
J'ai essayé d'apprendre PredNet
J'ai essayé d'organiser SVM.
J'ai essayé de réintroduire Linux
J'ai essayé de présenter Pylint
J'ai essayé de résumer SparseMatrix
jupyter je l'ai touché
J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch
J'ai essayé d'implémenter Mine Sweeper sur un terminal avec python
J'ai essayé d'implémenter un pseudo pachislot en Python
J'ai essayé d'implémenter le poker de Drakue en Python
J'ai essayé d'implémenter le perceptron artificiel avec python
J'ai essayé d'implémenter GA (algorithme génétique) en Python
J'ai essayé d'implémenter Grad-CAM avec keras et tensorflow
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai essayé d'implémenter le calcul automatique de la preuve de séquence
J'ai essayé de mettre en œuvre le modèle de base du réseau neuronal récurrent
J'ai essayé de créer l'API Quip
J'ai essayé d'implémenter la détection d'anomalies par apprentissage de structure clairsemée
J'ai essayé de toucher Python (installation)
J'ai essayé d'implémenter un automate cellulaire unidimensionnel en Python
J'ai essayé de mettre en œuvre une évasion (type d'évitement de tromperie) avec Quantx
[Django] J'ai essayé d'implémenter des restrictions d'accès par héritage de classe.
J'ai essayé d'expliquer l'ensemble de données de Pytorch
J'ai essayé l'authentification vocale Watson (Speech to Text)
J'ai essayé d'exécuter GAN dans Colaboratory
J'ai touché l'API de Tesla
J'ai essayé d'implémenter ListNet d'apprentissage de rang avec Chainer
J'ai essayé d'implémenter la fonction d'envoi de courrier en Python
J'ai essayé de mettre en œuvre le chapeau de regroupement de Harry Potter avec CNN
J'ai essayé de m'organiser à propos de MCMC.
J'ai essayé d'implémenter Perceptron Part 1 [Deep Learning from scratch]
J'ai essayé d'implémenter le blackjack du jeu Trump en Python
J'ai essayé de déplacer le ballon
J'ai essayé d'estimer la section.
J'ai essayé d'implémenter SSD avec PyTorch maintenant (édition du modèle)
[Python] J'ai essayé d'implémenter un tri stable, alors notez
J'ai essayé de mettre en œuvre un jeu de dilemme de prisonnier mal compris en Python
J'ai essayé d'implémenter la classification des phrases par Self Attention avec PyTorch
J'ai essayé de créer un linebot (implémentation)
J'ai essayé de résumer la gestion des exceptions Python
J'ai essayé d'utiliser Azure Speech to Text.