[PYTHON] Super-resolution with SRGAN and ESRGAN

SRGAN SRGAN is an algorithm that uses a neural network to increase the resolution of images, and this time I implemented it. reference https://qiita.com/pacifinapacific/items/ec338a500015ae8c33fe https://buildersbox.corp-sansan.com/entry/2019/04/29/110000

Implemented for the time being

github https://github.com/AokiMasataka/Super-resolution The dataset uses the same SRResNet that I created a long time ago. SResNet article https://qiita.com/AokiMasataka/items/3d382310d8a78f711c71 The network will be implemented in PyTorch as well as practicing PyTorch. SRGAN's Generator network consists of ResNet + Pixcelshuffer.

If you write it in code, it will look like this.

class ResidualBlock(nn.Module):
    def __init__(self, nf=64):
        super(ResidualBlock, self).__init__()
        self.Block = nn.Sequential(
            nn.Conv2d(nf, nf, kernel_size=3, padding=1),
            nn.BatchNorm2d(nf),
            nn.PReLU(),
            nn.Conv2d(nf, nf, kernel_size=3, padding=1),
            nn.BatchNorm2d(nf),
        )

    def forward(self, x):
        out = self.Block(x)
        return x + out


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.relu = nn.PReLU()

        self.residualLayer = nn.Sequential(
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock()
        )

        self.pixelShuffle = nn.Sequential(
            nn.Conv2d(64, 64*4, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(64, 3, kernel_size=9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.conv1(x)
        skip = self.relu(x)

        x = self.residualLayer(skip)
        x = self.pixelShuffle(x + skip)
        return x

Discriminator uses an unconventional convolutional network. The size of the argument is the vertical and horizontal size of the image, this time the size of the input image is 64x64.

class Discriminator(nn.Module):
    def __init__(self, size=64):
        super(Discriminator, self).__init__()
        size = int(size / 8) ** 2

        self.net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            Flatten(),
            nn.Linear(128 * size, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        return self.net(x)

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)

Generator loss Vggloss is used for the loss of the Generator, vggloss is clearer by making the average of the features through the layers of the trained vgg model, whereas mseloss makes the average of the pixels of the image loss Generate an image.

class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        vgg = models.vgg16(pretrained=True)
        self.contentLayers = nn.Sequential(*list(vgg.features)[:31]).cuda().eval()
        for param in self.contentLayers.parameters():
            param.requires_grad = False

    def forward(self, fakeFrame, frameY):
        MSELoss = nn.MSELoss()
        content_loss = MSELoss(self.contentLayers(fakeFrame), self.contentLayers(frameY))
        return content_loss

The loss of the Generator is the sum of this content_loss and the BCE Loss output from the Discriminator. Based on these, we will create a train function

def train(loader):
    tensor_x, tensor_y = torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.float)
    DS = TensorDataset(tensor_x, tensor_y)
    loader = DataLoader(DS, batch_size=BATCH_SIZE, shuffle=True)
    D.train()
    G.train()

    D_optimizer = torch.optim.Adam(D.parameters(), lr=DiscriminatorLR, betas=(0.9, 0.999))
    G_optimizer = torch.optim.Adam(G.parameters(), lr=GeneratorLR, betas=(0.9, 0.999))

    realLabel = torch.ones(BATCH_SIZE, 1).cuda()
    fakeLabel = torch.zeros(BATCH_SIZE, 1).cuda()
    BCE = torch.nn.BCELoss()
    VggLoss = VGGLoss()

    for batch_idx, (X, Y) in enumerate(loader):
        if X.shape[0] < BATCH_SIZE:
            break

        X = X.cuda()
        Y = Y.cuda()

        fakeFrame = G(X)

        D.zero_grad()
        DReal = D(Y)
        DFake = D(fakeFrame)

        D_loss = (BCE(DFake, fakeLabel) + BCE(DReal, realLabel)) / 2
        D_loss.backward(retain_graph=True)
        D_optimizer.step()

        G.zero_grad()
        G_label_loss= BCE(DFake, realLabel)
        G_loss = VggLoss(fakeFrame, Y) + 1e-3 * G_label_loss

        G_loss.backward()
        G_optimizer.step()

        print("G_loss :", G_loss, " D_loss :", D_loss)

The image below shows the result of 32epoch training. The top is a reduced image, the middle is the SRGAN output, and the bottom is the original image. Feeling not bad for accuracy, SRGAN.png

ESRGAN

Difference from SRGAN

RRDN(Residual in Residual Dense Network) ・ It seems that the generation capacity will increase by removing the batch normalization. DenseBlock adds layer output to all layer inputs ・ Furthermore, connect three Dense Blocks in the same way as ResNet. residual-in-residual-dense-block-RRDB.png When implemented, it looks like this

class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, padding=1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, padding=1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, padding=1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, padding=1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, padding=1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), dim=1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), dim=1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), dim=1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), dim=1))
        return x5 * 0.2 + x


