If I thought I didn't intend to declare a variable other than the first model call (build), I got an unexpected trap, so a memorandum
TensorFlow2 (I confirmed it at 2.5 nightly)
@tf.function
def f(x):
v = tf.Variable(1.0)
v.assign_add(x)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception
<class 'ValueError'>: in user code:
<ipython-input-17-73e410646579>:3 f *
v = tf.Variable(1.0)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__ **
return cls._variable_v2_call(*args, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
"tf.function-decorated function tried to create "
It can be just a declaration of a python variable. Because the same memory area is recycled and used every time the function is called.
@tf.function
def f(x):
v = tf.ones((5, 5), dtype=tf.float32)
return v
# raises no error!
tf.function with tf.keras.models.Model as input causes an unexpected error.
@tf.function
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
return model1(inputs)
if __name__ == '__main__':
model1 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
model2 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
inputs = tf.ones((10, 10), dtype=tf.float32)
call(model1, inputs) # raises no error!
call(model2, inputs) # raises an error! "tf.function-decorated function tried to create variables on non-first call"
The point is that the keras model builds a variable such as a weight matrix when it is called for the first time. Therefore, if you pass two unbuilt models as arguments to this call function, you will get an accident trying to build the weights of the second model on the second call! !!
# @tf.function is not attached directly
def call(model1: tf.keras.models.Model, inputs: tf.Tensor):
return model1(inputs)
if __name__ == '__main__':
model1 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
model2 = tf.keras.Sequential([
tf.keras.layers.Dense(16),
tf.keras.layers.Dense(4)
])
inputs = tf.ones((10, 10), dtype=tf.float32)
#Functions are also objects in python. Generate a function dedicated to each model.
model1_call = tf.function(call)
model2_call = tf.function(call)
model1_call(model1, inputs)
model2_call(model2, inputs)
tf.function stops writing directly in the function definition and creates a function object dedicated to each model. Metaprogramming concepts such as decorators and higher-order functions are difficult for python beginners, so it is important not to follow the reason deeply.
If you want to know the reason why you can do this, I think you should go around with a python decorator higher-order function.
Eager Execution is too slow to talk about when writing algorithms with complicated and many steps on GPU, so I would like to use tf.function to pursue computational efficiency, but TF2 has just been released. It's easy to stumble upon bugs that don't make sense because tf.function is too black box.
We will continue to accumulate a lot of knowledge about tf2 and deep learning, so please follow us with LGTM!