Inspiré par L'écriture explicite d'une boucle avec numpy est extrêmement lente
Dans l'article ci-dessus, il a été dit qu'écrire explicitement serait extrêmement lent. Par exemple
def matmul1(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
Le code comme est plus lent que np.dot
.
%timeit matmul1(a, b)
1 loops, best of 3: 12.9 s per loop
%timeit np.dot(a, b)
10 loops, best of 3: 20.7 ms per loop
C'est lent car il est calculé sur mon ordinateur portable. De plus, atlas / mkl n'est pas lié.
Utilisez maintenant Numba.
import numba
@numba.jit #Ajouter seulement ici
def matmul1_jit(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
Il s'agit d'une compilation JIT de code Python utilisant LLVM, donc elle peut s'exécuter très rapidement. Le premier appel comprend le temps de compilation, donc si vous mesurez la vitesse sur les appels suivants:
%timeit matmul1_jit(a, b)
10 loops, best of 3: 24.4 ms per loop
Le simple fait d'ajouter une ligne comme celle-ci le rendait à peu près identique à np.dot
(environ 20% plus lent).
Placez le [ipynb] entier (http://nbviewer.ipython.org/gist/termoshtt/824ff3e766de5fe9fdd6) dans gist. J'aimerais pouvoir intégrer nbviewer dans Qiita.
Recommended Posts