[PYTHON] Finden Sie das Differential zweiter Ordnung mit der automatischen Differenzierung von JAX

Einführung

Die automatische Unterscheidung ist praktisch, wenn Sie die Unterscheidung einer komplizierten Funktion finden möchten. Zu dieser Zeit verwendete ich Pytorchs automatische Differenzierung. Ich wollte jedoch nur die automatische Differenzierung verwenden, aber das Pytorch-Paket ist ziemlich schwer. Als ich nach einem leichteren Paket suchte, kam ich zu JAX. ..

Was ist JAX?

Offiziell image.png Eine aktualisierte Version von Autograd (derzeit nicht gepflegt). Sie können die GPU verwenden, um die automatische Differenzierung mit hoher Geschwindigkeit zu berechnen (natürlich funktioniert sie auch auf der CPU).

Wie installiert man

CPC only version

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

Wenn Sie eine GPU verwenden möchten, lesen Sie bitte die Anleitung zur Pip-Installation unter Offiziell.

Differenzierung zweiter Ordnung in JAX

Versuchen wir nun, die Differenzierung zweiter Ordnung der Protokollfunktion zu finden.

import jax.numpy as jnp
from jax import grad

#Definition der Protokollfunktion
fn = lambda x0: jnp.log(x0)

# x =Differenziere um 1
x = 1

#Auswechslung
y0 = fn(x)
#Einmalige Differenzierung
y1 = grad(fn)(x)
#Differential zweiter Ordnung
y2 = grad(grad(fn))(x)

Ausführungsergebnis


>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)

Schließlich

Mit JAX können Sie die automatische Differenzierung einfach und problemlos verwenden.

Recommended Posts

Finden Sie das Differential zweiter Ordnung mit der automatischen Differenzierung von JAX
Finden Sie die Bearbeitungsentfernung (Levenshtein-Entfernung) mit Python
Finden Sie den SHA256-Wert mit R (mit Bonus)
Die zweite Nacht der Runde mit für
Finden Sie den Stimmungswert mit Python (Rike Koi)
Finden Sie mit NumPy die Position über dem Schwellenwert
Prognostizieren Sie die zweite Runde des Sommers 2016 mit scikit-learn
Finden Sie den Tag nach Datum / Uhrzeit heraus
Finden Sie den kürzesten Weg mit dem Python Dijkstra-Algorithmus
Berechnen Sie die Summe der eindeutigen Werte durch Pandas-Kreuztabellen
Finden Sie den Speicherort der mit pip installierten Pakete heraus