In diesem Artikel möchte ich vorstellen, was zu tun ist, wenn Sie versuchen, model.save (oder model.to_json) mit Keras zu erstellen und `XX hat Argumente in` __init __` und muss daher` get_config
`` überschreiben. ..
In Keras gibt es viele vordefinierte Ebenen wie die dichte Ebene und die Conv-Ebene, die kombiniert werden, um ein Grundmodell zu entwerfen. Für Fortgeschrittene müssen Sie jedoch Ihre eigenen benutzerdefinierten Ebenen implementieren und Ihrem Modell hinzufügen. Wenn Sie beispielsweise den in der neuesten Veröffentlichung veröffentlichten Mechanismus nutzen möchten, ist er in der vordefinierten Keras-Ebene nicht vorhanden, und Sie müssen ihn aus Github zitieren oder selbst implementieren. (Wenn Sie an der Implementierung einer benutzerdefinierten Ebene interessiert sind, lesen Sie dieses offizielle Beispiel (https://github.com/keras-team/keras/blob/master/examples/antirectifier.py).) Alternativ können Anfänger unwissentlich ein Modell verwenden, das eine benutzerdefinierte Ebene enthält, wenn sie ein im Kernel von kaggle veröffentlichtes Skript abspalten. (Ich selbst bin diesem Fehler auf diese Weise begegnet.)
Übrigens hat der Fehler `XX (Name der benutzerdefinierten Ebene) Argumente in` __init__` und muss daher überschrieben werden` get_config
`` kann für das Modell einschließlich dieser ** benutzerdefinierten Ebene ** nicht korrekt behandelt werden Zu dieser Zeit ärgerte mich Keras: "Ich kenne eine solche Schicht nicht."
Dies kann durch Überschreiben von get_config () in der benutzerdefinierten Ebenenklasse behoben werden.
Insbesondere ** machen Sie das Argument `__init__``` der benutzerdefinierten Layer-Klasse in ein Wörterbuch, fügen Sie es der Konfiguration der übergeordneten Klasse hinzu und geben Sie ** usw. zurück.` `Get_config ()`
Definieren
Dies bedeutet, dass das Argument von `` __init __``` dem Designdokument dieser benutzerdefinierten Ebene ähnelt, daher habe ich es willkürlich erstellt. ** Ich bringe Keras ausdrücklich bei, wie die benutzerdefinierte Ebene funktioniert. * * Gleichwertig.
Übrigens muss das auf diese Weise gespeicherte Modell beim Laden auch explizit die benutzerdefinierte Ebene im Argument custom_objects angeben. Die Methode ist sehr einfach, gehen Sie wie folgt vor:
load_model('my_model.h5', custom_objects={'NameOfCustomLayer': NameOfCustomLayer})
Nehmen wir als Beispiel den öffentlichen Kernel von Kaggle. [GLRec] ResNet50 ArcFace (TF2.2)
In diesem Skript wird die eigentliche Modelldefinition unten ausgeführt. Das Backbone-Modell ist in Keras in ResNet 50 vordefiniert. (Das Gewicht kann nicht nur mit dem lokal gespeicherten wie diesem ermittelt werden, sondern auch mit dem Keras-Paket.) Die Pooling- und Dropout-Ebenen sind ebenfalls vordefiniert.
Hier sehen Sie, dass nur die Randebene eindeutig instanziiert wird. Dies ist die benutzerdefinierte Ebene für dieses Modell.
create_model.py
def create_model(input_shape,
n_classes,
dense_units=512,
dropout_rate=0.0,
scale=30,
margin=0.3):
backbone = tf.keras.applications.ResNet50(
include_top=False,
input_shape=input_shape,
weights=('../input/imagenet-weights/' +
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
)
pooling = tf.keras.layers.GlobalAveragePooling2D(name='head/pooling')
dropout = tf.keras.layers.Dropout(dropout_rate, name='head/dropout')
dense = tf.keras.layers.Dense(dense_units, name='head/dense')
margin = ArcMarginProduct(
n_classes=n_classes,
s=scale,
m=margin,
name='head/arc_margin',
dtype='float32')
softmax = tf.keras.layers.Softmax(dtype='float32')
image = tf.keras.layers.Input(input_shape, name='input/image')
label = tf.keras.layers.Input((), name='input/label')
x = backbone(image)
x = pooling(x)
x = dropout(x)
x = dense(x)
x = margin([x, label])
x = softmax(x)
return tf.keras.Model(
inputs=[image, label], outputs=x)
Überprüfen Sie die Randebenenklasse ArcMarginProduct. Dann können Sie sehen, dass es sich um eine benutzerdefinierte Ebene handelt, die tf.keras.layers.Layer erbt. (Die implementierte Technologie heißt übrigens ArcFace.)
In dieser von mir definierten benutzerdefinierten Ebene trat der Fehler am Anfang auf, als ich model.save ausführte, als ich get_config () nicht richtig überschrieb.
In diesem Kernel ist `get_config ()`
in der Klasse nicht definiert. Wenn Sie also versuchen, so zu speichern, wie es ist, tritt ein Fehler auf.
custom_layer.py
class ArcMarginProduct(tf.keras.layers.Layer):
'''
Implements large margin arc distance.
Reference:
https://arxiv.org/pdf/1801.07698.pdf
https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
blob/master/src/modeling/metric_learning.py
'''
def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
ls_eps=0.0, **kwargs):
super(ArcMarginProduct, self).__init__(**kwargs)
self.n_classes = n_classes
self.s = s
self.m = m
self.ls_eps = ls_eps
self.easy_margin = easy_margin
self.cos_m = tf.math.cos(m)
self.sin_m = tf.math.sin(m)
self.th = tf.math.cos(math.pi - m)
self.mm = tf.math.sin(math.pi - m) * m
def build(self, input_shape):
super(ArcMarginProduct, self).build(input_shape[0])
self.W = self.add_weight(
name='W',
shape=(int(input_shape[0][-1]), self.n_classes),
initializer='glorot_uniform',
dtype='float32',
trainable=True,
regularizer=None)
def call(self, inputs):
X, y = inputs
y = tf.cast(y, dtype=tf.int32)
cosine = tf.matmul(
tf.math.l2_normalize(X, axis=1),
tf.math.l2_normalize(self.W, axis=0)
)
sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = tf.where(cosine > 0, phi, cosine)
else:
phi = tf.where(cosine > self.th, phi, cosine - self.mm)
one_hot = tf.cast(
tf.one_hot(y, depth=self.n_classes),
dtype=cosine.dtype
)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
Daher müssen Sie die folgenden Änderungen vornehmen.
Insbesondere überschreibt es `get_config ()`
und gibt das Argument `__init__
`und die Konfiguration der übergeordneten Klasse zurück.
new_custom_layer.py
class ArcMarginProduct(tf.keras.layers.Layer):
'''
Implements large margin arc distance.
Reference:
https://arxiv.org/pdf/1801.07698.pdf
https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
blob/master/src/modeling/metric_learning.py
'''
def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
ls_eps=0.0, **kwargs):
super(ArcMarginProduct, self).__init__(**kwargs)
self.n_classes = n_classes
self.s = s
self.m = m
self.ls_eps = ls_eps
self.easy_margin = easy_margin
self.cos_m = tf.math.cos(m)
self.sin_m = tf.math.sin(m)
self.th = tf.math.cos(math.pi - m)
self.mm = tf.math.sin(math.pi - m) * m
###Start Hinzugefügter Code
def get_config(self):
config = {
"n_classes" : self.n_classes,
"s" : self.s,
"m" : self.m,
"easy_margin" : self.easy_margin,
"ls_eps" : self.ls_eps
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
### End
def build(self, input_shape):
super(ArcMarginProduct, self).build(input_shape[0])
self.W = self.add_weight(
name='W',
shape=(int(input_shape[0][-1]), self.n_classes),
initializer='glorot_uniform',
dtype='float32',
trainable=True,
regularizer=None)
def call(self, inputs):
X, y = inputs
y = tf.cast(y, dtype=tf.int32)
cosine = tf.matmul(
tf.math.l2_normalize(X, axis=1),
tf.math.l2_normalize(self.W, axis=0)
)
sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = tf.where(cosine > 0, phi, cosine)
else:
phi = tf.where(cosine > self.th, phi, cosine - self.mm)
one_hot = tf.cast(
tf.one_hot(y, depth=self.n_classes),
dtype=cosine.dtype
)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
Und das Modell sollte wie folgt geladen werden.
load_model.py
loaded_model =keras.models.load_model("path_to_model", custom_objects = {"ArcMarginProduct": ArcMarginProduct})
So erstellen Sie eine benutzerdefinierte Ebene mit Keras
Benutzerdefinierte Ebenen mit Keras serialisieren
NotImplementedError: Layers with arguments in __init__
must override get_config