[PYTHON] J'ai essayé de former le modèle RWA (Recurrent Weighted Average) dans Keras

Qu'est-ce que la RWA (moyenne pondérée récurrente)?

image.png

Cliquez ici pour l'article (Machine Learning on Sequential Data Using a Recurrent Weighted Average) Dans la figure ci-dessus, c est un diagramme schématique du modèle RWA (a est un LSTM normal, b est un LSTM avec attention).

RWA est l'un des dérivés des réseaux neuronaux récurrents (* RNN *) qui gère les données en série. Dans le document proposé, par rapport au LSTM, qui est souvent utilisé comme implémentation de RNN,

Et, ** toutes les bonnes choses ** sont écrites. J'ai été surpris de la force de la revendication et de la simplicité de l'architecture, et je me suis demandé s'il pouvait vraiment surpasser le LSTM, qui est maintenant presque le standard de facto, donc cette fois [implémentation Keras de RWA](https: // gist. J'ai réécrit un peu github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) et j'ai essayé de reproduire certaines des expériences de l'article.

Architecture RWA

Vous pouvez considérer RWA comme une généralisation de l'attention et une redéfinition récursive qui l'intègre dans la structure du RNN. En d'autres termes, l'attention (en RNN) est un cas particulier de RWA.

Regardons de plus près. Tous les RNN, pas seulement les LSTM, sont des réseaux qui traitent des données en série. Puisque les données de série peuvent être facilement traitées en supposant le processus de Markov (= l'état actuel n'est déterminé que par les données actuelles et l'état passé), entrez «données actuelles» et «état passé» dans RNN. Il est modélisé de manière récursive afin de sortir "l'état actuel". Si vous écrivez dans la formule, $ h_t = Recurrent(x_t, h_{t-1}) $ est. $ h_t $ est l'état actuel de t, $ x_t $ est la donnée courante et $ h_ {t-1} $ est l'état passé. L'implémentation du réseau neuronal de cette fonction $ Recurrent $ est RNN, et la complexité de la fonction $ Recurrent $ en introduisant diverses portes est LSTM. [^ 1]

Cependant, comme vous pouvez le voir sur la figure du haut, le modèle Attention n'est pas défini de manière récursive et ne peut pas être représenté sous la forme de la fonction $ Recurrent $. [^ 2] RWA considère l'attention comme une moyenne mobile des états passés, et en transformant l'expression de manière équivalente, il la réduit à la forme de la fonction $ Recurrent $.

Plus précisément, RWA prend une moyenne mobile des états passés comme suit: image.png f est une fonction d'activation appropriée, où z régit les termes qui transforment récursivement les états, et a régit le poids moyen des états passés (a correspond à l'attention). Faire). Si cette formule est laissée telle quelle, on ne peut pas dire qu'elle est récursive car Σ inclut l'opération "d'addition d'états passés de 1 à t". D'une manière ou d'une autre, je veux redéfinir l'équation (2) uniquement dans "l'état précédent".

Maintenant, divisons l'intérieur de f dans l'équation (2) en le dénominateur d et la molécule n. image.png Ensuite, nous savons que n et d sont des sommes cumulatives, nous pouvons donc les réécrire comme suit: image.png À ce stade, n et d ont été transformés en un format qui ne dépend que du moment précédent. C'est tout pour l'essence.

Après cela, z est légèrement modifié par rapport au RNN ordinaire et l'expression est divisée en le terme u (c'est-à-dire, l'incorporation) qui ne voit que l'entrée et le terme g qui voit également l'état. image.png Est-ce que les mathématiques de RWA. [^ 3]

RWA a presque la même structure que le RNN le plus simple. Puisque RWA a commencé à l'origine sous la forme de "se référant à tous les états passés", on s'attend à ce que l'état puisse être mis à jour en se référant à l'état passé à tout moment, même s'il n'y a pas de porte Oublie ou de porte de sortie à l'intérieur comme LSTM. Sera fait.

la mise en oeuvre

Nous avons expérimenté du code qui modifiait l'implémentation Keras de RWA publié par un tiers (https://gist.github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) afin que le paramètre return_sequences soit valide. Cliquez ici pour le code modifié et le script d'expérimentation / visualisation (return_sequences est un paramètre qui vous permet de définir si vous souhaitez afficher l'historique de tous les états passés, pas seulement le dernier état, dans la couche récurrente de Keras. Sans cela, vous ne pourrez pas visualiser l'état plus tard.)

Expérience

La plus simple des expériences mentionnées dans l'article à mettre en œuvre

Nous avons mené des expériences avec deux types.

