# models/CASTER.py
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class CASTER(nn.Module):
"""
CASTER baseline (tensor decomposition/bilinear scoring style).
Architecture:
- X -> drug embedding e_i
- Each relation k learns a vector r_k
- logits_k(i,j) = < (e_i ⊙ r_k), e_j >
Signature construction consistent with existing model_manager:
- (feature, hidden1, hidden2, num_relations, num_classes, dropout)
- Where hidden2 is used as the final embedding dimension d.
"""
[docs]
def __init__(self, feature:int, hidden1:int, hidden2:int,
num_relations:int, num_classes:int, dropout:float=0.3):
"""
Initialize CASTER model.
Args:
feature: Input feature dimension
hidden1: First hidden layer dimension
hidden2: Final embedding dimension (d)
num_relations: Number of relation types
num_classes: Number of output classes
dropout: Dropout rate
"""
super().__init__()
self.feature = int(feature)
self.hid1 = int(hidden1)
self.dim = int(hidden2) # Embedding dimension d
self.num_classes = int(num_classes)
self.dropout = float(dropout)
# Get drug embedding e from node features X (two-layer MLP, closer to "feature coupling decomposition" idea)
self.enc = nn.Sequential(
nn.Linear(self.feature, self.hid1),
nn.ReLU(),
nn.Dropout(self.dropout),
nn.Linear(self.hid1, self.dim)
)
# Diagonal "transformation" vector r_k for each relation k (CASTER/DistMult idea)
self.rel = nn.Parameter(torch.randn(self.num_classes, self.dim))
# Optional bias (drug/relation)
self.bias_e = nn.Parameter(torch.zeros(self.dim))
self.bias_r = nn.Parameter(torch.zeros(self.num_classes))
# Cache node feature matrix X (consistent with other baseline's bind_graph convention)
self.register_buffer("x_cache", None)
[docs]
def bind_graph(self, data_graph):
"""
Bind graph data to the model.
Args:
data_graph: Graph data containing node features
"""
# Only need X, not using edges
self.x_cache = data_graph.x
def _node_embed(self, X):
"""
Generate node embeddings from features.
Args:
X: Input node features
Returns:
Node embeddings
"""
e = self.enc(X) # [N, d]
# e = F.normalize(e + self.bias_e, dim=-1) # Training stabilization
return e
[docs]
def forward(self, graph_or_none, idx_batch):
"""
Forward pass of the model.
Args:
graph_or_none: Can be None (already bound) or Data containing x
idx_batch: (i_idx, j_idx, y) batch indices
Returns:
logits: [B, K] output logits
"""
X = graph_or_none.x if (graph_or_none is not None) else self.x_cache
assert X is not None, "CASTER 需要节点特征 X,请先在 Trainer 中 bind_graph(dataset.data_graph)"
e = self._node_embed(X) # [N, d]
device = e.device
i_idx = torch.as_tensor(list(idx_batch[0]), dtype=torch.long, device=device)
j_idx = torch.as_tensor(list(idx_batch[1]), dtype=torch.long, device=device)
e_i = e[i_idx] # [B, d]
e_j = e[j_idx] # [B, d]
# Calculate logits for all relations: DistMult style (e_i ⊙ r_k) · e_j
# Equivalent: first expand e_i to [B, 1, d], element-wise multiply with rel [K, d], then dot product with e_j
# B, d = e_i.size()
R = self.rel # [K, d]
# [B, K, d]
eiR = e_i.unsqueeze(1) * R.unsqueeze(0)
# [B, K]
logits = (eiR * e_j.unsqueeze(1)).sum(-1) + self.bias_r.unsqueeze(0)
# logits = (eiR * e_j.unsqueeze(1)).sum(-1)
return logits