[PYTHON] [TensorFlow 2] So überprüfen Sie den Inhalt von Tensor im Diagrammmodus

Einführung

Bei Verwendung einer selbst erstellten Funktion für die Verlustfunktion während des Trainings oder des Konvertierungsteils des Datensatzes in TensorFlow (2.x) habe ich versucht zu bestätigen, dass der Wert wie erwartet ist. Print ( Aufruf) gibt den Wert möglicherweise nicht aus. Sie könnten einen Debugger wie "tfdbg" verwenden, aber hier ist eine einfachere Möglichkeit, das sogenannte "Print-Debugging" durchzuführen.

Für tfdbgVerwenden Sie tfdbg, um Keras nan und inf - Qiita zu zerstören

Überprüfungsumgebung

Im folgenden Code wird angenommen, dass ↓ geschrieben ist.

import tensorflow as tf 

Verhalten, wenn print () auf Tensor ausgeführt wird

Eager Execution ist jetzt die Standardeinstellung in TensorFlow 2.x. Solange Sie "Tensor" mit dem Interpreter zusammenstellen, können Sie einfach "print ()" ausführen, um den Wert von "Tensor" anzuzeigen. Sie können den Wert auch als "ndarray" mit ".numpy ()" erhalten.

Wenn Sie den Wert sehen können

x = tf.constant(10)
y = tf.constant(20)
z = x + y
print(z)         # tf.Tensor(30, shape=(), dtype=int32)
print(z.numpy()) # 30

Wenn der Wert nicht bestätigt werden kann (1)

Da es langsam ist, den Wert von "Tensor" zum Zeitpunkt des Lernens / der Bewertung nacheinander auszuwerten, ist es möglich, die im Grafikmodus zu verarbeitende Funktion durch Hinzufügen des Dekorators "@ tf.function" zu definieren. ** Im Diagrammmodus definierte Vorgänge werden zusammen als Berechnungsdiagramm (außerhalb von Python) verarbeitet, sodass Sie die Werte im Berechnungsprozess nicht sehen können. ** ** **

@tf.function
def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# Tensor("mul:0", shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

Sie können das Berechnungsergebnis außerhalb von dot () ausgeben, aber Sie können den Wert darin nicht sehen. Selbst wenn Sie "dot ()" mehrmals aufrufen, wird "print ()" im Inneren zum Zeitpunkt der Analyse des Diagramms grundsätzlich nur einmal ausgeführt.

In diesem Fall können Sie natürlich die @ tf.function () entfernen, um den Wert zu sehen.

def dot(x, y):
    tmp = x * y
    print(tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# tf.Tensor([3 2], shape=(2,), dtype=int32)
# tf.Tensor(5, shape=(), dtype=int32)
# tf.Tensor([0 4], shape=(2,), dtype=int32)
# tf.Tensor(4, shape=(), dtype=int32)

Wenn der Wert nicht bestätigt werden kann (2)

Es kann implizit im Grafikmodus ausgeführt werden, z. B. wenn map () für tf.data.Dataset verarbeitet wird oder wenn Sie Ihre eigene Verlustfunktion in Keras verwenden. In diesem Fall können Sie den Wert in der Mitte mit print () nicht sehen, auch wenn ** @ tf.function nicht hinzugefügt wird. ** ** **

def fourth_power(x):
    z = x * x
    print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)

# Tensor("mul:0", shape=(), dtype=int64)
# tf.Tensor(0, shape=(), dtype=int64)
# tf.Tensor(1, shape=(), dtype=int64)
# tf.Tensor(16, shape=(), dtype=int64)
# :

tf.print()

Verwenden Sie tf.print (), um den Wert von Tensor im Prozess anzuzeigen, der im Grafikmodus ausgeführt wird. tf.print | TensorFlow Core v2.1.0

Sie können die Werte wie folgt anzeigen oder "tf.shape ()" verwenden, um die Abmessungen und die Größe von "Tensor" anzuzeigen. Bitte verwenden Sie es auch zum Debuggen, wenn Sie wütend werden, wenn die Abmessungen und Größen aus irgendeinem Grund in Ihrer eigenen Funktion nicht übereinstimmen.

@tf.function
def dot(x, y):
    tmp = x * y
    tf.print(tmp)
    tf.print(tf.shape(tmp))
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# [3 2] 
# [2]
# tf.Tensor(5, shape=(), dtype=int32)
# [0 4]
# [2]
# tf.Tensor(4, shape=(), dtype=int32)

