Source code for openddi.models.GOGNN

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGPooling, NNConv, RGCNConv
from torch_geometric.data import Data
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp


[docs] class GOGNN(nn.Module): """ GOGNN model for drug-drug interaction prediction using relational graph convolutional networks. This model uses RGCNConv for graph encoding and MLP for pair classification. """
[docs] def __init__(self, feature: int, hidden1: int, hidden2: int, num_relations: int, num_classes: int, dropout: float = 0.5, pooling_ratio: float = 0.5): """ Initialize GOGNN model. Args: feature: Input feature dimension hidden1: First hidden layer dimension (unused in current implementation) hidden2: Second hidden layer dimension (unused in current implementation) num_relations: Number of relation types for RGCN num_classes: Number of output classes dropout: Dropout rate for regularization pooling_ratio: Pooling ratio parameter (unused in current implementation) """ super().__init__() self.num_relations = int(num_relations) self.num_classes = int(num_classes) self.feature = int(feature) self.num_edge_features = int(num_classes) self.nhid = int(hidden1) self.ddi_nhid = int(hidden2) self.pooling_ratio = pooling_ratio self.dropout_ratio = dropout self.conv = RGCNConv(self.feature, 256, num_relations=self.num_relations) self.mlp = nn.Sequential( nn.Linear(512, 256), nn.ELU(), nn.Dropout(p=0.1), nn.Linear(256, 128), nn.ELU(), nn.Dropout(p=0.1), nn.Linear(128, self.num_classes) # Output dimension = number of labels )
[docs] def forward(self, data_o, idx): """ Forward pass of the GOGNN model. Args: data_o: Graph data object containing node features, edge indices, and edge types idx: Batch indices containing pairs of nodes to process Returns: Output logits for the given node pairs """ x_o, edge_index, e_type = data_o.x, data_o.edge_index, data_o.edge_type a_idx = torch.as_tensor(list(idx[0]), dtype=torch.long, device=x_o.device) b_idx = torch.as_tensor(list(idx[1]), dtype=torch.long, device=x_o.device) x = F.relu(self.conv(x_o, edge_index, e_type)) ent_a = x[a_idx] ent_b = x[b_idx] pair_vec = torch.cat([ent_a, ent_b], dim=1) logits = self.mlp(pair_vec) # [batch_size, num_classes] return logits