[PYTHON] I tried to implement reading Dataset with PyTorch

Overview

It's been half a year since I started studying machine learning, and I managed to make a Dataset with PyTorch, so I'll post it as a reminder. When I was studying GAN, I was studying by dropping the code from GitHub, but since I was only reading MNIST and CIFAR, I wanted to execute it with my own dataset, so I made my own Dataset. (I don't know because the practice of posting articles is also an article for some time ...)

environment

Dataset prerequisites

So, the basic configuration is like this.


class MyDataset(torch.utils.data.Dataset):

    def __init__(self, imageSize, dir_path, transform=None):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

In addition to the Path to the data, we passed the image input size and transform for preprocessing as arguments to the class.

Constructor definition

The constructor, which is automatically called when the class is created, performs the following processing.

    def __init__(self, imageSize, dir_path, transform=None):
        self.transform = transforms.Compose([
            transforms.Resize(imageSize), #Image resizing
            transforms.ToTensor(), #Tensorization
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), #Standardization
        ])

        #Enter the input data and label here
        self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]

        self.data_num = len(self.image_paths) #Here is__len__Becomes the return value of
        self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
        self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}

I had multi-classified material data, so I used it.

Definition of \ _ \ _ getitem \ _ \ _

Since \ _ \ _ getitem \ _ \ _ is a method for reading data and its correct label during learning, we will implement it using the information read by the constructor.


    def __getitem__(self, idx):
        p = self.image_paths[idx]
        image = Image.open(p)

        if self.transform:
            out_data = self.transform(image)

        out_label = p.split("\\")
        out_label = self.class_to_idx[out_label[3]]

        return out_data, out_label

I think that the image data can be read by the constructor, but I was worried about the memory when there was a lot of data, so I decided to read it each time. I also use a slightly annoying method of making a dictionary for class labels.

Handing over to DataLoader

When you actually read it in the code, you can use it for learning by using it as follows. (DataLoader argument shuffle randomizes how data is referenced)

    data_set = MyDataset(32, dir_path=root_data)
    dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)

Summary source code

import torch.utils.data
import torchvision.transforms as transforms
from pathlib import Path
from PIL import Image

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, imageSize, dir_path, transform=None):
        self.transform = transforms.Compose([
            transforms.Resize(imageSize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]

        self.data_num = len(self.image_paths)
        self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
        self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}


    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        p = self.image_paths[idx]
        image = Image.open(p)

        if self.transform:
            out_data = self.transform(image)

        out_label = p.split("\\")
        out_label = self.class_to_idx[out_label[3]]

        return out_data, out_label

if __name__ == "__main__":
    root_data = 'Path to data'
    data_set = MyDataset(32, dir_path=root_data)
    dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)

Reference site

I implemented it while looking at the following site. Thank you very much. Explanation of transforms, Datasets, Dataloader of pyTorch and creation and use of self-made Dataset PyTorch: Dataset and DataLoader (Image Processing Task)

Recommended Posts

I tried to implement reading Dataset with PyTorch
I tried to implement SSD with PyTorch now (Dataset)
I tried to implement CVAE with PyTorch
I tried to explain Pytorch dataset
I tried to implement and learn DCGAN with PyTorch
I tried to implement Autoencoder with TensorFlow
I tried to implement SSD with PyTorch now (model edition)
I tried to implement sentence classification by Self Attention with PyTorch
I tried to implement PCANet
I tried to implement StarGAN (1)
I tried to move Faster R-CNN quickly with pytorch
I tried to implement Minesweeper on terminal with python
I tried to implement an artificial perceptron with python
I tried to implement time series prediction with GBDT
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I tried to implement Grad-CAM with keras and tensorflow
I tried to implement Deep VQE
I tried to implement adversarial validation
I tried implementing DeepPose with PyTorch
I tried to implement hierarchical clustering
I tried to implement Realness GAN
I tried to implement breakout (deception avoidance type) with Quantx
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement ListNet of rank learning with Chainer
I tried to implement Harry Potter sort hat with CNN
I tried to implement PLSA in Python
I tried to implement permutation in Python
I tried to visualize AutoEncoder with TensorFlow
I tried to get started with Hy
I tried to implement PLSA in Python 2
I tried to implement ADALINE in Python
[Introduction to Pytorch] I played with sinGAN ♬
I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch PartⅡ
I tried to implement PPO in Python
I tried to solve TSP with QAOA
I tried to predict next year with AI
I tried to use lightGBM, xgboost with Boruta
I tried to implement deep learning that is not deep with only NumPy
I tried to learn logical operations with TF Learn
I tried to move GAN (mnist) with keras
I tried to implement a blockchain that actually works with about 170 lines
I tried to save the data with discord
I tried to detect motion quickly with OpenCV
I tried to integrate with Keras in TFv1.1
I tried to get CloudWatch data with Python
I tried to output LLVM IR with Python
I tried to debug.
I tried to implement TOPIC MODEL in Python
I tried to detect an object with M2Det!
I tried to automate sushi making with python
I tried to predict Titanic survival with PyCaret
I tried to paste
I tried to operate Linux with Discord Bot
I tried to implement selection sort in python
I tried to study DP with Fibonacci sequence
I tried to start Jupyter with Amazon lightsail
I tried to judge Tsundere with Naive Bayes
I tried to implement the traveling salesman problem
I tried to implement merge sort in Python with as few lines as possible
I tried to implement Cifar10 with SONY Deep Learning library NNabla [Nippon Hurray]