[PYTHON] Cut the PyTorch calculation graph

What do you want to do?

Let's break the calculation graph and define the original derivative

What to do this time

I think that there are various situations where the calculation graph is cut off, but I do not know whether it is suitable for a complicated function or not, so this time I prepared a function that evaluates with eval by converting a simple function into a character string. To do. No matter how simple the calculation is, PyTorch can't take care of it, so the calculation graph will be cut off.

Also, if you do so, you will have to define the derivative yourself, but this time the forward difference method $\frac{\partial f(x, w)}{\partial w} = \frac{f(x, w + \Delta w) - f(x, w)}{\Delta w}$ I will use.

Disadvantage

The code is complicated by not using automatic differentiation, which is one of the reasons for using PyTorch. In this case, you don't have to do it and the speed will be slower.

advantage

There is no advantage to doing this time. However, if you really want to cut the calculation graph, there is an irreplaceable advantage that you can do it for the time being. On the other hand, if it is absolutely impossible, I think that it will be like rethinking the model, rethinking whether you really have to do it using PyTorch, but since you can do it for the time being, you will lose the opportunity to rethink it. ..

If you still want to do it, please. (I still had to do it)

** Let's do it **

Consider the following function

Consider something insanely simple.

Consider a function like this


#Both are the same, but f_str is like this, PyTorch can not take care of differentiation.
#Input is x,Both w are assumed to be PyTorch Tensor type.
def f(x, w):
    return 2 * x * w[0] + x**2 * w[1]

def f_str(x, w):
    return torch.tensor([eval(f'2 * {x_} * {w[0]} + {x_}**2 * {w[1]}') for x_ in x])

f is just what I saw. f_str is the same as f, but it is calculated by converting it to a character string and reinterpreting it into a Python expression with eval. Considering the possibility that the input x will come in batch, the f_str will once disassemble the contents and recreate the tensor.

smart automatic differentiation of torch

f can be automatically differentiated


x = torch.tensor([1.])
w = torch.tensor([1., 1.]).requires_grad_()

f(x, w) # => tensor([3.], grad_fn=<AddBackward0>)

y.backward()
w.grad # => tensor([2., 1.])

PyTorch is smart, so you can see everything you did with f, and w.grad will be done automatically after y.backward (). I think this is one of the reasons for using machine learning frameworks such as PyTorch.

Stupid function that cannot be automatically differentiated

You may remember that I made f_str very stupid earlier. In this case, you will face the following grief.

f_str cannot be automatically differentiated


x = torch.tensor([1.])
w = torch.tensor([1., 1.]).requires_grad_()

f_str(x, w) # => tensor([3.]) grad_There is no fn! !! !! !!
y.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

w.grad # => None

I thought that there was no grad_fn, but when I backwarded it, I got an error, grad was not set, and it was terrible.

First, let's learn the person who made it normally

Let's make data

Data creation


actual_w = 1.2, -3.4

xs = np.random.rand(200).astype(np.float32)
ys = np.array([f(x, actual_w) for x in xs], dtype=np.float32)
train_d = torch.utils.data.TensorDataset(torch.from_numpy(xs), torch.from_numpy(ys))
train_loader = torch.utils.data.DataLoader(train_d, batch_size=10)

v_xs = np.random.rand(10).astype(np.float32)
v_ys = np.array([f(x, actual_w) for x in v_xs], dtype=np.float32)
valid_d = torch.utils.data.TensorDataset(torch.from_numpy(v_xs), torch.from_numpy(v_ys))
valid_loader = torch.utils.data.DataLoader(valid_d, batch_size=1)

Set the value appropriately as true w. Then, random numbers are given to make a pair of x andf (x, true w). It would be nice to put a Gaussian random number on this, but it's troublesome, so I won't do that this time. This is an overkill to use PyTorch. I think that scipy.minimize.optimize or something like that is enough.

Let's learn normally

Let's learn f


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.parameter.Parameter(torch.tensor([0., 0.]))
        
    def forward(self, x):
        return f(x, self.weight)

