Source code for openddi.models.MKGFENN

# models/MKGFENN.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv

[docs] class FENN(nn.Module): """ Feature-level fusion with gating mechanism. For each modality m: h_m = ReLU(W_m x_m), g_m = sigmoid(wg_m^T x_m + b_m) Fusion: h = Σ_m g_m · h_m """
[docs] def __init__(self, modal_dims, out_dim, dropout=0.0): """ Initialize FENN module. Args: modal_dims: List of dimensions for each modality out_dim: Output dimension dropout: Dropout rate """ super().__init__() self.modal_dims = list(map(int, modal_dims)) self.out_dim = int(out_dim) self.dropout = float(dropout) self.proj = nn.ModuleList([nn.Linear(d, self.out_dim) for d in self.modal_dims]) self.gate = nn.ModuleList([nn.Linear(d, 1) for d in self.modal_dims])
[docs] def forward(self, x, splits): """ Forward pass of FENN module. Args: x: Input tensor [N, sum(d_m)] splits: Split indices [0, d1, d1+d2, ..., sum] Returns: Fused feature representation """ # x: [N, sum(d_m)]; splits: [0, d1, d1+d2, ..., sum] parts = [] for (l, r), pj, gt in zip(zip(splits[:-1], splits[1:]), self.proj, self.gate): xm = x[:, l:r] # [N, d_m] hm = F.relu(pj(xm)) # [N, out] gm = torch.sigmoid(gt(xm)) # [N, 1] parts.append(hm * gm) h = torch.stack(parts, dim=0).sum(0) # [N, out] return F.dropout(h, p=self.dropout, training=self.training)
[docs] class MKGFENN(nn.Module): """ MKG-FENN baseline model. Architecture: - FENN (multi-modal fusion) → RGCN → Pair-MLP → logits - Signature construction consistent with model_manager """
[docs] def __init__(self, feature:int, hidden1:int, hidden2:int, num_relations:int, num_classes:int, dropout:float=0.5): """ Initialize MKGFENN model. Args: feature: Input feature dimension hidden1: First hidden layer dimension hidden2: Second hidden layer dimension num_relations: Number of relation types num_classes: Number of output classes dropout: Dropout rate for regularization """ super().__init__() self.feature = int(feature) self.hid1 = int(hidden1) self.hid2 = int(hidden2) self.num_relations = int(num_relations) self.num_classes = int(num_classes) self.dropout = float(dropout) # ---- Multi-modal splitting (default single modality, can be set at runtime) ---- self.modal_dims = [self.feature] self.register_buffer("splits", torch.tensor([0, self.feature], dtype=torch.long)) self.fenn = FENN(self.modal_dims, out_dim=self.hid1, dropout=self.dropout) # ---- Graph encoding: two-layer RGCN using edge_type ---- self.rgcn1 = RGCNConv(self.hid1, self.hid1, num_relations=self.num_relations) self.rgcn2 = RGCNConv(self.hid1, self.hid2, num_relations=self.num_relations) # ---- Pair → classification ---- pair_in = 4 * self.hid2 mid = max(self.hid2, 128) self.mlp = nn.Sequential( nn.Linear(pair_in, mid), nn.ReLU(), nn.Dropout(self.dropout), nn.Linear(mid, mid), nn.ReLU(), nn.Dropout(self.dropout), nn.Linear(mid, self.num_classes) ) # ---- Graph caching ---- self.register_buffer("x_cache", None) self.register_buffer("edge_index_cache", None) self.register_buffer("edge_type_cache", None)
# During runtime, set the dimensions for each modality (consistent with the order in which they are concatenated in the CSV file), for example, "1024,768,256,128"
[docs] def set_modal_splits(self, split_str: str): """ Set modal splits at runtime (consistent with CSV concatenation order). Args: split_str: Comma-separated string of modal dimensions, e.g., "1024,768,256,128" """ if not split_str: return dims = [int(s) for s in split_str.split(',') if s.strip()] assert sum(dims) == self.feature, f"modal_splits 之和需等于 feature={self.feature}, got {sum(dims)}" self.modal_dims = dims # Rebuild FENN and update split boundaries self.fenn = FENN(self.modal_dims, out_dim=self.hid1, dropout=self.dropdown if hasattr(self, "dropdown") else self.dropout) self.fenn = self.fenn.to(next(self.parameters()).device) acc = [0] for d in self.modal_dims: acc.append(acc[-1] + d) self.splits = torch.tensor(acc, dtype=torch.long, device=next(self.parameters()).device)
[docs] def bind_graph(self, data_graph): """ Bind graph data to the model. Args: data_graph: Graph data object """ self.x_cache = data_graph.x self.edge_index_cache = data_graph.edge_index self.edge_type_cache = getattr(data_graph, "edge_type", None)
def _encode_nodes(self, x, edge_index, edge_type): """ Encode nodes through FENN fusion and RGCN layers. Args: x: Node features edge_index: Graph edge indices edge_type: Edge relation types Returns: Encoded node representations """ # FENN fusion if self.splits.device != x.device: self.splits = self.splits.to(x.device) h0 = self.fenn(x, self.splits) # [N, hid1] # RGCN h = self.rgcn1(h0, edge_index, edge_type); h = F.relu(h); h = F.dropout(h, p=self.dropout, training=self.training) h = self.rgcn2(h, edge_index, edge_type) # [N, hid2] return h @staticmethod def _pair_features(h, i_idx, j_idx): """ Generate pair features from node embeddings. Args: h: Node embeddings i_idx: Indices of first nodes in pairs j_idx: Indices of second nodes in pairs Returns: Concatenated pair features """ hi, hj = h[i_idx], h[j_idx] return torch.cat([hi, hj, torch.abs(hi - hj), hi * hj], dim=-1)
[docs] def forward(self, graph_or_none, idx_batch): """ Forward pass of the MKGFENN model. Args: graph_or_none: Can be None (already bound) or Data containing graph information idx_batch: (i_idx, j_idx, y) batch indices Returns: logits: Output logits """ if graph_or_none is not None: x = graph_or_none.x; edge_index = graph_or_none.edge_index edge_type = getattr(graph_or_none, "edge_type", None) else: x, edge_index, edge_type = self.x_cache, self.edge_index_cache, self.edge_type_cache assert x is not None and edge_index is not None, "graph 未绑定或传入" assert edge_type is not None, "MKG-FENN 需要 edge_type(关系类型)" h = self._encode_nodes(x, edge_index, edge_type) device = h.device i_idx = torch.as_tensor(list(idx_batch[0]), dtype=torch.long, device=device) j_idx = torch.as_tensor(list(idx_batch[1]), dtype=torch.long, device=device) pf = self._pair_features(h, i_idx, j_idx) # [B, 4*hid2] logits = self.mlp(pf) # [B, K] return logits