It is PyTorch's DataLoader and DataSet that I used without much understanding when I noticed it. I would appreciate it if you could refer to it if you want to do something a little elaborate.
The second part is here.
If you're using PyTorch, you've probably seen DataLoader. MNIST's PyTorch example, which everyone uses for machine learning, also has this description.
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=256,
shuffle=True)
Or if you search with Qiita etc., you will see this way of writing.
train_dataset = datasets.MNIST(
'~/dataset/MNIST',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=256,
shuffle=True)
After that, you can see how the data is acquired and learned in batch by turning it with a For statement.
for epoch in epochs:
for img, label in train_loader:
#Describe the learning process in this
Colloquially, a DataLoader is a handy guy who follows a certain rule and carries the data as described in the DataSet. For example, in the above example, 256 MNIST data (mini-batch) will be included in img and label with Normalized data. Let's take a look at the contents to see how that is achieved.
Let's take a look at the DataLoader implementation. You can immediately see that it is in class.
class DataLoader(object):
r"""
Data loader. Combines a dataset and a sampler, and provides an iterable over
the given dataset.
"""
#abridgement
I will omit the details, but if you take a closer look at the information as an iterator, you will find the following implementation.
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
When this __next__
is called, data is returned.
And this data seems to be created by passing index to dataset.
At this stage, you don't have to be so nervous about how the index is created and how the dataset is called, but since it's a big deal, let's go one step further.
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
The index is passed to the dataset. Calling an instance of a class like this means that the dataset should be calling __getitem__
. (Here is detailed. Let's go to the dataset based on this.
As soon as you go to the definition of MNIST, you can see that it is a Class.
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
"""
#abridgement
Let's go see __getitem__
.
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
It is written in an easy-to-understand manner that the index is passed and the corresponding data is returned. I see, this is how MNIST data is returned.
Looking at the process for a moment, PIL's Image.fromarray is also written. In other words, if you devise and write this __getitem__
, it is possible to return free data.
But there is still something I don't understand. How is the index created? The hint is here.
if sampler is None: # give default samplers
if self.dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
self.sampler = sampler
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
The index seems to be created through sampler. By default, sampler is switched by the argument True, False called shuffle. For example, let's look at the implementation when shuffle = False.
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
The data_source here is a dataset. It looks like I've come this far and got a general idea. In other words, it repeats for the length of the dataset. Conversely, it seems necessary to prepare a special method called __len__
in the dataset.
Let's check __len__
in datasets.MNIST.
def __len__(self):
return len(self.data)
You are returning the length of data. Since data in MNIST has a size of 60000x28x28, 60000 will be returned. It's been pretty refreshing.
The article is getting longer, so that's it for the first part. In Part 2, you will make your own dataset.
Recommended Posts