[PYTHON] Versuchen Sie es mit tf.metrics

Auslösen

Als ich versuchte, tf.metrics.accuracy zu verwenden, war ich beunruhigt darüber, dass es zwei Rückgabewerte gab (Genauigkeit, Aktualisierung \ _op) und die Werte nicht die normale korrekte Antwortrate waren. Gleiches gilt für tf.metrics.recall und tf.metrics.precision. Es scheint, dass es im Moment fast keine japanischen Artikel darüber gibt, also habe ich mir vorerst eine Notiz gemacht.

Verhalten von tf.metrics

Wie der Name schon sagt, werden verschiedene Metriken berechnet, einschließlich der richtigen Antwortrate.

Wenn Sie jedoch nur den Namen sehen,

# labels:1-dimensionaler Tensor mit korrektem Antwortetikett
# predictions:Voraussichtlicher eindimensionaler Tensor

accuracy, update_op = tf.metrics.accuracy(labels, predictions)
accuracy = tf.reduce_mean(tf.cast(predictions == labels, tf.float32))

Sie würden erwarten, dass diese beiden Genauigkeiten den gleichen Wert haben. Was denkst du über update_op?

Zusammenfassend verhält sich tf.metrics.accuracy so, als ob es alle früheren Werte enthält. (Tatsächlich bleiben die Gesamtzahl der richtigen Antworten in der Vergangenheit und die Anzahl der Daten erhalten, und es wird nur die "Gesamtzahl" verwendet.)

Das heißt, wenn Sie in der ersten Epoche alle Fragen richtig beantwortet haben und in der zweiten Epoche alle Fragen falsch waren (und wenn die Stapelgröße jeder Epoche immer gleich ist), beträgt die erste Genauigkeit 1,00 und die zweite Genauigkeit 0,50. Es wird. Wenn alle Fragen in der dritten Epoche richtig beantwortet wurden, beträgt die dritte Genauigkeit etwa 0,67.

