[PYTHON] Partially read parameters in old TensorFlow 1.x

Thankfully, it's been a long time since TensorFlow 2.x was released to the world. We used to struggle a lot with v1.x, but with v2.x, we have implemented a lot of very nice features such as making it easy to change to a tf-like graph with a decorator, and we developers can learn models quickly. , Evaluation, etc. can now be performed.

However, some people have to use v1.x as it is due to compatibility issues with v2.x (what to hide is me). Now that all the updated documentation mentions 2.x, the 1.x documentation is hard to grow and it's hard to get the right information. I'm now (as of 17:10 2020/11/4) and it's been taking a while to hit the process of actually partially reading the parameters in v1.x. If you simply use ResNet or MobileNet or the published model as it is, you will not have much trouble because you can read the entire calculation graph, but I would like to use ResNet learned in advance as an image encoder for the subsequent self-made network. In the case of .., it is necessary to partially read the parameters. For those who will partially read parameters in v1.x in the future (what to hide is me), I will record how to partially read parameters as described in the title.

things to do

In principle, it is possible with the following code.

...
with tf.Session() as sess:
    saver = tf.train.Saver({'Node name of the model you want to load':Tf with that node name.Variable variable, ...})
    saver.restore(sess, 'path/to/checkpoint')

However, it says "How can I create a dictionary of node names and variables to be passed to tf.train.Saver!", So in that case

variables = tf.trainable_variables()
restore_variables = {}
for v in variables:
    if 'Model namespace' in v.name:
        restore_variables[v.name] = v

By doing so, you can extract only a specific node from the node you are currently using and put it in the dictionary.

When learning, : 0 may be given without permission, so in that case

fixed_name = v.name[:-2]
restore_variables[fixed_name] = v

By responding like, it becomes readable.

Supplement

Tensorflow.python.tools.inspect_checkpoint.print_tensors_in_checkpoint_file is useful if you want to see how each variable name is stored in each checkpoint when pre-learned.

import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(file_name='path/to/checkpoint', tensor_name='', all_tensors=False)

# beta1_power (DT_FLOAT) []
# beta2_power (DT_FLOAT) []
# cae/conv0/convolution2d/biases (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam (DT_FLOAT) [64]
# cae/conv0/convolution2d/biases/Adam_1 (DT_FLOAT) [64]
# cae/conv0/convolution2d/weights (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam (DT_FLOAT) [7,7,3,64]
# cae/conv0/convolution2d/weights/Adam_1 (DT_FLOAT) [7,7,3,64]
# cae/conv1/convolution2d/biases (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam (DT_FLOAT) [32]
# cae/conv1/convolution2d/biases/Adam_1 (DT_FLOAT) [32]
# cae/conv1/convolution2d/weights (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam (DT_FLOAT) [5,5,64,32]
# cae/conv1/convolution2d/weights/Adam_1 (DT_FLOAT) [5,5,64,32]

reference

I referred to the following article. https://blog.metaflow.fr/tensorflow-saving-restoring-and-mixing-multiple-models-c4c94d5d7125

Recommended Posts

Partially read parameters in old TensorFlow 1.x
OS X GPU is now supported in Tensorflow
Multivariate LSTM and data preprocessing in TensorFlow 2.x
Easy 3 minutes TensorBoard in Google Colab (using TensorFlow 2.x)
Read DXF in python
Image normalization in TensorFlow
Use "% tensorflow_version 2.x" when using TPU with Tensorflow 2.1.0 in Colaboratory