Source code for openddi.models.gognn.gognn

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGPooling, NNConv
from torch_geometric.data import Data, Batch
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.utils import to_undirected


[docs] class GoGNN(torch.nn.Module): """ GoGNN model for DDI prediction using graph neural networks. Features: - Molecular graph processing with GCN layers and SAGPooling - DDI graph processing for drug-drug interaction prediction - Supports both multiclass and multilabel classification Args: args: Configuration arguments. num_features: Input feature dimension. nhid: Hidden dimension for molecular GNN. ddi_nhid: Hidden dimension for DDI GNN. pooling_ratio: Ratio for SAGPooling layers. dropout_ratio: Dropout rate. num_rel: Number of relation types for classification. """
[docs] def __init__(self, args, num_features, nhid, ddi_nhid, pooling_ratio, dropout_ratio, num_rel): super(GoGNN, self).__init__() self.args = args self.num_features = num_features # self.ddi_num_features = args.ddi_num_features self.nhid = nhid self.ddi_nhid = ddi_nhid self.pooling_ratio = pooling_ratio self.dropout_ratio = dropout_ratio self.num_rel = num_rel self.conv1 = GCNConv(self.num_features, self.nhid).to(args.device) self.pool1 = SAGPooling(self.nhid, ratio=self.pooling_ratio).to(args.device) self.conv2 = GCNConv(self.nhid, self.nhid).to(args.device) self.pool2 = SAGPooling(self.nhid, ratio=self.pooling_ratio).to(args.device) self.conv3 = GCNConv(self.nhid, self.nhid).to(args.device) self.pool3 = SAGPooling(self.nhid, ratio=self.pooling_ratio).to(args.device) self.conv_noattn = GCNConv(6 * self.nhid, self.ddi_nhid).to(args.device) # dropout and edge classifier for supervised training self.dropout = nn.Dropout(self.dropout_ratio) self.edge_classifier = nn.Sequential( nn.Linear(2 * self.ddi_nhid, self.ddi_nhid), nn.ReLU(), nn.Linear(self.ddi_nhid, self.num_rel) )
[docs] def forward(self, data): """ Forward pass of GoGNN model. Args: data: Tuple containing (data_list, ddi_edge_index, ddi_edge_attr or None, ...) - data_list: List of molecular graph Data objects - ddi_edge_index: Edge indices for DDI graph Returns: torch.Tensor: Logits for DDI edge classification. """ # data: (data_list, ddi_edge_index, ddi_edge_attr or None, [optional placeholders...]) data_list, ddi_edge_index, *_ = data # Create Batch object from list of Data objects # This is done here to avoid pickling errors in DataLoader workers if isinstance(data_list, list): batched_data = Batch.from_data_list(data_list) else: batched_data = data_list ddi_edge_index = ddi_edge_index.to(self.args.device) # Unpack batched data x = batched_data.x.to(self.args.device) edge_index = batched_data.edge_index.to(self.args.device) edge_weight = batched_data.edge_attr.to(self.args.device) batch = batched_data.batch.to(self.args.device) # Parallel processing of all molecules in the batch # Layer 1 x = F.relu(self.conv1(x, edge_index, edge_weight)) # SAGPooling returns: x, edge_index, edge_attr, batch, perm, score x, edge_index, edge_weight, batch, _, _ = self.pool1(x, edge_index, edge_weight, batch) x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) # Layer 2 x = F.relu(self.conv2(x, edge_index, edge_weight)) x, edge_index, edge_weight, batch, _, _ = self.pool2(x, edge_index, edge_weight, batch) x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) # Layer 3 x = F.relu(self.conv3(x, edge_index, edge_weight)) x, edge_index, edge_weight, batch, _, _ = self.pool3(x, edge_index, edge_weight, batch) x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) # Concatenate features modular_feature = torch.cat((x1, x2, x3), dim=1) modular_feature = self.dropout(modular_feature) # one GCN layer on DDI graph (no edge attributes) x = F.relu(self.conv_noattn(modular_feature, ddi_edge_index)) # edge classification logits for given ddi_edge_index source, target = ddi_edge_index src_feat = x[source] tgt_feat = x[target] edge_feat = torch.cat([src_feat, tgt_feat], dim=1) logits = self.edge_classifier(edge_feat) return logits
[docs] def loss(self, logits, labels): """ Compute supervised loss for DDI edge classification. Args: logits: Model predictions. labels: Ground truth labels. Returns: torch.Tensor: Loss value. Note: Uses BCEWithLogitsLoss for multilabel tasks, CrossEntropyLoss for multiclass. """ task = getattr(self.args, 'matrix', 'multiclass') if task in ['multilabel', 'twosides']: labels = labels.float() return nn.BCEWithLogitsLoss()(logits, labels) else: return nn.CrossEntropyLoss()(logits, labels.long())