[PYTHON] Ich habe versucht, das RWA-Modell (Recurrent Weighted Average) in Keras zu trainieren

Was ist RWA (Recurrent Weighted Average)?

image.png

Klicken Sie hier für das Papier (Maschinelles Lernen für sequentielle Daten unter Verwendung eines wiederkehrenden gewichteten Durchschnitts) In der obigen Abbildung ist c ein schematisches Diagramm des RWA-Modells (a ist ein normales LSTM, b ist ein LSTM mit Aufmerksamkeit).

RWA ist eine der Ableitungen von wiederkehrenden neuronalen Netzen (* RNN *), die Seriendaten verarbeiten. In dem vorgeschlagenen Papier wird im Vergleich zu LSTM, das häufig als Implementierung von RNN verwendet wird,

Und ** alle guten Dinge ** sind geschrieben. Ich war überrascht über die Stärke des Anspruchs und die Einfachheit der Architektur und fragte mich, ob sie das LSTM, das mittlerweile fast der De-facto-Standard ist, wirklich übertreffen könnte, also diesmal [Keras-Implementierung von RWA](https: // gist. Ich habe github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) ein wenig umgeschrieben und versucht, einige der Experimente in der Arbeit zu reproduzieren.

RWA-Architektur

Sie können sich RWA als eine Verallgemeinerung der Aufmerksamkeit und eine rekursive Neudefinition vorstellen, die sie in die Struktur des RNN einbezieht. Mit anderen Worten, Aufmerksamkeit (in RNN) ist ein Sonderfall von RWA.

Lass uns genauer hinschauen. Alle RNNs, nicht nur LSTMs, sind Netzwerke, die Seriendaten verarbeiten. Da die Seriendaten leicht verarbeitet werden können, indem der Markov-Prozess angenommen wird (= der aktuelle Zustand wird nur durch die aktuellen Daten und den vergangenen Zustand bestimmt), geben Sie die "aktuellen Daten" und den "vergangenen Zustand" in RNN ein. Es wird rekursiv modelliert, so dass es den "aktuellen Zustand" ausgibt. Wenn Sie in die Formel schreiben, $ h_t = Recurrent(x_t, h_{t-1}) $ ist. $ h_t $ ist der aktuelle Zustand von t, $ x_t $ sind die aktuellen Daten und $ h_ {t-1} $ ist der vergangene Zustand. Die neuronale Netzwerkimplementierung dieser Funktion $ Recurrent $ ist RNN, und die Komplexität der $ Recurrent $ -Funktion durch Einführung verschiedener Gatter ist LSTM. [^ 1]

Wie Sie in der oberen Abbildung sehen können, ist das Aufmerksamkeitsmodell jedoch nicht rekursiv definiert und kann nicht in Form der Funktion $ Recurrent $ dargestellt werden. [^ 2] RWA betrachtet Aufmerksamkeit als einen gleitenden Durchschnitt vergangener Zustände und reduziert sie durch äquivalente Transformation des Ausdrucks auf die Form der $ Recurrent $ -Funktion.

Insbesondere nimmt RWA einen gleitenden Durchschnitt vergangener Zustände wie folgt: image.png f ist eine geeignete Aktivierungsfunktion, bei der z die Begriffe regelt, die Zustände rekursiv transformieren, und a regelt, wie viel Gewicht die vergangenen Zustände gemittelt werden (a entspricht der Aufmerksamkeit). Machen). Wenn diese Formel unverändert bleibt, kann sie nicht als rekursiv bezeichnet werden, da Σ die Operation "Hinzufügen vergangener Zustände von 1 zu t" umfasst. Irgendwie möchte ich Gleichung (2) nur im "vorherigen Zustand" neu definieren.

Teilen wir nun das Innere von f in Gleichung (2) in den Nenner d und das Molekül n. image.png Dann wissen wir, dass n und d kumulative Summen sind, sodass wir sie wie folgt umschreiben können: image.png Zu diesem Zeitpunkt wurden n und d in ein Format umgewandelt, das nur vom vorherigen Zeitpunkt abhängt. Das ist alles für das Wesentliche.

Danach wird z geringfügig von der normalen RNN geändert, und der Ausdruck wird in den Term u (dh Einbettung) unterteilt, der nur die Eingabe sieht, und den Term g, der auch den Zustand sieht. image.png Ist die Mathematik der RWA. [^ 3]

