[PYTHON] Implement DeepChem's GraphGatherLayer in PyTorch's custom layer


Following GraphConvLayer and GraphPoolLayer, I implemented DeepChem's GraphGatherLayer with a custom layer of Pytorch.



I ported DeepChem's GraphGatherLayer to PyTorch and tried to feed the output result of the previous GraphConvLayer to the created GraphPoolLayer.

import torch
from torch.utils import data
from deepchem.feat.graph_features import ConvMolFeaturizer
from deepchem.feat.mol_graphs import ConvMol
import torch.nn as nn
import numpy as np
from torch_scatter import scatter_max

def unsorted_segment_sum(data, segment_ids, num_segments):
    # segment_ids is a 1-D tensor repeat it to have the same shape as data
    if len(segment_ids.shape) == 1:
        s = torch.prod(torch.tensor(data.shape[1:])).long()
        segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])

    shape = [num_segments] + list(data.shape[1:])
    tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
    tensor = tensor.type(data.dtype)
    return tensor

class GraphConv(nn.Module):

    def __init__(self,
               activation=lambda x: x

        self.in_channel = in_channel
        self.out_channel = out_channel
        self.min_degree = min_deg
        self.max_degree = max_deg

        num_deg = 2 * self.max_degree + (1 - self.min_degree)

        self.W_list = [
                np.random.normal(size=(in_channel, out_channel))).double())
            for k in range(num_deg)]

        self.b_list = [
            nn.Parameter(torch.Tensor(np.zeros(out_channel)).double()) for k in range(num_deg)]

    def forward(self, atom_features, deg_slice, deg_adj_lists):


        W = iter(self.W_list)
        b = iter(self.b_list)

        # Sum all neighbors using adjacency matrix
        deg_summed = self.sum_neigh(atom_features, deg_adj_lists)

        # Get collection of modified atom features
        new_rel_atoms_collection = (self.max_degree + 1 - self.min_degree) * [None]

        for deg in range(1, self.max_degree + 1):
            # Obtain relevant atoms for this degree
            rel_atoms = deg_summed[deg - 1]

            # Get self atoms
            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]

            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Apply hidden affine to relevant atoms and append
            rel_out = torch.matmul(rel_atoms, next(W)) + next(b)
            self_out = torch.matmul(self_atoms, next(W)) + next(b)

            out = rel_out + self_out
            new_rel_atoms_collection[deg - self.min_degree] = out

        # Determine the min_deg=0 case
        if self.min_degree == 0:
            deg = 0

            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Only use the self layer
            out = torch.matmul(self_atoms, next(W)) + next(b)

            new_rel_atoms_collection[deg - self.min_degree] = out

        # Combine all atoms back into the list
        atom_features = torch.cat(new_rel_atoms_collection, 0)

        return atom_features

    def sum_neigh(self, atoms, deg_adj_lists):
        """Store the summed atoms by degree"""
        deg_summed = self.max_degree * [None]

        for deg in range(1, self.max_degree + 1):
            index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)
            gathered_atoms = atoms[index]

            # Sum along neighbors as well as self, and store
            summed_atoms = torch.sum(gathered_atoms, 1)
            deg_summed[deg - 1] = summed_atoms

        return deg_summed

class GraphPool(nn.Module):

    def __init__(self, min_degree=0, max_degree=10):
        self.min_degree = min_degree
        self.max_degree = max_degree

    def forward(self, atom_features, deg_slice, deg_adj_lists):

        # Perform the mol gather
        deg_maxed = (self.max_degree + 1 - self.min_degree) * [None]

        # Tensorflow correctly processes empty lists when using concat
        for deg in range(1, self.max_degree + 1):
            # Get self atoms
            begin = deg_slice[deg - self.min_degree, 0]
            size = deg_slice[deg - self.min_degree, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))

            # Expand dims
            self_atoms = torch.unsqueeze(self_atoms, 1)

            # always deg-1 for deg_adj_lists
            index = torch.tensor(deg_adj_lists[deg - 1], dtype=torch.int64)

            gathered_atoms = atom_features[index]
            gathered_atoms = torch.cat([self_atoms, gathered_atoms], 1)

            if gathered_atoms.shape[0] > 0:
                maxed_atoms = torch.max(gathered_atoms, 1)[0]
                maxed_atoms = torch.Tensor([])

            deg_maxed[deg - self.min_degree] = maxed_atoms

        if self.min_degree == 0:
            begin = deg_slice[0, 0]
            size = deg_slice[0, 1]
            self_atoms = torch.narrow(atom_features, 0, int(begin), int(size))
            deg_maxed[0] = self_atoms

        return torch.cat(deg_maxed, 0)

