[PYTHON] [TensorFlow] Resolve "ValueError: tf.function-decorated function tried to create variables on non-first call"

Tf.function inputting Optimizer or Model dies with "tried to create variables on non-first call"

If I thought I didn't intend to declare a variable other than the first model call (build), I got an unexpected trap, so a memorandum

Premise

TensorFlow2 (I confirmed it at 2.5 nightly)

Creating tf.Variable in tf.function is prohibited

  1. tf.function basically operates in the same memory area once allocated.
  2. tf.Variable allocates a new memory area and creates a variable.
  3. Therefore, every time a new Variable is created, it consumes new memory, resulting in an error.

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)

Caught expected exception 
  <class 'ValueError'>: in user code:

    <ipython-input-17-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

It can be just a declaration of a python variable. Because the same memory area is recycled and used every time the function is called.


@tf.function
def f(x):
    v = tf.ones((5, 5), dtype=tf.float32)
    return v
# raises no error!

Main subject

tf.function with tf.keras.models.Model as input causes an unexpected error.

@tf.function
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
    return model1(inputs)


if __name__ == '__main__':
    model1 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)

    ])
    model2 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)
    ])
    inputs = tf.ones((10, 10), dtype=tf.float32)
    call(model1, inputs)  # raises no error!
    call(model2, inputs)  # raises an error! "tf.function-decorated function tried to create variables on non-first call"

The point is that the keras model builds a variable such as a weight matrix when it is called for the first time. Therefore, if you pass two unbuilt models as arguments to this call function, you will get an accident trying to build the weights of the second model on the second call! !!

solution


# @tf.function is not attached directly
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
    return model1(inputs)


if __name__ == '__main__':
    model1 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)

    ])
    model2 = tf.keras.Sequential([
        tf.keras.layers.Dense(16),
        tf.keras.layers.Dense(4)
    ])
    inputs = tf.ones((10, 10), dtype=tf.float32)
    #Functions are also objects in python. Generate a function dedicated to each model.
    model1_call = tf.function(call)
    model2_call = tf.function(call)
    model1_call(model1, inputs)
    model2_call(model2, inputs)

tf.function stops writing directly in the function definition and creates a function object dedicated to each model. Metaprogramming concepts such as decorators and higher-order functions are difficult for python beginners, so it is important not to follow the reason deeply.

If you want to know the reason why you can do this, I think you should go around with a python decorator higher-order function.

Summary

Eager Execution is too slow to talk about when writing algorithms with complicated and many steps on GPU, so I would like to use tf.function to pursue computational efficiency, but TF2 has just been released. It's easy to stumble upon bugs that don't make sense because tf.function is too black box.

We will continue to accumulate a lot of knowledge about tf2 and deep learning, so please follow us with LGTM!

Recommended Posts

[TensorFlow] Resolve "ValueError: tf.function-decorated function tried to create variables on non-first call"
Error of tf.function-decorated function tried to create variables on non-first call. In tensorflow.keras