[PYTHON] [Pytorch] Mémo sur Dataset / DataLoader

Aperçu

Notez le Dataset / DataLoader utilisé lors de la création d'un dataset avec Pytorch

référence:

Prétraitement des données

Pour le prétraitement des données, il existe une bibliothèque par torchvision.transforms ou ʻalbumentations. L'opération de base est la même pour les deux. Créez une instance en compressant l'instance de la classe de prétraitement dans la liste et en l'utilisant comme argument de Compose (). Compose a une méthode call (self, img)`, donc si vous mettez une image dans l'argument de l'instance créée, elle sera prétraitée.

import alubmentations as alb

def get_augmentation(phase):
   transform_list = []
   if phase == 'train':
        transform_list.extend([albu.HorizonFlip(p=0.5),
                             albu.VerticalFlip(p=0.5)])
   transform_list.extend([albu.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225),
                                         p=1),
                          albu.ToTensor()
                        ])
    return albu.Compose(transform_list)

Dataset ** Un module qui récupère les données d'entrée et les étiquettes correspondantes une par une **. Lors du prétraitement des données, ** les transformations doivent être utilisées pour renvoyer les données prétraitées **.

** **

--Héritage de Dataset

Fondamentalement OK si ce qui précède est satisfait! Une instance de la classe d'héritage Dataset est le premier argument de DataLoder. (Plus tard pour Data Lodaer)

Par exemple, supposons que l'ensemble de données ait la structure de répertoires suivante.

datasets/ ____ train_images/
           |__ test_images/
           |__ train.csv

Cette fois, nous supposons que le fichier .csv contient des informations de chemin de données et d'étiquette pour l'ensemble de données.

import os.path as osp

import cv2
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class MyDataManager(Dataset):
    """My Dataset

    Args:
        root(str): root path of dataset directory
        df(DataFrame): DataFrame object from csv file
        phase(str): train or test
    """
    
    def __init__(self, root, df, phase):
       super(MyDataManager, self).__init__()
       self.root = root
       self.df = df
       self.phase = phase
       self.transfoms = get_augmentation()

    def __getitem__(self, idx):
       img_path = osp.join(self.root, self.df.iloc[idx].name)
       img = cv2.imread(img_path)
       img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
       img = self.transform(image=img)

       label = osp.join(eslf.root, self.df.iloc[idx].value)
       
       ret = {'image': img, 'label': label}

       return ret

    def __len__(self):
       return len(self.df)
       

Cette fois, la valeur de retour est de type dict, mais il n'y a pas de problème avec return image, label. Lors de l'exécution d'une segmentation, etc., il est nécessaire de donner l'étiquette en tant qu'image de masque, donc dans ce cas, transférez également l'image de masque.

DataLoader

Les données récupérées par Datset peuvent être utilisées comme argument de DataLoader. La structure des arguments de DataLoader est la suivante

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

Par conséquent, créez la fonction suivante.


def dataloader(dir_path,phase,batch_size, num_workers, shuffle=False):
    df_path = osp.join(dir_path, 'train.csv')
    df = pd.read_csv(df_path)

    dataset = MyDataManager(dir_path, df, phase)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory=True,
                    shuffle=shuffle)
   
    return dl

Recommended Posts

[Pytorch] Mémo sur Dataset / DataLoader
Notes sur Pytorch
Mémo PyTorch (gestion des dimensions)
[Python] Mémo sur les fonctions
Préparation du jeu de données pour PyTorch
[Python] Mémo sur les erreurs
PyTorch DataLoader est lent
Comprendre le DataSet et le DataLoader de PyTorch (2)
Comprendre le DataSet et le DataLoader de PyTorch (1)
Pliez le jeu de données Pytorch en couches
J'ai essayé d'expliquer l'ensemble de données de Pytorch