[PYTHON] Remarque sur le comportement par défaut de collate_fn dans PyTorch

Qu'est-ce que collate_fn?

De nombreuses personnes utilisent DataLoader lors du chargement des ensembles de données avec PyTorch. (Il existe de nombreux bons articles sur l'utilisation de DataLoader. Par exemple, cet article est facile à comprendre.)

collate_fn est l'un des arguments donnés au constructeur lors de la création d'une instance DataLoader, et a pour rôle de regrouper les données individuelles extraites de l'ensemble de données dans un mini-lot. Plus précisément, collate_fn provient de l'ensemble de données **, comme décrit dans la documentation officielle (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn). Entrez la liste des données récupérées **. Ensuite, la valeur de retour de collate_fn sera sortie de DataLoader.

Par conséquent, lorsque vous lisez des données de votre propre ensemble de données avec DataLoader, vous pouvez les gérer en créant votre propre collate_fn comme indiqué dans l'exemple ci-dessous.

def simple_collate_fn(list_of_data):
    #Ici, on suppose que chaque donnée est un vecteur D-dimensionnel.
    tensors = [torch.FloatTensor(data) for data in list_of_data]
    #Combinez les dimensions nouvellement ajoutées dans un mini-lot dans une matrice N x D.(N est le nombre de données)
    batched_tensor = tensor.stack(tensors, dim=0)
    #Cette valeur de retour est
    # for batched_tensor in dataloader:
    #Est sorti du chargeur de données.
    return batched_tensor

Comportement par défaut de collate_fn

Afin de simplifier l'implémentation, je voudrais éviter d'implémenter mon propre collate_fn si le comportement par défaut sans donner collate_fn peut être utilisé.

Quand je l'ai recherché, collate_fn est assez sophistiqué même par défaut, et il semble que ce ne soit pas seulement une combinaison de tenseurs comme torch.stack (*, dim = 0), donc cette fois comme mémorandum ce défaut Je voudrais résumer les fonctions.

Documentation officielle

En fait, le comportement par défaut de collate_fn est bien documenté dans la documentation officielle (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn).

  • It always prepends a new dimension as the batch dimension.
  • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
  • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.

En d'autres termes, il semble avoir les fonctions suivantes.

J'ai été particulièrement surpris car je n'avais jamais entendu parler de l'existence de la troisième fonction. (Je suis gêné d'avoir implémenté un simple collate_fn qui regroupe respectivement plusieurs vecteurs de données ...)

Jetez un œil à la mise en œuvre

Cependant, comme le comportement détaillé ne peut être compris sans examiner réellement l'implémentation, Implémentation réelle Je voudrais jeter un œil à (/collate.py).

Je pense que c'est le plus rapide pour le lire, mais je vais le résumer grossièrement pour que vous n'ayez pas à relire l'implémentation lorsque vous la vérifierez à nouveau à l'avenir.

Informations à partir de la version 1.5.

Classification des cas par type

La valeur par défaut collate_fn, default_collate, est un processus récursif, et le processus est classé en fonction du type du premier élément de l'argument batch.

elem = batch[0]
elem_type = type(elem)

Ci-dessous, nous résumerons le traitement spécifique par type d''elem '.

torch.Tensor

Si batch est torch.Tensor, il ajoute simplement une dimension en premier et se joint.

return torch.stack(batch, 0)

Type de «numpy»

Dans le cas de "ndarray" de numpy, il est tensorisé puis combiné comme dans le cas de "torch.Tensor".

return default_collate([torch.as_tensor(b) for b in batch])

Par contre, dans le cas du scalaire numpy, le "batch" courant est un vecteur, il est donc converti en un tenseur tel quel.

return torch.as_tensor(batch)

float, int, str

Dans ce cas également, «batch» est un vecteur, il est donc renvoyé sous forme de tensorisé ou de liste comme indiqué ci-dessous.

# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch

Classes qui héritent de collections.abc.Mapping telles que dict

Comme indiqué ci-dessous, chaque clé est groupée et renvoyée en tant que valeur de clé d'origine.

return {key: default_collate([d[key] for d in batch]) for key in elem}

namedtuple

Dans ce cas également, le traitement par lots est effectué pour chaque attribut tout en conservant le même nom d'attribut que le «namedtuple» d'origine.

return elem_type(*(default_collate(samples) for samples in zip(*batch)))

Classes qui héritent de collections.abc.Sequence telles que list

Le traitement par lots est effectué pour chaque élément comme indiqué ci-dessous.

transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

Exemple concret

Par exemple, essayez de lire un ensemble de données avec une structure compliquée qui inclut des dictionnaires et des chaînes de caractères avec la valeur par défaut collate_fn comme indiqué ci-dessous.

