Die automatische Unterscheidung ist praktisch, wenn Sie die Unterscheidung einer komplizierten Funktion finden möchten. Zu dieser Zeit verwendete ich Pytorchs automatische Differenzierung. Ich wollte jedoch nur die automatische Differenzierung verwenden, aber das Pytorch-Paket ist ziemlich schwer. Als ich nach einem leichteren Paket suchte, kam ich zu JAX. ..
Offiziell Eine aktualisierte Version von Autograd (derzeit nicht gepflegt). Sie können die GPU verwenden, um die automatische Differenzierung mit hoher Geschwindigkeit zu berechnen (natürlich funktioniert sie auch auf der CPU).
CPC only version
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
Wenn Sie eine GPU verwenden möchten, lesen Sie bitte die Anleitung zur Pip-Installation unter Offiziell.
Versuchen wir nun, die Differenzierung zweiter Ordnung der Protokollfunktion zu finden.
import jax.numpy as jnp
from jax import grad
#Definition der Protokollfunktion
fn = lambda x0: jnp.log(x0)
# x =Differenziere um 1
x = 1
#Auswechslung
y0 = fn(x)
#Einmalige Differenzierung
y1 = grad(fn)(x)
#Differential zweiter Ordnung
y2 = grad(grad(fn))(x)
Ausführungsergebnis
>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)
Mit JAX können Sie die automatische Differenzierung einfach und problemlos verwenden.
Recommended Posts