Schreiben Sie den Code, um den Gaußschen Kernel zu berechnen, der im Gaußschen Prozess usw. mit Baku-Geschwindigkeit berechnet werden muss.
Geschrieben in Numba, um zu beschleunigen. Es ist schade, dass der Code selbst aufgrund der von Numba auferlegten Einschränkungen etwas überflüssig wird, aber der Effekt der Beschleunigung ist enorm. Es gab Einschränkungen wie das Schreiben durch Multiplizieren des Quadrats, um mit Numba zu beschleunigen, und das Nichtverwenden der Numpy-Funktion in der von mir definierten Funktion.
from numba import jit, void, f8
import numpy as np
import time
@jit(void(f8[:, :], f8[:, :]))
def gauss_gram_mat(x, K):
n_points = len(x)
n_dim = len(x[0])
b = 0
sgm = 0.2
for j in range(n_points):
for i in range(n_points):
for k in range(n_dim):
b = (x[i][k] - x[j][k]) / sgm
K[i][j] += b * b
def gauss_gram_mat_normal(x, K):
n_points = len(x)
n_dim = len(x[0])
b = 0
sgm = 0.2
for j in range(n_points):
for i in range(n_points):
for k in range(n_dim):
b = (x[i][k] - x[j][k]) / sgm
K[i][j] += b * b
n_dim = 10
n_points = 2000
x = np.random.rand(n_points, n_dim)
K = np.zeros((n_points, n_points))
start = time.time()
gauss_gram_mat(x, K)
K = np.exp(- K / 2)
print("Namba: {}".format(time.time() - start))
start = time.time()
gauss_gram_mat_normal(x, K)
K = np.exp(- K / 2)
print("Normal: {}".format(time.time() - start))
Obwohl es nur ein Muster gibt, haben wir die Berechnungsgeschwindigkeit mit dem normalen Code und dem Code von Numba in Bezug auf die Anzahl der Punkte und die Anzahl der oben genannten Dimensionen verglichen.
Anscheinend ist es fast 500 mal schneller. (Wenn Sie die Einschlussnotation usw. verwenden, ist es ohne Numba schneller, aber es ist bisher unmöglich.)
Numba: 0.11480522155761719
Normal: 50.70034885406494
Ich habe es auch mit Numpy überprüft.
import numpy as np
import time
n_dim = 10
n_points = 2000
sgm = 0.2
x = np.random.rand(n_points, n_dim)
now = time.time()
K = np.exp(- 0.5 * (((x - x[:, None]) / sgm) ** 2).sum(axis=2))
print("Numpy: {}".format(time.time() - start))
Das Ergebnis ist, dass Numba schneller als Numpy ist.
Numpy: 0.3936312198638916