[PYTHON] Produce beautiful sea slugs by deep learning

Introduction

This is a memo of the record when DCGAN, which is a derivative of GAN, was performed using Tensorflow. I will explain it roughly without going too deep. I wrote almost the same article the other day, but it got messed up, so I'll reorganize it a little.

I want to generate sea slugs in the title! It says, but at first I was thinking of generating Pokemon with DCGAN. So for the time being, I think I'll write it briefly from the attempt to generate Pokemon.

By the way, sea slugs are such creatures. There are many colorful types and they are beautiful image.pngimage.pngimage.png

What are GAN and DCGAN?

I would like to write about GAN briefly. GAN is a person who learns two things, ** "Generator" that creates fake and ** "Discriminator" that discriminates **, and generates data that is as close to the real thing as possible. Generator creates a new image from random noise with reference to real data. Discriminator discriminates the image generated by Generator as "fake or genuine". Generator and Discriminator are good rivals. By repeating this over and over again, the Generator and Discriminator will become smarter and smarter. As a result, images that are close to real data will be generated.

↓ It looks like this image.png

↓ What I wrote more easily image.png

This is the basic mechanism of GAN. DCGAN is the one that uses CNN (Convolutional Neural Network) for this GAN. CNN is complicated in various ways, but to put it simply, it is possible to share weights between neural networks by making the neural network a multi-layer structure using two layers, ** convolution layer ** and ** pooling layer **. Will be possible. As a result, DCGAN can perform learning with higher efficiency and accuracy than GAN.

I will use this DCGAN to generate Pokemon and sea slugs. Also, the explanation of GAN and DCGAN GAN (1) Understanding the basic structure that I can't hear anymore GAN that I can't hear anymore (2) Image generation by DCGAN Is easy to understand.

Pokémon

There are many types of Pokemon, and I chose it as a theme because I thought it would be fun with a familiar theme. There are so many types of Pokemon today. Pokemon image is [here](https://kamigame.jp/%E3%83%9D%E3%82%B1%E3%83%A2%E3%83%B3USUM/%E3%83%9D%E3%82 % B1% E3% 83% A2% E3% 83% B3 /% E8% 89% B2% E9% 81% 95% E3% 81% 84% E3% 83% 9D% E3% 82% B1% E3% 83% Downloaded from A2% E3% 83% B3% E4% B8% 80% E8% A6% A7.html).

By the way, this time, I used the Chrome extension ** "Image Downloader" ** to collect Pokemon images. It is recommended because it can be used easily without writing code. I thought that the number of data was too small, so I added rotation and inversion with the following code and inflated it. By the way, it is saved in .npy format for easy reading.

import os,glob
import numpy as np
from tqdm import tqdm
from keras.preprocessing.image import load_img,img_to_array
from keras.utils import np_utils
from sklearn import model_selection
from PIL import Image

#Store classes in an array
classes = ["class1", "class2"]

num_classes = len(classes)
img_size = 128
color=False

#Loading images
#Finally images and labels are stored in the list

temp_img_array_list=[]
temp_index_array_list=[]
for index,classlabel in enumerate(classes):
    photos_dir = "./" + classlabel
    #Get a list of images for each class with glob
    img_list = glob.glob(photos_dir + "/*.jpg ")
    for img in tqdm(img_list):
        temp_img=load_img(img,grayscale=color,target_size=(img_size, img_size))
        temp_img_array=img_to_array(temp_img)
        temp_img_array_list.append(temp_img_array)
        temp_index_array_list.append(index)
        #Rotation processing
        for angle in range(-20,20,5):
            #rotation
            img_r = temp_img.rotate(angle)
            data = np.asarray(img_r)
            temp_img_array_list.append(data)
            temp_index_array_list.append(index)
            #Invert
            img_trans = img_r.transpose(Image.FLIP_LEFT_RIGHT)
            data = np.asarray(img_trans)
            temp_img_array_list.append(data)
            temp_index_array_list.append(index)

            X=np.array(temp_img_array_list)
            Y=np.array(temp_index_array_list)

np.save("./img_128RGB.npy", X)
np.save("./index_128RGB.npy", Y)

I wanted to make a chimeric Pokemon by mixing Pokemon full of DCGAN image.pngimage.pngimage.pngimage.pngimage.png

character_kimera_chimaira.png

But what I actually did was image.png

It was clearly overfitted, as you can see from both the generated image and the loss. Discriminator is insanely strong. So next, I thought about the cause and solved it.

Causes of overfitting

Is Pokemon difficult?

-Since Pokemon have different colors and shapes, is it easy to generate chaotic guys? ――I want to use something that has a unified shape to some extent. Now change from Pokemon generation to ** Nudibranch generation **. ――However, the color and shape of sea slugs are not so unified, so I feel that the subject matter is delicate. But I'll tell you that making something you like keeps you motivated.

The number of data is small

――We collected about 500+ images of sea slugs from Pokemon images. Rotation (-20 ° ~ 20 °) and inversion will probably increase 16 times, so the amount of data has increased by ** "500 x 16 = 8000" **. --Images were collected by ** Flickr ** and ** icrawler **. ――I'll roughly explain how to use Flickr. Go to the Flickr API site (https://www.flickr.com/services/api/) where it says ** API key **. 名称未設定ファイル (1).png If you get a Yahoo account here and log in, this screen will appear, so get the key from here. (It's painted black) 名称未設定ファイル (2).png Use this key to get the image with the code below

