Je fais du Deep Learning dans mes recherches, mais l'autre jour, j'ai appris que la mise à jour de Chainer touchait à sa fin, j'ai donc décidé de changer le framework en Pytorch ainsi que le développeur. Comme point de départ, j'ai décidé de porter le programme Chainer existant vers Pytorch.
En gros, tout ce que j'avais à faire était de changer le nom de la fonction, mais en chemin, j'ai remarqué que Pytorch n'avait pas HardSigmoid. Alors faisons-le nous-mêmes.
... mais c'est écrit dans la référence officielle, donc je l'ai fait presque exactement. --> https://pytorch.org/docs/master/autograd.html
python
class MyHardSigmoid(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
result = (0.2 * i + 0.5).clamp(min=0.0, max=1.0)
return result
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
result, = ctx.saved_tensors
grad_input *= 0.2
grad_input[result < -2.5] = 0
grad_input[result > -2.5] = 0
return grad_input
Si vous n'écrivez pas @staticmethod, un avertissement apparaîtra. La fonction officielle est une fonction exponentielle, mais nous la changerons en sigmoïde dur.
Premièrement, forward () est propagé vers l'avant. hard sigmoid () a la formule suivante, donc je l'ai écrit pour qu'il le soit.
h(x) = \left\{
\begin{array}{ll}
0 & (x \lt -2.5) \\
0.2x + 0.5 & (-2.5 \leq x \leq 2.5) \\
1 & (2.5 \lt x)
\end{array}
\right.
Ensuite, avec backward (), cela écrit la rétropropagation. Le coefficient différentiel est le suivant.
\frac{\partial h(x)}{\partial x} = \left\{
\begin{array}{ll}
0 & (x \lt -2.5) \\
0.2 & (-2.5 \leq x \leq 2.5) \\
0 & (2.5 \lt x)
\end{array}
\right.
Et enfin, appliquez cela au modèle. (Le contenu du modèle est approprié.)
model.py
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
hard_sigmoid = MyHardSigmoid.apply
return hard_sigmoid(self.conv2(x))
C'est parfait! !! ... devrait être
Recommended Posts