[PYTHON] Einführung in die Verwendung von Pytorch Lightning ~ Bis Sie Ihr eigenes Modell formatieren und auf Tensorboard ausgeben ~

Was ist in diesem Artikel zu tun?

Für Leute, die "Ich habe mein eigenes DNN-Modell erstellt, aber der Code ist schmutzig" und "Ich habe die Büroarbeit satt (Speichern, Protokollieren, gemeinsamer DNN-Code)".

Was ist Pytorch Lightning?

Eine Python-Bibliothek, die dies tut. Es ist die Top-Github-Sternzahl und das beliebte Deep-Learning-Framework.

Wie benutzt man

1. Zuerst installieren

console


$ pip install pytorch-lightning

2. Schreiben Sie ein Deep-Learning-Modell gemäß pytorch_lightning

pytorch_lightning.LightningModule erben,



 * Netzwerk
 * 3 Methoden: forward (self, x), training_step (self, batch, batch_idx), configure_optimizers (self)

 Wenn Sie die beiden definieren, können Sie sie sofort verwenden. Beachten Sie jedoch, dass der ** Funktionsname und das Argumentpaar nicht geändert werden können **!
 (Zum Beispiel batch_idx Wenn Sie es wie `` `training_step (self, batch)` `` definieren, auch wenn Sie es nicht benötigen, wird es fehlerhaft sein)


#### **`MyModel.py`**
```python

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMyModel(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)

    x = F.log_softmax(x, dim=1)
    return x

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = F.nll_loss(logits, y)
    return loss

Jede der drei Funktionen "Netzwerkausgabe zurückgeben" "In 1 Schleife arbeiten & Verlustfunktion zurückgeben" "Optimierer zurückgeben" Jede Verarbeitung ist in Ordnung

** Für diejenigen, die lang sind, aber ein VAE-Beispiel sehen möchten (Klicken) **
#FC Beispiel lernen MNIST
import pytorch_lightning as pl

class LitMyModel(pl.LightningModule):
    def __init__(self):
        # layers
        self.fc1 = nn.Linear(self.out_size, 400)
        self.fc4 = nn.Linear(400, self.out_size)

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.out_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def training_step(self, batch, batch_idx):
        recon_batch, mu, logvar = self.forward(batch)
        loss = self.loss_function(
            recon_batch, batch, mu, logvar, out_size=self.out_size)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        return optimizer

Wenn Sie bereits ein Modell haben, verschieben Sie einfach den Code Danach setzen Sie den Datenlader und das Modell in `pl.Trainer ()` fit () `und beginnen zu lernen!

Laufzeit


dataloader = #Your own dataloader or datamodule

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)

Blitz Einfach und großartig.


3. Fügen Sie den Methoden dieser Klasse weitere Arbeiten hinzu

Nachdem Sie die oben genannten Informationen erhalten haben, fügen Sie der Klasse die Methoden ** Test, Validierung und andere Optionen ** hinzu.

test Fügen Sie der Klassenmethode `test_step (self, batch, batch_idx)` hinzu. Nur. Ausführung

Beim Ausführen des Tests


trainer.test()

validation Dies wird auch durch Hinzufügen der `val_step ()` Methode und der `val_dataloader ()` Methode ~ vervollständigt

dataloader Dies kann auch in Klassenmethoden gruppiert werden, aber ** Dataset & Data Loader wird empfohlen, `pytorch_lightning.LightningDataModule``` von einer anderen Klasse zu erben und` MyDataModule``` Klasse ** zu definieren ** ..

** Für diejenigen, die lang sind, aber das MNIST-Beispiel sehen möchten (Klicken) **

class MyDataModule(LightningDataModule): def init(self): super().init() self.train_dims = None self.vocab_size = 0

def prepare_data(self):
    # called only on 1 GPU
    download_dataset()
    tokenize()
    build_vocab()

def setup(self):
    # called on every GPU
    vocab = load_vocab()
    self.vocab_size = len(vocab)

    self.train, self.val, self.test = load_datasets()
    self.train_dims = self.train.next_batch.size()

def train_dataloader(self):
    transforms = ...
    return DataLoader(self.train, batch_size=64)

def val_dataloader(self):
    transforms = ...
    return DataLoader(self.val, batch_size=64)

def test_dataloader(self):
    transforms = ...
    return DataLoader(self.test, batch_size=64)
Wenn Sie dies zum Zeitpunkt des Lernens und Testens in `` `.fit ()` `` beißen, wird es interpretiert, ohne data_loader zu übergeben.

Laufzeit


datamodule = MyDataModule()

model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)

callback So etwas wie "Prozess nur am Anfang des Zuges" und "Prozess am Ende der Epoche". https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks Es gibt viele Informationen. Es ist in Ordnung, wenn Sie eine Funktion für das Timing definieren, das Sie verarbeiten möchten

from pytorch_lightning.callbacks import Callback

