Dieses Mal werde ich, um das Verständnis von LSTM zu vertiefen, LSTM mit Scratch in TensorFlow schreiben.
Das Blockdiagramm von LSTM mit Forget_gate sieht wie folgt aus, und Sie können sehen, dass es aus 4 kleinen Netzwerken besteht (** output_gate, input_gate, forget_gate, z **).
** Z ** möchte das Gewicht W erhöhen, damit es nicht vergessen wird, wenn es eine Eingabe gibt, an die Sie sich erinnern möchten. Wenn Sie jedoch W erhöhen, werden Sie sich auch an Informationen erinnern, an die Sie sich nicht erinnern müssen, also am Ende an die Informationen, an die Sie sich erinnern möchten Wird überschrieben. Dies wird als ** Eingabegewichtskonflikt ** bezeichnet. Um dies zu vermeiden, blockiert ** input_gate ** irrelevante Informationen und verhindert, dass sie in die Speicherzelle C geschrieben werden. ** forget_gate ** löscht bei Bedarf Informationen in der Speicherzelle C. Dies liegt daran, dass sich die Reihe von Zeitreihendaten sofort ändern kann, wenn bestimmte Bedingungen erfüllt sind, und dass die zu diesem Zeitpunkt gespeicherten Informationen zurückgesetzt werden müssen.
** output_gate ** liest nicht den gesamten Inhalt der Speicherzelle C, löscht jedoch unnötige Inhalte, wie im Fall der Eingabe, um ** Konflikte mit dem Ausgabegewicht ** zu vermeiden.
Die vier Netzwerkgewichte self.W und Bias self.B haben dieselbe Form, daher deklarieren wir sie zusammen.
self.W = tf.Variable(tf.zeros([input_size + hidden_size, hidden_size *4 ]))
self.B = tf.Variable(tf.zeros([hidden_size * 4 ]))
Weiterleitungscode weiterleiten. Dieses Mal werden zur Vereinfachung der Nachbearbeitung h und c gestapelt. Stellen Sie sie daher zuerst wieder her. Anschließend wird die gewichtete lineare Summe der vier Netzwerke zusammen berechnet und das Ergebnis in vier geteilt.
def forward(self, prev_state, x):
# h,Wiederherstellen c
h, c = tf.unstack(prev_state)
#Berechnen Sie die gewichtete lineare Summe von vier Netzwerken zusammen
inputs = tf.concat([x, h], axis=1)
inputs = tf.matmul(inputs, self.W) + self.B
z, i, f, o = tf.split(inputs, 4, axis=1)
Führen Sie Sigmoid durch die Signale von den drei Toren.
#Führen Sie Sigmoid durch das Signal jedes Tors
input_gate = tf.sigmoid(i)
forget_gate = tf.sigmoid(f)
output_gate = tf.sigmoid(o)
Aktualisieren Sie die Speicherzelle basierend auf den Gate- und Middle-Layer-Eingängen, um den Middle-Layer-Ausgang zu berechnen. Es gibt kein Problem, auch wenn vor output_gate kein tanh steht, daher wird es weggelassen.
#Speicherzellenaktualisierung, Zwischenausgangsberechnung
next_c = c * forget_gate + tf.nn.tanh(z) * input_gate
next_h = next_c * output_gate
#Stapel aufgrund von Nachbearbeitung
return tf.stack([next_h, next_c])
Verwenden wir nun dieses LSTM, um den Code zu schreiben, der die Vorhersage tatsächlich ausführt. Der Datensatz verwendet ein Bild von ** Dagits ** mit ** Zahlen 0-9 ** (kleine 8 * 8 Pixel).
Lassen Sie LSTM basierend auf dem Ergebnis des achtmaligen Scannens einer Datenzeile Zeile für Zeile die Anzahl vorhersagen.
import numpy as np
import tensorflow as tf
from sklearn import datasets
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
class LSTM(object):
def __init__(self, input_size, hidden_size, output_size):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
#Eingabeebene
self.inputs = tf.placeholder(tf.float32, shape=[None, None, self.input_size], name='inputs')
self.W = tf.Variable(tf.zeros([input_size + hidden_size, hidden_size *4 ]))
self.B = tf.Variable(tf.zeros([hidden_size * 4 ]))
#Ausgabeschicht
self.Wv = tf.Variable(tf.truncated_normal([hidden_size, output_size], mean=0, stddev=0.01))
self.bv = tf.Variable(tf.truncated_normal([output_size], mean=0, stddev=0.01))
self.init_hidden = tf.matmul(self.inputs[:,0,:], tf.zeros([input_size, hidden_size]))
self.init_hidden = tf.stack([self.init_hidden, self.init_hidden])
self.input_fn = self._get_batch_input(self.inputs)
def forward(self, prev_state, x):
# h,Wiederherstellen c
h, c = tf.unstack(prev_state)
#Berechnen Sie die gewichtete lineare Summe von vier Netzwerken zusammen
inputs = tf.concat([x, h], axis=1)
inputs = tf.matmul(inputs, self.W) + self.B
z, i, f, o = tf.split(inputs, 4, axis=1)
#Führen Sie Sigmoid durch das Signal jedes Tors
input_gate = tf.sigmoid(i)
forget_gate = tf.sigmoid(f)
output_gate = tf.sigmoid(o)
#Speicherzellenaktualisierung, Zwischenausgangsberechnung
next_c = c * forget_gate + tf.nn.tanh(z) * input_gate
next_h = next_c * output_gate
#Stapel aufgrund von Nachbearbeitung
return tf.stack([next_h, next_c])
def _get_batch_input(self, inputs):
return tf.transpose(tf.transpose(inputs, perm=[2, 0, 1]))
def calc_all_layers(self):
all_hidden_states = tf.scan(self.forward, self.input_fn, initializer=self.init_hidden, name='states')
return all_hidden_states[:, 0, :, :]
def calc_output(self, state):
return tf.nn.tanh(tf.matmul(state, self.Wv) + self.bv)
def calc_outputs(self):
all_states = self.calc_all_layers()
all_outputs = tf.map_fn(self.calc_output, all_states)
return all_outputs
#Datensatz laden( 8*8 image of a digit)
digits = datasets.load_digits()
X = digits.images
Y_= digits.target
Y=tf.keras.utils.to_categorical(Y_, 10)
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
print(Y.shape)
#Vorausschauende Ausführung
hidden_size = 50
input_size = 8
output_size = 10
y = tf.placeholder(tf.float32, shape=[None, output_size], name='inputs')
lstm = LSTM(input_size, hidden_size, output_size)
outputs = lstm.calc_outputs()
last_output = outputs[-1]
output = tf.nn.softmax(last_output)
loss = -tf.reduce_sum(y * tf.log(output))
train_step = tf.train.AdamOptimizer().minimize(loss)
correct_predictions = tf.equal(tf.argmax(y, 1), tf.argmax(output, 1))
acc = (tf.reduce_mean(tf.cast(correct_predictions, tf.float32)))
sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
log_loss = []
log_acc = []
log_val_acc = []
for epoch in range(100):
start=0
end=100
for i in range(14):
X=X_train[start:end]
Y=y_train[start:end]
start=end
end=start+100
sess.run(train_step,feed_dict={lstm.inputs:X, y:Y})
log_loss.append(sess.run(loss,feed_dict={lstm.inputs:X, y:Y}))
log_acc.append(sess.run(acc,feed_dict={lstm.inputs:X_train[:500], y:y_train[:500]}))
log_val_acc.append(sess.run(acc,feed_dict={lstm.inputs:X_test, y:y_test}))
print("\r[%s] loss: %s acc: %s val acc: %s"%(epoch, log_loss[-1], log_acc[-1], log_val_acc[-1])),
#acc Grafik
plt.ylim(0., 1.)
plt.plot(log_acc, label='acc')
plt.plot(log_val_acc, label = 'val_acc')
plt.legend()
plt.show()
Es ist eine relativ gute Genauigkeit.
Recommended Posts