Parce que j'utilise souvent `` critère = torch.nn.CrossEntropyLoss () '' comme base de la fonction de perte de Pytorch. Il est produit pour comprendre les détails. Si vous faites une erreur, faites-le moi savoir.
CrossEntropyLoss Exemple Pytorch (1)
torch.manual_seed(42) #Graine fixe pour maintenir la reproductibilité
loss = nn.CrossEntropyLoss()
input_num = torch.randn(1, 5, requires_grad=True)
target = torch.empty(1, dtype=torch.long).random_(5)
print('input_num:',input_num)
print('target:',target)
output = loss(input_num, target)
print('output:',output)
input_num: tensor([[ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229]], requires_grad=True)
target: tensor([0])
output: tensor(1.3472, grad_fn=<NllLossBackward>)
En supposant que la classe de réponse correcte est $ class $ et que le nombre de classes est $ n $, l'erreur $ loss $ de CrossEntropyLoss peut être exprimée par la formule suivante.
loss=-\log(\frac{\exp(x[class])}{\sum_{j=0}^{n} \exp(x[j])}) \\
=-(\log(\exp(x[class])- \log(\sum_{j=0}^{n} \exp(x[j])) \\
=-\log(\exp(x[class]))+\log(\sum_{j=0}^{n} \exp(x[j])) \\
=-x[class]+\log(\sum_{j=0}^{n} \exp(x[j])) \\
À partir de l'exemple de code source, la classe de réponse correcte est $ class = 0 $ et le nombre de classes est $ n = 5 $, donc si vous la cochez
loss=-x[0]+\log(\sum_{j=0}^{5} \exp(x[j]))\\
=-x[0]+\log(\exp(x[0])+\exp(x[1])+\exp(x[2])+\exp(x[3])+\exp(x[4])) \\
= -0.3367 + \log(\exp(0.3367)+\exp(0.1288)+\exp(0.2345)+\exp(0.2303)+\exp(-1.1229)) \\
= 1.34717 \cdots \\
\fallingdotseq 1.34712
Il correspondait au résultat du programme en toute sécurité! Au fait, le calcul se fait avec le code suivant (le calcul manuel est impossible ...)
from math import exp, log
x_sum = exp(0.3367)+exp( 0.1288)+exp(0.2345)+exp(0.2303)+exp(-1.1229)
x = 0.3367
ans = -x + log(x_sum)
print(ans) # 1.3471717976017477
C'est une poussée.
L'erreur d'arrondi (génération de fractions circulaires due à l'affichage binaire des points décimaux) semble désormais inutile.
Normalement, c'est random.seed (42) '', mais avec Pytorch c'est
torch.manual_seed (42) '', donc c'est comme ça.
(1)TORCH.NN
Recommended Posts