[PYTHON] Is einsum relatively fast?

Compare matrix / vector calculation time with several patterns

I didn't mean to do it, but when I started doing it, I tried various things. .. .. I'm sorry, so make a note.

  \boldsymbol{A}_k = \left(
    \begin{array}{c}
      a_{k1} \\
      a_{k2} \\
      \vdots \\
      a_{kn}
    \end{array}
  \right) \\

\boldsymbol B_k = \left(
\begin{array}{ccccc}
a_{k11} & \cdots & a_{k1i} & \cdots & a_{k1n} \\
 \vdots & \ddots &         &        & \vdots \\
a_{ki1} &        & a_{kii} &        & a_{kin} \\
 \vdots &        &         & \ddots & \vdots \\
a_{kn1} & \cdots & a_{kni} & \cdots & a_{knn} \\
\end{array}
\right) \\

In the case of

\boldsymbol{A}_k^T \boldsymbol{B}_k \\
\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\
\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

I tried to calculate for all $ k $ at once. Even with this, numpy matrix calculation is complicated. einsum is relatively fast and may be easier to write if you get used to it. (Still, I made a mistake and calculated it incorrectly)

For Ak ^ T Bk

\boldsymbol{A}_k^T \boldsymbol{B}_k \\

If you write it differently, it may change, but it looks like this.

  1. In this case, einsum is overwhelmingly fast. (func3)
  2. The runner-up is matmul or numba + expression. (func4, func2)
  3. The last is a numba + for statement. (func1)

numba is not fast despite the restrictions of compilation. Is it badly written? (Converting from list to np.array, np.newaxis cannot be used, etc.)

code

import numba
from numba import jit

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, K))
    for i in range(I):
        C[i] = A[i].dot(B[i])
    return C
    #return np.array([A[i].dot(B[i]) for i in range(len(A))])

@njit(cache=True)
def func2(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func2a(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func3(A, B):
    return np.einsum('km,kmn->kn', A, B)
def func4(A, B):
    return np.matmul(np.expand_dims(A, 1), B).squeeze()

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
C4 = func4(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3), np.allclose(C1, C4))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func4")
%timeit func4(A, B)

Execution result.

allclose True True True
func1
94.9 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
43.3 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
114 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
20.1 µs ± 907 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
func4
42.8 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

For Ak ^ T Bk Ak

\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\

einsum is getting slower. Still faster than the numba + for statement.

  1. numba + expression (func2)
  2. einsum (func3)
  3. The last is a numba + for statement. (func1)

einsum seems to have an optimization option, so I tried to specify it, but it was slow. Also, I didn't know how to write using matmul.

code

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros(I)
    for i in range(I):
        C[i] = A[i].dot(B[i]).dot(A[i])
    return C

@njit(cache=True)
def func2(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func2a(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func3(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A)
def func3a(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A, optimize=True)
C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

Execution result

allclose True True
func1
101 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
45.4 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
120 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
56.2 µs ± 500 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
139 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

For Bk Ak Ak ^ T Bk

\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

einsum is even slower. Optimization is much faster, but the numba + expression is faster.

  1. numba + expression (func2)
  2. einsum (func3)
  3. The last is a numba + for statement. (func1)

It seems that the calculation order can be optimized with einsum_path. I haven't looked at it in detail.

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, J, K))
    for i in range(I):
        C[i] = np.outer(B[i] @ A[i], A[i] @ B[i])
    return C

@njit(cache=True)
def func2(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func2a(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func3(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=['einsum_path', (0, 1), (0, 1), (0, 1)])
def func3a(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=True)

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

Execution result

allclose True True
func1
335 µs ± 7.15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func2
97.8 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
246 µs ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func3
154 µs ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
250 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Recommended Posts

Is einsum relatively fast?
Fast bit count is bin (). count ('1')