[PYTHON] J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch

Que faire cette fois

Depuis que j'ai créé un modèle de traitement d'image qui effectue une super-résolution auparavant, j'ai décidé de créer un modèle qui génère ensuite des images. Par conséquent, je pense que le GAN est le moyen le plus important de générer des images, et je pense à mettre en œuvre et à apprendre DCGAN, qui est un modèle relativement simple.

À propos de DCGAN

L'idée du GAN (Generative adversarial network) est simple, et la concurrence entre le créateur du faux et l'évaluateur qui le voit est de créer un faux plus précis. Par conséquent, la structure du réseau -Générateur qui génère une image avec un bruit approprié en entrée ・ Discriminateur qui juge l'authenticité en saisissant une image Il est essentiellement composé de deux choses. Pour le mécanisme détaillé, GAN (1) Comprendre la structure de base que je ne peux plus demander j'ai étudié ici. DCGAN (Deep Convolutional Generative adversarial network) est un modèle qui génère des images en utilisant la couche de convolution inverse.

la mise en oeuvre

Le code a été implémenté dans les quatre fichiers suivants. ・ Networks.py: Rédaction sur la structure du réseau -Utils.py: Écriture sur les fonctions de jeu de données et de perte ・ Train.py: Former le modèle -Generate.py: Génère une image en utilisant le modèle appris. networks.py Ici, nous définissons la structure du générateur et du discriminateur. Cette fois, il a une structure simple car il s'agit de DCGAN.

networks.py



import torch
from torch import nn

