[PYTHON] Creating an unknown Pokemon with StyleGAN2 [Part 2]

Last time showed the outline and the generation result. This time, I will introduce the changes from the official implementation of StyleGAN2 in order to actually work with GTX1070.

The trained model has been blown away due to various reasons, so I'd like to take note of that as well!

Official implementation of StyleGAN2

Official implementation changes

I actually messed with run_training.py and training / dataset.py.

The model data is about 300MB, which is quite heavy, so if you do not consider the storage frequency, the capacity will run out and learning will end. In my case, the learning stopped in the middle and the model was overwritten, so it was overwritten empty and the trained model for two days was blown away. It is recommended to save in multiple models on a regular basis.

run_training.py


def run(...):
  ...
  #Generated sample is output for each tick ・ Network is once every 5 ticks
  train.image_snapshot_ticks = 1
  train.network_snapshot_ticks = 5

  #This time I created a model with image size 64
  dataset_args = EasyDict(tfrecord_dir=dataset, resolution = 64)

Here Fixed the memory error

training/dataset.py


class TFRecordDataset:
  def __init__(...):
    ...
    # Load labels.
    assert max_label_size == 'full' or max_label_size >= 0
    #self._np_labels = np.zeros([1<<30, 0], dtype=np.float32)
    self._np_labels = np.zeros([1<<20, 0], dtype=np.float32)

Also, if the GPU processing becomes too long, the learning task will be killed. I was also hit by this. As a countermeasure, it seems that the timeout should be cut off by setting the DWORD value. (Http://www.field-and-network.jp/rihei/20121028223437.php) It is recommended to set it before learning.

Model evaluation index ・ Monitoring of learning

Quantitative evaluation of GAN is one of the difficult tasks, but in many studies, the quality of the generated image is evaluated by a method called FID (Frechet Inception Distance). This is a method of inputting the data set image and the generated image into the feature extraction model and calculating the Frechet distance between the distributions of the features.

Due to the number of dimensions of the multivariate normal distribution to be handled, it is necessary to prepare at least 4000 data sets from the data set and the generated image, and it takes time to read the data set in particular. If this process is omitted, the index will disappear, so I don't know when to stop learning, so I can't erase it, but is it something that can't be speeded up?

Since the FID is output to results /.../metrix-fid50k.txt, you will need to check regularly whether learning is proceeding smoothly.

metrix-fid50k.txt


network-snapshot-              time 19m 34s      fid50k 278.0748
network-snapshot-              time 19m 34s      fid50k 382.7474
network-snapshot-              time 19m 34s      fid50k 338.3625
network-snapshot-              time 19m 24s      fid50k 378.2344
network-snapshot-              time 19m 33s      fid50k 306.3552
network-snapshot-              time 19m 33s      fid50k 173.8370
network-snapshot-              time 19m 30s      fid50k 112.3612
network-snapshot-              time 19m 31s      fid50k 99.9480
network-snapshot-              time 19m 35s      fid50k 90.2591
network-snapshot-              time 19m 38s      fid50k 75.5776
network-snapshot-              time 19m 39s      fid50k 67.8876
network-snapshot-              time 19m 39s      fid50k 66.0221
network-snapshot-              time 19m 46s      fid50k 63.2856
network-snapshot-              time 19m 40s      fid50k 64.6719
network-snapshot-              time 19m 31s      fid50k 64.2135
network-snapshot-              time 19m 39s      fid50k 63.6304
network-snapshot-              time 19m 42s      fid50k 60.5562
network-snapshot-              time 19m 36s      fid50k 59.4038
network-snapshot-              time 19m 36s      fid50k 57.2236
network-snapshot-              time 19m 40s      fid50k 56.9055
network-snapshot-              time 19m 47s      fid50k 56.5965
network-snapshot-              time 19m 34s      fid50k 56.5844
network-snapshot-              time 19m 38s      fid50k 56.4158
network-snapshot-              time 19m 34s      fid50k 54.0568
network-snapshot-              time 19m 32s      fid50k 54.0307
network-snapshot-              time 19m 40s      fid50k 54.0492
network-snapshot-              time 19m 32s      fid50k 54.1482
network-snapshot-              time 19m 38s      fid50k 53.3513
network-snapshot-              time 19m 32s      fid50k 53.8889
network-snapshot-              time 19m 39s      fid50k 53.5233
network-snapshot-              time 19m 40s      fid50k 53.9403
network-snapshot-              time 19m 43s      fid50k 53.1017
network-snapshot-              time 19m 39s      fid50k 53.3370
network-snapshot-              time 19m 36s      fid50k 53.0706
network-snapshot-              time 19m 43s      fid50k 52.6289
network-snapshot-              time 19m 39s      fid50k 51.8526
network-snapshot-              time 19m 35s      fid50k 52.3760
network-snapshot-              time 19m 42s      fid50k 52.7780
network-snapshot-              time 19m 36s      fid50k 52.3064
network-snapshot-              time 19m 42s      fid50k 52.4976

If learning stops in the middle

If learning stops due to some error like me, you can start learning from there by loading the network-snapshot-*. Pkl model saved under results. The necessary description is as follows.

run_training.py


def run(...):
  ...
  train.resume_pkl = "./results/00000-stylegan2-tf_images-1gpu-config-f/network-snapshot-00640.pkl"
  train.resume_kimg = 640
  train.resume_time = 150960

For train.resume_time, it is okay if you convert the calculation time output when the model is saved from log.txt etc. to sec and enter it.

Perhaps transfer learning can be done by specifying a trained model using the same method. It is necessary to manually set the fixed weights and other details after loading the model. If you want to transfer with a task like this time, I feel that you can relearn everything ...

log.txt


dnnlib: Running training.training_loop.training_loop() on localhost...
...
tick 40    kimg 640.1    lod 0.00  minibatch 32   time 1d 17h 56m   sec/tick 2588.1  sec/kimg 161.76  maintenance 1203.1 gpumem 5.1

Results up to where I was able to learn

I fell at 640kimg due to lack of storage (crying) It's too sad because the FID was lowered nicely. fakes000640.png

The outline is getting clearer, and not only the shape of the body but also the face is beginning to be reproduced.

I would like to see the learning results in the future as soon as possible, but since I am re-learning, it seems that it will be in the near future. The results will be added as soon as the learning progresses.

I'm worried that I've introduced all the changes from the original implementation, so I've listed the code below. How to use is also written on GitHub, so please.

https://github.com/Takuya-Shuto-engineer/PokemonGAN

References

Recommended Posts

Creating an unknown Pokemon with StyleGAN2 [Part 1]
Creating an unknown Pokemon with StyleGAN2 [Part 2]
Creating an egg with python
Studying Python Part.1 Creating an environment
Note when creating an environment with python
GUI programming with kivy ~ Part 5 Creating buttons with images ~
Procedure for creating an application with Django with Pycharm ~ Preparation ~
Creating an authentication function
sandbox with neo4j part 10
Creating an environment that automatically builds with Github Actions (Android)
I came up with a way to create a 3D model from a photo Part 01 Creating an environment