[PYTHON] Personnalisez l'affichage de la progression pendant l'apprentissage avec tf.keras (contre-mesures de débordement de cellules Google Colaboratory)

[Détails]

Lorsque mon collègue utilisait Keras (tf.keras) pour apprendre des dizaines de milliers d'Epoch sur Google Colaboratory, il a déploré que le navigateur devienne lourd et que l'affichage n'ait pas été mis à jour après tout. La cause était que le journal de progression était affiché en spécifiant verbose = 1 dans model.fit, mais le journal devenait gonflé et lourd, et lorsqu'un certain seuil (?) Était atteint, l'affichage était mis à jour. C'est parti. (L'opération se poursuit) Il suffit d'arrêter la journalisation à verbose = 0 et de l'activer, mais vous ne pourrez alors pas vérifier la progression.

Au fait, je me suis souvenu que j'avais troublé le même phénomène dans le passé et l'avais résolu en utilisant la fonction de rappel, alors j'aimerais le partager.

[À propos de la fonction de rappel de tf.keras]

Vous pouvez personnaliser le comportement pendant l'entraînement en spécifiant une classe qui hérite de la classe tf.keras.callbacks.Callback dans l'argument callbacks de model.fit. Consultez la documentation officielle pour plus de détails. 【tf.keras.callbacks.Callback - TensorFlow】 [Rappel - Documentation Keras]

tf.keras.callbacks.Callback fournit plusieurs méthodes, mais elles sont supposées être appelées à un moment donné. En remplaçant ces méthodes, vous pouvez modifier le comportement pendant l'apprentissage. Cette fois, j'ai remplacé la méthode suivante.

Méthode Moment d'être appelé
on_train_begin Au début de l'apprentissage
on_train_end À la fin de l'apprentissage
on_batch_begin Au début du lot
on_batch_end À la fin du lot
on_epoch_begin Au début de l'époque
on_epoch_end À la fin de l'époque

En plus de ce qui précède, il existe des méthodes qui sont appelées pendant l'inférence et les tests.

【politique】

En continuant à écraser l'affichage de la progression pendant l'apprentissage sur la même ligne sans casser la ligne, la cellule de sortie ne grandira pas et ne débordera pas. Utilisez le code suivant pour continuer à écraser sur la même ligne.

print('\rTest Print', end='')

Le \ r dans le code ci-dessus signifie Carriage Return (CR), qui vous permet de déplacer le curseur au début d'une ligne. Cela vous permet d'écraser les lignes affichées.

Cependant, si cela est laissé tel quel, un saut de ligne se produira chaque fois que l'instruction d'impression est exécutée. Par conséquent, spécifiez ʻend = '' comme argument de l'instruction d'impression. Le but est de supprimer les sauts de ligne en spécifiant que le premier argument ne doit pas être sorti après la sortie. Par défaut, ʻend = '\ n' est spécifié dans l'instruction d'impression. \ n signifie Saut de ligne (LF), qui envoie le curseur sur une nouvelle ligne (c'est-à-dire un saut de ligne).

Si vous exécutez le code suivant à titre d'essai, il continuera à écraser 0 à 9 et peut être exprimé comme s'il comptait.

Écraser l'échantillon


from time import sleep
for i in range(10):
  print('\r%d' % i, end='')
  sleep(1)

Je pense ici. Je pense également qu'il vaut mieux définir ʻend = '\ r'au lieu d'imprimer' \ r'`.

Cependant, cette tentative ne fonctionne pas. Parce qu'en Python, lorsque \ r'` est sorti, il semble que le contenu sorti jusqu'à présent soit effacé. Par exemple, si vous exécutez `print ('Test Print', end = '\ r')`, rien ne sera affiché, ce qui n'est pas pratique à cet effet. Par conséquent, il n'y a pas d'autre choix que de sortir la chaîne de caractères que vous souhaitez afficher après avoir sorti \ r'` juste avant la sortie de caractères.

Utilisez donc la technique ci-dessus pour coder avec la stratégie suivante.

Au début / à la fin de l'apprentissage

Affiche le début / la fin et l'heure à laquelle il a été exécuté. Il s'agit d'un saut de ligne normal.

Lorsque le lot est terminé et lorsque l’époque est terminée

Le nombre d'Epoch, le nombre de données traitées, les acc et les pertes sont affichés. Cet affichage est écrasé sans saut de ligne pour réduire la taille de la cellule de sortie.

【codage】

Nous le mettrons en œuvre sur la base de la politique ci-dessus. La partie modèle est basée sur le didacticiel de TensorFlow. 【TensorFlow 2 quickstart for beginners】

import tensorflow as tf
#Définition de la fonction de rappel pour l'affichage de la progression personnalisé
"""
Fonction de rappel pour afficher la progression.
Les données sont collectées et affichées à la fin du lot et de l'époque.
Le point est lorsque l'impression est sortie/Argument fin en ramenant le curseur au début de la ligne avec r=''Le fait est que les sauts de ligne sont supprimés.
"""
import datetime

class DisplayCallBack(tf.keras.callbacks.Callback):
  #constructeur
  def __init__(self):
    self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss = None, None, None, None
    self.now_batch, self.now_epoch = None, None

    self.epochs, self.samples, self.batch_size = None, None, None

  #Affichage de progression personnalisé(Corps d'affichage)
  def print_progress(self):
    epoch = self.now_epoch
    batch = self.now_batch

    epochs = self.epochs
    samples = self.samples
    batch_size = self.batch_size
    sample = batch_size*(batch)

    # '\r'Et fin=''Pour éviter les sauts de ligne en utilisant
    if self.last_val_acc and self.last_val_loss:
      # val_acc/val_la perte peut être affichée
      print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f - val_acc: %f val_loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss), end='')
    else:
      # val_acc/val_la perte ne peut pas être affichée
      print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss), end='')
      

  #Au début de l'ajustement
  def on_train_begin(self, logs={}):
    print('\n##### Train Start ##### ' + str(datetime.datetime.now()))

    #Obtenir les paramètres
    self.epochs = self.params['epochs']
    self.samples = self.params['samples']
    self.batch_size = self.params['batch_size']

    #Évitez l'affichage standard de la progression
    self.params['verbose'] = 0


  #Au début du lot
  def on_batch_begin(self, batch, logs={}):
    self.now_batch = batch

  #Lorsque le lot est terminé(Affichage de la progression)
  def on_batch_end(self, batch, logs={}):
    #Mise à jour des dernières informations
    self.last_acc = logs.get('acc') if logs.get('acc') else 0.0
    self.last_loss = logs.get('loss') if logs.get('loss') else 0.0

    #Affichage de la progression
    self.print_progress()


  #Au début de l'époque
  def on_epoch_begin(self, epoch, log={}):
    self.now_epoch = epoch
  
  #Quand l'époque est terminée(Affichage de la progression)
  def on_epoch_end(self, epoch, logs={}):
    #Mise à jour des dernières informations
    self.last_val_acc = logs.get('val_acc') if logs.get('val_acc') else 0.0
    self.last_val_loss = logs.get('val_loss') if logs.get('val_loss') else 0.0

    #Affichage de la progression
    self.print_progress()


  #Lorsque l'ajustement est terminé
  def on_train_end(self, logs={}):
    print('\n##### Train Complete ##### ' + str(datetime.datetime.now()))
