Source code for openddi.models.TIGER

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

[docs] class TIGER(nn.Module): """ TIGER (Implementation version): Graph encoding + pair representation + MLP classification. Maintains compatible signature with existing model_manager: (feature, hidden1, hidden2, num_relations, num_classes, dropout) Args: feature (int): Input feature dimension hidden1 (int): First hidden layer dimension hidden2 (int): Second hidden layer dimension num_relations (int): Number of relation types (for compatibility) num_classes (int): Number of output classes dropout (float): Dropout rate, default 0.3 """
[docs] def __init__(self, feature:int, hidden1:int, hidden2:int, num_relations:int, num_classes:int, dropout:float=0.3): super().__init__() self.feature = int(feature) self.hid1 = int(hidden1) self.hid2 = int(hidden2) self.num_classes = int(num_classes) self.dropout = float(dropout) # ---- Graph encoding: Two-layer GAT ---- heads1, heads2 = 4, 4 self.gnn1 = GATConv(self.feature, self.hid1 // heads1, heads=heads1, concat=True) self.gnn2 = GATConv(self.hid1, self.hid1 // heads2, heads=heads2, concat=True) # ---- Pair representation: concat + abs diff + hadamard ---- pair_in_dim = 4 * self.hid1 hid = max(self.hid2, 128) self.mlp = nn.Sequential( nn.Linear(pair_in_dim, hid), nn.ReLU(), nn.Dropout(self.dropout), nn.Linear(hid, hid), nn.ReLU(), nn.Dropout(self.dropout), nn.Linear(hid, self.num_classes) # Direct output logits ) # Can bind graph (dataset.data_graph) for Trainer to reduce parameter passing each step self.register_buffer("x_cache", None) self.register_buffer("edge_index_cache", None)
[docs] def bind_graph(self, data_graph): """ Bind graph data to model for caching. Args: data_graph: Graph data to cache """ self.x_cache = data_graph.x self.edge_index_cache = data_graph.edge_index
def _encode_nodes(self, x, edge_index): """ Encode nodes using GAT layers. Args: x: Node features edge_index: Graph edge indices Returns: torch.Tensor: Encoded node representations """ h = self.gnn1(x, edge_index) h = F.elu(h) h = F.dropout(h, p=self.dropout, training=self.training) h = self.gnn2(h, edge_index) return h @staticmethod def _pair_features(h, i_idx, j_idx): """ Generate pair features from node representations. Args: h: Node representations i_idx: Indices for first drug in pairs j_idx: Indices for second drug in pairs Returns: torch.Tensor: Concatenated pair features """ hi = h[i_idx] hj = h[j_idx] return torch.cat([hi, hj, torch.abs(hi - hj), hi * hj], dim=-1)
[docs] def forward(self, graph_or_none, idx_batch): """ Forward pass of the model. Args: graph_or_none: torch_geometric.data.Data or None (uses cached graph if None) idx_batch: (i_idx, j_idx, y) - compatible with Base_multi_dataset / Base_multilabel_dataset Returns: torch.Tensor: Output logits [B, K] """ if graph_or_none is not None: x, edge_index = graph_or_none.x, graph_or_none.edge_index else: x, edge_index = self.x_cache, self.edge_index_cache assert x is not None and edge_index is not None, "graph 未绑定或传入" h = self._encode_nodes(x, edge_index) i_idx = torch.as_tensor(list(idx_batch[0]), dtype=torch.long, device=h.device) j_idx = torch.as_tensor(list(idx_batch[1]), dtype=torch.long, device=h.device) pair_feat = self._pair_features(h, i_idx, j_idx) # [B, 4*hid2] logits = self.mlp(pair_feat) # [B, K] return logits