[PYTHON] [TF] Speichern und Laden von Tensorflow-Trainingsparametern

Verwenden Sie ** tf.train.Saver **, um die in Tensorflow gelernten Parameter zu speichern und zu laden.

sparen

Verwenden Sie zum Speichern die Methode ** save ** der erstellten Speicherklasse.

python


saver = tf.train.Saver()

Einige Verarbeitung
 
#sparen
saver.save(sess, "model.ckpt")

Das Speichern kann am Ende des Lernens oder in der Mitte des Lernens erfolgen.

Lesen

Verwenden Sie beim Lesen die Methode ** restore ** der erstellten Speicherklasse. Wir brauchen eine Sitzung, also laden Sie sie nach dem Erstellen der Sitzung. Erstellen Sie bei der Ausführung unter ipython eine Sitzung mit tf.InteractiveSession (), normalerweise tf.Session ().

python


sess=tf.InteractiveSession()

saver.restore(sess, "model.ckpt")

Der Status des tatsächlichen Speicherns und Ladens wird unten angezeigt.

Der Fluss ist wie folgt.

    1. Modellieren
  1. Lernen
    1. Speichern Sie die Parameter für einen späteren Vergleich in einer anderen Variablen
  2. Speichern Sie die Parameter in der Datei 5.Session Close
  3. Sitzung erstellen
  4. Initialisierung (Dies ist zunächst nicht erforderlich. Sie wurde absichtlich zum Vergleich initialisiert.)
  5. Vergleichen Sie mit gespeicherten Parametern (Dies ist anders, da es vor einer Zeit initialisiert wurde.)
  6. Parameter aus Datei lesen
  7. Vergleiche mit gespeicherten Parametern (dies stimmt überein)
  8. Lernen

TF_SaveAndRestoreModel-20-1-html.png

Code

python


# # import

# In[1]:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data


# # load dataset

# In[2]:

mnist = input_data.read_data_sets("./data/mnist/", one_hot=True) 


# # build model

# In[3]:

def mlp(x, output_dim, reuse=False):
        
    w1 = tf.get_variable("w1", [x.get_shape()[1], 1024], initializer=tf.random_normal_initializer())
    b1 = tf.get_variable("b1", [1024], initializer=tf.constant_initializer(0.0))
    w2 = tf.get_variable("w2", [1024, output_dim], initializer=tf.random_normal_initializer())
    b2 = tf.get_variable("b2", [output_dim], initializer=tf.constant_initializer(0.0))
    
    fc1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    fc2 = tf.matmul(fc1, w2) + b2

    return fc2, [w1, b1, w2, b2]

def slp(x, output_dim):
    w1 = tf.get_variable("w1", [x.get_shape()[1], output_dim], initializer=tf.random_normal_initializer())
    b1 = tf.get_variable("b1", [output_dim], initializer=tf.constant_initializer(0.0))
    
    fc1 = tf.nn.relu(tf.matmul(x, w1) + b1)
    return fc1, [w1, b1]

n_batch = 32
n_label = 10
n_train = 10000
imagesize = 28
learning_rate = 0.5

x_node = tf.placeholder(tf.float32, shape=(n_batch, imagesize*imagesize))
y_node = tf.placeholder(tf.float32, shape=(n_batch, n_label))

with tf.variable_scope("MLP") as scope:
    out_m, theta_m = mlp(x_node, n_label)
           
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(out_m, y_node))
opt  = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
tr_pred = tf.nn.softmax(out_m)

test_data = mnist.test.images
test_labels = mnist.test.labels
tx = tf.constant(test_data)
ty_ = tf.constant(test_labels)

with tf.variable_scope("MLP") as scope:
    scope.reuse_variables()
    ty, _ = mlp(tx, n_label)
    
te_pred = tf.nn.softmax(ty) 


# In[4]:

def accuracy(y, y_):
    return 100.0 * np.sum(np.argmax(y, 1) == np.argmax(y_, 1)) / y.shape[0]


# In[5]:

saver = tf.train.Saver()

sess=tf.InteractiveSession()

init = tf.initialize_all_variables()
sess.run(init)


# In[6]:

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)
    feed_dict = {x_node: bx, y_node: by}
    _, _loss, _tr_pred = sess.run([opt, loss, tr_pred], feed_dict=feed_dict)
    if step % 500 == 0:
        _accuracy = accuracy(_tr_pred, by)
        print 'step = %d, loss=%.2f, accuracy=%.2f' % (step, _loss, _accuracy)

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[8]:

old_theta_m = [ p.eval() for p in theta_m] # for comparing


# In[9]:

saver.save(sess, "model.ckpt")


# In[10]:

sess.close()


# In[11]:

sess=tf.InteractiveSession()

# for clear
init = tf.initialize_all_variables()
sess.run(init)


# In[12]:

for prm, prm_o in zip(theta_m, old_theta_m):
    p1 = prm.eval()
    p2 = prm_o
    print np.sum(p1 != p2) 


# In[13]:

saver.restore(sess, "model.ckpt")


# In[14]:

for prm, prm_o in zip(theta_m, old_theta_m):
    p1 = prm.eval()
    p2 = prm_o
    print np.sum(p1 != p2) 


# In[15]:

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[16]:

