Für Leute, die "Ich habe mein eigenes DNN-Modell erstellt, aber der Code ist schmutzig" und "Ich habe die Büroarbeit satt (Speichern, Protokollieren, gemeinsamer DNN-Code)".
Eine Python-Bibliothek, die dies tut. Es ist die Top-Github-Sternzahl und das beliebte Deep-Learning-Framework.
console
$ pip install pytorch-lightning
pytorch_lightning.LightningModule erben,
* Netzwerk
* 3 Methoden: forward (self, x), training_step (self, batch, batch_idx), configure_optimizers (self)
Wenn Sie die beiden definieren, können Sie sie sofort verwenden. Beachten Sie jedoch, dass der ** Funktionsname und das Argumentpaar nicht geändert werden können **!
(Zum Beispiel batch_idx Wenn Sie es wie `` `training_step (self, batch)` `` definieren, auch wenn Sie es nicht benötigen, wird es fehlerhaft sein)
#### **`MyModel.py`**
```python
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
class LitMyModel(LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.log_softmax(x, dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
Jede der drei Funktionen "Netzwerkausgabe zurückgeben" "In 1 Schleife arbeiten & Verlustfunktion zurückgeben" "Optimierer zurückgeben" Jede Verarbeitung ist in Ordnung
#FC Beispiel lernen MNIST
import pytorch_lightning as pl
class LitMyModel(pl.LightningModule):
def __init__(self):
# layers
self.fc1 = nn.Linear(self.out_size, 400)
self.fc4 = nn.Linear(400, self.out_size)
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.out_size))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def training_step(self, batch, batch_idx):
recon_batch, mu, logvar = self.forward(batch)
loss = self.loss_function(
recon_batch, batch, mu, logvar, out_size=self.out_size)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
return optimizer
Wenn Sie bereits ein Modell haben, verschieben Sie einfach den Code
Danach setzen Sie den Datenlader und das Modell in `pl.Trainer ()`
fit ()
`und beginnen zu lernen!
Laufzeit
dataloader = #Your own dataloader or datamodule
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)
Blitz Einfach und großartig.
Nachdem Sie die oben genannten Informationen erhalten haben, fügen Sie der Klasse die Methoden ** Test, Validierung und andere Optionen ** hinzu.
test
Fügen Sie der Klassenmethode `test_step (self, batch, batch_idx)`
hinzu. Nur. Ausführung
Beim Ausführen des Tests
trainer.test()
validation
Dies wird auch durch Hinzufügen der `val_step ()`
Methode und der `val_dataloader ()`
Methode ~ vervollständigt
dataloader
Dies kann auch in Klassenmethoden gruppiert werden, aber ** Dataset & Data Loader wird empfohlen, `pytorch_lightning.LightningDataModule``` von einer anderen Klasse zu erben und`
MyDataModule``` Klasse ** zu definieren ** ..
class MyDataModule(LightningDataModule): def init(self): super().init() self.train_dims = None self.vocab_size = 0
def prepare_data(self):
# called only on 1 GPU
download_dataset()
tokenize()
build_vocab()
def setup(self):
# called on every GPU
vocab = load_vocab()
self.vocab_size = len(vocab)
self.train, self.val, self.test = load_datasets()
self.train_dims = self.train.next_batch.size()
def train_dataloader(self):
transforms = ...
return DataLoader(self.train, batch_size=64)
def val_dataloader(self):
transforms = ...
return DataLoader(self.val, batch_size=64)
def test_dataloader(self):
transforms = ...
return DataLoader(self.test, batch_size=64)
Laufzeit
datamodule = MyDataModule()
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)
callback So etwas wie "Prozess nur am Anfang des Zuges" und "Prozess am Ende der Epoche". https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks Es gibt viele Informationen. Es ist in Ordnung, wenn Sie eine Funktion für das Timing definieren, das Sie verarbeiten möchten
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('Trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
Wenn Sie es in einer anderen Klasse definieren, können Sie es kurz schreiben ~
Von nun an die Hauptbeziehung zur Aufbewahrung von Datensätzen. Anzeigen von Zahlenwerten (Verlust, Genauigkeit usw.), Bildern, Tönen usw. auf dem Tensorboard
Tensorflow Beispiel
with tf.name_scope('summary'):
tf.summary.scalar('loss', loss)
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('./logs', sess.graph)
Ich neigte dazu, schmutzigen Code zu erstellen, indem ich den Code, den ich sehen wollte, in die Mitte steckte, aber pytorch_lightning kann präzise geschrieben werden.
MyModel.py
def training_step(self, batch, batch_idx):
# ...
loss = ...
self.logger.summary.scalar('loss', loss, step=self.global_step)
# equivalent
result = TrainResult()
result.log('loss', loss)
return result
Fügen Sie in der Methode `logger.summary``` hinzu, wenn Sie like aufnehmen, oder fügen Sie den Teil`
return loss``` einmal zur Klasse `` pytorch_lightning.LightningModule.TrainResult ()
`hinzu Beißen Sie es einfach und es wird automatisch im Speicherverzeichnis gespeichert!
logger kann dem Konstruktor der Klasse `` Trainer ()
`hinzugefügt werden, und das Speicherverzeichnis wird auch hier bestimmt.
from pytorch_lightning import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)
Sie können Daten wie Text und Bilder auch mit dem `.add_hogehoge ()`
des `` `logger.experiment```-Objekts speichern!
MyModel.py
def training_step(...):
...
# the logger you used (in this case tensorboard)
tensorboard = self.logger.experiment
tensorboard.add_histogram(...)
tensorboard.add_figure(...)
Der Beamte sagt, dass der Zeitpunkt des Rückrufs ebenfalls empfohlen wird.
Es ist großartig ... (Es ist wichtig, also sag es zweimal s (ry)
Als Gefühl der Verwendung von Pytorch Lightning ~ ~ (im Vergleich zu der schlechten Lesbarkeit von ignite aufgrund der eingefügten Verarbeitung) ~ ~ Die Regeln sind leicht zu verstehen, und das Klassendesign und die Dokumentenpflege waren korrekt, daher werde ich sie zuerst verwenden Ich hatte das Gefühl, dass dies ein empfohlener Rahmen für tiefes Lernen ist
Recommended Posts