RWA hat fast die gleiche Struktur wie das einfachste RNN. Da RWA ursprünglich von der Form "Verweisen auf alle vergangenen Zustände" ausgegangen ist, wird erwartet, dass der Zustand jederzeit durch Verweisen auf den vergangenen Zustand aktualisiert werden kann, selbst wenn kein Vergessens- oder Ausgangsgatter wie LSTM enthalten ist. Getan werden.

Implementierung

Wir haben mit Code experimentiert, der die von einem Drittanbieter veröffentlichte Keras-Implementierung von RWA (https://gist.github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) so geändert hat, dass der Parameter "return_sequences" gültig ist. Klicken Sie hier für geänderten Code und Experiment / Visualisierungsskript (return_sequences ist ein Parameter, mit dem Sie festlegen können, ob der Verlauf aller vergangenen Zustände und nicht nur des letzten Zustands in Keras 'Recurrent Layer ausgegeben werden soll. Ohne diesen Parameter können Sie den Zustand später nicht visualisieren.)

Experiment

Das am einfachsten zu implementierende der im Papier erwähnten Experimente

Wir haben Experimente mit zwei Typen durchgeführt.

Classifying by Sequence Length Es ist ein Problem zu beurteilen, "ob die Länge der angegebenen Seriendaten eine bestimmte Länge überschreitet?". Bereiten Sie einen Vektor vor, dessen Länge sich zufällig im Bereich von 0 oder mehr und 1000 oder weniger ändert. Wenn die Länge des Vektors 500 überschreitet, wird er als 1 beurteilt, andernfalls wird er als 0 beurteilt. Der Wert jedes Elements des Vektors wird in geeigneter Weise aus der Normalverteilung entnommen und ausgefüllt (Hinweis: Der Elementwert des Vektors hängt nicht mit diesem Problem zusammen. Es ist die Länge des Vektors, die mit diesem Problem zusammenhängt.) Die Zielfunktion ist binary_crossentropy. In der Arbeit wurde die Mini-Batch-Größe auf 100 festgelegt. Da es jedoch schwierig war, Daten mit unterschiedlichen Serienlängen in denselben Batch zu integrieren, wurde die Batch-Größe im Experiment für dieses Problem auf 1 festgelegt (dies nimmt viel Zeit in Anspruch). Die folgenden Ergebnisse wurden in etwa 12 Stunden unter Verwendung einer GPU erhalten.

Die experimentellen Ergebnisse sind wie folgt (Vertikale Achse: Genauigkeit (höher ist besser), Horizontale Achse: Anzahl der Epochen)

Während der Verarbeitung der Daten habe ich mich gefragt, wie der Status von RWA ist, also habe ich ihn auch aufgezeichnet. Die vertikale Achse ist die Zeitdimension und die horizontale Achse ist die Zustandsdimension (250). 1000.png

Die obige Abbildung ist ein Beispiel mit einer Serienlänge von 1000 (dh ich möchte, dass das Vorhersageergebnis "1" ist). In diesem Fall konnte ich richtig vorhersagen. Betrachtet man die Darstellung des Zustands, so scheint sich der Zustand zu ändern, wenn die Serienlänge nahe bei 500 liegt, und es scheint, dass der Zustand wie eine Abstufung in der Zeitrichtung als Ganzes ist. Anscheinend konnte ich richtig lernen. Ich habe verschiedene Tests mit unterschiedlichen Längen der Serie versucht, aber die Genauigkeit verschlechterte sich stark, wenn die Serienlänge etwa 500 betrug, aber die Genauigkeit betrug 100% für die Serien, die extrem kurz oder lang waren. (Die obige Abbildung ist auch ein Beispiel für eine extrem lange Serienlänge.)

Adding Problem image.png

Das Problem besteht darin, "einen Vektor geeigneter Länge vorzubereiten und die Werte von zwei zufällig ausgewählten Stellen zu addieren". Die dem Modell gegebenen Daten sind zwei Vektoren der Länge n. Einer ist ein reeller Zahlenvektor, der andere ist ein Vektor, bei dem 1s nur an zwei Stellen stehen und der Rest 0s ist. Lassen Sie sie lernen, die reellen Zahlen zu addieren, in denen 1 steht. Die Zielfunktion ist MSE. Dieses Problem wurde mit einer Mini-Chargengröße von 100 gemäß Papier experimentiert. Die Versuchszeit beträgt mit GPU weniger als eine Stunde.

Die experimentellen Ergebnisse sind wie folgt (Vertikale Achse: MSE (niedriger ist besser), horizontale Achse: Anzahl der Epochen)

Bitte beachten Sie, dass sich der Maßstab auf der horizontalen Achse geändert hat (da ich mit 1epoch = 100batch experimentiert habe, führt das Multiplizieren des Werts auf der horizontalen Achse mit 100 zum gleichen Maßstab wie das Originalpapier). In Bezug auf RWA konnte ich die Ergebnisse des Papiers reproduzieren. LSTM lieferte bei einer Länge von 100 die gleichen Ergebnisse wie das Papier, lernte jedoch bei einer Länge von 1000 nicht gut. Verbessert sich die Genauigkeit im Vergleich zur Konvergenz von LSTM aufgrund des Originalpapiers mit zusätzlichen 100 Lernphasen für eine Reihe von 1000?

RWA behauptet auch, dass "Sie jedes Problem (innerhalb des von Ihnen versuchten Bereichs) lösen können, ohne sich mit Hyperparametern oder Initialisierungseinstellungen herumschlagen zu müssen", sodass nur RWA die Ergebnisse des Papiers auf einmal reproduzieren kann. Es kann als Folgeprüfung wünschenswerter sein.

Der Zustand von RWA ist wie folgt (ein Graph entspricht einer Stichprobe) Die vertikale Achse ist die Zeit (100 oder 1000) und die horizontale Achse ist die Zustandsdimension (250). Wo oben geschrieben, sind die Daten angegeben, wo sich das richtige Flag befand.

0.png 94.png

72.png 12.png

Wenn Sie die Elemente finden, die hinzugefügt werden sollen (dh wo), können Sie sehen, dass sich einige Dimensionen des Status schnell ändern. Es scheint sicher, dass gelernt wird, damit die in den Seriendaten enthaltenen Ereignisse erkannt werden können.

Zusammenfassung

Persönlich bin ich der Meinung, dass RWA viel einfacher und verständlicher ist als LSTM und dass dies eine gute Möglichkeit ist, intuitive Ideen zu verwirklichen. In dem vorgeschlagenen Papier wird nur der einfachste Vergleich mit LSTM durchgeführt, und das Problem besteht darin, wie es mit LSTM mit Aufmerksamkeit verglichen wird und ob Schichten gestapelt werden, um es mehrschichtig (gestapelt) zu machen, wie dies häufig bei LSTM der Fall ist. Ich weiß immer noch nicht, was passieren wird. (Die Situation, die auf das Aufmerksamkeitsmodell angewendet werden kann, ist jedoch begrenzt, und da RWA wie eine Verallgemeinerung der Aufmerksamkeit ist, kann es nicht verglichen werden ...) Ich denke, wenn in Zukunft mehr Forschung betrieben wird, könnte es möglich sein, LSTM zu ersetzen und RWA als De-facto-Standard zu verwenden.

[^ 1]: Es heißt AR, wenn es durch ein lineares Modell ohne Zustand ausgedrückt wird, und es wird als Zustandsraummodell bezeichnet, wenn es durch ein verstecktes Markov-Modell ausgedrückt wird, in dem die Gleichung explizit geschrieben ist. [^ 2]: Weil es von allen vergangenen Zuständen abhängt, nicht nur vom vorherigen Zustand. [^ 3]: Um den numerischen Fehler zu reduzieren, werden n und d hinsichtlich der Implementierung in Äquivalente umgewandelt. Einzelheiten finden Sie in Anhang B des Dokuments.

Recommended Posts

Ich habe versucht, das RWA-Modell (Recurrent Weighted Average) in Keras zu trainieren
Ich habe das VGG16-Modell mit Keras implementiert und versucht, CIFAR10 zu identifizieren
Ich habe versucht, Keras in TFv1.1 zu integrieren
Ich habe versucht, TOPIC MODEL in Python zu implementieren
Ich habe versucht, die beim maschinellen Lernen verwendeten Bewertungsindizes zu organisieren (Regressionsmodell).
Ich habe versucht, die in Python installierten Pakete grafisch darzustellen
Ich habe versucht, das grundlegende Modell des wiederkehrenden neuronalen Netzwerks zu implementieren
Ich habe versucht, mit TensorFlow den Durchschnitt mehrerer Spalten zu ermitteln
Ich habe versucht, die Zugverspätungsinformationen mit LINE Notify zu benachrichtigen
Ich habe versucht, den in Pandas häufig verwendeten Code zusammenzufassen
Ich habe versucht, die Zeit und die Zeit der C-Sprache zu veranschaulichen
Ich habe versucht, die im Geschäftsleben häufig verwendeten Befehle zusammenzufassen
Ich habe versucht, den Ball zu bewegen
Ich habe versucht, den Abschnitt zu schätzen.
Ich habe versucht, den Datenverkehr mit WebSocket in Echtzeit zu beschreiben
Ich habe versucht, das Bild mit OpenCV im "Skizzenstil" zu verarbeiten
Ich habe versucht, das Bild mit OpenCV im "Bleistift-Zeichenstil" zu verarbeiten
Ich habe versucht, PLSA in Python zu implementieren
Ich habe versucht, den Befehl umask zusammenzufassen
Ich habe versucht, Permutation in Python zu implementieren
Ich habe versucht, PLSA in Python 2 zu implementieren
Ich habe versucht, die grafische Modellierung zusammenzufassen.
Ich habe versucht, ADALINE in Python zu implementieren
Ich habe versucht, das Umfangsverhältnis π probabilistisch abzuschätzen
Ich habe versucht, die COTOHA-API zu berühren
Ich habe versucht, PPO in Python zu implementieren
[Python] Ich habe versucht, den kollektiven Typ (Satz) auf leicht verständliche Weise zusammenzufassen.
Ich habe versucht, das Modell mit der Low-Code-Bibliothek für maschinelles Lernen "PyCaret" zu visualisieren.
Ich habe versucht, den Höhenwert von DTM in einem Diagramm anzuzeigen
Ich habe versucht, das Verhalten des neuen Koronavirus mit dem SEIR-Modell vorherzusagen.
Ich habe Web Scraping versucht, um die Texte zu analysieren.
[Python] Ich habe versucht, das Mitgliederbild der Idolgruppe mithilfe von Keras zu beurteilen
Ich habe versucht, GAN (mnist) mit Keras zu bewegen
Ich habe versucht, beim Trocknen der Wäsche zu optimieren
Ich habe versucht, die Daten mit Zwietracht zu speichern
Ich habe versucht, "Birthday Paradox" mit Python zu simulieren
Ich habe die Methode der kleinsten Quadrate in Python ausprobiert
Ich habe versucht, die Trapezform des Bildes zu korrigieren
Ich habe versucht, die Veränderung der Schneemenge für 2 Jahre durch maschinelles Lernen vorherzusagen
Ich habe versucht, eine selektive Sortierung in Python zu implementieren
LeetCode Ich habe versucht, die einfachen zusammenzufassen
Ich habe versucht, das Problem des Handlungsreisenden umzusetzen
Ich möchte den Fortschritt in Python anzeigen!
Ich habe versucht, ein Modell mit dem Beispiel von Amazon SageMaker Autopilot zu erstellen
Ich habe versucht, die Texte von Hinatazaka 46 zu vektorisieren!
Ich habe "Lobe" ausprobiert, mit dem das von Microsoft veröffentlichte Modell des maschinellen Lernens problemlos trainiert werden kann.
Ich habe versucht, den Text in der Bilddatei mit Tesseract der OCR-Engine zu extrahieren
Ich habe versucht, die neuen mit dem Corona-Virus infizierten Menschen in Ichikawa City, Präfektur Chiba, zusammenzufassen
Ich habe versucht, HULFT IoT (Agent) in das Gateway Rooster von Sun Electronics zu integrieren
[Erste Datenwissenschaft ⑥] Ich habe versucht, den Marktpreis von Restaurants in Tokio zu visualisieren
Ich habe versucht, Iris aus dem Kamerabild zu erkennen
Ich habe versucht, die Grundform von GPLVM zusammenzufassen
Ich habe versucht, eine CSV-Datei mit Python zu berühren
Ich habe versucht, Soma Cube mit Python zu lösen
Ich habe versucht, einen Pseudo-Pachislot in Python zu implementieren
Ich habe einen Code erstellt, um illustration2vec in ein Keras-Modell zu konvertieren
Ich habe versucht, Drakues Poker in Python zu implementieren