[PYTHON] Produisez de belles vaches de mer par apprentissage profond

introduction

Ceci est un mémo de l'enregistrement lorsque DCGAN, qui est un dérivé de GAN, a été effectué à l'aide de Tensorflow. Je vais vous expliquer grossièrement sans aller trop loin. J'ai écrit presque le même article l'autre jour, mais il a été foiré, alors je vais le réorganiser un peu.

Je veux générer une vache de mer dans le titre! Cela dit, mais au début, je pensais générer des Pokémon avec DCGAN. Donc, pour le moment, je pense que je vais l'écrire brièvement à partir de la tentative de génération de Pokémon.

Au fait, Umiushi est une telle créature. Il existe de nombreux types colorés et ils sont beaux image.pngimage.pngimage.png

Que sont GAN et DCGAN?

J'écrirai brièvement sur GAN. GAN est un gars qui forme deux choses, ** un "Générateur" qui crée un faux ** et un "Discriminateur" ** qui distingue **, et génère des données aussi proches que possible de la réalité. Generator crée une nouvelle image à partir de bruit aléatoire en référence à des données réelles. Discriminator discrimine l'image générée par Generator comme "fausse ou authentique". Generator et Discriminator sont de bons rivaux. En répétant cela encore et encore, le générateur et le discriminateur deviendront plus intelligents et plus intelligents. En conséquence, des images proches des données réelles seront générées.

↓ ça ressemble à ça image.png

↓ Ce que j'ai écrit plus facilement image.png

C'est le mécanisme de base du GAN. DCGAN est celui qui utilise CNN (réseau de neurones convolutifs) pour ce GAN. CNN est compliqué de diverses manières, mais pour le dire simplement, il est possible de partager des poids entre les réseaux de neurones en faisant du réseau de neurones une structure multicouche utilisant deux couches, ** couche de convolution ** et ** couche de pooling **. Sera possible. En conséquence, DCGAN peut effectuer un apprentissage plus efficace et plus précis que le GAN.

J'utiliserai ce DCGAN pour générer des Pokémon et Umiushi. Aussi, l'explication de GAN et DCGAN GAN (1) Comprendre la structure de base que je n'entends plus GAN que je n'entends plus (2) Génération d'images par DCGAN Est facile à comprendre.

Pokémon

Il existe de nombreux types de Pokémon, et je l'ai choisi comme thème parce que je pensais que ce serait amusant avec un thème familier. Il existe tellement de types de Pokémon aujourd'hui. L'image Pokemon est [ici](https://kamigame.jp/%E3%83%9D%E3%82%B1%E3%83%A2%E3%83%B3USUM/%E3%83%9D%E3%82 % B1% E3% 83% A2% E3% 83% B3 /% E8% 89% B2% E9% 81% 95% E3% 81% 84% E3% 83% 9D% E3% 82% B1% E3% 83% Téléchargé depuis A2% E3% 83% B3% E4% B8% 80% E8% A6% A7.html).

D'ailleurs, cette fois, j'ai utilisé l'extension Chrome ** "Image Downloader" ** pour collecter des images Pokémon. Il est recommandé car il peut être utilisé facilement sans écrire de code. Je pensais que le nombre de données était trop petit, alors j'ai ajouté la rotation et l'inversion avec le code suivant et l'ai gonflé. À propos, il est enregistré au format `` .npy '' pour une lecture facile.

import os,glob
import numpy as np
from tqdm import tqdm
from keras.preprocessing.image import load_img,img_to_array
from keras.utils import np_utils
from sklearn import model_selection
from PIL import Image

#Stocker les classes dans un tableau
classes = ["class1", "class2"]

num_classes = len(classes)
img_size = 128
color=False

#Chargement des images
#Enfin l'image et l'étiquette sont stockées dans la liste

temp_img_array_list=[]
temp_index_array_list=[]
for index,classlabel in enumerate(classes):
    photos_dir = "./" + classlabel
    #Obtenez une liste d'images pour chaque classe avec glob
    img_list = glob.glob(photos_dir + "/*.jpg ")
    for img in tqdm(img_list):
        temp_img=load_img(img,grayscale=color,target_size=(img_size, img_size))
        temp_img_array=img_to_array(temp_img)
        temp_img_array_list.append(temp_img_array)
        temp_index_array_list.append(index)
        #Traitement de rotation
        for angle in range(-20,20,5):
            #rotation
            img_r = temp_img.rotate(angle)
            data = np.asarray(img_r)
            temp_img_array_list.append(data)
            temp_index_array_list.append(index)
            #Inverser
            img_trans = img_r.transpose(Image.FLIP_LEFT_RIGHT)
            data = np.asarray(img_trans)
            temp_img_array_list.append(data)
            temp_index_array_list.append(index)

            X=np.array(temp_img_array_list)
            Y=np.array(temp_index_array_list)

