[PYTHON] [TensorFlow] Résoudre "ValueError: la fonction décorée par tf.function a essayé de créer des variables lors d'un autre appel"

Tf.function avec Optimizer ou Model comme entrée meurt avec "essayé de créer des variables sur le non-premier appel"

Un mémorandum car j'ai été pris dans un piège inattendu si je pensais que je n'avais pas l'intention de déclarer une variable autre que le premier appel de modèle (build)

supposition

TensorFlow2 (je l'ai confirmé à 2,5 tous les soirs)

La création de tf.Variable dans tf.function est interdite

  1. tf.function fonctionne essentiellement dans la même zone de mémoire une fois sécurisée.
  2. tf.Variable alloue une nouvelle zone mémoire et crée une variable.
  3. Par conséquent, chaque fois qu'une nouvelle variable est créée, elle consomme de la nouvelle mémoire, ce qui entraîne une erreur.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)

Caught expected exception 
  <class 'ValueError'>: in user code:

    <ipython-input-17-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

Ce peut être juste une déclaration d'une variable python. Parce que la même zone de mémoire est recyclée et utilisée chaque fois que la fonction est appelée.


@tf.function
def f(x):
    v = tf.ones((5, 5), dtype=tf.float32)
    return v
# raises no error!

Sujet principal

tf.function avec tf.keras.models.Model comme entrée provoque une erreur inattendue.

@tf.function
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
    return model1(inputs)


if __name__ == '__main__':
    model1 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)

    ])
    model2 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)
    ])
    inputs = tf.ones((10, 10), dtype=tf.float32)
    call(model1, inputs)  # raises no error!
    call(model2, inputs)  # raises an error! "tf.function-decorated function tried to create variables on non-first call"

Le fait est que le modèle keras construit une variable telle qu'une matrice de poids lors de son premier appel. Par conséquent, si vous passez deux modèles non construits comme arguments à cette fonction d'appel, vous obtiendrez un accident en essayant de construire le poids du deuxième modèle lors du deuxième appel! !!

Solution


# @tf.la fonction n'est pas directement attachée
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
    return model1(inputs)


if __name__ == '__main__':
    model1 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)

    ])
    model2 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)
    ])
    inputs = tf.ones((10, 10), dtype=tf.float32)
    #Les fonctions sont également des objets en python. Générez une fonction dédiée à chaque modèle.
    model1_call = tf.function(call)
    model2_call = tf.function(call)
    model1_call(model1, inputs)
    model2_call(model2, inputs)

tf.function arrête d'écrire directement dans la définition de fonction et crée un objet fonction dédié à chaque modèle. Les concepts de métaprogrammation tels que les décorateurs et les fonctions d'ordre supérieur sont difficiles pour les débutants en python, il est donc important de ne pas suivre la raison en profondeur.

Si vous voulez savoir la raison pour laquelle vous pouvez opter pour cela, vous devriez utiliser la fonction d'ordre supérieur du décorateur python, etc.

Sommaire

Eager Execution est trop lent pour en parler lors de l'écriture compliquée et de nombreuses étapes sur GPU, donc j'aimerais utiliser tf.function pour poursuivre l'efficacité du calcul, mais TF2 vient de sortir. Il est facile de tomber sur des bogues qui n'ont pas de sens car tf.function est trop blackbox.

Nous continuerons à accumuler beaucoup de connaissances sur tf2 et le deeplearning, alors suivez-nous avec LGTM!

Recommended Posts

[TensorFlow] Résoudre "ValueError: la fonction décorée par tf.function a essayé de créer des variables lors d'un autre appel"
Erreur de la fonction décorée par tf.function lors de la tentative de création de variables lors d'un autre appel. Dans tensorflow.keras