Source code for openddi.models.KGNN

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_geometric.nn import RGCNConv


[docs] class KGNN(nn.Module): """ KGNN model for knowledge graph-based drug-drug interaction prediction. This model uses RGCN layers for graph encoding and linear layer for pair classification. """
[docs] def __init__(self, feature:int, hidden1:int, hidden2:int, num_relations:int, num_classes:int, dropout:float=0.3): """ Initialize KGNN model. Args: feature: Input feature dimension hidden1: First hidden layer dimension for RGCN and linear projection hidden2: Second hidden layer dimension for RGCN num_relations: Number of relation types for RGCN num_classes: Number of output classes dropout: Dropout rate for regularization """ super().__init__() self.num_relations = int(num_relations) self.num_classes = int(num_classes) self.hidden1 = int(hidden1) self.hidden2 = int(hidden2) self.feature = int(feature) self.dropout = dropout self.fc = nn.Linear(self.feature, self.hidden1) self.rgcn1 = RGCNConv(self.feature, self.hidden1, num_relations=self.num_relations) self.rgcn2 = RGCNConv(self.hidden1, self.hidden2, num_relations=self.num_relations) self.linear = nn.Linear(2 * self.hidden2 , self.num_classes)
[docs] def forward(self, data_o, idx): """ Forward pass of the KGNN model. Args: data_o: Graph data object containing node features, edge indices, and edge types idx: Batch indices containing pairs of nodes to process Returns: Output logits for the given node pairs """ x_o, edge_index, edge_type = data_o.x, data_o.edge_index, data_o.edge_type a_idx = torch.as_tensor(list(idx[0]), dtype=torch.long, device=x_o.device) b_idx = torch.as_tensor(list(idx[1]), dtype=torch.long, device=x_o.device) # Hierarchical concatenation x = self.fc(x_o) x1_o = F.relu(self.rgcn1(x_o, edge_index, edge_type)) xt = F.dropout(x1_o, self.dropout) x2_o = self.rgcn2(xt, edge_index, edge_type) e_drug_one = x2_o[a_idx] e_drug_two = x2_o[b_idx] output = self.linear(torch.cat([e_drug_one, e_drug_two], dim = 1)) return output