np.save("./img_128RGB.npy", X)
np.save("./index_128RGB.npy", Y)

Je voulais faire un Pokémon chimérique en mélangeant des Pokémon pleins de DCGAN image.pngimage.pngimage.pngimage.pngimage.png

character_kimera_chimaira.png

Mais ce que j'ai fait était image.png

Il était clairement surentraîné, comme vous pouvez le voir à la fois sur l'image générée et sur la perte. Le discriminateur est incroyablement fort. Alors ensuite, j'ai pensé à la cause et l'ai résolue.

Causes du surapprentissage

Est-ce que Pokémon est difficile?

――Pokemon a différentes couleurs et formes, il est donc facile de générer des gars chaotiques? «Je veux utiliser quelque chose qui a une forme unifiée dans une certaine mesure. Passez maintenant de la génération Pokémon à la ** génération des vaches de mer ** ――Cependant, la couleur et la forme d'Umiushi ne sont pas si unifiées, donc je pense que c'est un sujet délicat. Mais je vais vous dire que faire quelque chose que vous aimez vous garde motivé.

Le nombre de données est petit

――Nous avons collecté plus de 500 images d'Umiushi à partir d'images Pokemon. La rotation (-20 ° ~ 20 °) et l'inversion augmenteront probablement 16 fois, donc la quantité de données a augmenté de ** "500 x 16 = 8000" **. --Les images ont été collectées par ** Flickr ** et ** icrawler **. ――Je vais vous expliquer comment utiliser Flickr en gros. Accédez au site API Flickr où ** API key ** est écrit. 名称未設定ファイル (1).png Si vous obtenez un compte Yahoo ici et que vous vous connectez, cet écran apparaîtra, alors récupérez la clé d'ici. (Il est peint en noir) 名称未設定ファイル (2).png Utilisez cette touche pour obtenir l'image avec le code ci-dessous

from flickrapi import FlickrAPI
from urllib.request import urlretrieve
from pprint import pprint
import os, time, sys

#Informations sur la clé AP I
key = "********"
secret = "********"
wait_time = 1

#Spécifiez le dossier de sauvegarde
savedir = "./gazou"

flickr = FlickrAPI(key, secret, format="parsed-json")
result = flickr.photos.search(
        per_page = 100,
        tags = "seaslug",
        media = "photos",
        sort = "relevance",
        safe_search = 1,
        extras = "url_q, licence"
)

photos = result["photos"]

#Stocker des informations sur la photo par traitement en boucle
for i, photo in enumerate(photos['photo']):
    url_q = photo["url_q"]
    filepath = savedir + "/" + photo["id"] + ".jpg "
    if os.path.exists(filepath): continue
    urlretrieve(url_q, filepath)
    time.sleep(wait_time)

Cela collectera des données, mais je voulais plus, donc je vais collecter des images avec ** icrawler **. C'est incroyablement facile à utiliser.

$ pip install icrawler
from icrawler.builtin import GoogleImageCrawler

crawler = GoogleImageCrawler(storage={"root_dir": "gazou"})
crawler.crawl(keyword="Vache de mer", max_num=100)

Cela seul enregistrera l'image d'oursin dans le dossier spécifié. f4b244b3be30f5fed4837d57fb64219c.jpg Comme Pokemon, cette image a été gonflée en la tournant et en la retournant.

Pas de décrochage

--Pour expliquer brièvement le décrochage, il empêche le surapprentissage en ignorant le ratio défini de nœuds. --Pour plus de détails, cet article semble être bon. ――Ce qui suit est le discriminateur réel avec la suppression appliquée.

