[PYTHON] [PyTorch] Ein wenig Verständnis von CrossEntropyLoss mit mathematischen Formeln

Einführung

Weil `Kriterium = torch.nn.CrossEntropyLoss ()` häufig als Grundlage für die Verlustfunktion von Pytorch verwendet wird. Es wird ausgegeben, um die Details zu verstehen. Wenn Sie einen Fehler machen, lassen Sie es mich bitte wissen.

CrossEntropyLoss Pytorch-Beispiel (1)

torch.manual_seed(42) #Fixes Saatgut zur Aufrechterhaltung der Reproduzierbarkeit
loss = nn.CrossEntropyLoss()
input_num = torch.randn(1, 5, requires_grad=True)
target = torch.empty(1, dtype=torch.long).random_(5)
print('input_num:',input_num)
print('target:',target)
output = loss(input_num, target)
print('output:',output)
input_num: tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229]], requires_grad=True)
target: tensor([0])
output: tensor(1.3472, grad_fn=<NllLossBackward>)

Unter der Annahme, dass die richtige Antwortklasse $ class $ und die Anzahl der Klassen $ n $ ist, kann der Fehler $ loss $ von CrossEntropyLoss durch die folgende Formel ausgedrückt werden.

loss=-\log(\frac{\exp(x[class])}{\sum_{j=0}^{n} \exp(x[j])}) \\
=-(\log(\exp(x[class])- \log(\sum_{j=0}^{n} \exp(x[j])) \\
=-\log(\exp(x[class]))+\log(\sum_{j=0}^{n} \exp(x[j])) \\
=-x[class]+\log(\sum_{j=0}^{n} \exp(x[j])) \\

Aus dem Quellcodebeispiel ist die richtige Antwortklasse $ class = 0 $ und die Anzahl der Klassen ist $ n = 5 $. Wenn Sie dies also überprüfen

loss=-x[0]+\log(\sum_{j=0}^{5} \exp(x[j]))\\
=-x[0]+\log(\exp(x[0])+\exp(x[1])+\exp(x[2])+\exp(x[3])+\exp(x[4])) \\
= -0.3367 + \log(\exp(0.3367)+\exp(0.1288)+\exp(0.2345)+\exp(0.2303)+\exp(-1.1229)) \\
= 1.34717 \cdots \\
\fallingdotseq 1.34712

Es stimmte sicher mit dem Ergebnis des Programms überein! Die Berechnung erfolgt übrigens mit folgendem Code (manuelle Berechnung ist nicht möglich ...)

from math import exp, log
x_sum = exp(0.3367)+exp( 0.1288)+exp(0.2345)+exp(0.2303)+exp(-1.1229)
x = 0.3367
ans = -x + log(x_sum)
print(ans) # 1.3471717976017477

Es ist ein Stoß.

schließlich

Rundungsfehler (Erzeugung von Kreisbrüchen aufgrund der binären Anzeige von Dezimalstellen) scheinen jetzt unnötig zu sein. Normalerweise ist es `random.seed (42)`, aber mit Pytorch ist es torch.manual_seed (42) `, also fühlt es sich an wie.

Verweise

(1)TORCH.NN

Recommended Posts

[PyTorch] Ein wenig Verständnis von CrossEntropyLoss mit mathematischen Formeln
LiNGAM (ICA-Version) mit mathematischen Formeln und Python zu verstehen
Vorhersage des Nikkei-Durchschnitts mit Pytorch 2
Vorhersage des Nikkei-Durchschnitts mit Pytorch
Ein bisschen im Kettenschiff stecken
Vorhersage des Nikkei-Durchschnitts mit Pytorch ~ Makuma ~
Eine Sammlung von Tipps zur Beschleunigung des Lernens und Denkens mit PyTorch
[PyTorch] Warum Sie eine Instanz von CrossEntropyLoss () wie eine Funktion behandeln können
Multi-Class Multi-Label-Klassifizierung von Bildern mit Pytorch
Eine kleine Nischenfunktion Einführung von Faiss
Versuchen Sie eine Formel mit Σ mit Python
Vollständiges Verständnis der asynchronen Python-Programmierung
Ein grobes Verständnis von Python-Feuer und ein Memo
Eine kleine Überprüfung von Pandas 1.0 und Dask
Summe der Variablen in einem mathematischen Modell
Memorandum zu Djangos QueryDict
Machen Sie ein Zeichnungsquiz mit kivy + PyTorch
Vollständiges Verständnis der objektorientierten Programmierung von Python
Memorandum für die Migration mit GORM
[AtCoder] Lösen Sie ein Problem von ABC101 ~ 169 mit Python
Löse A ~ D des Yuki-Codierers 247 mit Python
Die Geschichte des Versuchs, Tensorboard mit Pytorch zu verwenden
Holen Sie sich eine Liste der IAM-Benutzer mit Boto3
Die Jobplanung ist bei AP Schuler etwas fortgeschritten
[Python] Ein grobes Verständnis des Protokollierungsmoduls
Ablauf beim Erstellen einer virtuellen Umgebung mit Anaconda
[PyTorch] Ich war ein wenig verloren in torch.max ()
[Python] Ein grobes Verständnis von Iterablen, Iteratoren und Generatoren
Erstellen Sie eine Tabelle mit IPython Notebook
Mit der Docker-Version der Nginx-Einheit war es ein wenig schwierig, eine Flasche zu machen
Artikel, der Ihnen hilft, den Kollisionsalgorithmus für starre Kugeln ein wenig zu verstehen