** Die folgenden Tools wurden veröffentlicht und wir empfehlen, sie zu verwenden. ** ** ** Ich habe ein Tool erstellt, das die Architektur veranschaulicht, wenn das Faltungsnetzwerk wie Keras definiert ist
Wenn Sie ein Modell in einer Notation wie dem sequentiellen Modell von Keras definieren, haben wir ein Tool erstellt, das die Architektur gut veranschaulicht. Es kann sich um eine abhängige Bibliothek handeln, da es sich um ein Tool handelt, das nur Text ausgibt. https://github.com/yu4u/convnet-drawer
Verwenden Sie Python + Pydot + Graphviz, um ein Diagramm der CNN-Architektur zu zeichnen. Meine Motivation war es, auf https://github.com/jettan/tikz_cnn zu schauen und ein ähnliches Diagramm in Python anstelle von TeX zu zeichnen.
Installieren Sie pydotplus und graphviz. Ich benutze Conda, aber ich denke, Pip ist in Ordnung (nicht überprüft).
conda install -c conda-forge pydotplus
conda install graphviz
Bereiten Sie eine entsprechende Punktdatei vor, laden Sie sie mit pydotplus, speichern Sie das Bild und zeigen Sie das Bild an. (Das Bild wird auf Jupyter angezeigt. Bitte bearbeiten Sie es entsprechend.)
drawCNN.py
import pydotplus
from IPython.display import Image
graph = pydotplus.graphviz.graph_from_dot_file('dot/pytorchainer.dot')
graph.write_png('img/pytorchainer.png')
Image(graph.create_png())
pytorchainer.dot
digraph G {
Python [shape=box]
Torch
Chainer -> "Chainer v2"
Chainer -> ChainerMN
Python -> PyTorch
Torch -> PyTorch
Chainer -> PyTorch
PyTorch -> PyTorChainer
"Chainer v2" -> PyTorChainer
ChainerMN -> PyTorChainer
Diese Figur ist Fiktion.[shape=plaintext]
}
Jetzt können Sie die Punktdatei aus Python zeichnen.
Die technischen Daten der Punktsprache und von PyDot Plus finden Sie im Folgenden. Zusammenfassung zum Zeichnen von Grafiken in Graphviz- und Punktsprachen PyDotPlus API Reference
Zeichnen wir nun ein Diagramm der CNN-Architektur. Das heißt, Sie müssen nur Ebenen (und Pfeile) hinzufügen, die in Punktsprache geschrieben sind. Unten tanzt die magische Zahl für die Positionsanpassung, aber bitte vergib mir.
drawCNN.py
class CNNDot():
def __init__(self):
self.layer_id = 0
self.arrow_id = 0
def get_layer_str(self, size, channels, xoffset=0.0, yoffset=0.0, fillcolor='white', caption=''):
width = size * 0.5
height = size
x = xoffset
y = height * 0.5 + yoffset
x_caption = x - width * 0.25
y_caption = -y - 0.7
layer_str = """
layer{} [
shape=polygon, sides=4, skew=-2, orientation=90,
label="", style=filled, fixedsize=true, fillcolor="{}",
width={}, height={}, pos="{},{}!"
]
""".format(self.layer_id, fillcolor, width, height, x, y)
if caption != '':
layer_str += """
layer_caption{} [
shape=plaintext, label="{}", fixedsize=true, fontsize=24,
pos="{},{}!"
]
""".format(self.layer_id, caption, x_caption, y_caption)
self.layer_id += 1
return layer_str
def get_arrow_str(self, xmin, ymin, xmax, ymax):
arrow_str = """
arrow{0}_tail [
shape=none, label="", fixedsize=true, width=0, height=0,
pos="{1},{2}!"
]
arrow{0}_head [
shape=none, label="", fixedsize=true, width=0, height=0,
pos="{3},{4}!"
]
arrow{0}_tail -> arrow{0}_head
""".format(self.arrow_id, xmin, ymin, xmax, ymax)
self.arrow_id += 1
return arrow_str
cnndot = CNNDot()
# layers
graph_data_main = cnndot.get_layer_str(3.0, 0, -1.00, fillcolor='gray') # input
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.00, caption='conv') # encoder begin
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.50)
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.25, caption='conv')
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 2.50, caption='conv')
graph_data_main += cnndot.get_layer_str(2.0, 0, 3.00)
graph_data_main += cnndot.get_layer_str(1.5, 0, 3.75, caption='conv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 4.25)
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.00, caption='conv')
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.50)
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.25, caption='deconv') # decoder begin
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.75)
graph_data_main += cnndot.get_layer_str(1.5, 0, 7.50, caption='deconv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 8.00)
graph_data_main += cnndot.get_layer_str(2.0, 0, 8.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 9.25)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.00)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.50)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.25)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.75)
graph_data_main += cnndot.get_layer_str(3.0, 0, 12.75, fillcolor='#FF8080') # output
# arrows
graph_data_main += cnndot.get_arrow_str(0.50, 3.0*1.2, 11.25-0.22, 3.0*1.2)
graph_data_main += cnndot.get_arrow_str(1.75, 2.5*1.2, 10.00-0.20, 2.5*1.2)
graph_data_main += cnndot.get_arrow_str(3.00, 2.0*1.2, 8.75-0.18, 2.0*1.2)
graph_data_main += cnndot.get_arrow_str(4.25, 1.5*1.2, 7.50-0.16, 1.5*1.2)
graph_data_main += cnndot.get_arrow_str(5.50, 1.0*1.2, 6.25-0.14, 1.0*1.2)
graph_data_setting = 'graph[ layout = neato, size="16,8"]'
graph_data = 'digraph G {{ \n{}\n{}\n }}'.format(graph_data_setting, graph_data_main)
graph = pydotplus.graphviz.graph_from_dot_data(graph_data)
# save and show image
graph.write_png('img/encoder-decoder.png')
Image(graph.create_png())
Für diesen Code sehen Sie ungefähr Folgendes: (Es ist eine Spezifikation, dass jede Schicht dünn ist. Wenn Sie die Seiten dehnen, sollten Sie in der Lage sein, einen quadratischen Körper zu zeichnen.)
Es unterscheidet sich vom obigen Code, aber ich habe versucht, ein Keras-Modell (InceptionV3) zu zeichnen. Das Quadrat wird mit svg write gezeichnet und eingefügt.
Recommended Posts