[PYTHON] Sélectionnez les variables requises dans TensorFlow et enregistrez / restaurez

Lors de la division du processus d'apprentissage dans TensorFlow, la variable save / restore est requise, qui est prise en charge par la classe tf.train.Saver dans TensorFlow. Si l'échelle du modèle est petite, vous pouvez enregistrer / restaurer toutes les variables utilisées, mais si le modèle est grand, vous souhaiterez enregistrer / restaurer uniquement les variables dont vous avez réellement besoin.

Dans cet article, nous confirmerons la méthode de sauvegarde / restauration des variables en utilisant la classification numérique manuscrite MNIST comme exemple. (L'environnement est Python 2.7.11, tensorflow 0.8.0.)

Ajouter trainable = True aux variables requises

Il existe différentes situations possibles en fonction du contenu du programme et de ce qui est nécessaire. Le moyen le plus simple est de sauvegarder la totalité de la variable utilisée (variable de classe tf.Variable).


chkpt_file = '../MNIST_data/mnist_cnn.ckpt'

# Create the model
def inference(x, y_, keep_prob, phase_train):
(Omis)
Construction d'un modèle de réseau, etc.
    
    return loss, accuracy, y_pred

if __name__ == '__main__':
(Omis)
    loss, accuracy, y_pred = inference(x, y_, 
                                         keep_prob, phase_train)
    #                                    
    #Opération d'économie avant d'entrer dans la session(ops)Est défini sans argument
    #
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init) 
               
        if restore_call:
            # Restore variables from disk.
            saver.restore(sess, chkpt_file) 

        if TASK == 'train':
            print('\n Training...')
            for i in range(5001):
            
