[PYTHON] Implemented "slanted triangular learning rate" in Keras, which is effective in BERT fine tuning

Overview

I was competing in Kaggle's natural language processing competition and wanted to fine-tun BERT. I got the information that the Slanted triangular learning rate (STLR) is good, and when I implemented it in Keras, the accuracy improved considerably.

(Addition) After that, I was able to get a silver medal.

Slanted triangualr learning rate

I referred to the paper below.

As can be seen from Fig. 2, both the warm-up of the learning rate at the beginning of learning and the attenuation of the learning rate after the middle stage are made linear. "Slanted-triangular" because it looks like a tilted triangle.

By the way, the original paper that I decided to use STLR is ↓

Implemented by Keras

This can be achieved using Keras's Callbacks mechanism. You can't use LearningRateScheduler because STLR needs to change the learning rate for each iteration (steps in Keras terminology), not for each epoch. You need to inherit the Callbacks class and create it in scratch.

class SlantedTriangularScheduler(Callback):
    
    def __init__(self,
                 lr_max: float = 0.001,
                 cut_frac: float = 0.1,
                 ratio: float = 32):
        self.lr_max = lr_max
        self.cut_frac = cut_frac
        self.ratio = ratio
    
    def on_train_begin(self, logs = None):
        epochs = self.params['epochs']
        steps = self.params['steps']
        self.cut = epochs * steps * self.cut_frac
        self.iteration = 0
        
    def on_batch_begin(self, batch: int, logs = None):
        t = self.iteration
        cut = self.cut
        if t < cut:
            p = t / cut
        else:
            p = 1 - (t - cut) / (cut * (1 / self.cut_frac - 1))
        lr = self.lr_max * (1 + p * (self.ratio - 1)) / self.ratio
        K.set_value(self.model.optimizer.lr, lr)
        self.iteration += 1

As for the variable names, the same ones as Eq (3) of the original paper are used as much as possible.

As shown in How to Fine-Tune BERT for Text Classification?, this combination worked well for BERT's fine-tuning.

Hyperparameters value
lr_max 2e-5
cut_frac 0.1
ratio 32

Recommended Posts

Implemented "slanted triangular learning rate" in Keras, which is effective in BERT fine tuning
Try fine tuning (transfer learning), which is the mainstream with images with keras, with data learning
Implemented hard-swish in Keras