[PYTHON] [DL] Was ist Gewichtsverlust?

In einem tiefen neuronalen Netzwerk ist das Modell umso ausdrucksvoller, je mehr Schichten vorhanden sind. Je höher jedoch die Anzahl der Schichten ist, desto höher ist das Risiko einer ** Überanpassung **. Das Risiko einer ** Überanpassung ** wird verringert, indem die Parameterfreiheit unter Beibehaltung der Ausdruckskraft des Modells eingeschränkt wird. Eine der Methoden ist ** Gewichtsabnahme **.

Die Gewichtsaktualisierungsformel ist wie folgt geschrieben.

w \leftarrow w -\eta \frac{\partial C(w)}{\partial w} - \eta \lambda w

Die obige Formel ist etwas schwer zu verstehen, was Sie tun möchten, aber sie kommt tatsächlich von der Kostenfunktion, wie unten gezeigt.

\tilde C(w) = C(w) + \frac{\lambda}{2}||w||^2

Dies ist die ** Kostenfunktion ** mit der ** L2-Regularisierungsklausel **. Dieser Begriff reduziert den Gewichtswert. Wenn Sie es also tatsächlich implementieren, fügen Sie den Abschnitt ** L2-Regularisierung ** zu den Kosten hinzu.

Normalerweise wird die ** L2-Regularisierung ** nicht auf die Vorspannung angewendet. Dies ist auf die unterschiedlichen Rollen von Neuronengewicht und Voreingenommenheit zurückzuführen. Da das Gewicht die Rolle bei der Auswahl der Eingabe spielt, spielt es keine Rolle, ob der Wert kleiner wird, solange sich die Priorität nicht ändert. Dies liegt daran, dass die Vorspannung aufgrund der Rolle des Schwellenwerts möglicherweise groß sein muss.

Wenn Sie tatsächlich mit ** Gewichtsabnahme ** und ** ohne Gewichtsabnahme ** trainieren und das Gewichtshistogramm betrachten, sieht es wie in der folgenden Abbildung aus. Die linke ist ohne Gewichtsabnahme und die rechte ist mit Gewichtsabnahme. Sie können sehen, dass das Gewicht kleiner wird.

weightdecay1.png

Die Genauigkeit ist wie folgt. Blau ist das Ergebnis ohne Gewichtsabnahme, Rot ist das Ergebnis eines Gewichtsabfalls, die gepunktete Linie sind die Trainingsdaten und die durchgezogene Linie sind die Validierungsdaten. weightdecay2.png

image

import tensorflow as tf
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)

image_size = 28
n_labels = 10
n_batch  = 128
n_train  = 10000
beta = 0.001

def accuracy(y, y_):
    return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]

with tf.variable_scope("slp"):
    x  = tf.placeholder(tf.float32, shape=(n_batch, image_size*image_size))
    y_ = tf.placeholder(tf.float32, shape=(n_batch, n_labels))
    w0 = tf.get_variable("w0", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b0 = tf.get_variable("b0", [n_labels], initializer=tf.constant_initializer(0.0))

    w1 = tf.get_variable("w1", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b1 = tf.get_variable("b1", [n_labels], initializer=tf.constant_initializer(0.0))
    
    y0 = tf.matmul( x, w0 ) + b0
    y1 = tf.matmul( x, w1 ) + b1
    
valid_data = mnist.validation.images
valid_labels = mnist.validation.labels
test_data = mnist.test.images
test_labels = mnist.test.labels
vx = tf.constant(valid_data)
vy_ = tf.constant(valid_labels)
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)

loss0 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y0, y_))
optimizer0 = tf.train.GradientDescentOptimizer(0.5).minimize(loss0)

loss1 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y1, y_) + beta * tf.nn.l2_loss(w1))
optimizer1 = tf.train.GradientDescentOptimizer(0.5).minimize(loss1)

train_prediction0 = tf.nn.softmax(y0)
valid_prediction0 = tf.nn.softmax(tf.matmul(vx, w0) + b0)
test_prediction0  = tf.nn.softmax(tf.matmul(tx, w0) + b0)

train_prediction1 = tf.nn.softmax(y1)
valid_prediction1 = tf.nn.softmax(tf.matmul(vx, w1) + b1)
test_prediction1  = tf.nn.softmax(tf.matmul(tx, w1) + b1)

sess = tf.InteractiveSession()
# sess = tf.Session()