model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = torch.nn.MSELoss()
loss_hist = []
model.train()
for epoch in range(20):
    for i, (xs, l) in enumerate(train_loader):
        out = model(xs)
        loss = criterion(out, l)
        loss_hist.append(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(epoch, loss, model.weight)

The model.weight eventually became 1.1996, -3.3987. Since I set 1.2, -3.4 as the true value, it is almost the same. The loss during training and the loss in the data prepared for validation (successful when all are almost 0)

Screenshot_20200624_235842.png

It became like. Sounds pretty good.

Try changing f to f_str

Now let's change f to the fucking function f_str.

Screenshot_20200625_000201.png

It's a rather confusing error. It's easy to say. Since the calculation graph was cut off by f_str and automatic differentiation became impossible, it is moss by backward. So you have to define the derivative yourself.

Let's make a derivative only for you

The official documentation for these cases can be found here [https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). However, when I actually tried it, I couldn't reach the itchy place, so I would like to introduce you to some places.

First the code, then the commentary

Define the derivative by forward difference for a general function


class GeneralFunctionWithForwardDifference(torch.autograd.Function):
    @staticmethod
    def forward(ctx, f, xs, weight):
        ys = f(xs, weight)
        ctx.save_for_backward(xs, ys, weight)
        ctx.f = f #Actually, you can save something in ctx and use it backward
        return ys
        
    @staticmethod
    def backward(ctx, grad_output):
        xs, ys, weight = ctx.saved_tensors
        f = ctx.f
        dw = 0.001
        diff = []
        weight = weight.detach() #Detach to avoid leaving extra calculation history in weight.
        for i in range(len(weight)):
            weight[i] += dw
            diff.append(torch.sum(grad_output * (f(xs, weight) - ys)))
            weight[i] -= dw
        diff = torch.tensor(diff) / dw
        return None, None, diff

Create a class that inherits torch.autograd.Function, define forward and backward with @staticmethod, and set the first argument of each to ctx (another name is also acceptable). But it's a good idea to follow this) and so on, as the documentation says.

Saving data other than tensors

The documentation says that you can save the tensor with ctx.save_for_backward, but this method only saves torch.Tensor.

But this time I want to pass f_str as an argument to forward and save it for backward. Actually, this can be saved in the form of ctx. Nachara = ..., which seems to be usable in backward. It is also used inside Pytorch, so I think it's probably okay to use it. I will.

What should backward return?

The value returned by backward corresponds to the argument of forward. Returns the derivative result of the argument of forward minus ctx.

If it corresponds to something that does not require differentiation (not a tensor or a tensor that is not required_grad = True), you can return None. This time, only w needs differentiation.

If the input is tensor $ {\ bf w} = [w_0, w_1, ..., w_ {n-1}] $, the value returned is

[\sum_i\mathrm{grad\_output}_i\frac{\partial f(x_i, {\bf w})}{\partial w_0}, \sum_i \mathrm{grad\_output}_i\frac{\partial f(x_i, {\bf w})}{\partial w_1}, ... \sum_i\mathrm{grad\_output}_i\frac{\partial f(x_i, {\bf w})}{\partial w_{n-1}}]

It will be. However, $ \ sum_i $ says that if the input x comes in a mini-batch $ [x_0, x_1, ...] $, then the results of each will be added together. The dimension of grad_output corresponds to the size of the mini-batch, so multiply the result like this.

Finally, those who use f_str can also learn

f_Learning a model using str


class Model2(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.parameter.Parameter(torch.tensor([0., 0.]))
        
    def forward(self, x):
        #It's a little annoying to write.
        return GeneralFunctionWithForwardDifference.apply(f_str, x, self.weight)

model2 = Model2()
optimizer = torch.optim.Adam(model2.parameters(), lr=0.1)
criterion = torch.nn.MSELoss()
loss_hist2 = []
model2.train()
for epoch in range(20):
    for i, (xs, l) in enumerate(train_loader):
        out = model2(xs)
        loss = criterion(out, l)
        loss_hist2.append(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(epoch, loss, model2.weight)

Note that the created Function is called in the form of .apply.

Screenshot_20200625_001619.png

A graph similar to the previous one came out so much that I thought I made a mistake. The final parameter is also 1.1996, -3.3987, which is also almost the same as the true value 1.2, -3.4. Well, I don't use random numbers, I'm doing the same data with the same parameter initial values, so that's probably the case. I don't know.

By the way, I tried to draw the loss on top of each other and to take the difference between the predicted values of validation.

Screenshot_20200625_001748.png

It's almost the same. I used a sloppy derivative, so if you think that there will be a slight difference, it seems that there is not much difference. I'm glad.

Summary

I've seen how to use a strange function that cuts through a computational graph in PyTorch by defining the derivative by yourself. I don't want to do it anymore.

I will put this notebook here [https://gist.github.com/gyu-don/f5cc025139312ccfd39e48400018118d)

Recommended Posts

Cut the PyTorch calculation graph
[PyTorch] Sample ⑨ ~ Dynamic graph ~
Connected components of the graph
Learn with PyTorch Graph Convolutional Networks
Output the call graph with PyCallGraph
Let's cut the face from the image