[PYTHON] Understand PyTorch's DataSet and DataLoader (1)

Introduction

It is PyTorch's DataLoader and DataSet that I used without much understanding when I noticed it. I would appreciate it if you could refer to it if you want to do something a little elaborate.

The second part is here.

Check PyTorch Example

If you're using PyTorch, you've probably seen DataLoader. MNIST's PyTorch example, which everyone uses for machine learning, also has this description.

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('~/dataset/MNIST',
                    train=True,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=256,
    shuffle=True)

Or if you search with Qiita etc., you will see this way of writing.

train_dataset = datasets.MNIST(
    '~/dataset/MNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))
    
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True)

After that, you can see how the data is acquired and learned in batch by turning it with a For statement.

for epoch in epochs:
    for img, label in train_loader:
        #Describe the learning process in this

Colloquially, a DataLoader is a handy guy who follows a certain rule and carries the data as described in the DataSet. For example, in the above example, 256 MNIST data (mini-batch) will be included in img and label with Normalized data. Let's take a look at the contents to see how that is achieved.

Take a look at torch.utils.data.DataLoader

Let's take a look at the DataLoader implementation. You can immediately see that it is in class.

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.
    """
    #abridgement

I will omit the details, but if you take a closer look at the information as an iterator, you will find the following implementation.

def __next__(self):
    index = self._next_index()  # may raise StopIteration
    data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
    if self.pin_memory:
        data = _utils.pin_memory.pin_memory(data)
    return data

When this __next__ is called, data is returned. And this data seems to be created by passing index to dataset. At this stage, you don't have to be so nervous about how the index is created and how the dataset is called, but since it's a big deal, let's go one step further.

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
            if self.auto_collation:
                data = [self.dataset[idx] for idx in possibly_batched_index]
            else:
                data = self.dataset[possibly_batched_index]
            return self.collate_fn(data)

The index is passed to the dataset. Calling an instance of a class like this means that the dataset should be calling __getitem__. (Here is detailed. Let's go to the dataset based on this.

Take a look at datasets.MNIST

As soon as you go to the definition of MNIST, you can see that it is a Class.

class MNIST(VisionDataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    """

    #abridgement

Let's go see __getitem__.

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], int(self.targets[index])

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

It is written in an easy-to-understand manner that the index is passed and the corresponding data is returned. I see, this is how MNIST data is returned. Looking at the process for a moment, PIL's Image.fromarray is also written. In other words, if you devise and write this __getitem__, it is possible to return free data.

Take a look at torch.utils.data.DataLoader again

But there is still something I don't understand. How is the index created? The hint is here.

if sampler is None:  # give default samplers
    if self.dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)

self.sampler = sampler
@property
def _index_sampler(self):
    # The actual sampler used for generating indices for `_DatasetFetcher`
    # (see _utils/fetch.py) to read data at each time. This would be
    # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
    # We can't change `.sampler` and `.batch_sampler` attributes for BC
    # reasons.
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

The index seems to be created through sampler. By default, sampler is switched by the argument True, False called shuffle. For example, let's look at the implementation when shuffle = False.

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

The data_source here is a dataset. It looks like I've come this far and got a general idea. In other words, it repeats for the length of the dataset. Conversely, it seems necessary to prepare a special method called __len__ in the dataset.

Take a look at datasets.MNIST again

Let's check __len__ in datasets.MNIST.

def __len__(self):
        return len(self.data)

You are returning the length of data. Since data in MNIST has a size of 60000x28x28, 60000 will be returned. It's been pretty refreshing.

Continue to next time

The article is getting longer, so that's it for the first part. In Part 2, you will make your own dataset.

Recommended Posts

Understand PyTorch's DataSet and DataLoader (2)
Understand PyTorch's DataSet and DataLoader (1)
Understand t-SNE and improve visualization
Understand Python packages and modules
[Pytorch] Memo about Dataset / DataLoader
[Python / matplotlib] Understand and use FuncAnimation
Understand Cog and Extension in discord.py
Understand Armijo's rules and convex functions