J'ai essayé MNIST avec RNN (réseaux neuronaux récurrents).
RNN Il a une structure de réseau qui peut gérer des données séquentielles de longueur variable pour les valeurs d'entrée et de sortie. Le RNN a un état, et à chaque instant t, il peut passer à l'état suivant en fonction de la valeur d'entrée et de l'état. RNN a un état à l'intérieur et conserve l'état en passant de l'entrée à l'état suivant.
LSTM LSTM (Long short-term memory) est un type de modèle ou d'architecture de données chronologiques (données séquentielles) apparu en 1995 comme une extension de RNN (Recurrent Neural Network). Selon cet article, si vous souhaitez générer une phrase, Il est responsable de prédire le prochain mot qui semble être. En rappelant à plusieurs reprises au LSTM la phrase correcte (mise à jour du vecteur de poids), ce LSTM apprend «virtuellement» la règle selon laquelle «est» vient après «ceci». Il semble que vous puissiez faire quelque chose comme ça. Je vois! Génial!
La principale caractéristique est qu'il est possible d'apprendre des dépendances à long terme qui ne pourraient pas être apprises avec les RNN conventionnels.
En regardant ce site, LSTM est en train de passer de haut en bas comme indiqué dans l'image ci-dessous. Cela semble être appris avec.
J'ai utilisé un cahier Jupyter. Importez les bibliothèques requises et chargez les données MNIST
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
Ensuite, définissez des espaces réservés pour l'entrée et corrigez les étiquettes
x = tf.placeholder("float", [None, 28, 28])
y = tf.placeholder("float", [None, 10])
Définissez un modèle pour RNN.
Modèle LSTM avec 128 unités de couches cachées
Convertissez en un tenseur divisé pour chaque étape. Converti en une liste Python avec 28 tenseurs de [taille du lot x 28] avec tf.unstack
.
def RNN(x):
x = tf.unstack(x, 28, 1)
#Paramètres LSTM
lstm_cell = rnn.BasicLSTMCell(128, forget_bias=1.0)
#Définition du modèle. La valeur de sortie et l'état de chaque pas de temps sont renvoyés
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
#Paramètres de pondération et de biais
weight = tf.Variable(tf.random_normal([128, 10]))
bias = tf.Variable(tf.random_normal([10]))
return tf.matmul(outputs[-1], weight) + bias
Définissez la fonction de coût. Cette fois, j'ai utilisé la fonction d'erreur d'entropie croisée et Adam Optimizer pour la formation.
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
#Pour évaluation
correct_pred = tf.equal(tf.argmax(preds, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
Entraînez-vous en utilisant le modèle créé
batch_size = 128
n_training_iters = 100000
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < n_training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# next_lot retourné par lot_x est[batch_size, 784]Parce que c'est un tenseur de_Convertir en taille x 28 x 28.
batch_x = batch_x.reshape((batch_size, 28, 28))
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % 10 == 0:
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print('step: {} / loss: {:.6f} / acc: {:.5f}'.format(step, loss, acc))
step += 1
#tester
test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, 28, 28))
test_label = mnist.test.labels[:test_len]
test_acc = sess.run(accuracy, feed_dict={x: test_data, y: test_label})
print("Test Accuracy: {}".format(test_acc))
production
step: 10 / loss: 1.751291 / acc: 0.42969
step: 20 / loss: 1.554639 / acc: 0.46875
step: 30 / loss: 1.365595 / acc: 0.57031
step: 40 / loss: 1.176470 / acc: 0.60156
step: 50 / loss: 0.787636 / acc: 0.75781
step: 60 / loss: 0.776735 / acc: 0.75781
step: 70 / loss: 0.586180 / acc: 0.79688
step: 80 / loss: 0.692503 / acc: 0.80469
step: 90 / loss: 0.550008 / acc: 0.82812
step: 100 / loss: 0.553710 / acc: 0.86719
step: 110 / loss: 0.423268 / acc: 0.86719
step: 120 / loss: 0.462931 / acc: 0.82812
step: 130 / loss: 0.365392 / acc: 0.85938
step: 140 / loss: 0.505170 / acc: 0.85938
step: 150 / loss: 0.273539 / acc: 0.91406
step: 160 / loss: 0.322731 / acc: 0.87500
step: 170 / loss: 0.531190 / acc: 0.85156
step: 180 / loss: 0.318869 / acc: 0.90625
step: 190 / loss: 0.351407 / acc: 0.86719
step: 200 / loss: 0.232232 / acc: 0.92188
step: 210 / loss: 0.245849 / acc: 0.92969
step: 220 / loss: 0.312085 / acc: 0.92188
step: 230 / loss: 0.276383 / acc: 0.89844
step: 240 / loss: 0.196890 / acc: 0.94531
step: 250 / loss: 0.221909 / acc: 0.91406
step: 260 / loss: 0.246551 / acc: 0.92969
step: 270 / loss: 0.242577 / acc: 0.92188
step: 280 / loss: 0.165623 / acc: 0.94531
step: 290 / loss: 0.232382 / acc: 0.94531
step: 300 / loss: 0.159169 / acc: 0.92969
step: 310 / loss: 0.229053 / acc: 0.92969
step: 320 / loss: 0.384319 / acc: 0.90625
step: 330 / loss: 0.151922 / acc: 0.93750
step: 340 / loss: 0.153512 / acc: 0.95312
step: 350 / loss: 0.113470 / acc: 0.96094
step: 360 / loss: 0.192841 / acc: 0.93750
step: 370 / loss: 0.169354 / acc: 0.96094
step: 380 / loss: 0.217942 / acc: 0.94531
step: 390 / loss: 0.151771 / acc: 0.95312
step: 400 / loss: 0.139619 / acc: 0.96094
step: 410 / loss: 0.236149 / acc: 0.92969
step: 420 / loss: 0.131790 / acc: 0.94531
step: 430 / loss: 0.172267 / acc: 0.96094
step: 440 / loss: 0.182242 / acc: 0.93750
step: 450 / loss: 0.131859 / acc: 0.94531
step: 460 / loss: 0.216793 / acc: 0.92969
step: 470 / loss: 0.082368 / acc: 0.96875
step: 480 / loss: 0.064672 / acc: 0.96094
step: 490 / loss: 0.119717 / acc: 0.96875
step: 500 / loss: 0.169831 / acc: 0.94531
step: 510 / loss: 0.106913 / acc: 0.98438
step: 520 / loss: 0.073209 / acc: 0.97656
step: 530 / loss: 0.131819 / acc: 0.96875
step: 540 / loss: 0.210754 / acc: 0.94531
step: 550 / loss: 0.141051 / acc: 0.93750
step: 560 / loss: 0.217726 / acc: 0.94531
step: 570 / loss: 0.121927 / acc: 0.96094
step: 580 / loss: 0.130969 / acc: 0.94531
step: 590 / loss: 0.125145 / acc: 0.95312
step: 600 / loss: 0.193178 / acc: 0.95312
step: 610 / loss: 0.114959 / acc: 0.95312
step: 620 / loss: 0.129038 / acc: 0.96094
step: 630 / loss: 0.151445 / acc: 0.95312
step: 640 / loss: 0.120206 / acc: 0.96094
step: 650 / loss: 0.107941 / acc: 0.96875
step: 660 / loss: 0.114320 / acc: 0.95312
step: 670 / loss: 0.094687 / acc: 0.94531
step: 680 / loss: 0.115308 / acc: 0.96875
step: 690 / loss: 0.125207 / acc: 0.96094
step: 700 / loss: 0.085296 / acc: 0.96875
step: 710 / loss: 0.119154 / acc: 0.94531
step: 720 / loss: 0.089058 / acc: 0.96875
step: 730 / loss: 0.054484 / acc: 0.97656
step: 740 / loss: 0.113646 / acc: 0.93750
step: 750 / loss: 0.051113 / acc: 0.99219
step: 760 / loss: 0.183365 / acc: 0.94531
step: 770 / loss: 0.112222 / acc: 0.95312
step: 780 / loss: 0.078913 / acc: 0.96094
Test Accuracy: 0.984375
Classez avec une précision de 98%! !!
Qu'est-ce que les réseaux de neurones récurrents qui gèrent les données de séries chronologiques Comprendre le LSTM-avec les tendances récentes
Recommended Posts