[PYTHON] I tried batch normalization with PyTorch (+ note)

Introduction

I didn't understand batch normalization well, so I tried it with PyTorch. As a result, I understood that the input data should be aligned with an average of 0 and a variance of 1 for each column. Also, there are some notes that I noticed when I moved it, so make a note of it.

Try

First import

import torch
import torch.nn.functional as F
from torch import nn

Determine the input data size and generate the value appropriately

input_samples = 100
input_features = 10 
x = torch.rand((input_samples,input_features)) * 10

Although not very disjointed, it produces data with different column means and variances.

average

torch.mean(x, 0)
tensor([5.0644, 5.0873, 5.0446, 5.3872, 5.2406, 5.3518, 5.3203, 4.9909, 5.0590,
        5.2169])

Distributed

torch.var(x, 0)
tensor([ 9.4876,  8.6519,  8.4050,  9.8280, 10.0146,  8.6054,  7.0800,  8.6111,
         7.7851,  8.5604])

Let's apply batch normalization

batch_norm=nn.BatchNorm1d(input_features)
y = batch_norm(x)

average

torch.mean(y, 0)
tensor([ 1.9073e-08,  5.2452e-08, -4.7684e-09,  3.8743e-08, -3.8147e-08,
         4.1723e-08, -7.8678e-08, -5.9605e-08,  5.7220e-08,  4.2915e-08],
       grad_fn=<MeanBackward1>)

⇒ Almost 0.

Distributed

torch.var(y, 0)
tensor([1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101, 1.0101,
        1.0101], grad_fn=<VarBackward1>)

⇒ Almost 1.

important point

List the notes that I noticed by changing the input in various ways.

If there is only one input data, a calculation error will occur.

In case of one, the following error is output.

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 10])

In the case of one case, the mean is the data itself and the variance is 0, so there is no point in calculating it.

If all the data in the columns are the same, the variance is naturally 0.

To make the mean 0, all the values ​​in the column are 0, and of course the variance is 0 instead of 1.

Variance does not become 1 when there is little input data

If there are 3 data items, it will be `1.5```, if there are 10 items, it will be `1.111```, etc. It will be 1 as the data size increases. I haven't delved into it in detail, but it seems to be due to the formula, so please take a look at the documentation.

Other

If batch normalization is done immediately after input, can it be used for normalization of input data? When I searched, I found the following Q & A

https://www.366service.com/jp/qa/9a05f9f614c8ca449ef8693928b7921c

It is easier and more efficient to calculate the mean and variance of the entire sample once, but it is certainly true. That's true, but it's a hassle!

Recommended Posts

I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch
I tried implementing DeepPose with PyTorch PartⅡ
I tried to implement CVAE with PyTorch
I tried to detect Mario with pytorch + yolov3
I tried to implement reading Dataset with PyTorch
I tried fp-growth with python
I tried scraping with Python
I tried to move Faster R-CNN quickly with pytorch
I tried to implement and learn DCGAN with PyTorch
I tried Learning-to-Rank with Elasticsearch!
I tried clustering with PyCaret
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I tried gRPC with Python
I tried scraping with python
I made Word2Vec with Pytorch
I tried to implement SSD with PyTorch now (Dataset)
PyTorch Learning Note 2 (I tried using a pre-trained model)
I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement SSD with PyTorch now (model edition)
I tried trimming efficiently with OpenCV
I tried summarizing sentences with summpy
I tried web scraping with python.
I tried moving food with SinGAN
I implemented Attention Seq2Seq with PyTorch
I tried to explain Pytorch dataset
I tried face detection with MTCNN
I tried running prolog with python 3.8.2.
I tried SMTP communication with Python
I tried sentence generation with GPT-2
I tried learning LightGBM with Yellowbrick
I tried face recognition with OpenCV
I tried to implement sentence classification by Self Attention with PyTorch
I tried multiple regression analysis with polynomial regression
I tried sending an SMS with Twilio
I tried using Amazon SQS with django-celery
I implemented Shake-Shake Regularization (ShakeNet) with PyTorch
I tried to implement Autoencoder with TensorFlow
I tried to visualize AutoEncoder with TensorFlow
I tried to get started with Hy
I tried scraping Yahoo News with Python
I tried using Selenium with Headless chrome
I tried factor analysis with Titanic data!
I tried learning with Kaggle's Titanic (kaggle②)
I tried sending an email with python.
I tried non-photorealistic rendering with Python + opencv
I tried a functional language with Python
[Introduction to Pytorch] I played with sinGAN ♬
I tried recursion with Python ② (Fibonacci sequence)
I tried playing with the image with Pillow
I tried to solve TSP with QAOA
I tried simple image recognition with Jupyter
I tried CNN fine tuning with Resnet
I tried natural language processing with transformers.
#I tried something like Vlookup with Python # 2
I tried handwriting recognition of runes with scikit-learn
I tried to predict next year with AI
I tried "smoothing" the image with Python + OpenCV
I tried to use lightGBM, xgboost with Boruta
I tried to learn logical operations with TF Learn
I tried to move GAN (mnist) with keras