Ich habe MNIST mit RNN (Recurrent Neural Networks) ausprobiert.
RNN Es verfügt über eine Netzwerkstruktur, die sequentielle Daten variabler Länge für Eingabe- und Ausgabewerte verarbeiten kann. Das RNN hat einen Zustand und kann zu jedem Zeitpunkt t basierend auf dem Eingabewert und dem Zustand in den nächsten Zustand übergehen. RNN hat einen Zustand im Inneren und hält den Zustand durch Übergang vom Eingang zum nächsten Zustand.
LSTM LSTM (Long Short-Term Memory) ist eine Art Modell oder Architektur für Zeitreihendaten (sequentielle Daten), die 1995 als Erweiterung von RNN (Recurrent Neural Network) erschien. Laut diesem Artikel, wenn Sie einen Satz generieren möchten, Es ist dafür verantwortlich, das nächste Wort vorherzusagen, das zu sein scheint. Durch wiederholtes Erinnern des LSTM an den richtigen Satz (Aktualisieren des Gewichtsvektors) lernt dieses LSTM "virtuell" die Regel, dass "ist" nach "diesem" kommt. Es scheint, dass Sie so etwas tun können. Das war's! Toll!
Das größte Merkmal ist, dass es möglich ist, langfristige Abhängigkeiten zu lernen, die mit herkömmlichen RNNs nicht gelernt werden konnten.
Mit Blick auf diese Site wechselt LSTM von oben nach unten, wie in der folgenden Abbildung gezeigt. Es scheint mit gelernt zu sein.
Ich habe ein Jupyter-Notizbuch benutzt. Importieren Sie die erforderlichen Bibliotheken und laden Sie MNIST-Daten
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
Definieren Sie als Nächstes Platzhalter für die Eingabe und korrigieren Sie die Beschriftungen
x = tf.placeholder("float", [None, 28, 28])
y = tf.placeholder("float", [None, 10])
Definieren Sie ein Modell für RNN.
LSTM-Modell mit 128 Hidden-Layer-Einheiten
Konvertieren Sie in einen Tensor, der für jeden Schritt geteilt wird. Konvertiert in eine Python-Liste mit 28 Tensoren von [Stapelgröße x 28] mit tf.unstack
.
def RNN(x):
x = tf.unstack(x, 28, 1)
#LSTM-Einstellungen
lstm_cell = rnn.BasicLSTMCell(128, forget_bias=1.0)
#Modelldefinition. Der Ausgabewert und der Status jedes Zeitschritts werden zurückgegeben
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
#Gewichts- und Vorspannungseinstellungen
weight = tf.Variable(tf.random_normal([128, 10]))
bias = tf.Variable(tf.random_normal([10]))
return tf.matmul(outputs[-1], weight) + bias
Kostenfunktion definieren. Dieses Mal habe ich die Cross-Entropy-Error-Funktion und Adam Optimizer für das Training verwendet.
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)
#Zur Auswertung
correct_pred = tf.equal(tf.argmax(preds, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
Trainiere mit dem erstellten Modell
batch_size = 128
n_training_iters = 100000
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step = 1
# Keep training until reach max iterations
while step * batch_size < n_training_iters:
batch_x, batch_y = mnist.train.next_batch(batch_size)
# next_Charge von Charge zurückgegeben_x ist[batch_size, 784]Weil es ein Tensor von ist_In Größe x 28 x 28 konvertieren.
batch_x = batch_x.reshape((batch_size, 28, 28))
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
if step % 10 == 0:
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
print('step: {} / loss: {:.6f} / acc: {:.5f}'.format(step, loss, acc))
step += 1
#Prüfung
test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, 28, 28))
test_label = mnist.test.labels[:test_len]
test_acc = sess.run(accuracy, feed_dict={x: test_data, y: test_label})
print("Test Accuracy: {}".format(test_acc))
Ausgabe
step: 10 / loss: 1.751291 / acc: 0.42969
step: 20 / loss: 1.554639 / acc: 0.46875
step: 30 / loss: 1.365595 / acc: 0.57031
step: 40 / loss: 1.176470 / acc: 0.60156
step: 50 / loss: 0.787636 / acc: 0.75781
step: 60 / loss: 0.776735 / acc: 0.75781
step: 70 / loss: 0.586180 / acc: 0.79688
step: 80 / loss: 0.692503 / acc: 0.80469
step: 90 / loss: 0.550008 / acc: 0.82812
step: 100 / loss: 0.553710 / acc: 0.86719
step: 110 / loss: 0.423268 / acc: 0.86719
step: 120 / loss: 0.462931 / acc: 0.82812
step: 130 / loss: 0.365392 / acc: 0.85938
step: 140 / loss: 0.505170 / acc: 0.85938
step: 150 / loss: 0.273539 / acc: 0.91406
step: 160 / loss: 0.322731 / acc: 0.87500
step: 170 / loss: 0.531190 / acc: 0.85156
step: 180 / loss: 0.318869 / acc: 0.90625
step: 190 / loss: 0.351407 / acc: 0.86719
step: 200 / loss: 0.232232 / acc: 0.92188
step: 210 / loss: 0.245849 / acc: 0.92969
step: 220 / loss: 0.312085 / acc: 0.92188
step: 230 / loss: 0.276383 / acc: 0.89844
step: 240 / loss: 0.196890 / acc: 0.94531
step: 250 / loss: 0.221909 / acc: 0.91406
step: 260 / loss: 0.246551 / acc: 0.92969
step: 270 / loss: 0.242577 / acc: 0.92188
step: 280 / loss: 0.165623 / acc: 0.94531
step: 290 / loss: 0.232382 / acc: 0.94531
step: 300 / loss: 0.159169 / acc: 0.92969
step: 310 / loss: 0.229053 / acc: 0.92969
step: 320 / loss: 0.384319 / acc: 0.90625
step: 330 / loss: 0.151922 / acc: 0.93750
step: 340 / loss: 0.153512 / acc: 0.95312
step: 350 / loss: 0.113470 / acc: 0.96094
step: 360 / loss: 0.192841 / acc: 0.93750
step: 370 / loss: 0.169354 / acc: 0.96094
step: 380 / loss: 0.217942 / acc: 0.94531
step: 390 / loss: 0.151771 / acc: 0.95312
step: 400 / loss: 0.139619 / acc: 0.96094
step: 410 / loss: 0.236149 / acc: 0.92969
step: 420 / loss: 0.131790 / acc: 0.94531
step: 430 / loss: 0.172267 / acc: 0.96094
step: 440 / loss: 0.182242 / acc: 0.93750
step: 450 / loss: 0.131859 / acc: 0.94531
step: 460 / loss: 0.216793 / acc: 0.92969
step: 470 / loss: 0.082368 / acc: 0.96875
step: 480 / loss: 0.064672 / acc: 0.96094
step: 490 / loss: 0.119717 / acc: 0.96875
step: 500 / loss: 0.169831 / acc: 0.94531
step: 510 / loss: 0.106913 / acc: 0.98438
step: 520 / loss: 0.073209 / acc: 0.97656
step: 530 / loss: 0.131819 / acc: 0.96875
step: 540 / loss: 0.210754 / acc: 0.94531
step: 550 / loss: 0.141051 / acc: 0.93750
step: 560 / loss: 0.217726 / acc: 0.94531
step: 570 / loss: 0.121927 / acc: 0.96094
step: 580 / loss: 0.130969 / acc: 0.94531
step: 590 / loss: 0.125145 / acc: 0.95312
step: 600 / loss: 0.193178 / acc: 0.95312
step: 610 / loss: 0.114959 / acc: 0.95312
step: 620 / loss: 0.129038 / acc: 0.96094
step: 630 / loss: 0.151445 / acc: 0.95312
step: 640 / loss: 0.120206 / acc: 0.96094
step: 650 / loss: 0.107941 / acc: 0.96875
step: 660 / loss: 0.114320 / acc: 0.95312
step: 670 / loss: 0.094687 / acc: 0.94531
step: 680 / loss: 0.115308 / acc: 0.96875
step: 690 / loss: 0.125207 / acc: 0.96094
step: 700 / loss: 0.085296 / acc: 0.96875
step: 710 / loss: 0.119154 / acc: 0.94531
step: 720 / loss: 0.089058 / acc: 0.96875
step: 730 / loss: 0.054484 / acc: 0.97656
step: 740 / loss: 0.113646 / acc: 0.93750
step: 750 / loss: 0.051113 / acc: 0.99219
step: 760 / loss: 0.183365 / acc: 0.94531
step: 770 / loss: 0.112222 / acc: 0.95312
step: 780 / loss: 0.078913 / acc: 0.96094
Test Accuracy: 0.984375
Klassifizieren Sie mit 98% Genauigkeit! !!
Was sind wiederkehrende neuronale Netze, die Zeitreihendaten verarbeiten LSTM mit den neuesten Trends verstehen
Recommended Posts