[PYTHON] Cross-validation with PyTorch

Introduction

Learn how to cross-validate when using a Dataset with Pytorch.

Split using Subset

You can use torch.utils.data.dataset.Subset to split a Dataset by specifying an index. Combine this with the scikit-learn sklearn.model_selection.

train_test_split Use sklearn.model_selection.train_test_split to split the index into train_index and valid_index, and use Subset to split the Dataset.

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import train_test_split


dataset = get_dataset()

train_index, valid_index = train_test_split(range(len(dataset)), test_size=0.3)

batch_size = 16
train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset   = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

#Learning code here

KFold cross-validation

Use sklearn.model_selection.KFold to split the index into train_index and valid_index, and use Subset to split the Dataset.

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold


dataset = get_dataset()

batch_size = 16
kf = KFold(n_splits=3)

for _fold, (train_index, test_index) in enumerate(kf.split(X)):
    train_dataset = Subset(dataset, train_index)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
    valid_dataset   = Subset(dataset, valid_index)
    valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

    #Learning code here

If it is a class classification Dataset, you should be able to get the value of y by usingdataset [:] [1], so you should be able to do Stratified KFold as well.

Recommended Posts

Cross-validation with PyTorch
Play with PyTorch
Beginning with PyTorch
Use RTX 3090 with PyTorch
Install torch-scatter with PyTorch 1.7
Try implementing XOR with PyTorch
Implement PyTorch + GPU with Docker
Prediction of Nikkei 225 with Pytorch 2
Machine learning Minesweeper with PyTorch
AWS Lambda with PyTorch [Lambda import]
Prediction of Nikkei 225 with Pytorch
Perform Stratified Split with PyTorch
I made Word2Vec with Pytorch
[PyTorch Tutorial ⑤] Learning PyTorch with Examples (Part 2)
Learn with PyTorch Graph Convolutional Networks
I implemented Attention Seq2Seq with PyTorch
I tried implementing DeepPose with PyTorch
How to Data Augmentation with PyTorch
[PyTorch Tutorial ⑤] Learning PyTorch with Examples (Part 1)
pytorch @ python3.8 environment construction with pipenv
Achieve pytorch reflection padding with Tensorflow
Sine wave prediction (regression) with Pytorch
Install pytorch
Multi-class, multi-label classification of images with pytorch
I implemented Shake-Shake Regularization (ShakeNet) with PyTorch
PyTorch Links
Make a drawing quiz with kivy + PyTorch
[Python] Introduction to CNN with Pytorch MNIST
Practice Pytorch
Document classification with toch text from PyTorch
[Introduction to Pytorch] I played with sinGAN ♬
I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch PartⅡ
I tried to implement CVAE with PyTorch
Machine learning with Pytorch on Google Colab
Install PyTorch
[PyTorch] Handle image pairs with Dataset & DataLorder
I tried to detect Mario with pytorch + yolov3
I tried to implement reading Dataset with PyTorch
Story of trying to use tensorboard with pytorch
Display the image after Data Augmentation with Pytorch