[PYTHON] BigTransfer (BiT) Fine Tuning

--Official blog https://blog.tensorflow.org/2020/05/bigtransfer-bit-state-of-art-transfer-learning-computer-vision.html

--Official sample code (TF2, PyTorch, JAX) https://github.com/google-research/big_transfer/tree/master/colabs

From here, the information necessary for BiT fine tuning (the part I used / examined when I played) is extracted and summarized in Japanese.

Text

-** URL where you can refer to and get the pre-trained model ** The h5 and npz formats are stored in the Cloud bucket. It can be used other than TF.

--The meaning of ** S, M, L ** attached to the BiT model ⇒ Differences in trained data sets. ** L ** is private.

Model name data set
BiT-S ILSVRC-2012 (1.3M images)
BiT-M ImageNet-21k (14M images)
BiT-L JFT (300M images)

--The meaning of ** R- ?? x? ** on the BiT model The BiT model uses ResNet, so the information there. --How to read R50x3 → 50 layers of ResNet, the width of each layer is 3 times the normal width. --Number of parameters Shows the approximate number of parameters that the model published on TFHub (excluding the output layer) has.

ResNet Number of parameters (approximate number)
R50x1 23M
R101x1 42M
R50x3 211M
R101x3 381M
R152x4 928M

The following code confirms the number of parameters.

import tensorflow as tf
import tensorflow_hub as tfhub
model = tfhub.KerasLayer('https://tfhub.dev/google/bit/s-r50x1/1')
print(sum(tf.math.reduce_prod(w.shape).numpy() for w in model.weights))

BiT-HyperRule

A heuristic method provided for fine tuning BiT. "If you use this, you will feel good in one shot."

Of course, hyper-parameter search may give you a better model. It is worth the cost.

Depends on the image size of the dataset → Resize, crop size

The image size of the dataset is resized and randomly cropped to the specified size. The correspondence table is quoted from the official blog Table1. BigTransfer (BiT): State-of-the-art transfer learning for computer vision; Table1 At the same time, randomly flip left and right. You only need to resize the validation data.

I wonder if TF2 will have such an atmosphere.

def augmentation(image, label):
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, [512, 512], method=tf.image.ResizeMethod.BILINEAR)
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, [480, 480, 3])
    return image, label

ds_train: tf.data.Dataset
ds_train = (ds_train
            .shuffle(1024)
            .repeat()
            .map(augmentation, tf.data.experimental.AUTOTUNE)
            .batch(64)
            .prefetch(tf.data.experimental.AUTOTUNE))

Note

――I think the image size may be sparse ... ――I read the public source code, but it was resized uniformly, so you don't have to worry about it. ――If you are worried about only a part, filter it out. (I tried removing it in the bonus sample)

--Data expansion that is not performed depending on the task because it will be different from the correct label. --Counting of objects ⇒ Random crop is NG --Specify the position of the object ⇒ Random flip is NG

Depends on the number of dataset samples → Number of learning steps, Mix-Up

Quoted from the official blog Table 2. BigTransfer (BiT): State-of-the-art transfer learning for computer vision; Table2

boundaries is used for subsequent learning rate scheduling.

if dataset_size < 20 * 10 ** 3:
    schedule_len, boundaries = 500, [200, 300, 400]
elif 20 * 10 ** 3 <= dataset_size < 500 * 10 ** 3:
    schedule_len, boundaries = 10000, [3000, 6000, 9000]
else:
    schedule_len, boundaries = 20000, [6000, 12000, 18000]

How to MixUp

Reference: https://github.com/google-research/big_transfer/blob/master/input_pipeline_tf2_or_jax.py#L118

import tensorflow_probability as tfp

def mixup(image, label):
    beta_dist = tfp.distributions.Beta(0.1, 0.1)  # alpha=0.1
    beta = tf.cast(beta_dist.sample([]), tf.float32)
    image = (beta * image + (1 - beta) * tf.reverse(image, axis=[0]))
    label = (beta * label + (1 - beta) * tf.reverse(label, axis=[0]))
    return image, label

Change the arguments as appropriate. For this function, the timing to adapt to the dataset is ** after the mini-batch batch () **.

Since MixUp is involved, the label should be a One-Hot vector.

By the way, the above code is mixed up by tf.reverse with a set of data located symmetrically from the center in the batch. Below is an example of 16 batches. It is mixed up in the upper left [0] and the lower right [15]. MixUp Cat-Dog

Batch size = 512

If memory is not possible, you can lower it.

In the TF2 sample code, the learning rate and the number of steps are adjusted according to the batch size. ・ ・ But it is not done in other samples. I haven't changed the number of steps when changing the learning rate, but I'm not sure why.

batch_size = 64
schedule_len = schedule_len * 512 / batch_size
lr = 0.003 * batch_size / 512

Optimization algorithm = SGD

The learning rate is the initial value. The following scheduling is performed to change the learning rate during learning.

Learning rate scheduling

When the learning progress reaches 30%, 60%, 90% of the total, the learning rate is attenuated by $ \ frac {1} {10} $.

In the official sample code, it is not strictly separated by 30%, 60%, 90%. For example, when the number of samples is less than 20k, it becomes 0.003 up to 200 steps and 0.0003 between 201 and 300 steps.

lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=boundaries, values=[lr, lr * 1e-1, lr * 1e-2, lr * 1e-3])

optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)

bonus

Jupyter Notebook when playing with Cats vs Dogs fine-tuned with reference to the sample code.

Recommended Posts

BigTransfer (BiT) Fine Tuning
Various Fine Tuning with Mobilenet v2
[Kaggle] Classify colorectal cancer [fine tuning]
[PyTorch] CPU vs. GPU vs. TPU [Fine Tuning]
I tried CNN fine tuning with Resnet