** I want to display the image after Data Augmentation! ** **
I thought and implemented it.
Data Augmentation is a technology for padding a single image, and the following operations are added.
--Random Crop (Random crop image) --Random Horizontal Flip (image is flipped horizontally with a certain probability) --Random Erasing (Randomly add noise to a part of the image) --Random Affine (Randomly scales / rotates the image)
There are many other things.
This time, I loaded the training image dataset of CIFAR-10 and tried to incorporate Random Horizontal Flip and Random Erasing into transforms.
test.py
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset,DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt
#Loading images
batch_size = 100
train_data = dsets.CIFAR10(root='./tmp/cifar-10', train=True, download=False, transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.RandomErasing(p=0.5, scale=(0.02, 0.4), ratio=(0.33, 3.0))]))
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_data = dsets.CIFAR10(root='./tmp/cifar-10', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(),]))
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)
def image_show(data_loader,n):
#Read Augmented image data
tmp = iter(data_loader)
images,labels = tmp.next()
#Convert image from tensor to numpy
images = images.numpy()
#Take out n images one by one and display them
for i in range(n):
image = np.transpose(images[i],[1,2,0])
plt.imshow(image)
plt.show()
image_show(train_loader,10)
The image_show function is a function that displays the image after Augmentation.
Get one mini-batch from DataLoader with iter ().
Then, use .next () to store the image data in images and the labels in labels.
images = images.numpy () converts image data from tensor to numpy.
At this point, images has a structure of ** [batch size, number of channels, width, height] **, but to display images with pyplot of matplotlib ** [width, height, number of channels] Must be **.
Therefore, it is transformed using np.transpose.
It was confirmed that it was flipped horizontally and that noise was added by Random Erasing.
Recommended Posts