I wanted to start Deep Learning with compounds, so I decided to hack DeepChem's GraphConvModel and implement it in Keras. So, first of all, I decided to output what is implemented in Keras by the summary method of the model object.
Put model.summary () in line 624 of the file where the class definition of GraphConvModel is defined, and try to create a prediction model with appropriate data.
<anaconda3>/envs/deepchem/lib/python3.7/site-packages/deepchem/models/graph_conv.py
print(model.summary())
Like this. I have read the paper and have a general idea of what it is, but DeepChem is a little different from the paper, and I will analyze it from now on.
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 75)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
input_3 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
input_6 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_7 (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
input_8 (InputLayer) [(None, 3)] 0
__________________________________________________________________________________________________
input_9 (InputLayer) [(None, 4)] 0
__________________________________________________________________________________________________
input_10 (InputLayer) [(None, 5)] 0
__________________________________________________________________________________________________
input_11 (InputLayer) [(None, 6)] 0
__________________________________________________________________________________________________
input_12 (InputLayer) [(None, 7)] 0
__________________________________________________________________________________________________
input_13 (InputLayer) [(None, 8)] 0
__________________________________________________________________________________________________
input_14 (InputLayer) [(None, 9)] 0
__________________________________________________________________________________________________
input_15 (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
input_16 (InputLayer) [(None, 11)] 0
__________________________________________________________________________________________________
graph_conv (GraphConv) (None, 64) 102144 input_1[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 64) 256 graph_conv[0][0]
__________________________________________________________________________________________________
graph_pool (GraphPool) (None, 64) 0 batch_normalization[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
graph_conv_1 (GraphConv) (None, 64) 87360 graph_pool[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64) 256 graph_conv_1[0][0]
__________________________________________________________________________________________________
graph_pool_1 (GraphPool) (None, 64) 0 batch_normalization_1[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 128) 8320 graph_pool_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128) 512 dense[0][0]
__________________________________________________________________________________________________
graph_gather (GraphGather) (64, 256) 0 batch_normalization_2[0][0]
input_2[0][0]
input_3[0][0]
input_6[0][0]
input_7[0][0]
input_8[0][0]
input_9[0][0]
input_10[0][0]
input_11[0][0]
input_12[0][0]
input_13[0][0]
input_14[0][0]
input_15[0][0]
input_16[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (64, 2) 514 graph_gather[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (64, 1, 2) 0 dense_1[0][0]
__________________________________________________________________________________________________
input_4 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
trim_graph_output (TrimGraphOut (None, 1, 2) 0 reshape[0][0]
input_4[0][0]
__________________________________________________________________________________________________
input_5 (InputLayer) [(None,)] 0
__________________________________________________________________________________________________
softmax (Softmax) (None, 1, 2) 0 trim_graph_output[0][0]
==================================================================================================
Total params: 199,362
Trainable params: 198,850
Non-trainable params: 512
__________________________________________________________________________________________________