Notieren Sie sich den Datensatz / DataLoader, der beim Erstellen eines Datensatzes mit Pytorch verwendet wird
Referenz:
Für die Datenvorverarbeitung gibt es eine Bibliothek pro "torchvision.transforms" oder "albumations".
Die Grundoperation ist für beide gleich.
Erstellen Sie eine Instanz, indem Sie die Vorverarbeitungsklasseninstanz in die Liste packen und als Argument für "Compose ()" verwenden.
Compose
hat eine __call__ (self, img)
-Methode. Wenn Sie also ein Bild in das Argument der erstellten Instanz einfügen, wird es vorverarbeitet.
import alubmentations as alb
def get_augmentation(phase):
transform_list = []
if phase == 'train':
transform_list.extend([albu.HorizonFlip(p=0.5),
albu.VerticalFlip(p=0.5)])
transform_list.extend([albu.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
p=1),
albu.ToTensor()
])
return albu.Compose(transform_list)
Dataset ** Ein Modul, das nacheinander Eingabedaten und entsprechende Beschriftungen abruft **. Bei der Vorverarbeitung von Daten sollten ** Transformationen verwendet werden, um die vorverarbeiteten Daten ** zurückzugeben.
**
__getitem__
, __len__
Grundsätzlich OK, wenn das oben genannte erfüllt ist! Eine Instanz der Dataset-Vererbungsklasse ist das erste Argument von DataLoder. (Später für Data Lodaer)
Angenommen, der Datensatz hat die folgende Verzeichnisstruktur.
datasets/ ____ train_images/
|__ test_images/
|__ train.csv
Diesmal wird davon ausgegangen, dass die CSV-Datei den Datenpfad und die Beschriftungsinformationen für das Dataset enthält.
import os.path as osp
import cv2
import pandas as pd
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class MyDataManager(Dataset):
"""My Dataset
Args:
root(str): root path of dataset directory
df(DataFrame): DataFrame object from csv file
phase(str): train or test
"""
def __init__(self, root, df, phase):
super(MyDataManager, self).__init__()
self.root = root
self.df = df
self.phase = phase
self.transfoms = get_augmentation()
def __getitem__(self, idx):
img_path = osp.join(self.root, self.df.iloc[idx].name)
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = self.transform(image=img)
label = osp.join(eslf.root, self.df.iloc[idx].value)
ret = {'image': img, 'label': label}
return ret
def __len__(self):
return len(self.df)
Diesmal ist der Rückgabewert vom Typ dict, aber es gibt kein Problem mit "return image, label". Bei der Segmentierung usw. ist es erforderlich, das Etikett als Maskenbild anzugeben, damit in diesem Fall auch das Maskenbild übertragen wird.
DataLoader
Die von Datset abgerufenen Daten können als Argument von DataLoader verwendet werden. Die Argumentstruktur von DataLoader lautet wie folgt
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
Erstellen Sie daher die folgende Funktion.
def dataloader(dir_path,phase,batch_size, num_workers, shuffle=False):
df_path = osp.join(dir_path, 'train.csv')
df = pd.read_csv(df_path)
dataset = MyDataManager(dir_path, df, phase)
dl = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=shuffle)
return dl
Recommended Posts