[PYTHON] Ich habe ein Netzwerk erstellt, um Schwarzweißbilder in Farbbilder umzuwandeln (pix2pix)

1. Zuallererst

Ich möchte GAN (Hostile Generation Network) verwenden, um Graustufenbilder automatisch einzufärben. Es scheint technisch "pix2pix" genannt zu werden.

Dieses Graustufenbild ist
![0_gray.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/141993/66e565f7-1c0b-ca9b-2a7c-aabdc47b2977 .png) Es wurde automatisch wie folgt gefärbt !!
![0_fake.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/141993/96a8572a- 3471-96ee-e2ce-6328c0d46401.png) An einigen Stellen gibt es einige seltsame Teile und einige Bilder funktionieren nicht, aber die Farbgebung ist ganz natürlich.

Übrigens, wenn Sie nur die unterste Zeile des Originalbildes anzeigen, MSCOCO2014_557.pngMSCOCO2014_575.pngSailboat.jpgvoc2007_1963.png So was. Die Farben der Züge und Betten sind unterschiedlich, aber ich habe das Gefühl, dass sie insgesamt in den gleichen Farben gestrichen sind.

2. Grobes Bild dieses Lernens

Das grobe Bild dieser Studie ist wie folgt. Da es sich um ein GAN handelt, verwenden wir zwei Netzwerke, Generator und Discriminator.
(1)pix2pix.jpg

(2)Folie 2.JPG

(3)Folie 3.JPG

(4)Folie 4.JPG

(5)Folie 5.JPG

(6)Folie 6.JPG

Auf diese Weise werden der Generator und der Diskriminator trainiert, um die beiden Netzwerke abwechselnd auszutricksen.

3. Über das Lernnetzwerk

Dieses Mal wird Pytorch 1.1, Torchvision 0.30 verwendet. Importieren Sie vorerst die zu verwendende Bibliothek

import glob
import os
import pickle
import torch
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np                            #1.16.4
import matplotlib.pyplot as plt
from PIL import Image
from torch import nn
from skimage import io

Die Umwelt ist windows10, Anaconda1.9.7, core-i3 8100, RAN 16.0 GB GEFORCE GTX 1060

Die GPU wird empfohlen, da sie viel Lernzeit in Anspruch nimmt.

3-1.Generator

Das für die semantische Segmentierung verwendete U-Netz wird für den Generator verwendet. Sie können ein Ausgabebild mit der gleichen Form wie das Eingabebild im Encoder-Decoder-Netzwerk erhalten. Das Eingabebild ist ein graues Bild und das Ausgabebild ist ein Farbbild (falsches Bild). u-net-architecture.png Das Merkmal dieses U-Netzes ist der Teil Kopieren und Zuschneiden. Es ist (anscheinend) ein Gerät, eine Ausgabe in der Nähe der Eingabeebene zu einer Ebene in der Nähe der Ausgabeebene hinzuzufügen, damit die Form des Originalbilds nicht verloren geht.

Das Realisieren dieser Kopie und Ernte mit Pytorch ist ziemlich einfach,

Es ist jedoch erforderlich, die Form des Tensors anzupassen, der mit torch.cat kombiniert werden soll.

