Viele Benutzer verwenden DataLoader, wenn Datensätze mit PyTorch geladen werden. (Es gibt viele gute Artikel zur Verwendung von DataLoader. Beispielsweise ist dieser Artikel leicht zu verstehen.)
collate_fn
ist eines der Argumente, die dem Konstruktor beim Erstellen einer DataLoader
-Instanz gegeben werden, und hat die Aufgabe, die einzelnen aus dem Dataset abgerufenen Daten in einem Mini-Batch zu gruppieren.
Insbesondere stammt collate_fn
aus dem ** Datensatz, wie in der offiziellen Dokumentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn) beschrieben. Geben Sie die Liste der abgerufenen Daten ein **. Dann wird der Rückgabewert von "collate_fn" von "DataLoader" ausgegeben.
Wenn Sie also Daten aus Ihrem eigenen Dataset mit "DataLoader" lesen, können Sie damit umgehen, indem Sie "collate_fn" erstellen, wie im folgenden Beispiel gezeigt.
def simple_collate_fn(list_of_data):
#Hier nehmen wir an, dass jedes Daten ein D-dimensionaler Vektor ist.
tensors = [torch.FloatTensor(data) for data in list_of_data]
#Kombinieren Sie die neu hinzugefügten Dimensionen zu einem Mini-Batch zu einer N x D-Matrix.(N ist die Anzahl der Daten)
batched_tensor = tensor.stack(tensors, dim=0)
#Dieser Rückgabewert ist
# for batched_tensor in dataloader:
#Wird vom Data Loader ausgegeben.
return batched_tensor
Um die Implementierung zu vereinfachen, möchte ich die Implementierung meiner eigenen "collate_fn" vermeiden, wenn das Standardverhalten ohne "collate_fn" verwendet werden kann.
Als ich es nachgeschlagen habe, ist collate_fn
selbst standardmäßig ziemlich ausgefeilt, und es scheint, dass es nicht nur eine Kombination von Tensoren wie torch.stack (*, dim = 0)
ist, also diesmal als Memorandum diese Standardeinstellung Ich möchte die Funktionen zusammenfassen.
Tatsächlich ist das Standardverhalten von "collate_fn" in der offiziellen Dokumentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn) gut dokumentiert.
- It always prepends a new dimension as the batch dimension.
- It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
- It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.
Mit anderen Worten, es scheint die folgenden Funktionen zu haben.
Ich war besonders überrascht, weil ich noch nie von der Existenz der dritten Funktion gehört hatte. (Es ist mir peinlich, ein einfaches collate_fn
implementiert zu haben, das mehrere Datenvektoren stapelt ...)
Da das detaillierte Verhalten jedoch nicht verstanden werden kann, ohne die Implementierung tatsächlich zu betrachten, Tatsächliche Implementierung Ich würde gerne einen Blick auf (/collate.py) werfen.
Ich denke, es ist am schnellsten, es tatsächlich zu lesen, aber ich werde es grob zusammenfassen, damit Sie die Implementierung nicht erneut lesen müssen, wenn Sie sie in Zukunft erneut überprüfen.
Informationen ab Version 1.5.
Der Standardwert "collate_fn", "default_collate", ist ein rekursiver Prozess, und der Prozess wird nach dem Typ des ersten Elements des Arguments "batch" klassifiziert.
elem = batch[0]
elem_type = type(elem)
Im Folgenden werden wir die spezifische Verarbeitung nach dem Typ "elem" zusammenfassen.
torch.Tensor
Wenn batch`` torch.Tensor
ist, wird einfach zuerst eine Dimension hinzugefügt und verbunden.
return torch.stack(batch, 0)
numpy
Im Fall von "ndarray" von numpy wird es tensorisiert und dann wie im Fall von "torch.Tensor" kombiniert.
return default_collate([torch.as_tensor(b) for b in batch])
Andererseits ist im Fall eines numpy-Skalars der aktuelle "Stapel" ein Vektor, so dass er so wie er ist in einen Tensor umgewandelt wird.
return torch.as_tensor(batch)
float
, int
, str
Auch in diesem Fall ist "Batch" ein Vektor, daher wird er wie unten gezeigt als Tensor oder Liste zurückgegeben.
# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch
Wie unten gezeigt, wird jeder Schlüssel gestapelt und als ursprünglicher Schlüsselwert zurückgegeben.
return {key: default_collate([d[key] for d in batch]) for key in elem}
namedtuple
Auch in diesem Fall wird für jedes Attribut eine Stapelverarbeitung durchgeführt, wobei derselbe Attributname wie beim ursprünglichen "namedtuple" beibehalten wird.
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
Die Stapelverarbeitung wird für jedes Element wie unten gezeigt durchgeführt.
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
Versuchen Sie beispielsweise, ein Dataset mit einer komplexen Struktur zu lesen, die Wörterbücher und Zeichenfolgen mit dem Standardwert "collate_fn" enthält (siehe unten).
import numpy as np
from torch.utils.data import DataLoader
if __name__=="__main__":
complex_dataset = [
[0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
[1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
]
dataloader = DataLoader(complex_dataset, batch_size=2)
for batch in dataloader:
print(batch)
Anschließend können Sie wie folgt bestätigen, dass die Stapelverarbeitung erfolgreich durchgeführt wurde.
[
tensor([0, 1]),
('Bob', 'Tom'),
{
'height': tensor([172.5000, 153.1000], dtype=torch.float64),
'feature': tensor([[1, 2, 3],[3, 2, 1]])
}
]
Übrigens wird Pythons float
standardmäßig in torch.float64
konvertiert. Normalerweise drückt numpy.ndarray
einen Vektor oder Tensor aus, daher denke ich, dass es kein Problem gibt, aber wenn Sie es nicht wissen, werden Sie in eine Falle tappen.
Recommended Posts