Source code for openddi.models.DeepDDI
# models/DeepDDI.py
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class DeepDDI(nn.Module):
"""
DeepDDI baseline model for drug-drug interaction prediction.
Architecture:
- No graph convolution, directly uses node feature matrix X
- Pair representation: [x_i || x_j || absolute difference || element-wise product]
- MLP -> logits (supports single/multi-label classification)
"""
[docs]
def __init__(self, feature:int, hidden1:int, hidden2:int,
num_relations:int, num_classes:int, dropout:float=0.3):
"""
Initialize DeepDDI model.
Args:
feature: Input feature dimension
hidden1: First hidden layer dimension
hidden2: Second hidden layer dimension
num_relations: Number of relation types (unused in current implementation)
num_classes: Number of output classes
dropout: Dropout rate for regularization
"""
super().__init__()
self.feature = int(feature)
self.hid1 = int(hidden1)
self.hid2 = int(hidden2)
self.num_classes = int(num_classes)
self.dropout = float(dropout)
pair_in = 4 * self.feature
hid = max(self.hid1, 256)
self.mlp = nn.Sequential(
nn.Linear(pair_in, hid),
nn.ReLU(),
nn.Dropout(self.dropout),
nn.Linear(hid, self.hid2),
nn.ReLU(),
nn.Dropout(self.dropout),
nn.Linear(self.hid2, self.num_classes) # logits
)
# Cache node feature matrix X (from dataset.data_graph.x), interface aligned with TIGER/ZeroDDI
self.register_buffer("x_cache", None)
[docs]
def bind_graph(self, data_graph):
"""
Bind graph data to the model.
Args:
data_graph: Graph data object (only uses X, not edge_index)
"""
# Only uses X, not edge_index
self.x_cache = data_graph.x
@staticmethod
def _pair_features(X, i_idx, j_idx):
"""
Generate pair features from node features.
Args:
X: Node feature matrix
i_idx: Indices of first nodes in pairs
j_idx: Indices of second nodes in pairs
Returns:
Concatenated pair features
"""
xi = X[i_idx]
xj = X[j_idx]
return torch.cat([xi, xj, torch.abs(xi - xj), xi * xj], dim=-1)
[docs]
def forward(self, graph_or_none, idx_batch):
"""
Forward pass of the DeepDDI model.
Args:
graph_or_none: Can be None (already bound) or Data containing node features
idx_batch: (i_idx, j_idx, y) batch indices
Returns:
logits: [B, K] output logits
"""
X = graph_or_none.x if (graph_or_none is not None) else self.x_cache
assert X is not None, "DeepDDI 需要节点特征表 X,请在 trainer 中先 bind_graph(dataset.data_graph)"
device = X.device
i_idx = torch.as_tensor(list(idx_batch[0]), dtype=torch.long, device=device)
j_idx = torch.as_tensor(list(idx_batch[1]), dtype=torch.long, device=device)
pair_feat = self._pair_features(X, i_idx, j_idx) # [B, 4*feature]
logits = self.mlp(pair_feat) # [B, K]
return logits