def discriminator(x, reuse=False, alpha=0.2):
    with tf.variable_scope("discriminator", reuse=reuse):
        x1 = tf.layers.conv2d(x, 32, 5, strides=2, padding="same")
        x1 = tf.maximum(alpha * x1, x1)
        x1_drop = tf.nn.dropout(x1, 0.5)
        
        x2 = tf.layers.conv2d(x1_drop, 64, 5, strides=2, padding="same")
        x2 = tf.layers.batch_normalization(x2, training=True)
        x2 = tf.maximum(alpha * x2, x2)
        x2_drop = tf.nn.dropout(x2, 0.5)
        
        x3 = tf.layers.conv2d(x2_drop, 128, 5, strides=2, padding="same")
        x3 = tf.layers.batch_normalization(x3, training=True)
        x3 = tf.maximum(alpha * x3, x3)
        x3_drop = tf.nn.dropout(x3, 0.5)
        
        x4 = tf.layers.conv2d(x3_drop, 256, 5, strides=2, padding="same")
        x4 = tf.layers.batch_normalization(x4, training=True)
        x4 = tf.maximum(alpha * x4, x4)
        x4_drop = tf.nn.dropout(x4, 0.5)
        
        x5 = tf.layers.conv2d(x4_drop, 512, 5, strides=2, padding="same")
        x5 = tf.layers.batch_normalization(x5, training=True)
        x5 = tf.maximum(alpha * x5, x5)
        x5_drop = tf.nn.dropout(x5, 0.5)
        
        flat = tf.reshape(x5_drop, (-1, 4*4*512))
        logits = tf.layers.dense(flat, 1)
        logits_drop = tf.nn.dropout(logits, 0.5)
        out = tf.sigmoid(logits_drop)
        
        return out, logits

Taux d'apprentissage élevé?

――Si le taux d'apprentissage est élevé, la formation se déroulera rapidement, mais elle divergera facilement et il sera difficile d'apprendre. ――Lorsque j'ai vérifié avec différentes valeurs commençant par 1e-2, 1e-4 est-il juste? C'était comme ça. Dans mon cas, l'apprentissage est devenu trop lent à 1e-5.

Trop de données d'entraînement?

―― Au départ, c'était environ 8: 2, mais il a été changé en 6: 4. Je ne pouvais pas vraiment ressentir l'effet

