PyTorchs DataLoader und DataSet, die ich ohne viel Verständnis verwendet habe, als ich es bemerkte. Ich würde es begrüßen, wenn Sie sich darauf beziehen könnten, wenn Sie etwas Aufwändiges tun möchten.
Der zweite Teil ist hier.
Wenn Sie PyTorch verwenden, haben Sie wahrscheinlich DataLoader gesehen. Das PyTorch-Beispiel von MNIST, das jeder für maschinelles Lernen verwendet, enthält ebenfalls diese Beschreibung.
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=256,
shuffle=True)
Oder wenn Sie mit Qiita usw. suchen, werden Sie diese Schreibweise sehen.
train_dataset = datasets.MNIST(
'~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=256,
shuffle=True)
Danach können Sie sehen, wie die Daten im Stapel erfasst und gelernt werden, indem Sie sie mit der For-Anweisung drehen.
for epoch in epochs:
for img, label in train_loader:
#Beschreiben Sie hier den Lernprozess
Mündlich gesprochen ist DataLoader ein praktischer Typ, der einer bestimmten Regel folgt und die Daten wie im DataSet beschrieben überträgt. Im obigen Beispiel werden beispielsweise 256 MNIST-Daten (Mini-Batch) in das Bild aufgenommen und mit normalisierten Daten beschriftet. Schauen wir uns den Inhalt an, um zu sehen, wie dies erreicht wird.
Werfen wir einen Blick auf die DataLoader-Implementierung. Sie können sofort sehen, dass es in der Klasse ist.
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
"""
#Kürzung
Ich werde die Details weglassen, aber wenn Sie sich die Informationen als Iterator genauer ansehen, finden Sie die folgende Implementierung.
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
Wenn dieses __next__
aufgerufen wird, werden Daten zurückgegeben.
Und diese Daten scheinen durch Übergabe des Index an den Datensatz erstellt zu werden.
Zu diesem Zeitpunkt müssen Sie nicht so nervös sein, wie der Index erstellt und wie der Datensatz aufgerufen wird. Da dies jedoch eine große Sache ist, gehen wir noch einen Schritt weiter.
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
Der Index wird an den Datensatz übergeben. Das Aufrufen einer Instanz der Klasse auf diese Weise sollte bedeuten, dass "getitem" im Dataset aufgerufen wird. (Hier wird detailliert beschrieben. Gehen wir zu dem darauf basierenden Datensatz.
Sobald Sie zur Definition von MNIST gehen, können Sie sehen, dass es sich um Klasse handelt.
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
"""
#Kürzung
Schauen wir uns __getitem__
an.
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
Es ist leicht verständlich geschrieben, dass der Index übergeben und die entsprechenden Daten zurückgegeben werden. Ich sehe, so werden MNIST-Daten zurückgegeben. Wenn Sie sich den Prozess für einen Moment ansehen, wird auch Image.fromarray von PIL geschrieben. Mit anderen Worten, wenn Sie dieses getitem entwerfen und schreiben, können Sie beliebige Daten zurückgeben.
Aber es gibt noch etwas, das ich nicht verstehe. Wie wird der Index erstellt? Der Hinweis ist hier.
if sampler is None: # give default samplers
if self.dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
self.sampler = sampler
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
Der Index scheint durch Sampler erstellt zu werden. Standardmäßig wird der Sampler durch das Argument True, False (Shuffle) umgeschaltet. Schauen wir uns zum Beispiel die Implementierung an, wenn shuffle = False ist.
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
Die Datenquelle hier ist ein Datensatz. Es sieht so aus, als wäre ich so weit gekommen und hätte eine allgemeine Vorstellung. Mit anderen Worten, es wird für die Länge des Datensatzes wiederholt. Umgekehrt scheint es notwendig zu sein, eine spezielle Methode namens "len" im Datensatz vorzubereiten.
Lassen Sie uns __len__
in datasets.MNIST überprüfen.
def __len__(self):
return len(self.data)
Sie geben die Datenlänge zurück. Da Daten in MNIST eine Größe von 60000 x 28 x 28 haben, wird 60000 zurückgegeben. Es war ziemlich erfrischend.
Der Artikel wird länger, das ist es also für den ersten Teil. In Teil 2 erstellen Sie Ihren eigenen Datensatz.
Recommended Posts