[PYTHON] Ich habe versucht, GAN (mnist) mit Keras zu bewegen

Einführung

Dieses Mal ist GAN (Generative Adversarial Network) [dieses Buch](https://www.amazon.co.jp/%E5%AE%9F%E8%B7%B5GAN-%E6%95%B5%E5%AF % BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC % E3% 82% AF% E3% 81% AB% E3% 82% 88% E3% 82% 8B% E6% B7% B1% E5% B1% A4% E5% AD% A6% E7% BF% 92-Kompass -Bücher% E3% 82% B7% E3% 83% AA% E3% 83% BC% E3% 82% BA-Jakub-Langr-ebook / dp / B08573Y8GP), daher möchte ich es zusammenfassen. Nachdem ich es in mehreren Teilen geschrieben habe, werde ich es am Ende fest zusammenfassen. ** Dieser Artikel ist also sehr rau. ** **. In diesem Artikel geht es um die Erklärung und Zusammenfassung des in diesem Buch vorgestellten einfachen Implementierungscodes von GAN. Ich werde die detaillierte Erklärung von GAN anderen Sites überlassen und in diesem Artikel nur einen Überblick geben. (Wenn es Nachfrage zu geben scheint, möchte ich später eine Zusammenfassung veröffentlichen.) Ich werde die GAN-Site veröffentlichen, die für GAN-Anfänger hilfreich zu sein scheint. [GAN: Was ist ein feindliches Generationsnetzwerk? - Bilderzeugung durch "Lernen ohne Lehrer"](https://www.imagazine.co.jp/gan%EF%BC%9A%E6%95%B5%E5% AF% BE% E7% 9A% 84% E7% 94% 9F% E6% 88% 90% E3% 83% 8D% E3% 83% 83% E3% 83% 88% E3% 83% AF% E3% 83% BC% E3% 82% AF% E3% 81% A8% E3% 81% AF% E4% BD% 95% E3% 81% 8B% E3% 80% 80% EF% BD% 9E% E3% 80% 8C% E6% 95% 99% E5% B8% AB /)

Was ist GAN?

Ich werde es kurz erklären.

GAN wird auf Japanisch als feindliches Generationsnetzwerk bezeichnet. Es ist eine Variante von DNN und ist heute im Bereich der künstlichen Intelligenz sehr beliebt.

Lernen Sie die Eigenschaften der Eingabedaten kennen und generieren Sie etwas Ähnliches wie die Eingabedaten. Die Daten können Audio, Text, Bilder usw. sein. Wenn Sie beispielsweise eine große Anzahl von Katzenbildern eingeben, werden Katzenbilder ausgegeben (wenn Sie gut gelernt haben).

Als Algorithmus bereiten wir zwei DNNs vor und teilen sie in eine Person, die ein Bild erzeugt, und eine Person, die unterscheidet, ob das Bild ein reales oder ein erzeugtes Bild ist. Durch die Konkurrenz dieser beiden Modelle wird ein Bild in der Nähe des Eingabebildes ausgegeben.

Ergebnis

Ich werde das Ergebnis zuerst veröffentlichen. Dies ist das erzeugte Bild. download.png Auf der anderen Seite sind hier die Eingabedaten. download.png

Dieser Code wurde mit einem sehr einfachen DNN erstellt, sodass noch Verbesserungspotenzial besteht. Trotzdem war ich überrascht, dass selbst eine kleine Anzahl von Modellen ihn bisher ausdrücken konnte. Ich werde eine verbesserte Version im nächsten Artikel veröffentlichen.

Code

Es wird der Code sein. Ich werde wahrscheinlich gebeten, dies genauer zu erklären. Ich denke darüber nach, es wieder zusammenzusetzen, nachdem ich das Buch gelesen habe, also warte bitte bis dahin! (Wenn es Ihnen gefällt, wird es sehr ermutigend sein)

simple_gan.py


%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

#Form des Mnisten[28, 28, 1]Definieren
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
#Die Dimension des Rauschens, die der Generator eingibt, um das Bild zu erzeugen
z_dim = 100

#generator(Funktion zur Definition des Generators)
def build_generator(img_shape, z_dim):
  model = Sequential()
  model.add(Dense(128, input_dim=z_dim))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(28*28*1, activation='tanh'))
  model.add(Reshape(img_shape))
  return model

#Funktionen zum Definieren des Diskriminators
def build_discriminatior(img_shape):
  model = Sequential()
  model.add(Flatten(input_shape=img_shape))
  model.add(Dense(128))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(1, activation='sigmoid'))
  return model

#Gan-Modelldefinition(Funktion zum Verbinden von Generator und Klassifikator)
def build_gan(generator, discriminator):
  model = Sequential()
  model.add(generator)
  model.add(discriminator)
  return model

#Ich werde die Funktion tatsächlich aufrufen und das GAN-Modell kompilieren
discriminator = build_discriminatior(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
generator = build_generator(img_shape, z_dim)

#Ich werde die Lernfunktion des Klassifikators ausschalten. Auf diese Weise können der Diskriminator und der Generator getrennt trainiert werden.
discriminator.trainable = False 

gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())


losses = []
accuracies = []
iteration_checkpoint = []
#Eine Funktion, mit der Sie lernen können. Nehmen Sie die Anzahl der Iterationen, die Stapelgröße und die Anzahl der Iterationen, um das Bild als Argumente zu generieren und zu visualisieren
def train(iterations, batch_size, sample_interval):
  (x_train, _), (_, _) = mnist.load_data()

  x_train = x_train / 127.5 - 1
  x_train = np.expand_dims(x_train, axis=3)
  
  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for iteration in range(iterations):

    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs = x_train[idx]
    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    d_loss_real = discriminator.train_on_batch(imgs, real)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss, acc = 0.5 * np.add(d_loss_real, d_loss_fake)

    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    g_loss = gan.train_on_batch(z, real)
