[PYTHON] Ich möchte mich mit Backpropagation (tf.custom_gradient) (Tensorflow) selbst verwenden.

Normales Schreiben bei Verwendung von custom_graident
@tf.custom_gradient
def gradient_reversal(x):
  y = x
  def grad(dy):
    return - dy
  return y, grad

#Bei Verwendung im Modell
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()

    def call(self, x):
        return gradient_reversal(x)
class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)
TypeError: If using @custom_gradient with a function that uses variables, then grad_fn must accept a keyword argument 'variables'.

Verifizierungs-Schlüssel

import tensorflow as tf


optimizer = tf.keras.optimizers.Adam(learning_rate=0.1)


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            z = self.alpha * w
            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            return z, variables

        return y, backward

    def call(self, x):
        return self.forward(x)


class MyModel2(tf.keras.Model):
    def __init__(self):
        super(MyModel2, self).__init__()
        self.alpha = self.add_weight(name="alpha", initializer=tf.keras.initializers.Ones())

    @tf.custom_gradient
    def forward(self, x):
        y = self.alpha * x
        tf.print("forward")
        tf.print("  y: ", y)

        def backward(w, variables=None):
            with tf.GradientTape() as tape:
                tape.watch(w)
                z = - self.alpha * w

            grads = tape.gradient(z, [w])

            tf.print("backward")
            tf.print("  z: ", z)
            tf.print("  variables: ", variables)
            tf.print("  alpha: ", self.alpha)
            tf.print("  grads: ", grads)
            return z, grads

        return y, backward

    def call(self, x):
        return self.forward(x)


for model in [MyModel(), MyModel2()]:
    print()
    print()
    print()
    print(model.name)
    for i in range(10):
        with tf.GradientTape() as tape:
            x = tf.Variable(1.0, tf.float32)
            y = model(x)

        grads = tape.gradient(y, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        tf.print("step")
        tf.print("  y:", y)
        tf.print("  grads:", grads)
        print()

Recommended Posts

Ich möchte mich mit Backpropagation (tf.custom_gradient) (Tensorflow) selbst verwenden.
Ich möchte R-Datensatz mit Python verwenden
Ich möchte in der Einschlussnotation drucken
Ich möchte ein Glas aus Python verwenden
Ich möchte Linux auf dem Mac verwenden
Ich möchte IPython Qt Console verwenden
Ich möchte Matplotlib in PySimpleGUI einbetten
DQN mit TensorFlow implementiert (ich wollte ...)
Ich möchte die Django Debug Toolbar in Ajax-Anwendungen verwenden
Ich möchte Dunnetts Test in Python machen
Ich möchte MATLAB feval mit Python verwenden
Ich möchte Datetime.now in Djangos Test reparieren
Ich möchte mit Python ein Fenster erstellen
Ich möchte DB-Informationen in einer Liste speichern
Ich möchte verschachtelte Dicts in Python zusammenführen
Ich möchte Temporäres Verzeichnis mit Python2 verwenden
Ich möchte Ceres Solver aus Python verwenden
Ich möchte -inf nicht mit np.log verwenden
Ich möchte ip vrf mit SONiC verwenden
[Ich möchte Bilder mit Tensorflow klassifizieren] (2) Lassen Sie uns Bilder klassifizieren
Ich möchte die Aktivierungsfunktion Mish verwenden
Ich möchte den Fortschritt in Python anzeigen!
Ich möchte Python in der Umgebung von pyenv + pipenv unter Windows 10 verwenden
Ich möchte eine Variable in einen Python-String einbetten
Ich möchte mit einem Knopf am Kolben übergehen
Ich möchte in Python schreiben! (2) Schreiben wir einen Test
Auch mit JavaScript möchte ich Python `range ()` sehen!
Ich möchte eine Datei mit Python zufällig testen
Ich möchte mit einem Roboter in Python arbeiten.
Ich möchte in Python schreiben! (3) Verwenden Sie Mock
[TensorFlow] Ich möchte Fenster mit Ragged Tensor verarbeiten
Ich habe versucht zusammenzufassen, wie man Pandas von Python benutzt
Ich möchte OpenJDK 11 mit Ubuntu Linux 18.04 LTS / 18.10 verwenden
Ich möchte am Ende etwas mit Python machen
Ich möchte Strings in Kotlin wie Python manipulieren!
Ich möchte eine Python-Datenquelle in Re: Dash verwenden, um Abfrageergebnisse zu erhalten
Ich möchte SUDOKU lösen
[TensorFlow] Ich möchte die Indizierung für Ragged Tensor beherrschen
Ich möchte das neueste gcc verwenden, auch wenn ich keine Sudo-Berechtigungen habe! !!
Ich möchte R-Funktionen einfach mit ipython notebook verwenden
Ich möchte eine Spalte mit NA in R einfach löschen
Ich möchte so etwas wie Uniq in Python sortieren
Ich möchte nur die SudachiPy-Normalisierungsverarbeitung verwenden
[Python] Ich möchte die Option -h mit argparse verwenden
Ich möchte eine virtuelle Umgebung mit Jupyter Notebook verwenden!
[Django] Ich möchte mich nach einer neuen Registrierung automatisch anmelden
Ich möchte den Wörterbuchtyp in der Liste eindeutig machen
[Einführung in Pytorch] Ich möchte Sätze in Nachrichtenartikeln generieren
Ich möchte eindeutige Werte in einem Array oder Tupel zählen
Ich möchte die gültigen Zahlen im Numpy-Array ausrichten
Ich möchte Python mit VS-Code ausführen können
Ich möchte eine schöne Ergänzung zu input () in Python hinzufügen
Ich möchte VS Code und Spyder ohne Anakonda verwenden! !! !!
Ich wollte den AWS-Schlüssel nicht in das Programm schreiben
[Für diejenigen, die TPU verwenden möchten] Ich habe versucht, die Tensorflow Object Detection API 2 zu verwenden
Ich möchte Shortcut-Übersetzungen wie die DeepL-App auch unter Linux verwenden
Verwendung von Klassen in Theano
[Linux] Ich möchte das Datum wissen, an dem sich der Benutzer angemeldet hat
Mock in Python-Wie man Mox benutzt
Ich möchte APG4b mit Python lösen (nur 4.01 und 4.04 in Kapitel 4)