#Création d'instance pour la fonction de rappel
cbDisplay = DisplayCallBack()
#Lire et normaliser l'ensemble de données MNIST
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# tf.keras.Construire un modèle séquentiel
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
#Apprentissage de modèle
#Ici, nous utilisons la fonction de rappel
history = model.fit(x_train, y_train,
                    validation_data = (x_test, y_test),
                    batch_size=128,
                    epochs=5,
                    verbose=1,              #L'affichage de progression standard est ignoré dans la fonction de rappel
                    callbacks=[cbDisplay])  #Définir l'affichage de la progression personnalisé comme fonction de rappel
#Évaluation du modèle
import pandas as pd

results = pd.DataFrame(history.history)
results.plot();

[Exemple de sortie]

Si vous exécutez ce qui précède, quel que soit le nombre d'Époque que vous tournez, seules les 3 lignes suivantes seront affichées. La deuxième ligne est réécrite avec les dernières informations à la fin de Batch et Epoch, et la dernière ligne est sortie lorsque l'apprentissage est terminé.

##### Train Start ##### 2019-12-24 02:17:27.484038
Epoch 5/5 (59904/60000) -- acc: 0.970283 loss: 0.066101 - val_acc: 0.973900 val_loss: 0.087803
##### Train Complete ##### 2019-12-24 02:17:34.443442

image.png

Recommended Posts

Personnalisez l'affichage de la progression pendant l'apprentissage avec tf.keras (contre-mesures de débordement de cellules Google Colaboratory)
Ingénierie des fonctionnalités pour l'apprentissage automatique à partir du 4e Google Colaboratory - Fonctionnalités interactives
Apprenez facilement 100 traitements linguistiques Knock 2020 avec "Google Colaboratory"
Un mémo lors de l'exécution de l'exemple de code de Deep Learning créé à partir de zéro avec Google Colaboratory
Déplaçons word2vec avec Chainer et voyons la progression de l'apprentissage