[PYTHON] How to avoid BrokenPipeError with PyTorch's DataLoader Note

PyTorch's DataLoader has a mechanism for multi-process data loading. When I tried to use it on Windows, it didn't work with the same error as here. Therefore, I investigated various things and solved it, so I will make a note of the method.

DataLoader multi-process

Quoted from Official Docs

A DataLoader uses single-process data loading by default.

Within a Python process, the Global Interpreter Lock (GIL) prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, PyTorch provides an easy switch to perform multi-process data loading by simply setting the argument num_workers to a positive integer.

And that. Roughly speaking, if the value of the variable num_workers in the DataLoader class is set to 1 or more, data reading can be parallelized. BrokenPipeError So, when I set num_workers to a value of 1 or more and moved it,

BrokenPipeError: [Errno 32] Broken pipe

It did not work with an error. Even if you define Dataset in another file by referring to Error when you want to load Pytorch Dataset in parallel with DataLoader (Windows) A similar error occurred.

Solution

If you refer to here, it seems that if you want to multiprocess on Windows, ʻif name == You have to execute a function that multiprocesses in "main" `.

Before correction

train.py


from torch.utils.data import DataLoader
from dataloader import MyDataset #Created dataset

def train():
    dataset = MyDataset()
    train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
                              batch_size=4,
                              pin_memory=True,
                              drop_last=True)

    for batch in train_loader:
        #do some process...

if __name__ == "__main__":
    train()

Revised

train.py


from torch.utils.data import DataLoader
from dataloader import MyDataset #Created dataset

def train(train_loader):
    for batch in train_loader:
        #do some process...

if __name__ == "__main__":
    #dataset,Move DataLoader
    dataset = MyDataset()
    train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
                              batch_size=4,
                              pin_memory=True,
                              drop_last=True)

    train(train_loader)

In the case of DataLoader, if the instance was created in ʻif name == "main" `, the multi-process worked even if the data reading itself was executed in another function.

Summary

I wrote a memo for parallelizing DataLoader in Windows environment. Around deep learning, there are many tasks that do not work on Windows or cannot be done without some ingenuity. Therefore, I would like to regularly write articles about errors that occur around Windows.

Recommended Posts

How to avoid BrokenPipeError with PyTorch's DataLoader Note
How to display images continuously with matplotlib Note
How to update with SQLAlchemy?
[Note] How to use virtualenv
How to cast with Theano
How to Alter with SQLAlchemy?
How to separate strings with','
How to RDP with Fedora31
How to Delete with SQLAlchemy?
Python: How to use async with
How to use virtualenv with PowerShell
How to deal with imbalanced data
How to install python-pip with ubuntu20.04LTS
How to deal with imbalanced data
How to get started with Scrapy
How to get started with Python
How to deal with DistributionNotFound errors
How to get started with Django
How to use FTP with Python
How to calculate date with python
How to install mysql-connector with pip3
How to INNER JOIN with SQLAlchemy
How to install Anaconda with pyenv
How to authenticate with Django Part 2
How to authenticate with Django Part 3
[Note] How to change DocumentRoot after SSL conversion with Let's Encrypt
How to do arithmetic with Django template
How to title multiple figures with matplotlib
How to get parent id with sqlalchemy
How to add a package with PyCharm
How to use OpenVPN with Ubuntu 18.04.3 LTS
How to use Cmder with PyCharm (Windows)
How to prevent package updates with apt
How to work with BigQuery in Python
How to use Ass / Alembic with HtoA
How to deal with enum compatibility errors
How to use Japanese with NLTK plot
How to do portmanteau test with python
How to search Google Drive with Google Colaboratory
How to display python Japanese with lolipop
How to download youtube videos with youtube-dl
How to use jupyter notebook with ABCI
Trying to handle SQLite3 with Python [Note]
How to power off Linux with Ultra96-V2
"How to pass PATH" to learn with homebrew
How to scrape websites created with SPA
How to use CUT command (with sample)
How to enter Japanese with Python curses
[Python] How to deal with module errors
How to install zsh (with .zshrc customization)
How to read problem data with paiza
How to use SQLAlchemy / Connect with aiomysql
How to get started with laravel (Linux)
How to group volumes together with LVM
How to install python3 with docker centos
How to use JDBC driver with Redash
Checklist on how to avoid turning the elements of numpy's array with for
Note: How to get the last day of the month with python (added the first day of the month)
How to selectively delete past tweets with Tweepy
How to upload with Heroku, Flask, Python, Git (4)
How to deal with memory leaks in matplotlib.pyplot