Vache de mer (résultat d'amélioration de Pokemon)

100epoch ダウンロード (4).png

200epoch ダウンロード (8).png

300epoch ダウンロード (10).png

400epoch ダウンロード (6).png

500epoch ダウンロード (3).png

«Pour le moment, je l'ai tourné autour de 500 époques. En le regardant de loin, je sens que des vaches marines sont produites. ――Mais honnêtement, le résultat est subtil ... ―― Les facteurs possibles sont "N'y a-t-il pas assez d'époque?" "L'image contient-elle trop d'éléments supplémentaires (fond rocheux, etc.)?" "La couche est-elle trop profonde?" Différentes choses peuvent être envisagées, telles que "Est-ce que c'est bon?" «Je voulais l'améliorer davantage et le tourner un peu plus, mais il fonctionne sur ** Google Colaboratory **, et c'est assez difficile en raison du temps de connexion.

Colaboratory Colaboratory est un environnement de notebook Jupyter qui s'exécute sur le cloud fourni par Google, et vous pouvez utiliser un GPU d'environ 800 000 yens. De plus, il n'est pas nécessaire de créer un environnement ou de postuler à Datalab. Plus gratuit. C'est incroyablement pratique, mais il comporte les restrictions suivantes.

――Si vous disposez d'une connexion GPU pendant un certain temps par jour (récemment environ 4 heures [500 époques]), vous ne pourrez pas l'utiliser ce jour-là. (Cela est dû au manque de ressources GPU dans Colaboratory, il n'y a donc pas de solution de contournement et il n'y a pas d'autre choix que d'attendre. Il semble que les GPU sont préférentiellement attribués aux utilisateurs qui ne l'utilisent pas constamment.)

#Tout d'abord, démarrez Hyperdash, une application pour smartphone, et créez un compte.
#Installation Hyperdash

!pip install hyperdash
from hyperdash import monitor_cell
!hyperdash login --email

Il vous sera demandé votre adresse e-mail et votre mot de passe Hyperdash, alors saisissez-les. 名称未設定ファイル (4).png Ensuite, écrivez le code qui utilise Hyperdash et vous êtes prêt à partir.

#Utiliser Hyperdash

from tensorflow.keras.callbacks import Callback
from hyperdash import Experiment

class Hyperdash(Callback):
    def __init__(self, entries, exp):
        super(Hyperdash, self).__init__()
        self.entries = entries
        self.exp = exp

    def on_epoch_end(self, epoch, logs=None):
        for entry in self.entries:
            log = logs.get(entry)            
            if log is not None:
                self.exp.metric(entry, log)

exp = Experiment("N'importe quel nom")
hd_callback = Hyperdash(["val_loss", "loss", "val_accuracy", "accuracy"], exp)


~~~Code d'exécution de la formation~~~


exp.end()

Maintenant, si vous regardez l'application pour smartphone Hyperdash, vous devriez voir le journal d'apprentissage. L'utilisation d'Hyperdash a résolu le problème pendant 90 minutes, mais le runtime peut être déconnecté pour une raison quelconque, donc je pense que c'est une bonne idée de diviser la formation en petits morceaux et de les enregistrer sous .ckpt ''. Ce .ckpt`` disparaît également lorsque le runtime est déconnecté, donc enregistrez-le tôt.

#Résultats d'apprentissage.Économisez avec ckpt
saver.save(sess, "/****1.ckpt")

# .Lisez le résultat d'apprentissage enregistré par ckpt et recommencez à partir de là
saver.restore(sess, "/****1.ckpt")

# .Enregistrer ckpt dans le répertoire spécifié
from google.colab import files
files.download( "/****1.ckpt.data-00000-of-00001" ) 

Réflexion / Conclusion

--DCGAN est difficile car le modèle est compliqué et un surentraînement est susceptible de se produire. La première considération est de construire un modèle simple avec une couche moins profonde. TIl semble que cela ne soit pas directement lié au surapprentissage, mais attention au "numéro d'époque", à "l'image simple" et "simplifier le sujet" mentionnés ci-dessus. ―― La variable latente est-elle également un paramètre assez important? Je vais enquêter davantage. «Cela a peut-être été un article difficile à lire parce que je viens d'écrire ce que je faisais. Merci d'avoir lu jusqu'au bout. DCGAN est amusant car les résultats apparaissent sous forme d'images. J'essaierai également d'apporter des améliorations et des changements.

Recommended Posts

Produisez de belles vaches de mer par apprentissage profond
Apprentissage profond appris par l'implémentation 1 (édition de retour)
L'apprentissage en profondeur
Deep learning 2 appris par l'implémentation (classification d'images)
Détection d'objets par apprentissage profond pour comprendre en profondeur par Keras
Chainer et deep learning appris par approximation de fonction
Apprentissage profond appris par la mise en œuvre ~ Détection d'anomalies (apprentissage sans enseignant) ~
Mémorandum d'apprentissage profond
Commencer l'apprentissage en profondeur
99,78% de précision avec apprentissage en profondeur en reconnaissant les hiragana manuscrits
Interpolation d'images vidéo par apprentissage en profondeur, partie 1 [Python]
Apprentissage en profondeur Python
Apprentissage parallèle du deep learning par Keras et Kubernetes
Apprentissage profond × Python
Apprentissage profond appris par mise en œuvre (segmentation) ~ Mise en œuvre de SegNet ~
Investissement en actions par apprentissage approfondi (méthode du gradient de politique) (1)
[Détection d'anomalies] Détecter la distorsion de l'image par apprentissage à distance
Classer les visages d'anime par suite / apprentissage profond avec Keras
Premier apprentissage profond ~ Lutte ~
Apprentissage profond à partir de zéro
Deep learning 1 Pratique du deep learning
Apprentissage profond / entropie croisée
Premier apprentissage profond ~ Préparation ~
Première solution d'apprentissage en profondeur ~
[AI] Apprentissage métrique profond
J'ai essayé le deep learning
Python: réglage du Deep Learning
Technologie d'apprentissage en profondeur à grande échelle
Fonction d'apprentissage profond / softmax
"Apprenez en créant! Développement en deep learning par PyTorch" sur Colaboratory.
Compréhension de base de l'estimation de la profondeur par caméra mono (Deep Learning)
Créez une IA qui identifie le visage de Zuckerberg grâce à l'apprentissage en profondeur ③ (Apprentissage des données)
Chanson auto-exploitée par apprentissage en profondeur (édition Stacked LSTM) [DW Day 6]