[PYTHON] Validation croisée avec PyTorch

introduction

Apprenez à effectuer une contre-validation lors de l'utilisation de Dataset avec Pytorch.

Fractionner à l'aide d'un sous-ensemble

Vous pouvez utiliser torch.utils.data.dataset.Subset pour diviser un Dataset en spécifiant un index. Combinez cela avec le scikit-learn sklearn.model_selection.

train_test_split Utilisez sklearn.model_selection.train_test_split pour diviser l'index en train_index et valid_index, et utilisez Subset pour diviser le jeu de données.

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import train_test_split


dataset = get_dataset()

train_index, valid_index = train_test_split(range(len(dataset)), test_size=0.3)

batch_size = 16
train_dataset = Subset(dataset, train_index)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_dataset   = Subset(dataset, valid_index)
valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

#Code d'apprentissage ici

Validation croisée KFold

Utilisez sklearn.model_selection.KFold pour diviser l'index en train_index et valid_index, et utilisez Subset pour diviser le jeu de données.

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold


dataset = get_dataset()

batch_size = 16
kf = KFold(n_splits=3)

for _fold, (train_index, test_index) in enumerate(kf.split(X)):
    train_dataset = Subset(dataset, train_index)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
    valid_dataset   = Subset(dataset, valid_index)
    valid_dataloader = DataLoader(valid_dataset, batch_size, shuffle=False)

    #Code d'apprentissage ici

S'il s'agit d'un Dataset de classification de classe, vous devriez être capable d'obtenir la valeur de y en faisant dataset [:] [1], donc vous devriez être capable de faire aussi Stratified KFold.

Recommended Posts

Validation croisée avec PyTorch
Jouez avec PyTorch
À partir de PyTorch
Utilisez RTX 3090 avec PyTorch
Installer la diffusion de la torche avec PyTorch 1.7
Essayez d'implémenter XOR avec PyTorch
Implémenter le GPU PyTorch + avec Docker
Prédiction de la moyenne Nikkei avec Pytorch 2
Démineur d'apprentissage automatique avec PyTorch
PyTorch avec AWS Lambda [importation Lambda]
Prédiction de la moyenne Nikkei avec Pytorch
Effectuer un fractionnement stratifié avec PyTorch
J'ai créé Word2Vec avec Pytorch
[Tutoriel PyTorch ⑤] Apprentissage de PyTorch avec des exemples (Partie 2)
Apprenez avec les réseaux convolutifs PyTorch Graph
J'ai essayé d'implémenter Attention Seq2Seq avec PyTorch
J'ai essayé d'implémenter DeepPose avec PyTorch
Comment augmenter les données avec PyTorch
[Tutoriel PyTorch ⑤] Apprentissage de PyTorch avec des exemples (Partie 1)
Construction de l'environnement pytorch @ python3.8 avec pipenv
Obtenez un rembourrage de réflexion Pytorch avec Tensorflow
Prédiction des ondes de Sin (retour) avec Pytorch
Installer pytorch
Classification multi-étiquette d'images multi-classes avec pytorch
J'ai essayé d'implémenter la régularisation Shake-Shake (ShakeNet) avec PyTorch
Liens PyTorch
Créez un quiz de dessin avec kivy + PyTorch
Pratiquez Pytorch
Classification des documents avec texte toch de PyTorch
[Introduction à Pytorch] J'ai joué avec sinGAN ♬
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé d'implémenter CVAE avec PyTorch
Apprentissage automatique avec Pytorch sur Google Colab
Installez PyTorch
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
Histoire d'essayer d'utiliser Tensorboard avec Pytorch
Afficher l'image après l'augmentation des données avec Pytorch