[PYTHON] Essayez le nouveau chaînage du planificateur dans PyTorch 1.4

introduction

Avez-vous mis à niveau vers PyTorch 1.4? Si vous ne l'avez pas déjà fait, vous pouvez découvrir comment mettre à jour depuis Officiel ici. (Veuillez noter que la série Python 2 ne sera pas prise en charge à partir de la prochaine version)

En tant que nouvelle fonctionnalité de PyTorch 1.4, la ** fonction de chaînage du planificateur ** a été discrètement ajoutée. (Cliquez ici pour les notes de version) Essayez immédiatement.

Qu'est-ce que Scheduler

Vous pouvez utiliser Scheduler pour modifier le taux d'apprentissage pour chaque époque. Plus le taux d'apprentissage est élevé, plus l'apprentissage progresse rapidement, mais si le taux d'apprentissage reste trop élevé, il y a un risque de sauter la solution optimale. Par conséquent, il est courant d'utiliser Scheduler lors de l'apprentissage de NN et de réduire progressivement le taux d'apprentissage à mesure que le nombre d'époques augmente. (Bien que cela ne soit pas directement lié à cette histoire, veuillez noter que le planificateur PyTorch est implémenté pour renvoyer un multiplicateur au taux d'apprentissage d'origine, contrairement à Keras, etc.)

Qu'est-ce que le nouveau chaînage de fonctions?

Selon le fonctionnaire, ** "Une fonction qui vous permet de combiner des effets en définissant et en faisant avancer deux ordonnanceurs l'un après l'autre" **. Jusqu'à présent, seul le taux d'apprentissage décidé par un ordonnanceur pouvait être utilisé, mais il semble qu'il soit possible de réaliser des techniques d'appariement. Cela ne sort pas très bien, alors déplaçons-le et voyons comment le taux d'apprentissage change.

Tout d'abord, vérifiez son comportement avec PyTorch 1.3

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR, StepLR
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler1 = StepLR(optimizer, step_size=3, gamma=0.5)
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
s1, s2, lr = [], [], []
for epoch in range(100):
    optimizer.step()
    scheduler1.step()
    scheduler2.step()
    s1.append(scheduler1.get_lr()[0])
    s2.append(scheduler2.get_lr()[0])
    for param_group in optimizer.param_groups:
        lr.append(param_group['lr'])

Nous utilisons deux ordonnanceurs, StepLR et ExponentialLR. Appelons-les respectivement les ordonnanceurs 1 et 2. Tracez le taux d'apprentissage (s1, s2) de l'ordonnanceur obtenu respectivement.

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
plt.plot(s1, label='StepLR (scheduler1)')
plt.plot(s2, label='ExponentialLR (scheduler2)')
plt.legend()

image.png

Vous pouvez voir le type de caractéristiques de chaque planificateur. Ensuite, tracez le taux d'apprentissage d'Optimizer.

plt.plot(lr, label='Learning Rate')
plt.legend()

image.png

Comme vous pouvez le voir en un coup d'œil, vous pouvez voir que le taux d'apprentissage de Exponential LR (scheduler 2) est utilisé. Apparemment, PyTorch 1.3 a utilisé le taux d'apprentissage du planificateur où l'étape a été appelée pour la dernière fois. En outre, il semble que les planificateurs les uns des autres ont travaillé indépendamment sans aucune influence particulière.

Confirmez enfin l'effet du chaînage avec PyTorch 1.4

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import ExponentialLR, StepLR
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler1 = StepLR(optimizer, step_size=3, gamma=0.5)
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
s1, s2, lr = [], [], []
for epoch in range(100):
    optimizer.step()
    scheduler1.step()
    scheduler2.step()
    s1.append(scheduler1.get_last_lr()[0])
    s2.append(scheduler2.get_last_lr()[0])
    for param_group in optimizer.param_groups:
        lr.append(param_group['lr'])

** Notez que dans PyTorch 1.3, j'obtenais le taux d'apprentissage du planificateur avec .get_lr (), mais dans PyTorch 1.4, il s'agit de .get_last_lr (). ** ** Vous pouvez également utiliser .get_lr (), mais sachez que la valeur correcte peut ne pas être sortie. Ces changements ont été officiellement annoncés.

Maintenant, traçons le taux d'apprentissage de chaque planificateur.

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
plt.plot(s1, label='StepLR (scheduler1)')
plt.plot(s2, label='ExponentialLR (scheduler2)')
plt.legend()

image.png

À ce stade, la situation est différente de 1.3. Tracez ensuite le taux d'apprentissage d'Optimizer.

plt.plot(lr, label='Learning Rate')
plt.legend()

image.png

De cette façon, vous pouvez voir que les taux d'apprentissage des deux planificateurs sont multipliés l'un après l'autre pour obtenir le taux d'apprentissage final de l'Optimizer. Il semble bon d'être conscient que le taux d'apprentissage change en raison de l'influence des deux ordonnanceurs eux-mêmes. Il semble que le taux d'apprentissage d'Optimizer change également lorsque le taux d'apprentissage de chaque planificateur change.

Il semble qu'il sera plus facile d'écrire un ordonnanceur qui modifie le taux d'apprentissage de manière compliquée, qui jusqu'à présent devait être écrit par vous-même. Cela semble particulièrement utile lorsque vous souhaitez ajouter un petit comportement cyclique.

c'est tout.

Recommended Posts

Essayez le nouveau chaînage du planificateur dans PyTorch 1.4
Essayez Cython dans les plus brefs délais
Essayez d'utiliser l'API Kraken avec Python
Essayez d'utiliser la bande HL dans l'ordre
Essayez d'accéder à l'API Spotify dans Django.
Essayez d'utiliser l'API BitFlyer Ligntning en Python
Essayez d'implémenter la méthode Monte Carlo en Python
Essayez d'utiliser l'API DropBox Core avec Python
Essayez de charger l'image dans un thread séparé (OpenCV-Python)
Essayez de déchiffrer les données de connexion stockées dans Firefox
Essayez la segmentation sémantique (Pytorch)
Essayez gRPC en Python
Quoi de neuf dans Python 3.5
Nouveau dans Python 3.4.0 (1) --pathlib
Essayez 9 tranches en Python
Quoi de neuf dans Python 3.6
Essayez de gratter les données COVID-19 Tokyo avec Python
Remarque sur le comportement par défaut de collate_fn dans PyTorch
Obtenez et créez des nœuds ajoutés et mis à jour dans la nouvelle version
Essayez d'extraire les mots-clés populaires dans COTOHA