[PYTHON] J'ai essayé de déplacer GAN (mnist) avec keras

introduction

Cette fois, GAN (Generative Adversarial Network) est [ce livre](https://www.amazon.co.jp/%E5%AE%9F%E8%B7%B5GAN-%E6%95%B5%E5%AF % BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC % E3% 82% AF% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E6% B7% B1% E5% B1% A4% E5% AD% A6% E7% BF% 92-Compas -Books% E3% 82% B7% E3% 83% AA% E3% 83% BC% E3% 82% BA-Jakub-Langr-ebook / dp / B08573Y8GP) je voudrais donc le résumer. Après l'avoir rédigé en plusieurs parties, je le résumerai fermement à la fin. ** Cet article est donc très approximatif. ** ** Cet article concerne l'explication et le résumé du code d'implémentation simple du GAN présenté dans ce livre. Je laisserai l'explication détaillée du GAN à d'autres sites, et dans cet article je ne donnerai qu'un aperçu. (S'il semble y avoir une demande, j'aimerais publier un résumé plus tard.) Je publierai le site GAN qui semble être utile pour les débutants du GAN. [GAN: Qu'est-ce qu'un réseau de génération hostile? -Génération d'images par "l'apprentissage sans enseignant"](https://www.imagazine.co.jp/gan%EF%BC%9A%E6%95%B5%E5% AF% BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC% E3% 82% AF% E3% 81% A8% E3% 81% AF% E4% BD% 95% E3% 81% 8B% E3% 80% 80% EF% BD% 9E% E3% 80% 8C% E6% 95% 99% E5% B8% AB /)

Qu'est-ce que le GAN

Je vais vous expliquer brièvement.

GAN est appelé un réseau de génération hostile en japonais. C'est une variante du DNN et est maintenant très populaire dans le domaine de l'intelligence artificielle.

Apprenez les caractéristiques des données d'entrée et générez quelque chose de similaire aux données d'entrée. Les données peuvent être audio, texte, images, etc. Par exemple, si vous entrez une grande quantité d'images de chat, la sortie sera des images de chat (si vous avez bien appris).

En tant qu'algorithme, nous préparerons deux DNN et les diviserons en une personne qui génère une image et une personne qui distingue si l'image est une image réelle ou une image générée. En concurrençant ces deux modèles, une image proche de l'image d'entrée est sortie.

résultat

Je publierai le résultat en premier. Ceci est l'image générée. download.png En revanche, voici les données d'entrée. download.png

Ce code a été fait avec un DNN très simple, il y a donc encore place à l'amélioration, mais même ainsi, j'ai été surpris que même un petit nombre de modeld puisse l'exprimer jusqu'à présent. Je publierai une version améliorée dans le prochain article.

code

Ce sera le code. On me demandera probablement d'expliquer plus en détail. Je pense à le remonter après avoir lu le livre, alors attendez d'ici là! (Si vous l'aimez, ce sera très encourageant)

simple_gan.py


%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

#Forme de mnist[28, 28, 1]Définir
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
#La dimension du bruit que le générateur entrera pour générer l'image
z_dim = 100

#generator(Fonction de définition du générateur)
def build_generator(img_shape, z_dim):
  model = Sequential()
  model.add(Dense(128, input_dim=z_dim))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(28*28*1, activation='tanh'))
  model.add(Reshape(img_shape))
  return model

#Fonctions de définition du discriminateur
def build_discriminatior(img_shape):
  model = Sequential()
  model.add(Flatten(input_shape=img_shape))
  model.add(Dense(128))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(1, activation='sigmoid'))
  return model

#Définition du modèle Gan(Fonction de connexion du générateur et du classificateur)
def build_gan(generator, discriminator):
  model = Sequential()
  model.add(generator)
  model.add(discriminator)
  return model

#Je vais en fait appeler la fonction et compiler le modèle GAN
discriminator = build_discriminatior(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
generator = build_generator(img_shape, z_dim)

#Je vais désactiver la fonction d'apprentissage du classificateur. En faisant cela, le discriminateur et le générateur peuvent être formés séparément.
discriminator.trainable = False 

gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())


