[PYTHON] Hinweis zum Standardverhalten von collate_fn in PyTorch

Was ist collate_fn?

Viele Benutzer verwenden DataLoader, wenn Datensätze mit PyTorch geladen werden. (Es gibt viele gute Artikel zur Verwendung von DataLoader. Beispielsweise ist dieser Artikel leicht zu verstehen.)

collate_fn ist eines der Argumente, die dem Konstruktor beim Erstellen einer DataLoader-Instanz gegeben werden, und hat die Aufgabe, die einzelnen aus dem Dataset abgerufenen Daten in einem Mini-Batch zu gruppieren. Insbesondere stammt collate_fn aus dem ** Datensatz, wie in der offiziellen Dokumentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn) beschrieben. Geben Sie die Liste der abgerufenen Daten ein **. Dann wird der Rückgabewert von "collate_fn" von "DataLoader" ausgegeben.

Wenn Sie also Daten aus Ihrem eigenen Dataset mit "DataLoader" lesen, können Sie damit umgehen, indem Sie "collate_fn" erstellen, wie im folgenden Beispiel gezeigt.

def simple_collate_fn(list_of_data):
    #Hier nehmen wir an, dass jedes Daten ein D-dimensionaler Vektor ist.
    tensors = [torch.FloatTensor(data) for data in list_of_data]
    #Kombinieren Sie die neu hinzugefügten Dimensionen zu einem Mini-Batch zu einer N x D-Matrix.(N ist die Anzahl der Daten)
    batched_tensor = tensor.stack(tensors, dim=0)
    #Dieser Rückgabewert ist
    # for batched_tensor in dataloader:
    #Wird vom Data Loader ausgegeben.
    return batched_tensor

Standardverhalten von collate_fn

Um die Implementierung zu vereinfachen, möchte ich die Implementierung meiner eigenen "collate_fn" vermeiden, wenn das Standardverhalten ohne "collate_fn" verwendet werden kann.

Als ich es nachgeschlagen habe, ist collate_fn selbst standardmäßig ziemlich ausgefeilt, und es scheint, dass es nicht nur eine Kombination von Tensoren wie torch.stack (*, dim = 0) ist, also diesmal als Memorandum diese Standardeinstellung Ich möchte die Funktionen zusammenfassen.

Offizielle Dokumentation

Tatsächlich ist das Standardverhalten von "collate_fn" in der offiziellen Dokumentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn) gut dokumentiert.

  • It always prepends a new dimension as the batch dimension.
  • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
  • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.

Mit anderen Worten, es scheint die folgenden Funktionen zu haben.

Ich war besonders überrascht, weil ich noch nie von der Existenz der dritten Funktion gehört hatte. (Es ist mir peinlich, ein einfaches collate_fn implementiert zu haben, das mehrere Datenvektoren stapelt ...)

Schauen Sie sich die Implementierung an

Da das detaillierte Verhalten jedoch nicht verstanden werden kann, ohne die Implementierung tatsächlich zu betrachten, Tatsächliche Implementierung Ich würde gerne einen Blick auf (/collate.py) werfen.

Ich denke, es ist am schnellsten, es tatsächlich zu lesen, aber ich werde es grob zusammenfassen, damit Sie die Implementierung nicht erneut lesen müssen, wenn Sie sie in Zukunft erneut überprüfen.

Informationen ab Version 1.5.

Fallklassifizierung nach Typ

Der Standardwert "collate_fn", "default_collate", ist ein rekursiver Prozess, und der Prozess wird nach dem Typ des ersten Elements des Arguments "batch" klassifiziert.

elem = batch[0]
elem_type = type(elem)

Im Folgenden werden wir die spezifische Verarbeitung nach dem Typ "elem" zusammenfassen.

torch.Tensor

Wenn batch`` torch.Tensor ist, wird einfach zuerst eine Dimension hinzugefügt und verbunden.

return torch.stack(batch, 0)

Art von numpy

Im Fall von "ndarray" von numpy wird es tensorisiert und dann wie im Fall von "torch.Tensor" kombiniert.

return default_collate([torch.as_tensor(b) for b in batch])

Andererseits ist im Fall eines numpy-Skalars der aktuelle "Stapel" ein Vektor, so dass er so wie er ist in einen Tensor umgewandelt wird.

return torch.as_tensor(batch)

float, int, str

Auch in diesem Fall ist "Batch" ein Vektor, daher wird er wie unten gezeigt als Tensor oder Liste zurückgegeben.

# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch

Klassen, die von "collection.abc.Mapping" erben, wie z. B. "dict"

Wie unten gezeigt, wird jeder Schlüssel gestapelt und als ursprünglicher Schlüsselwert zurückgegeben.

return {key: default_collate([d[key] for d in batch]) for key in elem}

namedtuple

Auch in diesem Fall wird für jedes Attribut eine Stapelverarbeitung durchgeführt, wobei derselbe Attributname wie beim ursprünglichen "namedtuple" beibehalten wird.

return elem_type(*(default_collate(samples) for samples in zip(*batch)))

Klassen, die von "collection.abc.Sequence" erben, wie z. B. "list"

Die Stapelverarbeitung wird für jedes Element wie unten gezeigt durchgeführt.

transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

Konkretes Beispiel

Versuchen Sie beispielsweise, ein Dataset mit einer komplexen Struktur zu lesen, die Wörterbücher und Zeichenfolgen mit dem Standardwert "collate_fn" enthält (siehe unten).