Wenn Sie "Dataset.map ()" debuggen möchten, können Sie "tf.print ()" verwenden.

def fourth_power(x):
    z = x * x
    tf.print(z)
    z = z * z
    return z

ds = tf.data.Dataset.range(10).map(fourth_power)
for i in ds:
    print(i)
# 0
# tf.Tensor(0, shape=(), dtype=int64)
# 1
# tf.Tensor(1, shape=(), dtype=int64)
# 4
# tf.Tensor(16, shape=(), dtype=int64)
# :

Beispiel für tf.print () funktioniert nicht gut

Wenn Sie dann an nichts denken und den Inhalt von Tensor mit tf.print () ausgeben sollten, ist dies nicht der Fall. tf.print () wird ausgeführt, wenn der Prozess tatsächlich als berechneter Graph ausgeführt wird. Mit anderen Worten, wenn Sie feststellen, dass die Typen und Dimensionen zum Zeitpunkt der Diagrammanalyse nicht übereinstimmen, ohne eine ** Berechnung durchzuführen, und ein Fehler auftritt, wird die Verarbeitung von "tf.print ()" nicht ausgeführt. ** Es ist kompliziert ...

@tf.function
def add(x, y):
    z = x + y
    tf.print(z) #Dieser Druck wird nicht ausgeführt
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

In einem solchen Fall ist es besser, gewöhnliches print () zu verwenden, um zu überprüfen, ob die Daten der gewünschten Form übergeben werden.

@tf.function
def add(x, y):
    print(x) #Dieser Druck wird während der Analyse des Berechnungsdiagramms ausgeführt
    print(y)
    z = x + y
    return z

x = tf.constant([1, 2])
y = tf.constant([3, 4, 5])
ret = add(x, y)
# Tensor("x:0", shape=(2,), dtype=int32)
# Tensor("y:0", shape=(3,), dtype=int32)
# ValueError: Dimensions must be equal, but are 2 and 3 for 'add' (op: 'AddV2') with input shapes: [2], [3].

So aktualisieren Sie Variablen in Ihrer eigenen Funktion

Zählen Sie beispielsweise, wie oft Ihre eigene Funktion im Diagrammmodus aufgerufen wurde, und zeigen Sie die Anzahl der Aufrufe an.

Schlechtes Beispiel

count = 0

@tf.function
def dot(x, y):
    global count
    tmp = x * y
    count += 1
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 1 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

"1" wird auch beim zweiten Mal angezeigt. Dies liegt daran, dass count + = 1 als Python-Code während der Diagrammanalyse nur einmal ausgeführt wird.

Ein Beispiel, das gut funktioniert

Die richtige Antwort ist, "tf.Variable ()" und "assign_add ()" wie unten gezeigt zu verwenden. tf.Variable | TensorFlow Core v2.1.0

count = tf.Variable(0)

@tf.function
def dot(x, y):
    tmp = x * y
    count.assign_add(1)
    tf.print(count, tmp)
    return tf.reduce_sum(tmp)

x = tf.constant([1, 2])
y = tf.constant([3, 1])
w = tf.constant([0, 2])
print(dot(x, y))
print(dot(x, w))
# 1 [3 2]
# tf.Tensor(5, shape=(), dtype=int32)
# 2 [0 4]
# tf.Tensor(4, shape=(), dtype=int32)

Referenzartikel

Leistungsverbesserung mit tf.function | TensorFlow Core (Offizielles Dokument)

Recommended Posts

