Implemented ArcFace for TensorFlow 2.x as a combination of custom layers and custom loss functions.
Among various methods of deep distance learning, ArcFace, which can be configured simply by adding it to the output layer of the classification problem, can be said to be a simple and clear method. (Reference: Modern deep metric learning method: SphereFace, CosFace, ArcFace) Realized with TensorFlow 2.x (Keras) because the connection layer of a normal neural network requires only the input value and weight to the layer in the calculation of the output value, whereas the calculation of ArcFace also requires the correct label. To do this, you need to devise something. As a precedent example of realizing ArcFace with Keras, there is "[Keras] MobileNetV2 + I tried to classify PET bottles using ArcFace!". In this case, ArcFace is implemented as a 2-input custom layer, and the correct label is input by bypassing the correct label from the generator of the training data set. The author also tried it with reference to the above, but when I wanted to replace ArcFace with the normal classification, I found it complicated that the network structure changed, so I considered implementing it by another method.
The rough calculation procedure of ArcFace is shown below.
Here, cos (θ + m) = cos (θ) ・ cos (m) -sin (θ) ・ sin (m) (addition theorem), and sin (θ) = √ (1-cos (θ) ^ 2 ), So in practice, cos (θ) should be calculated including the correct answer label in step 2.
2'. Calculate cos (θ) = X · W (inner product) for all labels
This eliminates the need for correct labels up to step 2'. And can be implemented as a custom layer. In addition, since step 3 and subsequent steps do not require input or weights, they can be implemented in the loss function. However, it is necessary to calculate Softmax when calculating Accuracy.
An implementation example is shown below.
arcface.py
import TensorFlow as tf
#First half of ArcFace
class ArcFaceLayer0(tf.keras.layers.Layer) :
def __init__(self, num_outputs, kernel_regularizer = None, **kargs) :
super(ArcFaceLayer0, self).__init__(**kargs)
self.num_outputs = num_outputs
self.kernel_regularizer = kernel_regularizer
def build(self, input_shape) :
weight_shape = (input_shape[-1] , self.num_outputs)
self.kernel = self.add_weight(
name='kernel',
shape = weight_shape,
initializer = tf.keras.initializers.TruncatedNormal(),
regularizer = self.kernel_regularizer,
trainable = True
)
super(ArcFaceLayer0, self).build(input_shape)
def call(self, input) :
n_input = tf.math.l2_normalize(input, axis = 1) #L2 regularization of input
n_kernel = tf.math.l2_normalize(self.kernel, axis = 0) #L2 regularization of weights
return tf.matmul(n_input, n_kernel) # W.Inner product of Tx
#ArcFace implemented on the loss function side
class ArcFaceLoss(tf.keras.losses.Loss) :
# m:margin
# s:magnification
# loss_func:Original loss function tf.keras.losses.CategoricalCrossentropy(from_logits = True)Such
def __init__(self, loss_func, m = 0.5, s = 30, name = "arcface_loss", **kwargs) :
self.loss_func = loss_func
self.margin = m
self.s = s
self.enable = True
super(ArcFaceLoss, self).__init__(name = name, **kwargs)
def call(self, y_true, y_pred):
# y_pred is cos(θ)
#Sin for the addition theorem(θ)To calculate
sine = tf.keras.backend.sqrt(1.0 - tf.keras.backend.square(y_pred))
phi = y_pred * self.cos_m - sine * self.sin_m # cos(θ+m)Addition theorem
phi = tf.where(y_pred > 0, phi, y_pred) #As it is when facing the day after tomorrow
#Correct answer class:cos(θ+m)Other classes:cosθ
logits = (y_true * phi) + ((1.0 - y_true) * y_pred)
#Call the original loss function
return self.loss_func(y_true, logits * self.s)
#Evaluation function for ArcFace
class ArcFaceAccuracy(tf.keras.metrics.Mean) :
def __init__(self, metrics_func, s = 30, name = "arcface_accuracy", dtype = None) :
self.metrics_func = metrics_func
self.s = s
super(ArcFaceAccuracy, self).__init__(name, dtype)
def update_state(self, y_true, y_pred, sample_weight = None) :
output = tf.nn.softmax(y_pred * self.s)
matches = self.metrics_func(y_true, output)
return super(ArcFaceAccuracy, self).update_state(matches, sample_weight = sample_weight)
Recommended Posts