[PYTHON] Hack GraphConvModel implemented in DeepChem with summary

Introduction

I wanted to start Deep Learning with compounds, so I decided to hack DeepChem's GraphConvModel and implement it in Keras. So, first of all, I decided to output what is implemented in Keras by the summary method of the model object.

environment

Method

Put model.summary () in line 624 of the file where the class definition of GraphConvModel is defined, and try to create a prediction model with appropriate data.

<anaconda3>/envs/deepchem/lib/python3.7/site-packages/deepchem/models/graph_conv.py


    print(model.summary())

result

Like this. I have read the paper and have a general idea of what it is, but DeepChem is a little different from the paper, and I will analyze it from now on.

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 75)]         0
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 2)]          0
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None,)]            0
__________________________________________________________________________________________________
input_6 (InputLayer)            [(None, 1)]          0
__________________________________________________________________________________________________
input_7 (InputLayer)            [(None, 2)]          0
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 3)]          0
__________________________________________________________________________________________________
input_9 (InputLayer)            [(None, 4)]          0
__________________________________________________________________________________________________
input_10 (InputLayer)           [(None, 5)]          0
__________________________________________________________________________________________________
input_11 (InputLayer)           [(None, 6)]          0
__________________________________________________________________________________________________
input_12 (InputLayer)           [(None, 7)]          0
__________________________________________________________________________________________________
input_13 (InputLayer)           [(None, 8)]          0
__________________________________________________________________________________________________
input_14 (InputLayer)           [(None, 9)]          0
__________________________________________________________________________________________________
input_15 (InputLayer)           [(None, 10)]         0
__________________________________________________________________________________________________
input_16 (InputLayer)           [(None, 11)]         0
__________________________________________________________________________________________________
graph_conv (GraphConv)          (None, 64)           102144      input_1[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 64)           256         graph_conv[0][0]
__________________________________________________________________________________________________
graph_pool (GraphPool)          (None, 64)           0           batch_normalization[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
graph_conv_1 (GraphConv)        (None, 64)           87360       graph_pool[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64)           256         graph_conv_1[0][0]
__________________________________________________________________________________________________
graph_pool_1 (GraphPool)        (None, 64)           0           batch_normalization_1[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 128)          8320        graph_pool_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128)          512         dense[0][0]
__________________________________________________________________________________________________
graph_gather (GraphGather)      (64, 256)            0           batch_normalization_2[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (64, 2)              514         graph_gather[0][0]
__________________________________________________________________________________________________
reshape (Reshape)               (64, 1, 2)           0           dense_1[0][0]
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None,)]            0
__________________________________________________________________________________________________
trim_graph_output (TrimGraphOut (None, 1, 2)         0           reshape[0][0]
                                                                 input_4[0][0]
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None,)]            0
__________________________________________________________________________________________________
softmax (Softmax)               (None, 1, 2)         0           trim_graph_output[0][0]
==================================================================================================
Total params: 199,362
Trainable params: 198,850
Non-trainable params: 512
__________________________________________________________________________________________________


Recommended Posts

Hack GraphConvModel implemented in DeepChem with summary
Implemented SimRank in Python
Implemented hard-swish in Keras
Implemented Shiritori in Python
[For beginners] Summary of standard input in Python (with explanation)