Source code for openddi.models.SumGNN

import torch
import torch.nn as nn
from torch_geometric.nn import RGCNConv

__all__ = ["SumGNN"]

[docs] class SumGNN(nn.Module): """ SumGNN model with two RGCN layers. Projects high-dimensional multimodal features to lower dimensions before message passing. Pair representation = concat(h1_i, h1_j, h2_i, h2_j) -> fully connected output logits. Args: feature (int): Original multimodal feature dimension (e.g., 2860) hidden1 (int): First hidden layer dimension hidden2 (int): Second hidden layer dimension num_relations (int): Number of relation types in the graph num_classes (int): Number of output classes dropout (float): Dropout rate, default 0.3 proj_dim (int, optional): Projection dimension for input features num_bases (int, optional): Number of bases for RGCN weight decomposition """
[docs] def __init__(self, feature:int, hidden1:int, hidden2:int, num_relations:int, num_classes:int, dropout:float=0.3, proj_dim:int=None, num_bases:int=None): super().__init__() self.feature = int(feature) # Original multimodal dimension (2860) self.hidden1 = int(hidden1) # First hidden layer dimension self.hidden2 = int(hidden2) # Second hidden layer dimension self.num_rel = max(1, int(num_relations)) # Number of relations self.num_classes = int(num_classes) self.dropout = float(dropout) # ---- Key: Input dimension reduction (from 2860 to proj_dim, default: hidden1 * 4) ---- if proj_dim is None: proj_dim = max(128, min(512, self.hidden1 * 4)) self.proj_dim = int(proj_dim) self.in_proj = nn.Sequential( nn.Linear(self.feature, self.proj_dim), nn.ReLU(), nn.Dropout(self.dropout), ) # ---- RGCN low-rank decomposition: significantly reduces weight memory ---- if num_bases is None: num_bases = min(self.num_rel, 2) # For single-relation graphs, 2 is sufficient self.num_bases = int(num_bases) self.gnn1 = RGCNConv(self.feature, self.hidden1, num_relations=self.num_rel, num_bases=self.num_bases) self.gnn2 = RGCNConv(self.hidden1, self.hidden2, num_relations=self.num_rel, num_bases=self.num_bases) self.relu1 = nn.ReLU() self.relu2 = nn.ReLU() self.drop1 = nn.Dropout(self.dropout) self.drop2 = nn.Dropout(self.dropout) # Pair representation: [h1_i, h1_j, h2_i, h2_j] self.fc = nn.Linear(2 * self.hidden1 + 2 * self.hidden2, self.num_classes)
[docs] def drug_feat(self, emb): """Compatibility interface for old code; not used.""" self.drugfeat = emb
[docs] def forward(self, data_o, idx_batch): """ Forward pass of the model. Args: data_o: Graph data object containing node features, edge indices and types idx_batch: Batch indices for drug pairs Returns: torch.Tensor: Output logits for the batch """ x, edge_index, edge_type = data_o.x, data_o.edge_index, data_o.edge_type a_idx = torch.as_tensor(list(idx_batch[0]), dtype=torch.long, device=x.device) b_idx = torch.as_tensor(list(idx_batch[1]), dtype=torch.long, device=x.device) # Input dimension reduction to reduce memory usage in edge gathering # h0 = self.in_proj(x) h1 = self.gnn1(x, edge_index, edge_type) h1 = self.relu1(h1); h1 = self.drop1(h1) h2 = self.gnn2(h1, edge_index, edge_type) h2 = self.relu2(h2); h2 = self.drop2(h2) pair = torch.cat([h1[a_idx], h1[b_idx], h2[a_idx], h2[b_idx]], dim=1) return self.fc(pair)