[PYTHON] Ist einsum relativ schnell?

Vergleichen Sie die Berechnungszeit für Matrix / Vektor mit mehreren Mustern

Ich wollte es nicht tun, aber als ich anfing, versuchte ich verschiedene Dinge. .. .. Es tut mir leid, also mach dir eine Notiz.

  \boldsymbol{A}_k = \left(
    \begin{array}{c}
      a_{k1} \\
      a_{k2} \\
      \vdots \\
      a_{kn}
    \end{array}
  \right) \\

\boldsymbol B_k = \left(
\begin{array}{ccccc}
a_{k11} & \cdots & a_{k1i} & \cdots & a_{k1n} \\
 \vdots & \ddots &         &        & \vdots \\
a_{ki1} &        & a_{kii} &        & a_{kin} \\
 \vdots &        &         & \ddots & \vdots \\
a_{kn1} & \cdots & a_{kni} & \cdots & a_{knn} \\
\end{array}
\right) \\

Im Falle von

\boldsymbol{A}_k^T \boldsymbol{B}_k \\
\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\
\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

Ich habe versucht, alle $ k $ auf einmal zu berechnen. Trotzdem ist die Matrixberechnung von numpy kompliziert. einsum ist relativ schnell und es kann einfacher sein zu schreiben, wenn man sich daran gewöhnt. (Trotzdem habe ich einen Fehler gemacht und ihn falsch berechnet)

Für Ak ^ T Bk

\boldsymbol{A}_k^T \boldsymbol{B}_k \\

Wenn Sie es anders schreiben, kann es sich ändern, aber es sieht so aus.

  1. In diesem Fall ist einsum überwältigend schnell. (func3)
  2. Der Zweitplatzierte ist matmul oder numba + expression. (func4, func2)
  3. Zum Schluss numba + für Anweisung. (func1)

numba ist wegen seiner Kompilierungsbeschränkungen nicht schnell. Ist es schlecht geschrieben? (Konvertieren von Liste zu np.array, np.newaxis kann nicht verwendet werden usw.)

Code

import numba
from numba import jit

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, K))
    for i in range(I):
        C[i] = A[i].dot(B[i])
    return C
    #return np.array([A[i].dot(B[i]) for i in range(len(A))])

@njit(cache=True)
def func2(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func2a(A, B):
    return (np.expand_dims(A, -1) * B).sum(1)
def func3(A, B):
    return np.einsum('km,kmn->kn', A, B)
def func4(A, B):
    return np.matmul(np.expand_dims(A, 1), B).squeeze()

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
C4 = func4(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3), np.allclose(C1, C4))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func4")
%timeit func4(A, B)

Ausführungsergebnis.

allclose True True True
func1
94.9 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
43.3 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
114 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
20.1 µs ± 907 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
func4
42.8 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Für Ak ^ T Bk Ak

\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\

einsum wird langsamer. Immer noch schneller als die numba + for-Anweisung.

  1. numba + Ausdruck (func2)
  2. einsum (func3)
  3. Zum Schluss numba + für Anweisung. (func1)

einsum scheint eine Optimierungsoption zu haben, also habe ich versucht, sie anzugeben, aber sie war langsam. Außerdem wusste ich nicht, wie man mit Matmul schreibt.

Code

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros(I)
    for i in range(I):
        C[i] = A[i].dot(B[i]).dot(A[i])
    return C

@njit(cache=True)
def func2(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func2a(A, B):
    return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func3(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A)
def func3a(A, B):
    return np.einsum('km,kmn,kn->k', A, B, A, optimize=True)
C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

Ausführungsergebnis

allclose True True
func1
101 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
45.4 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
120 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
56.2 µs ± 500 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
139 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Für Bk Ak Ak ^ T Bk

\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k

einsum ist noch langsamer. Die Optimierung ist viel schneller, aber der Ausdruck numba + ist schneller.

  1. numba + Ausdruck (func2)
  2. einsum (func3)
  3. Zum Schluss numba + für Anweisung. (func1)

Es scheint, dass die Berechnungsreihenfolge mit einsum_path optimiert werden kann. Ich habe es nicht im Detail betrachtet.

import numba
from numba import jit, njit, prange

A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))

@njit(cache=True)
def func1(A, B):
    I, J, K = np.shape(B)
    C = np.zeros((I, J, K))
    for i in range(I):
        C[i] = np.outer(B[i] @ A[i], A[i] @ B[i])
    return C

@njit(cache=True)
def func2(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func2a(A, B):
    return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func3(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=['einsum_path', (0, 1), (0, 1), (0, 1)])
def func3a(A, B):
    return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=True)

C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))

print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)

Ausführungsergebnis

allclose True True
func1
335 µs ± 7.15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func2
97.8 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
246 µs ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func3
154 µs ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
250 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Recommended Posts

Ist einsum relativ schnell?
Die schnelle Bitanzahl ist bin (). Count ('1')