[PYTHON] Verhalten bei Container Trainable = False in Keras

Einführung

Oft möchten Sie das Gewicht Ihres Netzwerks mit Keras festlegen und nur eine weitere Ebene lernen. Es ist ein Memo, in dem ich untersucht habe, worauf ich damals achten sollte.

Versions

Überprüfung

Betrachten Sie das folgende Modell. model_normal.png

Angenommen, Sie möchten das Gewicht des "NormalContainer" -Teils hier "aktualisieren", und manchmal möchten Sie es nicht aktualisieren.

Intuitiv scheint es gut zu sein, False auf die Eigenschaft "Container # trainable" zu setzen, aber ich werde versuchen zu sehen, ob es wie beabsichtigt funktioniert.

Code

# coding: utf8

import numpy as np
from keras.engine.topology import Input, Container
from keras.engine.training import Model
from keras.layers.core import Dense
from keras.utils.vis_utils import plot_model



def all_weights(m):
    return [list(w.reshape((-1))) for w in m.get_weights()]


def random_fit(m):
    x1 = np.random.random(10).reshape((5, 2))
    y1 = np.random.random(5).reshape((5, 1))
    m.fit(x1, y1, verbose=False)

np.random.seed(100)

x = in_x = Input((2, ))

# Create 2 Containers shared same wights
x = Dense(1)(x)
x = Dense(1)(x)
fc_all = Container(in_x, x, name="NormalContainer")
fc_all_not_trainable = Container(in_x, x, name="FixedContainer")

# Create 2 Models using the Containers
x = fc_all(in_x)
x = Dense(1)(x)
model_normal = Model(in_x, x)

x = fc_all_not_trainable(in_x)
x = Dense(1)(x)
model_fixed = Model(in_x, x)

# Set one Container trainable=False
fc_all_not_trainable.trainable = False  # Case1

# Compile
model_normal.compile(optimizer="sgd", loss="mse")
model_fixed.compile(optimizer="sgd", loss="mse")

# fc_all_not_trainable.trainable = False  # Case2

# Watch which weights are updated by model.fit
print("Initial Weights")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_normal)

print("after training Model-Normal")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_fixed)

print("after training Model-Fixed")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))


# plot_model(model_normal, "model_normal.png ", show_shapes=True)

Erstellen Sie zwei "Container", "fc_all" und "fc_all_not_trainable". Letzteres lässt "trainierbar" auf "Falsch". Erstellen Sie damit "Model" mit dem Namen "model_normal" und "model_fixed".

Das erwartete Verhalten ist

Das ist.

Behältergewicht Anderes Gewicht
model_normal#fit() Veränderung Veränderung
model_fixed#fit() Es ändert sich nicht Veränderung

Ausführungsergebnis: Fall1

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37869808], [0.0091063408]]

Wie erwartet.

Hinweis: trainable = False muss vor compile () gesetzt werden

Was ist, wenn Sie im obigen Code trainable = False nach Model # compile () (wo Fall 2 ist) setzen?

Ausführungsergebnis: Fall2

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [0.37869808], [0.0091063408]]

Gleich bis "nach dem Training Model-Normal", Wenn nach dem Training Model-Fixed, ändert sich auch das Gewicht von Container.

Model # compile () ruft beim Aufruf trainable_weights aus allen enthaltenen Layern ab. Wenn Sie zu diesem Zeitpunkt nicht "trainierbar" einstellen, ist dies daher bedeutungslos.

Ein weiterer Punkt ist, dass es nicht notwendig ist, für alle im Container ** enthaltenen Ebenen "trainierbar" festzulegen. "Container" ist eine Ebene, wenn man sie von "Modell" aus betrachtet. Model ruft Container # trainable_weights auf, gibt aber nichts zurück, wenn Container # trainable False ist (entsprechend /keras/engine/topology.py#L1891)), sodass alle in Container enthaltenen Ebenengewichte nicht aktualisiert werden. Es ist ein bisschen unklar, ob dies eine Spezifikation oder nur die Implementierung in dieser Phase ist, aber ich denke, es ist wahrscheinlich beabsichtigt.

schließlich

Der leichte Dunst wurde behoben.

Recommended Posts

Verhalten bei Container Trainable = False in Keras
Verhalten, wenn mehrere Server in Nameservern von dnspython angegeben sind
Verschlüsseln Sie das Bild beim Drücken fälschlicherweise
Verhalten beim Auflisten in Python heapq
Überprüfen Sie das Verhalten des Zerstörers in Python
Verhalten bei der Rückkehr in den with-Block
Verhaltensänderung von [Diagramm / Zeitleiste] in Choregraphe 2.5.5.5
Ich war in Schwierigkeiten, weil sich das Verhalten des Docker-Containers nicht geändert hat
Unterschiede im Verhalten jeder LL-Sprache, wenn der Listenindex übersprungen wird
Fügen Sie Python3 in den Docker-Container von Amazon Linux2 ein
Verhalten beim Speichern eines Python-Datetime-Objekts in MongoDB
Verhalten von numpy.dot beim Übergeben von 1d-Array und 2d-Array
Beachten Sie, wenn Sie lxml des Python-Pakets in Ubuntu 14.04 einfügen