Zeichnen Sie ein CNN-Diagramm in Python

ConvNet Drawer (hinzugefügt am 4. Januar 2018)

** Die folgenden Tools wurden veröffentlicht und wir empfehlen, sie zu verwenden. ** ** ** Ich habe ein Tool erstellt, das die Architektur veranschaulicht, wenn das Faltungsnetzwerk wie Keras definiert ist

Wenn Sie ein Modell in einer Notation wie dem sequentiellen Modell von Keras definieren, haben wir ein Tool erstellt, das die Architektur gut veranschaulicht. Es kann sich um eine abhängige Bibliothek handeln, da es sich um ein Tool handelt, das nur Text ausgibt. https://github.com/yu4u/convnet-drawer

Überblick

Verwenden Sie Python + Pydot + Graphviz, um ein Diagramm der CNN-Architektur zu zeichnen. Meine Motivation war es, auf https://github.com/jettan/tikz_cnn zu schauen und ein ähnliches Diagramm in Python anstelle von TeX zu zeichnen.

Vorbereitung

Installieren Sie pydotplus und graphviz. Ich benutze Conda, aber ich denke, Pip ist in Ordnung (nicht überprüft).

conda install -c conda-forge pydotplus
conda install graphviz

Bereiten Sie eine entsprechende Punktdatei vor, laden Sie sie mit pydotplus, speichern Sie das Bild und zeigen Sie das Bild an. (Das Bild wird auf Jupyter angezeigt. Bitte bearbeiten Sie es entsprechend.)

drawCNN.py


import pydotplus
from IPython.display import Image

graph = pydotplus.graphviz.graph_from_dot_file('dot/pytorchainer.dot')
graph.write_png('img/pytorchainer.png')
Image(graph.create_png())

pytorchainer.dot


digraph G {

Python [shape=box]
Torch

Chainer -> "Chainer v2"
Chainer -> ChainerMN

Python -> PyTorch
Torch -> PyTorch
Chainer -> PyTorch

PyTorch -> PyTorChainer
"Chainer v2" -> PyTorChainer
ChainerMN -> PyTorChainer

Diese Figur ist Fiktion.[shape=plaintext]

}

pytorchainer.png Jetzt können Sie die Punktdatei aus Python zeichnen.

Die technischen Daten der Punktsprache und von PyDot Plus finden Sie im Folgenden. Zusammenfassung zum Zeichnen von Grafiken in Graphviz- und Punktsprachen PyDotPlus API Reference

CNN Zeichnung

Zeichnen wir nun ein Diagramm der CNN-Architektur. Das heißt, Sie müssen nur Ebenen (und Pfeile) hinzufügen, die in Punktsprache geschrieben sind. Unten tanzt die magische Zahl für die Positionsanpassung, aber bitte vergib mir.

drawCNN.py


class CNNDot():
    def __init__(self):
        self.layer_id = 0
        self.arrow_id = 0

    def get_layer_str(self, size, channels, xoffset=0.0, yoffset=0.0, fillcolor='white', caption=''):
        width = size * 0.5
        height = size
        x = xoffset
        y = height * 0.5 + yoffset
        x_caption = x - width * 0.25
        y_caption = -y - 0.7
        
        layer_str = """
          layer{} [
              shape=polygon, sides=4, skew=-2, orientation=90,
              label="", style=filled, fixedsize=true, fillcolor="{}",
              width={}, height={}, pos="{},{}!"
          ]
        """.format(self.layer_id, fillcolor, width, height, x, y)

        if caption != '':
            layer_str += """
              layer_caption{} [
                  shape=plaintext, label="{}", fixedsize=true, fontsize=24,
                  pos="{},{}!"
              ]
            """.format(self.layer_id, caption, x_caption, y_caption)

        self.layer_id += 1
        return layer_str

    def get_arrow_str(self, xmin, ymin, xmax, ymax):
        arrow_str = """
            arrow{0}_tail [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{1},{2}!"
            ]
            arrow{0}_head [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{3},{4}!"
            ]
            arrow{0}_tail -> arrow{0}_head
        """.format(self.arrow_id, xmin, ymin, xmax, ymax)
        self.arrow_id += 1
        return arrow_str
        
cnndot = CNNDot()
# layers
graph_data_main = cnndot.get_layer_str(3.0, 0, -1.00, fillcolor='gray') # input
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.00, caption='conv') # encoder begin
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.50)
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.25, caption='conv')
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 2.50, caption='conv')
graph_data_main += cnndot.get_layer_str(2.0, 0, 3.00)
graph_data_main += cnndot.get_layer_str(1.5, 0, 3.75, caption='conv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 4.25)
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.00, caption='conv')
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.50)
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.25, caption='deconv') # decoder begin
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.75)
graph_data_main += cnndot.get_layer_str(1.5, 0, 7.50, caption='deconv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 8.00)
graph_data_main += cnndot.get_layer_str(2.0, 0, 8.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 9.25)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.00)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.50)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.25)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.75)
graph_data_main += cnndot.get_layer_str(3.0, 0, 12.75, fillcolor='#FF8080') # output

