[PYTHON] Implementation example of hostile generation network (GAN) by keras [For beginners]

What I did in this article

**-Minst image generation by GAN --Introducing the implementation method using keras **

Introduction

Generative adversarial network, or GAN. I often hear that it's really popular, but when you actually try to implement it yourself, it's quite a threshold.

It's a technology that seems to be important for me, so I left it alone just by looking at it from the outside. Surprisingly, there are quite a lot of people like that.

** This time, I will introduce an example of implementing such GAN using mnist data. ** ** Data and code ["Unsupervised learning with python"](url https://www.amazon.co.jp/Python%E3%81%A7%E3%81%AF%E3%81%98%E3%82 % 81% E3% 82% 8B% E6% 95% 99% E5% B8% AB% E3% 81% AA% E3% 81% 97% E5% AD% A6% E7% BF% 92-% E2% 80% 95% E6% A9% 9F% E6% A2% B0% E5% AD% A6% E7% BF% 92% E3% 81% AE% E5% 8F% AF% E8% 83% BD% E6% 80% A7% E3% 82% 92% E5% BA% 83% E3% 81% 92% E3% 82% 8B% E3% 83% A9% E3% 83% 99% E3% 83% AB% E3% 81% AA% E3% 81% 97% E3% 83% 87% E3% 83% BC% E3% 82% BF% E3% 81% AE% E5% 88% A9% E7% 94% A8-Ankur-Patel / dp / 4873119103) I am allowed to.

The book I referred to was written using object-oriented programming, so it was a little high level, but it was a great learning experience.

Similarly, I hope it will be helpful for beginners.

Here are the results I got first. Because it has an impact on the appearance.

** Genuine ** image.png

** Generate ** mnist_14000.png

** I feel that the generated images are disgustingly similar ... ** If you learn longer, you may be able to do better things.

What is a Generative Adversarial Network (GAN)?

Here is a brief overview. Please refer to this article for details. GAN: What is a hostile generation network? Image generation by "unsupervised learning" https://www.imagazine.co.jp/gan%EF%BC%9A%E6%95%B5%E5%AF%BE%E7%9A%84%E7%94%9F%E6%88%90%E3%83%8D%E3%83%83%E3%83%88%E3%83%AF%E3%83%BC%E3%82%AF%E3%81%A8%E3%81%AF%E4%BD%95%E3%81%8B%E3%80%80%EF%BD%9E%E3%80%8C%E6%95%99%E5%B8%AB/

** With GAN, you can learn a dataset and create data that looks just like the same dataset. ** ** In the example of the reference article, GAN is used to generate a photo of the bedroom that does not actually exist. It's hard to tell, machine learning is scary.

Since this article uses mnist, we will generate handwritten characters. How do you generate this handwritten character?

** At GAN, there are two models, one that generates the data and one that identifies the data. In the model that generates data, we will create data that looks like handwritten characters. Then, the created data is used as an identification model to determine whether it is fake or genuine. Then, based on the result, we will train the generative model and then create an image that is closer to the real thing. ** **

Simply put, it's just this model. The only question that remains is how to train the data and how to train the data. I think.

Data learning

In this model, data training is performed as follows.

**-Generate an image (1 * 28 * 28) from noise (100 * 1 * 1) with a generative model --Learning the discriminative model with "actual image" and "image created by generative model" --Create a new image from the generative model. The generative model and the discriminative model are trained so that the generated image is classified as a "real image" in the discriminative model. ** **

We will actually implement this model.

Library import

It's just a reference book, but I've improved it a little so that it can be used with google colab.

python



'''Main'''
import numpy as np
import pandas as pd
import os, time, re
import pickle, gzip, datetime

'''Data Viz'''
import matplotlib.pyplot as plt
import seaborn as sns
color = sns.color_palette()
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import Grid

%matplotlib inline

'''Data Prep and Model Evaluation'''
from sklearn import preprocessing as pp
from sklearn.model_selection import train_test_split 
from sklearn.model_selection import StratifiedKFold 
from sklearn.metrics import log_loss, accuracy_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import roc_curve, auc, roc_auc_score, mean_squared_error
from keras.utils import to_categorical

'''Algos'''
import lightgbm as lgb

'''TensorFlow and Keras'''
import tensorflow as tf
import keras
from keras import backend as K
from keras.models import Sequential, Model
from keras.layers import Activation, Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.layers import LeakyReLU, Reshape, UpSampling2D, Conv2DTranspose
from keras.layers import BatchNormalization, Input, Lambda
from keras.layers import Embedding, Flatten, dot
from keras import regularizers
from keras.losses import mse, binary_crossentropy
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.optimizers import Adam, RMSprop

from keras.datasets import mnist

sns.set("talk")

Data read

It is reading data. Use minst data. It is intended for use in colaboratory. Since we only use x_train, we only normalize reshpae and 0 to 1 values to x_train.

python



#Data divided into training data and test data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1))
#Pixel value 0~Normalize between 1
x_train= x_train / 255.0

DCGAN class design

Super important DCGAN code. It is defined by a class that summarizes the generative model and the discriminative model. To briefly describe the function of each function generator --Neural network that converts 100 * 1 * 1 vector to 28 * 28 * 1 image ――By learning this, an image like that will be generated.

discriminator --A neural network that identifies whether an image of 28 * 28 * 1 is genuine or fake

discriminator_model --Compile and model a neural network for identification

adversarial?model --A model created by connecting generaor and discriminator --Train the generated network with this model

python


#DCGAN class
class DCGAN(object):
  #Initialization
    def __init__(self, img_rows=28, img_cols=28, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None   # discriminator
        self.G = None   # generator
        self.AM = None  # adversarial model
        self.DM = None  # discriminator model
    
    #Generation network
    #100*1*The matrix of 1 is the same as the image in the dataset 1*28*28
    def generator(self, depth=256, dim=7, dropout=0.3, momentum=0.8, \
                  window=5, input_dim=100, output_depth=1):
        if self.G:
            return self.G
        self.G = Sequential()

        #100*1*1 → 256*7*7
        self.G.add(Dense(dim*dim*depth, input_dim=input_dim))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))
        self.G.add(Reshape((dim, dim, depth)))
        self.G.add(Dropout(dropout))
        
        #256*7*7 → 128*14*14
        self.G.add(UpSampling2D())
        self.G.add(Conv2DTranspose(int(depth/2), window, padding='same'))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))

        #128*14*14 → 64*28*28
        self.G.add(UpSampling2D())
        self.G.add(Conv2DTranspose(int(depth/4), window, padding='same'))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))

        #64*28*28→32*28*28
        self.G.add(Conv2DTranspose(int(depth/8), window, padding='same'))
        self.G.add(BatchNormalization(momentum=momentum))
        self.G.add(Activation('relu'))

        #1*28*28
        self.G.add(Conv2DTranspose(output_depth, window, padding='same'))
        #Set each pixel to a value between 0 and 1
        self.G.add(Activation('sigmoid')) 
        self.G.summary()
        return self.G


    #Identification network
    #28*28*Distinguish whether the image of 1 is genuine
    def discriminator(self, depth=64, dropout=0.3, alpha=0.3):
        if self.D:
            return self.D

        self.D = Sequential()
        input_shape = (self.img_rows, self.img_cols, self.channel)

      #28*28*1 → 14*14*64
        self.D.add(Conv2D(depth*1, 5, strides=2, input_shape=input_shape,padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

      #14*14*64 → 7*7*128
        self.D.add(Conv2D(depth*2, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

      #7*7*128 → 4*4*256
        self.D.add(Conv2D(depth*4, 5, strides=2, padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

        #4*4*512 → 4*4*512 ####However, check if it matches###
        self.D.add(Conv2D(depth*8, 5, strides=1, padding='same'))
        self.D.add(LeakyReLU(alpha=alpha))
        self.D.add(Dropout(dropout))

        #Flatten and classify by sigmoid
        self.D.add(Flatten())
        self.D.add(Dense(1))
        self.D.add(Activation('sigmoid'))

        self.D.summary()
        return self.D

    #Discriminative model
    def discriminator_model(self):
        if self.DM:
            return self.DM
        optimizer = RMSprop(lr=0.0002, decay=6e-8)
        self.DM = Sequential()
        self.DM.add(self.discriminator())
        self.DM.compile(loss='binary_crossentropy', \
                        optimizer=optimizer, metrics=['accuracy'])
        return self.DM

    #Generative model
    def adversarial_model(self):
        if self.AM:
            return self.AM
        optimizer = RMSprop(lr=0.0001, decay=3e-8)
        self.AM = Sequential()
        self.AM.add(self.generator())
        self.AM.add(self.discriminator())
        self.AM.compile(loss='binary_crossentropy', \
                        optimizer=optimizer, metrics=['accuracy'])
        return self.AM

DCGAN class design for mnist

Next, we will use these functions to actually train the minst data to generate an image. Train the image with the train function and save the image with plot_images.

The train function is executed in the following flow.

**-Generate training data from noise --Apply the generated data to the discriminative model. Save how well you could identify at this time in D_loss. --Learning with adversarial_model so that the generated data looks real. Save how much you were deceived at this time in A_loss. ** **

python


#A class that applies DCGAN to MNIST data
class MNIST_DCGAN(object):
    #Initialization
    def __init__(self, x_train):
        self.img_rows = 28
        self.img_cols = 28
        self.channel = 1

        self.x_train = x_train

        #Identification of DCGAN, definition of hostile generative model
        self.DCGAN = DCGAN()
        self.discriminator =  self.DCGAN.discriminator_model()
        self.adversarial = self.DCGAN.adversarial_model()
        self.generator = self.DCGAN.generator()

    #Training function
    #train_on_batch is learning for each batch. Output is loss and acc
    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None

        if save_interval>0:
            noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])

        for i in range(train_steps):
            #Batch training data_Randomly take out only size
            images_train = self.x_train[np.random.randint(0,self.x_train.shape[0], size=batch_size), :, :, :] 
            
            # 100*1*Generate noise of 1 by batch size and make it a fake image
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])
            
            #Train the generated image
            images_fake = self.generator.predict(noise)
            x = np.concatenate((images_train, images_fake))
            #Set the training data to 1 and the generated data to 0
            y = np.ones([2*batch_size, 1])
            y[batch_size:, :] = 0
            
            #Train the discriminative model
            d_loss = self.discriminator.train_on_batch(x, y)

            y = np.ones([batch_size, 1])
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100])

            #Generate&Train the discriminative model
            #Training of the generative model is done only here
            a_loss = self.adversarial.train_on_batch(noise, y)

            #Loss and accuracy of training data and generative models
            #D loss is the loss and acc of the generated image and the actual image
            #A loss is loss and acc when the image generated by adversarial is 1.
            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            print(log_mesg)

            #save_Save data for each interval
            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(save2file=True, \
                        samples=noise_input.shape[0],\
                        noise=noise_input, step=(i+1))

    #Plot training results
    def plot_images(self, save2file=False, fake=True, samples=16, \
                    noise=None, step=0):
        current_path = os.getcwd()
        file = os.path.sep.join(["","data", 'images', 'chapter12', 'synthetic_mnist', ''])
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
            else:
                filename = "mnist_%d.png " % step
            images = self.generator.predict(noise)
        else:
            i = np.random.randint(0, self.x_train.shape[0], samples)
            images = self.x_train[i, :, :, :]

        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.img_rows, self.img_cols])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        if save2file:
            plt.savefig(current_path+file+filename)
            plt.close('all')
        else:
            plt.show()

