[PYTHON] Définissez des règles de mise à jour pour chaque paramètre dans Chainer v2

introduction

UpdateRule a été introduit à partir de la v2 dans Chainer, et il est désormais possible de définir des règles de mise à jour pour chaque paramètre en exécutant l'instance UpdateRule. Par exemple, vous pouvez modifier le taux d'apprentissage ou supprimer la mise à jour de certains paramètres.

Qu'est-ce qu'un paramètre?

Les paramètres de cet article font référence à l'instance chainer.Parameter. chainer.Parameter est une classe qui hérite de chainer.Variable et est utilisée pour contenir les paramètres de chainer.Link. Par exemple, chainer.functions.Convolution2D a deux paramètres, W et b.

UpdateRule

chainer.UpdateRule est une classe qui définit comment mettre à jour les paramètres. Il existe des classes dérivées qui prennent en charge les algorithmes de mise à jour tels que SGD. ʻUpdate Rule` a les attributs suivants.

Vous pouvez arrêter la mise à jour des paramètres ou modifier le taux d'apprentissage en manipulant activé ou hyperparam.

Lorsque la règle de mise à jour est générée

L'instance UpdateRule de chaque paramètre est créée lorsque vous appelez setup () de l'instance chainer.Optimizer.

Exemple

Supposons que vous construisiez le réseau neuronal suivant.

class MLP(chainer.Chain):
    def __init__(self):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(2, 2)
            self.l2 = L.Linear(2, 1)

    def __call__(self, x):
        h = self.l1(x)
        h = self.l2(h)
        return h

Arrêter la mise à jour des paramètres

La mise à jour des paramètres peut être arrêtée pour chaque paramètre ou lien.

Spécifier en unités de paramètres

Pour empêcher la mise à jour de certains paramètres, définissez ʻupdate_rule.enabled` sur False. Exemple:

net.l1.W.update_rule.enabled = False

Spécifier dans les unités de lien

Pour empêcher la mise à jour du lien, vous pouvez appeler disable_update (). Inversement, appelez ʻenable_update` pour mettre à jour tous les paramètres de Link.

Exemple:

net.l1.disable_update()

Modifier les hyper paramètres

Les hyper paramètres tels que le taux d'apprentissage peuvent être modifiés en manipulant les attributs de hyperparam.

Exemple:

net.l1.W.update_rule.hyperparam.lr = 1.0

Ajouter une fonction de crochet

En appelant ʻupdate_rule.add_hook, des fonctions de hook telles que chainer.optimizer.WeightDecay` peuvent être définies pour chaque paramètre.

Exemple:

net.l1.W.update_rule.add_hook(chainer.optimizer.WeightDecay(0.0001))

Essaie

À titre d'exemple, augmentons le taux d'apprentissage de certains paramètres et arrêtons de mettre à jour d'autres paramètres.

# -*- coding: utf-8 -*-

import numpy as np

import chainer
from chainer import functions as F
from chainer import links as L
from chainer import optimizers


class MLP(chainer.Chain):
    def __init__(self):
        super(MLP, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(2, 2)
            self.l2 = L.Linear(2, 1)

    def __call__(self, x):
        h = self.l1(x)
        h = self.l2(h)
        return h


net = MLP()
optimizer = optimizers.SGD(lr=0.1)

#La configuration de l'appel générera une règle de mise à jour
optimizer.setup(net)

net.l1.W.update_rule.hyperparam.lr = 10.0
net.l1.b.update_rule.enabled = False

x = np.asarray([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
t = np.asarray([[0], [1], [1], [0]], dtype=np.int32)

y = net(x)

print('before')
print('l1.W')
print(net.l1.W.data)
print('l1.b')
print(net.l1.b.data)
print('l2.W')
print(net.l2.W.data)
print('l2.b')
print(net.l2.b.data)

loss = F.sigmoid_cross_entropy(y, t)
net.cleargrads()
loss.backward()
optimizer.update()

print('after')
print('l1.W')
print(net.l1.W.data)
print('l1.b')
print(net.l1.b.data)
print('l2.W')
print(net.l2.W.data)
print('l2.b')
print(net.l2.b.data)

Le résultat de l'exécution est le suivant. Vous pouvez voir que la quantité de changement dans l1.W est beaucoup plus grande que la quantité de changement dans l2.W, et l1.b n'a pas changé.

before
l1.W
[[ 0.0049778  -0.16282777]
 [-0.92988533  0.2546134 ]]
l1.b
[ 0.  0.]
l2.W
[[-0.45893994 -1.21258962]]
l2.b
[ 0.]
after
l1.W
[[ 0.53748596  0.01032409]
 [ 0.47708291  0.71210718]]
l1.b
[ 0.  0.]
l2.W
[[-0.45838338 -1.20276082]]
l2.b
[-0.01014706]

Recommended Posts

Définissez des règles de mise à jour pour chaque paramètre dans Chainer v2
Comment définir la résolution de sortie pour chaque image clé dans Blender
[Déprécié] Tutoriel pour débutant Chainer v1.24.0