[PYTHON] I tried to move GAN (mnist) with keras

Introduction

This time, GAN (Generative Adversarial Network) is [this book](https://www.amazon.co.jp/%E5%AE%9F%E8%B7%B5GAN-%E6%95%B5%E5%AF % BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC % E3% 82% AF% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E6% B7% B1% E5% B1% A4% E5% AD% A6% E7% BF% 92-Compass -Books% E3% 82% B7% E3% 83% AA% E3% 83% BC% E3% 82% BA-Jakub-Langr-ebook / dp / B08573Y8GP) so I would like to summarize it. After writing it in several parts, I will summarize it firmly at the end. ** So this article is very rough. ** ** In this article, I will write about the explanation and summary of the simple implementation code of GAN introduced in this book. I will leave the detailed explanation of GAN to other sites, and in this article I will only give an overview. (If there seems to be demand, I would like to post a summary later.) I will post the GAN site that seems to be helpful for GAN beginners. [GAN: What is a hostile generation network? -Image generation by "unsupervised learning"](https://www.imagazine.co.jp/gan%EF%BC%9A%E6%95%B5%E5% AF% BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC% E3% 82% AF% E3% 81% A8% E3% 81% AF% E4% BD% 95% E3% 81% 8B% E3% 80% 80% EF% BD% 9E% E3% 80% 8C% E6% 95% 99% E5% B8% AB /)

What is GAN

I will explain briefly.

GAN is a Japanese word for a hostile generation network. It's a variant of DNN and is now very popular in the field of artificial intelligence.

Learn the characteristics of the input data and generate something similar to the input data. Data can be audio, text, images, etc. For example, if you input a large number of cat images, the output will be cat images (if you have learned well).

As an algorithm, we will prepare two DNNs, and divide them into a person who generates an image and a person who distinguishes whether the image is a real image or a generated image. By competing these two models, an image close to the input image is output.

result

I will post the result first. This is the generated image. download.png On the other hand, here is the input data. download.png

This code was done with a very simple DNN, so there is still room for improvement, but even so, I was surprised that even a small number of modeld could express it so far. I will post an improved version in the next article.

code

It will be the code. I'm likely to be asked to explain in more detail. I'm thinking of putting it together again after I've read the book, so please wait until then! (If you like it, it will be very encouraging)

simple_gan.py


%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

#mnist shape[28, 28, 1]Define
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
#The dimension of noise that the generator inputs to generate the image
z_dim = 100

#generator(Function for defining the generator)
def build_generator(img_shape, z_dim):
  model = Sequential()
  model.add(Dense(128, input_dim=z_dim))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(28*28*1, activation='tanh'))
  model.add(Reshape(img_shape))
  return model

#Function for defining a discriminator
def build_discriminatior(img_shape):
  model = Sequential()
  model.add(Flatten(input_shape=img_shape))
  model.add(Dense(128))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(1, activation='sigmoid'))
  return model

#Gan model definition(Function for connecting the generator and the classifier)
def build_gan(generator, discriminator):
  model = Sequential()
  model.add(generator)
  model.add(discriminator)
  return model

#I will compile the GAN model by actually calling the function
discriminator = build_discriminatior(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
generator = build_generator(img_shape, z_dim)

#I will turn off the learning function of the classifier. By doing this, the classifier and the generator can be trained separately.
discriminator.trainable = False 

gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())


losses = []
accuracies = []
iteration_checkpoint = []
#A function to let you learn. Take the number of iterations, batch size, and how many iterations to generate and visualize the image as arguments
def train(iterations, batch_size, sample_interval):
  (x_train, _), (_, _) = mnist.load_data()

  x_train = x_train / 127.5 - 1
  x_train = np.expand_dims(x_train, axis=3)
  
  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for iteration in range(iterations):

    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs = x_train[idx]
    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    d_loss_real = discriminator.train_on_batch(imgs, real)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss, acc = 0.5 * np.add(d_loss_real, d_loss_fake)

    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    g_loss = gan.train_on_batch(z, real)
#sample_Save loss value, accuracy, and checkpoint for each interval
    if (iteration+1) % sample_interval == 0:
      losses.append((d_loss, g_loss))
      accuracies.append(acc)
      iteration_checkpoint.append(iteration+1)
#Generate image
      sample_images(generator)

#Function to generate an image as a sample
def sample_images(generator, image_grid_rows =4, image_grid_colmuns=4):
  z = np.random.normal(0, 1, (image_grid_rows*image_grid_colmuns, z_dim))
  gen_images = generator.predict(z)

  gen_images = 0.5 * gen_images + 0.5

  fig, axs = plt.subplots(image_grid_rows, image_grid_colmuns, figsize=(4,4), sharex=True, sharey=True)

  cnt = 0
  for i in range(image_grid_rows):
    for j in range(image_grid_colmuns):
      axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray')
      axs[i, j].axis('off')
      cnt += 1

Recommended Posts

I tried to move GAN (mnist) with keras
I tried to integrate with Keras in TFv1.1
I tried to move machine learning (ObjectDetection) with TouchDesigner
I tried to implement Grad-CAM with keras and tensorflow
I tried to implement Realness GAN
I tried to move the ball
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement Autoencoder with TensorFlow
I tried to visualize AutoEncoder with TensorFlow
I tried to get started with Hy
I tried to implement CVAE with PyTorch
I tried to solve TSP with QAOA
I tried to make deep learning scalable with Spark × Keras × Docker
I tried to predict next year with AI
I tried to detect Mario with pytorch + yolov3
I tried to implement reading Dataset with PyTorch
I tried to use lightGBM, xgboost with Boruta
I tried to save the data with discord
I tried to detect motion quickly with OpenCV
I tried to get CloudWatch data with Python
I tried to output LLVM IR with Python
I tried to debug.
I tried to detect an object with M2Det!
I tried to automate sushi making with python
I tried to predict Titanic survival with PyCaret
I tried to paste
I tried to operate Linux with Discord Bot
I tried to study DP with Fibonacci sequence
I tried to start Jupyter with Amazon lightsail
I tried to judge Tsundere with Naive Bayes
I tried to make deep learning scalable with Spark × Keras × Docker 2 Multi-host edition
I tried to learn the sin function with chainer
I tried to create a table only with Django
I tried to read and save automatically with VOICEROID2 2
I tried to implement and learn DCGAN with PyTorch
I tried to implement Minesweeper on terminal with python
I tried to get started with blender python script_Part 01
I tried to touch the CSV file with Python
I tried to draw a route map with Python
I tried to solve the soma cube with python
I tried to automatically read and save with VOICEROID2
I tried to get started with blender python script_Part 02
I tried to generate ObjectId (primary key) with pymongo
I tried to implement an artificial perceptron with python
I tried to build ML Pipeline with Cloud Composer
I tried to implement time series prediction with GBDT
I tried to uncover our darkness with Chatwork API
I tried to automatically generate a password with Python3
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I tried to solve the problem with Python Vol.1
I tried to analyze J League data with Python
I tried to make an OCR application with PySimpleGUI
I tried to implement SSD with PyTorch now (Dataset)
I tried to interpolate Mask R-CNN with Optical Flow
I tried to step through Bayesian optimization. (With examples)
I tried to find an alternating series with tensorflow
[Introduction to AWS] I tried playing with voice-text conversion ♪
I tried to solve AOJ's number theory with Python
I tried fp-growth with python
I tried scraping with Python
I tried to learn PredNet