[PYTHON] Hack GraphConvModel implémenté dans DeepChem avec résumé

introduction

Je voulais commencer le Deep Learning avec un composé, j'ai donc décidé de pirater GraphConvModel de DeepChem et de l'implémenter dans Keras. Donc, tout d'abord, j'ai décidé de sortir ce qui est implémenté dans Keras par la méthode summary de l'objet modèle.

environnement

Méthode

Placez model.summary () sur la ligne 624 du fichier qui définit la classe GraphConvModel et créez un modèle de prédiction avec les données appropriées.

<anaconda3>/envs/deepchem/lib/python3.7/site-packages/deepchem/models/graph_conv.py


    print(model.summary())

résultat

Comme ça. J'ai lu le papier et j'ai une idée générale, mais Deep Chem est un peu différent du papier, et je vais l'analyser à partir de maintenant.

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 75)]         0
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 2)]          0
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None,)]            0
__________________________________________________________________________________________________
input_6 (InputLayer)            [(None, 1)]          0
__________________________________________________________________________________________________
input_7 (InputLayer)            [(None, 2)]          0
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 3)]          0
__________________________________________________________________________________________________
input_9 (InputLayer)            [(None, 4)]          0
__________________________________________________________________________________________________
input_10 (InputLayer)           [(None, 5)]          0
__________________________________________________________________________________________________
input_11 (InputLayer)           [(None, 6)]          0
__________________________________________________________________________________________________
input_12 (InputLayer)           [(None, 7)]          0
__________________________________________________________________________________________________
input_13 (InputLayer)           [(None, 8)]          0
__________________________________________________________________________________________________
input_14 (InputLayer)           [(None, 9)]          0
__________________________________________________________________________________________________
input_15 (InputLayer)           [(None, 10)]         0
__________________________________________________________________________________________________
input_16 (InputLayer)           [(None, 11)]         0
__________________________________________________________________________________________________
graph_conv (GraphConv)          (None, 64)           102144      input_1[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 64)           256         graph_conv[0][0]
__________________________________________________________________________________________________
graph_pool (GraphPool)          (None, 64)           0           batch_normalization[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
graph_conv_1 (GraphConv)        (None, 64)           87360       graph_pool[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 64)           256         graph_conv_1[0][0]
__________________________________________________________________________________________________
graph_pool_1 (GraphPool)        (None, 64)           0           batch_normalization_1[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 128)          8320        graph_pool_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128)          512         dense[0][0]
__________________________________________________________________________________________________
graph_gather (GraphGather)      (64, 256)            0           batch_normalization_2[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_6[0][0]
                                                                 input_7[0][0]
                                                                 input_8[0][0]
                                                                 input_9[0][0]
                                                                 input_10[0][0]
                                                                 input_11[0][0]
                                                                 input_12[0][0]
                                                                 input_13[0][0]
                                                                 input_14[0][0]
                                                                 input_15[0][0]
                                                                 input_16[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (64, 2)              514         graph_gather[0][0]
__________________________________________________________________________________________________
reshape (Reshape)               (64, 1, 2)           0           dense_1[0][0]
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None,)]            0
__________________________________________________________________________________________________
trim_graph_output (TrimGraphOut (None, 1, 2)         0           reshape[0][0]
                                                                 input_4[0][0]
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None,)]            0
__________________________________________________________________________________________________
softmax (Softmax)               (None, 1, 2)         0           trim_graph_output[0][0]
==================================================================================================
Total params: 199,362
Trainable params: 198,850
Non-trainable params: 512
__________________________________________________________________________________________________


Recommended Posts

Hack GraphConvModel implémenté dans DeepChem avec résumé
Implémentation de SimRank en Python
Implémentation hard-swish avec Keras
Implémentation de Shiritori en Python
[Pour les débutants] Résumé de l'entrée standard en Python (avec explication)