[PYTHON] View image after Data Augmentation in PyTorch

Here's a summary of some of the most commonly used image augmentations using pytorch: cherry_blossom: Please refer to it when you like "I want to erase a part of the image, but I forgot the name of the person who does it ...".

Also, some of them are shown here. For more details, see the original PyTorch transforms documentation.

Preparation

First, import the required libraries and download the CIFAR10 dataset that we are dealing with this time. We also prepared a function to visualize the image.

#Module import
import torch
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
#Functions to download the CIFAR10 dataset
def load_cifar10(transform):
    cifar10_dataset = datasets.CIFAR10(
                        root='./',
                        transform=transform,
                        download=True)
    return cifar10_dataset
#Function to visualize the image of CIFAR10
def show_img(dataset):
    plt.figure(figsize=(15, 3))
    for i in range(5):
        image, label = dataset[i]
        image = image.permute(1, 2, 0)
        plt.subplot(1, 5, i+1)
        plt.tick_params(labelbottom=False, labelleft=False, labelright=False, labeltop=False)
        plt.tick_params(bottom=False, left=False, right=False, top=False)
        plt.imshow(image)
    plt.savefig('a.png')

CIFAR10 dataset

Let's take a look at the CIFAR10 dataset. CIFAR10 is (32, 32) size and consists of 50,000 train images and 10,000 test images. There are 10 class labels. airplane (airplane), automobile (car), bird (bird), cat (cat), deer (deer), dog (dog), frog (frog), horse (horse), ship (ship), truck (truck) ..

#Image processing, tensor type in pytorch.
transform = transforms.Compose([
    transforms.ToTensor()
])
#CIFAR10 download and image visualization
cifar10_dataset = load_cifar10(transform)
show_img(cifar10_dataset)

a.png

Risize Change the resolution of the image. The original size of CIFAR10 is (32, 32). This time, the resolution of the image size is reduced to (16, 16).

transform = transforms.Compose([
    transforms.Resize(16),
    transforms.ToTensor()
])
cifar10_dataset = load_cifar10(transform)
show_img(cifar10_dataset)

a.png

CenterCrop Cut out the central part. The original size of CIFAR10 is (32, 32). This time, the size is (24, 24) and the central part is cut out. Often applied to test data. (Also, since it will be redundant, I will omit the visualization code from this time. The code itself is the same as the above.)

transform = transforms.Compose([
    transforms.CenterCrop(24),
    transforms.ToTensor()
])

a.png

RandomCrop This is randomly cut out in (24, 24) size. In CenterCrop, the central part was always taken out, but in RandomCrop, it may be the central part, the part closer to the upper left, or the part closer to the lower right. Often applied to training images.

transform = transforms.Compose([
    transforms.RandomCrop(24),
    transforms.ToTensor()
])

a.png

RandomHorizontalFlip Randomly flips horizontally with the given probability (p). The default value of p is 0.5, so if you don't set a value of p, it will flip horizontally with a half probability. This time, the 1st, 2nd, and 3rd sheets are horizontally inverted.

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

a.png

RandomVerticalFlip Randomly flips vertically with a given probability (p). Since the default value of p is 0.5, if you do not set the value of p, there is a half chance of vertical inversion. This time, the 1st, 4th, and 5th sheets are vertically inverted.

transform = transforms.Compose([
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor()
])

a.png

RandomRotation Rotates randomly within a given degrees. Since degrees = 30 this time, the range from -30 degrees to +30 degrees is taken randomly. You can also change the value of the center of the rotation axis (default is center) and the value of the part that appears by rotating (default is black).

transform = transforms.Compose([
    transforms.RandomRotation(degrees=30),
    transforms.ToTensor()
])

a.png

RandomErasing p: Probability of being Erasing. scale: The number of areas to be erased divided by the total area. By default, it can be erased from 1/50 to 1/3 of the total. ratio: The aspect ratio of the rectangle to be Erased. They range from horizontally long rectangles to vertically long rectangles. value: A value that applies to the Erased range. The default value is set to 1 and it will be black.

transform = transforms.Compose([
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0),
    transforms.ToTensor()
])

a.png

Normalize Normalizes the tensor image with mean and standard deviation. For n channels mean. (M1, ..., Mn) and standard deviation: (S1, ..., Sn) are given for n channels. In a general RGB color image, there are 3 channels, so 3 each for mean and std are specified. The values used this time are pulling valid values in the CIFAR10 dataset. (Note) Normalize is different from the previous ones, and it is applied after changing to the Tendor type.

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

a.png

Grayscale Convert to a gray image. With num_output_channels, you can select whether the output image is channel 1 or channel 3. If you set it to 3 channels, the same value will be entered in 3 channels of each pixel.

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3)
    transforms.ToTensor()
])

a.png

RandomApply Sets the probability of applying a transform. As an example, apply the above Grayscale with a probability of 0.5. This time, the 1st, 2nd and 5th sheets have been converted to grayscale.

transform = transforms.Compose([
    transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=0.5),
    transforms.ToTensor()
])

a.png

At the end

I'm new to pytorch, so please let me know if you make a mistake. I often forget the name of transforms, so I thought it would be nice to have a list in Japanese, so I summarized it in the article. I would be happy if anyone could help me! !! : relaxed: If you have any concerns, please comment!

Recommended Posts

View image after Data Augmentation in PyTorch
Display the image after Data Augmentation with Pytorch
[PyTorch] Data Augmentation for segmentation
How to Data Augmentation with PyTorch
Get a panoramic image in Google Street View
How to read time series data in PyTorch
Train MNIST data with a neural network in PyTorch
Sampling in imbalanced data
View images in Matplotlib
Image format in Python
Data Augmentation with openCV
Image uploader in Flask
Image normalization in TensorFlow