import numpy as np
from torch.utils.data import DataLoader

if __name__=="__main__":
    complex_dataset = [
        [0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
        [1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
    ]
    dataloader = DataLoader(complex_dataset, batch_size=2)
    for batch in dataloader:
        print(batch)

Ensuite, vous pouvez confirmer qu'il est correctement groupé comme suit.

[
    tensor([0, 1]),
    ('Bob', 'Tom'),
    {
        'height': tensor([172.5000, 153.1000], dtype=torch.float64),
        'feature': tensor([[1, 2, 3],[3, 2, 1]])
    }
]

Au fait, le «float» de python est converti en «torch.float64» par défaut. Normalement, numpy.ndarray exprime un vecteur ou un tenseur, donc je pense qu'il n'y a pas de problème, mais si vous ne le savez pas, vous tomberez dans un piège.

Recommended Posts

Remarque sur le comportement par défaut de collate_fn dans PyTorch
[Note] Importation de fichiers dans le répertoire parent en Python
Trouvez le rang de la matrice dans le monde XOR (rang de la matrice sur F2)
[python] Une note que j'ai commencé à comprendre le comportement de matplotlib.pyplot
Obtenez le nombre de lecteurs d'articles sur Mendeley en Python
Vérifiez le comportement du destroyer en Python
Récupérer l'appelant d'une fonction en Python
Le comportement de signal () dépend de l'option de compilation
Remarques sur la personnalisation de la classe de liste de dict
Copiez la liste en Python
Trouvez le nombre de jours dans un mois
Écrire une note sur la version python de python virtualenv
Calculer la probabilité de valeurs aberrantes sur les moustaches de la boîte
[Note] À propos du rôle du trait de soulignement "_" en Python
Sortie sous la forme d'un tableau python
N'y a-t-il pas une valeur par défaut dans le dictionnaire?
En Python, changez le comportement de la méthode en fonction de la façon dont elle est appelée
Un briefing sur la colère provoquée par le grattage
Prenez note de la liste des utilisations de base de Pandas
Différence de résultats en fonction de l'argument du multiprocessus.
Ecrire un histogramme à l'échelle logarithmique sur l'axe des x en python
Une note d'essayer un simple tutoriel MCMC sur PyMC3
Une réflexion sur la visualisation du champ d'application du modèle de prédiction
Un mémorandum sur la mise en œuvre des recommandations en Python
Note Python: Le mystère de l'attribution d'une variable à une variable
[Exemple d'amélioration de Python] Apprentissage des bases de Python sur un site gratuit en 2 semaines
Une note sur l'implémentation de la bibliothèque qui explore les hyperparamètres à l'aide de l'optimisation bayésienne en Python
rsync Le comportement change en fonction de la présence ou de l'absence de la barre oblique de la source de copie
Code qui définit les valeurs par défaut en cas d'AttributeError
Découvrez la largeur apparente d'une chaîne en python
Notez ce que vous voulez faire à l'avenir avec Razpai
Enquête sur l'utilisation du machine learning dans les services réels
Remarques sur l'intégration du langage de script dans les scripts bash
Remarque 2 pour intégrer le langage de script dans un script bash
Comptez le nombre de caractères dans le texte dans le presse-papiers sur Mac
Obtenez le nombre d'éléments spécifiques dans la liste python
Une note quand j'ai touché l'API de reconnaissance faciale de Microsoft avec Python
Remarques sur la façon de charger un environnement virtuel avec PyCharm
Une note sur le comportement de bowtie2 lors de plusieurs coups
[Note] Un script shell qui vérifie l'utilisation du processeur d'un processus spécifique dans une boucle while.
J'ai essayé un peu le comportement de la fonction zip
Trouver les valeurs propres d'une vraie matrice symétrique en Python
Rendement dans la classe qui a hérité de l'unittest.TestCase ne fonctionnait pas avec le nez (selon la version du nez?)
Vérification de la propagation du canular de "Déclaration d'urgence le 1er avril"
Traitez le contenu du fichier dans l'ordre avec un script shell
Comment déterminer l'existence d'un élément sélénium en Python
Remarque sur la façon de vérifier la connexion au port du serveur de licences
Créez un environnement Selenium sur Amazon Linux 2 dans les plus brefs délais
Pourquoi mettre une tranche sur le côté gauche dans la formule de substitution
Comment vérifier la taille de la mémoire d'une variable en Python
Sous Linux, l'horodatage d'un fichier est un peu dépassé.
Si vous donnez une liste avec l'argument par défaut de la fonction ...
Lire la sortie standard d'un sous-processus ligne par ligne en Python
Comment vérifier la taille de la mémoire d'un dictionnaire en Python