[PYTHON] Comprendre le DataSet et le DataLoader de PyTorch (1)

introduction

DataLoader et DataSet de PyTorch que j'ai utilisés sans trop comprendre quand j'ai remarqué. Je vous serais reconnaissant si vous pouviez vous y référer si vous voulez faire quelque chose d'un peu élaboré.

La deuxième partie est ici.

Vérifiez l'exemple PyTorch

Si vous utilisez PyTorch, vous avez probablement vu DataLoader. L'exemple PyTorch de MNIST, que tout le monde utilise pour l'apprentissage automatique, a également cette description.

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('~/dataset/MNIST',
                    train=True,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=256,
    shuffle=True)

Ou si vous recherchez avec Qiita etc., vous verrez cette façon d'écrire.

train_dataset = datasets.MNIST(
    '~/dataset/MNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))
    
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True)

Après cela, vous pouvez voir comment les données sont acquises et apprises par lots en les tournant avec l'instruction For.

for epoch in epochs:
    for img, label in train_loader:
        #Décrivez le processus d'apprentissage dans ce

Parlant verbalement, DataLoader est un gars pratique qui suit une certaine règle et transporte les données comme décrit dans le DataSet. Par exemple, dans l'exemple ci-dessus, 256 données MNIST (mini-lot) seront incluses dans img et label avec des données normalisées. Jetons un coup d'œil au contenu pour voir comment cela est réalisé.

Jetez un œil à torch.utils.data.DataLoader

Jetons un coup d'œil à l'implémentation de DataLoader. Vous pouvez immédiatement voir qu'il est en classe.

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.
    """
    #réduction

Je vais omettre les détails, mais si vous regardez de plus près les informations en tant qu'itérateur, vous trouverez l'implémentation suivante.

def __next__(self):
    index = self._next_index()  # may raise StopIteration
    data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
    if self.pin_memory:
        data = _utils.pin_memory.pin_memory(data)
    return data

Lorsque ce __next__ est appelé, les données sont renvoyées. Et ces données semblent être créées en passant l'index à l'ensemble de données. À ce stade, vous n'avez pas à être si nerveux à propos de la façon dont l'index est créé et de la façon dont le jeu de données est appelé, mais comme c'est un gros problème, allons plus loin.

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
            if self.auto_collation:
                data = [self.dataset[idx] for idx in possibly_batched_index]
            else:
                data = self.dataset[possibly_batched_index]
            return self.collate_fn(data)

L'index est passé à l'ensemble de données. Appeler une instance de la classe de cette manière devrait signifier que __getitem__ est appelé dans l'ensemble de données. (Ici est détaillé. Passons à l'ensemble de données basé sur cela.

Jetez un œil aux ensembles de données.

Dès que vous accédez à la définition de MNIST, vous pouvez voir qu'il s'agit d'une classe.

class MNIST(VisionDataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    """

    #réduction

Allons voir __getitem__.

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], int(self.targets[index])

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

Il est écrit d'une manière facile à comprendre que l'index est passé et les données correspondantes sont renvoyées. Je vois, c'est ainsi que les données MNIST sont renvoyées. En regardant le processus pendant un moment, Image.fromarray de PIL est également écrit. En d'autres termes, si vous concevez et écrivez ce __getitem__, il est possible de renvoyer n'importe quelle donnée.

Jetez à nouveau un œil à torch.utils.data.DataLoader

Mais il y a encore quelque chose que je ne comprends pas. Comment l'index est-il créé? L'indice est ici.

if sampler is None:  # give default samplers
    if self.dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)

self.sampler = sampler
@property
def _index_sampler(self):
    # The actual sampler used for generating indices for `_DatasetFetcher`
    # (see _utils/fetch.py) to read data at each time. This would be
    # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
    # We can't change `.sampler` and `.batch_sampler` attributes for BC
    # reasons.
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

L'index semble être créé par échantillonneur. Par défaut, l'échantillonneur est commuté par l'argument True, False appelé shuffle. Par exemple, regardons l'implémentation lorsque shuffle = False.

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

Le data_source ici est un ensemble de données. Il semble que je sois arrivé jusqu'ici et que j'ai une idée générale. En d'autres termes, il se répète pour la longueur de l'ensemble de données. Inversement, il semble nécessaire de préparer une méthode spéciale appelée «len» dans l'ensemble de données.

Jetez un œil aux ensembles de données.

Vérifions __len__ dans datasets.MNIST.

def __len__(self):
        return len(self.data)

Vous renvoyez la longueur des données. Étant donné que les données dans MNIST ont une taille de 60000x28x28, 60000 seront renvoyées. Cela a été assez rafraîchissant.

Continuer à la prochaine fois

L'article s'allonge, c'est donc tout pour la première partie. Dans Partie 2, vous allez créer votre propre ensemble de données.

Recommended Posts

Comprendre le DataSet et le DataLoader de PyTorch (2)
Comprendre le DataSet et le DataLoader de PyTorch (1)
Comprendre t-SNE et améliorer la visualisation
Apprenez à connaître les packages et les modules Python
[Pytorch] Mémo sur Dataset / DataLoader
[Python / matplotlib] Comprendre et utiliser FuncAnimation
Comprendre les rouages et les extensions dans discord.py
Comprendre les règles et les fonctions convexes d'Armijo