[PYTHON] Find the second derivative with JAX automatic differentiation

Introduction

Automatic differentiation is convenient when you want to find the derivative of a complicated function. At that time, I used to use Pytorch's automatic differentiation. However, I wanted to use only automatic differentiation, but the pytorch package is quite heavy, so when I was looking for a lighter package, I arrived at JAX. ..

What is JAX?

Official image.png An updated version of Autograd (not currently maintained). You can use the GPU to calculate automatic differentiation at high speed (of course it also works on the CPU).

Installation method

CPC only version

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

If you want to use GPU, please refer to the pip installation guidance in Official.

Second derivative in JAX

Let's try to find the second derivative of the log function.

import jax.numpy as jnp
from jax import grad

#log function definition
fn = lambda x0: jnp.log(x0)

# x =Differentiate around 1
x = 1

#Substitution
y0 = fn(x)
#One-time differentiation
y1 = grad(fn)(x)
#Second derivative
y2 = grad(grad(fn))(x)

Execution result


>>> float(y0), float(y1), float(y2)
(0.0, 1.0, -1.0)

Finally

With JAX, you can easily and easily use automatic differentiation.

Recommended Posts

Find the second derivative with JAX automatic differentiation
Find the Levenshtein Distance with python
Find the SHA256 value with R (with bonus)
The second night of the loop with for
Find the mood value with python (Rike Koi)
Find a position above the threshold with NumPy
Predict the second round of summer 2016 with scikit-learn
Find out the day of the week with datetime
Find the shortest path with the Python Dijkstra's algorithm
Find the sum of unique values with pandas crosstab
Find out the location of packages installed with pip
I moved the automatic summarization API "summpy" with python3.