[PYTHON] An implementation of ArcFace for TensorFlow

Introduction

Implemented ArcFace for TensorFlow 2.x as a combination of custom layers and custom loss functions.

background

Among various methods of deep distance learning, ArcFace, which can be configured simply by adding it to the output layer of the classification problem, can be said to be a simple and clear method. (Reference: Modern deep metric learning method: SphereFace, CosFace, ArcFace) Realized with TensorFlow 2.x (Keras) because the connection layer of a normal neural network requires only the input value and weight to the layer in the calculation of the output value, whereas the calculation of ArcFace also requires the correct label. To do this, you need to devise something. As a precedent example of realizing ArcFace with Keras, there is "[Keras] MobileNetV2 + I tried to classify PET bottles using ArcFace!". In this case, ArcFace is implemented as a 2-input custom layer, and the correct label is input by bypassing the correct label from the generator of the training data set. The author also tried it with reference to the above, but when I wanted to replace ArcFace with the normal classification, I found it complicated that the network structure changed, so I considered implementing it by another method.

Implementation method

The rough calculation procedure of ArcFace is shown below.

  1. L2 regularization of input X and weight W
  2. Calculate cos (θ) = X · W (inner product) for labels other than the correct answer
  3. Calculate cos (θ + m) for the correct label
  4. Calculate Softmax

Here, cos (θ + m) = cos (θ) ・ cos (m) -sin (θ) ・ sin (m) (addition theorem), and sin (θ) = √ (1-cos (θ) ^ 2 ), So in practice, cos (θ) should be calculated including the correct answer label in step 2.

2'. Calculate cos (θ) = X · W (inner product) for all labels

This eliminates the need for correct labels up to step 2'. And can be implemented as a custom layer. In addition, since step 3 and subsequent steps do not require input or weights, they can be implemented in the loss function. However, it is necessary to calculate Softmax when calculating Accuracy.

An implementation example is shown below.

arcface.py


import TensorFlow as tf

#First half of ArcFace
class ArcFaceLayer0(tf.keras.layers.Layer) :
    def __init__(self, num_outputs, kernel_regularizer = None, **kargs) :
        super(ArcFaceLayer0, self).__init__(**kargs)
        self.num_outputs = num_outputs
        self.kernel_regularizer = kernel_regularizer

    def build(self, input_shape) :
        weight_shape = (input_shape[-1] , self.num_outputs)
        self.kernel = self.add_weight(
            name='kernel',
            shape = weight_shape,
            initializer = tf.keras.initializers.TruncatedNormal(),
            regularizer = self.kernel_regularizer,
            trainable = True
            )
        super(ArcFaceLayer0, self).build(input_shape)

    def call(self, input) :
        n_input = tf.math.l2_normalize(input, axis = 1)               #L2 regularization of input
        n_kernel = tf.math.l2_normalize(self.kernel, axis = 0)        #L2 regularization of weights
        return tf.matmul(n_input, n_kernel)      # W.Inner product of Tx

#ArcFace implemented on the loss function side
class ArcFaceLoss(tf.keras.losses.Loss) :
    # m:margin
    # s:magnification
    # loss_func:Original loss function tf.keras.losses.CategoricalCrossentropy(from_logits = True)Such
    def __init__(self, loss_func, m = 0.5, s = 30, name = "arcface_loss", **kwargs) :
        self.loss_func = loss_func
        self.margin = m
        self.s = s
        self.enable = True
        super(ArcFaceLoss, self).__init__(name = name, **kwargs)

    def call(self, y_true, y_pred):
        # y_pred is cos(θ)
        #Sin for the addition theorem(θ)To calculate
        sine = tf.keras.backend.sqrt(1.0 - tf.keras.backend.square(y_pred))
        phi = y_pred * self.cos_m - sine * self.sin_m       # cos(θ+m)Addition theorem
        phi = tf.where(y_pred > 0, phi, y_pred)             #As it is when facing the day after tomorrow

        #Correct answer class:cos(θ+m)Other classes:cosθ 
        logits = (y_true * phi) + ((1.0 - y_true) * y_pred)

        #Call the original loss function
        return self.loss_func(y_true, logits * self.s)

#Evaluation function for ArcFace
class ArcFaceAccuracy(tf.keras.metrics.Mean) :
    def __init__(self, metrics_func, s = 30, name = "arcface_accuracy", dtype = None) :
        self.metrics_func = metrics_func
        self.s = s
        super(ArcFaceAccuracy, self).__init__(name, dtype)

    def update_state(self, y_true, y_pred, sample_weight = None) :
        output = tf.nn.softmax(y_pred * self.s)
        matches = self.metrics_func(y_true, output)

        return super(ArcFaceAccuracy, self).update_state(matches, sample_weight = sample_weight) 

Recommended Posts

An implementation of ArcFace for TensorFlow
Implementation of Scale-space for SIFT
Install an older version of Tensorflow
It's an implementation of ConnectionPool in redis.py
Implementation of cos similarity matrix [Pytorch, Tensorflow]
Implementation of Deep Learning model for image recognition
Summary of Tensorflow / Keras
I tried the MNIST tutorial for beginners of tensorflow.
Implementation of Fibonacci sequence
Implementation example of LINE BOT server for actual operation
Enable GPU for tensorflow
Python: Get a list of methods for an object
Implementation of ML-EM method, cross-section reconstruction algorithm for CT scan
Turn an array of strings with a for statement (Python3)
Quantum computer implementation of quantum walk 2
Implementation of TF-IDF using gensim
Implementation of MathJax on Sphinx
An introduction to private TensorFlow
Installation notes for TensorFlow for Windows
Convenient library of Tensorflow TF-Slim
Tuning experiment of Tensorflow data
Explanation and implementation of SocialFoceModel
Implementation of game theory-Prisoner's dilemma-
Percentage of LIKE for pymysql
Overview of Docker (for beginners)
Quantum computer implementation of quantum walk 3
Python implementation of particle filters
Implementation of quicksort in Python
Quantum computer implementation of quantum walk 1
Deep reinforcement learning 2 Implementation of reinforcement learning
[Keras] batch inference of arcface
Tips for importing macOS-optimized TensorFlow in an Apple M1 chip environment
Explain how to use TensorFlow 2.X with implementation of VGG16 / ResNet50
Implementation example of hostile generation network (GAN) by keras [For beginners]