# PHGLDDI.py (Fixed version: unified device, compatible with older HypergraphConv)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.pool import SAGPooling
from torch_geometric.nn import global_mean_pool
import torch.nn as nn
from torch_geometric.nn import GINConv
from torch_geometric.nn import GCNConv
[docs]
class GCN_Bottom(nn.Module):
"""Bottom GCN module for hierarchical graph processing."""
[docs]
def __init__(self, hidden=512, feature=300):
"""
Initialize GCN_Bottom module.
Args:
hidden: Hidden dimension size
feature: Input feature dimension
"""
super(GCN_Bottom, self).__init__()
self.conv1 = GCNConv(feature, hidden)
self.conv2 = GCNConv(hidden, hidden)
self.conv3 = GCNConv(hidden, hidden)
self.conv4 = GCNConv(hidden, hidden)
self.bn1 = nn.BatchNorm1d(hidden)
self.bn2 = nn.BatchNorm1d(hidden)
self.bn3 = nn.BatchNorm1d(hidden)
self.bn4 = nn.BatchNorm1d(hidden)
self.sag1 = SAGPooling(hidden,0.5)
self.sag2 = SAGPooling(hidden,0.5)
self.sag3 = SAGPooling(hidden,0.5)
self.sag4 = SAGPooling(hidden,0.5)
self.fc1 = nn.Linear(hidden, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.fc3 = nn.Linear(hidden, hidden)
self.fc4 = nn.Linear(hidden, hidden)
self.dropout = nn.Dropout(0.5)
[docs]
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = self.fc1(x)
x = F.relu(x)
x = self.bn1(x)
y = self.sag1(x, edge_index)
x = y[0]
batch = y[3]
edge_index = y[1]
x = self.conv2(x, edge_index)
x = self.fc2(x)
x = F.relu(x)
x = self.bn2(x)
y = self.sag2(x, edge_index, batch = batch)
x = y[0]
batch = y[3]
edge_index = y[1]
x = self.conv3(x, edge_index)
x = self.fc3(x)
x = F.relu(x)
x = self.bn3(x)
y = self.sag3(x, edge_index, batch = batch)
x = y[0]
batch = y[3]
edge_index = y[1]
x = self.conv4(x, edge_index)
x = self.fc4(x)
x = F.relu(x)
x = self.bn4(x)
y= self.sag4(x, edge_index, batch = batch)
return global_mean_pool(y[0], y[3]), y[1]
[docs]
class GIN_Top(torch.nn.Module):
"""Top GIN module for hierarchical graph processing."""
[docs]
def __init__(self, fea, hid, hidden=256, train_eps=True):
"""
Initialize GIN_Top module.
Args:
fea: Input feature dimension
hid: Hidden dimension size
hidden: Output hidden dimension size
train_eps: Whether to train epsilon parameter
"""
super(GIN_Top, self).__init__()
self.train_eps = train_eps
self.gin_conv1 = GINConv(
nn.Sequential(
nn.Linear(fea, hid),
nn.ReLU(),
nn.Linear(hid, hid),
nn.ReLU(),
# nn.Linear(hidden, hidden),
# nn.ReLU(),
nn.BatchNorm1d(hid),
), train_eps=self.train_eps
)
self.gin_conv2 = GINConv(
nn.Sequential(
nn.Linear(hid, hid),
nn.ReLU(),
nn.Linear(hid, hid),
nn.ReLU(),
nn.BatchNorm1d(hid),
), train_eps=self.train_eps
)
self.gin_conv3 = GINConv(
nn.Sequential(
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.BatchNorm1d(hidden),
), train_eps=self.train_eps
)
self.lin1 = nn.Linear(hid, hidden)
self.fc1 = nn.Linear(2 * hidden, 1)
self.fc2 = nn.Linear(hidden, 1)
[docs]
def reset_parameters(self):
"""Reset model parameters."""
self.fc1.reset_parameters()
self.gin_conv1.reset_parameters()
self.gin_conv2.reset_parameters()
# self.gin_conv3.reset_parameters()
self.lin1.reset_parameters()
self.fc1.reset_parameters()
self.fc2.reset_parameters()
[docs]
def forward(self, x, edge_index):
"""
Forward pass of GIN_Top.
Args:
x: Node features
edge_index: Graph edge indices
Returns:
Processed node representations
"""
x = self.gin_conv1(x, edge_index)
x = self.gin_conv2(x, edge_index)
# x = self.gin_conv3(x, edge_index)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=0.3, training=self.training)
return x
[docs]
class PHGLDDI(nn.Module):
"""PHGLDDI model for hierarchical graph learning in drug-drug interaction prediction."""
[docs]
def __init__(self, feature:int, hidden1:int, hidden2:int,
num_relations:int, num_classes:int):
"""
Initialize PHGLDDI model.
Args:
feature: Input feature dimension
hidden1: First hidden layer dimension
hidden2: Second hidden layer dimension
num_relations: Number of relation types (unused in current implementation)
num_classes: Number of output classes
"""
super(PHGLDDI,self).__init__()
self.BGNN = GCN_Bottom(hidden1, feature)
self.TGNN = GIN_Top(feature, hidden1, hidden2)
self.fc = nn.Linear(hidden2, num_classes)
[docs]
def forward(self, graph_or_none, idx_batch):
"""
Forward pass of the PHGLDDI model.
Args:
graph_or_none: Graph data object containing node features and edge indices
idx_batch: Batch indices containing pairs of nodes to process
Returns:
Output logits for the given node pairs
"""
x, edge_index= graph_or_none.x, graph_or_none.edge_index
edge_index = edge_index.to(x.device)
# embs, ed_index = self.BGNN(x, edge_index)
final = self.TGNN(x, edge_index)
i_idx = torch.as_tensor(list(idx_batch[0]), dtype=torch.long, device=x.device)
j_idx = torch.as_tensor(list(idx_batch[1]), dtype=torch.long, device=x.device)
x1 = final[i_idx]
x2 = final[j_idx]
x = torch.mul(x1, x2)
x = self.fc(x)
return x