[PYTHON] How to use the library "torchdiffeq" that implements Neural ODE's ODE Block

1.First of all

It's very new, but I'll show you how to use the Neural ODE implementation library. By the way, Neural ODE is the best paper of NeuroIPS 2018.

This Neural ODE, authors have published a library of official repositories called torchdiffeq.

Although Neural ODE has many articles explaining the theory and interpretation, I thought that there were few Japanese articles describing the actual usage of this library, so I summarized the basic usage in this article. By the way, torchdiffeq is a library for PyTorch.

By reading this article, you will be able to:

--Torchdiffeq can solve the initial value problem of first-order ordinary differential equations --Torch diffeq can solve the initial value problem of second-order ordinary differential equations --The ODE Block layer, which is the layer that makes up Neural ODE, can be implemented.

2. Prerequisite knowledge

I will review the prerequisite knowledge for using torchdiffeq.

2.1 What is an ordinary differential equation?

Of the differential equations, the one with essentially only one unknown variable is called an ordinary differential equation. For example, differential equations such as $ \ frac {dz} {dt} = f (z (t), t) $ and $ m \ ddot {x} = -kx $ have multiple variables, but $ Since z $ and $ x $ are functions of $ t $, there is essentially only one unknown variable, $ t $, and it can be said that it is an ordinary differential equation.

2.2 What is Neural ODE?

There are many other easy-to-understand articles about Neural ODE, so please refer to them.

To briefly explain the outline, Neural ODE can be said to be a "neural network with continuous layers".

There is a theory that it is difficult to grasp the concept because there is a word called "ordinary differential equation" that is not heard in the neural network area, but I think the important point is that "layers are continuous". .. This makes it possible to, for example, "take out the output of the 0.5th layer", which was not possible with the conventional model.

Reference link below: -Neural network capable of expressing continuous dynamics -[NIPS 2018 Grand Prize Paper] From the University of Toronto: A completely new Neural Network model that connects the middle layers in a differentiable continuous space

3. How to use torchdiffeq

Now let's see how to use torchdiffeq.

3.1 Installation

To install, execute the following command.

pip install torchdiffeq

3.2 Example: First-order differential equation

Before actually implementing Neural ODE, let's take a simple first-order ODE as an example to see how to use torchdiffeq easily.

Consider the following differential equation. $ z(0) = 0, \\\ \frac{dz(t)}{dt} = f(t) = t $

This solution uses $ C $ as the constant of integration

\int dz = \int tdt+C \\\ z(t) = \frac{t^2}{2} + C

Since $ z (0) = 0 , we can see that the solution of this differential equation is as follows. $ z(t) = \frac{t^2}{2} $$

Implementation by torchdiffeq

The simplest implementation to solve this problem with torchdiffeq is below.

first_order.py


from torchdiffeq import odeint

def func(t, z):
    return t

z0 = torch.Tensor([0])
t = torch.linspace(0,2,100)
out = odeint(func, z0, t)

Below, the points are itemized.

--The function func corresponds to $ f $ of the above differential equation $ \ frac {dz} {dt} = f (t, z) $. The arguments are in the order (t, z). The dimensions of the output must match z. You do not have to use either z or t. --The variable z0 is the initial value of dynamics. --The variable t represents time and must be a one-dimensional tensor, such astensor ([0., 0.1, 0.2, ..., 0.9, 1.0]). t [0] is the time corresponding to the initial value.

――It should be noted that the element of t must be a column that increases (decreases) monotonously in a narrow sense. An error will occur even if the same value is included, such as t = tensor ([0, 0, 1]). --Solve the differential equation with ʻodeint (func, z0, t) . The arguments are the function func, the initial value y0, and the time tin order. --The solver returns the value ofz at the time specified by t. That is, when t = tensor ([t0, t1, ..., tn]), ʻout = tensor ([z0, z1, ..., zn]). Since t [0] is the initial time, the output ʻout [0]always matchesz0`.

Plot the above results.

from matplotlib.pyplot as plt

plt.plot(t, out)
plt.axes().set_aspect('equal', 'datalim')  #Aspect ratio 1:Set to 1
plt.grid()
plt.xlim(0,2)
plt.show()

first_order.png

You can see that it matches the solution of the differential equation obtained by hand, $ z = \ frac {t ^ 2} {2} $.

3.3 (Reference) Example 2: Solving the second-order differential equation

Using torchdiffeq also solves second-order differential equations. As an example, we solve the (?) Simple vibration differential equation, which is familiar to science, with torchdiffeq. The differential equation of simple vibration is as follows. $ m\ddot{x} = -kx $ In the initial state, when $ t = 0 $, $ x = 1 $, $ \ dot {x} = \ frac {dx} {dt} = 0 $. The trick to solving the second-order differential equation is to decompose the second-order differential equation into two first-order differential equations. Specifically, do as follows.

\left[ \begin{array}{c} \dot{x} \\\ \ddot{x} \\\ \end{array} \right] = \left[ \begin{array}{cc} 0 & 1\\\ -\frac{k}{m} & 0\\\ \end{array} \right] \left[ \begin{array}{c} x \\\ \dot{x} \\\ \end{array} \right]

Here, $ \ boldsymbol {y} = \ left [ \begin{array}{c} x \
\dot{x} \
\end{array} If you put \ right] $, this second-order differential equation will result in the following first-order differential equation.

\frac{d\boldsymbol{y}}{dt} = f(\boldsymbol{y})

The implementation is as follows. $ k = 1, m = 1 $.

oscillation.py


class Oscillation:
    def __init__(self, km):
        self.mat = torch.Tensor([[0, 1],
                                 [-km, 0]])

    def solve(self, t, x0, dx0):
        y0 = torch.cat([x0, dx0])
        out = odeint(self.func, y0, t)
        return out

    def func(self, t, y):
        # print(t)
        out = y @ self.mat  # @Is a matrix product
        return out

if __name__=="__main__":
    x0 = torch.Tensor([1])
    dx0 = torch.Tensor([0])

    import numpy as np
    t = torch.linspace(0, 4 * np.pi, 1000)
    solver = Oscillation(1)
    out = solver.solve(t, x0, dx0)

When you draw it, you can see that the solution of simple vibration is properly obtained. osillation.png

4. Implementation of ODE Block

Now that you are familiar with how to use torchdiffeq, let's see how to actually implement ODE Block. The ODE Block is a module that forms the dynamics of $ \ frac {dz} {dt} = f (t, z) $. The actual Neural ODE is constructed using ODE Block together with the normal Full-Connect layer and convolution layer.

The following implementation emphasizes simplicity and is just an example.

from torchdiffeq import odeint_adjoint as odeint

class ODEfunc(nn.Module):
    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.seq = nn.Sequential(nn.Linear(dim, 124),
                                 nn.ReLU(),
                                 nn.Linear(124, 124),
                                 nn.ReLU(),
                                 nn.Linear(124, dim),
                                 nn.Tanh())

    def forward(self, t, x):
        out = self.seq(x)
        return out


class ODEBlock(nn.Module):
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time)
        return out[1]  # out[0]Because the initial value is included in.

To briefly explain,

--ODE Block treats the received input x as the initial value of the differential equation. ―― ʻODEfunc` is $ f $ that describes the dynamics of the system. --The integration interval of ODE Block is fixed at 0 ~ 1. And it returns the output of the layer at $ t = 1 $.

