[PYTHON] [TensorFlow 2] How to check the contents of Tensor in graph mode

Introduction

When using a self-made function for the loss function during learning or the conversion part of the data set in TensorFlow (2.x), I tried to confirm "Is the value as expected?" Print ( Calling) may not output the value. You can use a debugger such as tfdbg, but here's a simpler way to do so-called "print debugging".

For tfdbgUse tfdbg to crush Keras nan and inf --Qiita

Verification environment

In the code below, it is assumed that ↓ is written.

import tensorflow as tf 

Behavior when print () is done to Tensor

Eager Execution is now the default in TensorFlow 2.x, so as long as you are assembling Tensor with the interpreter, you can just do print () to see the value of Tensor. You can also get the value as ndarray with .numpy ().

If you can see the value

x = tf.constant(10)
y = tf.constant(20)
z = x + y
print(z)         # tf.Tensor(30, shape=(), dtype=int32)
print(z.numpy()) # 30

If the value cannot be confirmed (1)

Since it is slow to evaluate the value of Tensor sequentially at the time of learning / evaluation, it is possible to define the function to be processed in the graph mode by adding the @ tf.function decorator. ** Operations defined in graph mode are processed together as a calculation graph (outside Python), so you cannot see the values in the calculation process. ** **

@tf.function
def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# Tensor("mul:0", shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

You can print the calculation result outside of dot (), but you cannot see the value inside. Moreover, even if you call dot () multiple times, the print () inside is basically executed only once at the timing of analyzing the graph.

Of course, in this case you can remove the @ tf.function () and you will see the value.

def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# tf.Tensor([3 2], shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor([0 4], shape=(2,), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

If the value cannot be confirmed (2)

It may be implicitly executed in graph mode, such as when map () processing for tf.data.Dataset or when using your own loss function in Keras. In this case, you cannot see the value in the middle with print () even if you do not add ** @ tf.function. ** **

def fourth_power(x):
    z = x * x
    print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)

# Tensor("mul:0", shape=(), dtype=int64)
# tf.Tensor(0, shape=(), dtype=int64)
# tf.Tensor(1, shape=(), dtype=int64)
# tf.Tensor(16, shape=(), dtype=int64)
# :

tf.print()

Use tf.print () to see the value of Tensor in the process running in graph mode. tf.print | TensorFlow Core v2.1.0

You can display the values as follows, or you can use tf.shape () to display the dimensions and size of the Tensor. Please also use it for debugging when you get angry when the dimensions and sizes do not match for some reason in your own function.

@tf.function
def dot(x, y):
    tmp = x * y
    tf.print(tmp)
    tf.print(tf.shape(tmp))
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# [3 2] 
# [2]
# tf.Tensor(5, shape=(), dtype=int32)
# [0 4]
# [2]
# tf.Tensor(4, shape=(), dtype=int32)

If you want to debug Dataset.map (), you can use tf.print ().

def fourth_power(x):
    z = x * x
    tf.print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)
# 0
# tf.Tensor(0, shape=(), dtype=int64)
# 1
# tf.Tensor(1, shape=(), dtype=int64)
# 4
# tf.Tensor(16, shape=(), dtype=int64)
# :

Example of tf.print () not working well

Then, if you don't think about anything and you should output the contents of Tensor with tf.print (), that is not the case. tf.print () is executed when the process is actually executed as a calculated graph. In other words, if it is found that the types and dimensions do not match at the time of graph analysis without performing ** calculation and an error occurs, the processing of tf.print () will not be executed. ** It's complicated ...

@tf.function
def add(x, y):
    z = x + y
    tf.print(z) #This print will not be executed
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

In such a case, it is better to use ordinary print () to check if the data of the desired shape is passed.

@tf.function
def add(x, y):
    print(x) #This print is executed during calculation graph analysis
    print(y)
    z = x + y
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# Tensor("x:0", shape=(2,), dtype=int32)
# Tensor("y:0", shape=(3,), dtype=int32)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

How to update variables in your own function

For example, consider counting the number of times your function has been called in graph mode and displaying how many times it has been called.

Bad example

count = 0

@tf.function
def dot(x, y):
    global count
    tmp = x * y
    count += 1
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 1 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

"1" will be displayed the second time as well. This is because count + = 1 as Python code is only executed once during graph analysis.

An example that works well

The correct answer is to use tf.Variable () and ```assign_add () `` as shown below. tf.Variable | TensorFlow Core v2.1.0

count = tf.Variable(0)

@tf.function
def dot(x, y):
    tmp = x * y
    count.assign_add(1)
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 2 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

Reference article

Performance up with tf.function | TensorFlow Core (Official document)

Recommended Posts

[TensorFlow 2] How to check the contents of Tensor in graph mode
How to check the version of Django
How to check the memory size of a variable in Python
How to check the memory size of a dictionary in Python
The first artificial intelligence. How to check the version of Tensorflow installed.
How to check if the contents of the dictionary are the same in Python by hash value
How to compare if the contents of the objects in scipy.sparse.csr_matrix are the same
How to mention a user group in slack notification, how to check the id of the user group
How to install the deep learning framework Tensorflow 1.0 in the Anaconda environment of Windows
How to check in Python if one of the elements of a list is in another list
[Ubuntu] How to delete the entire contents of a directory
How to find the optimal number of clusters in k-means
How to see the contents of the Jupyter notebook ipynb file
How to connect the contents of a list into a string
How to run TensorFlow 1.0 code in 2.0
[Super easy! ] How to display the contents of dictionaries and lists including Japanese in Python
How to specify an infinite number of tolerances in the numeric argument validation check of argparse
How to handle multiple versions of CUDA in the same environment
How to determine the existence of a selenium element in Python
How to know the internal structure of an object in Python
How to change the color of just the button pressed in Tkinter
[EC2] How to install chrome and the contents of each command
Output the contents of ~ .xlsx in the folder to HTML with Python
How to get the vertex coordinates of a feature in ArcPy
Create a function to get the contents of the database in Go
Check the behavior of destructor in Python
How to check opencv version in python
Basics of PyTorch (1) -How to use Tensor-
How to check local GAE from iPhone browser in the same LAN
[Introduction to Python] How to sort the contents of a list efficiently with list sort
I made a program to check the size of a file in Python
I tried to display the altitude value of DTM in a graph
How to copy and paste the contents of a sheet in Google Spreadsheet in JSON format (using Google Colab)
How to calculate the volatility of a brand
How to use the C library in Python
How to find the area of the Voronoi diagram
The inaccuracy of Tensorflow was due to log (0)
Summary of how to import files in Python 3
How to run CNN in 1 system notation in Tensorflow 2
How to use the graph drawing library Bokeh
Summary of how to use MNIST in Python
How to check / extract files in RPM package
How to get the files in the [Python] folder
[TF] How to build Tensorflow in Proxy environment
How to pass the execution result of a shell command in a list in Python
How to uniquely identify the source of access in the Django Generic Class View
How to count the number of elements in Django and output to a template
A memorandum of how to execute the! Sudo magic command in Jupyter Notebook
How to find the coefficient of the trendline that passes through the vertices in Python
How to change the appearance of unselected Foreign Key fields in Django's ModelForm
How to make the font width of jupyter notebook put in pyenv equal width
How to get a list of files in the same directory with python
Easy way to check the source of Python modules
How to retrieve the nth largest value in Python
I tried to graph the packages installed in Python
How to know the port number of the xinetd service
How to get the variable name itself in python
How to run the Ansible module added in Ansible Tower
Template of python script to read the contents of the file
How to identify the element with the smallest number of characters in a Python list?
Omit the decimal point of the graph scale in matplotlib