[PYTHON] I tried to train the RWA (Recurrent Weighted Average) model in Keras

What is RWA (Recurrent Weighted Average)?

image.png

Click here for the paper (Machine Learning on Sequential Data Using a Recurrent Weighted Average) In the figure above, c is a schematic diagram of the RWA model (a is a normal LSTM, b is an LSTM with attention).

RWA is one of the derivations of Recurrent Neural Networks (* RNN *) that handles series data. In the proposed paper, compared to the LSTM, which is often used as an implementation of RNN,

--Good accuracy --Fast convergence --Small number of parameters

And, ** all good things ** are written. I was surprised at the strength of the claim and the simplicity of the architecture, and wondered if it could really surpass the LSTM, which is now almost the de facto standard, so this time I implemented [RWA's Keras implementation](https: // gist. I rewrote github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) a little and tried to reproduce some of the experiments in the paper.

RWA architecture

You can think of RWA as a generalization of Attention and a recursive redefinition that incorporates it into the structure of the RNN. In other words, attention (in RNNs) is a special case of RWA.

Let's take a closer look. All RNNs, not just LSTMs, are networks that process series data. Since it is easy to process series data by assuming the Markov process (= the current state is determined only by the current data and the past state), enter "current data" and "past state" in RNN. Is recursively modeled to output the "current state". If you write in the formula, $ h_t = Recurrent(x_t, h_{t-1}) $ is. $ h_t $ is the current state of t, $ x_t $ is the current data, and $ h_ {t-1} $ is the past state. The neural network implementation of this function $ Recurrent $ is RNN, and the complexity of the $ Recurrent $ function by introducing various gates is LSTM. [^ 1]

However, as you can see in the top figure, the Attention model is not recursively defined and cannot be represented in the form of the function $ Recurrent $. [^ 2] RWA considers attention as a moving average of past states, and transforms the expression equivalently to reduce it to the form of the $ Recurrent $ function.

Specifically, RWA takes a moving average of past states as follows: image.png f is a suitable activation function, where z controls the recursive state-transforming term, and a controls how much weight is placed on averaging past states (a corresponds to attention). To do). If this formula is left as it is, it cannot be said to be recursive because Σ includes the operation of "adding past states from 1 to t". Somehow, I want to redefine equation (2) only in the "previous state".

Now, let's divide the inside of f in Eq. (2) into the denominator d and the numerator n. image.png Then we know that n and d are cumulative sums, so we can rewrite them as follows: image.png At this point, n and d have been transformed into a format that depends only on the previous point in time. That's all for the essence.

After that, z is slightly changed from a normal RNN, and the expression is divided into a term u (that is, embedding) that looks only at the input and a term g that also looks at the state. image.png Is the mathematical formula of RWA. [^ 3]

RWAs have much the same structure as the simplest RNNs. Since RWA originally started from the "refer to all past states" format, it is expected that the state can be updated by referring to the past state at any time even if there is no internal Forget gate or output gate like LSTM. Will be done.

Implementation

I experimented with code that modified the Keras implementation of RWA published by a third party (https://gist.github.com/shamatar/55b804cf62b8ee0fa23efdb3ea5a4701) so that the return_sequences parameter is valid. Click here for modified code and experiment / visualization script (return_sequences is a parameter that allows you to set whether to output the history of all past states, not just the last state, in Keras' Recurrent Layer. Without this, you cannot visualize the state later.)

Experiment

The easiest to implement of the experiments mentioned in the paper

We conducted experiments with two types.

Classifying by Sequence Length It is a problem to judge "whether the length of the given series data exceeds a certain length?". Prepare a vector whose length changes randomly in the range of 0 or more and 1000 or less, and if the length of the vector exceeds 500, it is judged as 1, otherwise it is judged as 0. The value of each element of the vector is appropriately taken from the normal distribution and filled in (Note: The element value of the vector is not related to this problem. It is the length of the vector that is related to this problem) The objective function is binary_crossentropy. In the paper, the mini-batch size was set to 100, but since it was troublesome to incorporate data with different series lengths into the same batch, the batch size was set to 1 in the experiment for this problem (it takes a lot of time). The following results were obtained in about 12 hours using GPU).

The experimental results are as follows (Vertical axis: Accuracy (higher is better), Horizontal axis: Number of epochs)

--Results of the dissertation image.png

