Verwenden Sie tensorflow.train.Saver, um Tensorflow-Variablen in einer Datei zu speichern. Alle Variablen in der Sitzung werden mit der im Lernprogramm beschriebenen Methode gespeichert. Um nur eine bestimmte Variable zu speichern / wiederherzustellen, geben Sie der Initialisierungsfunktion von tensorflow.train.Saver eine Liste der Variablen, auf die Sie in einem Wörterbuchtyp abzielen möchten.
Dies ermöglicht es, Variablen einzeln aus mehreren Dateien zu lesen.
save.py
import tensorflow as tf
def get_particular_variables(name):
return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}
def define_variables(var0_value, var1_value, var2_value):
var0 = tf.Variable([var0_value])
with tf.variable_scope('foo'):
var1 = tf.Variable([var1_value])
with tf.variable_scope('bar'):
var2 = tf.Variable([var2_value])
return var0, var1, var2
sess = tf.InteractiveSession()
# defines variables
var0, var1, var2 = define_variables(0.0, 0.0, 0.0)
# saving only variables whose name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))
# initializing all of variables
sess.run(tf.initialize_all_variables())
print var0.eval(), var1.eval(), var2.eval()
# saving into file
saver.save(sess, './bar_val')
restore.py
import tensorflow as tf
def get_particular_variables(name):
return {v.name: v for v in tf.all_variables() if v.name.find(name) >= 0}
def define_variables(var0_value, var1_value, var2_value):
var0 = tf.Variable([var0_value])
with tf.variable_scope('foo'):
var1 = tf.Variable([var1_value])
with tf.variable_scope('bar'):
var2 = tf.Variable([var2_value])
return var0, var1, var2
sess = tf.InteractiveSession()
# defines variables
var0, var1, var2 = define_variables(1.0, 1.0, 1.0)
# restoring only variables whole name includes foo
saver = tf.train.Saver(get_particular_variables('foo'))
# initializing all of variables
sess.run(tf.initialize_all_variables())
print 'before restoring: ', var0.eval(), var1.eval(), var2.eval()
# restoring variable from file
saver.restore(sess, './bar_val')
print 'after restoring only var in foo: ', var0.eval(), var1.eval(), var2.eval()
Bei dieser Methode müssen jedoch lange Namen und die Hierarchie der Namespaces berücksichtigt werden. Zum Beispiel
variable | name-of-variable |
---|---|
var0 | Variable:0 |
var1 | foo/Variable:0 |
var2 | foo/bar/Variable:0 |
var3 | foobar/Variable:0 |
In einem solchen Fall gibt die Ausführung von get_particular_variables ('foo') oben var1, var2 und var3 zurück. Auf diese Weise werden abhängig von den Suchbedingungen zusätzliche Variablen gespeichert, die beim Wiederherstellen unerwartete Fehler verursachen können.
Recommended Posts