[PYTHON] [PyTorch] Data Augmentation for segmentation

0. Who is the target of this article

  1. Those who implement image ** segmentation ** using ** PyTorch **
  2. Those who want to inflate data with ** Data Augmentation **
  3. Those who want to apply ** exactly the same processing ** to the corresponding original image and mask image
  4. Especially those who use ** own dataset ** (data not in torchvision.datasets)

1 Overview

Primarily for supervised or semi-supervised segmentation datasets

2 Problems

Before looking at the case in question, let's first consider the case where there is no problem.

2.1 No problem case (object class recognition, etc.)

When doing Data Augmentation with PyTorch, usually define the transformation as follows


transform = torchvision.transforms.Compose([
    #Rotate by angle degrees
    transforms.RandomRotation(degrees),
    #Flip horizontally
    transforms.RandomHorizontalFlip(),
    #Invert vertically
    transforms.RandomVerticalFlip()
])

I will put it in the argument of the dataset


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform
)

Probably this is not a problem for object class recognition etc. The reason is that the teacher data is not an image, so you only have to process the original image.

2.2 Problem cases (such as segmentation)

The next problematic case The difference from the previous case is that the teacher data is given as an image.


transform = torchvision.transforms.Compose([
    #Rotate by angle degrees
    transforms.RandomRotation(degrees),
    #Flip horizontally
    transforms.RandomHorizontalFlip(),
    #Invert vertically
    transforms.RandomVerticalFlip()
])

I will put it in the argument of the dataset


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform, target_transform=transform
)

However, in this case, when retrieving data from HogeDataset, the conversions made to the original image and the mask image do not correspond. Example) Original image: 90 degree rotation, mask image: 270 degree rotation In this case, even if the data is inflated, it will not function as teacher data. Argument target_transform, why do you exist? However, the reason for this existence is probably that the mask image is also processed (without randomness) such as torchvision.transforms.Resize () and torchvision.transforms.ToTensor (). I think it's in

3 solution

So, how can we apply the same processing to the mask image as the original image? As a solution, you can create your own Dataset class as shown below.

HogeDataset.py


import os
import glob
import torch
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random
from PIL import Image

DATA_PATH = '[Original image directory path]'
MASK_PATH = '[Mask image directory path]'
TRAIN_NUM = [Number of training data]

class HogeDataset(torch.utils.data.Dataset):
    def __init__(self, transform = None, target_transform = None, train = True):
        #transform and target_transform is a non-random transform such as tensorization
        self.transform = transform
        self.target_transform = target_transform


        data_files = glob.glob(DATA_PATH + '/*.[File extension]')
        mask_files = glob.glob(MASK_PATH + '/*.[File extension]')

        self.dataset = []
        self.maskset = []

        #Import original image
        for data_file in data_files:
            self.dataset.append(Image.open(
                DATA_PATH + os.path.basename(data_file)
            ))

        #Mask image reading
        for mask_file in mask_files:
            self.maskset.append(Image.open(
                MASK_PATH + os.path.basename(mask_file)
            ))

        #Divided into training data and test data
        if train:
            self.dataset = self.dataset[:TRAIN_NUM]
            self.maskset = self.maskset[:TRAIN_NUM]
        else:
            self.dataset = self.dataset[TRAIN_NUM+1:]
            self.maskset = self.maskset[TRAIN_NUM+1:]

        # Data Augmentation
        #Random conversion is done here
        self.augmented_dataset = []
        self.augmented_maskset = []
        for num in range(len(self.dataset)):
            data = self.dataset[num]
            mask = self.maskset[num]
            #Random crop
            for crop_num in range(16):
                #Crop position is determined by random numbers
                i, j, h, w = transforms.RandomCrop.get_params(data, output_size=(250,250))
                cropped_data = tvf.crop(data, i, j, h, w)
                cropped_mask = tvf.crop(mask, i, j, h, w)
                
                #rotation(0, 90, 180,270 degrees)
                for rotation_num in range(4):
                    rotated_data = tvf.rotate(cropped_data, angle=90*rotation_num)
                    rotated_mask = tvf.rotate(cropped_mask, angle=90*rotation_num)
                    
                    #Either horizontal inversion or vertical inversion
                    #Invert(horizontal direction)
                    for h_flip_num in range(2):
                        h_flipped_data = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_data)
                        h_flipped_mask = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_mask)
                    
                    """    
                    #Invert(Vertical direction)
                    for v_flip_num in range(2):
                        v_flipped_data = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_data)
                        v_flipped_mask = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_mask)
                    """
        
                        #Add Data Augmented data
                        self.augmented_dataset.append(h_flipped_data)
                        self.augmented_maskset.append(h_flipped_mask)

        self.datanum = len(self.augmented_dataset)

    #Data size acquisition method
    def __len__(self):
        return self.datanum

    #Data acquisition method
    #Non-random conversion is done here
    def __getitem__(self, idx):
        out_data = self.augmented_dataset[idx]
        out_mask = self.augmented_maskset[idx]

        if self.transform:
            out_data = self.transform(out_data)

        if self.target_transform:
            out_mask = self.target_transform(out_mask)

        return out_data, out_mask

What we are doing is simple, we do Data Augmentation inside __init __ () At that time, about each image pair

Comprehensive processing in all cases

For the time being, you can apply the same processing as the original image to the mask image and perform Data Augmetation like this ** [Supplement] It is recommended to use only horizontal or vertical inversion because the combination of rotation and inversion can cause duplication! !! ** **

4 How to use

Try using your own Dataset class in 3


import torch
import torchvision
import HogeDataset

BATCH_SIZE = [Batch size]

#Preprocessing
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
target_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224), interpolation=0), 
    torchvision.transforms.ToTensor()
])

#Preparation of training data and test data
trainset = HogeDataset.HogeDataset(
    train=True,
    transform=transform, 
    target_transform=target_transform
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

testset = EpiDatasets.EpiDatasets(
    train=False,
    transform=transform,
    target_transform=target_transform
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

Recommended Posts

[PyTorch] Data Augmentation for segmentation
How to Data Augmentation with PyTorch
Display the image after Data Augmentation with Pytorch
Try Semantic Segmentation (Pytorch)
Dataset preparation for PyTorch
Data Augmentation with openCV
Data set for machine learning
Python for Data Analysis Chapter 4
Python for Data Analysis Chapter 2
Try Random Erasing Data Augmentation
New Data Augmentation? [Grid Mix]
Tips for data analysis ・ Notes
Python for Data Analysis Chapter 3
[PyTorch] TRANSFER LEARNING FOR COMPUTER VISION
Python course for data science_useful techniques
VS Code snippets for data analysts
Preprocessing template for data analysis (Python)
Data analysis for improving POG 3-Regression analysis-
Data formatting for Python / color plots
Recommended competition site for data scientists
Initial settings for Mac (for data analysts)