Dessinez un diagramme CNN en Python

ConvNet Drawer (ajouté le 4 janvier 2018)

** Les outils suivants ont été publiés et nous vous recommandons de les utiliser. ** ** J'ai créé un outil qui illustre l'architecture lorsque le réseau de neurones convolutifs est défini comme Keras

Lorsque vous définissez un modèle dans une notation comme le modèle séquentiel de Keras, nous avons créé un outil qui illustre bien l'architecture. Il peut s'agir d'une bibliothèque dépendante car c'est un outil qui ne produit que du texte. https://github.com/yu4u/convnet-drawer

Aperçu

Utilisez Python + pydot + Graphviz pour dessiner un diagramme de l'architecture CNN. Ma motivation était de regarder https://github.com/jettan/tikz_cnn et de vouloir dessiner un diagramme similaire en Python au lieu de TeX.

Préparation

Installez pydotplus et graphviz. J'utilise conda, mais je pense que pip est correct (non vérifié).

conda install -c conda-forge pydotplus
conda install graphviz

Préparez un fichier de points approprié, chargez-le avec pydotplus, enregistrez l'image et affichez l'image. (L'image est affichée sur Jupyter. Veuillez la modifier si nécessaire.)

drawCNN.py


import pydotplus
from IPython.display import Image

graph = pydotplus.graphviz.graph_from_dot_file('dot/pytorchainer.dot')
graph.write_png('img/pytorchainer.png')
Image(graph.create_png())

pytorchainer.dot


digraph G {

Python [shape=box]
Torch

Chainer -> "Chainer v2"
Chainer -> ChainerMN

Python -> PyTorch
Torch -> PyTorch
Chainer -> PyTorch

PyTorch -> PyTorChainer
"Chainer v2" -> PyTorChainer
ChainerMN -> PyTorChainer

Ce chiffre est une fiction.[shape=plaintext]

}

pytorchainer.png Vous êtes maintenant prêt à dessiner le fichier dot de Python.

Veuillez vous référer à ce qui suit pour les spécifications du langage dot et de PyDot Plus. Résumé de la façon de dessiner des graphiques en graphviz et en langage à points PyDotPlus API Reference

Dessin CNN

Maintenant, dessinons un diagramme de l'architecture CNN. Cela dit, tout ce que vous avez à faire est d'ajouter des couches (et des flèches) écrites en langage par points. Ci-dessous, le chiffre magique pour l'ajustement de la position danse, mais pardonnez-moi s'il vous plaît.

drawCNN.py


class CNNDot():
    def __init__(self):
        self.layer_id = 0
        self.arrow_id = 0

    def get_layer_str(self, size, channels, xoffset=0.0, yoffset=0.0, fillcolor='white', caption=''):
        width = size * 0.5
        height = size
        x = xoffset
        y = height * 0.5 + yoffset
        x_caption = x - width * 0.25
        y_caption = -y - 0.7
        
        layer_str = """
          layer{} [
              shape=polygon, sides=4, skew=-2, orientation=90,
              label="", style=filled, fixedsize=true, fillcolor="{}",
              width={}, height={}, pos="{},{}!"
          ]
        """.format(self.layer_id, fillcolor, width, height, x, y)

        if caption != '':
            layer_str += """
              layer_caption{} [
                  shape=plaintext, label="{}", fixedsize=true, fontsize=24,
                  pos="{},{}!"
              ]
            """.format(self.layer_id, caption, x_caption, y_caption)

        self.layer_id += 1
        return layer_str

    def get_arrow_str(self, xmin, ymin, xmax, ymax):
        arrow_str = """
            arrow{0}_tail [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{1},{2}!"
            ]
            arrow{0}_head [
                shape=none, label="", fixedsize=true, width=0, height=0,
                pos="{3},{4}!"
            ]
            arrow{0}_tail -> arrow{0}_head
        """.format(self.arrow_id, xmin, ymin, xmax, ymax)
        self.arrow_id += 1
        return arrow_str
        
cnndot = CNNDot()
# layers
graph_data_main = cnndot.get_layer_str(3.0, 0, -1.00, fillcolor='gray') # input
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.00, caption='conv') # encoder begin
graph_data_main += cnndot.get_layer_str(3.0, 0, 0.50)
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.25, caption='conv')
graph_data_main += cnndot.get_layer_str(2.5, 0, 1.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 2.50, caption='conv')
graph_data_main += cnndot.get_layer_str(2.0, 0, 3.00)
graph_data_main += cnndot.get_layer_str(1.5, 0, 3.75, caption='conv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 4.25)
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.00, caption='conv')
graph_data_main += cnndot.get_layer_str(1.0, 0, 5.50)
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.25, caption='deconv') # decoder begin
graph_data_main += cnndot.get_layer_str(1.0, 0, 6.75)
graph_data_main += cnndot.get_layer_str(1.5, 0, 7.50, caption='deconv')
graph_data_main += cnndot.get_layer_str(1.5, 0, 8.00)
graph_data_main += cnndot.get_layer_str(2.0, 0, 8.75)
graph_data_main += cnndot.get_layer_str(2.0, 0, 9.25)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.00)
graph_data_main += cnndot.get_layer_str(2.5, 0, 10.50)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.25)
graph_data_main += cnndot.get_layer_str(3.0, 0, 11.75)
graph_data_main += cnndot.get_layer_str(3.0, 0, 12.75, fillcolor='#FF8080') # output

