[PYTHON] Comment éviter BrokenPipeError avec la note DataLoader de PyTorch

DataLoader de PyTorch dispose d'un mécanisme de chargement de données multi-processus. Quand j'ai essayé de l'utiliser sous Windows, cela ne fonctionnait pas avec la même erreur que here. J'ai enquêté sur diverses choses et je l'ai résolu, je vais donc noter la méthode.

Processus multiple DataLoader

Citation de Documents officiels

A DataLoader uses single-process data loading by default.

Within a Python process, the Global Interpreter Lock (GIL) prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, PyTorch provides an easy switch to perform multi-process data loading by simply setting the argument num_workers to a positive integer.

Et cela. En gros, si la valeur de la variable «num_workers» de la classe DataLoader est définie sur 1 ou plus, la lecture des données peut être parallélisée. BrokenPipeError Ainsi, lorsque je règle num_workers sur une valeur de 1 ou plus et que je la déplace,

BrokenPipeError: [Errno 32] Broken pipe

Cela n'a pas fonctionné avec une erreur. Même si vous définissez Dataset dans un autre fichier en vous référant à Erreur lorsque vous voulez charger Pytorch Dataset en parallèle avec DataLoader (Windows) Une erreur similaire s'est produite.

Solution

Si vous vous référez à here, il semble que lors de l'exécution de plusieurs processus sous Windows, ʻif name == "main" `doit exécuter une fonction qui exécute plusieurs processus.

Avant correction

train.py


from torch.utils.data import DataLoader
from dataloader import MyDataset #Ensemble de données créé

def train():
    dataset = MyDataset()
    train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
                              batch_size=4,
                              pin_memory=True,
                              drop_last=True)

    for batch in train_loader:
        #do some process...

if __name__ == "__main__":
    train()

modifié

train.py


from torch.utils.data import DataLoader
from dataloader import MyDataset #Ensemble de données créé

def train(train_loader):
    for batch in train_loader:
        #do some process...

if __name__ == "__main__":
    #dataset,Déplacer DataLoader
    dataset = MyDataset()
    train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
                              batch_size=4,
                              pin_memory=True,
                              drop_last=True)

    train(train_loader)

Dans le cas de DataLoader, si l'instance a été créée dans ʻif name == "main" `, le multi-processus fonctionnait même si la lecture des données elle-même était exécutée dans une autre fonction.

Résumé

J'ai écrit un mémo pour paralléliser DataLoader dans l'environnement Windows. En matière d'apprentissage en profondeur, il existe de nombreuses tâches qui ne fonctionnent pas sous Windows ou qui ne peuvent être effectuées sans une certaine ingéniosité. Par conséquent, j'aimerais écrire régulièrement des articles sur les erreurs qui se produisent autour de Windows.

Recommended Posts

Comment éviter BrokenPipeError avec la note DataLoader de PyTorch
Comment afficher des images en continu avec matplotlib Memo
Comment mettre à jour avec SQLAlchemy?
[Note] Comment utiliser virtualenv
Comment lancer avec Theano
Comment modifier avec SQLAlchemy?
Comment séparer les chaînes avec ','
Comment faire RDP sur Fedora31
Comment supprimer avec SQLAlchemy?
Python: comment utiliser async avec
Pour utiliser virtualenv avec PowerShell
Comment installer python-pip avec ubuntu20.04LTS
Comment gérer les données déséquilibrées
Comment démarrer avec Scrapy
Comment démarrer avec Python
Comment gérer l'erreur DistributionNotFound
Comment démarrer avec Django
Comment calculer la date avec python
Comment installer mysql-connector avec pip3
Comment INNER JOIN avec SQL Alchemy
Comment installer Anaconda avec pyenv
[Note] Comment changer DocumentRoot après la conversion SSL avec Let's Encrypt
Comment effectuer un traitement arithmétique avec le modèle Django
Comment titrer plusieurs figures avec matplotlib
Comment obtenir l'identifiant du parent avec sqlalchemy
Comment ajouter un package avec PyCharm
Comment utiliser OpenVPN avec Ubuntu 18.04.3 LTS
Comment utiliser Cmder avec PyCharm (Windows)
Comment empêcher les mises à jour de paquets avec apt
Comment utiliser BigQuery en Python
Comment utiliser Ass / Alembic avec HtoA
Comment gérer les erreurs de compatibilité d'énumération
Comment utiliser le japonais avec le tracé NLTK
Comment faire un test de sac avec python
Comment rechercher Google Drive dans Google Colaboratory
Comment afficher le japonais python avec lolipop
Comment télécharger des vidéos YouTube avec youtube-dl
Comment utiliser le notebook Jupyter avec ABCI
Essayer de gérer SQLite3 avec Python [Note]
Comment mettre hors tension de Linux sur Ultra96-V2
Comment utiliser la commande CUT (avec exemple)
Comment entrer le japonais avec les malédictions Python
Comment installer zsh (avec la personnalisation .zshrc)
Comment lire les données de problème avec Paiza
Comment utiliser SQLAlchemy / Connect avec aiomysql
Comment regrouper des volumes avec LVM
Comment installer python3 avec docker centos
Comment utiliser le pilote JDBC avec Redash
Liste de contrôle pour éviter de transformer les éléments de array of numpy avec for
Remarque: Comment obtenir le dernier jour du mois avec python (ajouté le premier jour du mois)
Comment supprimer sélectivement les anciens tweets avec Tweepy
Comment télécharger avec Heroku, Flask, Python, Git (4)
Comment gérer les fuites de mémoire dans matplotlib.pyplot