[PYTHON] [Pytorch] Memo über Dataset / DataLoader

Überblick

Notieren Sie sich den Datensatz / DataLoader, der beim Erstellen eines Datensatzes mit Pytorch verwendet wird

Referenz:

Datenvorverarbeitung

Für die Datenvorverarbeitung gibt es eine Bibliothek pro "torchvision.transforms" oder "albumations". Die Grundoperation ist für beide gleich. Erstellen Sie eine Instanz, indem Sie die Vorverarbeitungsklasseninstanz in die Liste packen und als Argument für "Compose ()" verwenden. Compose hat eine __call__ (self, img) -Methode. Wenn Sie also ein Bild in das Argument der erstellten Instanz einfügen, wird es vorverarbeitet.

import alubmentations as alb

def get_augmentation(phase):
   transform_list = []
   if phase == 'train':
        transform_list.extend([albu.HorizonFlip(p=0.5),
                             albu.VerticalFlip(p=0.5)])
   transform_list.extend([albu.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225),
                                         p=1),
                          albu.ToTensor()
                        ])
    return albu.Compose(transform_list)

Dataset ** Ein Modul, das nacheinander Eingabedaten und entsprechende Beschriftungen abruft **. Bei der Vorverarbeitung von Daten sollten ** Transformationen verwendet werden, um die vorverarbeiteten Daten ** zurückzugeben.

** **

Grundsätzlich OK, wenn das oben genannte erfüllt ist! Eine Instanz der Dataset-Vererbungsklasse ist das erste Argument von DataLoder. (Später für Data Lodaer)

Angenommen, der Datensatz hat die folgende Verzeichnisstruktur.

datasets/ ____ train_images/
           |__ test_images/
           |__ train.csv

Diesmal wird davon ausgegangen, dass die CSV-Datei den Datenpfad und die Beschriftungsinformationen für das Dataset enthält.

import os.path as osp

import cv2
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class MyDataManager(Dataset):
    """My Dataset

    Args:
        root(str): root path of dataset directory
        df(DataFrame): DataFrame object from csv file
        phase(str): train or test
    """
    
    def __init__(self, root, df, phase):
       super(MyDataManager, self).__init__()
       self.root = root
       self.df = df
       self.phase = phase
       self.transfoms = get_augmentation()

    def __getitem__(self, idx):
       img_path = osp.join(self.root, self.df.iloc[idx].name)
       img = cv2.imread(img_path)
       img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
       img = self.transform(image=img)

       label = osp.join(eslf.root, self.df.iloc[idx].value)
       
       ret = {'image': img, 'label': label}

       return ret

    def __len__(self):
       return len(self.df)
       

Diesmal ist der Rückgabewert vom Typ dict, aber es gibt kein Problem mit "return image, label". Bei der Segmentierung usw. ist es erforderlich, das Etikett als Maskenbild anzugeben, damit in diesem Fall auch das Maskenbild übertragen wird.

DataLoader

Die von Datset abgerufenen Daten können als Argument von DataLoader verwendet werden. Die Argumentstruktur von DataLoader lautet wie folgt

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

Erstellen Sie daher die folgende Funktion.


def dataloader(dir_path,phase,batch_size, num_workers, shuffle=False):
    df_path = osp.join(dir_path, 'train.csv')
    df = pd.read_csv(df_path)

    dataset = MyDataManager(dir_path, df, phase)
    dl = DataLoader(dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    pin_memory=True,
                    shuffle=shuffle)
   
    return dl

Recommended Posts

[Pytorch] Memo über Dataset / DataLoader
Hinweise zu Pytorch
PyTorch-Memo (Dimensionsverwaltung)
[Python] Memo über Funktionen
Datensatzvorbereitung für PyTorch
[Python] Memo Über Fehler
PyTorch DataLoader ist langsam
Grundlegendes zu PyTorchs DataSet und DataLoader (2)
Grundlegendes zu PyTorchs DataSet und DataLoader (1)
Falten Sie den Pytorch-Datensatz in Schichten
Ich habe versucht, Pytorchs Datensatz zu erklären