[PYTHON] PyTorch DataLoader ist langsam

In PyTorch wird DataLoader (torch.utils.data.DataLoader) häufig zum Abrufen von Mini-Batch aus einem Dataset verwendet. Beim Experimentieren mit großen Datenmengen kann jedoch der DataLoader von PyTorch verwendet werden. Es stellte sich als sehr zeitaufwändig heraus. Zum Vergleich habe ich meinen eigenen Kursivator erstellt, der Mini-Batch aus dem Datensatz abruft und ausprobiert, aber festgestellt, dass der DataLoader von Pytorch erheblich langsamer ist. Dies kann ein Engpass sein, insbesondere bei Verwendung großer Datenmengen.

[Ergänzung: 2020/03/23] Ich habe einen Kommentar erhalten, dass die Ursache für die Langsamkeit BatchSampler ist, der standardmäßig in DataLoader verwendet wird. Siehe Kommentare für Details.

Aufbau

Im Folgenden wird davon ausgegangen, dass eine Mini-Charge mit einer Chargengröße von 10.000 wiederholt mit 1 Million Daten aus "Label" und "Target" extrahiert wird. Als Berechnungsumgebung wurde Google Colaboratory verwendet.

import torch

label  = torch.randn(1000000,10)
target = torch.randn(1000000,10)
batch_size = 10000

Erstellen Sie einen "Loader", um den Mini-Batch abzurufen, und messen Sie die Ausführungszeit mit der folgenden Funktion, die das Abrufen des Mini-Batch einfach wiederholt.

def run_loader(loader):
    for label,target in loader:
        pass

Pytorch Data Loader

Als ich mit torch.utils.data.DataLoader (ohne Shuffle) einen Loader erstellte und die Ausführungszeit maß, waren es 6,8 Sekunden. Es scheint, als würde es lange dauern, nur um die Daten abzurufen.

dataset = torch.utils.data.TensorDataset(label,target)
loader1 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader1)

 1 loop, best of 1: 6.83 s per loop

Wenn das Mischen durchgeführt wurde, dauerte es 7,0 Sekunden.

dataset = torch.utils.data.TensorDataset(label,target)
loader2 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader2)

 1 loop, best of 1: 6.97 s per loop

Homebrew Data Loader

Zum Vergleich habe ich einen Kursivator erstellt, der einen Mini-Batch aus dem Datensatz extrahiert, und das gleiche Experiment durchgeführt.

class DataLoader:

    def __init__(self,dataset,batch_size=1,shuffle=False):
        self.dataset = dataset 
        self.batch_size = batch_size
        self.shuffle = shuffle
        assert all([ dataset[i].size(0) == dataset[0].size(0) for i in range(len(dataset)) ]), 'all the elemtnes must have the same length'
        self.data_size = dataset[0].size(0)

    def __iter__(self):
        self._i = 0
        
        if self.shuffle:
            index_shuffle = torch.randperm(self.data_size)
            self.dataset = [ v[index_shuffle] for v in self.dataset ]

        return self

    def __next__(self):

        i1 = self.batch_size * self._i
        i2 = min( self.batch_size * ( self._i + 1 ), self.data_size )
        
        if i1 >= self.data_size:
            raise StopIteration()

        value = [ v[i1:i2] for v in self.dataset ]

        self._i += 1

        return value

Wenn Sie Ihren eigenen DataLoader (ohne Shuffle) verwenden, können Sie feststellen, dass die Ausführungszeit 500 Mikrosekunden beträgt und das Abrufen fast keine Zeit in Anspruch nimmt.

loader3 = DataLoader([label,target],batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader3)

 1 loop, best of 1: 468 µs per loop

Die Ausführungszeit für das Mischen beträgt 300 Millisekunden, was länger ist als ohne, aber im Vergleich zur Verwendung des Datenladers von Pytorch immer noch vernachlässigbar.

loader4 = DataLoader([label,target],batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader4)

 1 loop, best of 1: 296 ms per loop

Zusammenfassung

Es stellt sich heraus, dass das Abrufen eines Mini-Batches mit dem DataLoader von PyTorch viel Zeit in Anspruch nimmt. Dieser Effekt ist sehr groß, insbesondere bei großen Datenmengen.

Recommended Posts

PyTorch DataLoader ist langsam
[PyTorch Tutorial ①] Was ist PyTorch?
pandas idxmax ist langsam
Pypy Bool Typ ist langsam
[Pytorch] Memo über Dataset / DataLoader
Das PyTorch-Modul sagt, dass libcusparse.so.10 fehlt
[PyTorch Tutorial ⑥] Was ist torch.nn wirklich?
Ist Pythons Operator langsam? (Von ABC167D)