[PYTHON] What to do if you get a must override `get_config` error when trying to model.save in Keras


In this article, I would like to introduce what to do when you try to model.save (or model.to_json) in Keras and get `XX has arguments in` __init __` and therefore must override `get_config ``. ..

Background and cause

In Keras, there are many predefined layers such as Dense layer and Conv layer, and these are combined to design a basic model. But more advanced, you'll have to implement your own custom layers and add them to your model. For example, if you want to take advantage of the mechanism published in the latest paper, it doesn't exist in Keras's predefined layer and you have to quote it from Github or implement it yourself. (If you're curious about implementing custom layers, check out the official examples here (https://github.com/keras-team/keras/blob/master/examples/antirectifier.py).) Alternatively, beginners may unknowingly use a model that includes custom layers when forking a script published in kaggle's kernel, etc. (I myself faced this error that way.)

Now, the error `XX (custom layer name) has arguments in` __init__` and therefore must override `get_config `` has not been properly addressed for this ** model containing the custom layer ** At that time, Keras angered me, "I don't know such a layer."


This can be solved by overriding `get_config ()` in the custom layer class. More specifically, ** make the argument of `__init__``` of the custom layer class into a dictionary, add it to the config of the parent class and return **, etc. get_config () ` Define. What this means is that the argument of `init``` is like the design document of this custom layer, so I made it arbitrarily ** I explicitly teach Keras how the custom layer works * * Equivalent to.

By the way, models saved in this way also need to explicitly indicate their custom layers in the custom_objects argument when loading. The method is very simple, do the following:

load_model('my_model.h5', custom_objects={'NameOfCustomLayer': NameOfCustomLayer})

Concrete example

Let's take Kaggle's Public Kernel as an example. [GLRec] ResNet50 ArcFace (TF2.2)

In this script, the actual model definition is done below. The backbone model is predefined in Keras in ResNet50. (The weight can be obtained not only by using the locally saved one like this time, but also by the Keras package.) The pooling and dropout layers are also predefined.

You can see that only the margin layer is instantiated independently. This is the custom layer for this model.


def create_model(input_shape,

    backbone = tf.keras.applications.ResNet50(
        weights=('../input/imagenet-weights/' +

    pooling = tf.keras.layers.GlobalAveragePooling2D(name='head/pooling')
    dropout = tf.keras.layers.Dropout(dropout_rate, name='head/dropout')
    dense = tf.keras.layers.Dense(dense_units, name='head/dense')

    margin = ArcMarginProduct(

    softmax = tf.keras.layers.Softmax(dtype='float32')

    image = tf.keras.layers.Input(input_shape, name='input/image')
    label = tf.keras.layers.Input((), name='input/label')

    x = backbone(image)
    x = pooling(x)
    x = dropout(x)
    x = dense(x)
    x = margin([x, label])
    x = softmax(x)
    return tf.keras.Model(
        inputs=[image, label], outputs=x)

Check the class of the margin layer, ArcMarginProduct. Then you can see that it is a custom layer that inherits tf.keras.layers.Layer. (By the way, the implemented technology is called ArcFace.)

In this custom layer defined by myself, when I did not properly override `get_config ()`, I faced the error at the beginning when I did model.save.

In this Kernel, ``` get_config ()` `` is not defined in the class, so if you try to save as it is, an error will occur.


class ArcMarginProduct(tf.keras.layers.Layer):
    Implements large margin arc distance.

    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            shape=(int(input_shape[0][-1]), self.n_classes),

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

Therefore, you need to make the following changes.

Specifically, it overrides `get_config ()` and returns the arguments `__init__` and the config of the parent class.


class ArcMarginProduct(tf.keras.layers.Layer):
    Implements large margin arc distance.

    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

###Start added code
    def get_config(self):
        config = {
            "n_classes" : self.n_classes,
            "s" : self.s,
            "m" : self.m,
            "easy_margin" : self.easy_margin,
            "ls_eps" : self.ls_eps
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

###  End       
    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            shape=(int(input_shape[0][-1]), self.n_classes),

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

And the model should be loaded as follows.


loaded_model =keras.models.load_model("path_to_model", custom_objects = {"ArcMarginProduct": ArcMarginProduct})


How to create a custom layer in Keras Serialize custom layers with keras NotImplementedError: Layers with arguments in __init__ must override get_config

Recommended Posts

What to do if you get a must override `get_config` error when trying to model.save in Keras
What to do if you get an error when trying to send a message in tasks.loop () immediately after startup
What to do if you get an error when trying to load mnist
What to do if you get a "No versions found" error in pipenv
What to do if you get a memory error when converting from PySparkDataFrame to PandasDataFrame
What to do if you get an error when importing matplotlib in Python (Mac)
What to do if you get a Cannot retrieve metalink for repository error in yum
What to do if you get an error when running "certbot renew" in CakePHP environment
What to do if you get an Undefined error when trying to use pip with pyenv
What to do if you get an error when installing Dlib (Ubuntu)
What to do if you get the error RuntimeError: Python is not installed as a framework when trying to use matplitlib and pylab in Python 3.3
What to do if you get an error when installing python with pyenv
[Python] What to check when you get a Unicode Decode Error in Django
[Python] What to do if you get a ModuleNotFoundError when importing pandas using Jupyter Notebook in Anaconda
What to do if you get "coverage unknown" in Coveralls
What to do if a 0xC0000005 error occurs in tf.train.start_queue_runners ()
What to do if you get an OpenSSL error when installing Python 2 with pyenv
What to do if you get "(35,'SSL connect error')" in pycurl (one of them)
What to do if you get an Import Error when importing matplotlib with Jupyter
What to do if you get the error ʻERR_FEATURE_UNAVAILABLE_ON_PLATFORM` when using ts-node-dev on Linux
What to do if you get a UnicodeDecodeError with pip install
What to do if you get Swagger-codegen in python and Import Error: No module named
What to do if you get a Call with too many input arguments error at DoAndReturn in a golang test
What to do if you get lost in file reference with FileNotFoundError
What to do if you get angry in TensorFlow v2 without attribute'app'
What to do if you get an error saying c compiler cannot create executables in configure
What to do if you get a TypeError with numpy min, max
What to do if you get `locale.Error: unsupported locale setting` when getting the day of the week from a date in Python
What to do if you get angry with "Value Error: unknown local: UTF-8" in python manage.py syncdb
What to do if a symbolic link error occurs in import cv while trying to install OpenCV in Python
Workaround if you get an error when trying to install PySide with pip
What to do if a Unicode Encode Error occurs in Sublime Text Python
What to do if you get "Python not configured." Using PyDev in Eclipse
What to do if a version error occurs in the selenium Chrome driver
What to do if a UnicodeDecodeError occurs in pip
What to do if you are told "Import Error: cannot import name'HTTPSHandler'" when building a virtual environment using virtualenv
What to do if you get an "unknown service" error from your gRPC server
What to do if you get `No kernel for language python found` in Hydrogen
What to do if you run python in IntelliJ and end with an error
What to do if pip gives a DistributionError in Homebrew
What to do when a Remove Error occurs when updating conda
What to do if you can't log in as root
What to do if you get the error "Error: opencv3: Does not support building both Python 2 and 3 wrappers" when installing openCV 3
What to do when you get an error saying "Name resolution temporarily failed" on linux
What to do if you get an error when vagrant up when you enable public_network or private_network on Vagrant + Arch Linux → Install netctl
What to do if you can't use the trash in Lubuntu 18.04.
What to do when you get "I can't see the site !!!!"
What to do if you get angry if you don't have libxml / xmlversion.h when installing lxml on CentOS
What to do if pip --user returns an error in a virtual environment created with pyenv
What to do if there is a decimal in python json .dumps
What to do if you can't find PDO in Laravel or CakePHP
If you get a Programming Error: (1146, "Table'<table name>' doesn't exist") in Django
What to do if you can't use scikit grid search in Python
What to do if you get stuck during Anaconda installation on Linux
What to do if an error occurs when importing numpy with VScode
What to do if you can't install with pip in babun environment
What to do if you get Could not fetch URL 443 with pip
What to do if fprintd requires a password when registering your fingerprint
What happens if you do "import A, B as C" in Python?
[OSX] [pyenv] What to do when an SSL error occurs in pip
What to do when a warning message is displayed in pip list