Classifying by Sequence Length C'est un problème de juger "si la longueur des données de la série donnée dépasse une certaine longueur?". Préparez un vecteur dont la longueur change de façon aléatoire dans la plage de 0 ou plus et 1000 ou moins, et si la longueur du vecteur dépasse 500, il est jugé comme 1, sinon il est jugé comme 0. La valeur de chaque élément du vecteur est correctement extraite de la distribution normale et renseignée (Remarque: la valeur d'élément du vecteur n'est pas liée à ce problème. C'est la longueur du vecteur qui est liée à ce problème) La fonction objectif est binary_crossentropy. Dans le papier, la taille du mini-lot a été définie sur 100, mais comme il était difficile d'incorporer des données de différentes longueurs de série dans le même lot, la taille du lot a été définie sur 1 dans l'expérience pour ce problème (cela prend beaucoup de temps). Les résultats suivants ont été obtenus en environ 12 heures en utilisant le GPU).

Les résultats expérimentaux sont les suivants (Axe vertical: précision (la plus élevée est meilleure), axe horizontal: nombre d'époques)

--Résultats de l'article image.png

--Résultats de cette expérience seq_length.png En raison de contraintes de temps, LSTM est toujours en cours d'apprentissage, mais c'était le même résultat que le résultat de l'article en ce sens que RWA converge extrêmement rapidement (combien d'échantillons sont appris et convergés car la taille du lot est différente. Je ne peux pas dire).

Lors du traitement des données, je me demandais quel était l'état de RWA, alors je l'ai également tracé. L'axe vertical est la dimension temporelle et l'axe horizontal est la dimension d'état (250). 1000.png

