(*** Ceci est un article du 9 janvier 2020 **. Je pense qu'il n'y aura pas de problème dans un proche avenir.)
Pytorch Lightning est la trompette semblable à Keras de PyTorch. Il est intéressant de pouvoir écrire de manière compacte autour de modèles, d'apprentissage et de données.
Pour plus de détails, consultez l'article suivant de @fam_taro.
PyTorch Sangokushi (Ignite / Catalyst / Lightning) --Qiita
Cela semble très pratique, mais j'ai rencontré un bogue au début de l'installation, je vais donc vous rapporter le contenu et la solution.
OS: macOS 10.14.6
Python: 3.7.3
pytorch-lightning: 0.5.3.2
Comment installer Pytorch Lightning: pip install pytorch-lightning
L'arrêt anticipé est implémenté dans Pytorch Lightning, Vous pouvez l'écrire avec le code suivant (clean).
Partie de définition de modèle(Extrait)
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
...
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_batch_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
val_loss = torch.stack([x['val_batch_loss'] for x in outputs]).mean()
log = {'val_loss': val_loss}
return {'log': log}
...
Early_Autour de l'arrêt
early_stop_callback = EarlyStopping(
min_delta=0.00,
patience=1,
verbose=False,
monitor='val_loss',
mode='min',
)
model = MyModel()
trainer = pl.Trainer(early_stop_callback=early_stop_callback)
trainer.fit(model)
Cependant, lorsque je l'ai exécuté, j'ai eu l'erreur suivante et cela n'a pas fonctionné. (Je ne sais pas si cela arrive toujours, mais pour le moment, cela a toujours été reproduit dans mon environnement d'exécution.)
Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,train_loss
Ce bogue a également été signalé sur le numéro officiel.
https://github.com/williamFalcon/pytorch-lightning/issues/490
Il a été corrigé dans la dernière branche principale, donc l'installer avec la commande suivante corrigera le bogue.
pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade
L'installation de la dernière branche est susceptible de provoquer une divergence d'API entre ** 9 janvier 2020 ** Documentation actuelle.
Exemple: modifiez l'argument de la méthode d'initialisation de la classe d'enregistrement de point de contrôle pytorch_lightning.callbacks.ModelCheckpoint
(Page applicable)
pip install pytorch-lightning
(identique au document officiel)save_best_only
**: Spécifiez s'il faut enregistrer uniquement le meilleur modèle avec une valeur booléenne
--Si vous avez installé la dernière branche avec la commande "Solution" (différente de la documentation officielle)save_top_k
**: Spécifiez le nombre de sommets à enregistrer avec un entiersave_best_only
est supprimé d'une manière remplacée)Probablement, lorsqu'une version supérieure à 0.5.3.2 est publiée, l'habituel pip install pytorch-lightning
conviendra.
Recommended Posts