[PYTHON] Dessiner GoogLeNet avec draw_net de caffe

Dessiner GoogLeNet avec draw_net de caffe serait incroyable (par exemple, this ou this /index.php?plugin=attach&refer=Caffe network & openfile = googlenet.gif)), simplifie le dessin.

caffe/python/caffe/draw.Réécrire py


diff --git a/python/caffe/draw.py b/python/caffe/draw.py
index a002b60..2fb0606 100644
--- a/python/caffe/draw.py
+++ b/python/caffe/draw.py
@@ -75,31 +75,37 @@ def get_layer_label(layer, rankdir):
         separator = '\\n'
 
     if layer.type == 'Convolution' or layer.type == 'Deconvolution':
-        # Outer double quotes needed or else colon characters don't parse
-        # properly
-        node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
-                     (layer.name,
-                      separator,
-                      layer.type,
-                      separator,
-                      layer.convolution_param.kernel_size,
-                      separator,
-                      layer.convolution_param.stride,
-                      separator,
-                      layer.convolution_param.pad)
-    elif layer.type == 'Pooling':
-        pooling_types_dict = get_pooling_types_dict()
-        node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
-                     (layer.name,
-                      separator,
-                      pooling_types_dict[layer.pooling_param.pool],
-                      layer.type,
-                      separator,
-                      layer.pooling_param.kernel_size,
-                      separator,
-                      layer.pooling_param.stride,
-                      separator,
-                      layer.pooling_param.pad)
+        separator = '\\n'
+
+#     if layer.type == 'Convolution' or layer.type == 'Deconvolution':
+#         # Outer double quotes needed or else colon characters don't parse
+#         # properly
+#         node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
+#                      (layer.name,
+#                       separator,
+#                       layer.type,
+#                       separator,
+#                       layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1,
+#                       separator,
+#                       layer.convolution_param.stride[0] if len(layer.convolution_param.stride._values) else 1,
+#                       separator,
+#                       layer.convolution_param.pad[0] if len(layer.convolution_param.pad._values) else 0)
+#     elif layer.type == 'Pooling':
+#         pooling_types_dict = get_pooling_types_dict()
+#         node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
+#                      (layer.name,
+#                       separator,
+#                       pooling_types_dict[layer.pooling_param.pool],
+#                       layer.type,
+#                       separator,
+#                       layer.pooling_param.kernel_size,
+#                       separator,
+#                       layer.pooling_param.stride,
+#                       separator,
+#                       layer.pooling_param.pad)
+#     else:
+    if layer.type == 'InnerProduct' or layer.type == '':
+        node_label = '"%s%s(%s)"' % (layer.name, '\\n', 'full connect')
     else:
         node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
     return node_label
@@ -140,6 +146,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
     pydot_edges = []
     for layer in caffe_net.layer:
         node_label = get_layer_label(layer, rankdir)
+        if layer.type == 'Dropout' or layer.type == 'ReLU': continue
         node_name = "%s_%s" % (layer.name, layer.type)
         if (len(layer.bottom) == 1 and len(layer.top) == 1 and
            layer.bottom[0] == layer.top[0]):
@@ -159,10 +166,10 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
                                 'label': edge_label})
         for top_blob in layer.top:
             pydot_nodes[top_blob + '_blob'] = pydot.Node('%s' % (top_blob))
-            if label_edges:
-                edge_label = get_edge_label(layer)
-            else:
-                edge_label = '""'
+#            if label_edges:
+#                edge_label = get_edge_label(layer)
+#            else:
+            edge_label = '""'
             pydot_edges.append({'src': node_name,
                                 'dst': top_blob + '_blob',
                                 'label': edge_label})

Puis exécutez. Ajoutez manuellement la couche fc de train_val.prototxt au fichier deploy.prototxt.

python


./caffe/python/draw_net.py deploy.prototxt googlenet-deploy.pdf --rankdir 'BT'

Ceci complète cette figure. googlenet-deploy-all.png

Recommended Posts

Dessiner GoogLeNet avec draw_net de caffe
Dessin en temps réel avec matplotlib
Dessiner avec Python Tinker
Méthode de dessin graphique avec matplotlib
Dessin graphique avec IPython Notebook
Dessin 3D avec SceneKit dans Pythonista