In einem tiefen neuronalen Netzwerk ist das Modell umso ausdrucksvoller, je mehr Schichten vorhanden sind. Je höher jedoch die Anzahl der Schichten ist, desto höher ist das Risiko einer ** Überanpassung **. Das Risiko einer ** Überanpassung ** wird verringert, indem die Parameterfreiheit unter Beibehaltung der Ausdruckskraft des Modells eingeschränkt wird. Eine der Methoden ist ** Gewichtsabnahme **.
Die Gewichtsaktualisierungsformel ist wie folgt geschrieben.
w \leftarrow w -\eta \frac{\partial C(w)}{\partial w} - \eta \lambda w
Die obige Formel ist etwas schwer zu verstehen, was Sie tun möchten, aber sie kommt tatsächlich von der Kostenfunktion, wie unten gezeigt.
\tilde C(w) = C(w) + \frac{\lambda}{2}||w||^2
Dies ist die ** Kostenfunktion ** mit der ** L2-Regularisierungsklausel **. Dieser Begriff reduziert den Gewichtswert. Wenn Sie es also tatsächlich implementieren, fügen Sie den Abschnitt ** L2-Regularisierung ** zu den Kosten hinzu.
Normalerweise wird die ** L2-Regularisierung ** nicht auf die Vorspannung angewendet. Dies ist auf die unterschiedlichen Rollen von Neuronengewicht und Voreingenommenheit zurückzuführen. Da das Gewicht die Rolle bei der Auswahl der Eingabe spielt, spielt es keine Rolle, ob der Wert kleiner wird, solange sich die Priorität nicht ändert. Dies liegt daran, dass die Vorspannung aufgrund der Rolle des Schwellenwerts möglicherweise groß sein muss.
Wenn Sie tatsächlich mit ** Gewichtsabnahme ** und ** ohne Gewichtsabnahme ** trainieren und das Gewichtshistogramm betrachten, sieht es wie in der folgenden Abbildung aus. Die linke ist ohne Gewichtsabnahme und die rechte ist mit Gewichtsabnahme. Sie können sehen, dass das Gewicht kleiner wird.
Die Genauigkeit ist wie folgt. Blau ist das Ergebnis ohne Gewichtsabnahme, Rot ist das Ergebnis eines Gewichtsabfalls, die gepunktete Linie sind die Trainingsdaten und die durchgezogene Linie sind die Validierungsdaten.
import tensorflow as tf
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)
image_size = 28
n_labels = 10
n_batch = 128
n_train = 10000
beta = 0.001
def accuracy(y, y_):
return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]
with tf.variable_scope("slp"):
x = tf.placeholder(tf.float32, shape=(n_batch, image_size*image_size))
y_ = tf.placeholder(tf.float32, shape=(n_batch, n_labels))
w0 = tf.get_variable("w0", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
b0 = tf.get_variable("b0", [n_labels], initializer=tf.constant_initializer(0.0))
w1 = tf.get_variable("w1", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
b1 = tf.get_variable("b1", [n_labels], initializer=tf.constant_initializer(0.0))
y0 = tf.matmul( x, w0 ) + b0
y1 = tf.matmul( x, w1 ) + b1
valid_data = mnist.validation.images
valid_labels = mnist.validation.labels
test_data = mnist.test.images
test_labels = mnist.test.labels
vx = tf.constant(valid_data)
vy_ = tf.constant(valid_labels)
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)
loss0 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y0, y_))
optimizer0 = tf.train.GradientDescentOptimizer(0.5).minimize(loss0)
loss1 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y1, y_) + beta * tf.nn.l2_loss(w1))
optimizer1 = tf.train.GradientDescentOptimizer(0.5).minimize(loss1)
train_prediction0 = tf.nn.softmax(y0)
valid_prediction0 = tf.nn.softmax(tf.matmul(vx, w0) + b0)
test_prediction0 = tf.nn.softmax(tf.matmul(tx, w0) + b0)
train_prediction1 = tf.nn.softmax(y1)
valid_prediction1 = tf.nn.softmax(tf.matmul(vx, w1) + b1)
test_prediction1 = tf.nn.softmax(tf.matmul(tx, w1) + b1)
sess = tf.InteractiveSession()
# sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
result_accuracy = []
for step in xrange(n_train):
bx, by = mnist.train.next_batch(n_batch)
_, L0, tp0 = sess.run([optimizer0, loss0, train_prediction0], feed_dict={x: bx, y_: by})
_, L1, tp1 = sess.run([optimizer1, loss1, train_prediction1], feed_dict={x: bx, y_: by})
if step % 500 == 0:
ac_wo_decay_train = accuracy(tp0, by)
ac_wo_decay_valid = accuracy(valid_prediction0.eval(), valid_labels)
ac_wt_decay_train = accuracy(tp1, by)
ac_wt_decay_valid = accuracy(valid_prediction1.eval(), valid_labels)
ac = {'step' : step, 'wo_decay' : {'training' : ac_wo_decay_train, 'validation' : ac_wo_decay_valid}, 'wt_decay' : {'training' : ac_wt_decay_train, 'validation' : ac_wt_decay_valid}}
result_accuracy.append(ac)
print "step = %d, train accuracy0: %.2f, validation accuracy0: %.2f, train accuracy1: %.2f, validation accuracy1: %.2f" % (step, ac_wo_decay_train, ac_wo_decay_valid, ac_wt_decay_train, ac_wt_decay_valid)
print "test accuracy0: %.2f" % accuracy(test_prediction0.eval(), test_labels)
print "test accuracy1: %.2f" % accuracy(test_prediction1.eval(), test_labels)
fig,axes = plt.subplots(ncols=2, figsize=(8,4))
axes[0].hist(w0.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[0].set_title('without weight decay')
axes[0].set_xlabel('weight')
axes[1].hist(w1.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[1].set_title('with weight decay')
axes[1].set_xlabel('weight')
fig.show()
tr_step = [ac['step'] for ac in result_accuracy]
ac_training_wo_decay = [ac['wo_decay']['training'] for ac in result_accuracy]
ac_training_wt_decay = [ac['wt_decay']['training'] for ac in result_accuracy]
ac_validation_wo_decay = [ac['wo_decay']['validation'] for ac in result_accuracy]
ac_validation_wt_decay = [ac['wt_decay']['validation'] for ac in result_accuracy]
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(1,1,1)
ax.plot(tr_step, ac_training_wo_decay, color='blue', linestyle='dashed')
ax.plot(tr_step, ac_training_wt_decay, color='red', linestyle='dashed')
ax.plot(tr_step, ac_validation_wo_decay, color='blue', linestyle='solid')
ax.plot(tr_step, ac_validation_wt_decay, color='red', linestyle='solid')
ax.set_title('accuracy')
ax.set_xlabel('step')
ax.set_ylabel('accuracy')
ax.grid(True)
ax.set_xlim((0, 10000))
ax.set_ylim((0, 100))
fig.show()
Recommended Posts