[PYTHON] [DL] Qu'est-ce que la décroissance du poids?

Dans un réseau neuronal profond, plus il y a de couches, plus le modèle sera expressif. Cependant, plus le nombre de couches est élevé, plus le risque de ** surajustement ** est élevé. Le risque de ** sur-ajustement ** est réduit en limitant la liberté des paramètres tout en maintenant l'expressivité du modèle. L'une des méthodes est la ** décroissance du poids **.

La formule de mise à jour du poids est écrite comme suit.

w \leftarrow w -\eta \frac{\partial C(w)}{\partial w} - \eta \lambda w

La formule ci-dessus est un peu difficile à comprendre ce que vous voulez faire, mais elle provient en fait de la fonction de coût comme indiqué ci-dessous.

\tilde C(w) = C(w) + \frac{\lambda}{2}||w||^2

Il s'agit de la ** fonction de coût ** avec la clause ** de régularisation L2 **. Ce terme réduit la valeur de poids. Ainsi, lorsque vous l'implémenterez réellement, vous ajouterez la section ** Régularisation L2 ** au coût.

Normalement, la ** régularisation L2 ** n'est pas appliquée au biais. Cela vient de la différence des rôles du poids et du biais des neurones. Puisque le poids est le rôle de la sélection de l'entrée, peu importe si la valeur devient plus petite tant que la priorité ne change pas, En effet, le biais peut devoir être important en raison du rôle du seuil.

Si vous vous entraînez réellement avec ** perte de poids ** et ** sans perte de poids ** et regardez l'histogramme de poids, cela ressemblera à la figure ci-dessous. La gauche est sans décroissance du poids et la droite avec décroissance du poids. Vous pouvez voir que le poids diminue.

weightdecay1.png

La précision est la suivante. Le bleu est le résultat de l'absence de décroissance du poids, le rouge est le résultat de la décroissance du poids, la ligne en pointillé correspond aux données d'entraînement et la ligne continue correspond aux données de validation. weightdecay2.png

image

import tensorflow as tf
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./data/mnist/", one_hot=True)

image_size = 28
n_labels = 10
n_batch  = 128
n_train  = 10000
beta = 0.001

def accuracy(y, y_):
    return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]

with tf.variable_scope("slp"):
    x  = tf.placeholder(tf.float32, shape=(n_batch, image_size*image_size))
    y_ = tf.placeholder(tf.float32, shape=(n_batch, n_labels))
    w0 = tf.get_variable("w0", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b0 = tf.get_variable("b0", [n_labels], initializer=tf.constant_initializer(0.0))

    w1 = tf.get_variable("w1", [image_size * image_size, n_labels], initializer=tf.truncated_normal_initializer(seed=0))
    b1 = tf.get_variable("b1", [n_labels], initializer=tf.constant_initializer(0.0))
    
    y0 = tf.matmul( x, w0 ) + b0
    y1 = tf.matmul( x, w1 ) + b1
    
valid_data = mnist.validation.images
valid_labels = mnist.validation.labels
test_data = mnist.test.images
test_labels = mnist.test.labels
vx = tf.constant(valid_data)
vy_ = tf.constant(valid_labels)
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)

loss0 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y0, y_))
optimizer0 = tf.train.GradientDescentOptimizer(0.5).minimize(loss0)

loss1 = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( y1, y_) + beta * tf.nn.l2_loss(w1))
optimizer1 = tf.train.GradientDescentOptimizer(0.5).minimize(loss1)

train_prediction0 = tf.nn.softmax(y0)
valid_prediction0 = tf.nn.softmax(tf.matmul(vx, w0) + b0)
test_prediction0  = tf.nn.softmax(tf.matmul(tx, w0) + b0)

train_prediction1 = tf.nn.softmax(y1)
valid_prediction1 = tf.nn.softmax(tf.matmul(vx, w1) + b1)
test_prediction1  = tf.nn.softmax(tf.matmul(tx, w1) + b1)

sess = tf.InteractiveSession()
# sess = tf.Session()

init = tf.initialize_all_variables()
sess.run(init)
result_accuracy = []

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)    
    _, L0, tp0 = sess.run([optimizer0, loss0, train_prediction0], feed_dict={x: bx, y_: by})
    _, L1, tp1 = sess.run([optimizer1, loss1, train_prediction1], feed_dict={x: bx, y_: by})
    if step % 500 == 0:
        ac_wo_decay_train = accuracy(tp0, by)
        ac_wo_decay_valid = accuracy(valid_prediction0.eval(), valid_labels)
        ac_wt_decay_train = accuracy(tp1, by)
        ac_wt_decay_valid = accuracy(valid_prediction1.eval(), valid_labels)
        ac = {'step' : step, 'wo_decay' : {'training' : ac_wo_decay_train, 'validation' : ac_wo_decay_valid}, 'wt_decay' : {'training' : ac_wt_decay_train, 'validation' : ac_wt_decay_valid}}
        result_accuracy.append(ac)
        print "step = %d, train accuracy0: %.2f, validation accuracy0: %.2f, train accuracy1: %.2f, validation accuracy1: %.2f" % (step, ac_wo_decay_train, ac_wo_decay_valid, ac_wt_decay_train, ac_wt_decay_valid)
        
