[PYTHON] J'ai essayé d'implémenter la lecture de Dataset avec PyTorch

Aperçu

Cela fait six mois que j'ai commencé à étudier le machine learning, et j'ai réussi à créer un Dataset avec PyTorch, je le posterai donc comme rappel. Quand j'étudiais le GAN, j'étudiais en supprimant le code de GitHub, mais comme je ne lisais que MNIST et CIFAR, je voulais l'exécuter avec mon propre jeu de données, j'ai donc créé mon propre jeu de données. (Je ne sais pas car c'est un article que je pratique depuis un certain temps en postant des articles ...)

environnement

Conditions préalables de l'ensemble de données

Donc, la configuration de base est comme ça.


class MyDataset(torch.utils.data.Dataset):

    def __init__(self, imageSize, dir_path, transform=None):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass

En plus du chemin d'accès aux données, j'ai passé la taille d'entrée de l'image et la transformation pour le prétraitement comme arguments de la classe.

Définition du constructeur

Le constructeur, qui est automatiquement appelé lors de la création d'une classe, effectue le traitement suivant.

    def __init__(self, imageSize, dir_path, transform=None):
        self.transform = transforms.Compose([
            transforms.Resize(imageSize), #Redimensionnement de l'image
            transforms.ToTensor(), #Tensorisation
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), #Standardisation
        ])

        #Entrez les données d'entrée et l'étiquette ici
        self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]

        self.data_num = len(self.image_paths) #Voici__len__Sera la valeur de retour de
        self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
        self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}

J'avais des données matérielles multi-classifiées, alors je les ai utilisées.

Définition de \ _ \ _ getitem \ _ \ _

Puisque \ _ \ _ getitem \ _ \ _ est une méthode de lecture des données et de son étiquette de réponse correcte pendant l'entraînement, nous l'implémenterons en utilisant les informations lues par le constructeur.


    def __getitem__(self, idx):
        p = self.image_paths[idx]
        image = Image.open(p)

        if self.transform:
            out_data = self.transform(image)

        out_label = p.split("\\")
        out_label = self.class_to_idx[out_label[3]]

        return out_data, out_label

Je pense que c'est correct de lire les données d'image avec le constructeur, mais j'étais inquiète pour la mémoire quand il y avait beaucoup de données, alors j'ai décidé de la lire à chaque fois. J'utilise également une méthode légèrement ennuyeuse pour créer un dictionnaire pour les étiquettes de classe.

Remise à DataLoader

Lorsque vous le lisez réellement dans le code, vous pouvez l'utiliser pour apprendre en l'utilisant comme suit. (L'argument DataLoader shuffle rend aléatoire la façon dont les données sont référencées)

    data_set = MyDataset(32, dir_path=root_data)
    dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)

Résumé du code source

import torch.utils.data
import torchvision.transforms as transforms
from pathlib import Path
from PIL import Image

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, imageSize, dir_path, transform=None):
        self.transform = transforms.Compose([
            transforms.Resize(imageSize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]

        self.data_num = len(self.image_paths)
        self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
        self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}


    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        p = self.image_paths[idx]
        image = Image.open(p)

        if self.transform:
            out_data = self.transform(image)

        out_label = p.split("\\")
        out_label = self.class_to_idx[out_label[3]]

        return out_data, out_label

if __name__ == "__main__":
    root_data = 'Chemin d'accès aux données'
    data_set = MyDataset(32, dir_path=root_data)
    dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)

Site de référence

Je l'ai implémenté en regardant le site suivant. Merci beaucoup. Explication des transformations, des ensembles de données, du chargeur de données de pyTorch et de la création et de l'utilisation d'un ensemble de données personnalisé PyTorch: Dataset and DataLoader (tâche de traitement d'image)

Recommended Posts

J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé d'expliquer l'ensemble de données de Pytorch
J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch
J'ai essayé d'implémenter Autoencoder avec TensorFlow
J'ai essayé d'implémenter SSD avec PyTorch maintenant (édition du modèle)
J'ai essayé d'implémenter la classification des phrases par Self Attention avec PyTorch
J'ai essayé d'implémenter PCANet
J'ai essayé d'implémenter StarGAN (1)
J'ai essayé de déplacer Faster R-CNN rapidement avec pytorch
J'ai essayé d'implémenter Mine Sweeper sur un terminal avec python
J'ai essayé d'implémenter le perceptron artificiel avec python
[Introduction à Pytorch] J'ai essayé de catégoriser Cifar10 avec VGG16 ♬
J'ai essayé d'implémenter Grad-CAM avec keras et tensorflow
J'ai essayé d'implémenter Deep VQE
J'ai essayé de mettre en place une validation contradictoire
J'ai essayé d'implémenter DeepPose avec PyTorch
J'ai essayé d'implémenter Realness GAN
J'ai essayé de mettre en œuvre une évasion (type d'évitement de tromperie) avec Quantx
J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai essayé d'implémenter ListNet d'apprentissage de rang avec Chainer
J'ai essayé de mettre en œuvre le chapeau de regroupement de Harry Potter avec CNN
J'ai essayé d'implémenter PLSA en Python
J'ai essayé d'implémenter la permutation en Python
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de commencer avec Hy
J'ai essayé d'implémenter PLSA dans Python 2
J'ai essayé d'implémenter ADALINE en Python
[Introduction à Pytorch] J'ai joué avec sinGAN ♬
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé d'implémenter PPO en Python
J'ai essayé de résoudre TSP avec QAOA
J'ai essayé de prédire l'année prochaine avec l'IA
J'ai essayé d'utiliser lightGBM, xg boost avec Boruta
J'ai essayé de mettre en œuvre un apprentissage en profondeur qui n'est pas profond avec uniquement NumPy
J'ai essayé d'apprendre le fonctionnement logique avec TF Learn
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé de mettre en œuvre une blockchain qui fonctionne réellement avec environ 170 lignes
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé de détecter rapidement un mouvement avec OpenCV
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé d'obtenir des données CloudWatch avec Python
J'ai essayé de sortir LLVM IR avec Python
J'ai essayé de déboguer.
J'ai essayé d'implémenter TOPIC MODEL en Python
J'ai essayé de détecter un objet avec M2Det!
J'ai essayé d'automatiser la fabrication des sushis avec python
J'ai essayé de prédire la survie du Titanic avec PyCaret
J'ai essayé d'utiliser Linux avec Discord Bot
J'ai essayé d'implémenter le tri sélectif en python
J'ai essayé d'étudier DP avec séquence de Fibonacci
J'ai essayé de démarrer Jupyter avec toutes les lumières d'Amazon
J'ai essayé de juger Tundele avec Naive Bays
J'ai essayé de mettre en œuvre le problème du voyageur de commerce
J'ai essayé d'implémenter le tri par fusion en Python avec le moins de lignes possible
J'ai essayé d'implémenter Cifar10 avec la bibliothèque SONY Deep Learning NNabla [Nippon Hurray]