[PYTHON] Implement Style Transfer in Pytorch

Let's implement Style Transfer with Pytorch to generate a composite image that applies the style of the image to a certain image. Use Google Colaboratory (Colab). Since I am studying, there may be some misinterpretations.

Implement Style Transfer in Pytorch to generate composite images

This is the finished product of this article. style_transfered.png

The following is the original image and Van Gogh's painting that was the source of the style. The image was taken from unsplash. Van Gogh's style is well applied to the image of the torii. unsplash

original-800x305.png

Preparation

Let's implement the process before learning.

Required libraries and Google Drive mount

Import the required libraries. This time I will mount Google Drive and learn about the images in the drive.

Install Pillow to view pytorch and images.

!pip install torch torchvision
!pip install Pillow==4.0.0

Import the library.

%matplotlib inline
import torch
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

Mount Google Drive. You can now access the drive from Colab. Let's upload the image from which the composite image is generated and the image to which the style is applied to the drive.

from google.colab import drive
drive.mount("/content/drive")

Model definition

Use VGG. Also define the gpu of device.

vgg = models.vgg19(pretrained=True).features

for param in vgg.parameters():
  param.requires_grad_(False)
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

Define a function to display an image

Define a function that loads an image from the path and returns a Tensor. You can specify the size.

def load_image(img_path, max_size=400, shape=None):
  
  image = Image.open(img_path).convert('RGB')
  if max(image.size) > max_size:
    size = max_size
  else:
    size = max(image.size)
  
  if shape:
    size = shape
  in_transform = transforms.Compose([
                    transforms.Resize(size),
                    transforms.ToTensor(),
                                            ])
  
  image = in_transform(image).unsqueeze(0)
  
  return image

Load and display images

Defines the path of the image. Please set a convenience path here. The image that is the source of the final image is defined as content_path, and the image that is the source of the style is defined as style_path.

images_path ='drive/My Drive/'
content_path = images_path + 'content.jpg'
style_path = images_path + 'style.jpg'

Let's load it. The sizes of the two images are matched.

content = load_image(content_path).to(device)
style = load_image(style_path, shape=content.shape[-2:]).to(device)

Define a function that converts Tensor to numpy and allows you to display an image

def im_convert(tensor):
  #Tensor np.Convert to array
  image = tensor.to("cpu").clone().detach()
  image = image.numpy().squeeze()
  image = image.transpose(1,2,0)
  image = image * np.array((0.5, 0.5, 0.5) + np.array((0.5, 0.5, 0.5)))
  image = image.clip(0, 1)
  
  return image

Display the image on Colab.

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax1.axis("off")
ax2.imshow(im_convert(style))
ax2.axis("off")

style_transfer04-800x294.png

Extract features

CNN is roughly divided into two layers. The first half layer, which extracts features from the image, and the second half layer, which finally classifies the image from the extracted features.

The features of the two images are extracted by CNN, and the difference between the features is used as a loss function to train the final image. So let's put the image in CNN and extract the features in a specific layer. Extract in the 0th, 5th, 10th, 19th, 21st, and 28th layers of VGG.

First is the definition of the function.

def get_features(image, model):
#   Feature Extraction
#Layer to extract features
  layers = {'0': 'conv1_1',
            '5': 'conv2_1',
            '10': 'conv3_1',
            '19': 'conv4_1',
            '21': 'conv4_2',
            '28': 'conv5_1',}
  
  features = {}
  
  for name, layer in model._modules.items():
#Turn CNN
    image = layer(image)
#Extract features on a specific layer
#Here, 0th, 5th, 10th, 19th, 21st, 28th
    if name in layers:
      features[layers[name]] = image
      
  return features

Let's extract the features.

content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

The image on which the style is based does not simply compare features, but compares the Gram matrix of features. Let's define a function that computes the Gram matrix.

