[PYTHON] I tried to implement Realness GAN

There seems to be a new GAN called RealnessGAN, but there seems to be little information in Japanese. I implemented it and learned about CIFAR-10.

The paper is Real or Not Real, that is the Question. I also referred to Implementation by the author of the paper. Implementation here is also helpful.

It seems that if you use RealnessGAN, you can study beautifully even with DCGAN.

CelebA dataset training results https://github.com/kam1107/RealnessGAN/blob/master/images/CelebA_snapshot.png

Overview

In a normal GAN, the output of the Discriminator is a scalar value that represents "Realness". In this paper, it is proposed to use Discriminator, which outputs the probability distribution of Realness.

By increasing the information output by Discriminator, it seems that learning of Generator will be better. According to the paper, even with a normal DCGAN structure, it succeeded in learning a 1024 x 1024 face image (FFHQ data set). FFHQ dataset training results

Meaning of symbols

-$ D $ Discriminator -$ G $ Generator -$ \ boldsymbol {z} $ Noise to put in Generator (latent expression) -$ \ mathcal {A} _0 $ Anchor for fake images (Distribution of Realness given as the correct answer for $ D ?) - \ mathcal {A} 1 $ Anchor for real images -$ p {\ mathrm {data}} (\ boldsymbol {x}) $ Probability of randomly selecting an image $ \ boldsymbol {x} $ from the dataset? -$ p_g (\ boldsymbol {x}) $ Probability that $ G (\ boldsymbol {z}) $ from randomly selected $ \ boldsymbol {z} $ becomes image $ \ boldsymbol {x} $?

Method

A normal GAN Discriminator outputs a continuous scalar value "Realness". On the other hand, Realness GAN Discriminator seems to output Realness discrete probability distribution. For example

D(\mbox{image}) = 
\begin{bmatrix}
\mbox{The Realness of the image}1.0\mbox{Probability of} \\
\mbox{The Realness of the image}0.9\mbox{Probability of} \\
\vdots \\
\mbox{The Realness of the image}-0.9\mbox{Probability of} \\
\mbox{The Realness of the image}-1.0\mbox{Probability of} \\
\end{bmatrix}

It seems to be like. It seems that this discretized Realness value is called Outcome in the paper. The probability distribution seems to be obtained by taking a softmax in the channel direction for the raw output of the Discriminator.

Also, it seems that the correct answer data about the probability distribution of Realness is called Anchor in the paper. For example

\mathcal{A}_0 = 
\begin{bmatrix}
\mbox{The Realness of the fake image}1.0\mbox{Probability of} \\
\mbox{The Realness of the fake image}0.9\mbox{Probability of} \\
\vdots \\
\mbox{The Realness of the fake image}-0.9\mbox{Probability of} \\
\mbox{The Realness of the fake image}-1.0\mbox{Probability of} \\
\end{bmatrix}
\mathcal{A}_1 = 
\begin{bmatrix}
\mbox{Realness of real images}1.0\mbox{Probability of} \\
\mbox{Realness of real images}0.9\mbox{Probability of} \\
\vdots \\
\mbox{Realness of real images}-0.9\mbox{Probability of} \\
\mbox{Realness of real images}-1.0\mbox{Probability of} \\
\end{bmatrix}

It seems that the range of Realness, the distribution of Anchor, etc. can be freely customized.

Objective function

According to the paper, the objective function is

\max_{G} \min_{D} V(G, D) = \mathbb{E}{\boldsymbol{x} \sim p{\mathrm{data}}}[\mathcal{D}{\mathrm{KL}}( \mathcal{A}{1} || D(\boldsymbol{x}) )] + \mathbb{E}{\boldsymbol{x} \sim p{g}}[\mathcal{D}{\mathrm{KL}}( \mathcal{A}{0} || D(\boldsymbol{x}) )]. \tag{3}


 It seems.
 If you extract the objective function of Generator $ G $ from it,

> ```math
(G_{\mathrm{objective1}}) \quad
\min_{G} 
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{18}

It seems that learning will not go well with this. Therefore, the paper proposes two objective functions for $ G $.

(G_{\mathrm{objective2}}) \quad \min_{G} \quad \mathbb{E}{\boldsymbol{x} \sim p{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z}))]


> ```math
(G_{\mathrm{objective3}}) \quad
\min_{G} \quad
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(G(\boldsymbol{z}))]
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{20}

