Cela fait six mois que j'ai commencé à étudier le machine learning, et j'ai réussi à créer un Dataset avec PyTorch, je le posterai donc comme rappel. Quand j'étudiais le GAN, j'étudiais en supprimant le code de GitHub, mais comme je ne lisais que MNIST et CIFAR, je voulais l'exécuter avec mon propre jeu de données, j'ai donc créé mon propre jeu de données. (Je ne sais pas car c'est un article que je pratique depuis un certain temps en postant des articles ...)
Héritage du jeu de données PyTorch Pour transmettre l'objet de cette classe d'héritage de Dataset à DataLoader lors de sa transmission au modèle d'apprentissage
\ _ \ _ Getitem \ _ \ _ et \ _ \ _len \ _ \ _ méthodes \ _ \ _ Getitem \ _ \ _ est une méthode qui renvoie des données et des étiquettes dans tapple \ _ \ _ Len \ _ \ _ signifie tel quel, une méthode qui renvoie le nombre de données
Donc, la configuration de base est comme ça.
class MyDataset(torch.utils.data.Dataset):
def __init__(self, imageSize, dir_path, transform=None):
pass
def __len__(self):
pass
def __getitem__(self, idx):
pass
En plus du chemin d'accès aux données, j'ai passé la taille d'entrée de l'image et la transformation pour le prétraitement comme arguments de la classe.
Le constructeur, qui est automatiquement appelé lors de la création d'une classe, effectue le traitement suivant.
def __init__(self, imageSize, dir_path, transform=None):
self.transform = transforms.Compose([
transforms.Resize(imageSize), #Redimensionnement de l'image
transforms.ToTensor(), #Tensorisation
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), #Standardisation
])
#Entrez les données d'entrée et l'étiquette ici
self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]
self.data_num = len(self.image_paths) #Voici__len__Sera la valeur de retour de
self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}
J'avais des données matérielles multi-classifiées, alors je les ai utilisées.
Puisque \ _ \ _ getitem \ _ \ _ est une méthode de lecture des données et de son étiquette de réponse correcte pendant l'entraînement, nous l'implémenterons en utilisant les informations lues par le constructeur.
def __getitem__(self, idx):
p = self.image_paths[idx]
image = Image.open(p)
if self.transform:
out_data = self.transform(image)
out_label = p.split("\\")
out_label = self.class_to_idx[out_label[3]]
return out_data, out_label
Je pense que c'est correct de lire les données d'image avec le constructeur, mais j'étais inquiète pour la mémoire quand il y avait beaucoup de données, alors j'ai décidé de la lire à chaque fois. J'utilise également une méthode légèrement ennuyeuse pour créer un dictionnaire pour les étiquettes de classe.
Lorsque vous le lisez réellement dans le code, vous pouvez l'utiliser pour apprendre en l'utilisant comme suit. (L'argument DataLoader shuffle rend aléatoire la façon dont les données sont référencées)
data_set = MyDataset(32, dir_path=root_data)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)
import torch.utils.data
import torchvision.transforms as transforms
from pathlib import Path
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
def __init__(self, imageSize, dir_path, transform=None):
self.transform = transforms.Compose([
transforms.Resize(imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
self.image_paths = [str(p) for p in Path(dir_path).glob("**/*.png ")]
self.data_num = len(self.image_paths)
self.classes = ['carpet', 'ceramic', 'cloth', 'dirt', 'drywall', 'glass', 'grass', 'gravel', 'leaf', 'metal']
self.class_to_idx = {'carpet':0, 'ceramic':1, 'cloth':2, 'dirt':3, 'drywall':4, 'glass':5, 'grass':6,'gravel':7, 'leaf':8, 'metal':9}
def __len__(self):
return self.data_num
def __getitem__(self, idx):
p = self.image_paths[idx]
image = Image.open(p)
if self.transform:
out_data = self.transform(image)
out_label = p.split("\\")
out_label = self.class_to_idx[out_label[3]]
return out_data, out_label
if __name__ == "__main__":
root_data = 'Chemin d'accès aux données'
data_set = MyDataset(32, dir_path=root_data)
dataloader = torch.utils.data.DataLoader(data_set, batch_size=100, shuffle=True)
Je l'ai implémenté en regardant le site suivant. Merci beaucoup. Explication des transformations, des ensembles de données, du chargeur de données de pyTorch et de la création et de l'utilisation d'un ensemble de données personnalisé PyTorch: Dataset and DataLoader (tâche de traitement d'image)
Recommended Posts