Weil `Kriterium = torch.nn.CrossEntropyLoss ()`
häufig als Grundlage für die Verlustfunktion von Pytorch verwendet wird.
Es wird ausgegeben, um die Details zu verstehen.
Wenn Sie einen Fehler machen, lassen Sie es mich bitte wissen.
CrossEntropyLoss Pytorch-Beispiel (1)
torch.manual_seed(42) #Fixes Saatgut zur Aufrechterhaltung der Reproduzierbarkeit
loss = nn.CrossEntropyLoss()
input_num = torch.randn(1, 5, requires_grad=True)
target = torch.empty(1, dtype=torch.long).random_(5)
print('input_num:',input_num)
print('target:',target)
output = loss(input_num, target)
print('output:',output)
input_num: tensor([[ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229]], requires_grad=True)
target: tensor([0])
output: tensor(1.3472, grad_fn=<NllLossBackward>)
Unter der Annahme, dass die richtige Antwortklasse $ class $ und die Anzahl der Klassen $ n $ ist, kann der Fehler $ loss $ von CrossEntropyLoss durch die folgende Formel ausgedrückt werden.
loss=-\log(\frac{\exp(x[class])}{\sum_{j=0}^{n} \exp(x[j])}) \\
=-(\log(\exp(x[class])- \log(\sum_{j=0}^{n} \exp(x[j])) \\
=-\log(\exp(x[class]))+\log(\sum_{j=0}^{n} \exp(x[j])) \\
=-x[class]+\log(\sum_{j=0}^{n} \exp(x[j])) \\
Aus dem Quellcodebeispiel ist die richtige Antwortklasse $ class = 0 $ und die Anzahl der Klassen ist $ n = 5 $. Wenn Sie dies also überprüfen
loss=-x[0]+\log(\sum_{j=0}^{5} \exp(x[j]))\\
=-x[0]+\log(\exp(x[0])+\exp(x[1])+\exp(x[2])+\exp(x[3])+\exp(x[4])) \\
= -0.3367 + \log(\exp(0.3367)+\exp(0.1288)+\exp(0.2345)+\exp(0.2303)+\exp(-1.1229)) \\
= 1.34717 \cdots \\
\fallingdotseq 1.34712
Es stimmte sicher mit dem Ergebnis des Programms überein! Die Berechnung erfolgt übrigens mit folgendem Code (manuelle Berechnung ist nicht möglich ...)
from math import exp, log
x_sum = exp(0.3367)+exp( 0.1288)+exp(0.2345)+exp(0.2303)+exp(-1.1229)
x = 0.3367
ans = -x + log(x_sum)
print(ans) # 1.3471717976017477
Es ist ein Stoß.
Rundungsfehler (Erzeugung von Kreisbrüchen aufgrund der binären Anzeige von Dezimalstellen) scheinen jetzt unnötig zu sein.
Normalerweise ist es `random.seed (42)`
, aber mit Pytorch ist es torch.manual_seed (42)
`, also fühlt es sich an wie.
(1)TORCH.NN
Recommended Posts