Zum Glück ist es lange her, dass TensorFlow 2.x für die Welt veröffentlicht wurde. Bisher hatte v1.x große Probleme, aber in v2.x wurden viele sehr nette Funktionen implementiert, z. B. die einfache Möglichkeit, mit einem Dekorateur zu einem tf-ähnlichen Diagramm zu wechseln, und wir Entwickler können Modelle auch schnell lernen. , Auswertung etc. können nun durchgeführt werden.
Einige Benutzer müssen jedoch v1.x verwenden, da dies auf Kompatibilitätsprobleme mit v2.x zurückzuführen ist (was ich verbergen soll, bin ich). Nachdem in der aktualisierten Dokumentation 2.x erwähnt wird, ist es schwierig, die 1.x-Dokumentation zu erweitern und die richtigen Informationen zu erhalten. Ich bin jetzt (Stand 17:10 2020/11/4) und es hat lange gedauert, bis ich tatsächlich den Prozess des teilweisen Ladens von Parametern in v1.x erreicht habe. Wenn Sie einfach ResNet oder MobileNet oder ein öffentlich verfügbares Modell verwenden, werden Sie keine großen Probleme haben, da Sie das gesamte Berechnungsdiagramm lesen können. Ich möchte jedoch das zuvor erlernte ResNet als Bildcodierer für das nachfolgende selbst erstellte Netzwerk verwenden. Im Fall von .. müssen die Parameter teilweise gelesen werden. Für diejenigen, die in Zukunft Parameter in v1.x teilweise lesen werden (was ich verbergen soll, bin ich), werde ich aufzeichnen, wie Parameter teilweise gelesen werden, wie im Titel beschrieben.
Grundsätzlich ist dies mit folgendem Code möglich.
...
with tf.Session() as sess:
saver = tf.train.Saver({'Knotenname des Modells, das Sie laden möchten':Tf mit diesem Knotennamen.Variable Variable, ...})
saver.restore(sess, 'path/to/checkpoint')
Es heißt jedoch "Wie kann ich ein Wörterbuch mit Knotennamen und Variablen erstellen, die an tf.train.Saver übergeben werden sollen!". In diesem Fall also
variables = tf.trainable_variables()
restore_variables = {}
for v in variables:
if 'Modell-Namespace' in v.name:
restore_variables[v.name] = v
Auf diese Weise können Sie nur einen bestimmten Knoten aus dem aktuell verwendeten Knoten extrahieren und in das Wörterbuch aufnehmen.
Beim Lernen kann ": 0" ohne Erlaubnis gegeben werden, also in diesem Fall
fixed_name = v.name[:-2]
restore_variables[fixed_name] = v
Wenn Sie wie reagieren, wird es lesbar.
Tensorflow.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file
ist nützlich, wenn Sie sehen möchten, wie jeder Variablenname in jedem Prüfpunkt gespeichert ist, wenn er vorab gelernt wurde.
import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name='path/to/checkpoint', tensor_name='', all_tensors=False)
# beta1_power (DT_FLOAT) []
# beta2_power (DT_FLOAT) []
# cae/conv0/convolution2d/biases (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam_1 (DT_FLOAT) [64]
# cae/conv0/convolution2d/weights (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam_1 (DT_FLOAT) [7,7,3,64]
# cae/conv1/convolution2d/biases (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam_1 (DT_FLOAT) [32]
# cae/conv1/convolution2d/weights (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam_1 (DT_FLOAT) [5,5,64,32]
Ich habe auf den folgenden Artikel verwiesen. https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125
Recommended Posts