# arrows
graph_data_main += cnndot.get_arrow_str(0.50, 3.0*1.2, 11.25-0.22, 3.0*1.2)
graph_data_main += cnndot.get_arrow_str(1.75, 2.5*1.2, 10.00-0.20, 2.5*1.2)
graph_data_main += cnndot.get_arrow_str(3.00, 2.0*1.2, 8.75-0.18, 2.0*1.2)
graph_data_main += cnndot.get_arrow_str(4.25, 1.5*1.2, 7.50-0.16, 1.5*1.2)
graph_data_main += cnndot.get_arrow_str(5.50, 1.0*1.2, 6.25-0.14, 1.0*1.2)

graph_data_setting = 'graph[ layout = neato, size="16,8"]'
graph_data = 'digraph G {{ \n{}\n{}\n }}'.format(graph_data_setting, graph_data_main)
graph = pydotplus.graphviz.graph_from_dot_data(graph_data)

# save and show image
graph.write_png('img/encoder-decoder.png')
Image(graph.create_png())

Für diesen Code sehen Sie ungefähr Folgendes: (Es ist eine Spezifikation, dass jede Schicht dünn ist. Wenn Sie die Seiten dehnen, sollten Sie in der Lage sein, einen quadratischen Körper zu zeichnen.)

encoder-decoder.png

Impressionen

Inception V3 Zeichnung (2017/4/30 Nachtrag)

Es unterscheidet sich vom obigen Code, aber ich habe versucht, ein Keras-Modell (InceptionV3) zu zeichnen. Das Quadrat wird mit svg write gezeichnet und eingefügt.

inceptionv3.png

Recommended Posts

Zeichnen Sie ein CNN-Diagramm in Python
Zeichne ein Herz in Python
Zeichnen Sie eine Streudiagrammmatrix mit Python
Mit Python Teil 2 ein Herz zeichnen (SymPy Edition)
Zeichnen Sie mit graphviz eine Baumstruktur in Python 3
Zeichnen Sie ein Diagramm mit Python
Zeichnen Sie in Python ein Diagramm einer quadratischen Funktion
[Python] Wie zeichnet man mit Matplotlib ein Histogramm?
Erstellen Sie eine Funktion in Python
Erstellen Sie ein Wörterbuch in Python
Zeichnen Sie die Festplatte von Poancare in Python
Zeichnen Sie "Farn programmgesteuert zeichnen" in Python
Erstellen Sie ein Lesezeichen in Python
Zeichne die Yin-Funktion in Python
[Python] Zeichnen Sie ein Qiita-Tag-Beziehungsdiagramm mit NetworkX
Zeichnen Sie Sinuswellen mit Blender Python
Wahrscheinlich in einer Nishiki-Schlange (Originaltitel: Vielleicht in Python)
[Python] Verwalten Sie Funktionen in einer Liste
Erstellen Sie einen DI-Container mit Python
Zeichnen Sie Knoten interaktiv mit Plotly (Python)
ABC166 in Python A ~ C Problem
Schreiben Sie A * (A-Stern) -Algorithmen in Python
Erstellen Sie eine Binärdatei in Python
Löse ABC036 A ~ C mit Python
Schreiben Sie ein Kreisdiagramm in Python
Schreiben Sie das Vim-Plugin in Python
Schreiben Sie eine Suche mit Tiefenpriorität in Python
Implementierung eines einfachen Algorithmus in Python 2
Löse ABC037 A ~ C mit Python
Führen Sie einen einfachen Algorithmus in Python aus
Erstellen Sie eine zufällige Zeichenfolge in Python
Beim Schreiben eines Programms in Python
Zeichnen Sie eine Aquarellillusion mit Kantenerkennung in Python3 und openCV3
Spiralbuch in Python! Python mit einem Spiralbuch! (Kapitel 14 ~)
Löse ABC175 A, B, C mit Python
Verwenden Sie print in Python2 lambda expression
Ein einfacher HTTP-Client, der in Python implementiert ist
Führen Sie eine nicht rekursive Euler-Tour in Python durch
Ich habe ein Pay-Management-Programm in Python erstellt!
Schreiben Sie den Test in die Python-Dokumentzeichenfolge
Versuchen Sie, ein SYN-Paket in Python zu senden
Versuchen Sie, eine einfache Animation in Python zu zeichnen
Erstellen Sie eine einfache GUI-App in Python
Zeichne mit PyCall ein Herz in Ruby
Zeichnen Sie Nozomi Sasaki in Excel mit Python
Erstellen Sie ein Beziehungsdiagramm von Python-Modulen
[Python] [Windows] Machen Sie eine Bildschirmaufnahme mit Python
Führen Sie den Python-Interpreter im Skript aus
Wie bekomme ich Stacktrace in Python?
Schreiben Sie ein Caesar-Verschlüsselungsprogramm in Python
Hash in Perl ist ein Wörterbuch in Python
Scraping von Websites mit JavaScript in Python
Schreiben Sie eine einfache Giermethode in Python
Starten Sie eine Flask-App in Python Anywhere
Holen Sie sich ein Zeichen für Conoha mit Python
[GPS] Erstellen Sie eine kml-Datei mit Python
Schreiben Sie ein einfaches Vim-Plugin in Python 3
Generieren Sie eine Klasse aus einer Zeichenfolge in Python