class GraphGather(nn.Module):

    def __init__(self, batch_size):
        self.batch_size = batch_size

    def forward(self, atom_features, membership):

        assert self.batch_size > 1, "graph_gather requires batches larger than 1"

        sparse_reps = unsorted_segment_sum(atom_features, membership, self.batch_size)
        max_reps = scatter_max(atom_features, membership, dim=0)
        mol_features = torch.cat([sparse_reps, max_reps[0]], 1)
        return mol_features

class GCNDataset(data.Dataset):

    def __init__(self, smiles_list, label_list):
        self.smiles_list = smiles_list
        self.label_list = label_list

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, index):
        return self.smiles_list[index], self.label_list[index]

def gcn_collate_fn(batch):
    from rdkit import Chem
    cmf = ConvMolFeaturizer()

    mols = []
    labels = []

    for sample, label in batch:

    conv_mols = cmf.featurize(mols)
    multiConvMol = ConvMol.agglomerate_mols(conv_mols)

    atom_feature = torch.tensor(multiConvMol.get_atom_features(), dtype=torch.float64)
    deg_slice = torch.tensor(multiConvMol.deg_slice, dtype=torch.float64)
    membership = torch.tensor(multiConvMol.membership, dtype=torch.int64)
    deg_adj_lists = []

    for i in range(1, len(multiConvMol.get_deg_adjacency_lists())):

    return atom_feature, deg_slice, membership, deg_adj_lists,  labels

def main():
    dataset = GCNDataset(["CCC", "CCCC", "CCCCC"], [1, 0, 1])
    dataloader = data.DataLoader(dataset, batch_size=3, shuffle=False, collate_fn =gcn_collate_fn)

    gc = GraphConv(75, 20)
    gp = GraphPool()
    gt = GraphGather(3)
    for atom_feature, deg_slice, membership, deg_adj_lists, labels in dataloader:
        gc_out = gc(atom_feature, deg_slice, deg_adj_lists)
        gp_out = gp(gc_out, deg_slice, deg_adj_lists)
        gt_out = gt(gp_out, membership)

if __name__ == "__main__":


Yes, don't. For the time being, the resulting shape is the number of molecules x 40 dimensions, and it can be seen that the atoms are aggregated into the molecules. I like this white box feeling as usual (the comments are exactly the same every time, so I'm skipping it). This time, I had a lot of trouble porting TensorFlow's unsorted_segment_sum and unsorted_segment_max operations. Verification is about to begin.

tensor([[ 7.7457,  2.1970, 22.1151,  1.8238,  7.5860, 15.5079, -1.3865,  5.3634,
          0.3872, 24.7713, 30.9865, 13.0032,  5.8331, 12.8195,  9.2520, 16.4660,
         -8.8977, 10.5881, 16.8875,  3.6356,  2.5819,  0.7323,  7.3717,  0.6079,
          2.5287,  5.1693, -0.4622,  1.7878,  0.1291,  8.2571, 10.3288,  4.3344,
          1.9444,  4.2732,  3.0840,  5.4887, -2.9659,  3.5294,  5.6292,  1.2119],
        [12.4624, 16.9705, 26.8321,  4.3047, 17.4027, 23.3370, -1.8487,  7.1511,
          0.2538, 23.2520, 25.0874, 17.3375,  7.7775,  9.7369,  8.3362, 20.8373,
         -4.3081, 14.1175, 17.6781,  6.4011,  3.1156,  4.2426,  6.7080,  1.0762,
          4.3507,  5.8342, -0.4622,  1.7878,  0.0634,  5.8130,  6.2718,  4.3344,
          1.9444,  2.4342,  2.0840,  5.2093, -1.0770,  3.5294,  4.4195,  1.6003],
        [17.1790, 31.7441, 33.5401,  8.6282, 27.2195, 31.1660, -4.6301,  4.2145,
         -1.0452, 29.0650, 31.3592, 15.0395, 14.6857, 12.1711, 10.4202, 26.0466,
          3.5187, 10.4842, 22.0976,  9.1667,  3.6493,  7.7530,  6.7080,  2.1586,
          6.1727,  6.4992, -0.4622,  1.7878,  0.0634,  5.8130,  6.2718,  4.3344,
          3.5990,  2.4342,  2.0840,  5.2093,  1.8909,  3.5294,  4.4195,  1.9887]],
       dtype=torch.float64, grad_fn=<CatBackward>)


