[PYTHON] [Pytorch] Memo about Dataset / DataLoader

Overview

Make a note of the Dataset / DataLoader used when creating a dataset with Pytorch

reference:

Data preprocessing

For data preprocessing, there is a library per torchvision.transforms or ʻalbumentations. The basic operation is the same for both. Pack the preprocessing class instance into the list and create an instance with Compose ()as an argument. Compose has a call (self, img) `method, so if you put an image in the argument of the created instance, it will be preprocessed.

import alubmentations as alb

def get_augmentation(phase):
   transform_list = []
   if phase == 'train':
        transform_list.extend([albu.HorizonFlip(p=0.5),
                             albu.VerticalFlip(p=0.5)])
   transform_list.extend([albu.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225),
                                         p=1),
                          albu.ToTensor()
                        ])
    return albu.Compose(transform_list)

Dataset ** A module that fetches input data and corresponding labels one by one **. When preprocessing data, ** transforms should be used to return the preprocessed data **.

** **

--Inheritance of Dataset --Implementation of __getitem__, __len__

Basically OK if the above is satisfied! An instance of the Dataset inheritance class is the first argument of DataLoder. (Later for Data Lodaer)

For example, suppose the dataset has the following directory structure.

datasets/ ____ train_images/
           |__ test_images/
           |__ train.csv

This time we assume that the .csv file has the data path and label information for the dataset.

import os.path as osp

import cv2
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class MyDataManager(Dataset):
    """My Dataset

    Args:
        root(str): root path of dataset directory
        df(DataFrame): DataFrame object from csv file
        phase(str): train or test
    """
    
    def __init__(self, root, df, phase):
       super(MyDataManager, self).__init__()
       self.root = root
       self.df = df
       self.phase = phase
       self.transfoms = get_augmentation()

    def __getitem__(self, idx):
       img_path = osp.join(self.root, self.df.iloc[idx].name)
       img = cv2.imread(img_path)
       img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
       img = self.transform(image=img)

       label = osp.join(eslf.root, self.df.iloc[idx].value)
       
       ret = {'image': img, 'label': label}

       return ret

    def __len__(self):
       return len(self.df)
       

This time, the return value is dict type, but there is no problem with return image, label. When performing Segmentation etc., it is necessary to give the label as a mask image, so in that case also transfom the mask image.

DataLoader

The data fetched by Datset can be used as the argument of DataLoader. The argument structure of DataLoader is as follows

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

Therefore, create the following function.


def dataloader(dir_path,phase,batch_size, num_workers, shuffle=False):
    df_path = osp.join(dir_path, 'train.csv')
    df = pd.read_csv(df_path)

    dataset = MyDataManager(dir_path, df, phase)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory=True,
                    shuffle=shuffle)
   
    return dl

Recommended Posts

[Pytorch] Memo about Dataset / DataLoader
Notes about pytorch
PyTorch memo (dimension management)
[Python] Memo about functions
Dataset preparation for PyTorch
[Python] Memo about errors
PyTorch DataLoader is slow
Understand PyTorch's DataSet and DataLoader (2)
Understand PyTorch's DataSet and DataLoader (1)
Fold Pytorch Dataset in layers
I tried to explain Pytorch dataset