print "test accuracy0: %.2f" % accuracy(test_prediction0.eval(), test_labels)
print "test accuracy1: %.2f" % accuracy(test_prediction1.eval(), test_labels)

fig,axes = plt.subplots(ncols=2, figsize=(8,4))
axes[0].hist(w0.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[0].set_title('without weight decay')
axes[0].set_xlabel('weight')
axes[1].hist(w1.eval().flatten(), bins=sp.linspace(-3,3,50))
axes[1].set_title('with weight decay')
axes[1].set_xlabel('weight')
fig.show()

tr_step = [ac['step'] for ac in result_accuracy]
ac_training_wo_decay = [ac['wo_decay']['training'] for ac in result_accuracy]
ac_training_wt_decay = [ac['wt_decay']['training'] for ac in result_accuracy]
ac_validation_wo_decay = [ac['wo_decay']['validation'] for ac in result_accuracy]
ac_validation_wt_decay = [ac['wt_decay']['validation'] for ac in result_accuracy]

fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(1,1,1)

ax.plot(tr_step, ac_training_wo_decay, color='blue', linestyle='dashed')
ax.plot(tr_step, ac_training_wt_decay, color='red', linestyle='dashed')
ax.plot(tr_step, ac_validation_wo_decay, color='blue', linestyle='solid')
ax.plot(tr_step, ac_validation_wt_decay, color='red', linestyle='solid')
ax.set_title('accuracy')
ax.set_xlabel('step')
ax.set_ylabel('accuracy')
ax.grid(True)
ax.set_xlim((0, 10000))
ax.set_ylim((0, 100))
fig.show()

Recommended Posts

[DL] Qu'est-ce que la décroissance du poids?
Qu'est-ce que l'espace de noms
Qu'est-ce que Django? .. ..
Qu'est-ce que dotenv?
Qu'est-ce que POSIX
Qu'est-ce que Linux
Qu'est-ce que le klass?
Qu'est-ce que SALOME?
Qu'est-ce que Linux?
Qu'est-ce que python
Qu'est-ce que l'hyperopt?
Qu'est-ce que Linux
Qu'est-ce que pyvenv
Qu'est-ce que __call__
Qu'est-ce que Linux
Qu'est-ce que Python
[Python] Qu'est-ce que Pipeline ...
Qu'est-ce que Calmar Ratio?
Qu'est-ce qu'un terminal?
[Tutoriel PyTorch ①] Qu'est-ce que PyTorch?
Qu'est-ce que le réglage des hyper paramètres?
Qu'est-ce qu'un hacker?
Qu'est-ce que JSON? .. [Remarque]
À quoi sert Linux?
Qu'est-ce qu'un pointeur?
Qu'est-ce que l'apprentissage d'ensemble?
Qu'est-ce que TCP / IP?
Qu'est-ce que __init__.py de Python?
Qu'est-ce qu'un itérateur?
Qu'est-ce que UNIT-V Linux?
[Python] Qu'est-ce que virtualenv
Qu'est-ce que l'apprentissage automatique?
Qu'est-ce que Mini Sam ou Mini Max?
Qu'est-ce que l'analyse de régression logistique?
Quelle est la fonction d'activation?
Qu'est-ce qu'une variable d'instance?
Qu'est-ce qu'un changement de contexte?
Qu'est-ce que Google Cloud Dataflow?
[Python] Python et sécurité-① Qu'est-ce que Python?
Qu'est-ce qu'un super utilisateur?
La programmation du concours, c'est quoi (bonus)
[Python] * args ** Qu'est-ce que kwrgs?
Qu'est-ce qu'un appel système
[Définition] Qu'est-ce qu'un cadre?
A quoi sert l'interface ...
Qu'est-ce que Project Euler 3 Acceleration?
Qu'est-ce qu'une fonction de rappel?
Qu'est-ce que la fonction de rappel?
Quel est votre "coefficient de Tanimoto"?
Cours de base Python (1 Qu'est-ce que Python)
[Python] Qu'est-ce qu'une fonction zip?
[Python] Qu'est-ce qu'une instruction with?
Qu'est-ce que l'étiquetage dans les prévisions financières?
Qu'est-ce que la régression de crête de rang réduit?
Qu'est-ce que Azure Automation Update Management?
[Python] Qu'est-ce que @? (À propos des décorateurs)
Qu'est-ce qu'une portée lexicale / une portée dynamique?
Qu'est-ce que le réseau neuronal convolutif?