[PYTHON] Perform Stratified Split with PyTorch

What is Stratified Split?

When doing machine learning, we often split the dataset into training and validation data. Especially in the case of classification problems, it is possible to divide the data randomly without considering the class label, but it is desirable to divide the data so that the distribution of the class label of the divided data is the same as the original data. Dividing while maintaining the ratio of each class in this way is called stratified sampling or stratified split.

Implementation example in PyTorch

In scikit-learn, you can perform Stratified Split by passing the stratify option to the function sklearn.model_selection.train_test_split.

On the other hand, PyTorch does not have such a mechanism. You can use a function like torch.utils.data.random_split to randomly split the dataset, but you can't do a straight Stratified Split. Therefore, Stratified Split is realized by combining with scikit-learn's train_test_split.

For example, you can do a Stratified Split with code like this:

import torch
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

transformer = transforms.Compose([
    transforms.ToTensor(),
])

#Load image
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

#Split dataset into train and validation
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

#Create DataLoader
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

I will explain in order.

transformer = transforms.Compose([
    transforms.ToTensor(),
])

#Load image
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

This is nice, isn't it? I am creating a Dataset by reading an image with ImageFolder.

#Split dataset into train and validation
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)

This part is the main. Pass the array you want to split to the first argument of train_test_split, but you can't pass the Dataset directly, so list (range (len (dataset.targets))) is the index array of the Dataset ([0,1,1) 2,3, ... Number of data] ) is generated and passed in its place. Then, by passing the class label dataset.targets for this index array as a stratify option, the index array can be divided for training and validation while maintaining the ratio of the class label of the original data.

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

Since it is the index array that was split, the dataset is split based on that index. As the name implies, Subset is a class for creating a subset of data, and you can generate a Dataset corresponding to an index by passing the original Dataset and index array.

#Create DataLoader
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

All you have to do is pass the Dataset to the DataLoader as usual.

Reference site

https://discuss.pytorch.org/t/how-to-do-a-stratified-split/62290

Recommended Posts

Perform Stratified Split with PyTorch
Play with PyTorch
Cross-validation with PyTorch
Beginning with PyTorch
Use RTX 3090 with PyTorch
Install torch-scatter with PyTorch 1.7
Try an autoencoder with Pytorch
Perform logical operations with Perceptron
Implement PyTorch + GPU with Docker
AWS Lambda with PyTorch [Lambda import]
Prediction of Nikkei 225 with Pytorch
I made Word2Vec with Pytorch
Split screen into 3 with keyhac
[PyTorch Tutorial ⑤] Learning PyTorch with Examples (Part 2)
ROS Lecture 113 Perform tasks with smach
Learn with PyTorch Graph Convolutional Networks
I implemented Attention Seq2Seq with PyTorch
I tried implementing DeepPose with PyTorch
Prediction of Nikkei 225 with Pytorch ~ Intermission ~
How to Data Augmentation with PyTorch
pytorch @ python3.8 environment construction with pipenv
Achieve pytorch reflection padding with Tensorflow
Sine wave prediction (regression) with Pytorch