Source code for openddi.models.MIRACLE
# models/MIRACLE.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, RGCNConv
[docs]
class InteractionPredictor(nn.Module):
"""
Interaction predictor module for MIRACLE model.
"""
[docs]
def __init__(self, dg, hidden, k):
"""
Initialize InteractionPredictor.
Args:
dg: Input dimension
hidden: Hidden layer dimension
k: Output dimension (number of classes)
"""
super().__init__()
self.Wl = nn.Linear(dg, hidden)
self.bl = nn.Parameter(torch.zeros(hidden))
self.Wp = nn.Linear(hidden, k)
self.bp = nn.Parameter(torch.zeros(k))
[docs]
def forward(self, l):
"""
Forward pass of InteractionPredictor.
Args:
l: Input tensor
Returns:
Predicted interaction scores
"""
out = F.relu(self.Wl(l) + self.bl)
out = self.Wp(out) + self.bp
return out
[docs]
class MIRACLE(nn.Module):
"""
MIRACLE model for drug-drug interaction prediction.
This model combines GCN and RGCN layers with interaction prediction.
"""
[docs]
def __init__(self, feature: int, hidden1: int, hidden2: int,
num_relations: int, num_classes: int, dropout: float = 0.3, pooling_ratio: float = 0.5):
"""
Initialize MIRACLE model.
Args:
feature: Input feature dimension
hidden1: First hidden layer dimension for GCN
hidden2: Second hidden layer dimension for RGCN
num_relations: Number of relation types for RGCN
num_classes: Number of output classes
dropout: Dropout rate for regularization
pooling_ratio: Pooling ratio parameter (unused in current implementation)
"""
super().__init__()
self.num_relations = int(num_relations)
self.num_classes = int(num_classes)
self.feature_dim = int(feature)
self.num_edge_features = int(num_classes)
self.hidden1 = int(hidden1)
self.hidden2 = int(hidden2)
self.pooling_ratio = pooling_ratio
self.dropout_ratio = dropout
self.gnn1 = GCNConv(self.feature_dim, self.hidden1)
self.gnn2 = RGCNConv(self.hidden1, self.hidden2, self.num_relations)
self.dropout = nn.Dropout(self.dropout_ratio)
self.predictor = InteractionPredictor(self.hidden2, self.hidden2, self.num_classes)
[docs]
def forward(self, data_o, idx):
"""
Forward pass of the MIRACLE model.
Args:
data_o: Graph data object containing node features, edge indices, and edge types
idx: Batch indices containing pairs of nodes to process
Returns:
Predicted interaction scores
"""
x_o, edge_index, e_type = data_o.x, data_o.edge_index, data_o.edge_type
a_idx = torch.as_tensor(list(idx[0]), dtype=torch.long, device=x_o.device)
b_idx = torch.as_tensor(list(idx[1]), dtype=torch.long, device=x_o.device)
G = x_o
x = self.gnn1(G, edge_index)
x = F.relu(x)
x = self.dropout(x)
D = self.gnn2(x, edge_index, e_type)
l_d = D[a_idx] * D[b_idx]
pred = self.predictor(l_d)
return pred