Lorsque mon collègue utilisait Keras (tf.keras) pour apprendre des dizaines de milliers d'Epoch sur Google Colaboratory, il a déploré que le navigateur devienne lourd et que l'affichage n'ait pas été mis à jour après tout.
La cause était que le journal de progression était affiché en spécifiant verbose = 1
dans model.fit
, mais le journal devenait gonflé et lourd, et lorsqu'un certain seuil (?) Était atteint, l'affichage était mis à jour. C'est parti. (L'opération se poursuit)
Il suffit d'arrêter la journalisation à verbose = 0
et de l'activer, mais vous ne pourrez alors pas vérifier la progression.
Au fait, je me suis souvenu que j'avais troublé le même phénomène dans le passé et l'avais résolu en utilisant la fonction de rappel, alors j'aimerais le partager.
Vous pouvez personnaliser le comportement pendant l'entraînement en spécifiant une classe qui hérite de la classe tf.keras.callbacks.Callback
dans l'argument callbacks
de model.fit
.
Consultez la documentation officielle pour plus de détails.
【tf.keras.callbacks.Callback - TensorFlow】
[Rappel - Documentation Keras]
tf.keras.callbacks.Callback
fournit plusieurs méthodes, mais elles sont supposées être appelées à un moment donné.
En remplaçant ces méthodes, vous pouvez modifier le comportement pendant l'apprentissage.
Cette fois, j'ai remplacé la méthode suivante.
Méthode | Moment d'être appelé |
---|---|
on_train_begin | Au début de l'apprentissage |
on_train_end | À la fin de l'apprentissage |
on_batch_begin | Au début du lot |
on_batch_end | À la fin du lot |
on_epoch_begin | Au début de l'époque |
on_epoch_end | À la fin de l'époque |
En plus de ce qui précède, il existe des méthodes qui sont appelées pendant l'inférence et les tests.
En continuant à écraser l'affichage de la progression pendant l'apprentissage sur la même ligne sans casser la ligne, la cellule de sortie ne grandira pas et ne débordera pas. Utilisez le code suivant pour continuer à écraser sur la même ligne.
print('\rTest Print', end='')
Le \ r
dans le code ci-dessus signifie Carriage Return (CR), qui vous permet de déplacer le curseur au début d'une ligne.
Cela vous permet d'écraser les lignes affichées.
Cependant, si cela est laissé tel quel, un saut de ligne se produira chaque fois que l'instruction d'impression est exécutée.
Par conséquent, spécifiez ʻend = '' comme argument de l'instruction d'impression. Le but est de supprimer les sauts de ligne en spécifiant que le premier argument ne doit pas être sorti après la sortie. Par défaut, ʻend = '\ n'
est spécifié dans l'instruction d'impression.
\ n
signifie Saut de ligne (LF), qui envoie le curseur sur une nouvelle ligne (c'est-à-dire un saut de ligne).
Si vous exécutez le code suivant à titre d'essai, il continuera à écraser 0 à 9 et peut être exprimé comme s'il comptait.
Écraser l'échantillon
from time import sleep
for i in range(10):
print('\r%d' % i, end='')
sleep(1)
Je pense ici.
Je pense également qu'il vaut mieux définir ʻend = '\ r'au lieu d'imprimer
' \ r'`.
Cependant, cette tentative ne fonctionne pas.
Parce qu'en Python, lorsque \ r'` est sorti, il semble que le contenu sorti jusqu'à présent soit effacé. Par exemple, si vous exécutez `print ('Test Print', end = '\ r')`, rien ne sera affiché, ce qui n'est pas pratique à cet effet. Par conséquent, il n'y a pas d'autre choix que de sortir la chaîne de caractères que vous souhaitez afficher après avoir sorti
\ r'` juste avant la sortie de caractères.
Utilisez donc la technique ci-dessus pour coder avec la stratégie suivante.
Affiche le début / la fin et l'heure à laquelle il a été exécuté. Il s'agit d'un saut de ligne normal.
Le nombre d'Epoch, le nombre de données traitées, les acc et les pertes sont affichés. Cet affichage est écrasé sans saut de ligne pour réduire la taille de la cellule de sortie.
Nous le mettrons en œuvre sur la base de la politique ci-dessus. La partie modèle est basée sur le didacticiel de TensorFlow. 【TensorFlow 2 quickstart for beginners】
import tensorflow as tf
#Définition de la fonction de rappel pour l'affichage de la progression personnalisé
"""
Fonction de rappel pour afficher la progression.
Les données sont collectées et affichées à la fin du lot et de l'époque.
Le point est lorsque l'impression est sortie/Argument fin en ramenant le curseur au début de la ligne avec r=''Le fait est que les sauts de ligne sont supprimés.
"""
import datetime
class DisplayCallBack(tf.keras.callbacks.Callback):
#constructeur
def __init__(self):
self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss = None, None, None, None
self.now_batch, self.now_epoch = None, None
self.epochs, self.samples, self.batch_size = None, None, None
#Affichage de progression personnalisé(Corps d'affichage)
def print_progress(self):
epoch = self.now_epoch
batch = self.now_batch
epochs = self.epochs
samples = self.samples
batch_size = self.batch_size
sample = batch_size*(batch)
# '\r'Et fin=''Pour éviter les sauts de ligne en utilisant
if self.last_val_acc and self.last_val_loss:
# val_acc/val_la perte peut être affichée
print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f - val_acc: %f val_loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss), end='')
else:
# val_acc/val_la perte ne peut pas être affichée
print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss), end='')
#Au début de l'ajustement
def on_train_begin(self, logs={}):
print('\n##### Train Start ##### ' + str(datetime.datetime.now()))
#Obtenir les paramètres
self.epochs = self.params['epochs']
self.samples = self.params['samples']
self.batch_size = self.params['batch_size']
#Évitez l'affichage standard de la progression
self.params['verbose'] = 0
#Au début du lot
def on_batch_begin(self, batch, logs={}):
self.now_batch = batch
#Lorsque le lot est terminé(Affichage de la progression)
def on_batch_end(self, batch, logs={}):
#Mise à jour des dernières informations
self.last_acc = logs.get('acc') if logs.get('acc') else 0.0
self.last_loss = logs.get('loss') if logs.get('loss') else 0.0
#Affichage de la progression
self.print_progress()
#Au début de l'époque
def on_epoch_begin(self, epoch, log={}):
self.now_epoch = epoch
#Quand l'époque est terminée(Affichage de la progression)
def on_epoch_end(self, epoch, logs={}):
#Mise à jour des dernières informations
self.last_val_acc = logs.get('val_acc') if logs.get('val_acc') else 0.0
self.last_val_loss = logs.get('val_loss') if logs.get('val_loss') else 0.0
#Affichage de la progression
self.print_progress()
#Lorsque l'ajustement est terminé
def on_train_end(self, logs={}):
print('\n##### Train Complete ##### ' + str(datetime.datetime.now()))
#Création d'instance pour la fonction de rappel
cbDisplay = DisplayCallBack()
#Lire et normaliser l'ensemble de données MNIST
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# tf.keras.Construire un modèle séquentiel
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#Apprentissage de modèle
#Ici, nous utilisons la fonction de rappel
history = model.fit(x_train, y_train,
validation_data = (x_test, y_test),
batch_size=128,
epochs=5,
verbose=1, #L'affichage de progression standard est ignoré dans la fonction de rappel
callbacks=[cbDisplay]) #Définir l'affichage de la progression personnalisé comme fonction de rappel
#Évaluation du modèle
import pandas as pd
results = pd.DataFrame(history.history)
results.plot();
Si vous exécutez ce qui précède, quel que soit le nombre d'Époque que vous tournez, seules les 3 lignes suivantes seront affichées. La deuxième ligne est réécrite avec les dernières informations à la fin de Batch et Epoch, et la dernière ligne est sortie lorsque l'apprentissage est terminé.
##### Train Start ##### 2019-12-24 02:17:27.484038
Epoch 5/5 (59904/60000) -- acc: 0.970283 loss: 0.066101 - val_acc: 0.973900 val_loss: 0.087803
##### Train Complete ##### 2019-12-24 02:17:34.443442
Recommended Posts