(Processus d'apprentissage)
   
        # Save the variables to disk.Enfin écrire sur le disque
        if TASK == 'train':
            save_path = saver.save(sess, chkpt_file)
            print("Model saved in file: %s" % save_path)

Comme mentionné ci-dessus, vous pouvez définir des opérations (ops) en utilisant "tf.train.Saver ()" sans argument et enregistrer la totalité de la variable tf.Variable avec sa méthode save. (Remarque. Ceux définis dans tf.placeholder ne sont pas applicables.)

Cependant, dans la plupart des modèles simples de réseaux de neurones, le ** poids w ** et le ** biais b ** de chaque unité sont suffisants. La méthode recommandée ici est d'ajouter le drapeau entraînable dans la définition de variable.

Voici une définition de classe de la couche de convolution et un exemple de classe de la couche entièrement connectée.

class Convolution2D(object):
    '''
      constructor's args:
          input     : input image (2D matrix)
          input_siz ; input image size
          in_ch     : number of incoming image channel
          out_ch    : number of outgoing image channel
          patch_siz : filter(patch) size
          weights   : (if input) (weights, bias)
    '''
    def __init__(self, input, input_siz, in_ch, out_ch, patch_siz, activation='relu'):
        self.input = input      
        self.rows = input_siz[0]
        self.cols = input_siz[1]
        self.in_ch = in_ch
        self.activation = activation
        
        wshape = [patch_siz[0], patch_siz[1], in_ch, out_ch]
        
        w_cv = tf.Variable(tf.truncated_normal(wshape, stddev=0.1), 
                            trainable=True)
        b_cv = tf.Variable(tf.constant(0.1, shape=[out_ch]), 
                            trainable=True)
        
        self.w = w_cv
        self.b = b_cv
        self.params = [self.w, self.b]
        
(Omis)

# Full-connected Layer   
class FullConnected(object):
    def __init__(self, input, n_in, n_out):
        self.input = input
    
        w_h = tf.Variable(tf.truncated_normal([n_in,n_out],
                          mean=0.0, stddev=0.05), trainable=True)
        b_h = tf.Variable(tf.zeros([n_out]), trainable=True)
     
        self.w = w_h
        self.b = b_h
        self.params = [self.w, self.b]
    
(Omis)

Lors de la déclaration des variables (w_cv, b_cv) et (w_h, b_h) qui correspondent au poids et au biais, «trainable = True» (entraînable) est ajouté. Avec cet effort, vous ne pourrez collecter que les variables entraînables plus tard.

if __name__ == '__main__':
(Omis)
    vars_to_train = tf.trainable_variables()
    
    if os.path.exists(chkpt_file) == False:
        restore_call = False
        init = tf.initialize_all_variables()

    else:
        restore_call = True
        vars_all = tf.all_variables()
        vars_to_init = list(set(vars_all) - set(vars_to_train))
        init = tf.initialize_variables(vars_to_init)
          
    saver = tf.train.Saver(vars_to_train)

    with tf.Session() as sess:
(Omis)
    

Le point dans le code ci-dessus est Nous venons de collecter les variables déclarées avec trainable = True avec ** tf.trainable_variables () ** et toutes les variables déclarées avec ** tf.all_variables () **. L'image de l'ensemble des variables est la suivante.

tf_vars_1.png

Dans le premier processus, seules les variables avec apprentissage sont enregistrées et dans le deuxième processus et les suivants, ces variables enregistrées sont restaurées et utilisées. Cependant, le flux est que les variables qui ne sont pas enregistrées / restaurées sont initialisées (même après la deuxième fois).

Ici, comparez les tailles des fichiers enregistrés.

-rw-rw-r--1 52404005 31 mai 09:54 mnist_cnn.all_vars
-rw-rw-r--1 13100491 22 mai 09:15 mnist_cnn.trainable

Ceci n'est qu'un exemple, mais le fichier enregistré peut être aussi petit que 1/4. (Le fichier ci-dessus a été renommé de mnist_cnn.ckpt.)

Comment collecter des variables à l'aide d'espaces de noms

Ensuite, nous allons introduire une autre méthode, une méthode de collecte de variables à l'aide de l'espace de noms variable et de sauvegarde / restauration. Dans TensorFlow, afin de visualiser Graph (configuration du modèle), je pense que la construction de Graph peut se poursuivre lors de la définition d'un espace de noms, mais cet espace de noms peut être utilisé pour collecter les variables définies.

L'exemple suivant est une méthode de collecte des variables utilisées dans la normalisation par lots ajoutée à la couche de convolution.

def batch_norm(x, n_out, phase_train):
    with tf.variable_scope('bn'):

(Traitement de normalisation par lots, divers)

    return normed
    
#Créer la pièce de construction du modèle de modèle, lot ci-dessus_norm()Appelle
def inference(x, y_, keep_prob, phase_train):
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    
    with tf.variable_scope('conv_1'):
        conv1 = Convolution2D(x, (28, 28), 1, 32, (5, 5), activation='none')
        conv1_bn = batch_norm(conv1.output(), 32, phase_train)
        conv1_out = tf.nn.relu(conv1_bn)
           
        pool1 = MaxPooling2D(conv1_out)
        pool1_out = pool1.output()
    
    with tf.variable_scope('conv_2'):
        conv2 = Convolution2D(pool1_out, (28, 28), 32, 64, (5, 5), 
                                                          activation='none')
        conv2_bn = batch_norm(conv2.output(), 64, phase_train)
        conv2_out = tf.nn.relu(conv2_bn)
           
        pool2 = MaxPooling2D(conv2_out)
        pool2_out = pool2.output()    
        pool2_flat = tf.reshape(pool2_out, [-1, 7*7*64])
    
(Autres couches, omis)
    
    return loss, accuracy, y_pred
 
#Traitement principal
if __name__ == '__main__':
    TASK = 'train'    # 'train' or 'test'
    
    # Variables
    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    keep_prob = tf.placeholder(tf.float32)
    phase_train = tf.placeholder(tf.bool, name='phase_train')
    
    loss, accuracy, y_pred = inference(x, y_, 
                                         keep_prob, phase_train)

    # Train
    lr = 0.01
    train_step = tf.train.AdagradOptimizer(lr).minimize(loss)
    vars_to_train = tf.trainable_variables()    # option-1
    vars_for_bn1 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='conv_1/bn')
    vars_for_bn2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='conv_2/bn')
    vars_to_train = list(set(vars_to_train).union(set(vars_for_bn1)))
    vars_to_train = list(set(vars_to_train).union(set(vars_for_bn2)))
    
    if TASK == 'train':
        restore_call = False
        init = tf.initialize_all_variables()
    elif TASK == 'test':
        restore_call = True
        vars_all = tf.all_variables()
        vars_to_init = list(set(vars_all) - set(vars_to_train))
        init = tf.initialize_variables(vars_to_init)  # option-1
        # init = tf.initialize_all_variables()    option-2
    else:
        print('Check task switch.')
          
    saver = tf.train.Saver(vars_to_train) 

    with tf.Session() as sess:
