This time, GAN (Generative Adversarial Network) is [this book](https://www.amazon.co.jp/%E5%AE%9F%E8%B7%B5GAN-%E6%95%B5%E5%AF % BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC % E3% 82% AF% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E6% B7% B1% E5% B1% A4% E5% AD% A6% E7% BF% 92-Compass -Books% E3% 82% B7% E3% 83% AA% E3% 83% BC% E3% 82% BA-Jakub-Langr-ebook / dp / B08573Y8GP) so I would like to summarize it. After writing it in several parts, I will summarize it firmly at the end. ** So this article is very rough. ** ** In this article, I will write about the explanation and summary of the simple implementation code of GAN introduced in this book. I will leave the detailed explanation of GAN to other sites, and in this article I will only give an overview. (If there seems to be demand, I would like to post a summary later.) I will post the GAN site that seems to be helpful for GAN beginners. [GAN: What is a hostile generation network? -Image generation by "unsupervised learning"](https://www.imagazine.co.jp/gan%EF%BC%9A%E6%95%B5%E5% AF% BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC% E3% 82% AF% E3% 81% A8% E3% 81% AF% E4% BD% 95% E3% 81% 8B% E3% 80% 80% EF% BD% 9E% E3% 80% 8C% E6% 95% 99% E5% B8% AB /)
I will explain briefly.
GAN is a Japanese word for a hostile generation network. It's a variant of DNN and is now very popular in the field of artificial intelligence.
Learn the characteristics of the input data and generate something similar to the input data. Data can be audio, text, images, etc. For example, if you input a large number of cat images, the output will be cat images (if you have learned well).
As an algorithm, we will prepare two DNNs, and divide them into a person who generates an image and a person who distinguishes whether the image is a real image or a generated image. By competing these two models, an image close to the input image is output.
I will post the result first. This is the generated image. On the other hand, here is the input data.
This code was done with a very simple DNN, so there is still room for improvement, but even so, I was surprised that even a small number of modeld could express it so far. I will post an improved version in the next article.
It will be the code. I'm likely to be asked to explain in more detail. I'm thinking of putting it together again after I've read the book, so please wait until then! (If you like it, it will be very encouraging)
simple_gan.py
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam
#mnist shape[28, 28, 1]Define
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
#The dimension of noise that the generator inputs to generate the image
z_dim = 100
#generator(Function for defining the generator)
def build_generator(img_shape, z_dim):
model = Sequential()
model.add(Dense(128, input_dim=z_dim))
model.add(LeakyReLU(alpha=0.01))
model.add(Dense(28*28*1, activation='tanh'))
model.add(Reshape(img_shape))
return model
#Function for defining a discriminator
def build_discriminatior(img_shape):
model = Sequential()
model.add(Flatten(input_shape=img_shape))
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.01))
model.add(Dense(1, activation='sigmoid'))
return model
#Gan model definition(Function for connecting the generator and the classifier)
def build_gan(generator, discriminator):
model = Sequential()
model.add(generator)
model.add(discriminator)
return model
#I will compile the GAN model by actually calling the function
discriminator = build_discriminatior(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
generator = build_generator(img_shape, z_dim)
#I will turn off the learning function of the classifier. By doing this, the classifier and the generator can be trained separately.
discriminator.trainable = False
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())
losses = []
accuracies = []
iteration_checkpoint = []
#A function to let you learn. Take the number of iterations, batch size, and how many iterations to generate and visualize the image as arguments
def train(iterations, batch_size, sample_interval):
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 127.5 - 1
x_train = np.expand_dims(x_train, axis=3)
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for iteration in range(iterations):
idx = np.random.randint(0, x_train.shape[0], batch_size)
imgs = x_train[idx]
z = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = generator.predict(z)
d_loss_real = discriminator.train_on_batch(imgs, real)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss, acc = 0.5 * np.add(d_loss_real, d_loss_fake)
z = np.random.normal(0, 1, (batch_size, 100))
gen_imgs = generator.predict(z)
g_loss = gan.train_on_batch(z, real)
#sample_Save loss value, accuracy, and checkpoint for each interval
if (iteration+1) % sample_interval == 0:
losses.append((d_loss, g_loss))
accuracies.append(acc)
iteration_checkpoint.append(iteration+1)
#Generate image
sample_images(generator)
#Function to generate an image as a sample
def sample_images(generator, image_grid_rows =4, image_grid_colmuns=4):
z = np.random.normal(0, 1, (image_grid_rows*image_grid_colmuns, z_dim))
gen_images = generator.predict(z)
gen_images = 0.5 * gen_images + 0.5
fig, axs = plt.subplots(image_grid_rows, image_grid_colmuns, figsize=(4,4), sharex=True, sharey=True)
cnt = 0
for i in range(image_grid_rows):
for j in range(image_grid_colmuns):
axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
Recommended Posts