from flickrapi import FlickrAPI
from urllib.request import urlretrieve
from pprint import pprint
import os, time, sys

#AP key I information
key = "********"
secret = "********"
wait_time = 1

#Specify save folder
savedir = "./gazou"

flickr = FlickrAPI(key, secret, format="parsed-json")
result = flickr.photos.search(
        per_page = 100,
        tags = "seaslug",
        media = "photos",
        sort = "relevance",
        safe_search = 1,
        extras = "url_q, licence"
)

photos = result["photos"]

#Store information in photo by loop processing
for i, photo in enumerate(photos['photo']):
    url_q = photo["url_q"]
    filepath = savedir + "/" + photo["id"] + ".jpg "
    if os.path.exists(filepath): continue
    urlretrieve(url_q, filepath)
    time.sleep(wait_time)

This will collect some data, but I wanted more, so I will collect images with ** icrawler **. It's insanely easy to use.

$ pip install icrawler
from icrawler.builtin import GoogleImageCrawler

crawler = GoogleImageCrawler(storage={"root_dir": "gazou"})
crawler.crawl(keyword="Nudibranch", max_num=100)

This alone will save the sea slug image in the specified folder. f4b244b3be30f5fed4837d57fb64219c.jpg Like Pokemon, this image was inflated by rotating and flipping it.

No dropout

--To briefly explain the dropout, overfitting is prevented by ignoring the set ratio of nodes. --For details, this article seems to be good. ――The following is the actual Discriminator with the dropout applied.

def discriminator(x, reuse=False, alpha=0.2):
    with tf.variable_scope("discriminator", reuse=reuse):
        x1 = tf.layers.conv2d(x, 32, 5, strides=2, padding="same")
        x1 = tf.maximum(alpha * x1, x1)
        x1_drop = tf.nn.dropout(x1, 0.5)
        
        x2 = tf.layers.conv2d(x1_drop, 64, 5, strides=2, padding="same")
        x2 = tf.layers.batch_normalization(x2, training=True)
        x2 = tf.maximum(alpha * x2, x2)
        x2_drop = tf.nn.dropout(x2, 0.5)
        
        x3 = tf.layers.conv2d(x2_drop, 128, 5, strides=2, padding="same")
        x3 = tf.layers.batch_normalization(x3, training=True)
        x3 = tf.maximum(alpha * x3, x3)
        x3_drop = tf.nn.dropout(x3, 0.5)
        
        x4 = tf.layers.conv2d(x3_drop, 256, 5, strides=2, padding="same")
        x4 = tf.layers.batch_normalization(x4, training=True)
        x4 = tf.maximum(alpha * x4, x4)
        x4_drop = tf.nn.dropout(x4, 0.5)
        
        x5 = tf.layers.conv2d(x4_drop, 512, 5, strides=2, padding="same")
        x5 = tf.layers.batch_normalization(x5, training=True)
        x5 = tf.maximum(alpha * x5, x5)
        x5_drop = tf.nn.dropout(x5, 0.5)
        
        flat = tf.reshape(x5_drop, (-1, 4*4*512))
        logits = tf.layers.dense(flat, 1)
        logits_drop = tf.nn.dropout(logits, 0.5)
        out = tf.sigmoid(logits_drop)
        
        return out, logits

High learning rate?

――If the learning rate is high, the training will proceed quickly, but it will easily diverge and it will be difficult to learn. ――When I actually verified with various values starting from 1e-2, is 1e-4 just right? It was like that. In my case, learning was too slow at 1e-5. --For the behavior of various learning rates, this article is easy to understand.

Too much training data?

――Initially, it was about 8: 2, but it was changed to 6: 4. I couldn't really feel the effect

Nudibranch (improvement result from Pokemon)

100epoch ダウンロード (4).png

200epoch ダウンロード (8).png

300epoch ダウンロード (10).png

400epoch ダウンロード (6).png

500epoch ダウンロード (3).png

――For the time being, I turned it around 500 epoch. Looking at it from a distance, I feel that sea slugs are being produced. ――But honestly, the result is subtle ... ――Possible factors are "Is there enough epoch?" "Is the image containing too much extra (rocky background, etc.)?" "Is the layer too deep?" "After all, the image is a little simpler. Various things can be considered, such as "Is it good?" ――I wanted to improve it further and turn it a little more, but it is running on ** Google Colaboratory **, and it is quite difficult due to the connection time. --There are a few things I would like to write about Colaboratory, so I will set up a chapter next.