Wenn Sie dieses U-Netz so verwenden, wie es ist, wird es ein ziemlich großes Netzwerk sein. (Es sieht so aus, als gäbe es ungefähr 18 CNNs) Verkleinern Sie daher das Netzwerk und reduzieren Sie die Größe der Eingabe- / Ausgabebilder auf 3 x 128 x 128.

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.av2 = nn.AvgPool2d(kernel_size=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.av3 = nn.AvgPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.av4 = nn.AvgPool2d(kernel_size=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.av5 = nn.AvgPool2d(kernel_size=2)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        
        self.un6 = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        
        #Die Ausgabe von conv6 und die Ausgabe von conv4 werden an conv7 gesendet.,Doppelter Eingangskanal
        self.un7 = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv7 = nn.Conv2d(256 * 2, 128, kernel_size=3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        
        #Senden Sie die Ausgabe von conv7 und die Ausgabe von conv3 an conv8,Doppelter Eingangskanal
        self.un8 = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv8 = nn.Conv2d(128 * 2, 64, kernel_size=3, stride=1, padding=1)
        self.bn8 = nn.BatchNorm2d(64)
        
        #Die Ausgabe von conv8 und die Ausgabe von conv2 werden an conv9 gesendet.,Doppelter Eingangskanal
        self.un9 = nn.UpsamplingNearest2d(scale_factor=4)
        self.conv9 = nn.Conv2d(64 * 2, 32, kernel_size=3, stride=1, padding=1)
        self.bn9 = nn.BatchNorm2d(32)
        
        self.conv10 = nn.Conv2d(32 * 2, 3, kernel_size=5, stride=1, padding=2)
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        #x1-x4 ist Fackel.Weil ich Katze muss,Verlassen
        x1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x2 = F.relu(self.bn2(self.conv2(self.av2(x1))), inplace=True)
        x3 = F.relu(self.bn3(self.conv3(self.av3(x2))), inplace=True)
        x4 = F.relu(self.bn4(self.conv4(self.av4(x3))), inplace=True)
        x = F.relu(self.bn5(self.conv5(self.av5(x4))), inplace=True)
        x = F.relu(self.bn6(self.conv6(self.un6(x))), inplace=True)
        x = torch.cat([x, x4], dim=1)
        x = F.relu(self.bn7(self.conv7(self.un7(x))), inplace=True)
        x = torch.cat([x, x3], dim=1)
        x = F.relu(self.bn8(self.conv8(self.un8(x))), inplace=True)
        x = torch.cat([x, x2], dim=1)
        x = F.relu(self.bn9(self.conv9(self.un9(x))), inplace=True)
        x = torch.cat([x, x1], dim=1)
        x = self.tanh(self.conv10(x))
        return x

3-2.Discriminator Der Diskriminator ähnelt einem normalen Bildidentifikationsnetzwerk. Die Ausgabe ist jedoch n x n Zahlen, nicht eindimensional. Gibt für jeden dieser unterteilten Bereiche True oder False aus. Im Fall des Bildes unten ist es 4x4.

ddd.png Diese Technik wird als Patch-GAN bezeichnet.

Danach ist die Aktivierungsfunktion GANs klassisches Leakly Relu. InstanceNorm2d wird anstelle von BatchNorm2d verwendet.

Ich habe sowohl InstanceNorm2d als auch BatchNorm2d ausprobiert, aber ich habe keinen großen Unterschied in den Ergebnissen bemerkt. InstanceNorm2d war gut für Pix2Pix, daher verwende ich dieses Mal.

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)
        self.in1 = nn.InstanceNorm2d(16)
        
        self.av2 = nn.AvgPool2d(kernel_size=2)
        self.conv2_1 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.in2_1 = nn.InstanceNorm2d(32)
        self.conv2_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.in2_2 = nn.InstanceNorm2d(32)
        
        self.av3 = nn.AvgPool2d(kernel_size=2)
        self.conv3_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.in3_1 = nn.InstanceNorm2d(64)
        self.conv3_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.in3_2 = nn.InstanceNorm2d(64)
        
        self.av4 = nn.AvgPool2d(kernel_size=2)
        self.conv4_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.in4_1 = nn.InstanceNorm2d(128)
        self.conv4_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.in4_2 = nn.InstanceNorm2d(128)
        
        self.av5 = nn.AvgPool2d(kernel_size=2)
        self.conv5_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.in5_1 = nn.InstanceNorm2d(256)
        self.conv5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.in5_2 = nn.InstanceNorm2d(256)
        
        self.av6 = nn.AvgPool2d(kernel_size=2)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.in6 = nn.InstanceNorm2d(512)
        
        self.conv7 = nn.Conv2d(512, 1, kernel_size=1)
            
    def forward(self, x):      
        x = F.leaky_relu(self.in1(self.conv1(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in2_1(self.conv2_1(self.av2(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in2_2(self.conv2_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in3_1(self.conv3_1(self.av3(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in3_2(self.conv3_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in4_1(self.conv4_1(self.av4(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in4_2(self.conv4_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in5_1(self.conv5_1(self.av5(x))), 0.2, inplace=True)
        x = F.leaky_relu(self.in5_2(self.conv5_2(x)), 0.2, inplace=True)
        x = F.leaky_relu(self.in6(self.conv6(self.av6(x))), 0.2, inplace=True)
        x = self.conv7(x)
        
        return x

3-3. Bestätigung

Generieren Sie mit torch.randn ein Pseudobild Überprüfen Sie die Ausgabegröße von Generator und Diskriminator.

Hier werden zwei Bilder mit einer Größe von 3 x 128 x 128 erzeugt und in den Generator und Diskriminator eingegeben.

g, d = Generator(), Discriminator()

#Pseudobild durch Zufallszahl
test_imgs = torch.randn([2, 3, 128, 128])
test_imgs = g(test_imgs)
test_res = d(test_imgs)

print("Generator_output", test_imgs.size())
print("Discriminator_output",test_res.size())

Die Ausgabe sieht folgendermaßen aus:

Generator_output torch.Size([2, 3, 128, 128])  Discriminator_output torch.Size([2, 1, 4, 4])

Die Ausgabegröße des Generators entspricht der Eingabe. Die Ausgabegröße von Discriminator beträgt 4x4.

4. Über den Datenlader

Dieses Mal erfassen wir die Daten gemäß dem folgenden Ablauf. flow.png

Datenerweiterung von Teil b.

class DataAugment():
    #Datenerweiterung des PIL-Bildes,PIL zurückgeben
    def __init__(self, resize):
        self.data_transform = transforms.Compose([
                    transforms.RandomResizedCrop(resize, scale=(0.9, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomVerticalFlip()])
    def __call__(self, img):
        return self.data_transform(img)

In dem Teil, der in d-Tensor konvertiert wird, wird gleichzeitig auch eine Datennormalisierung durchgeführt.

class ImgTransform():
    #PIL-Bildgröße ändern,Normalisieren und Tensor zurückgeben
    def __init__(self, resize, mean, std):
        self.data_transform = transforms.Compose([
                    transforms.Resize(resize),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std)])
        
    def __call__(self, img):
        return self.data_transform(img)

Es ist eine Klasse, die die Dataset-Klasse von Pytorch erbt, und der Fluss bis a-d wird anstelle von getitem geschrieben. Sie können einfach einen Datenlader erstellen, indem Sie einen Eingabe- und Ausgabefluss für ein Bild im getitem-Teil erstellen.

class MonoColorDataset(data.Dataset):
    """
Erben Sie die Dataset-Klasse von Pytorch
    """
    def __init__(self, file_list, transform_tensor, augment=None):
        self.file_list = file_list
        self.augment = augment     #PIL to PIL
        self.transform_tensor = transform_tensor  #PIL to Tensor

    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        #Rufen Sie den Dateipfad der Indexnummer ab
        img_path = self.file_list[index]
        img = Image.open(img_path)
        img = img.convert("RGB")
        
        if self.augment is not None:
            img = self.augment(img)
        
        #Kopie für Schwarzweißbild
        img_gray = img.copy()
        #Konvertieren Sie ein Farbbild in ein Schwarzweißbild
        img_gray = transforms.functional.to_grayscale(img_gray,
                                                      num_output_channels=3)
        #Konvertieren Sie PIL in Tensor
        img = self.transform_tensor(img)
        img_gray = self.transform_tensor(img_gray)
        
        return img, img_gray

Durch Setzen von augment = None werden die Daten nicht erweitert, dh es handelt sich um einen Datensatz für Testdaten. Die Funktion zum Erstellen des Datenladers lautet wie folgt.

def load_train_dataloader(file_path, batch_size):
    """
    Input
     file_Pfad Liste der Dateipfade für das Bild, das Sie abrufen möchten
     batch_Größe Datenlader-Stapelgröße
    return
     train_loader, RGB_images and Gray_images
    """ 
    size = 128             #Eine Seitengröße des Bildes
    mean = (0.5, 0.5, 0.5) #Durchschnittswert für jeden Kanal, wenn das Bild normalisiert wird
    std = (0.5, 0.5, 0.5)  #Standardabweichung für jeden Kanal, wenn das Bild normalisiert wird
    
    #Datensatz
    train_dataset = MonoColorDataset(file_path_train, 
                                 transform=ImgTransform(size, mean, std), 
                                 augment=DataAugment(size))
    #Datenlader
    train_dataloader = data.DataLoader(train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True)
    return train_dataloader

5. Visualisierungsmethode

5.1 Zu visualisierende Funktionen

Es ist praktisch, "torchvision.utils.make_grid" zu verwenden, um mehrere Bilder in einer Kachel anzuordnen. Nachdem Sie ein gekacheltes Bild mit Tensor erstellt haben, konvertieren Sie es in Numpy und zeichnen Sie es mit Matplotlib.

def mat_grid_imgs(imgs, nrow, save_path = None):
    """
Pytorch Tensor(imgs)Eine Funktion, die Kacheln zeichnet
Bestimmen Sie die Anzahl der Seiten einer Kachel mit nrow
    """
    imgs = torchvision.utils.make_grid(
                imgs[0:(nrow**2), :, :, :], nrow=nrow, padding=5)
    imgs = imgs.numpy().transpose([1,2,0])

    imgs -= np.min(imgs)   #Mindestwert ist 0
    imgs /= np.max(imgs)   #Maximalwert ist 1
    
    plt.imshow(imgs)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    
    if save_path is not None:
        io.imsave(save_path, imgs)

Eine Funktion, die ein Testbild lädt und ein graues Bild und ein falsches Bild in Kacheln zeichnet.

def evaluate_test(file_path_test, model_G, device="cuda:0", nrow=4):
    """
Testbild laden,Zeichnen Sie graue und gefälschte Bilder in Kacheln
    """
    model_G = model_G.to(device)
    size = 128
    mean = (0.5, 0.5, 0.5)
    std = (0.5, 0.5, 0.5)
    test_dataset = MonoColorDataset(file_path_test, 
                                 transform=ImgTransform(size, mean, std), 
                                 augment=None)
    test_dataloader = data.DataLoader(test_dataset,
                                    batch_size=nrow**2,
                                    shuffle=False)
    #Zeichnen Sie für jeden Datenlader ein Bild
    for img, img_gray in test_dataloader:
        mat_grid_imgs(img_gray, nrow=nrow)
        img = img.to(device)
        img_gray = img_gray.to(device)
        #img_Von Grau mit Generator,RGB-Bild von Fake
        img_fake = model_G(img_gray)
        img_fake = img_fake.to("cpu")
        img_fake = img_fake.detach()
        mat_grid_imgs(img_fake, nrow=nrow)

5.2 Visualisierungsergebnis (vor dem Lernen)

g = Generator()
file_path_test = glob.glob("test/*")
evaluate_test(file_path_test, g)

weew.png Obwohl es das Ergebnis vor dem Lernen ist, kann die Form des Eingabebildes vage verstanden werden.

6. So erhalten Sie Trainingsdaten

Zur Zeit gebe ich nur eine große Menge von Bilddaten ein, also gebe ich COCO2014, PASCAL Voc2007, Labeled Faces in the Wild usw. ein. Diese Daten enthalten einen guten Anteil an Graubildern. Ich möchte dieses Mal ein Schwarzweißbild in Farbe erstellen, aber das Bild, das ein Modell sein sollte, kann im grauen Bild (?) Nicht angezeigt werden. Also möchte ich das graue Bild entfernen. Für graue Bilder sollten die Farben von R-Kanal, G-Kanal und B-Kanal gleich sein, daher möchte ich sie zum Entfernen verwenden. Gleichzeitig habe ich auch Bilder extrahiert, die zu weiß sind, Bilder, die zu dunkel sind, und Bilder, die nicht viel Farbschattierung aufweisen (Standardabweichung ist klein).

from skimage import io, color, transform

def color_mono(image, threshold=150):
    #Stellen Sie fest, ob das Eingabebild von 3chnnel farbig ist
    #Wenn Sie einen großen Schwellenwert festlegen, können Sie Mono auch für Fotos mit leicht gemischten Farben festlegen.
    image_size = image.shape[0] * image.shape[1]
    
    #Die Kombination von Kanälen(0, 1),(0, 2),(1, 2)3 Möglichkeiten,Sehen Sie den Unterschied für jeden Kanal
    diff = np.abs(np.sum(image[:,:, 0] - image[:,:, 1])) / image_size
    diff += np.abs(np.sum(image[:,:, 0] - image[:,:, 2])) / image_size
    diff += np.abs(np.sum(image[:,:, 1] - image[:,:, 2])) / image_size
    if diff > threshold:
        return "color"
    else:
        return "mono"

def bright_check(image, ave_thres = 0.15, std_thres = 0.1):
    try:
    #Bild, das zu hell ist,Bild zu dunkel,Bild mit ähnlicher Helligkeit Falsch
    #In Schwarzweiß konvertieren
        image = color.rgb2gray(image)
    
        if image.shape[0] < 144:
            return False    
        #Für zu helle Bilder
        if np.average(image) > (1.-ave_thres):
            return False
        #Für zu dunkle Bilder
        if np.average(image) < ave_thres:
            return False
        #Wenn die gesamte Helligkeit ähnlich ist
        if np.std(image) < std_thres:
            return False
        return True
    except:
        return False

paths = glob.glob("./test2014/*")

for i, path in enumerate(paths):
    image = io.imread(path)
    save_name = "./trans\\mscoco_" + str(i) +".png "
    
    x = image.shape[0] #Anzahl der Pixel entlang der x-Achse
    y = image.shape[1] #Anzahl der Pixel in Richtung der y-Achse
    
    try:
        #Die kürzere der x- und y-Achsen/2
        clip_half = min(x, y)/2
        #Schneiden Sie ein Quadrat im Bild aus
        image = image[int(x/2 -clip_half): int(x/2 + clip_half),
                  int(y/2 -clip_half): int(y/2 + clip_half), :]

        if color_mono(image) == "color":
            if bright_check(image):
                image = transform.resize(image, (144, 144, 3),
                                        anti_aliasing = True)
                image = np.uint8(image*255)
                io.imsave(save_name, image)
    except:
        pass

Ich schneide die Bilder in Quadrate und lege sie alle in einen Ordner. Das Bild ist 144x144 statt 128x128, damit die Daten erweitert werden können. coco.png

Dies ist im Allgemeinen in Ordnung, aber aus irgendeinem Grund gab es einige Auslassungen und sepiafarbene Bilder, sodass ich sie manuell löschte.

Ich habe ungefähr 110.000 Bilder in den Ordner "trans" gelegt. Verwenden Sie glob, um eine Liste von Bildpfaden zu erstellen und zu laden.

7. Lernen

7.1 Lernfunktionen

Das Lernen dauerte ungefähr 20 Minuten pro Epoche. Der Code ist lang, weil wir sowohl Generator als auch Diskriminator trainieren.

Der zu beachtende Punkt ist die Bezeichnung für die Berechnung des Verlusts, und die Größe der Diskriminatorausgabe entspricht der Größe der Diskriminatorausgabe in der Bestätigung von 4. Ich habe bestätigt, dass es [batch_size, 1, 4, 4] sein wird Erzeugt true_labels und false_labels.

def train(model_G, model_D, epoch, epoch_plus):
    device = "cuda:0"
    batch_size = 32
    
    model_G = model_G.to(device)
    model_D = model_D.to(device)
    
    params_G = torch.optim.Adam(model_G.parameters(),
                                lr=0.0002, betas=(0.5, 0.999))
    params_D = torch.optim.Adam(model_D.parameters(),
                                lr=0.0002, betas=(0.5, 0.999))
    
    #Etikett zur Berechnung des Verlustes,Achten Sie auf die Größe des Diskriminators
    true_labels = torch.ones(batch_size, 1, 4, 4).to(device)    #True
    false_labels = torch.zeros(batch_size, 1, 4, 4).to(device)  #False
    
    #loss_function
    bce_loss = nn.BCEWithLogitsLoss()
    mae_loss = nn.L1Loss()
    
    #Fehlerübergang aufzeichnen
    log_loss_G_sum, log_loss_G_bce, log_loss_G_mae = list(), list(), list()
    log_loss_D = list()
    
    for i in range(epoch):
        #Temporären Fehler aufzeichnen
        loss_G_sum, loss_G_bce, loss_G_mae = list(), list(), list()
        loss_D = list()
        
        train_dataloader = load_train_dataloader(file_path_train, batch_size)
        
        for real_color, input_gray in train_dataloader:
            batch_len = len(real_color)
            real_color = real_color.to(device)
            input_gray = input_gray.to(device)
            
            #Generator Training
            #Generieren Sie ein gefälschtes Farbbild
            fake_color = model_G(input_gray)
            
            #Falsches Bild vorübergehend speichern
            fake_color_tensor = fake_color.detach()
            
            #Berechnen Sie den Verlust so, dass das gefälschte Bild als echt getäuscht werden kann
            LAMBD = 100.0 #BCE- und MAE-Koeffizienten
            
            #aus, wenn gefälschtes Bild in den Klassifikator gelegt wird,D versucht, näher an 0 heranzukommen.
            out = model_D(fake_color)
            
            #Verlust für die Ausgabe von D.,Das Ziel ist wahr, weil ich G näher an die Realität bringen möchte_labels
            loss_G_bce_tmp = bce_loss(out, true_labels[:batch_len])
            
            #Verlust für G-Ausgang
            loss_G_mae_tmp = LAMBD * mae_loss(fake_color, real_color)
            loss_G_sum_tmp = loss_G_bce_tmp + loss_G_mae_tmp
            
            loss_G_bce.append(loss_G_bce_tmp.item())
            loss_G_mae.append(loss_G_mae_tmp.item())
            loss_G_sum.append(loss_G_sum_tmp.item())
            
            #Berechnen Sie den Gradienten,G Gewichtsaktualisierung
            params_D.zero_grad()
            params_G.zero_grad()
            loss_G_sum_tmp.backward()
            params_G.step()
            
            #Diskriminatorentraining
            real_out = model_D(real_color)
            fake_out = model_D(fake_color_tensor)
            
            #Berechnung der Verlustfunktion
            loss_D_real = bce_loss(real_out, true_labels[:batch_len])
            loss_D_fake = bce_loss(fake_out, false_labels[:batch_len])
            
            loss_D_tmp = loss_D_real + loss_D_fake
            loss_D.append(loss_D_tmp.item())
            
            #Berechnen Sie den Gradienten,D Gewichtsaktualisierung
            params_D.zero_grad()
            params_G.zero_grad()
            loss_D_tmp.backward()
            params_D.step()
        
        i = i + epoch_plus
        print(i, "loss_G", np.mean(loss_G_sum), "loss_D", np.mean(loss_D))
        log_loss_G_sum.append(np.mean(loss_G_sum))
        log_loss_G_bce.append(np.mean(loss_G_bce))
        log_loss_G_mae.append(np.mean(loss_G_mae))
        log_loss_D.append(np.mean(loss_D))
        
        file_path_test = glob.glob("test/*")
        evaluate_test(file_path_test, model_G, device)
        
    return model_G, model_D, [log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D]

Führen Sie das Lernen durch.

file_path_train = glob.glob("trans/*")
model_G = Generator()
model_D = Discriminator()
model_G, model_D, logs = train(model_G, model_D, 40)

## 7.2 Lernergebnisse Der Verlust von Trainingsdaten sieht so aus. ![loss.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/141993/59885ca0-eb16-7c5b-4e6c-c413ed27e49c.png)

Nach 2 Epochen 1epoch_fake.png

Das? Es fühlt sich ziemlich gut an, außer dass das Flugzeugbild überhaupt nicht gemalt ist?

Nach 11 Epochen 10_fake.png

Nach 21 Epochen 20_fake.png

Nach 40 Epochen endet (Bild am Anfang gezeigt) 0_fake.png

Unerwartet fand ich, dass das Bild nach 2 Epochen gut war ...

Ich werde auch andere Bilder posten. 11 Nach dem Ende der Epoche. Ich habe viele Bilder ausgewählt, die anscheinend fehlgeschlagen sind. Das schreckliche Bild ist wirklich schrecklich, fast ohne Farbe Wie das Bild von Baseball male ich es, ohne die Grenze zu ignorieren.

0_2gray.png

10_2_fake.png

Ich fühle mich gut in Grüns wie Gras und Bäumen und Blues wie dem Himmel. Dies scheint von der Verzerrung des Originaldatensatzes und der Leichtigkeit des Malens (Leichtigkeit der Erkennung) abzuhängen.

8. Matome, Eindruck

Das graue Bild wurde mit pix2pix eingefärbt.

Dieses Mal habe ich beschlossen, ein Bild hinzuzufügen und ein Farbbild zu erstellen, sobald ich etwas tun konnte. Da das Netzwerk flach ist, ist die Ausdruckskraft gering Ich bin der Meinung, dass es besser funktioniert, die Bildtypen einzugrenzen.

Verweise

Um ehrlich zu sein, denke ich, dass dies leichter zu verstehen ist als das, was ich geschrieben habe.

U-Net: Convolutional Networks for Biomedical Image Segmentation     https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/

Ich habe pix2pix von 1 implementiert und versucht, ein Schwarzweißbild (PyTorch) einzufärben.     https://blog.shikoan.com/pytorch_pix2pix_colorization/

pix2 Ich möchte pix verstehen     https://qiita.com/mine820/items/36ffc3c0aea0b98027fd

Bild

CoCo https://cocodataset.org/#home

Labeled Faces in the Wild http://vis-www.cs.umass.edu/lfw/

The PASCAL Visual Object Classes Homepage http://host.robots.ox.ac.uk/pascal/VOC/

Recommended Posts

Ich habe ein Netzwerk erstellt, um Schwarzweißbilder in Farbbilder umzuwandeln (pix2pix)
Ich habe ein Programm erstellt, um Bilder mit Python und OpenCV in ASCII-Grafik umzuwandeln
Ich habe ein CLI-Tool erstellt, um Bilder in jedem Verzeichnis in PDF zu konvertieren
Ich habe einen Code erstellt, um illustration2vec in ein Keras-Modell zu konvertieren
Konvertieren Sie Videos mit ffmpeg + python + opencv in Schwarzweiß
Ich habe eine einfache Netzwerkkamera hergestellt, indem ich ESP32-CAM und RTSP kombiniert habe.
Ich habe ein Skript erstellt, um Piktogramme anzuzeigen
Ich habe ein Skript in Python erstellt, um MDD-Dateien in das Scrapbox-Format zu konvertieren
Ich habe ein Programm erstellt, um einzugeben, was ich gegessen habe, und um Kalorien und Zucker anzuzeigen
Ich habe ein Tool erstellt, um Jupyter py mit VS Code in ipynb zu konvertieren
Ich habe ein Tool erstellt, um Slack über Connpass-Ereignisse zu informieren, und es zu Terraform gemacht
Einführung in die KI-Erstellung mit Python! Teil 3 Ich habe versucht, Bilder mit einem Convolutional Neural Network (CNN) zu klassifizieren und vorherzusagen.
Ich habe ein Tool erstellt, um Hy nativ zu kompilieren
Ich möchte ein Programm ausführen und verteilen, das die Größe von Bildern in Python3 + Pyinstaller ändert
Ich habe ein Modul in C-Sprache erstellt, das von Python geladene Bilder filtert
Ich habe ein Tool erstellt, um neue Artikel zu erhalten
Ich habe eine Bibliothek erstellt, um japanische Sätze schön zu brechen
Ich habe ein Skript erstellt, um ein Snippet in README.md einzufügen
Ich habe ein Python-Modul erstellt, um Kommentare zu übersetzen
Ich habe versucht, LINE BOT mit Python und Heroku zu machen
Ich habe einen Befehl zum Markieren des Tabellenclips gegeben
Ich habe eine Python-Bibliothek erstellt, die einen rollierenden Rang hat
〇✕ Ich habe ein Spiel gemacht
Ich habe ein Tool erstellt, mit dem das Erstellen und Installieren eines öffentlichen Schlüssels etwas einfacher ist.
Ich habe ein Skript in Python erstellt, um eine Textdatei für JSON zu konvertieren (für das vscode-Benutzer-Snippet).
Erstellt eine Methode zur automatischen Auswahl und Visualisierung eines geeigneten Diagramms für Pandas DataFrame
Ich habe versucht, das grundlegende Modell des wiederkehrenden neuronalen Netzwerks zu implementieren
Ich habe versucht, eine Rangliste zu erstellen, indem ich das Mitgliederteam der Organisation abgekratzt habe
Ich habe ein Paket erstellt, um Zeitreihen mit Python zu filtern
Demosaic Bayer FITS-Dateien und konvertieren sie in Farbe TIFF
Ich habe einen Befehl zum Generieren eines Kommentars für eine Tabelle in Django eingegeben
Ich habe ein Tool erstellt, um eine Wortwolke aus Wikipedia zu erstellen
Konvertieren Sie verstümmelte gescannte Bilder mit Pillow und PyPDF in PDF
[Titan Craft] Ich habe ein Werkzeug gemacht, um einen Riesen nach Minecraft zu rufen
Ich habe Chatbot mit LINE Messaging API und Python erstellt
Ich habe Sie dazu gebracht, Befehle über einen WEB-Browser auszuführen
Ich habe einen neuronalen Netzwerkgenerator erstellt, der auf FPGA läuft
Ich habe ein Drehbuch gemacht, um bei meinem Koshien Hallo zu sagen
Ich habe mein eigenes neuronales 3-Layer-Forward-Propagation-Netzwerk erstellt und versucht, die Berechnung genau zu verstehen.
Ich habe einen Python-Text gemacht
Ich habe einen Zwietrachtbot gemacht
Ich habe eine Bibliothek erstellt, die Konfigurationsdateien mit Python einfach lesen kann
Ich habe versucht, eine Python-Datei in eine EXE-Datei zu verwandeln (Rekursionsfehler unterstützt)
Ich habe mit Razpai einen Webserver erstellt, um Anime zu schauen
Ich wollte mein Gesichtsfoto in einen Yuyu-Stil umwandeln.
Ich habe versucht, mit Selenium und Python einen regelmäßigen Ausführungsprozess durchzuführen
Holz kratzen und essen - ich möchte ein gutes Restaurant finden! ~ (Arbeit)
Ich habe versucht, Bulls and Cows mit einem Shell-Programm zu erstellen
Ich möchte eine Pipfile erstellen und im Docker wiedergeben
Ich habe den Befehl gegeben, einen farbenfrohen Kalender im Terminal anzuzeigen
Ich habe Chatbot mit der LINE Messaging API und Python (2) ~ Server ~ erstellt
Ich habe ein Skript geschrieben, um goodnotes5 und Anki bei der Zusammenarbeit zu unterstützen
Ich habe einen Chat-Chat-Bot mit Tensor2Tensor erstellt und diesmal hat es funktioniert
Ich habe ein POST-Skript erstellt, um ein Problem in Github zu erstellen und es im Projekt zu registrieren