At the end

GAN is amazing. I even feel uncomfortable when something like handwritten characters is generated.

It seems that it can be used for abnormality detection etc. at the actual site.

However, in the summary of the reference books, there was a statement that ** "Please be prepared for a great deal of effort when using GAN" **. There was no detailed reason for that ...

How hard do you have, GAN.

Until the end Thank you for reading.

Recommended Posts

Implementation example of hostile generation network (GAN) by keras [For beginners]
Anomaly detection by autoencoder using keras [Implementation example for beginners]
Implementation example of LINE BOT server for actual operation
[Linux] Basics of authority setting by chmod for beginners
Save the output of conditional GAN for each class ~ With cGAN implementation by PyTorch ~
[For beginners] Basics of Python explained by Java Gold Part 2
Save the output of GAN one by one ~ With the implementation of GAN by PyTorch ~
■ Kaggle Practice for Beginners --Introduction of Python --by Google Colaboratory
[For beginners] Basics of Python explained by Java Gold Part 1
Overview of Docker (for beginners)
Implementation of Scale-space for SIFT
Summary of basic implementation by PyTorch
Implementation of a two-layer neural network 2
[Must-see for beginners] Basics of Linux
An implementation of ArcFace for TensorFlow
Example of code rewriting by ast.NodeTransformer
Derivation of multivariate t distribution and implementation of random number generation by python
[Example of Python improvement] What is the recommended learning site for Python beginners?
Implement a model with state and behavior (3) --Example of implementation by decorator