Ich mache Deep Learning in meiner Forschung, aber neulich habe ich erfahren, dass das Chainer-Update zu Ende geht, und habe beschlossen, das Framework sowohl auf Pytorch als auch auf den Entwickler zu ändern. Als Ausgangspunkt habe ich mich entschlossen, vom bestehenden Chainer-Programm nach Pytorch zu portieren.
Grundsätzlich musste ich nur den Namen der Funktion ändern, aber unterwegs bemerkte ich, dass Pytorch kein Hard Sigmoid hatte. Also machen wir es uns selbst.
... aber es steht in der offiziellen Referenz, also habe ich es fast genau gemacht. --> https://pytorch.org/docs/master/autograd.html
python
class MyHardSigmoid(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
result = (0.2 * i + 0.5).clamp(min=0.0, max=1.0)
return result
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
result, = ctx.saved_tensors
grad_input *= 0.2
grad_input[result < -2.5] = 0
grad_input[result > -2.5] = 0
return grad_input
Wenn Sie @staticmethod nicht schreiben, wird eine Warnung angezeigt. Die offizielle ist eine Exponentialfunktion, aber wir werden dies in hartes Sigmoid ändern.
Zunächst wird forward () vorwärts weitergegeben. hard sigmoid () hat die folgende Formel, also habe ich es so geschrieben, dass es so wäre.
h(x) = \left\{
\begin{array}{ll}
0 & (x \lt -2.5) \\
0.2x + 0.5 & (-2.5 \leq x \leq 2.5) \\
1 & (2.5 \lt x)
\end{array}
\right.
Dann schreibt dies mit backward () die Backpropagation. Der Differentialkoeffizient ist wie folgt.
\frac{\partial h(x)}{\partial x} = \left\{
\begin{array}{ll}
0 & (x \lt -2.5) \\
0.2 & (-2.5 \leq x \leq 2.5) \\
0 & (2.5 \lt x)
\end{array}
\right.
Und schließlich wenden Sie dies auf das Modell an. (Der Inhalt des Modells ist angemessen.)
model.py
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
hard_sigmoid = MyHardSigmoid.apply
return hard_sigmoid(self.conv2(x))
Dies ist perfekt! !! ... sollte sein
Recommended Posts