(*** Dies ist ein Artikel vom 9. Januar 2020 **. Ich denke, dass es in naher Zukunft kein Problem geben wird.)
Pytorch Lightning ist PyTorchs Keras-ähnliche Trompete. Es ist attraktiv, kompakt um Modelle, Lernen und Daten schreiben zu können.
Einzelheiten finden Sie im folgenden Artikel von @fam_taro.
PyTorch Sangokushi (Ignite / Catalyst / Lightning) - Qiita
Es scheint sehr praktisch zu sein, aber ich bin zu Beginn der Installation auf einen Fehler gestoßen, sodass ich über den Inhalt und die Lösung berichten werde.
OS: macOS 10.14.6
Python: 3.7.3
pytorch-lightning: 0.5.3.2
So installieren Sie Pytorch Lightning: pip install pytorch-lightning
Early Popping ist in Pytorch Lightning implementiert. Sie können es mit dem folgenden Code schreiben (sauber).
Modelldefinitionsteil(Auszug)
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
...
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_batch_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
val_loss = torch.stack([x['val_batch_loss'] for x in outputs]).mean()
log = {'val_loss': val_loss}
return {'log': log}
...
Early_Um anzuhalten
early_stop_callback = EarlyStopping(
min_delta=0.00,
patience=1,
verbose=False,
monitor='val_loss',
mode='min',
)
model = MyModel()
trainer = pl.Trainer(early_stop_callback=early_stop_callback)
trainer.fit(model)
Als ich es ausführte, bekam ich jedoch den folgenden Fehler und es funktionierte nicht. (Ich weiß nicht, ob es immer passiert, aber vorerst wurde es immer in meiner Ausführungsumgebung reproduziert.)
Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,train_loss
Dieser Fehler wurde auch in der offiziellen Ausgabe gemeldet.
https://github.com/williamFalcon/pytorch-lightning/issues/490
Es wurde in der neuesten Hauptniederlassung behoben. Wenn Sie es also mit dem folgenden Befehl installieren, wird der Fehler behoben.
pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade
Die Installation des neuesten Zweigs führt wahrscheinlich zu einer API-Diskrepanz zwischen ** 9. Januar 2020 ** Aktuelle Dokumentation.
Beispiel: Ändern Sie das Argument der Initialisierungsmethode der Checkpoint-Speicherklasse "pytorch_lightning.callbacks.ModelCheckpoint" (Anwendbare Seite)
save_best_only
**: Geben Sie an, ob nur das beste Modell mit dem Bool-Wert gespeichert werden soll
--Wenn Sie den neuesten Zweig mit dem Befehl "Lösung" installiert haben (anders als in der offiziellen Dokumentation)save_top_k
**: Geben Sie an, wie viele Spitzen mit einer Ganzzahl gespeichert werden sollensave_best_only
wird ersetzt ersetzt)Wenn eine Version über 0.5.3.2 veröffentlicht wird, ist der übliche "pip install pytorch-lightning" wahrscheinlich in Ordnung.
Recommended Posts