[PYTHON] Verwendung der Bibliothek "torchdiffeq", die den ODE-Block von Neural ODE implementiert

1. Zuallererst

Es ist sehr neu, aber ich zeige Ihnen, wie Sie die Implementierungsbibliothek für neuronale ODE verwenden. Neural ODE ist übrigens das beste Papier von NeuroIPS 2018.

Die offiziellen Autoren der Neuronalen ODE haben eine Bibliothek offizieller Repositories mit dem Namen [torchdiffeq] veröffentlicht (https://github.com/rtqichen/torchdiffeq).

Obwohl Neural ODE viele Artikel enthält, die die Theorie und Interpretation erklären, dachte ich, dass es nur wenige japanische Artikel gibt, die die tatsächliche Verwendung dieser Bibliothek beschreiben, und fasste die grundlegende Verwendung in diesem Artikel zusammen. Torchdiffeq ist übrigens eine Bibliothek für PyTorch.

Wenn Sie diesen Artikel lesen, können Sie:

--Torchdiffeq kann das Anfangswertproblem der gewöhnlichen Differentialgleichung erster Ordnung lösen --Torchdiffeq kann das Anfangswertproblem der gewöhnlichen Differentialgleichung zweiter Ordnung lösen

2. Voraussetzungen

Ich werde die erforderlichen Kenntnisse für die Verwendung von torchdiffeq überprüfen.

2.1 Was ist eine normale Differentialgleichung?

Von den Differentialgleichungen wird diejenige mit im wesentlichen nur einer unbekannten Variablen als normale Differentialgleichung bezeichnet. Beispielsweise haben Differentialgleichungen wie $ \ frac {dz} {dt} = f (z (t), t) $ und $ m \ ddot {x} = -kx $ mehrere Variablen, aber $ Da z $ und $ x $ Funktionen von $ t $ sind, gibt es im Wesentlichen nur eine unbekannte Variable, $ t $, und es kann gesagt werden, dass es sich um eine normale Differentialgleichung handelt.

2.2 Was ist neuronale ODE?

Es gibt viele andere leicht verständliche Artikel über neuronale ODE. Bitte lesen Sie sie.

Um den Umriss kurz zu erläutern, kann man von neuronaler ODE als "neuronales Netzwerk mit kontinuierlichen Schichten" sprechen.

Es gibt eine Theorie, dass es schwierig ist, das Konzept zu verstehen, weil es ein Wort "normale Differentialgleichung" gibt, das im neuronalen Netzbereich nicht zu hören ist, aber ich denke, dass der wichtige Punkt darin besteht, dass "Schichten kontinuierlich sind". .. Dies ermöglicht es beispielsweise, "die Ausgabe der 0,5-ten Schicht herauszunehmen", was mit dem herkömmlichen Modell nicht möglich war.

Referenzlink unten:

3. Wie benutzt man torchdiffeq?

Nun wollen wir sehen, wie man torchdiffeq benutzt.

3.1 Installation

Führen Sie zum Installieren den folgenden Befehl aus.

pip install torchdiffeq

3.2 Beispiel: Differentialgleichung erster Ordnung

Bevor wir mit der Implementierung der neuronalen ODE beginnen, nehmen wir als Beispiel eine einfache gewöhnliche Differentialgleichung erster Ordnung, um zu sehen, wie Torchdiffeq einfach verwendet werden kann.

Betrachten Sie die folgende Gleichung Differentialgleichung. $ z(0) = 0, \\\ \frac{dz(t)}{dt} = f(t) = t $

Diese Lösung verwendet $ C $ als Integrationskonstante

\int dz = \int tdt+C \\\ z(t) = \frac{t^2}{2} + C

Da $ z (0) = 0 $ ist, können wir sehen, dass die Lösung dieser Differentialgleichung wie folgt ist. $ z(t) = \frac{t^2}{2} $

Implementierung durch torchdiffeq

Die einfachste Implementierung zur Lösung dieses Problems mit torchdiffeq finden Sie unten.

first_order.py


from torchdiffeq import odeint

def func(t, z):
    return t

z0 = torch.Tensor([0])
t = torch.linspace(0,2,100)
out = odeint(func, z0, t)

Unten sind die Punkte aufgelistet.

――Es ist zu beachten, dass das Element von t eine Spalte sein muss, die im engeren Sinne zunimmt (abnimmt). Ein Fehler tritt auch dann auf, wenn derselbe Wert enthalten ist, z. B. "t = Tensor ([0, 0, 1])".

Zeichnen Sie die obigen Ergebnisse.

from matplotlib.pyplot as plt

plt.plot(t, out)
plt.axes().set_aspect('equal', 'datalim')  #Seitenverhältnis 1:Auf 1 setzen
plt.grid()
plt.xlim(0,2)
plt.show()

first_order.png

Sie können sehen, dass es mit der Lösung $ z = \ frac {t ^ 2} {2} $ der durch Handberechnung erhaltenen Differentialgleichung übereinstimmt.

3.3 (Referenz) Beispiel 2: Lösen der Differentialgleichung zweiter Ordnung

Wenn Sie torchdiffeq verwenden, können Sie auch die Differentialgleichung zweiter Ordnung lösen. Als Beispiel lösen wir die (?) Einfache Schwingungsdifferentialgleichung, die der Wissenschaft mit torchdiffeq vertraut ist. Die Differentialgleichung der einfachen Schwingung lautet wie folgt. $ m\ddot{x} = -kx $ Im Ausgangszustand, wenn $ t = 0 $, $ x = 1 $, $ \ dot {x} = \ frac {dx} {dt} = 0 $. Der Trick zum Lösen der Differentialgleichung zweiter Ordnung besteht darin, die Differentialgleichung zweiter Ordnung in zwei Differentialgleichungen erster Ordnung zu zerlegen. Gehen Sie insbesondere wie folgt vor.

\left[ \begin{array}{c} \dot{x} \\\ \ddot{x} \\\ \end{array} \right] = \left[ \begin{array}{cc} 0 & 1\\\ -\frac{k}{m} & 0\\\ \end{array} \right] \left[ \begin{array}{c} x \\\ \dot{x} \\\ \end{array} \right]

Hier ist $ \ boldsymbol {y} = \ left [ \begin{array}{c} x \
\dot{x} \
\end{array} Wenn Sie \ right] $ setzen, führt diese Differentialgleichung zweiter Ordnung zu der folgenden Differentialgleichung erster Ordnung.

\frac{d\boldsymbol{y}}{dt} = f(\boldsymbol{y})

Die Implementierung ist wie folgt. $ k = 1, m = 1 $.

oscillation.py


class Oscillation:
    def __init__(self, km):
        self.mat = torch.Tensor([[0, 1],
                                 [-km, 0]])

    def solve(self, t, x0, dx0):
        y0 = torch.cat([x0, dx0])
        out = odeint(self.func, y0, t)
        return out

    def func(self, t, y):
        # print(t)
        out = y @ self.mat  # @Ist das Matrixprodukt
        return out

if __name__=="__main__":
    x0 = torch.Tensor([1])
    dx0 = torch.Tensor([0])

    import numpy as np
    t = torch.linspace(0, 4 * np.pi, 1000)
    solver = Oscillation(1)
    out = solver.solve(t, x0, dx0)

Wenn Sie es zeichnen, können Sie sehen, dass die Lösung der einfachen Vibration richtig erhalten wird. osillation.png

4. Implementierung des ODE-Blocks

Nachdem Sie nun mit der Verwendung von torchdiffeq vertraut sind, wollen wir sehen, wie ODE Block tatsächlich implementiert wird. Der ODE-Block ist ein Modul, das die Dynamik von $ \ frac {dz} {dt} = f (t, z) $ bildet. Die eigentliche neuronale ODE wird unter Verwendung des ODE-Blocks zusammen mit der normalen Full-Connect-Schicht und der Faltungsschicht konstruiert.

Die folgende Implementierung betont die Einfachheit und ist nur ein Beispiel.

from torchdiffeq import odeint_adjoint as odeint

class ODEfunc(nn.Module):
    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.seq = nn.Sequential(nn.Linear(dim, 124),
                                 nn.ReLU(),
                                 nn.Linear(124, 124),
                                 nn.ReLU(),
                                 nn.Linear(124, dim),
                                 nn.Tanh())

    def forward(self, t, x):
        out = self.seq(x)
        return out


class ODEBlock(nn.Module):
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time)
        return out[1]  # out[0]Weil der Anfangswert in enthalten ist.

