Inspiriert von Das explizite Schreiben einer Schleife mit Numpy ist extrem langsam
In dem obigen Artikel wurde gesagt, dass das explizite Schreiben für extrem langsam sein würde. Zum Beispiel
def matmul1(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
Code like ist langsamer als np.dot
.
%timeit matmul1(a, b)
1 loops, best of 3: 12.9 s per loop
%timeit np.dot(a, b)
10 loops, best of 3: 20.7 ms per loop
Es ist langsam, weil es auf meinem Notebook-PC berechnet wird. Außerdem ist atlas / mkl nicht verknüpft.
Verwenden Sie hier Numba.
import numba
@numba.jit #Nur hier hinzufügen
def matmul1_jit(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
Dies ist eine JIT-Kompilierung von Python-Code mit LLVM, sodass sie sehr schnell ausgeführt werden kann. Der erste Aufruf enthält die Zeit zum Kompilieren. Wenn Sie also die Geschwindigkeit bei nachfolgenden Aufrufen messen:
%timeit matmul1_jit(a, b)
10 loops, best of 3: 24.4 ms per loop
Nur eine Zeile wie diese hinzuzufügen, machte es ungefähr gleich wie "np.dot" (ungefähr 20% langsamer).
Platzieren Sie das gesamte ipynb in gist. Ich wünschte, ich könnte nbviewer in Qiita einbetten.
Recommended Posts