Heureusement, cela fait longtemps que TensorFlow 2.x est sorti dans le monde. Jusqu'à présent, la v1.x a beaucoup de mal, mais dans la v2.x, de nombreuses fonctions très intéressantes telles que la possibilité de passer facilement à un graphe de type tf avec un décorateur ont été implémentées, et nous, les développeurs, pouvons également apprendre les modèles rapidement. , Une évaluation, etc. peut maintenant être effectuée.
Cependant, certaines personnes doivent utiliser la v1.x car cela est dû à des problèmes de compatibilité avec la v2.x (ce qu'il faut cacher, c'est moi). Maintenant que toute la documentation mise à jour mentionne 2.x, la documentation 1.x est difficile à développer et il est difficile d'obtenir les bonnes informations. Je suis maintenant (à partir de 17:10 2020/11/4) et cela fait longtemps avant que j'entre réellement dans le processus de chargement partiel des paramètres dans la v1.x. Si vous utilisez simplement ResNet ou MobileNet ou le modèle publié tel quel, vous n'aurez pas beaucoup de problèmes car vous pouvez lire l'intégralité du graphique de calcul, mais j'aimerais utiliser ResNet appris à l'avance comme encodeur d'image pour le réseau auto-créé suivant. Dans le cas de .., il est nécessaire de lire partiellement les paramètres. Pour ceux qui liront partiellement les paramètres de la v1.x à l'avenir (ce qu'il faut cacher, c'est moi), je vais enregistrer comment lire partiellement les paramètres comme décrit dans le titre.
En principe, c'est possible avec le code suivant.
...
with tf.Session() as sess:
saver = tf.train.Saver({'Nom de nœud du modèle que vous souhaitez charger':Tf avec ce nom de nœud.Variable variable, ...})
saver.restore(sess, 'path/to/checkpoint')
Cependant, il dit "Comment puis-je créer un dictionnaire de noms de nœuds et de variables à transmettre à tf.train.Saver!", Donc dans ce cas
variables = tf.trainable_variables()
restore_variables = {}
for v in variables:
if 'Espace de noms du modèle' in v.name:
restore_variables[v.name] = v
Ce faisant, vous pouvez extraire uniquement un nœud spécifique du nœud que vous utilisez actuellement et le mettre dans le dictionnaire.
Lors de l'apprentissage, : 0
peut être donné sans permission, donc dans ce cas
fixed_name = v.name[:-2]
restore_variables[fixed_name] = v
En répondant comme, il devient lisible.
Tensorflow.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file
est utile si vous voulez voir comment chaque nom de variable est stocké dans chaque point de contrôle lorsqu'il est pré-appris.
import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name='path/to/checkpoint', tensor_name='', all_tensors=False)
# beta1_power (DT_FLOAT) []
# beta2_power (DT_FLOAT) []
# cae/conv0/convolution2d/biases (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam_1 (DT_FLOAT) [64]
# cae/conv0/convolution2d/weights (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam_1 (DT_FLOAT) [7,7,3,64]
# cae/conv1/convolution2d/biases (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam_1 (DT_FLOAT) [32]
# cae/conv1/convolution2d/weights (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam_1 (DT_FLOAT) [5,5,64,32]
J'ai évoqué l'article suivant. https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125
Recommended Posts