Um es kurz zu erklären,

--ODE Block behandelt die empfangene Eingabe "x" als Anfangswert der Differentialgleichung. --ODEfunc ist $ f $, das die Dynamik des Systems beschreibt.

Auf diese Weise können Sie den ODE-Block wie unten gezeigt als ein Modul des neuronalen Netzes verwenden.

class ODEnet(nn.Module):
    def __init__(self, in_dim, mid_dim, out_dim):
        super(ODEnet, self).__init__()

        odefunc = ODEfunc(dim=mid_dim)
        
        self.fc1 = nn.Linear(in_dim, mid_dim)
        self.relu1 = nn.ReLU(inplace=True)
        self.norm1 = nn.BatchNorm1d(mid_dim)
        self.ode_block = ODEBlock(odefunc)  #Verwenden Sie den ODE-Block
        self.norm2 = nn.BatchNorm1d(mid_dim)
        self.fc2 = nn.Linear(mid_dim, out_dim)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        out = self.fc1(x)
        out = self.relu1(out)
        out = self.norm1(out)
        out = self.ode_block(out)
        out = self.norm2(out)
        out = self.fc2(out)

        return out

Dieses Modell war langsam zu berechnen. Die Verwendung von torchdiffeq scheint es jedoch nicht zu verlangsamen, und soweit ich es versucht habe, ist das neuronale ODE-Modell im offiziellen Repository so schnell wie ein normales neuronales Netzwerk. (Dieses sollte ein kleineres Modell sein ...)

