[PYTHON] Grundlegendes zu PyTorchs DataSet und DataLoader (1)

Einführung

PyTorchs DataLoader und DataSet, die ich ohne viel Verständnis verwendet habe, als ich es bemerkte. Ich würde es begrüßen, wenn Sie sich darauf beziehen könnten, wenn Sie etwas Aufwändiges tun möchten.

Der zweite Teil ist hier.

Überprüfen Sie das PyTorch-Beispiel

Wenn Sie PyTorch verwenden, haben Sie wahrscheinlich DataLoader gesehen. Das PyTorch-Beispiel von MNIST, das jeder für maschinelles Lernen verwendet, enthält ebenfalls diese Beschreibung.

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('~/dataset/MNIST',
                    train=True,
                    download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=256,
    shuffle=True)

Oder wenn Sie mit Qiita usw. suchen, werden Sie diese Schreibweise sehen.

train_dataset = datasets.MNIST(
    '~/dataset/MNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))
    
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True)

Danach können Sie sehen, wie die Daten im Stapel erfasst und gelernt werden, indem Sie sie mit der For-Anweisung drehen.

for epoch in epochs:
    for img, label in train_loader:
        #Beschreiben Sie hier den Lernprozess

Mündlich gesprochen ist DataLoader ein praktischer Typ, der einer bestimmten Regel folgt und die Daten wie im DataSet beschrieben überträgt. Im obigen Beispiel werden beispielsweise 256 MNIST-Daten (Mini-Batch) in das Bild aufgenommen und mit normalisierten Daten beschriftet. Schauen wir uns den Inhalt an, um zu sehen, wie dies erreicht wird.

Schauen Sie sich torch.utils.data.DataLoader an

Werfen wir einen Blick auf die DataLoader-Implementierung. Sie können sofort sehen, dass es in der Klasse ist.

class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides an iterable over
    the given dataset.
    """
    #Kürzung

Ich werde die Details weglassen, aber wenn Sie sich die Informationen als Iterator genauer ansehen, finden Sie die folgende Implementierung.

def __next__(self):
    index = self._next_index()  # may raise StopIteration
    data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
    if self.pin_memory:
        data = _utils.pin_memory.pin_memory(data)
    return data

Wenn dieses __next__ aufgerufen wird, werden Daten zurückgegeben. Und diese Daten scheinen durch Übergabe des Index an den Datensatz erstellt zu werden. Zu diesem Zeitpunkt müssen Sie nicht so nervös sein, wie der Index erstellt und wie der Datensatz aufgerufen wird. Da dies jedoch eine große Sache ist, gehen wir noch einen Schritt weiter.

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
            if self.auto_collation:
                data = [self.dataset[idx] for idx in possibly_batched_index]
            else:
                data = self.dataset[possibly_batched_index]
            return self.collate_fn(data)

Der Index wird an den Datensatz übergeben. Das Aufrufen einer Instanz der Klasse auf diese Weise sollte bedeuten, dass "getitem" im Dataset aufgerufen wird. (Hier wird detailliert beschrieben. Gehen wir zu dem darauf basierenden Datensatz.

Schauen Sie sich datasets.MNIST an

Sobald Sie zur Definition von MNIST gehen, können Sie sehen, dass es sich um Klasse handelt.

class MNIST(VisionDataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    """

    #Kürzung

Schauen wir uns __getitem__ an.

def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], int(self.targets[index])

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

Es ist leicht verständlich geschrieben, dass der Index übergeben und die entsprechenden Daten zurückgegeben werden. Ich sehe, so werden MNIST-Daten zurückgegeben. Wenn Sie sich den Prozess für einen Moment ansehen, wird auch Image.fromarray von PIL geschrieben. Mit anderen Worten, wenn Sie dieses getitem entwerfen und schreiben, können Sie beliebige Daten zurückgeben.

Schauen Sie sich noch einmal torch.utils.data.DataLoader an

Aber es gibt noch etwas, das ich nicht verstehe. Wie wird der Index erstellt? Der Hinweis ist hier.

if sampler is None:  # give default samplers
    if self.dataset_kind == _DatasetKind.Iterable:
        # See NOTE [ Custom Samplers and IterableDataset ]
        sampler = _InfiniteConstantSampler()
    else:  # map-style
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)

self.sampler = sampler
@property
def _index_sampler(self):
    # The actual sampler used for generating indices for `_DatasetFetcher`
    # (see _utils/fetch.py) to read data at each time. This would be
    # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
    # We can't change `.sampler` and `.batch_sampler` attributes for BC
    # reasons.
    if self._auto_collation:
        return self.batch_sampler
    else:
        return self.sampler

Der Index scheint durch Sampler erstellt zu werden. Standardmäßig wird der Sampler durch das Argument True, False (Shuffle) umgeschaltet. Schauen wir uns zum Beispiel die Implementierung an, wenn shuffle = False ist.

class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.

    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

Die Datenquelle hier ist ein Datensatz. Es sieht so aus, als wäre ich so weit gekommen und hätte eine allgemeine Vorstellung. Mit anderen Worten, es wird für die Länge des Datensatzes wiederholt. Umgekehrt scheint es notwendig zu sein, eine spezielle Methode namens "len" im Datensatz vorzubereiten.

Schauen Sie sich noch einmal Datasets.MNIST an

Lassen Sie uns __len__ in datasets.MNIST überprüfen.

def __len__(self):
        return len(self.data)

Sie geben die Datenlänge zurück. Da Daten in MNIST eine Größe von 60000 x 28 x 28 haben, wird 60000 zurückgegeben. Es war ziemlich erfrischend.

Fahren Sie mit dem nächsten Mal fort

Der Artikel wird länger, das ist es also für den ersten Teil. In Teil 2 erstellen Sie Ihren eigenen Datensatz.

Recommended Posts

Grundlegendes zu PyTorchs DataSet und DataLoader (2)
Grundlegendes zu PyTorchs DataSet und DataLoader (1)
Verstehen Sie t-SNE und verbessern Sie die Visualisierung
Lernen Sie Python-Pakete und -Module kennen
[Pytorch] Memo über Dataset / DataLoader
[Python / matplotlib] FuncAnimation verstehen und verwenden
Verstehen Sie Zahnräder und Erweiterungen in discord.py
Verstehe die Regeln und konvexen Funktionen von Armijo