I didn't mean to do it, but when I started doing it, I tried various things. .. .. I'm sorry, so make a note.
\boldsymbol{A}_k = \left(
\begin{array}{c}
a_{k1} \\
a_{k2} \\
\vdots \\
a_{kn}
\end{array}
\right) \\
\boldsymbol B_k = \left(
\begin{array}{ccccc}
a_{k11} & \cdots & a_{k1i} & \cdots & a_{k1n} \\
\vdots & \ddots & & & \vdots \\
a_{ki1} & & a_{kii} & & a_{kin} \\
\vdots & & & \ddots & \vdots \\
a_{kn1} & \cdots & a_{kni} & \cdots & a_{knn} \\
\end{array}
\right) \\
In the case of
\boldsymbol{A}_k^T \boldsymbol{B}_k \\
\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\
\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k
I tried to calculate for all $ k $ at once. Even with this, numpy matrix calculation is complicated. einsum is relatively fast and may be easier to write if you get used to it. (Still, I made a mistake and calculated it incorrectly)
\boldsymbol{A}_k^T \boldsymbol{B}_k \\
If you write it differently, it may change, but it looks like this.
numba is not fast despite the restrictions of compilation. Is it badly written? (Converting from list to np.array, np.newaxis cannot be used, etc.)
import numba
from numba import jit
A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))
@njit(cache=True)
def func1(A, B):
I, J, K = np.shape(B)
C = np.zeros((I, K))
for i in range(I):
C[i] = A[i].dot(B[i])
return C
#return np.array([A[i].dot(B[i]) for i in range(len(A))])
@njit(cache=True)
def func2(A, B):
return (np.expand_dims(A, -1) * B).sum(1)
def func2a(A, B):
return (np.expand_dims(A, -1) * B).sum(1)
def func3(A, B):
return np.einsum('km,kmn->kn', A, B)
def func4(A, B):
return np.matmul(np.expand_dims(A, 1), B).squeeze()
C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
C4 = func4(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3), np.allclose(C1, C4))
print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func4")
%timeit func4(A, B)
allclose True True True
func1
94.9 µs ± 3.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
43.3 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
114 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
20.1 µs ± 907 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
func4
42.8 µs ± 2.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
\boldsymbol{A}_k^T \boldsymbol{B}_k \boldsymbol{A}_k \\
einsum is getting slower. Still faster than the numba + for statement.
einsum seems to have an optimization option, so I tried to specify it, but it was slow. Also, I didn't know how to write using matmul.
import numba
from numba import jit, njit, prange
A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))
@njit(cache=True)
def func1(A, B):
I, J, K = np.shape(B)
C = np.zeros(I)
for i in range(I):
C[i] = A[i].dot(B[i]).dot(A[i])
return C
@njit(cache=True)
def func2(A, B):
return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func2a(A, B):
return ((np.expand_dims(A, -1) * B).sum(1) * A).sum(1)
def func3(A, B):
return np.einsum('km,kmn,kn->k', A, B, A)
def func3a(A, B):
return np.einsum('km,kmn,kn->k', A, B, A, optimize=True)
C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))
print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)
allclose True True
func1
101 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2
45.4 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
120 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3
56.2 µs ± 500 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
139 µs ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
\boldsymbol{B}_k \boldsymbol{A}_k \boldsymbol{A}_k^T \boldsymbol{B}_k
einsum is even slower. Optimization is much faster, but the numba + expression is faster.
It seems that the calculation order can be optimized with einsum_path. I haven't looked at it in detail.
import numba
from numba import jit, njit, prange
A = np.random.random((257, 6))
B = np.random.random((257, 6, 6))
@njit(cache=True)
def func1(A, B):
I, J, K = np.shape(B)
C = np.zeros((I, J, K))
for i in range(I):
C[i] = np.outer(B[i] @ A[i], A[i] @ B[i])
return C
@njit(cache=True)
def func2(A, B):
return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func2a(A, B):
return np.expand_dims((B * np.expand_dims(A, 1)).sum(2), -1) * np.expand_dims((np.expand_dims(A, -1) * B).sum(1), 1)
def func3(A, B):
return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=['einsum_path', (0, 1), (0, 1), (0, 1)])
def func3a(A, B):
return np.einsum('kab,kb,kc,kcd->kad', B, A, A, B, optimize=True)
C1 = func1(A, B)
C2 = func2(A, B)
C3 = func3(A, B)
print("allclose", np.allclose(C1, C2), np.allclose(C1, C3))
print("func1")
%timeit func1(A, B)
print("func2")
%timeit func2(A, B)
print("func2a")
%timeit func2a(A, B)
print("func3")
%timeit func3(A, B)
print("func3a")
%timeit func3a(A, B)
allclose True True
func1
335 µs ± 7.15 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func2
97.8 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func2a
246 µs ± 20.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
func3
154 µs ± 3.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
func3a
250 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)