[PYTHON] Essayez TensorFlow MNIST avec RNN

J'ai essayé MNIST avec RNN (réseaux neuronaux récurrents).

RNN Il a une structure de réseau qui peut gérer des données séquentielles de longueur variable pour les valeurs d'entrée et de sortie. Le RNN a un état, et à chaque instant t, il peut passer à l'état suivant en fonction de la valeur d'entrée et de l'état. RNN a un état à l'intérieur et conserve l'état en passant de l'entrée à l'état suivant.

LSTM LSTM (Long short-term memory) est un type de modèle ou d'architecture de données chronologiques (données séquentielles) apparu en 1995 comme une extension de RNN (Recurrent Neural Network). Selon cet article, si vous souhaitez générer une phrase, Il est responsable de prédire le prochain mot qui semble être. En rappelant à plusieurs reprises au LSTM la phrase correcte (mise à jour du vecteur de poids), ce LSTM apprend «virtuellement» la règle selon laquelle «est» vient après «ceci». Il semble que vous puissiez faire quelque chose comme ça. Je vois! Génial!

La principale caractéristique est qu'il est possible d'apprendre des dépendances à long terme qui ne pourraient pas être apprises avec les RNN conventionnels.

Essayons-le pour le moment

En regardant ce site, LSTM est en train de passer de haut en bas comme indiqué dans l'image ci-dessous. Cela semble être appris avec. mnist-gif.gif

Flux de programme

J'ai utilisé un cahier Jupyter. Importez les bibliothèques requises et chargez les données MNIST

import tensorflow as tf
from tensorflow.contrib import rnn

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

Ensuite, définissez des espaces réservés pour l'entrée et corrigez les étiquettes

x = tf.placeholder("float", [None, 28, 28])
y = tf.placeholder("float", [None, 10])

Définissez un modèle pour RNN. Modèle LSTM avec 128 unités de couches cachées Convertissez en un tenseur divisé pour chaque étape. Converti en une liste Python avec 28 tenseurs de [taille du lot x 28] avec tf.unstack.

def RNN(x):
    x = tf.unstack(x, 28, 1)

    #Paramètres LSTM
    lstm_cell = rnn.BasicLSTMCell(128, forget_bias=1.0)

    #Définition du modèle. La valeur de sortie et l'état de chaque pas de temps sont renvoyés
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    #Paramètres de pondération et de biais
    weight = tf.Variable(tf.random_normal([128, 10]))
    bias = tf.Variable(tf.random_normal([10]))

    return tf.matmul(outputs[-1], weight) + bias

Définissez la fonction de coût. Cette fois, j'ai utilisé la fonction d'erreur d'entropie croisée et Adam Optimizer pour la formation.

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

#Pour évaluation
correct_pred = tf.equal(tf.argmax(preds, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

Entraînez-vous en utilisant le modèle créé

batch_size = 128
n_training_iters = 100000
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    step = 1
    # Keep training until reach max iterations
    while step * batch_size < n_training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # next_lot retourné par lot_x est[batch_size, 784]Parce que c'est un tenseur de_Convertir en taille x 28 x 28.
        batch_x = batch_x.reshape((batch_size, 28, 28))
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        if step % 10 == 0:
            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
            print('step: {} / loss: {:.6f} / acc: {:.5f}'.format(step, loss, acc))
        step += 1

    #tester
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, 28, 28))
    test_label = mnist.test.labels[:test_len]
    test_acc = sess.run(accuracy, feed_dict={x: test_data, y: test_label})
    print("Test Accuracy: {}".format(test_acc))

production