--Results of this experiment seq_length.png Due to time constraints, the LSTM is still in the process of learning, but it was the same as the result of the paper in that the RWA converged overwhelmingly fast (how many samples are learned and converged because the batch size is different. I can't say).

While processing the data, I was wondering what the state of RWA was, so I also plotted it. The vertical axis is the dimension of time and the horizontal axis is the dimension of state (250). 1000.png

The figure above is an example with a series length of 1000 (that is, the prediction result should be "1"). In this case, I was able to predict correctly. Looking at the plot of the state, it seems that the state changes when the series length is close to 500, and it seems that the state is like a gradation in the time direction as a whole. Apparently I was able to learn correctly. I tried various tests with different lengths of the series, but the accuracy deteriorated sharply when the series length was around 500, while the accuracy was 100% for the series that were extremely short or long. (The above figure is also an example of extremely long series length)

Adding Problem image.png

The problem is "prepare a vector of appropriate length and add two randomly selected values". The data given to the model is two vectors of length n. One is a real number vector, the other is a vector with 1s standing in only two places and the rest being 0s. Let them learn to add up the real numbers where 1 stands. The objective function is MSE. This problem was experimented with a mini-batch size of 100 as per the paper. The experiment time is less than an hour using GPU.

The experimental results are as follows (Vertical axis: MSE (lower is better), horizontal axis: number of epochs)

--Results of the dissertation image.png

--Results of this experiment (length 100) addtion100.png

--Results of this experiment (length 1000) addition1000.png

Please note that the scale on the horizontal axis has changed (since I experimented with 1epoch = 100batch, multiplying the value on the horizontal axis by 100 will result in the same scale as the original paper). Regarding RWA, I was able to reproduce the results of the paper. LSTM gave the same results as the paper for length 100, but did not learn well for length 1000. Compared with the state of convergence of LSTM as a result of the original paper, does the accuracy start to improve with an additional 100 epoch of learning for a series of length 1000?

RWA also insists that it can solve any problem (within the range I tried) without having to mess with hyperparameters and initialization settings, so rather only RWA can reproduce the results of the paper in one shot. It may be more desirable as a follow-up exam.

The state of RWA is as follows (one graph corresponds to one sample) The vertical axis is time (100 or 1000) and the horizontal axis is the state dimension (250). Where written above the figure is the data of where the correct flag was located.

0.png 94.png

72.png 12.png

When you find the elements to add (that is, where), you can see that some of the dimensions of the state are changing rapidly. Certainly, it seems that learning is possible so that the events contained in the series data can be detected.

Summary

Personally, I feel that RWA is much simpler and easier to understand than LSTM, and that it is a good way to realize intuitive ideas. In the proposed paper, only the simplest comparison with LSTM is made, and the problem is how it compares with LSTM with attention, and if layers are stacked to make it multi-layer (stacked) as is often done with LSTM. I still don't know what will happen. (However, the situation that can be applied to the attention model is limited, and since RWA is like a generalization of attention, it may not be compared ...) I think that if more research is done in the future, RWA may be used as the de facto standard by replacing LSTM.

[^ 1]: If it is expressed by a linear model without a state, it is called AR, and if it is expressed by a hidden Markov model in which the equation is explicitly written, it is called a state space model. [^ 2]: Because it depends on all past states, not just the previous state. [^ 3]: In order to reduce the numerical error, n and d are transformed into equivalents in the implementation. See Appendix B of the paper for details.

Recommended Posts

I tried to train the RWA (Recurrent Weighted Average) model in Keras
I implemented the VGG16 model in Keras and tried to identify CIFAR10
I tried to integrate with Keras in TFv1.1
I tried to implement TOPIC MODEL in Python
I tried to organize the evaluation indexes used in machine learning (regression model)
I tried to graph the packages installed in Python
I tried to implement a basic Recurrent Neural Network model
I tried to find the average of the sequence with TensorFlow
I tried to notify the train delay information with LINE Notify
I tried to summarize the code often used in Pandas
I tried to illustrate the time and time in C language
I tried to summarize the commands often used in business
I tried to move the ball
I tried to estimate the interval.
I tried to describe the traffic in real time with WebSocket
I tried to process the image in "sketch style" with OpenCV
I tried to process the image in "pencil style" with OpenCV
I tried to implement PLSA in Python
I tried to summarize the umask command
I tried to implement permutation in Python
I tried to implement PLSA in Python 2
I tried to summarize the graphical modeling.
I tried to implement ADALINE in Python
I tried to estimate the pi stochastically
I tried to touch the COTOHA API
I tried to implement PPO in Python
[Python] I tried to summarize the set type (set) in an easy-to-understand manner.
I tried using the trained model VGG16 of the deep learning library Keras
I tried to visualize the model with the low-code machine learning library "PyCaret"
I tried to display the altitude value of DTM in a graph
I tried to predict the behavior of the new coronavirus with the SEIR model.
I tried to make PyTorch model API in Azure environment using TorchServe
I tried web scraping to analyze the lyrics.
[Python] I tried to judge the member image of the idol group using Keras
I tried to move GAN (mnist) with keras
I tried to optimize while drying the laundry
I tried to save the data with discord
I tried simulating the "birthday paradox" in Python
I tried the least squares method in Python
I tried to correct the keystone of the image
I tried to predict the change in snowfall for 2 years by machine learning
I tried to implement selection sort in python
LeetCode I tried to summarize the simple ones
I tried to implement the traveling salesman problem
I want to display the progress in Python!
I tried to create a model with the sample of Amazon SageMaker Autopilot
I tried to predict the price of ETF
I tried to vectorize the lyrics of Hinatazaka46!
[Linux] I learned LPIC lv1 in 10 days and tried to understand the mechanism of Linux.
I tried "Lobe" which can easily train the machine learning model published by Microsoft.
I tried to extract the text in the image file using Tesseract of the OCR engine
I tried to summarize the new coronavirus infected people in Ichikawa City, Chiba Prefecture
I tried to put HULFT IoT (Agent) in the gateway Rooster of Sun Electronics
[First data science ⑥] I tried to visualize the market price of restaurants in Tokyo
I tried to detect the iris from the camera image
I tried to summarize the basic form of GPLVM
I tried to touch the CSV file with Python
I tried to solve the soma cube with python
I tried to implement a pseudo pachislot in Python
I made a code to convert illustration2vec to keras model
I tried to implement Dragon Quest poker in Python