As a result of experimentation, it seems that $ G_ {\ mathrm {objective2}} $ in equation (19) was the best of the three objective functions for $ G $.

Summary,

\begin{align}
\min_{D} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(\boldsymbol{x}))] +  
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}) ))] \\
\min_{G} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z})))] -
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z})))]
\end{align}

It will be. $ \ mathbb {E} _ {\ boldsymbol {x} \ sim p_ {\ mathrm {data}}} [\ cdots] $, $ \ mathbb {E} _ {\ boldsymbol {z} \ sim p_ {\ boldsymbol { z}}} [\ cdots] $, $ \ mathbb {E} _ {\ boldsymbol {x} \ sim p_ {\ mathrm {data}}, \ boldsymbol {z} \ sim p_ {\ boldsymbol {z}}} Should the [\ cdots] $ part be the average of the mini-batch?

According to the paper, if Anchor is $ \ mathcal {A} _0 = [1, 0] $, $ \ mathcal {A} _1 = [0, 1] $, the objective function will have the same shape as a normal GAN, so RealnessGAN Seems to be considered a generalization of ordinary GAN.

Miscellaneous information

The dissertation contains some discussions and learning ideas, so I have summarized them.

Number of Outcomes

It seems that the more Outcome (the dimension of the output of Discriminator), the better. If you increase Outcome, should you increase the number of times you update Generator $ G $?

Anchor selection

The larger the KL divergence between the fake image Anchor $ \ mathcal {A} _0 $ and the real image Anchor $ \ mathcal {A} _1 $, the better.

Features resampling

It seems that performance will improve if the output dimension of the Discriminator is doubled and sampled from the normal distribution as the mean and standard deviation. In Github source, it seems that the standard deviation is not used as it is, but the index is taken after dividing by $ 2 $. (That is, the original output is the logarithm of the variance). Learning seems to be stable, especially in the second half of learning. I didn't do it in the code below.

code

Learn about CIFAR-10.

realness_gan.py


import numpy
import torch
import torchvision

#Function to calculate KL divergence
#epsilon is put in log so that NaN does not come out
def kl_divergence(p, q, epsilon=1e-16):
    return torch.mean(torch.sum(p * torch.log((p + epsilon) / (q + epsilon)), dim=1))

# torch.nn.Now you can put reshape in Sequential
class Reshape(torch.nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(*self.shape)

class GAN:
    def __init__(self):
        self.noise_dimension = 100
        self.n_outcomes      = 20
        self.device          = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.discriminator = torch.nn.Sequential(
            torch.nn.Conv2d( 3, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            Reshape(-1, 32 * 4 * 4),
            torch.nn.Linear(32 * 4 * 4, self.n_outcomes),
        ).to(self.device)
        self.generator = torch.nn.Sequential(
            torch.nn.Linear(self.noise_dimension, 32 * 4 * 4),
            Reshape(-1, 32, 4, 4),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32,  3, 3, padding=1),
            torch.nn.Sigmoid(),
        ).to(self.device)

        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])
        self.generator_optimizer     = torch.optim.Adam(self.generator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])

        #Calculate Anchor here
        #Take a random number histogram following the author's implementation on Github
        normal = numpy.random.normal(1, 1, 1000) #average+1, normal distribution with standard deviation 1
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2)) # -From 2+Take a histogram up to 2
        self.real_anchor = count / sum(count) #Normalized so that the sum is 1

        normal = numpy.random.normal(-1, 1, 1000) #average-1, normal distribution with standard deviation 1
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2))
        self.fake_anchor = count / sum(count)

        self.real_anchor = torch.Tensor(self.real_anchor).to(self.device)
        self.fake_anchor = torch.Tensor(self.fake_anchor).to(self.device)

    def generate_fakes(self, num):
        mean = torch.zeros(num, self.noise_dimension, device=self.device)
        std  = torch.ones(num, self.noise_dimension, device=self.device)
        noise = torch.normal(mean, std)
        return self.generator(noise)

    def train_discriminator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size).detach()

        #Softmax the output of Discriminator to make it a probability
        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        loss = kl_divergence(self.real_anchor, real_feature) + kl_divergence(self.fake_anchor, fake_feature) #Thesis formula(3)

        self.discriminator_optimizer.zero_grad()
        loss.backward()
        self.discriminator_optimizer.step()
        
        return float(loss)

    def train_generator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size)

        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        # loss = -kl_divergence(self.fake_anchor, fake_feature) #Thesis formula(18)
        loss = kl_divergence(real_feature, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) #Thesis formula(19)
        # loss = kl_divergence(self.real_anchor, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) #Thesis formula(20)
        
        self.generator_optimizer.zero_grad()
        loss.backward()
        self.generator_optimizer.step()
        
        return float(loss)

    def step(self, real):
        real = real.to(self.device)

        discriminator_loss = self.train_discriminator(real)
        generator_loss     = self.train_generator(real)

        return discriminator_loss, generator_loss

