[PYTHON] A note on the default behavior of collate_fn in PyTorch

What is collate_fn?

Many people use DataLoader when loading datasets with PyTorch. (There are many good articles on how to use DataLoader. For example, this article is easy to understand.)

collate_fn is one of the arguments given to the constructor when creating a DataLoader instance, and has the role of grouping the individual data retrieved from the dataset into a mini-batch. More specifically, collate_fn is from the ** dataset, as described in the Official Documentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn). Enter the list of retrieved data **. Then, the return value of collate_fn will be output from DataLoader.

Therefore, when reading data from your own dataset with DataLoader, you can handle it by creating collate_fn as shown in the example below.

def simple_collate_fn(list_of_data):
    #Here we assume that each piece of data is a D-dimensional vector.
    tensors = [torch.FloatTensor(data) for data in list_of_data]
    #Combine the newly added dimensions into a mini-batch into an N x D matrix.(N is the number of data)
    batched_tensor = tensor.stack(tensors, dim=0)
    #This return value is
    # for batched_tensor in dataloader:
    #It is output from DataLoader as follows.
    return batched_tensor

Default behavior of collate_fn

In order to simplify the implementation, I would like to avoid implementing my own collate_fn if the default behavior without giving collate_fn can be used.

When I looked it up, collate_fn is quite sophisticated even by default, and it seems that it is not just a combination of tensors liketorch.stack (*, dim = 0), so this time as a memorandum this default I would like to summarize the functions.

Official documentation

In fact, the default behavior of collate_fn is well documented in the Official Documentation (https://pytorch.org/docs/stable/data.html#dataloader-collate-fn).

  • It always prepends a new dimension as the batch dimension.
  • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.
  • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.

In other words, it seems to have the following functions.

I was particularly surprised because I had never heard of the existence of the third function. (I'm embarrassed to implement a simple collate_fn that batches multiple data vectors each ...)

Take a look at the implementation

However, since you can not understand the detailed behavior without actually looking at the implementation, [Actual implementation](https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/data/_utils I would like to take a look at (/collate.py).

I think it's the quickest to actually read it, but I'll summarize it roughly so that you don't have to read the implementation again when you check it again in the future.

Information as of version 1.5.

Case classification by type

The default collate_fn, default_collate, is a recursive process, and the process is classified according to the type of the first element of the argument batch.

elem = batch[0]
elem_type = type(elem)

Below, we will summarize the specific processing by the type of ʻelem`.

torch.Tensor

If batch is torch.Tensor, it simply adds one dimension first and joins.

return torch.stack(batch, 0)

Type of numpy

In the case of ndarray of numpy, it is tensorized and then combined as in the case of torch.Tensor.

return default_collate([torch.as_tensor(b) for b in batch])

On the other hand, in the case of numpy scalar, the current batch is a vector, so it is tensorized as it is.

return torch.as_tensor(batch)

float, int, str

In this case as well, batch is a vector, so it is returned as a tensor or list as shown below.

# float
return torch.tensor(batch, dtype=torch.float64)
# int
return torch.tensor(batch)
# str
return batch

Classes that inherit from collections.abc.Mapping such as dict

As shown below, each key is batched and returned as the original key value.

return {key: default_collate([d[key] for d in batch]) for key in elem}

namedtuple

In this case as well, batch processing is performed for each attribute while retaining the same attribute name as the original namedtuple.

return elem_type(*(default_collate(samples) for samples in zip(*batch)))

Classes that inherit from collections.abc.Sequence such as list

Batch processing is performed for each element as shown below.

transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

Concrete example

For example, try reading a dataset with a complex structure that includes dictionaries and strings as shown below with the default collate_fn.

import numpy as np
from torch.utils.data import DataLoader

if __name__=="__main__":
    complex_dataset = [
        [0, "Bob", {"height": 172.5, "feature": np.array([1,2,3])}],
        [1, "Tom", {"height": 153.1, "feature": np.array([3,2,1])}]
    ]
    dataloader = DataLoader(complex_dataset, batch_size=2)
    for batch in dataloader:
        print(batch)

Then, you can confirm that it is successfully batched as follows.

[
    tensor([0, 1]),
    ('Bob', 'Tom'),
    {
        'height': tensor([172.5000, 153.1000], dtype=torch.float64),
        'feature': tensor([[1, 2, 3],[3, 2, 1]])
    }
]

By the way, python's float is converted to torch.float64 by default. Normally, numpy.ndarray is used to represent vectors and tensors, so I think there is no problem, but if you don't know it, you'll fall into a trap.

Recommended Posts

A note on the default behavior of collate_fn in PyTorch
[Note] Import of a file in the parent directory in Python
Find the rank of a matrix in the XOR world (rank of a matrix on F2)
[python] A note that started to understand the behavior of matplotlib.pyplot
Get the number of readers of a treatise on Mendeley in Python
Check the behavior of destructor in Python
Get the caller of a function in Python
The behavior of signal () depends on the compile options
A note on customizing the dict list class
Make a copy of the list in Python
A note on optimizing blackbox functions in Python
Find the number of days in a month
A note about the python version of python virtualenv
Calculate the probability of outliers on a boxplot
[Note] About the role of underscore "_" in Python
About the behavior of Model.get_or_create () of peewee in Python
Output in the form of a python array
Isn't there a default value in the dictionary?
In Python, change the behavior of the method depending on how it is called
A brief note on the anger caused by scraping
Make a note of the list of basic Pandas usage
Difference in results depending on the argument of multiprocess.Process
A note on handling variables in Python recursive functions
Write a log-scale histogram on the x-axis in python
A note of trying a simple MCMC tutorial on PyMC3
A Study on Visualization of the Scope of Prediction Models
A reminder about the implementation of recommendations in Python
Create a shape on the trajectory of an object
Python Note: The mystery of assigning a variable to a variable
[Example of Python improvement] I learned the basics of Python on a free site in 2 weeks.
A note on the library implementation that explores hyperparameters using Bayesian optimization in Python
rsync Behavior changes depending on the presence or absence of the slash in the copy source
Code that sets the default value in case of AttributeError
Find out the apparent width of a string in python
Reuse the behavior of the @property method by using a descriptor [16/100]
Make a note of what you want to do in the future with Raspberry Pi
Survey on the use of machine learning in real services
A note for embedding the scripting language in a bash script
Analyzing data on the number of corona patients in Japan
Note 2 for embedding the scripting language in a bash script
Count the number of characters in the text on the clipboard on mac
Get the number of specific elements in a python list
A note on touching Microsoft's face recognition API in Python
A note on how to load a virtual environment in PyCharm
A memo about the behavior of bowtie2 during multiple hits
[Note] A shell script that checks the CPU usage of a specific process in a while loop.
I tried a little bit of the behavior of the zip function
Find the eigenvalues of a real symmetric matrix in Python
Yield in a class that inherits unittest.TestCase didn't work with nose (depending on the version of nose?)
A memo that reproduces the slide show (gadget) of Windows 7 on Windows 10.
Verification of the spread of hoaxes in the "State of Emergency Declaration on April 1"
Process the contents of the file in order with a shell script
How to determine the existence of a selenium element in Python
A note on how to check the connection to the license server port
Build a Selenium environment on Amazon Linux 2 in the shortest time
Why put a slice on the left side in the substitution formula
How to check the memory size of a variable in Python
On Linux, the time stamp of a file is a little past.
If you give a list with the default argument of the function ...
Read the standard output of a subprocess line by line in Python
How to check the memory size of a dictionary in Python