class Generator(nn.Module):
    def __init__(self, latent_size = 100):
        super().__init__()
        self.main = nn.Sequential(

            nn.ConvTranspose2d(latent_size, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).squeeze()

utils.py

Ici, nous définissons l'ensemble de données pour une utilisation facile pendant la formation.

utils.py


import os

from torch.utils.data import Dataset
from torchvision import transforms
import torch
from PIL import Image

class Dcgan_Dataset(Dataset):
    def __init__(self, root, datamode = "train", transform = transforms.ToTensor(), latent_size=100):

        self.image_dir = os.path.join(root, datamode)
        self.image_paths = [os.path.join(self.image_dir, name) for name in os.listdir(self.image_dir)]
        self.data_length = len(self.image_paths)

        self.transform = transform
        self.latent_size = latent_size

    def __len__(self):
        return self.data_length
    
    def __getitem__(self, index):
        latent = torch.randn(size=(self.latent_size, 1, 1))
        img_path = self.image_paths[index]
        img = Image.open(img_path)

        if not self.transform is None:
            img = self.transform(img)
        
        return latent, img


train.py

Ici, nous définissons l'apprentissage. Cela fait un peu longtemps parce que je le teste et que je l'enregistre sur le tensorboard.


import os
import argparse

from networks import Generator, Discriminator
from utils import Dcgan_Dataset

import numpy as np

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

from tqdm import tqdm



def main(opt):
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    # ----- Device Setting -----
    if opt.gpu is True:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")
    print("Device :", device)

    # ----- Dataset Setting -----
    train_dataset = Dcgan_Dataset(opt.dataset, datamode="train",
                                  transform=transforms.Compose([transforms.RandomHorizontalFlip(),
                                                                transforms.ToTensor()]))
    
    test_dataset = Dcgan_Dataset(opt.dataset, datamode="test")

    print("Training Dataset :", os.path.join(opt.dataset, "train"))
    print("Testing Dataset :", os.path.join(opt.dataset, "test"))

    # ----- DataLoader Setting -----
    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=True)

    print("batch_size :",opt.batch_size)
    print("test_batch_size :",opt.test_batch_size)

    # ----- Summary Writer Setting -----
    train_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper))
    test_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper + "_test"))

    print("log directory :",os.path.join(opt.tensorboard, opt.exper))
    print("log step :", opt.n_log_step)

    # ----- Net Work Setting -----
    latent_size = opt.latent_size
    model_D = Discriminator()
    model_G = Generator()

    # resume
    if opt.resume_epoch != 0:
        model_D_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_D_{}.pth".format(str(opt.resume_epoch)))
        model_G_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_G_{}.pth".format(str(opt.resume_epoch)))

        model_G.load_state_dict(torch.load(model_G_path, map_location="cpu"))
        model_D.load_state_dict(torch.load(model_D_path, map_location="cpu"))

    model_D.to(device)
    model_G.to(device)
    model_D.train()
    model_G.train()

    #Variable d'étiquette lors du calcul de la perte
    ones = torch.ones(opt.batch_size).to(device) #Exemple positif 1
    zeros = torch.zeros(opt.batch_size).to(device) #Exemple négatif 0

    val_latents = torch.randn(9, opt.latent_size, 1, 1).to(device)
    loss_f = nn.BCEWithLogitsLoss()

    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=0.0002)
    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=0.0002)

    print("Latent size :",opt.latent_size)

    # ----- Training Loop -----
    step = 0
    for epoch in tqdm(range(opt.resume_epoch, opt.resume_epoch + opt.epoch)):
        print("epoch :",epoch + 1,"/", opt.resume_epoch + opt.epoch)

        # for latent, real_img in tqdm(train_loader):
        for latent, real_img in train_loader:
            step += 1
            latent = latent.to(device)
            real_img = real_img.to(device)
            batch_len = len(real_img)

            fake_img = model_G(latent)

            pred_fake = model_D(fake_img)
            loss_G = loss_f(pred_fake, ones[: batch_len])

            model_D.zero_grad()
            model_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            pred_real = model_D(real_img)
            loss_D_real = loss_f(pred_real, ones[: batch_len])
            fake_img = model_G(latent)
            pred_fake = model_D(fake_img)
            loss_D_fake = loss_f(pred_fake, zeros[: batch_len])
            loss_D = loss_D_real + loss_D_fake

            model_D.zero_grad()
            model_G.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            if step % opt.n_log_step == 0:
                # test step 
                model_G.eval()
                model_D.eval()
                test_d_losses = []
                test_d_real_losses = []
                test_d_fake_losses = []
                test_g_losses = []
                for test_latent, test_real_img in test_loader:
                    test_latent = test_latent.to(device)
                    test_real_img = test_real_img.to(device)
                    batch_len = len(test_latent)
                    test_pred_img = model_G(test_latent)
                    test_fake_g = model_D(test_pred_img)
                    test_g_loss = loss_f(test_fake_g, ones[: batch_len])

                    test_g_losses.append(test_g_loss.item())

                    test_fake_d = model_D(test_pred_img)
                    test_real_d = model_D(test_real_img)
                    test_d_real_loss = loss_f(test_real_d, ones[: batch_len])
                    test_d_fake_loss = loss_f(test_fake_d, zeros[: batch_len])
                    test_d_loss = test_d_real_loss + test_d_fake_loss

                    test_d_real_losses.append(test_d_real_loss.item())
                    test_d_fake_losses.append(test_d_fake_loss.item())
                    test_d_losses.append(test_d_loss.item())
                
                # record process
                test_g_loss = sum(test_g_losses)/len(test_g_losses)
                test_d_loss = sum(test_d_losses)/len(test_d_losses)
                test_d_real_loss = sum(test_d_real_losses)/len(test_d_real_losses)
                test_d_fake_loss = sum(test_d_fake_losses)/len(test_d_fake_losses)


                train_writer.add_scalar("loss/g_loss", loss_G.item(), step)
                train_writer.add_scalar("loss/d_loss", loss_D.item(), step)
                train_writer.add_scalar("loss/d_real_loss", loss_D_real.item(), step)
                train_writer.add_scalar("loss/d_fake_loss", loss_D_fake.item(), step)
                train_writer.add_scalar("loss/epoch", epoch + 1, step)

                test_writer.add_scalar("loss/g_loss", test_g_loss, step)
                test_writer.add_scalar("loss/d_loss", test_d_loss, step)
                test_writer.add_scalar("loss/d_real_loss", test_d_real_loss, step)
                test_writer.add_scalar("loss/d_fake_loss", test_d_fake_loss, step)

                pred_img = model_G(val_latents)
                grid_img = make_grid(pred_img, nrow=3, padding=0)
                grid_img = grid_img.mul(0.5).add_(0.5)

                train_writer.add_image("train/{}".format(epoch), grid_img, step)

                model_D.train()
                model_G.train()
                
        if (epoch + 1) % opt.n_save_epoch == 0:
            save_dir = os.path.join(opt.checkpoints_dir, opt.exper)

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(epoch + 1)))
            model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(epoch + 1)))
            torch.save(model_D.state_dict(), model_d_path)
            torch.save(model_G.state_dict(), model_g_path)

            print("save_model")

    # save model
    save_dir = os.path.join(opt.checkpoints_dir, opt.exper)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(opt.epoch)))
    model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(opt.epoch)))
    torch.save(model_D.state_dict(), model_d_path)
    torch.save(model_G.state_dict(), model_g_path)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="../dataset/face_crop_img")
    parser.add_argument("--checkpoints_dir", default="../checkpoints")
    parser.add_argument("--exper", default="dcgan")
    parser.add_argument("--tensorboard", default="../tensorboard")
    parser.add_argument("--gpu", action="store_true", default=False)
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--test_batch_size", type=int, default=4)
    parser.add_argument("--n_log_step", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--n_save_epoch", type=int, default=10)

    parser.add_argument("--latent_size", type=int, default=100)

    # resume
    parser.add_argument("--resume_epoch", type=int, default=0)

    opt = parser.parse_args()
    main(opt)

generate.py Ici, nous définissons le comportement de chargement du modèle et de génération de l'image.

generate.py


import os
import argparse

from networks import Generator

import numpy as np
from tqdm import tqdm
import torch
from torchvision.utils import save_image


def generate(latents, model_G, width, height, save_path):
    assert len(latents) == width * height
    pred_images = model_G(latents)
    save_image(pred_images, save_path, nrow=width)

def generate_process(opt):

    # ----- Device Setting -----
    device = torch.device("cpu")

    # ----- Output Setting -----

    img_output_dir = opt.output_dir
    if not os.path.exists(img_output_dir):
        os.mkdir(img_output_dir)

    if not os.path.exists(os.path.join(img_output_dir, "images")):
        os.mkdir(os.path.join(img_output_dir, "images"))
    
    if opt.save_latent is True:
        if not os.path.exists(os.path.join(img_output_dir, "latents")):
            os.mkdir(os.path.join(img_output_dir, "latents"))
    
    print("Output :", img_output_dir)

    # ----- Model Loading -----
    print("Use model :", opt.model)
    model_g = Generator()
    model_g.load_state_dict(torch.load(opt.model, map_location="cpu"))
    model_g.to(device)

    model_g.eval()

    if opt.mode == "normal":
        latents = [torch.randn(size=(opt.width * opt.height, opt.latent_size, 1, 1) for i in range(opt.n_img)]
    
    elif opt.mode == "use_latent":
        assert opt.latent_dir != "None", "latent source directory is not set"
        latent_paths = [os.path.join(opt.latent_dir, name) for name in os.listdir(opt.latent_dir)]
        latents = [torch.from_numpy(np.load(path)) for path in latent_paths]

    elif opt.mode == "inter":
        latent_start = torch.from_numpy(np.load(opt.start_latent))
        latent_end = torch.from_numpy(np.load(opt.end_latent))
        alphas = [float(n / opt.latent_num) for n in range(opt.n_img)]
        latents = [alpha * latent_end + (1 - alpha) * latent_start for alpha in alphas]

    print("Generate image num :", len(latents))

    # ----- Generate Step -----
    print("Start Generate Process")
    for index,latent in tqdm(enumerate(latents)):
        img_path = os.path.join(img_output_dir, "images", str(index + 1) + ".png ")
        generate(latent, model_g, opt.width, opt.height, img_path)

        if opt.save_latent:
            latent_path = os.path.join(img_output_dir, "latents", str(index + 1) + ".npy")
            np.save(latent_path, latent.numpy())

    print("Finish Generate Process")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="../checkpoints/dcgan_test/model_G_1000.pth")
    parser.add_argument("--output_dir", default="./result")
    parser.add_argument("--save_latent", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--width", type=int, default=1)
    parser.add_argument("--height", type=int, default=1)
    parser.add_argument("--mode", choice=["normal", "use_latent", "inter"], default="normal")
    parser.add_argument("--n_img", type=int, default=1)

    # normal generation
    # use latent
    parser.add_argument("--latent_dir", default="None")

    # intermediate vectl
    parser.add_argument("--generate_inter", action="store_true", default=False)
    parser.add_argument("--start_latent", type=str, default=".")
    parser.add_argument("--end_latent", type=str, default=".")


    opt = parser.parse_args()
    generate(opt)
    

Résultat d'apprentissage

L'ensemble de données a été réalisé avec 18 000 images animées de visage et l'apprentissage a été réalisé à l'aide de Google Colaboratory. Voici les résultats de l'apprentissage. 10epoch 10.png 20epoch 20.png 30epoch 30.png 40epoch 40.png 50epoch 50.png 100epoch 100.png 200epoch 200.png 400epoch 400.png 600epoch 600.png 800epoch 800.png 1000epoch 1000.png

finalement

Cette fois, nous avons implémenté et appris DCGAN. Après tout, la précision de l'image générée n'était pas si élevée. Le jeu de données n'est pas non plus annoté, il semble donc être de mauvaise qualité. De plus, cette fois, ce n'était pas si difficile car cela générait une image 64x64, mais à mesure que la résolution augmente, je pense que le temps d'apprentissage deviendra énorme avec DCGAN, donc d'autres modèles Je voudrais étudier.

Recommended Posts

J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
J'ai essayé d'implémenter Grad-CAM avec keras et tensorflow
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai essayé d'implémenter SSD avec PyTorch maintenant (édition du modèle)
J'ai essayé d'implémenter Autoencoder avec TensorFlow
J'ai essayé d'implémenter la classification des phrases par Self Attention avec PyTorch
J'ai implémenté DCGAN et essayé de générer des pommes
J'ai essayé d'apprendre PredNet
J'ai essayé d'implémenter PCANet
J'ai essayé d'apprendre l'angle du péché et du cos avec le chainer
J'ai essayé d'implémenter StarGAN (1)
J'ai essayé d'entraîner la fonction péché avec chainer
J'ai essayé de lire et d'enregistrer automatiquement avec VOICEROID2 2
J'ai essayé d'implémenter Mine Sweeper sur un terminal avec python
J'ai essayé de lire et d'enregistrer automatiquement avec VOICEROID2
J'ai essayé d'implémenter le perceptron artificiel avec python
[Introduction à Pytorch] J'ai essayé de catégoriser Cifar10 avec VGG16 ♬
J'ai essayé d'implémenter Deep VQE
J'ai essayé de mettre en place une validation contradictoire
J'ai essayé d'expliquer l'ensemble de données de Pytorch
J'ai essayé d'implémenter DeepPose avec PyTorch
J'ai essayé d'implémenter Realness GAN
J'ai essayé d'implémenter une ligne moyenne mobile de volume avec Quantx
J'ai essayé de prédire et de soumettre les survivants du Titanic avec Kaggle
J'ai essayé de mettre en œuvre une évasion (type d'évitement de tromperie) avec Quantx
J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai essayé d'implémenter ListNet d'apprentissage de rang avec Chainer
J'ai essayé de mettre en œuvre le chapeau de regroupement de Harry Potter avec CNN
J'ai essayé de créer une interface graphique à trois yeux côte à côte avec Python et Tkinter
J'ai essayé d'implémenter PLSA en Python
J'ai essayé d'implémenter la permutation en Python
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de visualiser les signets volant vers Slack avec Doc2Vec et PCA
J'ai essayé de commencer avec Hy
J'ai essayé d'implémenter PLSA dans Python 2
J'ai essayé de faire un processus d'exécution périodique avec Selenium et Python
J'ai essayé d'implémenter ADALINE en Python
[Introduction à Pytorch] J'ai joué avec sinGAN ♬
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé d'implémenter PPO en Python
J'ai essayé de créer des taureaux et des vaches avec un programme shell
J'ai essayé de détecter facilement les points de repère du visage avec python et dlib
J'ai essayé de résoudre TSP avec QAOA
J'ai essayé de mettre en œuvre un apprentissage en profondeur qui n'est pas profond avec uniquement NumPy
J'ai essayé de mettre en œuvre une blockchain qui fonctionne réellement avec environ 170 lignes
J'ai essayé d'exprimer de la tristesse et de la joie face au problème du mariage stable.
[Deep Learning from scratch] J'ai essayé d'implémenter la couche sigmoïde et la couche Relu
J'ai essayé de convertir la chaîne datetime <-> avec tzinfo en utilisant strftime () et strptime ()
J'ai essayé de contrôler la bande passante et le délai du réseau avec la commande tc
J'ai essayé de prédire l'année prochaine avec l'IA
J'ai essayé d'utiliser lightGBM, xg boost avec Boruta
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé de détecter rapidement un mouvement avec OpenCV
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé de laisser VAE apprendre les animations