[PYTHON] Schreiben Sie mit Broadcast präzise Operationen für jedes Paar in die Daten

In 3 Zeilen

Problemstellung

Sei $ A $ die Menge aller $ (x, y) $, die $ -1 \ leq x \ leq 1 $, $ -1 \ leq y \ leq 1 $ erfüllen. Wenn ein Element aus $ A $ ausgewählt wird, wird der L1-Abstand zwischen den in $ A $ enthaltenen Gitterpunkten und dem ausgewählten Element berechnet.

Der L1-Abstand zwischen $ (x_1, y_1) $ und $ (x_2, y_2) $ wird wie folgt ausgedrückt.

d_1((x_1, y_1),(x_2,y_2)) = |x_1 - x_2|+|y_1 - y_2|

Denkweise

Berechnen Sie den L1-Abstand zwischen dem ausgewählten Element und allen Gitterpunkten und wählen Sie 5 in aufsteigender Reihenfolge des Abstands aus.

Um den L1-Abstand auf einmal zu berechnen, betrachten Sie den folgenden np.ndarray (oder tf.tensor).

lattice=np.array([[ 1,  1],
                  [ 1,  0],
                  [ 1, -1],
                  [ 0,  1],
                  [ 0,  0],
                  [ 0, -1],
                  [-1,  1],
                  [-1,  0],
                  [-1, -1]]) #shape = (9, 2)

Angenommen, das von Ihnen ausgewählte Original war $ (0.1,0.5) $. Tatsächlich ist die folgende Notation zur Berechnung des L1-Abstands wirksam

data = np.array([0.1,0.5])
l1_dist = np.sum(np.abs(data-lattice),axis=1)

Auf den ersten Blick sieht es wie eine normale Formel aus, aber der Teil "Datengitter" wird mit unterschiedlichen Formen voneinander subtrahiert. Hier werden die beiden Formen durch Rundfunk automatisch angepasst.

(Referenz: https://numpy.org/doc/stable/user/basics.broadcasting.html)

Gemäß der Formel werden die Bemaßungen von der Rückseite der Form verglichen, und wenn eine Bemaßung 1 ist, wird die Bemaßung durch Kopieren gemäß der anderen vergrößert. In diesem Fall ist die Form von "Daten" (2,) "und die" Form "von" Gitter "ist" (9,2) ", so dass die Abmessung auf der" Datenseite "angepasst wird.

array([[0.1, 0.5],
       [0.1, 0.5],
       [0.1, 0.5],
       [0.1, 0.5],
       [0.1, 0.5],
       [0.1, 0.5],
       [0.1, 0.5],
       [0.1, 0.5],
       [0.1, 0.5]])

Wurde berücksichtigt und die Subtraktion durchgeführt. Verwenden Sie dann "np.abs", um den absoluten Wert des Elements zu berechnen und entlang der entsprechenden "Achse" zu summieren. l1_dist hat die folgende Form als np.ndarray von(9,)

array([1.4, 1.4, 2.4, 0.6, 0.6, 1.6, 1.6, 1.6, 2.6])

Anwendung auf die Stapelverarbeitung

Die gleiche Idee kann erhalten werden, indem die Anzahl der Elemente, für die der L1-Abstand berechnet wird, auf zwei oder mehr erhöht wird. Nehmen wir an, dass es zwei Zielelemente gibt, und jedes ist $ (0,1,0,5) und (0,7,0,8) $. Dieses Mal werden wahrscheinlich "Daten" im folgenden Format geliefert

data = np.array([[0.1, 0.5],
                 [0.7, 0.8]]) # shape = (2,2)

In diesem Fall erzeugt "Datengitter" keine Sendung und es tritt ein Fehler auf. Dies liegt daran, dass die Abmessungen hinter der Form verglichen werden und es nicht mehr der Fall ist, wenn eine Abmessung 1 ist. Die Problemumgehung besteht darin, eine Achse mit der Dimension 1 in "np.expand_dims" hinzuzufügen.

data = np.expand_dims(data,axis=1) #Datenform= (2,1,2)von Gitter(9,2)Vergleiche mit der Achse=9 Daten von 1 sind dupliziert, Achse=Zwei 0-Gitter werden dupliziert
l1_dist = np.sum(np.abs(data-lattice),axis=2) # (2,9,2)Nach Subtraktion untereinander Achse=2 Summen sind erledigt. Beachten Sie, dass sich die Summenachse aufgrund der Erweiterung geändert hat

Wenn ja, ist l1_dist

array([[1.4, 1.4, 2.4, 0.6, 0.6, 1.6, 1.6, 1.6, 2.6],
       [0.5, 1.1, 2.1, 0.9, 1.5, 2.5, 1.9, 2.5, 3.5]]) # shape = (2, 9)

Wird sein.

Verschwommen

Nun, der Code ist einfacher, aber ich denke, die implizite Änderung der Form würde die Lesbarkeit beeinträchtigen. Ich würde gerne wissen, ob es eine bessere Vorgehensweise gibt, die Verarbeitungszeit, Lesbarkeit und Einfachheit kombiniert.

Referenz

TensorFlow Machine Learning Kochbuch Python-basiertes Verwendungsrezept 60+ (Ich empfehle jedoch nicht, dieses Buch zu kaufen ...)

Recommended Posts

Schreiben Sie mit Broadcast präzise Operationen für jedes Paar in die Daten
Online-Übertragung mit Python
Versuchen Sie, sich mit Python auf Ihrem PC automatisch bei Netflix anzumelden
Verwenden Sie ScraperWiki, um regelmäßig Daten von Ihrer Website abzurufen