[PYTHON] Super résolution avec SRGAN et ESRGAN

SRGAN SRGAN est un algorithme qui utilise un réseau de neurones pour augmenter la résolution des images, et cette fois je l'ai implémenté. référence https://qiita.com/pacifinapacific/items/ec338a500015ae8c33fe https://buildersbox.corp-sansan.com/entry/2019/04/29/110000

Mis en œuvre pour le moment

github https://github.com/AokiMasataka/Super-resolution L'ensemble de données utilise le même SRResNet que j'ai créé il y a longtemps. Article SResNet https://qiita.com/AokiMasataka/items/3d382310d8a78f711c71 Le réseau sera implémenté dans PyTorch ainsi que la pratique de PyTorch. Le réseau de générateurs de SRGAN se compose de ResNet + Pixcelshuffer.

Si vous l'écrivez dans le code, cela ressemblera à ceci.

class ResidualBlock(nn.Module):
    def __init__(self, nf=64):
        super(ResidualBlock, self).__init__()
        self.Block = nn.Sequential(
            nn.Conv2d(nf, nf, kernel_size=3, padding=1),
            nn.BatchNorm2d(nf),
            nn.PReLU(),
            nn.Conv2d(nf, nf, kernel_size=3, padding=1),
            nn.BatchNorm2d(nf),
        )

    def forward(self, x):
        out = self.Block(x)
        return x + out


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.relu = nn.PReLU()

        self.residualLayer = nn.Sequential(
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock(),
            ResidualBlock()
        )

        self.pixelShuffle = nn.Sequential(
            nn.Conv2d(64, 64*4, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(64, 3, kernel_size=9, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.conv1(x)
        skip = self.relu(x)

        x = self.residualLayer(skip)
        x = self.pixelShuffle(x + skip)
        return x

Discriminator utilise un réseau convolutif non conventionnel. La taille de l'argument est la taille verticale et horizontale de l'image, cette fois la taille de l'image d'entrée est de 64x64.

class Discriminator(nn.Module):
    def __init__(self, size=64):
        super(Discriminator, self).__init__()
        size = int(size / 8) ** 2

        self.net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            Flatten(),
            nn.Linear(128 * size, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        return self.net(x)

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)

Generator loss Vggloss est utilisé pour la perte du générateur, vggloss est plus clair en laissant la moyenne de ses caractéristiques être la perte à travers les couches du modèle vgg entraîné, tandis que mseloss est la moyenne des pixels de l'image. Générez une image.

class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        vgg = models.vgg16(pretrained=True)
        self.contentLayers = nn.Sequential(*list(vgg.features)[:31]).cuda().eval()
        for param in self.contentLayers.parameters():
            param.requires_grad = False

    def forward(self, fakeFrame, frameY):
        MSELoss = nn.MSELoss()
        content_loss = MSELoss(self.contentLayers(fakeFrame), self.contentLayers(frameY))
        return content_loss

La perte du générateur est la somme de ce content_loss et de la sortie BCE Loss du discriminateur. Sur cette base, nous allons créer une fonction de train

def train(loader):
    tensor_x, tensor_y = torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.float)
    DS = TensorDataset(tensor_x, tensor_y)
    loader = DataLoader(DS, batch_size=BATCH_SIZE, shuffle=True)
    D.train()
    G.train()

    D_optimizer = torch.optim.Adam(D.parameters(), lr=DiscriminatorLR, betas=(0.9, 0.999))
    G_optimizer = torch.optim.Adam(G.parameters(), lr=GeneratorLR, betas=(0.9, 0.999))

    realLabel = torch.ones(BATCH_SIZE, 1).cuda()
    fakeLabel = torch.zeros(BATCH_SIZE, 1).cuda()
    BCE = torch.nn.BCELoss()
    VggLoss = VGGLoss()

    for batch_idx, (X, Y) in enumerate(loader):
        if X.shape[0] < BATCH_SIZE:
            break

        X = X.cuda()
        Y = Y.cuda()

        fakeFrame = G(X)

        D.zero_grad()
        DReal = D(Y)
        DFake = D(fakeFrame)

        D_loss = (BCE(DFake, fakeLabel) + BCE(DReal, realLabel)) / 2
        D_loss.backward(retain_graph=True)
        D_optimizer.step()

        G.zero_grad()
        G_label_loss= BCE(DFake, realLabel)
        G_loss = VggLoss(fakeFrame, Y) + 1e-3 * G_label_loss

        G_loss.backward()
        G_optimizer.step()

        print("G_loss :", G_loss, " D_loss :", D_loss)

L'image ci-dessous montre le résultat d'un entraînement à 32 époques. Le haut est une image réduite, le milieu est la sortie dans SRGAN et le bas est l'image d'origine. Se sentir pas mal pour la précision, SRGAN.png

ESRGAN

Différence avec SRGAN

RRDN(Residual in Residual Dense Network) ・ Il semble que la capacité de production augmentera en supprimant la normalisation par lots. DenseBlock ajoute une sortie de couche à toutes les entrées de couche ・ De plus, connectez trois blocs denses de la même manière que ResNet. residual-in-residual-dense-block-RRDB.png Une fois mis en œuvre, il ressemble à ceci

class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, padding=1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, padding=1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, padding=1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, padding=1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, padding=1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), dim=1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), dim=1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), dim=1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), dim=1))
        return x5 * 0.2 + x


