[PYTHON] Essayez d'utiliser tf.metrics

Déclencheur

Lorsque j'ai essayé d'utiliser tf.metrics.accuracy, j'étais troublé car il y avait deux valeurs de retour (précision, mise à jour \ _op) et les valeurs n'étaient pas le taux de réponse correct normal. La même chose était vraie pour tf.metrics.recall et tf.metrics.precision. Il semble qu'il n'y ait presque pas d'articles japonais à ce sujet pour le moment, j'ai donc pris une note pour le moment.

Comportement de tf.metrics

Comme son nom l'indique, il calcule diverses mesures, y compris le taux de réponse correct.

Cependant, si vous ne voyez que le nom,

# labels:Tenseur unidimensionnel avec étiquette de réponse correcte
# predictions:Tenseur unidimensionnel d'étiquette prédite

accuracy, update_op = tf.metrics.accuracy(labels, predictions)
accuracy = tf.reduce_mean(tf.cast(predictions == labels, tf.float32))

Vous vous attendez à ce que ces deux exactitudes aient la même valeur. Aussi, que pensez-vous de update_op?

En conclusion, tf.metrics.accuracy se comporte comme s'il contenait toutes les valeurs passées. (En fait, le nombre total de réponses correctes dans le passé et le nombre de données comptées sont conservés, et seul "total ÷ nombre" est utilisé).

Autrement dit, si vous avez répondu correctement à toutes les questions à la première époque et que toutes les questions étaient erronées à la deuxième époque (et si la taille du lot de chaque époque est toujours la même), la première précision est de 1,00 et la seconde de 0,50. Il devient. Si toutes les questions reçoivent une réponse correcte à la troisième époque, la troisième précision est d'environ 0,67.

Il semble que beaucoup de gens soient confus à propos de ce comportement même si vous regardez les problèmes de tensorflow. Il existe des opinions telles que «Ce n'est pas intuitif» et «Je pense que tf.metrics.streaming \ _accuracy est un meilleur nom pour cette fonction».

À propos, un répondant a dit

Et cela. Je vois, j'en suis tombé amoureux. Cela semble certainement pratique.

Comment utiliser tf.metrics

tf.metrics a deux valeurs de retour. précision et mise à jour \ _op.

L'appel de update \ _op mettra à jour le taux de réponse correct. la précision contient le dernier taux de réponse correcte calculé (la valeur initiale est 0).

Bref, ça ressemble à ça.

import numpy as np
import tensorflow as tf

