[PYTHON] Einsum est-il relativement rapide?

Comparez le temps de calcul matriciel / vectoriel avec plusieurs modèles

Je ne voulais pas le faire, mais quand j'ai commencé à le faire, j'ai essayé diverses choses. .. .. C'est un gaspillage, alors prenez note.

  \boldsymbol{A}_k = \left(
    \begin{array}{c}
      a_{k1} \\
      a_{k2} \\
      \vdots \\
      a_{kn}
    \end{array}
  \right) \\

\boldsymbol B_k = \left(
\begin{array}{ccccc}
a_{k11} & \cdots & a_{k1i} & \cdots & a_{k1n} \\
 \vdots & \ddots &         &        & \vdots \\
a_{ki1} &        & a_{kii} &        & a_{kin} \\
 \vdots &        &         & \ddots & \vdots \\
a_{kn1} & \cdots & a_{kni} & \cdots & a_{knn} \\
\end{array}
\right) \\

Dans le cas de

\boldsymbol{A}_k^T \boldsymbol{B}_k \\
\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\
\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

J'ai essayé de calculer tous les $ k $ à la fois. Même avec cela, le calcul de la matrice de numpy est compliqué. einsum est relativement rapide et peut être plus facile à écrire si vous vous y habituez. (Pourtant, j'ai fait une erreur et je l'ai mal calculée)

Pour Ak ^ T Bk

\boldsymbol{A}_k^T \boldsymbol{B}_k \\

Si vous l'écrivez différemment, cela peut changer, mais cela ressemble à ceci.

  1. Dans ce cas, einsum est extrêmement rapide. (func3)
  2. Le finaliste est matmul ou numba + expression. (func4, func2)
  3. Enfin, numba + pour instruction. (func1)

numba n'est pas rapide pour ses restrictions de compilation. Est-ce mal écrit? (Conversion de liste en np.array, np.newaxis ne peut pas être utilisé, etc.)

code

import numba
from numba import jit

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, K))
    for i in range(I):
        C[i] = A[i].dot(B[i])
    return C
    #return np.array([A[i].dot(B[i]) for i in range(len(A))])

@njit(cache=True)
def func2(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func2a(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func3(A, B):
    return np.einsum('km,kmn->kn', A, B)
def func4(A, B):
    return np.matmul(np.expand_dims(A, 1), B).squeeze()

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
C4 = func4(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3), np.allclose(C1, C4))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func4")
%timeit func4(A, B)

Résultat de l'exécution.

allclose True True True
func1
94.9 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
43.3 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
114 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
20.1 µs ± 907 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
func4
42.8 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Pour Ak ^ T Bk Ak

\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\

einsum ralentit. Toujours plus rapide que le numba + pour la déclaration.

  1. expression numba + (func2)
  2. einsum (func3)
  3. Enfin, numba + pour instruction. (func1)

einsum semble avoir une option d'optimisation, j'ai donc essayé de la spécifier, mais c'était lent. De plus, je ne savais pas comment écrire avec matmul.

code

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros(I)
    for i in range(I):
        C[i] = A[i].dot(B[i]).dot(A[i])
    return C

@njit(cache=True)
def func2(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func2a(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func3(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A)
def func3a(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A, optimize=True)
C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

Résultat d'exécution

allclose True True
func1
101 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
45.4 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
120 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
56.2 µs ± 500 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
139 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Pour Bk Ak Ak ^ T Bk

\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

einsum est encore plus lent. L'optimisation est beaucoup plus rapide, mais l'expression numba + est plus rapide.

  1. expression numba + (func2)
  2. einsum (func3)
  3. Enfin, numba + pour instruction. (func1)

Il semble que l'ordre de calcul puisse être optimisé avec einsum_path. Je ne l'ai pas regardé en détail.

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, J, K))
    for i in range(I):
        C[i] = np.outer(B[i] @ A[i], A[i] @ B[i])
    return C

@njit(cache=True)
def func2(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func2a(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func3(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=['einsum_path', (0, 1), (0, 1), (0, 1)])
def func3a(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=True)

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

Résultat d'exécution

allclose True True
func1
335 µs ± 7.15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func2
97.8 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
246 µs ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func3
154 µs ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
250 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Recommended Posts

Einsum est-il relativement rapide?
Le nombre de bits rapides est bin (). Count ('1')