[PYTHON] NikuGan ~ I want to see a lot of delicious meat! !!



Introduction

On New Year's Eve when I returned to my parents' house, I had nothing to do, so I came up with the idea of ​​creating an image of ** delicious ** meat using DCGAN. (I did it in 4 hours, so the quality is insanely low ...)

Collecting images of real delicious meat

** It looks delicious ** I collected it from the following site to collect image data of meat.

onikuimages

お肉

When this site came out on Twitter before, I wanted to use it, so I used it this time. We obtained 60 images from this site and used them as real images.

(If you do it properly, you should scrape and collect a lot of images, but I don't do it because I don't have time.)

Implement DCGAN

My implementation is here. Please refer to those who want to do it properly. The implementation is done by referring to the articles of Official DCGAN Tutorial and hkthirano of Pytorch.

Data set preprocessing

In the preprocessing, all the image sizes are cropped to 64 * 64. I also did the pre-processing as described in the official tutorial.

image_size = 64
batch_size = 2
workers = 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Also perform preprocessing
dataset = datasets.ImageFolder(IMG_DIR,
                                transform=transforms.Compose([
                                    transforms.Resize(image_size),
                                    transforms.CenterCrop(image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                ]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

If you take a peek at the preprocessed image

#Take a look at the dataloader
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis('off')
plt.title('Training image')
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

dataloader

It feels good that the pretreatment is done properly. However, since I lowered the resolution, at this point it looks like ** delicious ** meat is no longer visible ...

Generator

The purpose of Generator is to create a fake image that is difficult for Discriminator to distinguish. The latent variable is set to 100 dimensions, and a 64 * 64 * 3 image is generated from it.

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.ConvTranspose2d(
                in_channels=100, 
                out_channels=256, 
                kernel_size=4, 
                stride=1, 
                padding=0, 
                bias=False
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

Discriminator

The purpose of Discriminator is to properly distinguish between fake images and real images generated by Generator.

Outputs a scalar value of whether it is genuine (1 or 0) from a 64 * 64 * 3 image.

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).squeeze()

Learn hostilely

Just as Generator creates fake images that are difficult to discriminate, Discriminator learns so that they can be discriminated properly.

By repeating the following loop many times, you will learn hostilely.

  1. Generate a fake image from the latent feature z and learn the Generator so that it can be deceived well (← Generator grows)
  2. Learn Discriminator so that it can be properly identified using real images and fake images (← Discriminator grows)
#Training function
def train_dcgan(model_G, model_D, params_G, params_D, dataloader):
    log_loss_G = []
    log_loss_D = []
    for real_img, _ in dataloader:
        batch_len = len(real_img)

        # ===Train Generator===
        #Generate fake image

        z = torch.randn(batch_len, nz, 1, 1).to(device)
        fake_img = model_G(z)

        #Temporarily save fake images
        #To avoid generating fake images twice
        fake_img_tensor = fake_img.detach()

        #Calculate to deceive fake images as real
        out = model_D(fake_img)
        loss_G = loss_f(out, ones[: batch_len])
        log_loss_G.append(loss_G.item())

        #I will update
        model_D.zero_grad()
        model_G.zero_grad()
        loss_G.backward()
        params_G.step()

        # ==Discriminator training===
        #Real image
        real_img = real_img.to(device)

        #Find the loss so that you can calculate the real image
        real_out = model_D(real_img)
        loss_D_real = loss_f(real_out, ones[:batch_len])

        #Fake image saved earlier
        fake_img = fake_img_tensor

        #Ask for loss so that fake images can be identified as fake
        fake_out = model_D(fake_img_tensor)
        loss_D_fake = loss_f(fake_out, zeros[:batch_len])

        #Total the loss of genuine and fake
        loss_D = loss_D_real + loss_D_fake
        log_loss_D.append(loss_D.item())

        #I will update
        model_D.zero_grad()
        model_G.zero_grad()
        loss_D.backward()
        params_D.step()
    
    return mean(log_loss_G), mean(log_loss_D)

Actually train and generate!

We learned with a batch size of 2,1000 epochs. Below is a gif that summarizes the learning results for each 100 epochs. first_gif.gif how is it? Doesn't it look like meat? ??

At first, it looks like noise, so at 300-500 epochs, I think there is meat on a white plate on a white background. However, after 500 epochs, it has returned to just having meat on a black background ... (Does the 500 epoch look the most like a real image?) image.png

Postscript (2021/1/1) We learned with a batch size of 8,5000 epochs. It looks like the meat is more meaty than last time. However, only similar images are generated and mode collapse occurs. The cause is that the vector of latent features is weak in 100 dimensions? As a result, I can't make meat without diversity, but I feel that it is closer to real meat. 8_5000.gif

Reflections and impressions

The quality of the image and the number of images are considered to be the reasons why it was not generated cleanly. Even though I collected ** delicious ** images, I thought it was a waste to reduce the resolution because of learning. Also, the number of sheets was collected from one site, so I think that it is overwhelmingly insufficient.

It was great to be able to produce meat by the end of 2020, starting from the idea. DCGAN is amazing because it can generate meaty things from only 60 images! !!

If I have time, I would like to produce more quality ** delicious ** meat!

Recommended Posts

NikuGan ~ I want to see a lot of delicious meat! !!
I want to start a lot of processes from python
I want to easily find a delicious restaurant
I want to install a package of Php Redis
I want to see a list of WebDAV files in the Requests module
I want to collect a lot of images, so I tried using "google image download"
I want to build a Python environment
I want to sort a list in the order of other lists
I want to color a part of an Excel string in Python
Python: I want to measure the processing time of a function neatly
I made a function to see the movement of a two-dimensional array (Python)
I want to make matplotlib a dark theme
I want to easily create a Noise Model
I want to create a window in Python
I want to make a game with Python
I don't want to take a coding test
I want to get League of Legends data ③
I want to get League of Legends data ②
I want to create a plug-in type implementation
I want to customize the appearance of zabbix
I want to get League of Legends data ①
I want to write to a file with Python
I want to upload a Django app to heroku
I want to display only different lines of a text file with diff
The story of IPv6 address that I want to keep at a minimum
I want to set a life cycle in the task definition of ECS
I want to add silence to the beginning of a wav file for 1 second
I want to create a web application that uses League of Legends data ①
I want to see the file name from DataLoader
I want to embed a variable in a Python string
I want to easily implement a timeout in python
I want to detect images of cats from Instagram
I want to iterate a Python generator many times
I want DQN Puniki to hit a home run
100 image processing knocks !! (021-030) I want to take a break ...
I want to give a group_id to a pandas data frame
I want to generate a UUID quickly (memorandum) ~ Python ~
I want to transition with a button in flask
I want to grep the execution result of strace
I want to climb a mountain with reinforcement learning
I want to write in Python! (2) Let's write a test
I want to find a popular package on PyPi
I want to randomly sample a file in Python
I want to fully understand the basics of Bokeh
I want to easily build a model-based development environment
I want to work with a robot in python.
I want to split a character string with hiragana
[Python] I want to make a nested list a tuple
Python + selenium to GW a lot of e-mail addresses
I want to manually create a legend with matplotlib
I want to send a business start email automatically
I want to run a quantum computer with Python
I want to bind a local variable with lambda
I want to increase the security of ssh connections
The story of Linux that I want to teach myself half a year ago
I want to take a screenshot of the site on Docker using any font
I want to use a network defined by myself in PPO2 of Stable Baselines
I want a mox generator
I want to specify another version of Python with pyvenv
I thought a little because the Trace Plot of the stan parameter is hard to see.
I want to make a blog editor with django admin