[TensorFlow 2] So überprüfen Sie den Inhalt von Tensor im Diagrammmodus
So überprüfen Sie die Version von Django
So überprüfen Sie die Speichergröße einer Variablen in Python
So überprüfen Sie die Speichergröße eines Wörterbuchs in Python
Die erste künstliche Intelligenz. So überprüfen Sie die installierte Version von Tensorflow.
So überprüfen Sie anhand des Hashwerts, ob der Inhalt des Wörterbuchs in Python identisch ist
So vergleichen Sie, ob der Inhalt der Objekte in scipy.sparse.csr_matrix identisch ist
Wie man eine Benutzergruppe mit Slack-Benachrichtigung erwähnt, wie man die ID einer Benutzergruppe überprüft
So installieren Sie das Deep Learning Framework Tensorflow 1.0 in der Windows Anaconda-Umgebung
So überprüfen Sie in Python, ob sich eines der Elemente einer Liste in einer anderen Liste befindet
[Ubuntu] So löschen Sie den gesamten Inhalt des Verzeichnisses
So finden Sie die optimale Anzahl von Clustern für k-means
So sehen Sie den Inhalt der ipynb-Datei des Jupyter-Notizbuchs
So verbinden Sie den Inhalt der Liste mit einer Zeichenfolge
So führen Sie TensorFlow 1.0-Code in 2.0 aus
[Super einfach! ] So zeigen Sie den Inhalt von Wörterbüchern und Listen einschließlich Japanisch in Python an
So geben Sie eine unendliche Anzahl von Toleranzen in der Überprüfung der numerischen Argumentvalidierung von argparse an
So bestimmen Sie die Existenz eines Selenelements in Python
Wie Sie die interne Struktur eines Objekts in Python kennen
So ändern Sie die Farbe nur der mit Tkinter gedrückten Taste
[EC2] So installieren Sie Chrome und den Inhalt jedes Befehls
Geben Sie den Inhalt von ~ .xlsx im Ordner mit Python in HTML aus
So ermitteln Sie die Scheitelpunktkoordinaten eines Features in ArcPy
Erstellen Sie eine Funktion, um den Inhalt der Datenbank in Go abzurufen
Überprüfen Sie das Verhalten des Zerstörers in Python
So überprüfen Sie die Version von opencv mit Python
Grundlagen von PyTorch (1) - Verwendung von Tensor-
So überprüfen Sie die lokale GAE über den iPhone-Browser im selben LAN
[Einführung in Python] So sortieren Sie den Inhalt einer Liste effizient mit Listensortierung
Ich habe ein Programm erstellt, um die Größe einer Datei mit Python zu überprüfen
Ich habe versucht, den Höhenwert von DTM in einem Diagramm anzuzeigen
Kopieren und Einfügen des Inhalts eines Blattes im JSON-Format mit einer Google-Tabelle (mithilfe von Google Colab)
So berechnen Sie die Volatilität einer Marke
Verwendung der C-Bibliothek in Python
So finden Sie den Bereich des Boronoi-Diagramms
Die Ungenauigkeit von Tensorflow war auf log (0) zurückzuführen.
Zusammenfassung zum Importieren von Dateien in Python 3
So führen Sie CNN in 1 Systemnotation mit Tensorflow 2 aus
Verwendung der Grafikzeichnungsbibliothek Bokeh
Zusammenfassung der Verwendung von MNIST mit Python
So überprüfen / extrahieren Sie Dateien im RPM-Paket
So erhalten Sie die Dateien im Ordner [Python]
[TF] So erstellen Sie Tensorflow in einer Proxy-Umgebung
So übergeben Sie das Ergebnis der Ausführung eines Shell-Befehls in einer Liste in Python
So identifizieren Sie die Zugriffsquelle in der generischen Klassenansicht von Django eindeutig
So zählen Sie die Anzahl der Elemente in Django und geben sie in die Vorlage aus
Ein Memorandum zur Ausführung des Befehls! Sudo magic in Jupyter Notebook
So ermitteln Sie den Koeffizienten der ungefähren Kurve, die in Python durch die Scheitelpunkte verläuft
So ändern Sie das Erscheinungsbild nicht ausgewählter Fremdschlüsselfelder in Djangos Modellformular
So stellen Sie die Schriftbreite des in pyenv eingegebenen Jupyter-Notizbuchs gleich
So erhalten Sie mit Python eine Liste der Dateien im selben Verzeichnis
Einfache Möglichkeit, die Quelle der Python-Module zu überprüfen
So rufen Sie den n-ten größten Wert in Python ab
Ich habe versucht, die in Python installierten Pakete grafisch darzustellen
Wie man die Portnummer des xinetd-Dienstes kennt
So erhalten Sie den Variablennamen selbst in Python
Ausführen des in Ansible Tower hinzugefügten Ansible-Moduls
Vorlage des Python-Skripts zum Lesen des Inhalts der Datei
Wie identifiziere ich das Element mit der geringsten Anzahl von Zeichen in einer Python-Liste?
Lassen Sie die Häkchen nach dem Dezimalpunkt in matplotlib weg