[PYTHON] Essayez d'utiliser collate_fn de Pytorch

Pytorch collate_fn est un argument de Dataloader.

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

Cette fois, je voudrais confirmer son comportement et son utilisation.

Qu'est-ce que collate_fn

Lorsque le \ _ \ _ getitem \ _ \ _ défini dans le jeu de données se présente sous la forme d'un lot, chaque élément (image, cible, etc.) est d'abord consolidé dans une liste. Collate_fn le manipule comme décrit dans Pytroch Official, le rendant finalement torche.Tensor C'est une fonction.

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

Par défaut, torch.stack () est utilisé pour créer Tensor, mais vous pouvez créer un lot avancé en utilisant votre propre collate_fn.

Créez votre propre collate_fn

Le comportement par défaut est presque le même que ci-dessous. (Bien que le nombre de retours dépende de \ _ \ _ getitem \ _ \ _) Il prend le lot comme argument, l'empile et le renvoie.

def collate_fn(batch):
    images, targets= list(zip(*batch))
    images = torch.stack(images)
    targets = torch.stack(targets)
    return images, targets

Vous pouvez modifier le contenu de votre propre collate_fn.

Cette fois, nous allons créer un lot de détection d'objets. La détection d'objet entre essentiellement le rectangle de l'objet et son étiquette, mais comme il peut y avoir plusieurs rectangles dans une image, il est nécessaire de connecter quelle image est quel rectangle lors du traitement par lots, et l'index Doit être joint.

[[label, xc, yx, w, h],
 [                   ],
 [                   ],...]

#Changer cela vers le bas

[[0, label xc, yx, w, h],
 [0,                   ],
 [1,                   ],...]

La mise en œuvre elle-même n'est pas si difficile.

def batch_idx_fn(batch):
    images, bboxes = list(zip(*batch))
    targets = []
    for idx, bbox in enumerate(bboxes):
        target = np.zeros((len(bbox), 6))
        target[:, 1:] = bbox
        target[:, 0] = idx
        targets.append(target)
    images = torch.stack(images)
    targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
    return images, targets

Lorsque vous l'utilisez réellement, ce sera comme suit.

test_data_loader = torch.utils.data.DataLoader(
                       test_dataset, 
                       batch_size=1, 
                       shuffle=False, 
                       collate_fn=batch_idx_fn
                       )
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]

en conclusion

Autre que lors de l'indexation dans cet article Lorsque la cible change pour chaque lot, Lorsque la cible n'est pas des données numériques qui ne peuvent pas être empilées Je pense qu'il peut être utilisé lorsque vous souhaitez utiliser le même jeu de données avec des modifications légèrement différentes.

Recommended Posts

Essayez d'utiliser collate_fn de Pytorch
Essayez d'utiliser Tkinter
Essayez d'utiliser docker-py
Essayez d'utiliser Cookiecutter
Essayez d'utiliser des géopandas
Essayez d'utiliser Selenium
Essayez d'utiliser scipy
Essayez d'utiliser pandas.DataFrame
Essayez d'utiliser django-swiftbrowser
Essayez d'utiliser matplotlib
Essayez d'utiliser tf.metrics
Essayez d'utiliser virtualenv (virtualenvwrapper)
Essayez d'utiliser virtualenv maintenant
[Kaggle] Essayez d'utiliser LGBM
Essayez d'utiliser l'analyseur de flux de Python.
Essayez d'utiliser Tkinter de Python
Essayez d'utiliser Tweepy [Python2.7]
Essayez d'utiliser PythonTex avec Texpad.
[Python] Essayez d'utiliser le canevas de Tkinter
Essayez d'utiliser l'image Docker de Jupyter
Essayez d'utiliser scikit-learn (1) - Clustering K par méthode moyenne
Essayez d'utiliser matplotlib avec PyCharm
Essayez d'utiliser Azure Logic Apps
Essayez d'utiliser Kubernetes Client -Python-
Essayez d'utiliser l'API Twitter
Essayez d'utiliser OpenCV sur Windows
Essayez d'utiliser Jupyter Notebook de manière dynamique
Essayez d'utiliser AWS SageMaker Studio
Essayez de tweeter automatiquement en utilisant Selenium.
Essayez d'utiliser SQLAlchemy + MySQL (partie 1)
Essayez d'utiliser l'API Twitter
Essayez d'utiliser SQLAlchemy + MySQL (partie 2)
Essayez d'utiliser l'API PeeringDB 2.0
Essayez d'utiliser la fonction de brouillon de Pelican
Essayez d'utiliser pytest-Overview and Samples-
Essayez d'utiliser le folium avec anaconda
Essayez d'utiliser l'API Admin de la passerelle Janus
Essayez d'utiliser Spyder inclus dans Anaconda
Essayez d'utiliser des modèles de conception (édition exportateur)
Essayez d'utiliser Pillow sur iPython (partie 1)
Essayez d'utiliser Pillow sur iPython (partie 2)
Essayez d'utiliser l'API de Pleasant (python / FastAPI)
Essayez d'utiliser LevelDB avec Python (plyvel)
Essayez d'utiliser pynag pour configurer Nagios
Essayez d'utiliser la fonction de débogage à distance de PyCharm
Essayez d'utiliser ArUco avec Raspberry Pi
Essayez d'utiliser LiDAR bon marché (Camsense X1)
[Serveur de location Sakura] Essayez d'utiliser flask.
Essayez d'obtenir des statistiques en utilisant e-Stat
Essayez d'utiliser le module Python Cmd
Essayez d'utiliser le networkx de Python avec AtCoder
Essayez d'utiliser LeapMotion avec Python
Essayez d'utiliser la reconnaissance de caractères manuscrits (OCR) de GCP
Essayez d'utiliser Amazon DynamoDB à partir de Python
code-server Environnement local (3) Essayez d'utiliser le plug-in de VSCode
Essayez d'utiliser l'API Wunderlist en Python
Essayez une formule utilisant Σ avec python
Essayez d'utiliser le framework d'application Web Flask
Essayez d'utiliser Bash sur Windows 10 2 (installation de TensorFlow)
Essayez d'utiliser l'API Kraken avec Python
Essayez d'utiliser le LiDAR de 6 $ de réduction (Camsense X1)