if __name__ == '__main__':
    transformer = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
    ])

    dataset = torchvision.datasets.CIFAR10(root='C:/datasets',
                                           transform=transformer,
                                           download=True)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=128,
                                           drop_last=True)

    gan = GAN()
    n_steps = 0

    for epoch in range(1000):
        for iteration, data in enumerate(iterator):
            real = data[0].float()
            discriminator_loss, generator_loss = gan.step(real)
            
            print('epoch : {}, iteration : {}, discriminator_loss : {}, generator_loss : {}'.format(
                epoch, iteration, discriminator_loss, generator_loss
            ))

            n_steps += 1

            if iteration == 0:
                fakes = gan.generate_fakes(64)
                torchvision.utils.save_image(fakes, 'out/{}.png'.format(n_steps))

result

0 epoch (1st step) 1.png

10th epoch (3901st step) 3901.png

100th epoch (39001th step) 39001.png

500th epoch (195001 step) 195001.png

With this implementation, both Batch Normalization and Spectral Normalization are [Feature Resampling](#Feature Resampling). ) Is not used either, but it seems that it can be generated reasonably well.

Recommended Posts

I tried to implement Realness GAN
I tried to implement PCANet
I tried to implement StarGAN (1)
I tried to implement Deep VQE
I tried to implement adversarial validation
I tried to implement hierarchical clustering
I tried to implement PLSA in Python
I tried to implement PLSA in Python 2
I tried to implement ADALINE in Python
I tried to implement PPO in Python
I tried to implement CVAE with PyTorch
I tried to debug.
I tried to paste
I tried to move GAN (mnist) with keras
I tried to implement TOPIC MODEL in Python
I tried to implement selection sort in python
I tried to implement the traveling salesman problem
I tried to learn PredNet
I tried to organize SVM.
I tried to reintroduce Linux
I tried to introduce Pylint
I tried to summarize SparseMatrix
I tried to touch jupyter
I tried to implement and learn DCGAN with PyTorch
I tried to implement Minesweeper on terminal with python
I tried to implement a pseudo pachislot in Python
I tried to implement a recommendation system (content-based filtering)
I tried to implement Dragon Quest poker in Python
I tried to implement an artificial perceptron with python
I tried to implement time series prediction with GBDT
I tried to implement GA (genetic algorithm) in Python
I tried to implement Grad-CAM with keras and tensorflow
I tried to implement SSD with PyTorch now (Dataset)
I tried to implement automatic proof of sequence calculation
I tried to implement a basic Recurrent Neural Network model
I tried to create Quip API
I tried to implement anomaly detection by sparse structure learning
I tried to touch Python (installation)
I tried to implement a one-dimensional cellular automaton in Python
I tried to implement breakout (deception avoidance type) with Quantx
[Django] I tried to implement access control by class inheritance.
I tried to explain Pytorch dataset
I tried Watson Speech to Text
I tried running GAN in Colaboratory
I tried to touch Tesla's API
I tried to implement ListNet of rank learning with Chainer
I tried to implement the mail sending function in Python
I tried to implement Harry Potter sort hat with CNN
I tried to organize about MCMC.
I tried to implement Perceptron Part 1 [Deep Learning from scratch]
I tried to implement blackjack of card game in Python
I tried to move the ball
I tried to estimate the interval.
I tried to implement SSD with PyTorch now (model edition)
[Python] I tried to implement stable sorting, so make a note
I tried to implement anomaly detection using a hidden Markov model
I tried to implement a misunderstood prisoner's dilemma game in Python
I tried to implement sentence classification by Self Attention with PyTorch
I tried to create a linebot (implementation)
I tried to summarize Python exception handling
I tried using Azure Speech to Text.