[PYTHON] Understand PyTorch's DataSet and DataLoader (2)

Until last time

Up to Last time, I have understood the behavior of PyTorch's DataLoader and DataSet. This time, let's apply it and create your own dataset. Maybe I referred to the source of here.

Let's make our own dataset

I feel that I can do something a little elaborate with the contents up to the last time. Let's make a dataset by yourself so that it can return data well.

Make a sample that returns MNIST data in pairs

In the recent trend of Metric Learning, it is necessary to make a pair of images. Various methods have been proposed, but I feel that there is not much good code to try for the time being. So, this time, let's make it easy to handle pairs by creating a dataset by yourself as an example.

Create a PairMnistDataset class

First, create a class. Inherit Torch's DataSet. On top of that, the constructor should receive the MNIST dataset. Metric Learning's Positive Pair and Negative Pair have the following relationship.

name Contents
Positive Pair Same label
Negative Pair Non-identical label

Since I want to Shuffle the Training data, I only need to create the positional relationship of the labels in the constructor, and for the Test data, I only need to create the Pair pattern first, so I will create a list of indexes.

from torch.utils.data import Dataset

class PairMnistDataset(Dataset):
    def __init__(self, mnist_dataset, train=True):
        self.train = train
        self.dataset = mnist_dataset
        self.transform = mnist_dataset.transform

        if self.train:
            self.train_data = self.dataset.train_data
            self.train_labels = self.dataset.train_labels
            self.train_label_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.train_label_set}
        else:
            self.test_data = self.dataset.test_data
            self.test_labels = self.dataset.test_labels
            self.test_label_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.test_label_set}

            #I will not shuffle, so decide the pair first
            positive_pairs = [[i,
                               np.random.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               np.random.choice(self.label_to_indices[np.random.choice(list(self.test_label_set - set([self.test_labels[i].item()])))]),
                               0]
                              for i in range(1, len(self.test_data), 2)]

            self.test_pairs = positive_pairs + negative_pairs

Make __getitem__

Let's make the __getitem__ that we studied in the previous article. All you have to do is describe what data to return when the index is passed.

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)

            # img1,label1 will be decided first
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                # positive pair
                #Processing to select indexes with the same label
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                # negative pair
                #Processing to select indexes with different labels
                siamese_label = np.random.choice(list(self.train_label_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])

            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return (img1, img2), target  #Whether the metric learning labels are the same

    def __len__(self):
        return len(self.dataset)

Try calling dataset and dataloader in main

All you have to do is call what you have made so far. The code is long and complicated so far, but I think that if you use it well, you can load the data smoothly.

def main():
    #The usual one at first
    train_dataset = datasets.MNIST(
        '~/dataset/MNIST',  #Change as appropriate
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))

    test_dataset = datasets.MNIST(
        '~/dataset/MNIST',  #Change as appropriate
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]))

    #Self-made dataset and dataloader
    pair_train_dataset = PairMnistDataset(train_dataset, train=True)
    pair_train_loader = torch.utils.data.DataLoader(
        pair_train_dataset,
        batch_size=16
    )

    pair_test_dataset = PairMnistDataset(test_dataset, train=False)
    pair_test_loader = torch.utils.data.DataLoader(
        pair_test_dataset,
        batch_size=16
    )

    #For example, you can call it like this
    for (data1, data2), label in pair_train_loader:
        print(data1.shape)
        print(data2.shape)
        print(label)

Click here for the result display. It is returned as a pair properly, and the flag of whether each label is the same or not is also returned. If you use this data, you can easily do Metric Learning.

    torch.Size([16, 1, 28, 28])
    torch.Size([16, 1, 28, 28])
    tensor([1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1])

Summary

Last time, this time it was quite long, but it was an article about understanding PyTorch's DataLoader and DataSet. How about reading data like this for the recently popular Metric Learning?

Recommended Posts

Understand PyTorch's DataSet and DataLoader (2)
Understand PyTorch's DataSet and DataLoader (1)
Understand t-SNE and improve visualization
Understand Python packages and modules
[Pytorch] Memo about Dataset / DataLoader
[Python / matplotlib] Understand and use FuncAnimation
Understand Cog and Extension in discord.py
Understand Armijo's rules and convex functions