class MyPrintingCallback(Callback):
    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('Trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

trainer = Trainer(callbacks=[MyPrintingCallback()])

Wenn Sie es in einer anderen Klasse definieren, können Sie es kurz schreiben ~

4. Mit Tensorboard verknüpfen und Aufnahmeeinstellungen hinzufügen

Von nun an die Hauptbeziehung zur Aufbewahrung von Datensätzen. Anzeigen von Zahlenwerten (Verlust, Genauigkeit usw.), Bildern, Tönen usw. auf dem Tensorboard

Tensorflow Beispiel


with tf.name_scope('summary'):
  tf.summary.scalar('loss', loss)
  merged = tf.summary.merge_all()
  writer = tf.summary.FileWriter('./logs', sess.graph)

Ich neigte dazu, schmutzigen Code zu erstellen, indem ich den Code, den ich sehen wollte, in die Mitte steckte, aber pytorch_lightning kann präzise geschrieben werden.

MyModel.py


def training_step(self, batch, batch_idx):
  # ...
  loss = ...
  self.logger.summary.scalar('loss', loss, step=self.global_step)

  # equivalent
  result = TrainResult()
  result.log('loss', loss)

  return result

Fügen Sie in der Methode `logger.summary``` hinzu, wenn Sie like aufnehmen, oder fügen Sie den Teil` return loss``` einmal zur Klasse `` pytorch_lightning.LightningModule.TrainResult () `hinzu Beißen Sie es einfach und es wird automatisch im Speicherverzeichnis gespeichert!

logger kann dem Konstruktor der Klasse `` Trainer () `hinzugefügt werden, und das Speicherverzeichnis wird auch hier bestimmt.

from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)

Sie können Daten wie Text und Bilder auch mit dem `.add_hogehoge ()` des `` `logger.experiment```-Objekts speichern!

MyModel.py


def training_step(...):
  ...
  # the logger you used (in this case tensorboard)
  tensorboard = self.logger.experiment
  tensorboard.add_histogram(...)
  tensorboard.add_figure(...)

Der Beamte sagt, dass der Zeitpunkt des Rückrufs ebenfalls empfohlen wird.

Es ist großartig ... (Es ist wichtig, also sag es zweimal s (ry)

Am Ende

Als Gefühl der Verwendung von Pytorch Lightning ~ ~ (im Vergleich zu der schlechten Lesbarkeit von ignite aufgrund der eingefügten Verarbeitung) ~ ~ Die Regeln sind leicht zu verstehen, und das Klassendesign und die Dokumentenpflege waren korrekt, daher werde ich sie zuerst verwenden Ich hatte das Gefühl, dass dies ein empfohlener Rahmen für tiefes Lernen ist

Recommended Posts

Einführung in die Verwendung von Pytorch Lightning ~ Bis Sie Ihr eigenes Modell formatieren und auf Tensorboard ausgeben ~
Erstellen Sie Ihre eigene Ausnahme
Einführung in die Verwendung von Pytorch Lightning ~ Bis Sie Ihr eigenes Modell formatieren und auf Tensorboard ausgeben ~
Bis Sie Ihren eigenen Dolmetscher selbst hosten
Bis Sie einen Schnappschuss des Amazon Elasticsearch-Dienstes erhalten und wiederherstellen
Wie Sie pyenv und pyenv-virtualenv auf Ihre eigene Weise verwenden
So installieren Sie den Cascade-Detektor und wie verwenden Sie ihn
Einführung in Lightning Pytorch
Wie man Decorator in Django benutzt und wie man es macht
Was ist pip und wie benutzt du es?
[Python] Wenn Sie Ihr eigenes Paket im oberen Verzeichnis importieren und verwenden möchten
[Einführung in die Udemy Python3 + -Anwendung] 36. Verwendung von In und Not
Einführung von DataLiner Version 1.3 und Verwendung von Union Append
[Einführung] Verwendung von open3d
So geben Sie die im Django-Modell enthaltenen Daten im JSON-Format zurück und ordnen sie der Broschüre zu
Verwendung von Google Colaboratory und Verwendungsbeispiel (PyTorch × DCGAN)
[Einführung in Python] Verwendung des Booleschen Operators (und ・ oder ・ nicht)
So installieren und verwenden Sie Tesseract-OCR
Bis Sie Ihren eigenen Dolmetscher selbst hosten
Verwendung von .bash_profile und .bashrc
So installieren und verwenden Sie Graphviz
Von der Einführung der GoogleCloudPlatform Natural Language API bis zur Verwendung
Einführung des Cyber-Sicherheits-Frameworks "MITRE CALDERA": Verwendung und Schulung
Es ist praktisch, stac_info und exc_info zu verwenden, wenn Sie Traceback in der Protokollausgabe durch Protokollierung anzeigen möchten.
[Einführung in Python] Wie verwende ich eine Klasse in Python?
So installieren und verwenden Sie pandas_datareader [Python]
[TF] Verwendung von Tensorboard von Keras
Bis Sie Ihre eigene Python-Bibliothek installieren
So installieren Sie Ihre eigene (Root-) Zertifizierungsstelle
Python: Verwendung von Einheimischen () und Globalen ()
Grundlagen von PyTorch (1) - Verwendung von Tensor-
Verwendung von Python zip und Aufzählung
Verwendung ist und == in Python
Verwendung von pandas Timestamp und date_range
[Python] So benennen Sie Tabellendaten und geben sie mit csv aus (to_csv-Methode)