[PYTHON] Try using tf.metrics

Trigger

When I tried to use tf.metrics.accuracy, I was troubled by the fact that there were two return values (accuracy, update \ _op) and the values were not the normal correct answer rate. The same was true for tf.metrics.recall and tf.metrics.precision. It seems that there are almost no Japanese articles about this at the moment, so I made a note for the time being.

Behavior of tf.metrics

As the name suggests, it calculates various metrics including the correct answer rate.

However, if you only see the name,

# labels:1D tensor with correct label
# predictions:Predicted label one-dimensional tensor

accuracy, update_op = tf.metrics.accuracy(labels, predictions)
accuracy = tf.reduce_mean(tf.cast(predictions == labels, tf.float32))

You would expect these two accuracy to be the same value. Also, what do you think about update_op?

In conclusion, tf.metrics.accuracy behaves as if it holds all past values. (Actually, the total number of correct answers in the past and the number of data count are retained, and only "total ÷ number" is used).

That is, if you answered all the questions correctly in the first epoch and all the questions were wrong in the second epoch (and if the batch size of each epoch is always the same), the first accuracy is 1.00 and the second accuracy is 0.50. It becomes. If all the questions are answered correctly in the third epoch, the accuracy of the third time is about 0.67.

It seems that many people are confused about this behavior even if you look at tensorflow issues. Opinions such as "It's non-intuitive" and "I think tf.metrics.streaming \ _accuracy is a better name for this function".

By the way, one respondent said

And that. I see, I fell in love with it. It certainly seems convenient.

How to use tf.metrics

tf.metrics has two return values. accuracy and update \ _op.

Calling update \ _op will update the correct answer rate. accuracy holds the last calculated correct answer rate (initial value is 0).

In short, it looks like this.

import numpy as np
import tensorflow as tf

labels = tf.placeholder(tf.float32, [None])
predictions = tf.placeholder(tf.float32, [None])
accuracy, update_op = tf.metrics.accuracy(labels, predictions)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    print(sess.run(accuracy))  #Initial value 0

    #First time(All questions correct)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 3 = 1

    #Second time(All questions wrong)
    sess.run(update_op, feed_dict={
        labels: np.array([0, 0, 0]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 6 = 0.5

    #Third time(All questions correct)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 6 / 9 =About 0.67

Implementation using tf.metrics

I don't know if this is good, but it looks like this, for example. Please let me know if there is another good way.

def create_metrics(labels, predictions, register_to_summary=True):
    update_op, metrics_op = {}, {}

    # accuracy, recall,Tf for precision calculation.Use metrics
    for key, func in zip(('accuracy', 'recall', 'precision'),
                         (tf.metrics.accuracy, tf.metrics.recall, tf.metrics.precision)):
        metrics_op[key], update_op[key] = func(labels, predictions, name=key)

    # f1_score is calculated by yourself
    metrics_op['f1_score'] = tf.divide(
        2 * metrics_op['precision'] * metrics_op['recall'],
        metrics_op['precision'] + metrics_op['recall'] + 1e-8,
        name='f1_score'
    )  # 1e-8 is a division by zero measure

    entire_update_op = tf.group(*update_op.values())

    if register_to_summary:  #Later tf.summary.merge_all()to be able to do
        for k, v in metrics_op.items():
            tf.summary.scalar(k, v)

    return metrics_op, entire_update_op

metrics_op, entire_update_op = create_metrics(labels, predictions)
merged = tf.summary.merge_all()

What I want to say and do is, in short

about it.

Remarks

By the way, these metrics are local variables, not global variables.

local_init_op = tf.local_variables_initializer()
sess.run(local_init_op)

need to do it.

Recommended Posts

Try using tf.metrics
Try using docker-py
Try using cookiecutter
Try using PDFMiner
Try using Selenium
Try using scipy
Try using pandas.DataFrame
Try using matplotlib
Try using PyODE
[Azure] Try using Azure Functions
Try using virtualenv now
Try using W & B
Try using Django templates.html
[Kaggle] Try using LGBM
Try using Python's Tkinter
Try using Tweepy [Python2.7]
Try using Pytorch's collate_fn
Try using PythonTex with Texpad.
[Python] Try using Tkinter's canvas
Try using Jupyter's Docker image
Try using scikit-learn (1) --K-means clustering
Try function optimization using Hyperopt
Try using matplotlib with PyCharm
Try using Azure Logic Apps
Try using Kubernetes Client -Python-
[Kaggle] Try using xg boost
Try using the Twitter API
Try using OpenCV on Windows
Try using Jupyter Notebook dynamically
Try using AWS SageMaker Studio
Try tweeting automatically using Selenium.
Try using SQLAlchemy + MySQL (Part 1)
Try using the Twitter API
Try using SQLAlchemy + MySQL (Part 2)
Try using Django's template feature
Try using the PeeringDB 2.0 API
Try using Pelican's draft feature
Try using pytest-Overview and Samples-
Try using folium with anaconda
Try using Janus gateway's Admin API
[Statistics] [R] Try using quantile regression.
Try using Spyder included in Anaconda
Try using design patterns (exporter edition)
Try using Pillow on iPython (Part 1)
Try using Pillow on iPython (Part 2)
Try using Pleasant's API (python / FastAPI)
Try using LevelDB in Python (plyvel)
Try using pynag to configure Nagios
Try using PyCharm's remote debugging feature
Try using ArUco on Raspberry Pi
Try using cheap LiDAR (Camsense X1)
[Sakura rental server] Try using flask.
Try using Pillow on iPython (Part 3)
Reinforcement learning 8 Try using Chainer UI
Try to get statistics using e-Stat
Try using Python argparse's action API
Try using the Python Cmd module
Try using Python's networkx with AtCoder
Try using Leap Motion in Python
Try using GCP Handwriting Recognition (OCR)
Try using Amazon DynamoDB from Python