(Ci-après, le contenu de la session TensorFlow)

Ici, il y a la couche de convolution 1 (conv_1) et la couche de convolution 2 (conv_2), et une normalisation par lots est effectuée pour chacune. Les variables utilisées ici sont collectées par ** tf.get_collection () ** comme suit. ing.

vars_for_bn1 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='conv_1/bn')
vars_for_bn2 = tf.get_collection(tf.GraphKeys.VARIABLES, scope='conv_2/bn')

Dans cet exemple, inference () définit les espaces de noms ** conv_1 ** et ** conv_2 **, et en cela, batch_norm () avec l'espace de noms ** bn ** est appelé. , Les espaces de noms ci-dessus (imbriqués) conv_1 / bn, conv_2 / bn.

Après cela, l'ensemble des variables est organisé et divisé en "choses à sauvegarder" et "choses à ne pas sauvegarder (initialisées après la prochaine fois)". (Comme le code est difficile à lire, je vais joindre un diagramme pour vous aider à le comprendre.)

tf_vars_2.png

Les variables nécessaires sont sauvegardées en passant la partie colorée dans la figure ci-dessus comme vars_to_train à tf.train.Saver (). Lorsque la normalisation par lots a été introduite "dans une boîte noire", un bogue s'est produit parce que les variables nécessaires n'étaient pas sauvegardées, mais en utilisant la méthode ci-dessus, le bogue pouvait être corrigé.

Enfin, vérifiez à nouveau la taille du fichier.

-rw-rw-r--1 52404005 31 mai 09:54 mnist_cnn.all_vars
-rw-rw-r--1 13105573 31 mai 10:05 mnist_cnn.ckpt
-rw-rw-r--1 13100491 22 mai 09:15 mnist_cnn.trainable

Le haut est le cas où toutes les variables sont stockées, environ 52 Mo, le second est le cas où l'entraînement et l'espace de noms ci-dessus sont utilisés ensemble, environ 13 Mo, et le troisième est le cas où seul l'entraînement est utilisé ( L'opération comprend des bogues), soit environ 13 Mo. Nous pensons que réduire la taille du fichier est plus efficace pour réduire le temps d’E / S disque que pour réduire l’utilisation du disque.

(Je vais télécharger le code final sur Gist. Le voici.)

Références (site Web)

Recommended Posts

Sélectionnez les variables requises dans TensorFlow et enregistrez / restaurez
12. Enregistrez la première colonne dans col1.txt et la deuxième colonne dans col2.txt
Comprendre l'espace de noms TensorFlow et les variables partagées principales
Examiner la relation entre TensorFlow et Keras pendant la période de transition
Clipping et normalisation dans TensorFlow
Définissez les variables d'environnement requises pour PySide (Qt4) et PyQt (Qt5)
Recherchez le pandas.DataFrame avec une variable et obtenez la ligne correspondante.
Enregistrez le fichier binaire en Python
[Python3] Enregistrez la matrice de moyenne et de covariance dans json avec les pandas
Rechercher le nom et les données d'une variable libre dans un objet fonction
Trouvez-le dans la file d'attente et modifiez-le
Créez une API REST à l'aide du modèle appris dans Lobe et TensorFlow Serving.
Enregistrez l'ID de canal spécifié dans le texte et chargez-le au prochain démarrage
Variables Python et types de données appris avec la chimio-automatique
[python] Différence entre variable et self. Variable dans la classe
À propos de la différence entre "==" et "is" en python
Vérification des méthodes et des variables à l'aide de la bibliothèque voir
Lorsque l'axe et l'étiquette se chevauchent dans matplotlib
Boucle les variables en même temps dans le modèle
Accédez aux variables définies dans le script depuis REPL
Extrayez les informations de paroles dans le fichier MP3 / MP4 et enregistrez-les dans le fichier de paroles (* .lrc) pour Sony Walkman.
Je veux remplacer les variables dans le fichier de modèle python et le produire en masse dans un autre fichier