step: 10 / loss: 1.751291 / acc: 0.42969
step: 20 / loss: 1.554639 / acc: 0.46875
step: 30 / loss: 1.365595 / acc: 0.57031
step: 40 / loss: 1.176470 / acc: 0.60156
step: 50 / loss: 0.787636 / acc: 0.75781
step: 60 / loss: 0.776735 / acc: 0.75781
step: 70 / loss: 0.586180 / acc: 0.79688
step: 80 / loss: 0.692503 / acc: 0.80469
step: 90 / loss: 0.550008 / acc: 0.82812
step: 100 / loss: 0.553710 / acc: 0.86719
step: 110 / loss: 0.423268 / acc: 0.86719
step: 120 / loss: 0.462931 / acc: 0.82812
step: 130 / loss: 0.365392 / acc: 0.85938
step: 140 / loss: 0.505170 / acc: 0.85938
step: 150 / loss: 0.273539 / acc: 0.91406
step: 160 / loss: 0.322731 / acc: 0.87500
step: 170 / loss: 0.531190 / acc: 0.85156
step: 180 / loss: 0.318869 / acc: 0.90625
step: 190 / loss: 0.351407 / acc: 0.86719
step: 200 / loss: 0.232232 / acc: 0.92188
step: 210 / loss: 0.245849 / acc: 0.92969
step: 220 / loss: 0.312085 / acc: 0.92188
step: 230 / loss: 0.276383 / acc: 0.89844
step: 240 / loss: 0.196890 / acc: 0.94531
step: 250 / loss: 0.221909 / acc: 0.91406
step: 260 / loss: 0.246551 / acc: 0.92969
step: 270 / loss: 0.242577 / acc: 0.92188
step: 280 / loss: 0.165623 / acc: 0.94531
step: 290 / loss: 0.232382 / acc: 0.94531
step: 300 / loss: 0.159169 / acc: 0.92969
step: 310 / loss: 0.229053 / acc: 0.92969
step: 320 / loss: 0.384319 / acc: 0.90625
step: 330 / loss: 0.151922 / acc: 0.93750
step: 340 / loss: 0.153512 / acc: 0.95312
step: 350 / loss: 0.113470 / acc: 0.96094
step: 360 / loss: 0.192841 / acc: 0.93750
step: 370 / loss: 0.169354 / acc: 0.96094
step: 380 / loss: 0.217942 / acc: 0.94531
step: 390 / loss: 0.151771 / acc: 0.95312
step: 400 / loss: 0.139619 / acc: 0.96094
step: 410 / loss: 0.236149 / acc: 0.92969
step: 420 / loss: 0.131790 / acc: 0.94531
step: 430 / loss: 0.172267 / acc: 0.96094
step: 440 / loss: 0.182242 / acc: 0.93750
step: 450 / loss: 0.131859 / acc: 0.94531
step: 460 / loss: 0.216793 / acc: 0.92969
step: 470 / loss: 0.082368 / acc: 0.96875
step: 480 / loss: 0.064672 / acc: 0.96094
step: 490 / loss: 0.119717 / acc: 0.96875
step: 500 / loss: 0.169831 / acc: 0.94531
step: 510 / loss: 0.106913 / acc: 0.98438
step: 520 / loss: 0.073209 / acc: 0.97656
step: 530 / loss: 0.131819 / acc: 0.96875
step: 540 / loss: 0.210754 / acc: 0.94531
step: 550 / loss: 0.141051 / acc: 0.93750
step: 560 / loss: 0.217726 / acc: 0.94531
step: 570 / loss: 0.121927 / acc: 0.96094
step: 580 / loss: 0.130969 / acc: 0.94531
step: 590 / loss: 0.125145 / acc: 0.95312
step: 600 / loss: 0.193178 / acc: 0.95312
step: 610 / loss: 0.114959 / acc: 0.95312
step: 620 / loss: 0.129038 / acc: 0.96094
step: 630 / loss: 0.151445 / acc: 0.95312
step: 640 / loss: 0.120206 / acc: 0.96094
step: 650 / loss: 0.107941 / acc: 0.96875
step: 660 / loss: 0.114320 / acc: 0.95312
step: 670 / loss: 0.094687 / acc: 0.94531
step: 680 / loss: 0.115308 / acc: 0.96875
step: 690 / loss: 0.125207 / acc: 0.96094
step: 700 / loss: 0.085296 / acc: 0.96875
step: 710 / loss: 0.119154 / acc: 0.94531
step: 720 / loss: 0.089058 / acc: 0.96875
step: 730 / loss: 0.054484 / acc: 0.97656
step: 740 / loss: 0.113646 / acc: 0.93750
step: 750 / loss: 0.051113 / acc: 0.99219
step: 760 / loss: 0.183365 / acc: 0.94531
step: 770 / loss: 0.112222 / acc: 0.95312
step: 780 / loss: 0.078913 / acc: 0.96094
Test Accuracy: 0.984375