#sample_Speichern Sie den Verlustwert, die Genauigkeit und den Prüfpunkt für jedes Intervall
    if (iteration+1) % sample_interval == 0:
      losses.append((d_loss, g_loss))
      accuracies.append(acc)
      iteration_checkpoint.append(iteration+1)
#Bild erzeugen
      sample_images(generator)

#Funktion zum Generieren eines Bildes als Beispiel
def sample_images(generator, image_grid_rows =4, image_grid_colmuns=4):
  z = np.random.normal(0, 1, (image_grid_rows*image_grid_colmuns, z_dim))
  gen_images = generator.predict(z)

  gen_images = 0.5 * gen_images + 0.5

  fig, axs = plt.subplots(image_grid_rows, image_grid_colmuns, figsize=(4,4), sharex=True, sharey=True)

  cnt = 0
  for i in range(image_grid_rows):
    for j in range(image_grid_colmuns):
      axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray')
      axs[i, j].axis('off')
      cnt += 1

Recommended Posts

Ich habe versucht, GAN (mnist) mit Keras zu bewegen
Ich habe versucht, Keras in TFv1.1 zu integrieren
Ich habe versucht, maschinelles Lernen (Objekterkennung) mit TouchDesigner zu verschieben
Ich habe versucht, Grad-CAM mit Keras und Tensorflow zu implementieren
Ich habe versucht, Realness GAN zu implementieren
Ich habe versucht, den Ball zu bewegen
Ich habe versucht, MNIST nach GNN zu klassifizieren (mit PyTorch-Geometrie).
Ich habe versucht, Autoencoder mit TensorFlow zu implementieren
Ich habe versucht, AutoEncoder mit TensorFlow zu visualisieren
Ich habe versucht, mit Hy anzufangen
Ich habe versucht, CVAE mit PyTorch zu implementieren
Ich habe versucht, TSP mit QAOA zu lösen
Ich habe versucht, Deep Learning mit Spark × Keras × Docker skalierbar zu machen
Ich habe versucht, nächstes Jahr mit AI vorherzusagen
Ich habe versucht, das Lesen von Dataset mit PyTorch zu implementieren
Ich habe versucht, lightGBM, xg Boost mit Boruta zu verwenden
Ich habe versucht, die Daten mit Zwietracht zu speichern
Ich habe versucht, mit OpenCV Bewegungen schnell zu erkennen
Ich habe versucht, CloudWatch-Daten mit Python abzurufen
Ich habe versucht, LLVM IR mit Python auszugeben
Ich habe versucht zu debuggen.
Ich habe versucht, ein Objekt mit M2Det zu erkennen!
Ich habe versucht, die Herstellung von Sushi mit Python zu automatisieren
Ich habe versucht, das Überleben der Titanic mit PyCaret vorherzusagen
Ich habe versucht, Linux mit Discord Bot zu betreiben
Ich habe versucht, DP mit Fibonacci-Sequenz zu studieren
Ich habe versucht, Jupyter mit allen Amazon-Lichtern zu starten
Ich habe versucht, Tundele mit Naive Bays zu beurteilen
Ich habe versucht, Deep Learning mit Spark × Keras × Docker 2 Multi-Host-Edition skalierbar zu machen
Ich habe versucht, die Sündenfunktion mit Chainer zu trainieren
Ich habe versucht, mit VOICEROID2 2 automatisch zu lesen und zu speichern
Ich habe versucht, DCGAN mit PyTorch zu implementieren und zu lernen
Ich habe versucht, Mine Sweeper auf dem Terminal mit Python zu implementieren
Ich habe versucht, mit Blenders Python script_Part 01 zu beginnen
Ich habe versucht, eine CSV-Datei mit Python zu berühren
Ich habe versucht, Soma Cube mit Python zu lösen
Ich habe versucht, mit VOICEROID2 automatisch zu lesen und zu speichern
Ich habe versucht, mit Blenders Python script_Part 02 zu beginnen
Ich habe versucht, ObjectId (Primärschlüssel) mit Pymongo zu generieren
Ich habe versucht, künstliches Perzeptron mit Python zu implementieren
Ich habe versucht, eine ML-Pipeline mit Cloud Composer zu erstellen
Ich habe versucht, unsere Dunkelheit mit der Chatwork-API aufzudecken
[Einführung in Pytorch] Ich habe versucht, Cifar10 mit VGG16 ♬ zu kategorisieren
Ich habe versucht, das Problem mit Python Vol.1 zu lösen
Ich habe versucht, eine OCR-App mit PySimpleGUI zu erstellen
Ich habe versucht, SSD jetzt mit PyTorch zu implementieren (Dataset)
Ich habe versucht, Mask R-CNN mit Optical Flow zu interpolieren
Ich habe versucht, die Bayes'sche Optimierung zu durchlaufen. (Mit Beispielen)
Ich habe versucht, die alternative Klasse mit Tensorflow zu finden
[Einführung in AWS] Ich habe versucht, mit der Sprach-Text-Konvertierung zu spielen ♪
Ich habe versucht, AOJs Integer-Theorie mit Python zu lösen
Ich habe fp-Wachstum mit Python versucht
Ich habe versucht, mit Python zu kratzen
Ich habe versucht, PredNet zu lernen