La figure ci-dessus est un exemple avec une longueur de série de 1000 (c'est-à-dire que je veux que le résultat de la prédiction soit "1"). Dans ce cas, j'ai pu prédire correctement. En regardant le tracé de l'état, il semble que l'état change lorsque la longueur de la série est proche de 500, et il semble que l'état est comme une gradation dans la direction du temps dans son ensemble. Apparemment, j'ai pu apprendre correctement. J'ai essayé différents tests avec différentes longueurs de série, mais la précision se détériorait fortement lorsque la longueur de la série était d'environ 500, mais la précision était de 100% pour les séries extrêmement courtes ou longues. (La figure ci-dessus est également un exemple de longueur de série extrêmement longue)

Adding Problem image.png

Le problème est de "préparer un vecteur de longueur appropriée et d'ajouter les valeurs de deux emplacements choisis au hasard". Les données fournies au modèle sont deux vecteurs de longueur n. L'un est un vecteur de nombres réels, l'autre est un vecteur avec des 1 à seulement deux endroits et le reste étant des 0. Laissez-les apprendre à additionner les nombres réels où 1 se situe. La fonction objectif est MSE. Ce problème a été expérimenté avec une taille de lot mini de 100 selon le papier. La durée de l'expérience est inférieure à une heure en utilisant le GPU.

Les résultats expérimentaux sont les suivants (Axe vertical: MSE (le plus bas est meilleur), axe horizontal: nombre d'époques)

--Résultats de l'article image.png

--Résultats de cette expérience (longueur 100) addtion100.png

--Résultats de cette expérience (longueur 1000) addition1000.png

Veuillez noter que l'échelle sur l'axe horizontal a changé (depuis que j'ai expérimenté avec 1epoch = 100batch, multiplier la valeur sur l'axe horizontal par 100 donnera la même échelle que le papier d'origine). En ce qui concerne RWA, j'ai pu reproduire les résultats de l'article. LSTM a donné les mêmes résultats que le papier pour une longueur de 100, mais n'a pas bien appris pour une longueur de 1000. Par rapport à la convergence du LSTM suite à l'article original, la précision commence-t-elle à s'améliorer avec une période d'apprentissage supplémentaire de 100 pour une série de longueur 1000?

RWA affirme également que "vous pouvez résoudre n'importe quel problème (dans la plage que vous avez essayée) sans avoir à jouer avec les hyperparamètres ou les paramètres d'initialisation", donc seul RWA peut reproduire les résultats du papier d'un seul coup. Il peut être plus souhaitable comme examen de suivi.

L'état de RWA est le suivant (un graphique correspond à un échantillon) L'axe vertical est le temps (100 ou 1000) et l'axe horizontal est la dimension d'état (250). Où est écrit au-dessus de la figure les données de l'emplacement du drapeau correct.

0.png 94.png

72.png 12.png

Lorsque vous trouvez les éléments à additionner (c'est-à-dire où), vous pouvez voir que certaines des dimensions de l'état changent rapidement. Certes, il semble que l'apprentissage se fasse pour que les événements contenus dans les données de la série puissent être détectés.

Résumé

Personnellement, je pense que RWA est beaucoup plus simple et plus facile à comprendre que LSTM, et que c'est un bon moyen de réaliser des idées intuitives. Dans le document proposé, seule la comparaison la plus simple avec LSTM est faite, et le problème est de savoir comment il se compare avec LSTM avec attention, et si les couches sont empilées pour le rendre multicouche (empilé) comme cela se fait souvent avec LSTM. Je ne sais toujours pas ce qui va se passer. (Cependant, la situation qui peut être appliquée au modèle d'attention est limitée, et comme RWA est comme une généralisation de l'attention, elle ne peut être comparée ...) Je pense que si plus de recherches sont effectuées à l'avenir, il sera peut-être possible de remplacer LSTM et d'utiliser RWA comme norme de facto.

[^ 1]: On l'appelle AR s'il est exprimé par un modèle linéaire sans état, et on l'appelle un modèle d'espace d'états s'il est exprimé par un modèle de Markov caché dans lequel l'équation est explicitement écrite. [^ 2]: Parce que cela dépend de tous les états passés, pas seulement de l'état précédent. [^ 3]: Afin de réduire l'erreur numérique, n et d sont transformés en équivalents en termes d'implémentation. Voir l'annexe B du document pour plus de détails.

Recommended Posts

J'ai essayé de former le modèle RWA (Recurrent Weighted Average) dans Keras
J'ai implémenté le modèle VGG16 avec Keras et essayé d'identifier CIFAR10
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé d'implémenter TOPIC MODEL en Python
J'ai essayé d'organiser les index d'évaluation utilisés en machine learning (modèle de régression)
J'ai essayé de représenter graphiquement les packages installés en Python
J'ai essayé de mettre en œuvre le modèle de base du réseau neuronal récurrent
J'ai essayé de trouver la moyenne de plusieurs colonnes avec TensorFlow
J'ai essayé de notifier les informations de retard de train avec LINE Notify
J'ai essayé de résumer le code souvent utilisé dans Pandas
J'ai essayé d'illustrer le temps et le temps du langage C
J'ai essayé de résumer les commandes souvent utilisées en entreprise
J'ai essayé de déplacer le ballon
J'ai essayé d'estimer la section.
J'ai essayé de décrire le trafic en temps réel avec WebSocket
J'ai essayé de traiter l'image en "style croquis" avec OpenCV
J'ai essayé de traiter l'image dans un "style de dessin au crayon" avec OpenCV
J'ai essayé d'implémenter PLSA en Python
J'ai essayé de résumer la commande umask
J'ai essayé d'implémenter la permutation en Python
J'ai essayé d'implémenter PLSA dans Python 2
J'ai essayé de résumer la modélisation graphique.
J'ai essayé d'implémenter ADALINE en Python
J'ai essayé d'estimer le rapport de circonférence π de manière probabiliste
J'ai essayé de toucher l'API COTOHA
J'ai essayé d'implémenter PPO en Python
[Python] J'ai essayé de résumer le type collectif (ensemble) d'une manière facile à comprendre.
J'ai essayé de visualiser le modèle avec la bibliothèque d'apprentissage automatique low-code "PyCaret"
J'ai essayé d'afficher la valeur d'altitude du DTM dans un graphique
J'ai essayé de prédire le comportement du nouveau virus corona avec le modèle SEIR.
J'ai essayé Web Scraping pour analyser les paroles.
[Python] J'ai essayé de juger l'image du membre du groupe d'idols en utilisant Keras
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé d'optimiser le séchage du linge
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé de simuler "Birthday Paradox" avec Python
J'ai essayé la méthode des moindres carrés en Python
J'ai essayé de corriger la forme trapézoïdale de l'image
J'ai essayé de prédire l'évolution de la quantité de neige pendant 2 ans par apprentissage automatique
J'ai essayé d'implémenter le tri sélectif en python
LeetCode j'ai essayé de résumer les plus simples
J'ai essayé de mettre en œuvre le problème du voyageur de commerce
Je veux afficher la progression en Python!
J'ai essayé de créer un modèle avec l'exemple d'Amazon SageMaker Autopilot
J'ai essayé de vectoriser les paroles de Hinatazaka 46!
J'ai essayé "Lobe" qui peut facilement entraîner le modèle d'apprentissage automatique publié par Microsoft.
J'ai essayé d'extraire le texte du fichier image en utilisant Tesseract du moteur OCR
J'ai essayé de résumer les nouvelles personnes infectées par le virus corona dans la ville d'Ichikawa, préfecture de Chiba
J'ai essayé de mettre HULFT IoT (Agent) dans la passerelle Rooster de Sun Electronics
[First data science ⑥] J'ai essayé de visualiser le prix du marché des restaurants à Tokyo
J'ai essayé de détecter l'iris à partir de l'image de la caméra
J'ai essayé de résumer la forme de base de GPLVM
J'ai essayé de toucher un fichier CSV avec Python
J'ai essayé de résoudre Soma Cube avec python
J'ai essayé d'implémenter un pseudo pachislot en Python
J'ai créé un code pour convertir illustration2vec en modèle Keras
J'ai essayé d'implémenter le poker de Drakue en Python