La différenciation automatique est pratique lorsque vous souhaitez rechercher la différenciation d'une fonction compliquée. A cette époque, j'utilisais la différenciation automatique de Pytorch. Cependant, je voulais n'utiliser que la différenciation automatique, mais le paquet pytorch est assez lourd, donc quand je cherchais un paquet plus léger, je suis arrivé à JAX. ..
Officiel Une version mise à jour de Autograd (non actuellement maintenue). Vous pouvez utiliser le GPU pour calculer la différenciation automatique à grande vitesse (bien sûr, cela fonctionne également sur le CPU).
CPC only version
pip install --upgrade pip
pip install --upgrade jax jaxlib # CPU-only version
Si vous souhaitez utiliser GPU, veuillez vous reporter aux instructions d'installation de pip dans Officiel.
Maintenant, essayons de trouver la différenciation de second ordre de la fonction de journal.
import jax.numpy as jnp
from jax import grad
#définition de la fonction log
fn = lambda x0: jnp.log(x0)
# x =Différencier autour de 1
x = 1
#Substitution
y0 = fn(x)
#Différenciation unique
y1 = grad(fn)(x)
#Différentiel de second ordre
y2 = grad(grad(fn))(x)
Résultat d'exécution
>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)
Avec JAX, vous pouvez facilement et facilement utiliser la différenciation automatique.
Recommended Posts