Pour les personnes qui "J'ai fait mon propre modèle DNN mais le code est sale" et "J'en ai marre du travail de bureau (sauvegarde, journalisation, code commun DNN)"
--Avec la bibliothèque d'explosifs de développement IA Pytorch Lightning
Une bibliothèque python qui le fait. C'est le numéro d'étoile Github le plus populaire et le cadre d'apprentissage en profondeur populaire.
console
$ pip install pytorch-lightning
pytorch_lightning.Hériter de LightningModule,
* Réseau
* 3 méthodes: forward (self, x), training_step (self, batch, batch_idx), configure_optimizers (self)
Si vous définissez les deux, vous pouvez l'utiliser immédiatement. Cependant, notez que le ** nom de la fonction et la paire d'arguments ne peuvent pas être modifiés **!
(Par exemple batch_idx Si vous le définissez comme `` `` training_step (self, batch) '' `` même si vous n'en avez pas besoin, ce sera bogué)
#### **`MyModel.py`**
```python
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
class LitMyModel(LightningModule):
def __init__(self):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.log_softmax(x, dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
Chacune des trois fonctions "Sortie réseau de retour" "Fonction de travail en 1 boucle et perte de retour" "Optimiseur de retour" Tout traitement est OK
#Exemple de FC apprentissage MNIST
import pytorch_lightning as pl
class LitMyModel(pl.LightningModule):
def __init__(self):
# layers
self.fc1 = nn.Linear(self.out_size, 400)
self.fc4 = nn.Linear(400, self.out_size)
def forward(self, x):
mu, logvar = self.encode(x.view(-1, self.out_size))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def training_step(self, batch, batch_idx):
recon_batch, mu, logvar = self.forward(batch)
loss = self.loss_function(
recon_batch, batch, mu, logvar, out_size=self.out_size)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(model.parameters(), lr=1e-3)
return optimizer
Bien sûr, si vous avez déjà un modèle, déplacez simplement le code
Après cela, mettez le chargeur de données et le modèle dans pl.Trainer ()
fit () '' et commencez à apprendre!
Durée
dataloader = #Your own dataloader or datamodule
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, dataloader)
éclair Facile et génial.
Maintenant que vous pouvez apprendre tout ce qui précède, ajoutez les méthodes ** test, validation et autres options ** à la classe.
test Ajoutez `` test_step (self, batch, batch_idx) '' à la méthode de classe. Seulement. Exécution
Lors de l'exécution du test
trainer.test()
validation
Ceci est également complété par l'ajout de la méthode
val_step () '' et de la méthode
val_dataloader () '' ~
dataloader
Cela peut également être regroupé en méthodes de classe, mais ** Dataset & Data Loader est recommandé d'hériter
pytorch_lightning.LightningDataModule d'une autre classe et de définir la classe `` MyDataModule
** ..
class MyDataModule(LightningDataModule): def init(self): super().init() self.train_dims = None self.vocab_size = 0
def prepare_data(self):
# called only on 1 GPU
download_dataset()
tokenize()
build_vocab()
def setup(self):
# called on every GPU
vocab = load_vocab()
self.vocab_size = len(vocab)
self.train, self.val, self.test = load_datasets()
self.train_dims = self.train.next_batch.size()
def train_dataloader(self):
transforms = ...
return DataLoader(self.train, batch_size=64)
def val_dataloader(self):
transforms = ...
return DataLoader(self.val, batch_size=64)
def test_dataloader(self):
transforms = ...
return DataLoader(self.test, batch_size=64)
Durée
datamodule = MyDataModule()
model = LitMyModel()
trainer = pl.Trainer()
trainer.fit(model, datamodule)
callback Quelque chose comme "processus à effectuer uniquement au début du train" et "processus à effectuer à la fin de l'époque" https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html#callbacks Il y a beaucoup d'informations autour. C'est OK si vous définissez une fonction pour le timing que vous souhaitez traiter
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('Trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
Si vous le définissez dans une autre classe, vous pouvez l'écrire de manière concise ~
Maintenant, à partir de maintenant, la principale relation de préservation des enregistrements. Pour afficher des valeurs numériques (perte, précision, etc.), des images, des sons, etc. sur le tensorboard
exemple de tensorflow
with tf.name_scope('summary'):
tf.summary.scalar('loss', loss)
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('./logs', sess.graph)
J'avais tendance à faire du code sale en collant le code que je voulais voir au milieu, mais pytorch_lightning peut être écrit de manière concise,
MyModel.py
def training_step(self, batch, batch_idx):
# ...
loss = ...
self.logger.summary.scalar('loss', loss, step=self.global_step)
# equivalent
result = TrainResult()
result.log('loss', loss)
return result
Ajoutez à
logger.summary dans la méthode lors de l'enregistrement comme, ou ajoutez une fois la partie `` `` return loss
à la classe
pytorch_lightning.LightningModule.TrainResult () ``. Il suffit de le mordre et il sera automatiquement enregistré dans le répertoire de sauvegarde!
logger est OK si vous l'ajoutez au constructeur de la classe
Trainer () '', et le répertoire de stockage est également décidé ici.
from pytorch_lightning import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger('logs/')
trainer = Trainer(logger=tb_logger)
Vous pouvez également enregistrer des données telles que du texte et des images en utilisant le
.add_hogehoge () '' de l'objet
`` logger.experiment```!
MyModel.py
def training_step(...):
...
# the logger you used (in this case tensorboard)
tensorboard = self.logger.experiment
tensorboard.add_histogram(...)
tensorboard.add_figure(...)
Le fonctionnaire dit que le moment du rappel est également recommandé.
C'est génial ... (C'est important, alors je vais le dire deux fois
Comme impression d'utiliser Pytorch Lightning ~ ~ (par rapport à la mauvaise lisibilité de ignite en raison du traitement inséré) ~ ~ Les règles sont faciles à comprendre, et la conception de la classe et la maintenance des documents étaient correctes, je vais donc les utiliser en premier J'ai senti que c'était un cadre d'apprentissage en profondeur recommandé pour