class Generator(nn.Module):
    def __init__(self, nf=64):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, nf, kernel_size=3, padding=1)
        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.blockLayer = nn.Sequential(
            ResidualDenseBlock(),
            ResidualDenseBlock(),
            ResidualDenseBlock(),
        )

        self.pixelShuffle = nn.Sequential(
            nn.Conv2d(nf, nf * 4, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.PixelShuffle(2),
            nn.Conv2d(nf, nf, kernel_size=3, padding=1),
            nn.Conv2d(nf, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

Relativistic GAN Le discriminateur SRGAN entraîne le vrai à produire 1 faux comme 0, mais le GAN relativiste compare l'image réelle avec la fausse image et définit la différence et l'étiquette comme BC Eloss. référence https://github.com/Yagami360/MachineLearning-Papers_Survey/issues/51 VGG Perceptual Loss Dans SRGAN, la quantité de caractéristiques a été extraite à l'aide de VGG16, mais dans Perceptual Loss, la structure est telle que L1_loss pour chaque couche de mise en commun de VGG16 est ajoutée. Ça ressemble à ça quand je l'écris à peu près

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(models.vgg16(pretrained=True).features[16:23].eval())
        blocks.append(models.vgg16(pretrained=True).features[23:30].eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks).cuda()
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), requires_grad=False).cuda()
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), requires_grad=False).cuda()

    def forward(self, fakeFrame, frameY):
        fakeFrame = (fakeFrame - self.mean) / self.std
        frameY = (frameY - self.mean) / self.std
        loss = 0.0
        x = fakeFrame
        y = frameY
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.l1_loss(x, y)
        return loss

Résultat d'apprentissage

Le haut est une image réduite, le milieu est la sortie avec ESRGAN, et le bas est l'image originale Comme SRGAN, 32epoch passe de 32px à 64px. ESRGAN.png Comparons les images générées côte à côte, le haut est SRGAN et le bas est ESRGAN. Le bruit est perceptible dans SRGAN, mais moins dans ESRGAN, et le contour général est plus clair que dans SRGAN. SRGAN.png ESRGAN.png

Recommended Posts

Super résolution avec SRGAN et ESRGAN
Avec et sans WSGI
Avec moi, cp et sous-processus
Programmation avec Python et Tkinter
Travailler avec le tkinter et la souris
Python et matériel - Utilisation de RS232C avec Python -
Group_by avec sqlalchemy et sum
python avec pyenv et venv
Avec moi, NER et Flair
Fonctionne avec Python et R
Communiquez avec FX-5204PS avec Python et PyUSB
Briller la vie avec Python et OpenCV
Fonctionnement de la souris et du clavier Python avec pyautogui
Tri avec un mélange de chiffres et de lettres
Robot fonctionnant avec Arduino et python
Installez Python 2.7.9 et Python 3.4.x avec pip.
Réseau neuronal avec OpenCV 3 et Python 3
Scraping avec Node, Ruby et Python
Easy Slackbot avec Docker et Errbot
Segmentation d'image avec scikit-image et scikit-learn
Processus d'authentification avec gRPC et authentification Firebase
Grattage avec Python, Selenium et Chromedriver
Jouez avec la série Poancare et SymPy
HTTPS avec Django et Let's Encrypt
Segmentation et regroupement de photos avec DBSCAN
Grattage avec Python et belle soupe
Sauvegarde NAS avec php et rsync
Encodage et décodage JSON avec python
Traitement de chemin avec take while et drop while
Authentification de base, authentification Digest avec Flask
Introduction à Hadoop et MapReduce avec Python
[GUI en Python] PyQt5-Glisser-déposer-
Comparez DCGAN et pix2pix avec Keras
Introduisez errBot et travaillez avec Slack
Enregistrer et récupérer des fichiers avec Pepper
Async / await avec Kivy et tkinter
J'ai joué avec PyQt5 et Python3
Connectez-vous avec PycURL et recevez une réponse
Expérimenté avec unicode, décoder et encoder
Lire et écrire du CSV avec Python
Intégration multiple avec Python et Sympy
Coexistence de Python2 et 3 avec CircleCI (1.0)
Dessinez des figures avec OpenCV et PIL
Jeu Sugoroku et jeu d'addition avec Python
Télécharger et télécharger des images avec Falcon
Modulation et démodulation FM avec Python
Créer un environnement avec pyenv et pyenv-virtualenv