[PYTHON] PyTorch DataLoader is slow

In PyTorch, DataLoader (torch.utils.data.DataLoader) is often used to retrieve mini-batch from a dataset, but when experimenting with large size data, using PyTorch's DataLoader It turned out to be very time consuming. For comparison, I made my own italator that retrieves a mini-batch from the dataset and tried it, but found that Pytorch's DataLoader was considerably slower than that. This can be a bottleneck, especially when using large sizes of data.

[Addition: 2020/03/23] I received a comment that the cause of the delay is BatchSampler, which is used by default in DataLoader. See comments for details.

Setting

In the following, it is assumed that a mini-batch with a batch size of 10,000 is repeatedly extracted from label and target with 1 million data. The calculation environment used was Google Colaboratory.

import torch

label  = torch.randn(1000000,10)
target = torch.randn(1000000,10)
batch_size = 10000

Create a loader to retrieve the mini-batch, and measure the execution time using the following function that simply repeats fetching the mini-batch.

def run_loader(loader):
    for label,target in loader:
        pass

Pytorch Data Loader

When I created a loader using torch.utils.data.DataLoader (without shuffle) and measured the execution time, it was 6.8 seconds. It feels like it's taking a long time to just retrieve the data.

dataset = torch.utils.data.TensorDataset(label,target)
loader1 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader1)

 1 loop, best of 1: 6.83 s per loop

When shuffle was performed, it took 7.0 seconds.

dataset = torch.utils.data.TensorDataset(label,target)
loader2 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader2)

 1 loop, best of 1: 6.97 s per loop

Homebrew Data Loader

For comparison, I created an italator that extracts a mini-batch from the dataset and conducted the same experiment.

class DataLoader:

    def __init__(self,dataset,batch_size=1,shuffle=False):
        self.dataset = dataset 
        self.batch_size = batch_size
        self.shuffle = shuffle
        assert all([ dataset[i].size(0) == dataset[0].size(0) for i in range(len(dataset)) ]), 'all the elemtnes must have the same length'
        self.data_size = dataset[0].size(0)

    def __iter__(self):
        self._i = 0
        
        if self.shuffle:
            index_shuffle = torch.randperm(self.data_size)
            self.dataset = [ v[index_shuffle] for v in self.dataset ]

        return self

    def __next__(self):

        i1 = self.batch_size * self._i
        i2 = min( self.batch_size * ( self._i + 1 ), self.data_size )
        
        if i1 >= self.data_size:
            raise StopIteration()

        value = [ v[i1:i2] for v in self.dataset ]

        self._i += 1

        return value

If you use your own DataLoader (without shuffle), you can see that the execution time is 500 microseconds and it takes almost no time to retrieve.

loader3 = DataLoader([label,target],batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader3)

 1 loop, best of 1: 468 µs per loop

The execution time for shuffle is 300 ms, which is longer than without it, but it is still negligible compared to using Pytorch's DataLoader.

loader4 = DataLoader([label,target],batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader4)

 1 loop, best of 1: 296 ms per loop

Summary

It turns out that it takes a lot of time to retrieve a mini-batch using PyTorch's DataLoader. This effect is very large, especially when dealing with large sizes of data.

Recommended Posts

PyTorch DataLoader is slow
[PyTorch Tutorial ①] What is PyTorch?
pandas idxmax is slow
pypy bool type is slow
[Pytorch] Memo about Dataset / DataLoader
PyTorch module says libcusparse.so.10 is missing
[PyTorch Tutorial ⑥] What is torch.nn really?
Is Python's in operator slow? (From ABC167D)
PNG saving is slow with Python imageio.imwrite