[PYTHON] I tried to implement and learn DCGAN with PyTorch

What to do this time

Since I created a model for image processing that performs super-resolution before, I decided to create a model that generates images next. Therefore, I think that image generation is GAN, and I am thinking of implementing and learning DCGAN, which is a relatively simple model.

About DCGAN

The idea of GAN (Generative adversarial network) is simple, and the competition between the forgery creator and the appraiser who sees it is to create a more accurate forgery. Therefore, the structure of the network -Generator that generates an image with appropriate noise as input ・ Discriminator that determines the authenticity of an image as input It is basically made up of two things. Regarding the detailed mechanism, GAN (1) Understanding the basic structure that I can't ask anymore I studied here. DCGAN (Deep Convolutional Generative adversarial network) is a model that generates images using a deconvolutional layer.

Implementation

The code was implemented in the following four files. ・ Networks.py: Writing about network structure -Utils.py: Writing about Dataset and loss functions ・ Train.py: Train the model -Generate.py: Generate an image using the learned model. networks.py Here we define the structure of Generator and Discriminator. This time it has a simple structure because it is DCGAN.

networks.py



import torch
from torch import nn

class Generator(nn.Module):
    def __init__(self, latent_size = 100):
        super().__init__()
        self.main = nn.Sequential(

            nn.ConvTranspose2d(latent_size, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).squeeze()

utils.py

Here we define the dataset for ease of use during training.

utils.py


import os

from torch.utils.data import Dataset
from torchvision import transforms
import torch
from PIL import Image

class Dcgan_Dataset(Dataset):
    def __init__(self, root, datamode = "train", transform = transforms.ToTensor(), latent_size=100):

        self.image_dir = os.path.join(root, datamode)
        self.image_paths = [os.path.join(self.image_dir, name) for name in os.listdir(self.image_dir)]
        self.data_length = len(self.image_paths)

        self.transform = transform
        self.latent_size = latent_size

    def __len__(self):
        return self.data_length
    
    def __getitem__(self, index):
        latent = torch.randn(size=(self.latent_size, 1, 1))
        img_path = self.image_paths[index]
        img = Image.open(img_path)

        if not self.transform is None:
            img = self.transform(img)
        
        return latent, img


train.py

Here we define learning. It's been a little longer because I'm testing it and recording it on the tensorboard.


import os
import argparse

from networks import Generator, Discriminator
from utils import Dcgan_Dataset

import numpy as np

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

from tqdm import tqdm



def main(opt):
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    # ----- Device Setting -----
    if opt.gpu is True:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")
    print("Device :", device)

    # ----- Dataset Setting -----
    train_dataset = Dcgan_Dataset(opt.dataset, datamode="train",
                                  transform=transforms.Compose([transforms.RandomHorizontalFlip(),
                                                                transforms.ToTensor()]))
    
    test_dataset = Dcgan_Dataset(opt.dataset, datamode="test")

    print("Training Dataset :", os.path.join(opt.dataset, "train"))
    print("Testing Dataset :", os.path.join(opt.dataset, "test"))

    # ----- DataLoader Setting -----
    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=True)

    print("batch_size :",opt.batch_size)
    print("test_batch_size :",opt.test_batch_size)

    # ----- Summary Writer Setting -----
    train_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper))
    test_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper + "_test"))

    print("log directory :",os.path.join(opt.tensorboard, opt.exper))
    print("log step :", opt.n_log_step)

    # ----- Net Work Setting -----
    latent_size = opt.latent_size
    model_D = Discriminator()
    model_G = Generator()

    # resume
    if opt.resume_epoch != 0:
        model_D_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_D_{}.pth".format(str(opt.resume_epoch)))
        model_G_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_G_{}.pth".format(str(opt.resume_epoch)))

        model_G.load_state_dict(torch.load(model_G_path, map_location="cpu"))
        model_D.load_state_dict(torch.load(model_D_path, map_location="cpu"))

    model_D.to(device)
    model_G.to(device)
    model_D.train()
    model_G.train()

    #Label variable when calculating loss
    ones = torch.ones(opt.batch_size).to(device) #Positive example 1
    zeros = torch.zeros(opt.batch_size).to(device) #Negative example 0

    val_latents = torch.randn(9, opt.latent_size, 1, 1).to(device)
    loss_f = nn.BCEWithLogitsLoss()

    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=0.0002)
    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=0.0002)

    print("Latent size :",opt.latent_size)

    # ----- Training Loop -----
    step = 0
    for epoch in tqdm(range(opt.resume_epoch, opt.resume_epoch + opt.epoch)):
        print("epoch :",epoch + 1,"/", opt.resume_epoch + opt.epoch)

        # for latent, real_img in tqdm(train_loader):
        for latent, real_img in train_loader:
            step += 1
            latent = latent.to(device)
            real_img = real_img.to(device)
            batch_len = len(real_img)

            fake_img = model_G(latent)

            pred_fake = model_D(fake_img)
            loss_G = loss_f(pred_fake, ones[: batch_len])

            model_D.zero_grad()
            model_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            pred_real = model_D(real_img)
            loss_D_real = loss_f(pred_real, ones[: batch_len])
            fake_img = model_G(latent)
            pred_fake = model_D(fake_img)
            loss_D_fake = loss_f(pred_fake, zeros[: batch_len])
            loss_D = loss_D_real + loss_D_fake

            model_D.zero_grad()
            model_G.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            if step % opt.n_log_step == 0:
                # test step 
                model_G.eval()
                model_D.eval()
                test_d_losses = []
                test_d_real_losses = []
                test_d_fake_losses = []
                test_g_losses = []
                for test_latent, test_real_img in test_loader:
                    test_latent = test_latent.to(device)
                    test_real_img = test_real_img.to(device)
                    batch_len = len(test_latent)
                    test_pred_img = model_G(test_latent)
                    test_fake_g = model_D(test_pred_img)
                    test_g_loss = loss_f(test_fake_g, ones[: batch_len])

                    test_g_losses.append(test_g_loss.item())

                    test_fake_d = model_D(test_pred_img)
                    test_real_d = model_D(test_real_img)
                    test_d_real_loss = loss_f(test_real_d, ones[: batch_len])
                    test_d_fake_loss = loss_f(test_fake_d, zeros[: batch_len])
                    test_d_loss = test_d_real_loss + test_d_fake_loss

                    test_d_real_losses.append(test_d_real_loss.item())
                    test_d_fake_losses.append(test_d_fake_loss.item())
                    test_d_losses.append(test_d_loss.item())
                
                # record process
                test_g_loss = sum(test_g_losses)/len(test_g_losses)
                test_d_loss = sum(test_d_losses)/len(test_d_losses)
                test_d_real_loss = sum(test_d_real_losses)/len(test_d_real_losses)
                test_d_fake_loss = sum(test_d_fake_losses)/len(test_d_fake_losses)


                train_writer.add_scalar("loss/g_loss", loss_G.item(), step)
                train_writer.add_scalar("loss/d_loss", loss_D.item(), step)
                train_writer.add_scalar("loss/d_real_loss", loss_D_real.item(), step)
                train_writer.add_scalar("loss/d_fake_loss", loss_D_fake.item(), step)
                train_writer.add_scalar("loss/epoch", epoch + 1, step)

                test_writer.add_scalar("loss/g_loss", test_g_loss, step)
                test_writer.add_scalar("loss/d_loss", test_d_loss, step)
                test_writer.add_scalar("loss/d_real_loss", test_d_real_loss, step)
                test_writer.add_scalar("loss/d_fake_loss", test_d_fake_loss, step)

                pred_img = model_G(val_latents)
                grid_img = make_grid(pred_img, nrow=3, padding=0)
                grid_img = grid_img.mul(0.5).add_(0.5)

                train_writer.add_image("train/{}".format(epoch), grid_img, step)

                model_D.train()
                model_G.train()
                
        if (epoch + 1) % opt.n_save_epoch == 0:
            save_dir = os.path.join(opt.checkpoints_dir, opt.exper)

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(epoch + 1)))
            model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(epoch + 1)))
            torch.save(model_D.state_dict(), model_d_path)
            torch.save(model_G.state_dict(), model_g_path)

            print("save_model")

    # save model
    save_dir = os.path.join(opt.checkpoints_dir, opt.exper)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(opt.epoch)))
    model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(opt.epoch)))
    torch.save(model_D.state_dict(), model_d_path)
    torch.save(model_G.state_dict(), model_g_path)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="../dataset/face_crop_img")
    parser.add_argument("--checkpoints_dir", default="../checkpoints")
    parser.add_argument("--exper", default="dcgan")
    parser.add_argument("--tensorboard", default="../tensorboard")
    parser.add_argument("--gpu", action="store_true", default=False)
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--test_batch_size", type=int, default=4)
    parser.add_argument("--n_log_step", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--n_save_epoch", type=int, default=10)

    parser.add_argument("--latent_size", type=int, default=100)

    # resume
    parser.add_argument("--resume_epoch", type=int, default=0)

    opt = parser.parse_args()
    main(opt)

generate.py Here we define the behavior of loading the model and generating the image.

generate.py


import os
import argparse

from networks import Generator

import numpy as np
from tqdm import tqdm
import torch
from torchvision.utils import save_image


def generate(latents, model_G, width, height, save_path):
    assert len(latents) == width * height
    pred_images = model_G(latents)
    save_image(pred_images, save_path, nrow=width)

def generate_process(opt):

    # ----- Device Setting -----
    device = torch.device("cpu")

    # ----- Output Setting -----

    img_output_dir = opt.output_dir
    if not os.path.exists(img_output_dir):
        os.mkdir(img_output_dir)

    if not os.path.exists(os.path.join(img_output_dir, "images")):
        os.mkdir(os.path.join(img_output_dir, "images"))
    
    if opt.save_latent is True:
        if not os.path.exists(os.path.join(img_output_dir, "latents")):
            os.mkdir(os.path.join(img_output_dir, "latents"))
    
    print("Output :", img_output_dir)

    # ----- Model Loading -----
    print("Use model :", opt.model)
    model_g = Generator()
    model_g.load_state_dict(torch.load(opt.model, map_location="cpu"))
    model_g.to(device)

    model_g.eval()

    if opt.mode == "normal":
        latents = [torch.randn(size=(opt.width * opt.height, opt.latent_size, 1, 1) for i in range(opt.n_img)]
    
    elif opt.mode == "use_latent":
        assert opt.latent_dir != "None", "latent source directory is not set"
        latent_paths = [os.path.join(opt.latent_dir, name) for name in os.listdir(opt.latent_dir)]
        latents = [torch.from_numpy(np.load(path)) for path in latent_paths]

    elif opt.mode == "inter":
        latent_start = torch.from_numpy(np.load(opt.start_latent))
        latent_end = torch.from_numpy(np.load(opt.end_latent))
        alphas = [float(n / opt.latent_num) for n in range(opt.n_img)]
        latents = [alpha * latent_end + (1 - alpha) * latent_start for alpha in alphas]

    print("Generate image num :", len(latents))

    # ----- Generate Step -----
    print("Start Generate Process")
    for index,latent in tqdm(enumerate(latents)):
        img_path = os.path.join(img_output_dir, "images", str(index + 1) + ".png ")
        generate(latent, model_g, opt.width, opt.height, img_path)

        if opt.save_latent:
            latent_path = os.path.join(img_output_dir, "latents", str(index + 1) + ".npy")
            np.save(latent_path, latent.numpy())

    print("Finish Generate Process")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="../checkpoints/dcgan_test/model_G_1000.pth")
    parser.add_argument("--output_dir", default="./result")
    parser.add_argument("--save_latent", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--width", type=int, default=1)
    parser.add_argument("--height", type=int, default=1)
    parser.add_argument("--mode", choice=["normal", "use_latent", "inter"], default="normal")
    parser.add_argument("--n_img", type=int, default=1)

    # normal generation
    # use latent
    parser.add_argument("--latent_dir", default="None")

    # intermediate vectl
    parser.add_argument("--generate_inter", action="store_true", default=False)
    parser.add_argument("--start_latent", type=str, default=".")
    parser.add_argument("--end_latent", type=str, default=".")


    opt = parser.parse_args()
    generate(opt)
    

Learning results

The dataset was made up of 18,000 animated face images, and learning was done using Google Colaboratory. Below are the results of the learning. 10epoch 10.png 20epoch 20.png 30epoch 30.png 40epoch 40.png 50epoch 50.png 100epoch 100.png 200epoch 200.png 400epoch 400.png 600epoch 600.png 800epoch 800.png 1000epoch 1000.png

Finally

This time, we implemented and learned DCGAN. After all, the accuracy of the generated image was not so high. The dataset is also not annotated, so it seems to be of poor quality. Also, this time it was not so difficult because it generated a 64x64 image, but as this becomes higher resolution, I think that the learning time will become enormous with DCGAN, so other models I would like to study about.

Recommended Posts

I tried to implement and learn DCGAN with PyTorch
I tried to implement CVAE with PyTorch
I tried to implement reading Dataset with PyTorch
I tried to implement Grad-CAM with keras and tensorflow
I tried to implement SSD with PyTorch now (Dataset)
I tried to implement SSD with PyTorch now (model edition)
I tried to implement Autoencoder with TensorFlow
I tried to implement sentence classification by Self Attention with PyTorch
I tried to detect Mario with pytorch + yolov3
I implemented DCGAN and tried to generate apples
I tried to learn PredNet
I tried to implement PCANet
I tried to learn the angle from sin and cos with chainer
I tried to implement StarGAN (1)
I tried to learn the sin function with chainer
I tried to read and save automatically with VOICEROID2 2
I tried to implement Minesweeper on terminal with python
I tried to automatically read and save with VOICEROID2
I tried to implement an artificial perceptron with python
I tried to implement time series prediction with GBDT
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I tried to implement Deep VQE
I tried to implement adversarial validation
I tried to explain Pytorch dataset
I tried implementing DeepPose with PyTorch
I tried to implement hierarchical clustering
I tried to implement Realness GAN
I tried to implement a volume moving average with Quantx
I tried to predict and submit Titanic survivors with Kaggle
I tried to implement breakout (deception avoidance type) with Quantx
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement ListNet of rank learning with Chainer
I tried to implement Harry Potter sort hat with CNN
I tried to make GUI tic-tac-toe with Python and Tkinter
I tried to implement PLSA in Python
I tried to implement permutation in Python
I tried to visualize AutoEncoder with TensorFlow
I tried to visualize bookmarks flying to Slack with Doc2Vec and PCA
I tried to get started with Hy
I tried to implement PLSA in Python 2
I tried to make a periodical process with Selenium and Python
I tried to implement ADALINE in Python
[Introduction to Pytorch] I played with sinGAN ♬
I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch PartⅡ
I tried to implement PPO in Python
I tried to create Bulls and Cows with a shell program
I tried to easily detect facial landmarks with python and dlib
I tried to solve TSP with QAOA
I tried to implement deep learning that is not deep with only NumPy
I tried to implement a blockchain that actually works with about 170 lines
I tried to express sadness and joy with the stable marriage problem.
[Deep Learning from scratch] I tried to implement sigmoid layer and Relu layer.
I tried to convert datetime <-> string with tzinfo using strftime () and strptime ()
I tried to control the network bandwidth and delay with the tc command
I tried to predict next year with AI
I tried to use lightGBM, xgboost with Boruta
I tried to save the data with discord
I tried to detect motion quickly with OpenCV
I tried to integrate with Keras in TFv1.1
I tried to let VAE learn motion graphics