Colaboratory Colaboratory is a Jupyter notebook environment that runs on the cloud provided by Google, and you can use a GPU of about 800,000 yen. Moreover, there is no need to build an environment or apply for Datalab. Further free. It's insanely convenient, but it has the following restrictions.

--If you connect to the GPU for a certain amount of time a day (recently about 4 hours [500 epoch]), you will not be able to use that day. (This is due to the lack of GPU resources in Colaboratory, so there is no workaround and there is no choice but to wait. GPU is said to be preferentially assigned to users who are not constantly using it.) --The runtime is disconnected after 90 minutes when inactive, up to 12 hours, and the learning results of the notebook are also initialized. ――Therefore, I used ** Hyperdash ** to solve the 90-minute problem. This allows you to connect the runtime for over 90 minutes. --In addition to the 90-minute problem by sending the learning log to Hyperdash, you can also solve the Buffered data was truncated after reaching the output size limit. Problem that makes it impossible to check the log on Colaboratory. ――Hyperdash is a smartphone app, so you can check the log even when you're on the go, which is convenient. --Hyperdash allows you to check the plots and parameters of the learning progress, but this time the purpose is only to prevent runtime disconnection, so just the steps below are OK.

#First, start the smartphone app Hyperdash and create an account.
#Install Hyperdash

!pip install hyperdash
from hyperdash import monitor_cell
!hyperdash login --email

You will be asked for your Hyperdash email address and password, so enter them. 名称未設定ファイル (4).png Next, write the code that uses Hyperdash and it's OK.

#Using Hyperdash

from tensorflow.keras.callbacks import Callback
from hyperdash import Experiment

class Hyperdash(Callback):
    def __init__(self, entries, exp):
        super(Hyperdash, self).__init__()
        self.entries = entries
        self.exp = exp

    def on_epoch_end(self, epoch, logs=None):
        for entry in self.entries:
            log = logs.get(entry)            
            if log is not None:
                self.exp.metric(entry, log)

exp = Experiment("Any name")
hd_callback = Hyperdash(["val_loss", "loss", "val_accuracy", "accuracy"], exp)


~~~Training execution code~~~


exp.end()

Now, if you look at the smartphone app Hyperdash, you should see the learning log. Using Hyperdash solved the problem for 90 minutes, but for some reason the runtime may be disconnected, so I think it's a good idea to divide the training into smaller pieces and save them as .ckpt. This .ckpt also disappears when the runtime is disconnected, so save it early.

#Learning results.Save with ckpt
saver.save(sess, "/****1.ckpt")

# .Read the learning result saved by ckpt and restart from there
saver.restore(sess, "/****1.ckpt")

# .Save ckpt to the specified directory
from google.colab import files
files.download( "/****1.ckpt.data-00000-of-00001" ) 

Reflection / Conclusion

--DCGAN is difficult because the model is complicated and overfitting is likely to occur. The first consideration is to build a simple model with a shallower layer. --Although it does not seem to be directly related to overfitting, pay attention to the above-mentioned "epoch number", "simple image", and "make the subject simpler". ――Is the latent variable also a fairly important parameter? I will investigate more. ――It may have been a difficult article to read because I just wrote down what I was doing. Thank you for reading to the end. DCGAN is fun because the result appears as an image. I will also try to make improvements and changes.

Recommended Posts

Produce beautiful sea slugs by deep learning
Deep learning learned by implementation 1 (regression)
Deep Learning
Deep learning learned by implementation 2 (image classification)
Deep Understanding Object Detection by Deep Learning by Keras
Chainer and deep learning learned by function approximation
Deep learning learned by implementation ~ Anomaly detection (unsupervised learning) ~
Deep Learning Memorandum
Start Deep learning
99.78% accuracy with deep learning by recognizing handwritten hiragana
Video frame interpolation by deep learning Part1 [Python]
Python Deep Learning
Parallel learning of deep learning by Keras and Kubernetes
Deep learning × Python
Deep learning learned by implementation (segmentation) ~ Implementation of SegNet ~
Stock investment by deep reinforcement learning (policy gradient method) (1)
[Anomaly detection] Detect image distortion by deep distance learning
Classify anime faces by sequel / deep learning with Keras
First Deep Learning ~ Struggle ~
Deep Learning from scratch
Deep learning 1 Practice of deep learning
Deep learning / cross entropy
First Deep Learning ~ Preparation ~
First Deep Learning ~ Solution ~
[AI] Deep Metric Learning
I tried deep learning
Python: Deep Learning Tuning
Deep learning large-scale technology
Deep learning / softmax function
"Learn while making! Development deep learning by PyTorch" on Colaboratory.
Basic understanding of depth estimation by mono camera (Deep Learning)
Create AI to identify Zuckerberg's face by deep learning ③ (Data learning)
Automatic composition by deep learning (Stacked LSTM edition) [DW Day 6]