[PYTHON] Was tun, wenn beim Versuch, model.save mit Keras zu modellieren, der Fehler "get_config" überschrieben werden muss?

Einführung

In diesem Artikel möchte ich vorstellen, was zu tun ist, wenn Sie versuchen, model.save (oder model.to_json) mit Keras zu erstellen und `XX hat Argumente in` __init __` und muss daher` get_config `` überschreiben. ..

Hintergrund und Ursache

In Keras gibt es viele vordefinierte Ebenen wie die dichte Ebene und die Conv-Ebene, die kombiniert werden, um ein Grundmodell zu entwerfen. Für Fortgeschrittene müssen Sie jedoch Ihre eigenen benutzerdefinierten Ebenen implementieren und Ihrem Modell hinzufügen. Wenn Sie beispielsweise den in der neuesten Veröffentlichung veröffentlichten Mechanismus nutzen möchten, ist er in der vordefinierten Keras-Ebene nicht vorhanden, und Sie müssen ihn aus Github zitieren oder selbst implementieren. (Wenn Sie an der Implementierung einer benutzerdefinierten Ebene interessiert sind, lesen Sie dieses offizielle Beispiel (https://github.com/keras-team/keras/blob/master/examples/antirectifier.py).) Alternativ können Anfänger unwissentlich ein Modell verwenden, das eine benutzerdefinierte Ebene enthält, wenn sie ein im Kernel von kaggle veröffentlichtes Skript abspalten. (Ich selbst bin diesem Fehler auf diese Weise begegnet.)

Übrigens hat der Fehler `XX (Name der benutzerdefinierten Ebene) Argumente in` __init__` und muss daher überschrieben werden` get_config `` kann für das Modell einschließlich dieser ** benutzerdefinierten Ebene ** nicht korrekt behandelt werden Zu dieser Zeit ärgerte mich Keras: "Ich kenne eine solche Schicht nicht."

Lösungen

Dies kann durch Überschreiben von get_config () in der benutzerdefinierten Ebenenklasse behoben werden. Insbesondere ** machen Sie das Argument `__init__``` der benutzerdefinierten Layer-Klasse in ein Wörterbuch, fügen Sie es der Konfiguration der übergeordneten Klasse hinzu und geben Sie ** usw. zurück.` `Get_config ()` Definieren Dies bedeutet, dass das Argument von `` __init __``` dem Designdokument dieser benutzerdefinierten Ebene ähnelt, daher habe ich es willkürlich erstellt. ** Ich bringe Keras ausdrücklich bei, wie die benutzerdefinierte Ebene funktioniert. * * Gleichwertig.

Übrigens muss das auf diese Weise gespeicherte Modell beim Laden auch explizit die benutzerdefinierte Ebene im Argument custom_objects angeben. Die Methode ist sehr einfach, gehen Sie wie folgt vor:

load_model('my_model.h5', custom_objects={'NameOfCustomLayer': NameOfCustomLayer})

Konkretes Beispiel

Nehmen wir als Beispiel den öffentlichen Kernel von Kaggle. [GLRec] ResNet50 ArcFace (TF2.2)

In diesem Skript wird die eigentliche Modelldefinition unten ausgeführt. Das Backbone-Modell ist in Keras in ResNet 50 vordefiniert. (Das Gewicht kann nicht nur mit dem lokal gespeicherten wie diesem ermittelt werden, sondern auch mit dem Keras-Paket.) Die Pooling- und Dropout-Ebenen sind ebenfalls vordefiniert.

Hier sehen Sie, dass nur die Randebene eindeutig instanziiert wird. Dies ist die benutzerdefinierte Ebene für dieses Modell.

create_model.py



def create_model(input_shape,
                 n_classes,
                 dense_units=512,
                 dropout_rate=0.0,
                 scale=30,
                 margin=0.3):

    backbone = tf.keras.applications.ResNet50(
        include_top=False,
        input_shape=input_shape,
        weights=('../input/imagenet-weights/' +
                 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
    )

    pooling = tf.keras.layers.GlobalAveragePooling2D(name='head/pooling')
    dropout = tf.keras.layers.Dropout(dropout_rate, name='head/dropout')
    dense = tf.keras.layers.Dense(dense_units, name='head/dense')

    margin = ArcMarginProduct(
        n_classes=n_classes,
        s=scale,
        m=margin,
        name='head/arc_margin',
        dtype='float32')

    softmax = tf.keras.layers.Softmax(dtype='float32')

    image = tf.keras.layers.Input(input_shape, name='input/image')
    label = tf.keras.layers.Input((), name='input/label')

    x = backbone(image)
    x = pooling(x)
    x = dropout(x)
    x = dense(x)
    x = margin([x, label])
    x = softmax(x)
    return tf.keras.Model(
        inputs=[image, label], outputs=x)

Überprüfen Sie die Randebenenklasse ArcMarginProduct. Dann können Sie sehen, dass es sich um eine benutzerdefinierte Ebene handelt, die tf.keras.layers.Layer erbt. (Die implementierte Technologie heißt übrigens ArcFace.)

In dieser von mir definierten benutzerdefinierten Ebene trat der Fehler am Anfang auf, als ich model.save ausführte, als ich get_config () nicht richtig überschrieb.

In diesem Kernel ist `get_config ()` in der Klasse nicht definiert. Wenn Sie also versuchen, so zu speichern, wie es ist, tritt ein Fehler auf.

custom_layer.py


class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

Daher müssen Sie die folgenden Änderungen vornehmen.

Insbesondere überschreibt es `get_config ()` und gibt das Argument `__init__ `und die Konfiguration der übergeordneten Klasse zurück.

new_custom_layer.py



class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m


###Start Hinzugefügter Code
    def get_config(self):
        config = {
            "n_classes" : self.n_classes,
            "s" : self.s,
            "m" : self.m,
            "easy_margin" : self.easy_margin,
            "ls_eps" : self.ls_eps
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

###  End       
        
    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

Und das Modell sollte wie folgt geladen werden.

load_model.py



loaded_model =keras.models.load_model("path_to_model", custom_objects = {"ArcMarginProduct": ArcMarginProduct})

Referenz

So erstellen Sie eine benutzerdefinierte Ebene mit Keras Benutzerdefinierte Ebenen mit Keras serialisieren NotImplementedError: Layers with arguments in __init__ must override get_config

Recommended Posts

Was tun, wenn beim Versuch, model.save mit Keras zu modellieren, der Fehler "get_config" überschrieben werden muss?
Was tun, wenn beim Versuch, eine Nachricht in task.loop () unmittelbar nach dem Start zu senden, eine Fehlermeldung angezeigt wird?
Was tun, wenn beim Laden von mnist eine Fehlermeldung angezeigt wird?
Was tun, wenn auf pipenv der Fehler "Keine Versionen gefunden" angezeigt wird?
Was tun, wenn beim Konvertieren von PySparkDataFrame in PandasDataFrame ein Speicherfehler auftritt?
Was tun, wenn beim Importieren von matplotlib in Python (Mac) eine Fehlermeldung angezeigt wird?
Was tun, wenn bei yum ein Metalink für Repository-Fehler nicht abgerufen werden kann?
Was tun, wenn beim Ausführen von "certbot erneuern" in der CakePHP-Umgebung eine Fehlermeldung angezeigt wird?
Was tun, wenn beim Versuch, pip mit pyenv zu verwenden, ein undefinierter Fehler angezeigt wird?
Was tun, wenn der Fehler RuntimeError angezeigt wird: Python wird nicht als Framework installiert, wenn Sie versuchen, matplitlib und pylab in Python 3.3 zu verwenden
Was tun, wenn bei der Installation von Python mit pyenv eine Fehlermeldung angezeigt wird?
[Python] Dinge, die überprüft werden müssen, wenn in Django ein Unicode-Dekodierungsfehler auftritt
Was tun, wenn Overalls "Abdeckung unbekannt" werden?
Was tun, wenn in tf.train.start_queue_runners () ein 0xC0000005-Fehler auftritt?
Was tun, wenn bei der Installation von Python 2 mit pyenv ein OpenSSL-Fehler auftritt?
Was tun, wenn in pycurl (einer von ihnen) "(35, 'SSL-Verbindungsfehler')" angezeigt wird?
Was tun, wenn beim Importieren von matplotlib mit Jupyter ein Importfehler auftritt?
Was tun, wenn bei Verwendung von ts-node-dev unter Linux der Fehler "ERR_FEATURE_UNAVAILABLE_ON_PLATFORM" angezeigt wird?
Was tun, wenn bei der Pip-Installation ein Unicode-Dekodierungsfehler auftritt?
Was tun, wenn Swagger-Codegen mit Python und Importfehler ausgeführt wird? Es wird kein Modul mit dem Namen angezeigt
Was tun, wenn bei Do and Return in einem Golang-Test ein Fehler mit zu vielen Eingabeargumenten auftritt?
Was tun, wenn Sie sich mit FileNotFoundError in der Dateireferenz verlieren?
Was tun, wenn Sie sich über TensorFlow v2 ohne Attribut 'app' ärgern?
Was tun, wenn die Fehlermeldung angezeigt wird, dass der c-Compiler in configure keine ausführbaren Dateien erstellen kann?
Was tun, wenn TypeError in min und max von numpy auftritt?
Was tun, wenn Sie "locale.Error: nicht unterstützte Gebietsschemaeinstellung" erhalten, wenn Sie den Tag vom Datum in Python abrufen?
Was tun, wenn Sie sich über "Wertefehler: unbekannt lokal: UTF-8" in python manage.py syncdb ärgern?
Was tun, wenn beim Importieren von Lebensläufen ein symbolischer Linkfehler auftritt, während versucht wird, OpenCV in Python zu installieren?
Problemumgehung, wenn beim Versuch, PySide mit pip zu installieren, eine Fehlermeldung angezeigt wird
Was tun, wenn in Sublime Text Python ein Unicode-Codierungsfehler auftritt?
Was tun, wenn "Python nicht konfiguriert" angezeigt wird? Verwenden von PyDev in Eclipse
Was tun, wenn im Selenium Chrome-Treiber ein Versionsfehler auftritt?
Was tun, wenn in pip ein Unicode-Dekodierungsfehler auftritt?
Was tun, wenn Sie beim Erstellen einer virtuellen Umgebung mit virtualenv die Meldung "Importfehler: Name 'HTTPSHandler' kann nicht importiert werden" erhalten
Was tun, wenn ein Fehler "unbekannter Dienst" vom gRPC-Server zurückgegeben wird?
Was tun, wenn in Hydrogen "Kein Kernel für Sprachpython gefunden" angezeigt wird?
Was tun, wenn Sie Python auf IntelliJ ausführen und mit einem Fehler beenden?
Was tun, wenn pip in Homebrew einen DistributionError ausgibt?
Was tun, wenn beim Aktualisieren von conda ein Fehler beim Entfernen auftritt?
Was tun, wenn Sie sich nicht als root anmelden können?
Was tun, wenn bei der Installation von openCV 3 der Fehler "Fehler: opencv3: Unterstützt nicht das Erstellen von Python 2- und 3-Wrappern" angezeigt wird
Was tun, wenn unter Linux die Fehlermeldung "Namensauflösung vorübergehend fehlgeschlagen" ausgegeben wird?
Was tun, wenn beim Aktivieren von public_network oder private_network unter Vagrant + Arch Linux beim Vagrant + Arch Linux → Install netctl eine Fehlermeldung angezeigt wird?
Was tun, wenn Sie den Papierkorb in Lubuntu 18.04 nicht verwenden können?
Was tun, wenn "Ich kann die Site nicht sehen !!!!"
Was tun, wenn Sie wütend werden, wenn Sie bei der Installation von lxml unter CentOS nicht über libxml / xmlversion.h verfügen?
Was tun, wenn pip --user in einer mit pyenv erstellten virtuellen Umgebung einen Fehler zurückgibt?
Was tun, wenn in python json .dumps eine Dezimalstelle enthalten ist?
Was tun, wenn PDO nicht in Laravel oder CakePHP gefunden wird?
Bei Programmierfehler: (1146, "Tabelle '<Tabellenname>' existiert nicht") tritt in Django auf
Was tun, wenn Sie die Rastersuche von sklearn in Python nicht verwenden können?
Was tun, wenn Sie während der Anaconda-Installation unter Linux nicht weiterkommen?
Was tun, wenn beim Importieren von numpy mit VScode ein Fehler auftritt?
Was tun, wenn Sie nicht mit pip in einer Babun-Umgebung installieren können?
Was tun, wenn Sie URL 443 mit pip nicht abrufen konnten?
Was passiert, wenn Sie in Python "A, B als C importieren"?
[OSX] [pyenv] Was tun, wenn in pip ein SSL-Fehler auftritt?
Was tun, wenn eine Warnmeldung in der Pip-Liste angezeigt wird?