losses = []
accuracies = []
iteration_checkpoint = []
#Une fonction pour vous permettre d'apprendre. Prenez le nombre d'itérations, la taille du lot et le nombre d'itérations à générer et visualiser l'image sous forme d'arguments
def train(iterations, batch_size, sample_interval):
  (x_train, _), (_, _) = mnist.load_data()

  x_train = x_train / 127.5 - 1
  x_train = np.expand_dims(x_train, axis=3)
  
  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for iteration in range(iterations):

    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs = x_train[idx]
    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    d_loss_real = discriminator.train_on_batch(imgs, real)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss, acc = 0.5 * np.add(d_loss_real, d_loss_fake)

    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    g_loss = gan.train_on_batch(z, real)
#sample_Enregistrer la valeur de la perte, la précision et le point de contrôle pour chaque intervalle
    if (iteration+1) % sample_interval == 0:
      losses.append((d_loss, g_loss))
      accuracies.append(acc)
      iteration_checkpoint.append(iteration+1)
#Générer une image
      sample_images(generator)

#Fonction pour générer une image comme échantillon
def sample_images(generator, image_grid_rows =4, image_grid_colmuns=4):
  z = np.random.normal(0, 1, (image_grid_rows*image_grid_colmuns, z_dim))
  gen_images = generator.predict(z)

  gen_images = 0.5 * gen_images + 0.5

  fig, axs = plt.subplots(image_grid_rows, image_grid_colmuns, figsize=(4,4), sharex=True, sharey=True)

  cnt = 0
  for i in range(image_grid_rows):
    for j in range(image_grid_colmuns):
      axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray')
      axs[i, j].axis('off')
      cnt += 1

Recommended Posts

J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé de déplacer l'apprentissage automatique (détection d'objet) avec TouchDesigner
J'ai essayé d'implémenter Grad-CAM avec keras et tensorflow
J'ai essayé d'implémenter Realness GAN
J'ai essayé de déplacer le ballon
J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai essayé d'implémenter Autoencoder avec TensorFlow
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de commencer avec Hy
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé de résoudre TSP avec QAOA
J'ai essayé de rendre le deep learning évolutif avec Spark × Keras × Docker
J'ai essayé de prédire l'année prochaine avec l'IA
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
J'ai essayé d'utiliser lightGBM, xg boost avec Boruta
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé de détecter rapidement un mouvement avec OpenCV
J'ai essayé d'obtenir des données CloudWatch avec Python
J'ai essayé de sortir LLVM IR avec Python
J'ai essayé de déboguer.
J'ai essayé de détecter un objet avec M2Det!
J'ai essayé d'automatiser la fabrication des sushis avec python
J'ai essayé de prédire la survie du Titanic avec PyCaret
J'ai essayé d'utiliser Linux avec Discord Bot
J'ai essayé d'étudier DP avec séquence de Fibonacci
J'ai essayé de démarrer Jupyter avec toutes les lumières d'Amazon
J'ai essayé de juger Tundele avec Naive Bays
J'ai essayé de rendre le deep learning évolutif avec Spark × Keras × Docker 2 Multi-host edition
J'ai essayé d'entraîner la fonction péché avec chainer
J'ai essayé de lire et d'enregistrer automatiquement avec VOICEROID2 2
J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch
J'ai essayé d'implémenter Mine Sweeper sur un terminal avec python
J'ai essayé de démarrer avec le script python de blender_Part 01
J'ai essayé de toucher un fichier CSV avec Python
J'ai essayé de résoudre Soma Cube avec python
J'ai essayé de lire et d'enregistrer automatiquement avec VOICEROID2
J'ai essayé de démarrer avec le script python de blender_Partie 02
J'ai essayé de générer ObjectId (clé primaire) avec pymongo
J'ai essayé d'implémenter le perceptron artificiel avec python
J'ai essayé de créer un pipeline ML avec Cloud Composer
J'ai essayé de découvrir notre obscurité avec l'API Chatwork
[Introduction à Pytorch] J'ai essayé de catégoriser Cifar10 avec VGG16 ♬
J'ai essayé de résoudre le problème avec Python Vol.1
J'ai essayé de créer une application OCR avec PySimpleGUI
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai essayé d'interpoler le masque R-CNN avec un flux optique
J'ai essayé de passer par l'optimisation bayésienne. (Avec des exemples)
J'ai essayé de trouver la classe alternative avec tensorflow
[Introduction à AWS] J'ai essayé de jouer avec la conversion voix-texte ♪
J'ai essayé de résoudre la théorie des nombres entiers d'AOJ avec Python
J'ai essayé fp-growth avec python
J'ai essayé de gratter avec Python
J'ai essayé d'apprendre PredNet