For people who "I made my own DNN model but the code is dirty" and "I'm tired of clerical work (save, log, DNN common code)"
--With AI development explosive library Pytorch Lightning ――Clean code management & learning & visualization of tensorboard
A python library that does. It is the top github star number & popular deep learning framework.
console
$ pip install pytorch-lightning
pytorch_lightning.Inheriting LightningModule,
* Network
* 3 methods: forward (self, x), training_step (self, batch, batch_idx), configure_optimizers (self)
If you define the two, you can use it immediately. However, note that ** the function name and argument pair cannot be changed **!
(E.g. batch_idx If you define it like `` `training_step (self, batch) ``` even if you don't need it, it will cause a bug)
#### **`MyModel.py`**
```python
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
class LitMyModel(LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.log_softmax(x, dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
Each of the three functions "Return network output" "Work in 1 loop & return loss function" "return optimizer" Any processing is OK
#FC example learning MNIST
import pytorch_lightning as pl
class LitMyModel(pl.LightningModule):
def __init__(self):
# layers
self.fc1 = nn.Linear(self.out_size, 400)
self.fc4 = nn.Linear(400, self.out_size)
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.out_size))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def training_step(self, batch, batch_idx):
recon_batch, mu, logvar = self.forward(batch)
loss = self.loss_function(
recon_batch, batch, mu, logvar, out_size=self.out_size)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
return optimizer
Of course, if you already have a model, just move the code.
After that, if you put the data loader and model in `fit ()`
of `pl.Trainer ()`
, learning starts already !!
runtime
dataloader = #Your own dataloader or datamodule
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)
lightning Easy and awesome.
Now that you can learn up to the above, add the ** test, validation, and other options ** methods to the class.
test Add ``` test_step (self, batch, batch_idx)` `` to the class method. Only. Execution
test run time
trainer.test()
validation
This is also completed by adding the `val_step ()`
method and the ``` val_dataloader () `` `method ~
dataloader
This can also be grouped into class methods, but ** Dataset & Data Loader is recommended to inherit `pytorch_lightning.LightningDataModule``` from another class and define
`MyDataModule``` class ** ..
class MyDataModule(LightningDataModule): def init(self): super().init() self.train_dims = None self.vocab_size = 0
def prepare_data(self):
# called only on 1 GPU
download_dataset()
tokenize()
build_vocab()
def setup(self):
# called on every GPU
vocab = load_vocab()
self.vocab_size = len(vocab)
self.train, self.val, self.test = load_datasets()
self.train_dims = self.train.next_batch.size()
def train_dataloader(self):
transforms = ...
return DataLoader(self.train, batch_size=64)
def val_dataloader(self):
transforms = ...
return DataLoader(self.val, batch_size=64)
def test_dataloader(self):
transforms = ...
return DataLoader(self.test, batch_size=64)
runtime
datamodule = MyDataModule()
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)
callback Something like "process to do only at the beginning of train" and "process to do at the end of epoch" https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks There is a lot of information around. It's OK if you define a function for the timing you want to process
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('Trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
If you define it in another class, you can write it concisely ~
Now, from here on, the main record preservation relationship. To display numerical values (loss, accuracy, etc.), images, sounds, etc. on the tensorboard,
tensorflow example
with tf.name_scope('summary'):
tf.summary.scalar('loss', loss)
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('./logs', sess.graph)
I tended to make dirty code by sticking the code I wanted to see in the middle, but pytorch_lightning can be written concisely,
MyModel.py
def training_step(self, batch, batch_idx):
# ...
loss = ...
self.logger.summary.scalar('loss', loss, step=self.global_step)
# equivalent
result = TrainResult()
result.log('loss', loss)
return result
Add to `logger.summary``` in the method when recording like, or add the
return loss``` part to the `` pytorch_lightning.LightningModule.TrainResult ()
`class once. Just bite it and it will automatically save to the save directory!
It's OK to add logger to the constructor of the `` Trainer ()
`class, and the storage directory is also decided here.
from pytorch_lightning import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)
You can also save data such as text and images using the `.add_hogehoge ()`
of the `` `logger.experiment``` object!
MyModel.py
def training_step(...):
...
# the logger you used (in this case tensorboard)
tensorboard = self.logger.experiment
tensorboard.add_histogram(...)
tensorboard.add_figure(...)
The official says that the timing of Callback is also recommended.
It's awesome ... (It's important, so I'll say it twice s (ry)
As a feeling of using Pytorch Lightning ~ ~ (compared to ignite's poor readability due to the processing being inserted) ~ ~ The rules are easy to understand, and the class design and document maintenance were proper, so I will use it first I felt that it was a recommended deep learning framework for
Recommended Posts