Es scheint, dass viele Menschen über dieses Verhalten verwirrt sind, selbst wenn Sie sich [Tensorflow-Probleme] ansehen (https://github.com/tensorflow/tensorflow/issues/9498). Es gibt Meinungen wie "Es ist nicht intuitiv" und "Ich denke, tf.metrics.streaming \ _accuracy ist ein besserer Name für diese Funktion."

Ein Befragter sagte übrigens

Und das. Ich verstehe, ich habe mich in sie verliebt. Es scheint sicherlich bequem.

Verwendung von tf.metrics

tf.metrics hat zwei Rückgabewerte. Genauigkeit und Aktualisierung \ _op.

Wenn Sie update \ _op aufrufen, wird die richtige Antwortrate aktualisiert. Die Genauigkeit enthält die zuletzt berechnete korrekte Antwortrate (Anfangswert ist 0).

Kurz gesagt, es sieht so aus.

import numpy as np
import tensorflow as tf

labels = tf.placeholder(tf.float32, [None])
predictions = tf.placeholder(tf.float32, [None])
accuracy, update_op = tf.metrics.accuracy(labels, predictions)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    print(sess.run(accuracy))  #Anfangswert 0

    #Erstes Mal(Alle Fragen richtig)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 3 = 1

    #Zweites Mal(Alle Fragen falsch)
    sess.run(update_op, feed_dict={
        labels: np.array([0, 0, 0]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 6 = 0.5

    #Drittes Mal(Alle Fragen richtig)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 6 / 9 =Über 0.67

Implementierung mit tf.metrics

Ich weiß nicht, ob das gut ist, aber es sieht zum Beispiel so aus. Bitte lassen Sie mich wissen, ob es einen anderen guten Weg gibt.

def create_metrics(labels, predictions, register_to_summary=True):
    update_op, metrics_op = {}, {}

    # accuracy, recall,tf zur Präzisionsberechnung.Verwenden Sie Metriken
    for key, func in zip(('accuracy', 'recall', 'precision'),
                         (tf.metrics.accuracy, tf.metrics.recall, tf.metrics.precision)):
        metrics_op[key], update_op[key] = func(labels, predictions, name=key)

    # f1_Die Punktzahl wird von Ihnen selbst berechnet
    metrics_op['f1_score'] = tf.divide(
        2 * metrics_op['precision'] * metrics_op['recall'],
        metrics_op['precision'] + metrics_op['recall'] + 1e-8,
        name='f1_score'
    )  # 1e-8 ist ein Nullteilungsmaß

    entire_update_op = tf.group(*update_op.values())

    if register_to_summary:  #Später tf.summary.merge_all()tun können
        for k, v in metrics_op.items():
            tf.summary.scalar(k, v)

    return metrics_op, entire_update_op

metrics_op, entire_update_op = create_metrics(labels, predictions)
merged = tf.summary.merge_all()

Was ich sagen und tun möchte, ist kurz gesagt

darüber.

Bemerkungen

Übrigens sind diese Metriken lokale Variablen, keine globalen Variablen.

local_init_op = tf.local_variables_initializer()
sess.run(local_init_op)

müssen es tun.

Recommended Posts

Versuchen Sie es mit tf.metrics
Versuchen Sie es mit Docker-Py
Versuchen Sie es mit einem Ausstecher
Versuchen Sie es mit PDFMiner
Versuchen Sie es mit Selen
Versuchen Sie es mit scipy
Versuchen Sie es mit pandas.DataFrame
Versuchen Sie es mit matplotlib
Versuchen Sie es mit PyODE
[Azure] Versuchen Sie, Azure-Funktionen zu verwenden
Versuchen Sie es jetzt mit virtualenv
Versuchen Sie es mit W & B.
Versuchen Sie es mit Django templates.html
[Kaggle] Versuchen Sie es mit LGBM
Versuchen Sie es mit Pythons Tkinter
Versuchen Sie es mit Tweepy [Python2.7]
Versuchen Sie es mit Pytorchs collate_fn
Versuchen Sie, PythonTex mit Texpad zu verwenden.
[Python] Versuchen Sie, Tkinters Leinwand zu verwenden
Versuchen Sie es mit Jupyters Docker-Image
Versuchen Sie es mit Scikit-Learn (1) - K-Clustering nach Durchschnittsmethode
Versuchen Sie die Funktionsoptimierung mit Hyperopt
Versuchen Sie es mit matplotlib mit PyCharm
Versuchen Sie es mit Azure Logic Apps
Versuchen Sie es mit Kubernetes Client -Python-
[Kaggle] Versuchen Sie es mit xg boost
Versuchen Sie es mit der Twitter-API
Versuchen Sie es mit OpenCV unter Windows
Versuchen Sie, Jupyter Notebook dynamisch zu verwenden
Versuchen Sie es mit AWS SageMaker Studio
Versuchen Sie, automatisch mit Selen zu twittern.
Versuchen Sie es mit SQLAlchemy + MySQL (Teil 1)
Versuchen Sie es mit der Twitter-API
Versuchen Sie es mit SQLAlchemy + MySQL (Teil 2)
Versuchen Sie es mit der Vorlagenfunktion von Django
Versuchen Sie es mit der PeeringDB 2.0-API
Versuchen Sie es mit der Entwurfsfunktion von Pelican
Versuchen Sie es mit pytest-Overview und Samples-
Versuchen Sie es mit Folium mit Anakonda
Versuchen Sie es mit der Admin-API von Janus Gateway
[Statistik] [R] Versuchen Sie, die Teilungspunktregression zu verwenden.
Versuchen Sie es mit Spyder, das in Anaconda enthalten ist
Versuchen Sie es mit Designmustern (Exporter Edition)
Versuchen Sie es mit Pillow auf iPython (Teil 1)
Versuchen Sie es mit Pillow auf iPython (Teil 2)
Versuchen Sie es mit der Pleasant-API (Python / FastAPI).
Versuchen Sie es mit LevelDB mit Python (plyvel)
Versuchen Sie, Nagios mit pynag zu konfigurieren
Versuchen Sie, die Remote-Debugging-Funktion von PyCharm zu verwenden
Versuchen Sie es mit ArUco mit Raspberry Pi
Versuchen Sie es mit billigem LiDAR (Camsense X1)
[Sakura-Mietserver] Versuchen Sie es mit einer Flasche.
Versuchen Sie es mit Pillow auf iPython (Teil 3).
Stärkung des Lernens 8 Versuchen Sie, die Chainer-Benutzeroberfläche zu verwenden
Versuchen Sie, Statistiken mit e-Stat abzurufen
Versuchen Sie es mit der Aktions-API von Python argparse
Versuchen Sie es mit dem Python Cmd-Modul
Versuchen Sie, Pythons networkx mit AtCoder zu verwenden
Versuchen Sie es mit LeapMotion mit Python
Versuchen Sie es mit der handgeschriebenen Zeichenerkennung (OCR) von GCP.
Versuchen Sie es mit Amazon DynamoDB von Python