init = tf.initialize_all_variables()
sess.run(init)
result_accuracy = []

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)    
    _, L0, tp0 = sess.run([optimizer0, loss0, train_prediction0], feed_dict={x: bx, y_: by})
    _, L1, tp1 = sess.run([optimizer1, loss1, train_prediction1], feed_dict={x: bx, y_: by})
    if step % 500 == 0:
        ac_wo_decay_train = accuracy(tp0, by)
        ac_wo_decay_valid = accuracy(valid_prediction0.eval(), valid_labels)
        ac_wt_decay_train = accuracy(tp1, by)
        ac_wt_decay_valid = accuracy(valid_prediction1.eval(), valid_labels)
        ac = {'step' : step, 'wo_decay' : {'training' : ac_wo_decay_train, 'validation' : ac_wo_decay_valid}, 'wt_decay' : {'training' : ac_wt_decay_train, 'validation' : ac_wt_decay_valid}}
        result_accuracy.append(ac)
        print "step = %d, train accuracy0: %.2f, validation accuracy0: %.2f, train accuracy1: %.2f, validation accuracy1: %.2f" % (step, ac_wo_decay_train, ac_wo_decay_valid, ac_wt_decay_train, ac_wt_decay_valid)
        
print "test accuracy0: %.2f" % accuracy(test_prediction0.eval(), test_labels)
print "test accuracy1: %.2f" % accuracy(test_prediction1.eval(), test_labels)

fig,axes = plt.subplots(ncols=2, figsize=(8,4))
axes[0].hist(w0.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[0].set_title('without weight decay')
axes[0].set_xlabel('weight')
axes[1].hist(w1.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[1].set_title('with weight decay')
axes[1].set_xlabel('weight')
fig.show()

tr_step = [ac['step'] for ac in result_accuracy]
ac_training_wo_decay = [ac['wo_decay']['training'] for ac in result_accuracy]
ac_training_wt_decay = [ac['wt_decay']['training'] for ac in result_accuracy]
ac_validation_wo_decay = [ac['wo_decay']['validation'] for ac in result_accuracy]
ac_validation_wt_decay = [ac['wt_decay']['validation'] for ac in result_accuracy]

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(1,1,1)

ax.plot(tr_step, ac_training_wo_decay, color='blue', linestyle='dashed')
ax.plot(tr_step, ac_training_wt_decay, color='red', linestyle='dashed')
ax.plot(tr_step, ac_validation_wo_decay, color='blue', linestyle='solid')
ax.plot(tr_step, ac_validation_wt_decay, color='red', linestyle='solid')
ax.set_title('accuracy')
ax.set_xlabel('step')
ax.set_ylabel('accuracy')
ax.grid(True)
ax.set_xlim((0, 10000))
ax.set_ylim((0, 100))
fig.show()

Recommended Posts

[DL] Was ist Gewichtsverlust?
Was ist ein Namespace?
Was ist Django? .. ..
Was ist dotenv?
Was ist POSIX?
Was ist Linux?
Was ist klass?
Was ist SALOME?
Was ist Linux?
Was ist Python?
Was ist Hyperopt?
Was ist Linux?
Was ist Pyvenv?
Was ist __call__?
Was ist Linux?
Was ist Python?
[Python] Was ist Pipeline ...
Was ist das Calmar-Verhältnis?
Was ist ein Terminal?
[PyTorch Tutorial ①] Was ist PyTorch?
Was ist Hyperparameter-Tuning?
Was ist ein Hacker?
Was ist JSON? .. [Hinweis]
Wofür ist Linux?
Was ist ein Zeiger?
Was ist Ensemble-Lernen?
Was ist TCP / IP?
Was ist Pythons __init__.py?
Was ist ein Iterator?
Was ist UNIT-V Linux?
[Python] Was ist virtualenv?
Was ist maschinelles Lernen?
Was ist Mini Sam oder Mini Max?
Was ist eine logistische Regressionsanalyse?
Was ist die Aktivierungsfunktion?
Was ist eine Instanzvariable?
Was ist ein Kontextwechsel?
Was ist Google Cloud Dataflow?
[Python] Python und Sicherheit - is Was ist Python?
Was ist ein Superuser?
Wettbewerbsprogrammierung ist was (Bonus)
[Python] * args ** Was ist kwrgs?
Was ist ein Systemaufruf?
[Definition] Was ist ein Framework?
Was ist die Schnittstelle für ...
Was ist Project Euler 3-Beschleunigung?
Was ist eine Rückruffunktion?
Was ist die Rückruffunktion?
Was ist Ihr "Tanimoto-Koeffizient"?
Python-Grundkurs (1 Was ist Python?)
[Python] Was ist eine Zip-Funktion?
[Python] Was ist eine with-Anweisung?
Was ist Kennzeichnung in der Finanzprognose?
Was ist eine reduzierte Rangkammregression?
Was ist Azure Automation Update Management?
[Python] Was ist @? (Über Dekorateure)
Was ist ein lexikalischer / dynamischer Bereich?
Was ist das Convolutional Neural Network?