[PYTHON] Implementieren Sie den GraphConvLayer von DeepChem in der benutzerdefinierten Ebene von PyTorch

Einführung

Ich habe DeepChems GraphConvLayer mit einer benutzerdefinierten Schicht Pytorch implementiert.

Umgebung

Quelle

Letztes Mal Ich habe mit dem erstellten DataSet- und DataLorder-Set einen Mini-Batch herausgenommen und versucht, ihn GraphConv zuzuführen und auszugeben.

import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np


class GraphConv(nn.Module):

    def __init__(self,
               in_channel,
               out_channel,
               min_deg=0,
               max_deg=10,
               activation=lambda x: x
               ):

        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.min_degree = min_deg
        self.max_degree = max_deg

        num_deg = 2 * self.max_degree + (1 - self.min_degree)

        self.W_list = [
            nn.Parameter(torch.Tensor(
                np.random.normal(size=(in_channel, out_channel))).double())
            for k in range(num_deg)]

        self.b_list = [
            nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]

    def forward(self, atom_features, deg_slice, deg_adj_lists):

        #print("deg_adj_list")
        print(deg_adj_lists)

        W = iter(self.W_list)
        b = iter(self.b_list)

        # Sum all neighbors using adjacency matrix
        deg_summed = self.sum_neigh(atom_features, deg_adj_lists)

        # Get collection of modified atom features
        new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]

        for deg in range(1, self.max_degree + 1):
            # Obtain relevant atoms for this degree
            rel_atoms = deg_summed[deg - 1]

            # Get self atoms
            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]

            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Apply hidden affine to relevant atoms and append
            rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
            self_out = torch.matmul(self_atoms, next(W)) + next(b)

            out = rel_out + self_out
            new_rel_atoms_collection[deg - self.min_degree] = out

        # Determine the min_deg=0 case
        if self.min_degree == 0:
            deg = 0

            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Only use the self layer
            out = torch.matmul(self_atoms, next(W)) + next(b)

            new_rel_atoms_collection[deg - self.min_degree] = out

        # Combine all atoms back into the list
        #print(new_rel_atoms_collection)
        atom_features = torch.cat(new_rel_atoms_collection, 0)

        return atom_features


    def sum_neigh(self, atoms, deg_adj_lists):
        """Store the summed atoms by degree"""
        deg_summed = self.max_degree * [None]

        for deg in range(1, self.max_degree + 1):
            index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
            gathered_atoms = atoms[index]

            # Sum along neighbors as well as self, and store
            summed_atoms = torch.sum(gathered_atoms, 1)
            deg_summed[deg - 1] = summed_atoms

        return deg_summed


class GCNDataset(data.Dataset):

    def __init__(self, smiles_list, label_list):
        self.smiles_list = smiles_list
        self.label_list = label_list

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, index):
        return self.smiles_list[index], self.label_list[index]


def gcn_collate_fn(batch):
    from rdkit import Chem
    cmf = ConvMolFeaturizer()

    mols = []
    labels = []

    for sample, label in batch:
        mols.append(Chem.MolFromSmiles(sample))
        labels.append(torch.tensor(label))

    conv_mols = cmf.featurize(mols)
    multiConvMol = ConvMol.agglomerate_mols(conv_mols)

    atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
    deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
    membership = torch.tensor(multiConvMol.membership, dtype=torch.float64)
    deg_adj_lists = []

    for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):
        deg_adj_lists.append(multiConvMol.get_deg_adjacency_lists()[i])

    return atom_feature, deg_slice, membership, deg_adj_lists,  labels


def main():
    dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
    dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)

    model = GraphConv(75, 20)
    for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
        print("atom_feature")
        print(atom_feature)
        print("deg_slice")
        print(deg_slice)
        print("membership")
        print(membership)
        print("result")
        print(model(atom_feature, deg_slice, deg_adj_lists))

if __name__ == "__main__":
    main()

Ergebnis

Ja, nicht. Derzeit scheint die resultierende Form die Anzahl der Atome x 20 Dimensionen zu sein (75 Dimensionen, die durch Faltung komprimiert wurden).

atom_feature
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         1., 0., 0.]], dtype=torch.float64)
deg_slice
tensor([[ 0.,  0.],
        [ 0.,  6.],
        [ 6.,  6.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.],
        [12.,  0.]], dtype=torch.float64)
membership
tensor([0., 0., 1., 1., 2., 2., 0., 1., 1., 2., 2., 2.], dtype=torch.float64)
result
tensor([[-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-0.2910,  2.2571,  1.6459, -4.0687, -3.3893,  4.3271,  1.5363,  1.2956,
         -1.1717,  0.8923, -0.9046, -3.9463,  4.2884, -3.5612, -9.7249,  1.9113,
          1.7882,  1.6279, -3.7770, -6.3691],
        [-1.6645,  6.3024,  0.6540, -0.7638,  5.3761, -6.3710, -0.3202,  1.3862,
          6.6121, -0.5707, -8.2441, -5.8404,  4.4354,  0.8659, -2.3474, -4.8642,
          8.3175,  0.1378, -4.6038, -3.9733],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721],
        [ 1.0006, -3.0494, -1.0774, -0.3946,  6.1658,  7.5366, -1.1302,  5.8955,
          8.6929, -0.0971, -3.9820,  1.1691,  2.7682,  2.3009, -3.1638, -3.4160,
         -5.4505, -1.0824,  1.1805, -3.3708],
        [-0.3320,  1.6265, -0.2117, -0.5792,  5.7710,  0.5828, -0.7252,  3.6408,
          7.6525, -0.3339, -6.1131, -2.3356,  3.6018,  1.5834, -2.7556, -4.1401,
          1.4335, -0.4723, -1.7117, -3.6721]], dtype=torch.float64,
       grad_fn=<CatBackward>)

Recommended Posts

Implementieren Sie den GraphConvLayer von DeepChem in der benutzerdefinierten Ebene von PyTorch
Implementieren Sie den GraphGatherLayer von DeepChem mit der benutzerdefinierten Ebene von PyTorch
Implementieren Sie den GraphPoolLayer von DeepChem in der benutzerdefinierten Ebene von PyTorch
Implementieren Sie einen benutzerdefinierten View Decorator mit Pyramid
Implementieren Sie ein benutzerdefiniertes Benutzermodell in Django
Implementieren Sie Custom Authorizer für die Firebase-Authentifizierung in Chalice
Implementieren Sie XENO mit Python
Benutzerdefinierte Sortierung in Python3
Implementieren Sie sum in Python
Implementieren Sie Traceroute in Python 3