[PYTHON] Effectuer un fractionnement stratifié avec PyTorch

Qu'est-ce que le fractionnement stratifié?

Lors de l'entraînement machine, il est courant de diviser un ensemble de données en données d'entraînement et en données de validation. En particulier dans le cas de problèmes de classification, il est possible de diviser au hasard sans tenir compte de l'étiquette de classe, mais il est souhaitable de diviser afin que la distribution de l'étiquette de classe des données divisées soit la même que les données d'origine. La division tout en maintenant le rapport de chaque classe de cette manière est appelée extraction stratifiée ou division stratifiée.

Exemple d'implémentation dans PyTorch

Dans scikit-learn, vous pouvez effectuer un fractionnement stratifié en passant l'option stratifier à la fonction sklearn.model_selection.train_test_split.

En revanche, PyTorch ne dispose pas d'un tel mécanisme. Vous pouvez utiliser une fonction comme torch.utils.data.random_split pour diviser aléatoirement un jeu de données, mais vous ne pouvez pas faire un fractionnement stratifié direct. Par conséquent, la division stratifiée est réalisée en combinant avec train_test_split de scicit-learn.

Par exemple, vous pouvez faire un fractionnement stratifié avec un code comme celui-ci:

import torch
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

transformer = transforms.Compose([
    transforms.ToTensor(),
])

#Charger l'image
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

#Divisez l'ensemble de données en train et validation
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)
train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

#Créer DataLoader
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

Je vais expliquer dans l'ordre.

transformer = transforms.Compose([
    transforms.ToTensor(),
])

#Charger l'image
dataset = torchvision.datasets.ImageFolder(root='directory_name', transform=transformer)

C'est bien, non? Le dossier d'images est utilisé pour charger l'image et créer un ensemble de données.

#Divisez l'ensemble de données en train et validation
train_indices, val_indices = train_test_split(list(range(len(dataset.targets))), test_size=0.2, stratify=dataset.targets)

Cette partie est la principale. Passez le tableau que vous voulez fractionner au premier argument de train_test_split, mais vous ne pouvez pas passer directement le Dataset, utilisez donc list (range (len (dataset.targets))) pour indexer le Dataset ([0,1,1) 2,3, ... Nombre de données] ) est généré et passé à sa place. Ensuite, en passant l'étiquette de classe dataset.targets pour ce tableau d'index comme option de stratification, le tableau d'index peut être divisé pour l'apprentissage et la validation tout en conservant le rapport de l'étiquette de classe des données d'origine.

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

Étant donné que c'est le tableau d'index qui est divisé, l'ensemble de données est divisé en fonction de l'index. Comme son nom l'indique, Subset est une classe permettant de créer un sous-ensemble de données, et vous pouvez générer un Dataset correspondant à un index en passant le Dataset et le tableau d'index d'origine.

#Créer DataLoader
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

Tout ce que vous avez à faire est de transmettre le Dataset au DataLoader comme d'habitude.

Site de référence

https://discuss.pytorch.org/t/how-to-do-a-stratified-split/62290

Recommended Posts

Effectuer un fractionnement stratifié avec PyTorch
Jouez avec PyTorch
Validation croisée avec PyTorch
À partir de PyTorch
Utilisez RTX 3090 avec PyTorch
Installer la diffusion de la torche avec PyTorch 1.7
Essayez Auto Encoder avec Pytorch
Effectuer des opérations logiques à l'aide de Perceptron
Implémenter le GPU PyTorch + avec Docker
PyTorch avec AWS Lambda [importation Lambda]
Prédiction de la moyenne Nikkei avec Pytorch
J'ai créé Word2Vec avec Pytorch
Écran divisé en 3 avec keyhac
[Tutoriel PyTorch ⑤] Apprentissage de PyTorch avec des exemples (Partie 2)
Conférence ROS 113 Effectuer des tâches avec smach
Apprenez avec les réseaux convolutifs PyTorch Graph
J'ai essayé d'implémenter Attention Seq2Seq avec PyTorch
J'ai essayé d'implémenter DeepPose avec PyTorch
Prédiction de la moyenne Nikkei avec Pytorch ~ Makuma ~
Comment augmenter les données avec PyTorch
Construction de l'environnement pytorch @ python3.8 avec pipenv
Obtenez un rembourrage de réflexion Pytorch avec Tensorflow
Prédiction des ondes de Sin (retour) avec Pytorch