for step in xrange(n_train):
    bx, by = mnist.train.next_batch(n_batch)
    feed_dict = {x_node: bx, y_node: by}
    _, _loss, _tr_pred = sess.run([opt, loss, tr_pred], feed_dict=feed_dict)
    if step % 500 == 0:
        _accuracy = accuracy(_tr_pred, by)
        print 'step = %d, loss=%.2f, accuracy=%.2f' % (step, _loss, _accuracy)

print 'test accuracy=%.2f' % accuracy(te_pred.eval(), test_labels)


# In[17]:

sess.close()


# In[ ]:

tf.reset_default_graph()

Recommended Posts

[TF] Speichern und Laden von Tensorflow-Trainingsparametern
[TF] Laden / Speichern von Modell und Parameter in Keras
So teilen und speichern Sie einen DataFrame
[TF] So erstellen Sie Tensorflow in einer Proxy-Umgebung
Freigeben von Ordnern für Docker und Windows mit Tensorflow
[How to!] Lerne und spiele Super Mario mit Tensorflow !!
[TensorFlow 2 / Keras] Ausführen des Lernens mit CTC Loss in Keras
[Google Colab] So unterbrechen Sie das Lernen und setzen es dann fort
TensorFlow-Lernmethode für Profis der freien Künste und Python-Anfänger
So zeichnen Sie interaktiv eine Pipeline für maschinelles Lernen mit scikit-learn und speichern sie in HTML
So installieren und verwenden Sie Tesseract-OCR
So installieren und konfigurieren Sie Amsel
Verwendung von .bash_profile und .bashrc
So installieren und verwenden Sie Graphviz
So konvertieren Sie das Tensorflow-Modell in Lite
Ich habe zusammengefasst, wie die Boot-Parameter von GRUB und GRUB2 geändert werden
So führen Sie TensorFlow 1.0-Code in 2.0 aus
Sammeln von Daten zum maschinellen Lernen
Lösen von Folienrätseln und 15 Rätseln
Coursera Machine Learning Challenge in Python: ex6 (Anpassen von SVM-Parametern)
Es ist sehr nützlich, Target in Luigi die Methoden save () und load () hinzuzufügen
[Linux] Unterteilen von Dateien und Ordnern
So verpacken und verteilen Sie Python-Skripte
Einführung in das maschinelle Lernen: Funktionsweise des Modells
scikit-learn Verwendung der Zusammenfassung (maschinelles Lernen)
So installieren und verwenden Sie pandas_datareader [Python]
Einführung in Deep Learning ~ Falten und Pooling ~
[TF] Verwendung von Tensorboard von Keras
So studieren Sie den Deep Learning G-Test
Python: Verwendung von Einheimischen () und Globalen ()
Verwendung von Tensorflow unter Docker-Umgebung
Verwendung von Python zip und Aufzählung
Verwendung ist und == in Python
Wie man Fabric installiert und wie man es benutzt
Wie schreibe ich pydoc und mehrzeilige Kommentare
So installieren Sie das Deep Learning Framework Tensorflow 1.0 in der Windows Anaconda-Umgebung
[Tensorflowjs_converter] So konvertieren Sie das Tensorflow-Modell in das Format Tensorflow.js
Konformität und Rückruf-Verständnis zur Bewertung der Klassifizierungsleistung ①-
So generieren Sie eine Sequenz in Python und C ++
So erstellen Sie erklärende Variablen und Zielfunktionen
[Python] Lesen von Daten aus CIFAR-10 und CIFAR-100
So führen Sie CNN in 1 Systemnotation mit Tensorflow 2 aus
So wechseln Sie zwischen Linux- und Mac-Shells
Einführung in Deep Learning ~ Lokalisierungs- und Verlustfunktion ~
[Python] Verwendung von Hash-Funktion und Taple.
[AWS / Lambda] Laden einer externen Python-Bibliothek
Tensorufuro, Tensafuro Immerhin welches (wie man Tensorflow liest)
Datenbereinigung Umgang mit fehlenden und Ausreißern
[TF] So geben Sie Variablen an, die mit Optimizer aktualisiert werden sollen
Jenkins ist weiterhin sicher zu bedienen! Referenzseite * [Zusammenfassung der Java-Unterstützung - Qiita] (https://qiita.com/nowokay/items/edb5c5df4dbfc4a99ffb "Zusammenfassung der Java-Unterstützung - Qiita") * [Offizielle Java 11-Version dieser Version Gegen eine Gebühr Unterstützung für das Oracle JDK. Die Erwartungen an kostenlosen langfristigen Support durch OpenJDK sind derzeit gering - Publickey] (https://www.publickey1.jp/blog/18/java_11oracle_jdkopenjdk.html "Die offizielle Java 11-Version wurde veröffentlicht. Der Support für Oracle JDK wird ab dieser Version bezahlt Es wird derzeit nicht erwartet, dass OpenJDK langfristig kostenlos unterstützt wird - Publickey ") Java, Jenkins, OpenJDK, Java8, Java11 Das Ranking der Ansichtsanzahl von Spotify wird aggregiert und in Excel gespeichert
So installieren Sie den Cascade-Detektor und wie verwenden Sie ihn
Erfahren Sie, wie Sie Bilder aus dem TensorFlow-Code aufblasen
So erzwingen Sie, dass TensorFlow 2.3.0 für CUDA11 + cuDNN8 erstellt wird
Aufteilen von Trainingsdaten für maschinelles Lernen in objektive Variablen und andere in Pandas
Schritte zum schnellen Erstellen einer umfassenden Lernumgebung auf einem Mac mit TensorFlow und OpenCV