I would like to use GAN (Generative Adversarial Network) to automatically color grayscale images.
It's technically called "pix2pix".
This grayscale image is
![0_gray.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/141993/66e565f7-1c0b-ca9b-2a7c-aabdc47b2977 .png)
I was able to color automatically as follows !!
![0_fake.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/0/141993/96a8572a- 3471-96ee-e2ce-6328c0d46401.png)
In some places, there are some strange parts and some images do not work, but the coloring is quite natural.
By the way, if you show only the bottom row of the original image, Like this. The colors of the trains and beds may be different, but I feel that they are painted in the same shades overall.
The rough image of this study is as follows.
Since it is a GAN, we use two networks, Generator and Discriminator.
(1)
(2)
(3)
(4)
(5)
(6)
In this way, the Generator and Discriminator are trained to alternately trick the two networks.
This time, I use pytorch 1.1, torchvision 0.30. For the time being, import the library to use
import glob
import os
import pickle
import torch
import torch.nn.functional as F
import torchvision
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np #1.16.4
import matplotlib.pyplot as plt
from PIL import Image
from torch import nn
from skimage import io
The environment is windows10, Anaconda1.9.7, core-i3 8100, RAN 16.0 GB GEFORCE GTX 1060
GPU is recommended because it takes a lot of learning time.
3-1.Generator
U-net used for semantic segmentation is used for Generator. You can get an output image with the same shape as the input image on the Encoder-Decoder network. The input image is a Gray image, and the output image is a color image (Fake image). The feature of this U-net is the Copy and Crop part. It is a device (apparently) to add an output close to the input layer to a layer close to the output layer so that the shape of the original image is not lost.
Realizing this Copy and Crop with pytorch is pretty easy, -Use torch.cat to combine inputs. -Double the number of input channels for Conv2d and BatchNorm2d. only. When I first saw it, I was quite impressed.
However, it is necessary to match the shape of the tensor to be combined with torch.cat.
If you use this U-net as it is, it will be a fairly huge network. (It looks like there are about 18 CNNs) Therefore, make the network smaller and reduce the size of the input / output images to 3 x 128 x 128.
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2)
self.bn1 = nn.BatchNorm2d(32)
self.av2 = nn.AvgPool2d(kernel_size=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.av3 = nn.AvgPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.av4 = nn.AvgPool2d(kernel_size=2)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.av5 = nn.AvgPool2d(kernel_size=2)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.bn5 = nn.BatchNorm2d(256)
self.un6 = nn.UpsamplingNearest2d(scale_factor=2)
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.bn6 = nn.BatchNorm2d(256)
#The output of conv6 and the output of conv4 are sent to conv7.,Double input channel
self.un7 = nn.UpsamplingNearest2d(scale_factor=2)
self.conv7 = nn.Conv2d(256 * 2, 128, kernel_size=3, stride=1, padding=1)
self.bn7 = nn.BatchNorm2d(128)
#Send the output of conv7 and the output of conv3 to conv8,Double input channel
self.un8 = nn.UpsamplingNearest2d(scale_factor=2)
self.conv8 = nn.Conv2d(128 * 2, 64, kernel_size=3, stride=1, padding=1)
self.bn8 = nn.BatchNorm2d(64)
#The output of conv8 and the output of conv2 are sent to conv9.,Double input channel
self.un9 = nn.UpsamplingNearest2d(scale_factor=4)
self.conv9 = nn.Conv2d(64 * 2, 32, kernel_size=3, stride=1, padding=1)
self.bn9 = nn.BatchNorm2d(32)
self.conv10 = nn.Conv2d(32 * 2, 3, kernel_size=5, stride=1, padding=2)
self.tanh = nn.Tanh()
def forward(self, x):
#x1-x4 is torch.Because I need to cat,Leave
x1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
x2 = F.relu(self.bn2(self.conv2(self.av2(x1))), inplace=True)
x3 = F.relu(self.bn3(self.conv3(self.av3(x2))), inplace=True)
x4 = F.relu(self.bn4(self.conv4(self.av4(x3))), inplace=True)
x = F.relu(self.bn5(self.conv5(self.av5(x4))), inplace=True)
x = F.relu(self.bn6(self.conv6(self.un6(x))), inplace=True)
x = torch.cat([x, x4], dim=1)
x = F.relu(self.bn7(self.conv7(self.un7(x))), inplace=True)
x = torch.cat([x, x3], dim=1)
x = F.relu(self.bn8(self.conv8(self.un8(x))), inplace=True)
x = torch.cat([x, x2], dim=1)
x = F.relu(self.bn9(self.conv9(self.un9(x))), inplace=True)
x = torch.cat([x, x1], dim=1)
x = self.tanh(self.conv10(x))
return x
3-2.Discriminator Discriminator is similar to a normal image identification network. However, the output is n x n numbers, not one-dimensional. Outputs True or False for each of these divided areas. In the case of the image below, it's 4x4.
This technique is called patch GAN.
After that, the activation function is GAN's classic Leakly Relu, InstanceNorm2d is used instead of BatchNorm2d.
I tried both InstanceNorm2d and BatchNorm2d, but I didn't really notice much difference in the results. InstanceNorm2d was good for Pix2Pix, so I'm using this one this time.
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2)
self.in1 = nn.InstanceNorm2d(16)
self.av2 = nn.AvgPool2d(kernel_size=2)
self.conv2_1 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.in2_1 = nn.InstanceNorm2d(32)
self.conv2_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.in2_2 = nn.InstanceNorm2d(32)
self.av3 = nn.AvgPool2d(kernel_size=2)
self.conv3_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.in3_1 = nn.InstanceNorm2d(64)
self.conv3_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.in3_2 = nn.InstanceNorm2d(64)
self.av4 = nn.AvgPool2d(kernel_size=2)
self.conv4_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.in4_1 = nn.InstanceNorm2d(128)
self.conv4_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.in4_2 = nn.InstanceNorm2d(128)
self.av5 = nn.AvgPool2d(kernel_size=2)
self.conv5_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.in5_1 = nn.InstanceNorm2d(256)
self.conv5_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.in5_2 = nn.InstanceNorm2d(256)
self.av6 = nn.AvgPool2d(kernel_size=2)
self.conv6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.in6 = nn.InstanceNorm2d(512)
self.conv7 = nn.Conv2d(512, 1, kernel_size=1)
def forward(self, x):
x = F.leaky_relu(self.in1(self.conv1(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in2_1(self.conv2_1(self.av2(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in2_2(self.conv2_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in3_1(self.conv3_1(self.av3(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in3_2(self.conv3_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in4_1(self.conv4_1(self.av4(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in4_2(self.conv4_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in5_1(self.conv5_1(self.av5(x))), 0.2, inplace=True)
x = F.leaky_relu(self.in5_2(self.conv5_2(x)), 0.2, inplace=True)
x = F.leaky_relu(self.in6(self.conv6(self.av6(x))), 0.2, inplace=True)
x = self.conv7(x)
return x
Generate a pseudo image using torch.randn Check the output size of Generator and Discriminator.
Here, two images with a size of 3 x 128 x 128 are generated and input to the Generator and Discriminator.
g, d = Generator(), Discriminator()
#Pseudo image with random numbers
test_imgs = torch.randn([2, 3, 128, 128])
test_imgs = g(test_imgs)
test_res = d(test_imgs)
print("Generator_output", test_imgs.size())
print("Discriminator_output",test_res.size())
The output looks like this:
Generator_output torch.Size([2, 3, 128, 128]) Discriminator_output torch.Size([2, 1, 4, 4])
The output size of the Generator is the same as the input. The output size of Discriminator is 4x4.
This time, we will get the data according to the following flow.
Data expansion of part b.
class DataAugment():
#Data augmentation of PIL image,Return PIL
def __init__(self, resize):
self.data_transform = transforms.Compose([
transforms.RandomResizedCrop(resize, scale=(0.9, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()])
def __call__(self, img):
return self.data_transform(img)
In the part that converts to d tensor, data normalization is also performed at the same time.
class ImgTransform():
#Resize PIL image,Normalize and return tensor
def __init__(self, resize, mean, std):
self.data_transform = transforms.Compose([
transforms.Resize(resize),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
def __call__(self, img):
return self.data_transform(img)
It is a class that inherits Pytorch's Dataset class, and the flow up to a-d is written in the place of getitem. You can easily create a data loader by creating an input and output flow for one image in the getitem part.
class MonoColorDataset(data.Dataset):
"""
Inherit Pytorch's Dataset class
"""
def __init__(self, file_list, transform_tensor, augment=None):
self.file_list = file_list
self.augment = augment #PIL to PIL
self.transform_tensor = transform_tensor #PIL to Tensor
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
#Get the file path of the index number
img_path = self.file_list[index]
img = Image.open(img_path)
img = img.convert("RGB")
if self.augment is not None:
img = self.augment(img)
#Copy for monochrome image
img_gray = img.copy()
#Convert color images to monochrome images
img_gray = transforms.functional.to_grayscale(img_gray,
num_output_channels=3)
#Convert PIL to tensor
img = self.transform_tensor(img)
img_gray = self.transform_tensor(img_gray)
return img, img_gray
By setting augment = None, the data will not be expanded, that is, it will be a dataset for test data. The function to create the data loader is as follows.
def load_train_dataloader(file_path, batch_size):
"""
Input
file_path List of file paths for the image you want to get
batch_size Data loader batch size
return
train_loader, RGB_images and Gray_images
"""
size = 128 #The size of one side of the image
mean = (0.5, 0.5, 0.5) #Average value for each channel when the image is normalized
std = (0.5, 0.5, 0.5) #Standard deviation per channel when image is normalized
#data set
train_dataset = MonoColorDataset(file_path_train,
transform=ImgTransform(size, mean, std),
augment=DataAugment(size))
#Data loader
train_dataloader = data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True)
return train_dataloader
It is convenient to use "torchvision.utils.make_grid" to arrange multiple images in tiles. After generating a tiled image with tensor, convert it to numpy and draw it with matplotlib.
def mat_grid_imgs(imgs, nrow, save_path = None):
"""
pytorch tensor(imgs)A function that draws tiles
Determine the number of sides of a tile with nrow
"""
imgs = torchvision.utils.make_grid(
imgs[0:(nrow**2), :, :, :], nrow=nrow, padding=5)
imgs = imgs.numpy().transpose([1,2,0])
imgs -= np.min(imgs) #Minimum value is 0
imgs /= np.max(imgs) #Maximum value is 1
plt.imshow(imgs)
plt.xticks([])
plt.yticks([])
plt.show()
if save_path is not None:
io.imsave(save_path, imgs)
A function that loads a test image and draws a gray image and a fake image in tiles.
def evaluate_test(file_path_test, model_G, device="cuda:0", nrow=4):
"""
load test image,Draw gray and fake images in tiles
"""
model_G = model_G.to(device)
size = 128
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
test_dataset = MonoColorDataset(file_path_test,
transform=ImgTransform(size, mean, std),
augment=None)
test_dataloader = data.DataLoader(test_dataset,
batch_size=nrow**2,
shuffle=False)
#Draw an image for each data loader
for img, img_gray in test_dataloader:
mat_grid_imgs(img_gray, nrow=nrow)
img = img.to(device)
img_gray = img_gray.to(device)
#img_From gray using Generator,RGB image of Fake
img_fake = model_G(img_gray)
img_fake = img_fake.to("cpu")
img_fake = img_fake.detach()
mat_grid_imgs(img_fake, nrow=nrow)
g = Generator()
file_path_test = glob.glob("test/*")
evaluate_test(file_path_test, g)
Although it is the result before learning, the shape of the input image can be vaguely understood.
For the time being, I just need to collect a large amount of image data, so I entered COCO2014, PASCAL Voc2007, Labeled Faces in the Wild etc. These data contain a good percentage of Gray images. I want to make a black and white image in color this time, but the image that should be a model cannot be shown in the Gray image (?). So I would like to remove the Gray image. For Gray images, the colors of R channel, G channel and B channel should be the same, so I would like to use that to remove them. At the same time, I also extracted images that are too white, images that are too dark, and images that do not have much color shading (standard deviation is small).
from skimage import io, color, transform
def color_mono(image, threshold=150):
#Determine if the input image of 3chnnel is color
#If you set a large threshold, you can set Mono even for photos with slightly mixed colors.
image_size = image.shape[0] * image.shape[1]
#The combination of channels(0, 1),(0, 2),(1, 2)3 ways,See the difference for each channel
diff = np.abs(np.sum(image[:,:, 0] - image[:,:, 1])) / image_size
diff += np.abs(np.sum(image[:,:, 0] - image[:,:, 2])) / image_size
diff += np.abs(np.sum(image[:,:, 1] - image[:,:, 2])) / image_size
if diff > threshold:
return "color"
else:
return "mono"
def bright_check(image, ave_thres = 0.15, std_thres = 0.1):
try:
#Image too bright,Image too dark,Image with similar brightness False
#Convert to black and white
image = color.rgb2gray(image)
if image.shape[0] < 144:
return False
#For images that are too bright
if np.average(image) > (1.-ave_thres):
return False
#For images that are too dark
if np.average(image) < ave_thres:
return False
#If all the brightness is similar
if np.std(image) < std_thres:
return False
return True
except:
return False
paths = glob.glob("./test2014/*")
for i, path in enumerate(paths):
image = io.imread(path)
save_name = "./trans\\mscoco_" + str(i) +".png "
x = image.shape[0] #Number of pixels in the x-axis direction
y = image.shape[1] #Number of pixels in the y-axis direction
try:
#The shorter one of the x and y axes/2
clip_half = min(x, y)/2
#Cut out a square in the image
image = image[int(x/2 -clip_half): int(x/2 + clip_half),
int(y/2 -clip_half): int(y/2 + clip_half), :]
if color_mono(image) == "color":
if bright_check(image):
image = transform.resize(image, (144, 144, 3),
anti_aliasing = True)
image = np.uint8(image*255)
io.imsave(save_name, image)
except:
pass
I cut the images into squares and put them all in one folder. The image is 144x144 instead of 128x128 so that the data can be expanded.
This is generally ok, but for some reason there were some omissions and sepia-colored images, so I deleted them manually.
I put about 110,000 images in the "trans" folder. Use glob to create and load a list of image paths.
Learning took about 20 minutes per epoch. The code is long because both Generator learning and Discriminator learning are performed.
The point to note is the label for calculating loss, and the size of the Discriminator output is the size of the Discriminator output in the confirmation of 4. I confirmed that it will be [batch_size, 1, 4, 4], so match it Generates true_labels and false_labels.
def train(model_G, model_D, epoch, epoch_plus):
device = "cuda:0"
batch_size = 32
model_G = model_G.to(device)
model_D = model_D.to(device)
params_G = torch.optim.Adam(model_G.parameters(),
lr=0.0002, betas=(0.5, 0.999))
params_D = torch.optim.Adam(model_D.parameters(),
lr=0.0002, betas=(0.5, 0.999))
#Label for calculating loss,Pay attention to the size of Discriminator
true_labels = torch.ones(batch_size, 1, 4, 4).to(device) #True
false_labels = torch.zeros(batch_size, 1, 4, 4).to(device) #False
#loss_function
bce_loss = nn.BCEWithLogitsLoss()
mae_loss = nn.L1Loss()
#Record error transition
log_loss_G_sum, log_loss_G_bce, log_loss_G_mae = list(), list(), list()
log_loss_D = list()
for i in range(epoch):
#Record temporary errors
loss_G_sum, loss_G_bce, loss_G_mae = list(), list(), list()
loss_D = list()
train_dataloader = load_train_dataloader(file_path_train, batch_size)
for real_color, input_gray in train_dataloader:
batch_len = len(real_color)
real_color = real_color.to(device)
input_gray = input_gray.to(device)
#Generator training
#Generate fake color image
fake_color = model_G(input_gray)
#Temporarily save fake image
fake_color_tensor = fake_color.detach()
#Calculate the loss so that the fake image can be deceived as the real thing
LAMBD = 100.0 #BCE and MAE coefficients
#out when fake image is put in the classifier,D tries to get closer to 0.
out = model_D(fake_color)
#Loss for the output of D,Target is true because I want to bring G closer to the real thing_labels
loss_G_bce_tmp = bce_loss(out, true_labels[:batch_len])
#Loss for G output
loss_G_mae_tmp = LAMBD * mae_loss(fake_color, real_color)
loss_G_sum_tmp = loss_G_bce_tmp + loss_G_mae_tmp
loss_G_bce.append(loss_G_bce_tmp.item())
loss_G_mae.append(loss_G_mae_tmp.item())
loss_G_sum.append(loss_G_sum_tmp.item())
#Calculate the gradient,G weight update
params_D.zero_grad()
params_G.zero_grad()
loss_G_sum_tmp.backward()
params_G.step()
#Discriminator training
real_out = model_D(real_color)
fake_out = model_D(fake_color_tensor)
#Loss function calculation
loss_D_real = bce_loss(real_out, true_labels[:batch_len])
loss_D_fake = bce_loss(fake_out, false_labels[:batch_len])
loss_D_tmp = loss_D_real + loss_D_fake
loss_D.append(loss_D_tmp.item())
#Calculate the gradient,D weight update
params_D.zero_grad()
params_G.zero_grad()
loss_D_tmp.backward()
params_D.step()
i = i + epoch_plus
print(i, "loss_G", np.mean(loss_G_sum), "loss_D", np.mean(loss_D))
log_loss_G_sum.append(np.mean(loss_G_sum))
log_loss_G_bce.append(np.mean(loss_G_bce))
log_loss_G_mae.append(np.mean(loss_G_mae))
log_loss_D.append(np.mean(loss_D))
file_path_test = glob.glob("test/*")
evaluate_test(file_path_test, model_G, device)
return model_G, model_D, [log_loss_G_sum, log_loss_G_bce, log_loss_G_mae, log_loss_D]
Perform learning.
file_path_train = glob.glob("trans/*")
model_G = Generator()
model_D = Discriminator()
model_G, model_D, logs = train(model_G, model_D, 40)
After 2 epoch
that? It feels pretty good except that the airplane image is not painted at all ??
After 11 epoch
After 21 epoch
After 40 epoch ends (image shown at the beginning)
Unexpectedly, I felt that the image after 2 epoch was good ...
I will also post other images. 11 After the end of epoch. I have selected a lot of images that seem to have failed. The terrible image is really terrible, with almost no color Like the image of baseball, I paint it ignoring the border.
I feel that I am good at greens such as grass and trees, and blues such as the sky. This seems to depend on the bias of the original dataset and the ease of painting (recognizability).
I used pix2pix to colorize the Gray image.
This time, I decided to add an image and make a color image as soon as I could do anything. As the network is shallow, the expressiveness is low, so I feel that it works better to narrow down the types of images.
To be honest, I feel that this is easier to understand than what I wrote.
U-Net: Convolutional Networks for Biomedical Image Segmentation https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/
I implemented pix2pix from 1 and tried to colorize a black and white image (PyTorch) https://blog.shikoan.com/pytorch_pix2pix_colorization/
pix2 I want to understand pix https://qiita.com/mine820/items/36ffc3c0aea0b98027fd
CoCo https://cocodataset.org/#home
Labeled Faces in the Wild http://vis-www.cs.umass.edu/lfw/
The PASCAL Visual Object Classes Homepage http://host.robots.ox.ac.uk/pascal/VOC/
Recommended Posts