def gram_matrix(tensor):
  #Calculate the gram matrix
  _, d, h, w = tensor.size()
  tensor = tensor.view(d, h * w)
#Transpose matrix and matrix multiplication
  gram = torch.mm(tensor, tensor.t())
  
  return gram

Hold the style Gram matrix

style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

Sets the weight of the layers to be compared.

style_weights = {'conv1_1':1.,
                 'conv2_1':0.75,
                 'conv3_1':0.2,
                 'conv4_1':0.2,
                 'conv5_1':0.2}

content_weight = 1 #alpha
style_weight = 1e6 #blue

The final image will be trained by copying the content image, so copy it and define it as the target.

target = content.clone().requires_grad_(True).to(device)

Learning

The preparation is complete. Let's learn! Record the progress so that you can make a video later. Let's set hyperparameters.

show_every = 300
optimizer = optim.Adam([target], lr=0.003)
steps = 10000
total_capture_frame_number = 500

height, width, channels = im_convert(target).shape
image_array = np.empty(shape=(total_capture_frame_number, height, width, channels))
capture_frame =steps/total_capture_frame_number
counter = 0

Now it's learning. Please check the comments.

for ii in range(1, steps+1):
  target_features = get_features(target, vgg)
#Calculation of loss function with content
  content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
#Calculation of loss function with style
  style_loss = 0
  
  for layer in style_weights:
    target_feature = target_features[layer]
    target_gram = gram_matrix(target_feature)
    style_gram = style_grams[layer]
#weight*Sum of squares error
    layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
    _, d, h, w = target_feature.shape
    style_loss += layer_style_loss / (d * h * w)
  
  #Total loss function
  total_loss = content_weight * content_loss + style_weight * style_loss
  
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()
  
  #follow-up
  if ii % show_every == 0:
    print('Total loss: ', total_loss.item())
    print('Iteration: ', ii)
    plt.imshow(im_convert(target))
    plt.axis("off")
    plt.show()
    
#Store for video
  if ii % capture_frame == 0:
    image_array[counter] = im_convert(target)
    counter = counter + 1

I think it will take some time. 600th

600.png

5100th

5100.png

9900th

9900.png

I am learning properly.

Export video

Create a video using OpenCV.

import cv2

frame_height, frame_width, _ = im_convert(target).shape
vid = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'XVID'), 30, (frame_width, frame_height))

for i in range(total_capture_frame_number):
  img = image_array[i]
  img = img*255
  img = np.array(img, dtype = np.uint8)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  vid.write(img)
  
vid.release()

Let's download it.

from google.colab import files
files.download('output.mp4')

You can watch the learning process by playing the video.

in conclusion

Thank you for your hard work! This time I implemented Style Transfer using Pytorch. You can easily try it using the GPU of Google Colaboratory. If you save it in Google Drive, it's okay if you run out of Colab instances.

Please try various other images!

Recommended Posts

Implement Style Transfer in Pytorch
Implement Enigma in python
Implement recommendations in Python
Implement XENO in python
Implement sum in Python
Implement Traceroute in Python 3
Implement follow functionality in Django
Implement timer function in pygame
Implement recursive closures in Go
Implement naive bayes in Python 3.3
Implement UnionFind (equivalent) in 10 lines
Implement ancient ciphers in python
Implement PyTorch + GPU with Docker
Fold Pytorch Dataset in layers
Implement Redis Mutex in Python
Implement extension field in Python
Implement fast RPC in Python
Implement method chain in Python
Implement Dijkstra's Algorithm in python
Implement Slack chatbot in Python
Transfer parameter values in Python
Implement Gaussian process in Pyro
Implement stacking learning in Python [Kaggle]
[PyTorch] TRANSFER LEARNING FOR COMPUTER VISION
Implement Table Driven Test in Java
Implement a date setter in Tkinter
Implement the Singleton pattern in Python
I wrote Gray Scale in Pytorch
Quickly implement REST API in Python
How to call PyTorch in Julia