[PYTHON] Comprendre le DataSet et le DataLoader de PyTorch (2)

Jusqu'à la dernière fois

Jusqu'à la dernière fois, j'ai compris le comportement des DataLoader et DataSet de PyTorch. Cette fois, appliquons-le et créons votre propre ensemble de données. J'ai peut-être mentionné la source de ici.

Créons notre propre jeu de données

Je sens que je peux faire quelque chose d'un peu élaboré avec le contenu jusqu'à la dernière fois. Rendons possible un bon retour des données en créant un jeu de données.

Créer un échantillon qui renvoie les données MNIST par paires

Dans la tendance récente du Metric Learning, il est nécessaire de faire une paire d'images. Diverses méthodes ont été proposées, mais j'estime qu'il n'y a pas beaucoup de bon code à essayer pour le moment. Donc, cette fois, facilitons la gestion des paires en créant vous-même un ensemble de données à titre d'exemple.

Créer une classe PairMnistDataset

Commencez par créer une classe. Hériter du DataSet de Torch. En plus de cela, le constructeur doit recevoir le jeu de données MNIST. La paire positive et la paire négative de Metric Learning ont la relation suivante.

Nom Contenu
Positive Pair Même étiquette
Negative Pair Étiquette non identique

Puisque je veux mélanger les données d'entraînement, je n'ai besoin que de créer la relation de position des étiquettes dans le constructeur, et pour les données de test, je n'ai besoin que de créer d'abord le modèle Pair, donc je vais créer une liste d'index.

from torch.utils.data import Dataset

class PairMnistDataset(Dataset):
    def __init__(self, mnist_dataset, train=True):
        self.train = train
        self.dataset = mnist_dataset
        self.transform = mnist_dataset.transform

        if self.train:
            self.train_data = self.dataset.train_data
            self.train_labels = self.dataset.train_labels
            self.train_label_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.train_label_set}
        else:
            self.test_data = self.dataset.test_data
            self.test_labels = self.dataset.test_labels
            self.test_label_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.test_label_set}

            #Je ne vais pas mélanger, alors décidez d'abord de la paire
            positive_pairs = [[i,
                               np.random.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               np.random.choice(self.label_to_indices[np.random.choice(list(self.test_label_set - set([self.test_labels[i].item()])))]),
                               0]
                              for i in range(1, len(self.test_data), 2)]

            self.test_pairs = positive_pairs + negative_pairs

Faire __getitem__

Faisons le «getitem» que nous avons étudié dans l'article précédent. Tout ce que vous avez à faire est de décrire les données à renvoyer lorsque l'index est passé.

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)

            # img1,label1 sera décidé en premier
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                # positive pair
                #Traitement pour sélectionner des index avec la même étiquette
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                # negative pair
                #Traitement pour sélectionner des index avec des étiquettes différentes
                siamese_label = np.random.choice(list(self.train_label_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])

            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return (img1, img2), target  #Si les étiquettes de l'apprentissage métrique sont les mêmes

    def __len__(self):
        return len(self.dataset)

Essayez d'appeler l'ensemble de données et le chargeur de données dans main

Tout ce que vous avez à faire est d'appeler ce que vous avez fait jusqu'à présent. Le code est long et compliqué jusqu'à présent, mais je pense que si vous l'utilisez bien, vous pouvez charger les données en douceur.

def main():
    #L'habituel au début
    train_dataset = datasets.MNIST(
        '~/dataset/MNIST',  #Changer le cas échéant
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))

    test_dataset = datasets.MNIST(
        '~/dataset/MNIST',  #Changer le cas échéant
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))

    #Ensemble de données et chargeur de données personnalisés
    pair_train_dataset = PairMnistDataset(train_dataset, train=True)
    pair_train_loader = torch.utils.data.DataLoader(
        pair_train_dataset,
        batch_size=16
    )

    pair_test_dataset = PairMnistDataset(test_dataset, train=False)
    pair_test_loader = torch.utils.data.DataLoader(
        pair_test_dataset,
        batch_size=16
    )

    #Par exemple, vous pouvez l'appeler comme ceci
    for (data1, data2), label in pair_train_loader:
        print(data1.shape)
        print(data2.shape)
        print(label)

Cliquez ici pour afficher les résultats. Il est renvoyé correctement sous forme de paire et l'indicateur indiquant s'ils ont ou non la même étiquette est également renvoyé. Si vous utilisez ces données, vous pouvez facilement effectuer un apprentissage métrique.

    torch.Size([16, 1, 28, 28])
    torch.Size([16, 1, 28, 28])
    tensor([1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1])

Résumé

La dernière fois, cette fois, c'était assez long, mais c'était un article sur la compréhension des DataLoader et DataSet de PyTorch. Que diriez-vous de lire les données comme celles-ci pour le très populaire Metric Learning?

Recommended Posts

Comprendre le DataSet et le DataLoader de PyTorch (2)
Comprendre le DataSet et le DataLoader de PyTorch (1)
Comprendre t-SNE et améliorer la visualisation
Apprenez à connaître les packages et les modules Python
[Pytorch] Mémo sur Dataset / DataLoader
[Python / matplotlib] Comprendre et utiliser FuncAnimation
Comprendre les rouages et les extensions dans discord.py
Comprendre les règles et les fonctions convexes d'Armijo