By doing this, you can use ODE Block as one module of the neural network as shown below.

class ODEnet(nn.Module):
    def __init__(self, in_dim, mid_dim, out_dim):
        super(ODEnet, self).__init__()

        odefunc = ODEfunc(dim=mid_dim)
        
        self.fc1 = nn.Linear(in_dim, mid_dim)
        self.relu1 = nn.ReLU(inplace=True)
        self.norm1 = nn.BatchNorm1d(mid_dim)
        self.ode_block = ODEBlock(odefunc)  #Use ODE Block
        self.norm2 = nn.BatchNorm1d(mid_dim)
        self.fc2 = nn.Linear(mid_dim, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        out = self.fc1(x)
        out = self.relu1(out)
        out = self.norm1(out)
        out = self.ode_block(out)
        out = self.norm2(out)
        out = self.fc2(out)

        return out

This model was slow to calculate. However, using torchdiffeq doesn't always seem to slow it down, and as far as I've tried, the Neural ODE model in the official repository is as fast as a normal neural network. (This one should be a smaller model ...)

5. Summary

I introduced the rudimentary usage of torchdiffeq which is useful for implementing Neural ODE. If you would like to see the program that is actually training the model, please see the following Official torchdiffeq repository or [My implementation repository](https://github. com / TakadaTakumi / neuralODE_sample).

reference

torchdiffeq - GitHub My Implementation Repository

Recommended Posts

How to use the library "torchdiffeq" that implements Neural ODE's ODE Block
How to use the C library in Python
How to use the graph drawing library Bokeh
[Python] How to use the graph creation library Altair
How to use the Rubik's Cube solver library "kociemba"
How to use the decorator
[python] How to use the library Matplotlib for drawing graphs
Notes on how to use marshmallow in the schema library
How to use the zip function
How to use the optparse module
[Python] How to import the library
How to use the ConfigParser module
How to use the Spark ML pipeline
[Linux] How to use the echo command
How to use the IPython debugger (ipdb)
How to use hmmlearn, a Python library that realizes hidden Markov models
python I don't know how to get the printer name that I usually use.
How to use MkDocs for the first time
How to use Python Image Library in python3 series
How to use the Google Cloud Translation API
How to use the NHK program guide API
[Algorithm x Python] How to use the list
How to use PyTorch-based image processing library "Kornia"
How to use a library that is not originally included in Google App Engine
A quick introduction to the neural machine translation library
How to solve the recursive function that solved abc115-D
How to use the Raspberry Pi relay module Python
I wanted to use the Python library from MATLAB
Linux user addition, how to use the useradd command
How to use the grep command and frequent samples
How to use the exists clause in Django's queryset
[Introduction to Udemy Python3 + Application] 27. How to use the dictionary
[Introduction to Udemy Python3 + Application] 30. How to use the set
How to use argparse and the difference between optparse
How to use the model learned in Lobe in Python
(Remember quickly) How to use the LINUX command line
How to use xml.etree.ElementTree
How to use Python-shell
How to use tf.data
How to use virtualenv
How to use Seaboan
How to use image-match
How to use Pandas 2
How to use Virtualenv
How to use pytest_report_header
How to use Bio.Phylo
How to use SymPy
How to use x-means
How to use WikiExtractor.py
How to use IPython
How to use virtualenv
How to use Matplotlib
How to use iptables
How to use numpy
How to use TokyoTechFes2015
How to use venv
How to use dictionary {}
How to use Pyenv
How to use list []
How to use python-kabusapi
How to use OptParse