[PYTHON] I tried to learn PredNet

caltech_montage_1.gif (The above figure is taken from the R & D source site)

PredNet? https://coxlab.github.io/prednet/ Bill Lotter, Gabriel Kreiman, and David Cox (2016)

A model of deep learning that predicts future frames of video. There are various demos in CoxLab (link above), which is the research and development source, so if you look at it, you will get an image.

It's a story that got excited around the middle of 2016, so there may be some feelings now, but I will summarize what I learned by actually learning. (The focus is on learning and the environment, rather than explaining algorithms and concepts.)

Until learning

Since the code is publicly available (https://github.com/coxlab/prednet), it is possible to train the model without reading the dissertation.

About the environment

In order to run the original code as it is, it is necessary to prepare the environment of Python2 + Keras1. Please note that the latest version is Python3 + Keras2, so it is a little old. I tried porting to Keras2, but learning failed, so I decided to do it with Keras1. I haven't done it, but I think it's possible to port it to Python3.

Data preparation

The main reference in the paper is In-vehicle camera image by KITTI [^ 1]. There is process_kitti.py in the script that has beengit clone, so when I run it,

--Data download

Will do it automatically. But this is ** very ** slow. In the environment I used, it took me three full days to download the data. It may or may not be about 100GB of data in total, but I still feel that it is too late.

Learning

Once the data is ready, run kitti_train.py to start learning. Here too, there are some caveats.

Learning time

It's hard to say because it depends on the environment, but in the environment I used (GeForce GTX TITAN 6082MiB), it took about 5 minutes per epoch. The default setting is to train 150 epoch, so it takes a little over 12 hours in total. I think it's a fairly conscientious learning time for a model that incorporates LSTMs.

State of learning

Below is the learning curve: image.png It can be said that the parameters are stable enough. If nothing is tampered with, the learning rate will be reduced to 1/10 when epoch = 75, but at that timing the loss will drop sharply. The training loss (blue line) has fallen, but the validation loss (orange line) still tends to fall, so you may want to learn a little more.

Another point to note is that the loss (MAE) measured during learning is ** measured differently ** than the MAE reported in the dissertation. In the paper, it is reported that the MAE finally became 3.13e-3 (= 0.00313), but this MAE measures the "error of pixels between the predicted frame and the actual frame". On the other hand, the loss reported by Keras during training is the same in the sense of MAE, but this is a measurement of the "error between the output of error unit E and the zero matrix" [^ 2]. In fact, the final error of the model trained this time (vertical axis in the graph above) was 1.780e-2 (= 0.01780) for training and 2.36e-2 (= 0.0236) for validation. The order is different.

Prediction result

When the model was actually applied to the test data, the MAE in the frame (the same as the measurement method in the paper) was 3.927e-3, which was slightly lower than the accuracy reported in the paper, but it seems to be close to the accuracy reported in the paper. is.

The resulting image is shown below, but first I will explain how to read the result image. This video is 10Hz (10fps, that is, 1 second contains 10 frames of images), and the prediction is 10 frames, that is, 1 second. Even though it is a prediction of 10 frames, it can be said that it is actually a prediction one frame ahead because the prediction is made after receiving the previous frame as input. If you draw a diagram of which image you are predicting from which image, it looks like the following. image.png ** (Note that the RNN structure exists in the model, so we actually "remember" the input so far and use that information as well.) **

It is also mentioned in the paper (also included in the code) that it really makes predictions a few frames ahead (called extrapolation), in which case the following figure Will be: image.png The first few frames warm up the model with the previous image to embed the state in the RNN. After that, without passing any correct image, I think that my output is correct and start extrapolation. Please note that all forecast results posted below are not extrapolated.

The forecast results are as follows (excerpt): plot_000.png plot_001.png plot_005.png plot_011.png plot_041.png plot_043.png plot_081.png

As reported in the paper, we can correctly predict the shadows of white lines and roads. Weird movements about the car (I can't learn because there isn't much data to move sideways to me? For example, in the 5th image, the car is crossing in front of me, and the prediction is strange) Unless it is, it is quite predictable. It seems that the action of turning the steering wheel (6th image) is also handled well.

[^ 1]: Karlsruhe Institute of Technology (KIT) and Toyota Technological Institute at Chicago (TTI-C), together with KITTI. TOYOTA also publishes the video data set of the in-vehicle camera. [^ 2]: The error unit E divides the difference between the predicted frame and the real frame into the "positive part" and "negative part" and applies ReLU, which is the "pixels of the predicted frame and the real frame". It is different from "error".

Recommended Posts

I tried to learn PredNet
I tried to debug.
I tried to paste
I tried to learn logical operations with TF Learn
I tried to let VAE learn motion graphics
I tried to organize SVM.
I tried to implement PCANet
I tried to reintroduce Linux
I tried to introduce Pylint
I tried to summarize SparseMatrix
I tried to touch jupyter
I tried to implement StarGAN (1)
I tried to implement and learn DCGAN with PyTorch
I tried to create Quip API
I tried to touch Python (installation)
I tried to implement adversarial validation
I tried to explain Pytorch dataset
I tried Watson Speech to Text
I tried to touch Tesla's API
I tried to implement hierarchical clustering
I want to scrape images to learn
I tried to organize about MCMC.
I tried to implement Realness GAN
I tried to move the ball
I tried to estimate the interval.
I tried to create a linebot (implementation)
I tried to summarize Python exception handling
I tried to implement PLSA in Python
I tried using Azure Speech to Text.
I tried to implement Autoencoder with TensorFlow
I tried to summarize the umask command
I tried to implement permutation in Python
I tried to create a linebot (preparation)
I tried to visualize AutoEncoder with TensorFlow
I tried scraping
I tried to recognize the wake word
I tried to get started with Hy
I tried PyQ
I tried to implement PLSA in Python 2
Python3 standard input I tried to summarize
I tried to classify text using TensorFlow
I tried AutoKeras
I tried to summarize the graphical modeling.
I tried adding post-increment to CPython Implementation
I tried to implement ADALINE in Python
I tried to let optuna solve Sudoku
I tried to estimate the pi stochastically
I tried to touch the COTOHA API
I tried to implement PPO in Python
I tried papermill
I tried to implement CVAE with PyTorch
I tried to predict by letting RNN learn the sine wave
I tried to make a Web API
I tried to solve TSP with QAOA
[Python] I tried to calculate TF-IDF steadily
I tried to touch Python (basic syntax)
I tried django-slack
I tried Django
I tried spleeter
I tried cgo
I tried my best to return to Lasso