5. Zusammenfassung

Wir haben eine rudimentäre Verwendung von torchdiffeq eingeführt, die für die Implementierung von neuronaler ODE nützlich ist. Wenn Sie das Programm sehen möchten, das das Modell tatsächlich trainiert, lesen Sie bitte das folgende Offizielle Torchdiffeq-Repository oder [Mein Implementierungs-Repository](https: // github). com / TakadaTakumi / neuralODE_sample).

Referenz

torchdiffeq - GitHub Mein Implementierungs-Repository

Recommended Posts

Verwendung der Bibliothek "torchdiffeq", die den ODE-Block von Neural ODE implementiert
Verwendung der C-Bibliothek in Python
Verwendung der Grafikzeichnungsbibliothek Bokeh
[Python] Verwendung der Diagrammerstellungsbibliothek Altair
Verwendung der Solver-Bibliothek "kociemba" von Rubik Cube
Wie benutzt man den Dekorateur?
[Python] Verwendung von Matplotlib, einer Bibliothek zum Zeichnen von Diagrammen
Hinweise zur Verwendung von Marshmallow in der Schemabibliothek
Verwendung der Zip-Funktion
Verwendung des optparse-Moduls
Verwendung des ConfigParser-Moduls
Verwendung der Spark ML-Pipeline
[Linux] Verwendung des Befehls echo
Verwendung des IPython-Debuggers (ipdb)
Verwendung von hmmlearn, einer Python-Bibliothek, die versteckte Markov-Modelle realisiert
Python Ich weiß nicht, wie ich den Druckernamen bekomme, den ich normalerweise benutze.
So verwenden Sie MkDocs zum ersten Mal
Verwendung der Python-Bildbibliothek in der Python3-Serie
Verwendung der Google Cloud Translation API
Verwendung der NHK-Programmführer-API
[Algorithmus x Python] Verwendung der Liste
Verwendung der PyTorch-basierten Bildverarbeitungsbibliothek "Kornia"
So verwenden Sie eine Bibliothek, die ursprünglich nicht in Google App Engine enthalten war
Eine grobe Einführung in die neuronale maschinelle Übersetzungsbibliothek
So lösen Sie die rekursive Funktion, die abc115-D gelöst hat
Verwendung von Raspeye Relay Module Python
Ich wollte die Python-Bibliothek von MATLAB verwenden
Linux-Benutzer hinzufügen, wie der Befehl useradd verwendet wird
Verwendung des Befehls grep und häufiger Samples
Verwendung der Exist-Klausel in Django Queryset
[Einführung in die Udemy Python3 + -Anwendung] 27. Verwendung des Wörterbuchs
[Einführung in die Udemy Python3 + -Anwendung] 30. Verwendung des Sets
Wie man Argparse benutzt und den Unterschied zwischen Optparse
Verwendung des in Lobe in Python erlernten Modells
(Denken Sie schnell daran) Verwendung der LINUX-Befehlszeile
Verwendung von xml.etree.ElementTree
Wie benutzt man Python-Shell
Hinweise zur Verwendung von tf.data
Verwendung von virtualenv
Wie benutzt man Seaboan?
Verwendung von Image-Match
Verwendung von Pandas 2
Verwendung von Virtualenv
Verwendung von pytest_report_header
Wie man Bio.Phylo benutzt
Verwendung von SymPy
Wie man x-means benutzt
Verwendung von WikiExtractor.py
Verwendung von IPython
Verwendung von virtualenv
Wie benutzt man Matplotlib?
Verwendung von iptables
Wie benutzt man numpy?
Verwendung von TokyoTechFes2015
Wie benutzt man venv
Verwendung des Wörterbuchs {}
Wie benutzt man Pyenv?
Verwendung der Liste []
Wie man Python-Kabusapi benutzt
Verwendung von OptParse