Source code for openddi.models.MMDGDTI

# models/MMDGDTI.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv  # Can be replaced with GATConv/GCNConv

[docs] class ModalAttnFusion(nn.Module): """ Multi-modal attention fusion. For each modality: x_m -> proj_m(x_m) in R^d Scoring: s_m = w_m^T x_m Attention: α = softmax([s_m]); h = sum_m α_m * ReLU(proj_m(x_m)) """
[docs] def __init__(self, modal_dims, out_dim, dropout=0.0): """ Initialize ModalAttnFusion. Args: modal_dims: List of dimensions for each modality out_dim: Output dimension dropout: Dropout rate """ super().__init__() self.modal_dims = list(map(int, modal_dims)) self.out_dim = int(out_dim) self.dropout = float(dropout) self.proj = nn.ModuleList([nn.Linear(d, self.out_dim) for d in self.modal_dims]) self.scor = nn.ModuleList([nn.Linear(d, 1) for d in self.modal_dims])
[docs] def forward(self, x, splits): """ Forward pass of ModalAttnFusion. Args: x: Input tensor [N, sum(d_m)] splits: Split indices [0, d1, d1+d2, ..., sum] Returns: Tuple of (fused features, attention weights) """ # x: [N, sum(d_m)]; splits: [0, d1, d1+d2, ..., sum] feats = [] scores = [] for (l, r), pj, sc in zip(zip(splits[:-1], splits[1:]), self.proj, self.scor): xm = x[:, l:r] # [N, d_m] hm = F.relu(pj(xm)) # [N, out] sm = sc(xm) # [N, 1] feats.append(hm); scores.append(sm) S = torch.cat(scores, dim=1) # [N, M] A = torch.softmax(S, dim=1) # [N, M] H = torch.stack(feats, dim=1) # [N, M, out] h = torch.sum(A.unsqueeze(-1) * H, dim=1) # [N, out] return F.dropout(h, p=self.dropout, training=self.training), A # Return attention for reference
[docs] class MMDGDTI(nn.Module): """ MMDG-DTI baseline (adapted for DDI). Architecture: - Modal attention fusion -> GraphSAGE (two layers) -> Pair representation -> MLP classification - Signature construction consistent with model_manager """
[docs] def __init__(self, feature:int, hidden1:int, hidden2:int, num_relations:int, num_classes:int, dropout:float=0.3): """ Initialize MMDGDTI model. Args: feature: Input feature dimension hidden1: First hidden layer dimension hidden2: Second hidden layer dimension num_relations: Number of relation types (unused in current implementation) num_classes: Number of output classes dropout: Dropout rate for regularization """ super().__init__() self.feature = int(feature) self.hid1 = int(hidden1) self.hid2 = int(hidden2) self.num_classes = int(num_classes) self.dropout = float(dropout) # --- Modal splitting (can be injected at runtime via set_modal_splits) --- self.modal_dims = [self.feature] # Default to single modality (after integration) self.register_buffer("splits", torch.tensor([0, self.feature], dtype=torch.long)) # --- Multi-modal fusion (attention) --- self.fuse = ModalAttnFusion(self.modal_dims, out_dim=self.hid1, dropout=self.dropout) # --- Graph encoding (two-layer SAGE, robust and concise; can be replaced with GAT/GCN/RGCN) --- self.gnn1 = SAGEConv(self.hid1, self.hid2) self.gnn2 = SAGEConv(self.hid2, self.hid2) # --- Pair head (symmetric): [hi, hj, |hi-hj|, hi*hj] -> logits --- pair_in = 4 * self.hid2 mid = max(self.hid2, 128) self.mlp = nn.Sequential( nn.Linear(pair_in, mid), nn.ReLU(), nn.Dropout(self.dropout), nn.Linear(mid, mid), nn.ReLU(), nn.Dropout(self.dropout), nn.Linear(mid, self.num_classes) ) # Graph caching self.register_buffer("x_cache", None) self.register_buffer("edge_index_cache", None)
# Called during Trainer initialization to inform about the dimensions of each modality (consistent with the CSV concatenation order)
[docs] def set_modal_splits(self, split_str: str): """ Set modal splits at runtime (consistent with CSV concatenation order). Args: split_str: Comma-separated string of modal dimensions """ if not split_str: # Keep single modality return dims = [int(s) for s in split_str.split(',') if s.strip()] assert sum(dims) == self.feature, f"modal_splits 之和需等于 feature={self.feature}, got {sum(dims)}" self.modal_dims = dims # Rebuild fusion layer self.fuse = ModalAttnFusion(self.modal_dims, out_dim=self.hid1, dropout=self.dropout).to(next(self.parameters()).device) # Build boundaries acc = [0] for d in self.modal_dims: acc.append(acc[-1] + d) self.splits = torch.tensor(acc, dtype=torch.long, device=next(self.parameters()).device)
[docs] def bind_graph(self, data_graph): """ Bind graph data to the model. Args: data_graph: Graph data object """ self.x_cache = data_graph.x self.edge_index_cache = data_graph.edge_index
def _encode_nodes(self, X, edge_index): """ Encode nodes through modal fusion and graph layers. Args: X: Node features edge_index: Graph edge indices Returns: Encoded node representations """ # Multi-modal attention fusion if self.splits.device != X.device: self.splits = self.splits.to(X.device) h0, _ = self.fuse(X, self.splits) # [N, hid1] # Graph encoding (two-layer SAGE) h = self.gnn1(h0, edge_index); h = F.relu(h); h = F.dropout(h, p=self.dropout, training=self.training) h = self.gnn2(h, edge_index) # [N, hid2] return h @staticmethod def _pair_feat(h, i_idx, j_idx): """ Generate pair features from node embeddings. Args: h: Node embeddings i_idx: Indices of first nodes in pairs j_idx: Indices of second nodes in pairs Returns: Concatenated pair features """ hi, hj = h[i_idx], h[j_idx] return torch.cat([hi, hj, torch.abs(hi - hj), hi * hj], dim=-1)
[docs] def forward(self, graph_or_none, idx_batch): """ Forward pass of the MMDGDTI model. Args: graph_or_none: Can be None (already bound) or Data containing graph information idx_batch: (i_idx, j_idx, y) batch indices Returns: logits: Output logits """ if graph_or_none is not None: X, edge_index = graph_or_none.x, graph_or_none.edge_index else: X, edge_index = self.x_cache, self.edge_index_cache assert X is not None and edge_index is not None, "graph 未绑定或传入" h = self._encode_nodes(X, edge_index) # [N, hid2] device = h.device i_idx = torch.as_tensor(list(idx_batch[0]), dtype=torch.long, device=device) j_idx = torch.as_tensor(list(idx_batch[1]), dtype=torch.long, device=device) pf = self._pair_feat(h, i_idx, j_idx) # [B, 4*hid2] logits = self.mlp(pf) # [B, K] return logits