import numpy as np
from torch.utils.data import DataLoader

if __name__=="__main__":
    complex_dataset = [
        [0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
        [1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
    ]
    dataloader = DataLoader(complex_dataset, batch_size=2)
    for batch in dataloader:
        print(batch)

Anschließend können Sie wie folgt bestätigen, dass die Stapelverarbeitung erfolgreich durchgeführt wurde.

[
    tensor([0, 1]),
    ('Bob', 'Tom'),
    {
        'height': tensor([172.5000, 153.1000], dtype=torch.float64),
        'feature': tensor([[1, 2, 3],[3, 2, 1]])
    }
]

Übrigens wird Pythons float standardmäßig in torch.float64 konvertiert. Normalerweise drückt numpy.ndarray einen Vektor oder Tensor aus, daher denke ich, dass es kein Problem gibt, aber wenn Sie es nicht wissen, werden Sie in eine Falle tappen.

Recommended Posts

Hinweis zum Standardverhalten von collate_fn in PyTorch
[Hinweis] Import von Dateien in das übergeordnete Verzeichnis in Python
Finden Sie den Rang der Matrix in der XOR-Welt (Rang der Matrix auf F2)
[Python] Ein Hinweis, dass ich das Verhalten von matplotlib.pyplot zu verstehen begann
Holen Sie sich die Anzahl der Leser von Artikeln über Mendeley in Python
Überprüfen Sie das Verhalten des Zerstörers in Python
Holen Sie sich den Aufrufer einer Funktion in Python
Das Verhalten von signal () hängt von der Kompilierungsoption ab
Hinweise zum Anpassen der Diktatlistenklasse
Kopieren Sie die Liste in Python
Finden Sie die Anzahl der Tage in einem Monat
Schreiben Sie eine Notiz über die Python-Version von Python Virtualenv
Berechnen Sie die Wahrscheinlichkeit von Ausreißern auf den Box-Whiskern
[Hinweis] Über die Rolle des Unterstrichs "_" in Python
Ausgabe in Form eines Python-Arrays
Gibt es keinen Standardwert im Wörterbuch?
Ändern Sie in Python das Verhalten der Methode je nach Aufruf
Ein Briefing über die Wut, die durch das Schaben verursacht wurde
Notieren Sie sich die Liste der grundlegenden Verwendungszwecke von Pandas
Unterschied in den Ergebnissen abhängig vom Argument von multiprocess.Process
Schreiben Sie in Python ein logarithmisches Histogramm auf die x-Achse
Ein Hinweis zum Ausprobieren eines einfachen MCMC-Tutorials auf PyMC3
Eine Überlegung zur Visualisierung des Anwendungsbereichs des Vorhersagemodells
Ein Memorandum über die Umsetzung von Empfehlungen in Python
Python Hinweis: Das Rätsel, einer Variablen eine Variable zuzuweisen
[Beispiel für eine Python-Verbesserung] In 2 Wochen wurden die Grundlagen von Python auf einer kostenlosen Website erlernt
Ein Hinweis zur Bibliotheksimplementierung, in der Hyperparameter mithilfe der Bayes'schen Optimierung in Python untersucht werden
rsync Das Verhalten ändert sich abhängig vom Vorhandensein oder Fehlen des Schrägstrichs der Kopierquelle
Code, der bei AttributeError Standardwerte festlegt
Finden Sie die scheinbare Breite einer Zeichenfolge in Python heraus
Notieren Sie sich, was Sie in Zukunft mit Razpai machen möchten
Umfrage zum Einsatz von maschinellem Lernen in realen Diensten
Hinweise zum Einbetten der Skriptsprache in Bash-Skripte
Hinweis 2 zum Einbetten der Skriptsprache in ein Bash-Skript
Zählen Sie die Anzahl der Zeichen im Text in der Zwischenablage auf dem Mac
Holen Sie sich die Anzahl der spezifischen Elemente in der Python-Liste
Ein Hinweis beim Berühren der Gesichtserkennungs-API von Microsoft mit Python
Hinweise zum Laden einer virtuellen Umgebung mit PyCharm
Ein Hinweis zum Verhalten von bowtie2 bei mehreren Treffern
[Hinweis] Ein Shell-Skript, das die CPU-Auslastung eines bestimmten Prozesses in einer while-Schleife überprüft.
Ich habe ein wenig versucht, das Verhalten der Zip-Funktion
Finden Sie die Eigenwerte einer reellen symmetrischen Matrix in Python
Ausbeute in einer Klasse, die unittest geerbt hat. TestCase funktionierte nicht mit der Nase (abhängig von der Version der Nase?)
Überprüfung der Scherzausbreitung der "Notfallerklärung am 1. April"
Verarbeiten Sie den Inhalt der Datei der Reihe nach mit einem Shell-Skript
So bestimmen Sie die Existenz eines Selenelements in Python
Ein Hinweis zum Überprüfen der Verbindung zum Lizenzserver-Port
Erstellen Sie in kürzester Zeit eine Selenium-Umgebung unter Amazon Linux 2
Warum in der Substitutionsformel eine Scheibe auf die linke Seite setzen?
So überprüfen Sie die Speichergröße einer Variablen in Python
Unter Linux ist der Zeitstempel einer Datei etwas vorbei.
Wenn Sie eine Liste mit dem Standardargument der Funktion angeben ...
Lesen Sie die Standardausgabe eines Unterprozesses zeilenweise in Python
So überprüfen Sie die Speichergröße eines Wörterbuchs in Python