labels = tf.placeholder(tf.float32, [None])
predictions = tf.placeholder(tf.float32, [None])
accuracy, update_op = tf.metrics.accuracy(labels, predictions)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    print(sess.run(accuracy))  #Valeur initiale 0

    #Première fois(Toutes les questions sont correctes)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 3 = 1

    #Deuxième fois(Toutes les questions sont fausses)
    sess.run(update_op, feed_dict={
        labels: np.array([0, 0, 0]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 6 = 0.5

    #Troisième fois(Toutes les questions sont correctes)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 6 / 9 =Environ 0.67

Implémentation à l'aide de tf.metrics

Je ne sais pas si c'est bon, mais ça ressemble à ça, par exemple. S'il vous plaît laissez-moi savoir s'il existe un autre bon moyen.

def create_metrics(labels, predictions, register_to_summary=True):
    update_op, metrics_op = {}, {}

    # accuracy, recall,tf pour le calcul de précision.Utiliser des métriques
    for key, func in zip(('accuracy', 'recall', 'precision'),
                         (tf.metrics.accuracy, tf.metrics.recall, tf.metrics.precision)):
        metrics_op[key], update_op[key] = func(labels, predictions, name=key)

    # f1_le score est calculé par vous-même
    metrics_op['f1_score'] = tf.divide(
        2 * metrics_op['precision'] * metrics_op['recall'],
        metrics_op['precision'] + metrics_op['recall'] + 1e-8,
        name='f1_score'
    )  # 1e-8 est une mesure de division zéro

    entire_update_op = tf.group(*update_op.values())

    if register_to_summary:  #Plus tard tf.summary.merge_all()pouvoir faire
        for k, v in metrics_op.items():
            tf.summary.scalar(k, v)

    return metrics_op, entire_update_op

metrics_op, entire_update_op = create_metrics(labels, predictions)
merged = tf.summary.merge_all()

Ce que je veux dire et faire c'est, en bref

à propos de ça.

Remarques

À propos, ces métriques sont des variables locales, pas des variables globales.

local_init_op = tf.local_variables_initializer()
sess.run(local_init_op)

besoin de le faire.

Recommended Posts

Essayez d'utiliser tf.metrics
Essayez d'utiliser docker-py
Essayez d'utiliser Cookiecutter
Essayez d'utiliser PDFMiner
Essayez d'utiliser Selenium
Essayez d'utiliser scipy
Essayez d'utiliser pandas.DataFrame
Essayez d'utiliser matplotlib
Essayez d'utiliser PyODE
[Azure] Essayez d'utiliser Azure Functions
Essayez d'utiliser virtualenv maintenant
Essayez d'utiliser W & B
Essayez d'utiliser Django templates.html
[Kaggle] Essayez d'utiliser LGBM
Essayez d'utiliser Tkinter de Python
Essayez d'utiliser Tweepy [Python2.7]
Essayez d'utiliser collate_fn de Pytorch
Essayez d'utiliser PythonTex avec Texpad.
[Python] Essayez d'utiliser le canevas de Tkinter
Essayez d'utiliser l'image Docker de Jupyter
Essayez d'utiliser scikit-learn (1) - Clustering K par méthode moyenne
Essayez l'optimisation des fonctions à l'aide d'Hyperopt
Essayez d'utiliser matplotlib avec PyCharm
Essayez d'utiliser Azure Logic Apps
Essayez d'utiliser Kubernetes Client -Python-
[Kaggle] Essayez d'utiliser xg boost
Essayez d'utiliser l'API Twitter
Essayez d'utiliser OpenCV sur Windows
Essayez d'utiliser Jupyter Notebook de manière dynamique
Essayez d'utiliser AWS SageMaker Studio
Essayez de tweeter automatiquement en utilisant Selenium.
Essayez d'utiliser SQLAlchemy + MySQL (partie 1)
Essayez d'utiliser l'API Twitter
Essayez d'utiliser SQLAlchemy + MySQL (partie 2)
Essayez d'utiliser la fonction de modèle de Django
Essayez d'utiliser l'API PeeringDB 2.0
Essayez d'utiliser la fonction de brouillon de Pelican
Essayez d'utiliser pytest-Overview and Samples-
Essayez d'utiliser le folium avec anaconda
Essayez d'utiliser l'API Admin de la passerelle Janus
[Statistiques] [R] Essayez d'utiliser la régression par points de division.
Essayez d'utiliser Spyder inclus dans Anaconda
Essayez d'utiliser des modèles de conception (édition exportateur)
Essayez d'utiliser Pillow sur iPython (partie 1)
Essayez d'utiliser Pillow sur iPython (partie 2)
Essayez d'utiliser l'API de Pleasant (python / FastAPI)
Essayez d'utiliser LevelDB avec Python (plyvel)
Essayez d'utiliser pynag pour configurer Nagios
Essayez d'utiliser la fonction de débogage à distance de PyCharm
Essayez d'utiliser ArUco avec Raspberry Pi
Essayez d'utiliser LiDAR bon marché (Camsense X1)
[Serveur de location Sakura] Essayez d'utiliser flask.
Essayez d'utiliser Pillow sur iPython (partie 3)
Renforcer l'apprentissage 8 Essayez d'utiliser l'interface utilisateur de Chainer
Essayez d'obtenir des statistiques en utilisant e-Stat
Essayez d'utiliser l'API d'action de Python argparse
Essayez d'utiliser le module Python Cmd
Essayez d'utiliser le networkx de Python avec AtCoder
Essayez d'utiliser LeapMotion avec Python
Essayez d'utiliser la reconnaissance de caractères manuscrits (OCR) de GCP
Essayez d'utiliser Amazon DynamoDB à partir de Python