[PYTHON] Customize the progress display during learning with tf.keras (Google Colaboratory cell overflow countermeasures)

[Details]

When my colleague was learning tens of thousands of Epoch on Google Colaboratory using Keras (tf.keras), he lamented that the browser became heavy and the display did not update after all. The cause was that the progress log was displayed by specifying verbose = 1 in model.fit, but the log became bloated and heavy, and when a certain threshold (?) Was reached, the display was updated. It's gone. (Operation is continuing) It's enough to stop logging at verbose = 0 and turn it around, but that makes it impossible to check the progress.

By the way, I remembered that I had troubled the same phenomenon in the past and solved it by using the Callback function, so I would like to share it.

[About the Callback function of tf.keras]

You can customize the behavior during learning by specifying a class that inherits the tf.keras.callbacks.Callback class in the argument callbacks of model.fit. Check the official documentation for details. 【tf.keras.callbacks.Callback - TensorFlow】 [Callback --Keras Documentation]

tf.keras.callbacks.Callback provides several methods, but they are supposed to be called at some point. By overriding these methods, you can change the learning behavior. This time I overridden the following method.

Method Timing to be called
on_train_begin At the beginning of learning
on_train_end At the end of learning
on_batch_begin At the start of Batch
on_batch_end At the end of Batch
on_epoch_begin At the start of Epoch
on_epoch_end At the end of Epoch

In addition to the above, there are methods that are called during inference and testing.

【policy】

By continuing to overwrite the progress display during learning on the same line without line breaks, the output cell will not grow and overflow. Use the following code to keep overwriting on the same line.

print('\rTest Print', end='')

The \ r in the above code means Carriage Return (CR), which allows you to move the cursor to the beginning of a line. This allows you to overwrite the displayed lines.

However, if this is left as it is, a line break will occur every time the print statement is executed. Therefore, specify ʻend ='' as an argument of the print statement. In short, line breaks are suppressed by specifying that the first argument should not be output after output. By default, ʻend ='\ n' is specified in the print statement. \ n stands for Line Feed (LF), which moves the cursor to a new line (that is, a newline).

If you execute the following code as a trial, it will continue to overwrite 0 to 9 and can be expressed as if it is counting up.

Overwrite sample


from time import sleep
for i in range(10):
  print('\r%d' % i, end='')
  sleep(1)

I think here. I also feel that it is better to set ʻend ='\ r'instead of printing'\ r'`.

However, this attempt does not work. Because in Python, when '\ r' is output, it seems that the contents output so far are cleared. For example, if you execute print ('Test Print', end ='\ r'), nothing will be displayed, which is inconvenient for this purpose. Therefore, there is no choice but to output the character string you want to output after outputting '\ r' just before the character output.

So, using the above method, code with the following policy.

At the start / end of learning

Shows the start / end and the time it was executed. Line breaks are normal here.

When Batch is completed and when Epoch is completed

The number of Epoch, the number of processed data, acc and loss are displayed. This display is overwritten without line breaks to reduce the size of the output cell.

【coding】

We will implement it based on the above policy. The model part is based on the TensorFlow tutorial. 【TensorFlow 2 quickstart for beginners】

import tensorflow as tf
#Callback function definition for custom progress display
"""
Callback function for displaying progress.
Data is collected and displayed at the end of Batch and Epoch.
The point is when printing is output/Returning the cursor to the beginning of the line with r, the argument end=''The point is that line breaks are suppressed.
"""
import datetime

class DisplayCallBack(tf.keras.callbacks.Callback):
  #constructor
  def __init__(self):
    self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss = None, None, None, None
    self.now_batch, self.now_epoch = None, None

    self.epochs, self.samples, self.batch_size = None, None, None

  #Custom progress display(Display body)
  def print_progress(self):
    epoch = self.now_epoch
    batch = self.now_batch

    epochs = self.epochs
    samples = self.samples
    batch_size = self.batch_size
    sample = batch_size*(batch)

    # '\r'And end=''To avoid line breaks using
    if self.last_val_acc and self.last_val_loss:
      # val_acc/val_loss can be displayed
      print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f - val_acc: %f val_loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss), end='')
    else:
      # val_acc/val_loss cannot be displayed
      print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss), end='')
      

  #At the start of fit
  def on_train_begin(self, logs={}):
    print('\n##### Train Start ##### ' + str(datetime.datetime.now()))

    #Get parameters
    self.epochs = self.params['epochs']
    self.samples = self.params['samples']
    self.batch_size = self.params['batch_size']

    #Avoid standard progress display
    self.params['verbose'] = 0


  #At the start of batch
  def on_batch_begin(self, batch, logs={}):
    self.now_batch = batch

  #When batch is completed(Progress display)
  def on_batch_end(self, batch, logs={}):
    #Update of the latest information
    self.last_acc = logs.get('acc') if logs.get('acc') else 0.0
    self.last_loss = logs.get('loss') if logs.get('loss') else 0.0

    #Progress display
    self.print_progress()


  #At the start of epoch
  def on_epoch_begin(self, epoch, log={}):
    self.now_epoch = epoch
  
  #When epoch is completed(Progress display)
  def on_epoch_end(self, epoch, logs={}):
    #Update of the latest information
    self.last_val_acc = logs.get('val_acc') if logs.get('val_acc') else 0.0
    self.last_val_loss = logs.get('val_loss') if logs.get('val_loss') else 0.0

    #Progress display
    self.print_progress()


  #When fit is completed
  def on_train_end(self, logs={}):
    print('\n##### Train Complete ##### ' + str(datetime.datetime.now()))
#Instantiation for callback function
cbDisplay = DisplayCallBack()
#Read and normalize MNIST dataset
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# tf.keras.Building a Sequential model
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
#Model learning
#Here we use the callback function
history = model.fit(x_train, y_train,
                    validation_data = (x_test, y_test),
                    batch_size=128,
                    epochs=5,
                    verbose=1,              #The standard progress display is ignored in the callback function
                    callbacks=[cbDisplay])  #Set custom progress display as callback function
#Model evaluation
import pandas as pd

results = pd.DataFrame(history.history)
results.plot();

[Output example]

If you execute the above, no matter how many Epoch you turn, only the following 3 lines will be displayed. The second line is rewritten with the latest information at the end of Batch and Epoch, and the last line is output when learning is completed.

##### Train Start ##### 2019-12-24 02:17:27.484038
Epoch 5/5 (59904/60000) -- acc: 0.970283 loss: 0.066101 - val_acc: 0.973900 val_loss: 0.087803
##### Train Complete ##### 2019-12-24 02:17:34.443442

image.png

Recommended Posts

Customize the progress display during learning with tf.keras (Google Colaboratory cell overflow countermeasures)
Feature engineering for machine learning starting with the 4th Google Colaboratory --Interaction features
Easy learning of 100 language processing knock 2020 with "Google Colaboratory"
A memo when executing the deep learning sample code created from scratch with Google Colaboratory
Let's move word2vec with Chainer and see the learning progress