# arrows
graph_data_main += cnndot.get_arrow_str(0.50, 3.0*1.2, 11.25-0.22, 3.0*1.2)
graph_data_main += cnndot.get_arrow_str(1.75, 2.5*1.2, 10.00-0.20, 2.5*1.2)
graph_data_main += cnndot.get_arrow_str(3.00, 2.0*1.2, 8.75-0.18, 2.0*1.2)
graph_data_main += cnndot.get_arrow_str(4.25, 1.5*1.2, 7.50-0.16, 1.5*1.2)
graph_data_main += cnndot.get_arrow_str(5.50, 1.0*1.2, 6.25-0.14, 1.0*1.2)

graph_data_setting = 'graph[ layout = neato, size="16,8"]'
graph_data = 'digraph G {{ \n{}\n{}\n }}'.format(graph_data_setting, graph_data_main)
graph = pydotplus.graphviz.graph_from_dot_data(graph_data)

# save and show image
graph.write_png('img/encoder-decoder.png')
Image(graph.create_png())

Pour ce code, vous verrez quelque chose comme ce qui suit: (C'est une spécification que chaque couche est mince. Si vous étirez les côtés, vous devriez être en mesure de dessiner un corps carré.)

encoder-decoder.png

Impressions

Dessin Inception V3 (PostScript 2017/4/30)

C'est différent du code ci-dessus, mais j'ai essayé de dessiner un modèle Keras (InceptionV3). Le carré est dessiné et collé avec une écriture svg.

inceptionv3.png

Recommended Posts

Dessinez un diagramme CNN en Python
Dessinez un cœur en Python
Dessinez une matrice de diagramme de dispersion avec python
Dessiner un cœur avec Python Partie 2 (SymPy Edition)
Dessinez une structure arborescente en Python 3 à l'aide de graphviz
Dessiner un graphique avec python
Dessiner un graphique d'une fonction quadratique en Python
[Python] Comment dessiner un histogramme avec Matplotlib
Créer une fonction en Python
Créer un dictionnaire en Python
Dessinez le disque de Poancare en Python
Dessiner "Dessiner une fougère par programme" en Python
Créer un bookmarklet en Python
Dessiner la fonction Yin en python
[Python] Dessinez un diagramme de relation de balises Qiita avec NetworkX
Dessinez des ondes sinusoïdales avec Blender Python
Probablement dans un serpent Nishiki (Titre original: Peut-être en Python)
[python] Gérer les fonctions dans une liste
Créer un conteneur DI avec Python
Dessinez des nœuds de manière interactive avec Plotly (Python)
ABC166 en Python A ~ C problème
Ecrire des algorithmes A * (A-star) en Python
Créer un fichier binaire en Python
Résoudre ABC036 A ~ C avec Python
Ecrire un graphique à secteurs en Python
Ecrire le plugin vim en Python
Écrire une recherche de priorité en profondeur en Python
Implémentation d'un algorithme simple en Python 2
Résoudre ABC037 A ~ C avec Python
Exécutez un algorithme simple en Python
Créer une chaîne aléatoire en Python
Lors de l'écriture d'un programme en Python
Dessinez une illusion d'aquarelle avec détection des contours en Python3 et openCV3
Livre en spirale en Python! Python avec un livre en spirale! (Chapitre 14 ~)
Résoudre ABC175 A, B, C avec Python
Utiliser l'impression dans l'expression lambda Python2
Un client HTTP simple implémenté en Python
Faites une visite Euler non récursive en Python
J'ai fait un programme de gestion de la paie en Python!
Ecrire le test dans la docstring python
Essayez d'envoyer un paquet SYN en Python
Essayez de dessiner une animation simple en Python
Créer une application GUI simple en Python
Dessinez un cœur en rubis avec PyCall
Dessinez Nozomi Sasaki dans Excel avec python
Créer un diagramme de relations des modules Python
[Python] [Windows] Faites une capture d'écran avec Python
Exécuter l'interpréteur Python dans le script
Comment obtenir stacktrace en python
Ecrire un programme de chiffrement Caesar en Python
Hash en Perl est un dictionnaire en Python
Scraping de sites Web à l'aide de JavaScript en Python
Ecrire une méthode de cupidité simple en Python
Lancer une application Flask dans Python Anywhere
Obtenez un jeton pour conoha avec python
[GPS] Créer un fichier kml avec Python
Ecrire un plugin Vim simple en Python 3
Générer une classe à partir d'une chaîne en Python