Als mein Kollege Keras (tf.keras) verwendete, um Zehntausende von Epochen in Google Colaboratory zu lernen, beklagte er, dass der Browser schwer wurde und die Anzeige schließlich nicht aktualisiert wurde. Die Ursache war, dass das Fortschrittsprotokoll durch Angabe von "verbose = 1" in "model.fit" angezeigt wurde, das Protokoll jedoch aufgebläht und schwer wurde. Wenn ein bestimmter Schwellenwert (?) Erreicht wurde, wurde die Anzeige aktualisiert. Es ist weg. (Der Betrieb wird fortgesetzt.) Es reicht aus, die Protokollierung bei "verbose = 0" zu beenden und zu drehen, aber dann können Sie den Fortschritt nicht überprüfen.
Übrigens erinnerte ich mich daran, dass ich in der Vergangenheit das gleiche Phänomen gestört und es mithilfe der Rückruffunktion gelöst hatte, also möchte ich es teilen.
Sie können das Verhalten während des Trainings anpassen, indem Sie eine Klasse angeben, die die Klasse tf.keras.callbacks.Callback
im Argument callbacks
der model.fit
erbt.
Überprüfen Sie die offizielle Dokumentation für Details.
【tf.keras.callbacks.Callback - TensorFlow】
[Rückruf - Keras-Dokumentation]
tf.keras.callbacks.Callback
bietet verschiedene Methoden, die aber irgendwann aufgerufen werden sollen.
Durch Überschreiben dieser Methoden können Sie das Verhalten während des Lernens ändern.
Diesmal habe ich die folgende Methode überschrieben.
Methode | Timing aufgerufen werden |
---|---|
on_train_begin | Zu Beginn des Lernens |
on_train_end | Am Ende des Lernens |
on_batch_begin | Zu Beginn der Charge |
on_batch_end | Am Ende der Charge |
on_epoch_begin | Zu Beginn der Epoche |
on_epoch_end | Am Ende der Epoche |
Darüber hinaus gibt es Methoden, die beim Inferenz- und Testaufruf aufgerufen werden.
Wenn Sie die Fortschrittsanzeige während des Lernens in derselben Zeile weiter überschreiben, ohne die Zeile zu unterbrechen, wächst die Ausgabezelle nicht und läuft nicht über. Verwenden Sie den folgenden Code, um das Überschreiben in derselben Zeile zu halten.
print('\rTest Print', end='')
Das "\ r" im obigen Code bedeutet "Wagenrücklauf" (Carriage Return, CR), mit dem Sie den Cursor an den Anfang einer Zeile bewegen können. Auf diese Weise können Sie die angezeigten Zeilen überschreiben.
Wenn dies jedoch unverändert bleibt, tritt bei jeder Ausführung der Druckanweisung ein Zeilenumbruch auf.
Geben Sie daher "end =" als Argument der print-Anweisung an.
Der Punkt besteht darin, Zeilenumbrüche zu unterdrücken, indem angegeben wird, dass das erste Argument nicht nach der Ausgabe ausgegeben werden soll.
Standardmäßig ist in der print-Anweisung end = '\ n'
angegeben.
\ n
steht für Line Feed (LF), der den Cursor auf eine neue Zeile (dh einen Zeilenumbruch) sendet.
Wenn Sie den folgenden Code als Test ausführen, überschreibt er weiterhin 0 bis 9 und kann so ausgedrückt werden, als würde er hochzählen.
Probe überschreiben
from time import sleep
for i in range(10):
print('\r%d' % i, end='')
sleep(1)
Ich denke hier.
Ich denke auch, dass es besser ist, "end =" \ r' zu setzen, anstatt "'\ r'
zu drucken.
Dieser Versuch funktioniert jedoch nicht. Denn wenn in Python "\ r" ausgegeben wird, scheint der bisher ausgegebene Inhalt gelöscht zu sein. Wenn Sie beispielsweise "print" ("Test Print", end = "\ r") "ausführen, wird nichts angezeigt, was für diesen Zweck unpraktisch ist. Daher bleibt keine andere Wahl, als die Zeichenfolge auszugeben, die Sie nach der Ausgabe von "\ r" kurz vor der Zeichenausgabe ausgeben möchten.
Verwenden Sie also die oben beschriebene Technik, um mit der folgenden Richtlinie zu codieren.
Zeigt den Start / das Ende und die Uhrzeit der Ausführung an. Dies ist ein normaler Zeilenumbruch.
Die Anzahl der Epochen, die Anzahl der verarbeiteten Daten, acc und Verlust werden angezeigt. Diese Anzeige wird ohne Zeilenumbrüche überschrieben, um die Größe der Ausgabezelle zu verringern.
Wir werden es basierend auf der oben genannten Richtlinie implementieren. Das Modellteil basiert auf TensorFlows Tutoria. 【TensorFlow 2 quickstart for beginners】
import tensorflow as tf
#Definition der Rückruffunktion für die benutzerdefinierte Fortschrittsanzeige
"""
Rückruffunktion zur Anzeige des Fortschritts.
Die Daten werden am Ende von Batch und Epoche gesammelt und angezeigt.
Der Punkt ist, wenn der Druck ausgegeben wird/Das Argument endet, während der Cursor mit r an den Zeilenanfang zurückgesetzt wird=''Der Punkt ist, dass Zeilenumbrüche unterdrückt werden.
"""
import datetime
class DisplayCallBack(tf.keras.callbacks.Callback):
#Konstrukteur
def __init__(self):
self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss = None, None, None, None
self.now_batch, self.now_epoch = None, None
self.epochs, self.samples, self.batch_size = None, None, None
#Benutzerdefinierte Fortschrittsanzeige(Displaykörper)
def print_progress(self):
epoch = self.now_epoch
batch = self.now_batch
epochs = self.epochs
samples = self.samples
batch_size = self.batch_size
sample = batch_size*(batch)
# '\r'Und Ende=''Um Zeilenumbrüche zu vermeiden, verwenden Sie
if self.last_val_acc and self.last_val_loss:
# val_acc/val_Verlust kann angezeigt werden
print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f - val_acc: %f val_loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss), end='')
else:
# val_acc/val_Verlust kann nicht angezeigt werden
print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss), end='')
#Zu Beginn der Passform
def on_train_begin(self, logs={}):
print('\n##### Train Start ##### ' + str(datetime.datetime.now()))
#Parameter abrufen
self.epochs = self.params['epochs']
self.samples = self.params['samples']
self.batch_size = self.params['batch_size']
#Vermeiden Sie die Standard-Fortschrittsanzeige
self.params['verbose'] = 0
#Zu Beginn der Charge
def on_batch_begin(self, batch, logs={}):
self.now_batch = batch
#Wenn die Charge abgeschlossen ist(Fortschrittsanzeige)
def on_batch_end(self, batch, logs={}):
#Aktualisierung der neuesten Informationen
self.last_acc = logs.get('acc') if logs.get('acc') else 0.0
self.last_loss = logs.get('loss') if logs.get('loss') else 0.0
#Fortschrittsanzeige
self.print_progress()
#Zu Beginn der Epoche
def on_epoch_begin(self, epoch, log={}):
self.now_epoch = epoch
#Wenn die Epoche beendet ist(Fortschrittsanzeige)
def on_epoch_end(self, epoch, logs={}):
#Aktualisierung der neuesten Informationen
self.last_val_acc = logs.get('val_acc') if logs.get('val_acc') else 0.0
self.last_val_loss = logs.get('val_loss') if logs.get('val_loss') else 0.0
#Fortschrittsanzeige
self.print_progress()
#Wenn die Anpassung abgeschlossen ist
def on_train_end(self, logs={}):
print('\n##### Train Complete ##### ' + str(datetime.datetime.now()))
#Instanzerstellung für Rückruffunktion
cbDisplay = DisplayCallBack()
#Lesen und normalisieren Sie den MNIST-Datensatz
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# tf.keras.Erstellen eines sequentiellen Modells
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#Modelllernen
#Hier verwenden wir die Rückruffunktion
history = model.fit(x_train, y_train,
validation_data = (x_test, y_test),
batch_size=128,
epochs=5,
verbose=1, #Die Standard-Fortschrittsanzeige wird in der Rückruffunktion ignoriert
callbacks=[cbDisplay]) #Stellen Sie die benutzerdefinierte Fortschrittsanzeige als Rückruffunktion ein
#Modellbewertung
import pandas as pd
results = pd.DataFrame(history.history)
results.plot();
Wenn Sie die obigen Schritte ausführen, werden unabhängig von der Anzahl der Epochen nur die folgenden 3 Zeilen angezeigt. Die zweite Zeile wird mit den neuesten Informationen am Ende von Batch und Epoche neu geschrieben, und die letzte Zeile wird ausgegeben, wenn das Lernen abgeschlossen ist.
##### Train Start ##### 2019-12-24 02:17:27.484038
Epoch 5/5 (59904/60000) -- acc: 0.970283 loss: 0.066101 - val_acc: 0.973900 val_loss: 0.087803
##### Train Complete ##### 2019-12-24 02:17:34.443442
Recommended Posts