[PYTHON] [PyTorch] A little understanding of CrossEntropyLoss with mathematical formulas

Introduction

Because we often use `criterion = torch.nn.CrossEntropyLoss ()` as the basis for Pytorch's loss function. It is output to understand the details. If you make a mistake, please let me know.

CrossEntropyLoss Refer to Pytorch sample (1)

torch.manual_seed(42) #Fixed seed to maintain reproducibility
loss = nn.CrossEntropyLoss()
input_num = torch.randn(1, 5, requires_grad=True)
target = torch.empty(1, dtype=torch.long).random_(5)
print('input_num:',input_num)
print('target:',target)
output = loss(input_num, target)
print('output:',output)
input_num: tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229]], requires_grad=True)
target: tensor([0])
output: tensor(1.3472, grad_fn=<NllLossBackward>)

Assuming that the correct class is $ class $ and the number of classes is $ n $, the error $ loss $ of CrossEntropyLoss can be expressed by the following formula.

loss=-\log(\frac{\exp(x[class])}{\sum_{j=0}^{n} \exp(x[j])}) \\
=-(\log(\exp(x[class])- \log(\sum_{j=0}^{n} \exp(x[j])) \\
=-\log(\exp(x[class]))+\log(\sum_{j=0}^{n} \exp(x[j])) \\
=-x[class]+\log(\sum_{j=0}^{n} \exp(x[j])) \\

From the source code sample, the correct class is $ class = 0 $ and the number of classes is $ n = 5 $, so if you check it

loss=-x[0]+\log(\sum_{j=0}^{5} \exp(x[j]))\\
=-x[0]+\log(\exp(x[0])+\exp(x[1])+\exp(x[2])+\exp(x[3])+\exp(x[4])) \\
= -0.3367 + \log(\exp(0.3367)+\exp(0.1288)+\exp(0.2345)+\exp(0.2303)+\exp(-1.1229)) \\
= 1.34717 \cdots \\
\fallingdotseq 1.34712

It matched the result of the program safely! By the way, the calculation is done with the following code (manual math is impossible ...)

from math import exp, log
x_sum = exp(0.3367)+exp( 0.1288)+exp(0.2345)+exp(0.2303)+exp(-1.1229)
x = 0.3367
ans = -x + log(x_sum)
print(ans) # 1.3471717976017477

It is a push.

at the end

Rounding error (occurrence of recurring decimals due to decimal point display) does not need to be considered much now. Normally it is `random.seed (42) ```, but with Pytorch it is torch.manual_seed (42) ``, so it feels like.

References

(1)TORCH.NN

Recommended Posts

[PyTorch] A little understanding of CrossEntropyLoss with mathematical formulas
Understanding with mathematical formulas and Python LiNGAM (ICA version)
Prediction of Nikkei 225 with Pytorch 2
Prediction of Nikkei 225 with Pytorch
A little stuck with chainer
Make a Linux version of OpenSiv3D with find_package a little easier
Prediction of Nikkei 225 with Pytorch ~ Intermission ~
A collection of tips for speeding up learning and reasoning with PyTorch
"Manim" that can draw animation of mathematical formulas and graphs with Python
[PyTorch] Why you can treat an instance of CrossEntropyLoss () like a function
Multi-class, multi-label classification of images with pytorch
A little niche feature introduction of faiss
Try mathematical formulas using Σ with python
A complete understanding of Python's asynchronous programming
A rough understanding of python-fire and a memo
A little scrutiny of pandas 1.0 and dask
Sum of variables in a mathematical model
A memorandum of understanding about django's QueryDict
Make a drawing quiz with kivy + PyTorch
A complete understanding of Python's object-oriented programming
Memorandum of Understanding when migrating with GORM
[AtCoder] Solve A problem of ABC101 ~ 169 with Python
Solve A ~ D of yuki coder 247 with python
Story of trying to use tensorboard with pytorch
Get a list of IAM users with Boto3
A little advanced job scheduling with AP Scheduler
[Python] A rough understanding of the logging module
Flow of creating a virtual environment with Anaconda
[PyTorch] I was a little lost in torch.max ()
[Python] A rough understanding of iterators, iterators, and generators
Create a table of contents with IPython notebook
It was a little difficult to do flask with the docker version of nginx-unit
An article that gives you a little understanding of the rigid sphere collision algorithm