Classez avec une précision de 98%! !!

Site de référence

Qu'est-ce que les réseaux de neurones récurrents qui gèrent les données de séries chronologiques Comprendre le LSTM-avec les tendances récentes

Recommended Posts

Essayez TensorFlow MNIST avec RNN
Essayez la régression avec TensorFlow
Essayez TensorFlow RNN avec un modèle de base
Essayez l'apprentissage en profondeur avec TensorFlow
MNIST (DCNN) avec Keras (backend TensorFlow)
[TensorFlow 2] Apprendre RNN avec perte CTC
Essayez l'apprentissage en profondeur avec TensorFlow Partie 2
Essayez les données en parallèle avec TensorFlow distribué
Essayez Distributed Tensor Flow
Pratiquez RNN TensorFlow
Zundokokiyoshi avec TensorFlow
Casser des blocs avec Tensorflow
Essayez Tensorflow avec une instance GPU sur AWS
Essayez Theano avec les données MNIST de Kaggle ~ Retour logistique ~
MNIST (DCNN) avec skflow
Lecture de données avec TensorFlow
Prévisions de courses de bateaux avec TensorFlow
Essayez SNN avec BindsNET
Code pour TensorFlow MNIST débutant / expert avec commentaires japonais
Essayez MNIST avec VAT (Virtual Adversarial Training) avec Keras
Challenge classification des images par TensorFlow2 + Keras 3 ~ Visualiser les données MNIST ~
Traduire Premiers pas avec TensorFlow
Essayez de défier le sol par récursif
Essayez l'optimisation des fonctions avec Optuna
Utiliser TensorFlow avec Intellij IDEA
Essayez d'utiliser PythonTex avec Texpad.
Essayez la détection des bords avec OpenCV
Essayez d'implémenter RBM avec chainer.
Essayez Google Mock avec C
Fonction sinueuse approximative avec TensorFlow
Essayez d'utiliser matplotlib avec PyCharm
Essayez de programmer avec un shell!
Essayez la programmation GUI avec Hy
Essayez Auto Encoder avec Pytorch
Essayez la sortie Python avec Haxe 3.2
Essayez l'opération matricielle avec NumPy
Essayez d'implémenter XOR avec PyTorch
Essayez d'exécuter CNN avec ChainerRL
Essayez différentes choses avec PhantomJS
Essayez le Deep Learning avec FPGA
Prévision du cours de l'action avec tensorflow
Essayez d'exécuter Python avec Try Jupyter
Essayez d'implémenter le parfum avec Go
Essayez Selenium Grid avec Docker
Essayez la reconnaissance faciale avec Python
Essayez OpenCV avec Google Colaboratory
Essayez le machine learning à la légère avec Kaggle
Pensez aux abandons avec MNIST
Essayez de créer Jupyter Hub avec Docker
Essayez d'utiliser le folium avec anaconda
Création d'un modèle séquentiel Tensorflow avec une image originale ajoutée à MNIST
[Version améliorée] Essayez MNIST avec VAT (Virtual Adversarial Training) sur Keras
Assurer la reproductibilité avec tf.keras dans Tensorflow 2.3
TensorFlow 2.2 ne peut pas être installé avec Python 3.8!
Essayez le Deep Learning avec les concombres FPGA-Select
[Explication pour les débutants] Tutoriel TensorFlow MNIST (pour les débutants)
Traduction TensorFlow MNIST pour les débutants en ML
Renforcer l'apprentissage 13 Essayez Mountain_car avec ChainerRL.
Essayez d'exécuter tensorflow sur Docker + anaconda