class Generator(nn.Module):
    def __init__(self, nf=64):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, nf, kernel_size=3, padding=1)
        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.blockLayer = nn.Sequential(
            ResidualDenseBlock(),
            ResidualDenseBlock(),
            ResidualDenseBlock(),
        )

        self.pixelShuffle = nn.Sequential(
            nn.Conv2d(nf, nf * 4, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.PixelShuffle(2),
            nn.Conv2d(nf, nf, kernel_size=3, padding=1),
            nn.Conv2d(nf, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

Relativistic GAN The SRGAN discriminator trains the real one to output 1 fake as 0, but the Relativistic GAN compares the real image with the fake image and sets the difference and label as BC Eloss. reference https://github.com/Yagami360/MachineLearning-Papers_Survey/issues/51 VGG Perceptual Loss In SRGAN, features were extracted using VGG16, but in Perceptual Loss, the structure is such that L1_loss for each pooling layer of VGG16 is added. It looks like this when I write it roughly

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(models.vgg16(pretrained=True).features[16:23].eval())
        blocks.append(models.vgg16(pretrained=True).features[23:30].eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks).cuda()
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), requires_grad=False).cuda()
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), requires_grad=False).cuda()

    def forward(self, fakeFrame, frameY):
        fakeFrame = (fakeFrame - self.mean) / self.std
        frameY = (frameY - self.mean) / self.std
        loss = 0.0
        x = fakeFrame
        y = frameY
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.l1_loss(x, y)
        return loss

Learning results

The top is a reduced image, the middle is the output with ESRGAN, and the bottom is the original image Like SRGAN, 32epoch is upscaled from 32px to 64px. ESRGAN.png Let's compare the generated images side by side, the top is SRGAN and the bottom is ESRGAN. Noise is noticeable in SRGAN, but there is less noise in ESRGAN, and the overall outline is clearer than in SRGAN. SRGAN.png ESRGAN.png

Recommended Posts

Super-resolution with SRGAN and ESRGAN
With and without WSGI
With me, cp, and Subprocess
Programming with Python and Tkinter
Working with tkinter and mouse
Python and hardware-Using RS232C with Python-
group_by with sqlalchemy and sum
python with pyenv and venv
With me, NER and Flair
Works with Python and R
Communicate with FX-5204PS with Python and PyUSB
Shining life with Python and OpenCV
Python-Mouse and keyboard operation with pyautogui
Sorting with mixed numbers and letters
Robot running with Arduino and python
Install Python 2.7.9 and Python 3.4.x with pip.
Neural network with OpenCV 3 and Python 3
Scraping with Node, Ruby and Python
Easy Slackbot with Docker and Errbot
Image segmentation with scikit-image and scikit-learn
Authentication process with gRPC and Firebase Authentication
Scraping with Python, Selenium and Chromedriver
Play with Poincare series and SymPy
HTTPS with Django and Let's Encrypt
Photo segmentation and clustering with DBSCAN
Scraping with Python and Beautiful Soup
NAS backup with php and rsync
JSON encoding and decoding with python
Path processing with takewhile and dropwhile
Basic authentication and Digest authentication with Flask
Hadoop introduction and MapReduce with Python
[GUI with Python] PyQt5-Drag and drop-
Compare DCGAN and pix2pix with keras
Introduce errBot and work with Slack
Save and retrieve files with Pepper
Async / await with Kivy and tkinter
I played with PyQt5 and Python3
Login with PycURL and receive response
Experimented with unicode, decode and encode
Reading and writing CSV with Python
Multiple integrals with Python and Sympy
Coexistence of Python2 and 3 with CircleCI (1.0)
Easy modeling with Blender and Python
Draw shapes with OpenCV and PIL
Sugoroku game and addition game with python
Upload and download images with falcon
FM modulation and demodulation with Python
Environment construction with pyenv and pyenv-virtualenv