Source code for openddi.models.tiger.tiger

# This part is adapted from:
# Source: https://github.com/Blair1213/TIGER 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep,global_sort_pool
from torch_geometric.utils import dropout_adj
from torch.nn import BCEWithLogitsLoss, Linear
import math
from torch.nn import Linear, Sequential, ReLU, BatchNorm1d as BN
from torch_geometric.utils import degree
from .GraphTransformer import GraphTransformer
import os

[docs] class NodeFeatures(torch.nn.Module): """ Node feature encoder with degree information. Args: degree: Maximum node degree. feature_num: Number of input features. embedding_dim: Output embedding dimension. layer: Number of layers for initialization (default=2). type: Type of features ('graph' or 'node') (default='graph'). """
[docs] def __init__(self, degree, feature_num, embedding_dim, layer=2, type='graph'): super(NodeFeatures, self).__init__() if type == 'graph': ##代表有feature num self.node_encoder = Linear(feature_num, embedding_dim) else: self.node_encoder = torch.nn.Embedding(feature_num, embedding_dim) self.degree_encoder = torch.nn.Embedding(degree, embedding_dim, padding_idx=0) ##将度的值映射成embedding self.apply(lambda module: init_params(module, layers=layer))
[docs] def reset_parameters(self): """ Reset model parameters. """ self.node_encoder.reset_parameters() self.degree_encoder.reset_parameters()
[docs] def forward(self, data): """ Encode node features with degree information. Args: data: Graph data object. Returns: torch.Tensor: Encoded node features. """ row, col = data.edge_index x_degree = degree(col, data.x.size(0), dtype=data.x.dtype) node_feature = self.node_encoder(data.x) node_feature += self.degree_encoder(x_degree.long()) return node_feature
[docs] class TIGER(torch.nn.Module): """ TIGER model for knowledge graph-enhanced DDI prediction. Features: - Molecular graph transformer - Knowledge graph transformer - Mutual information maximization - Multi-view representation learning Args: max_layer: Maximum number of transformer layers (default=6). num_features_drug: Number of drug features (default=78). num_nodes: Number of nodes in knowledge graph (default=200). num_relations_mol: Number of molecular relation types (default=10). num_relations_graph: Number of graph relation types (default=10). output_dim: Output dimension (default=64). max_degree_graph: Maximum degree in molecular graph (default=100). max_degree_node: Maximum degree in knowledge graph (default=100). sub_coeff: Subgraph coefficient for loss (default=0.2). mi_coeff: Mutual information coefficient for loss (default=0.5). dropout: Dropout rate (default=0.2). device: Computation device (default='cuda'). num_rel: Number of relation types for classification. args: Configuration arguments. """
[docs] def __init__(self, max_layer = 6, num_features_drug = 78, num_nodes = 200, num_relations_mol = 10, num_relations_graph = 10, output_dim=64, max_degree_graph=100, max_degree_node=100, sub_coeff = 0.2, mi_coeff = 0.5, dropout=0.2, device = 'cuda', num_rel = None, args=None): super(TIGER, self).__init__() print("TIGER Loaded") self.device = device self.args = args self.layers = max_layer self.num_features_drug = num_features_drug self.max_degree_graph = max_degree_graph self.max_degree_node = max_degree_node self.mol_coeff = sub_coeff self.mi_coeff = mi_coeff self.dropout = dropout self.mol_atom_feature = NodeFeatures(degree=max_degree_graph, feature_num=num_features_drug, embedding_dim=output_dim, type='graph') self.drug_node_feature = NodeFeatures(degree=max_degree_node, feature_num=num_nodes, embedding_dim=output_dim, type='node') ##学习的模块 self.mol_representation_learning = GraphTransformer(layer_num = max_layer, embedding_dim = output_dim, num_heads = 4, num_rel = num_relations_mol, dropout= dropout, type='graph') self.node_representation_learning = GraphTransformer(layer_num = max_layer, embedding_dim = output_dim, num_heads = 4, num_rel = num_relations_graph, dropout=dropout, type='node') ##Net用统一的代码就可以了,用type指示是哪种类型的学习,或者分开两个模块,然后两个模块里面集合一些公共的模块 self.fc1 = nn.Sequential( nn.Linear(output_dim*2, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, output_dim) ) self.fc2 = nn.Sequential( nn.Linear(output_dim*2, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, 512), nn.ReLU(), nn.Dropout(dropout), nn.Linear(512, num_rel) ) self.disc = Discriminator(output_dim) self.b_xent = BCEWithLogitsLoss() self.reserved_loss = 0 self.cdan_dim = output_dim * 2
[docs] def to(self, device): """ Move model to specified device. Args: device: Target device. Returns: TIGER: Model on target device. """ self.mol_atom_feature.to(device) self.drug_node_feature.to(device) self.mol_representation_learning.to(device) self.node_representation_learning.to(device) self.fc1.to(device) self.fc2.to(device) self.disc.to(device) self.b_xent.to(device) return self
[docs] def loss(self, pred, label): """ Compute supervised loss for DDI prediction. Args: pred: Model predictions. label: Ground truth labels. Returns: torch.Tensor: Combined loss value. """ if self.args.matrix in ['multilabel', 'twosides']: return nn.BCEWithLogitsLoss()(pred, label.float()) + self.reserved_loss else: return nn.CrossEntropyLoss()(pred, label.long()) + self.reserved_loss
[docs] def reset_parameters(self): """ Reset all model parameters. """ self.mol_atom_feature.reset_parameters() self.drug_node_feature.reset_parameters() self.mol_representation_learning.reset_parameters() self.node_representation_learning.reset_parameters()
[docs] def forward(self, data): """ Forward pass of TIGER model. Args: data: Tuple containing (drug1_mol, drug1_subgraph, drug2_mol, drug2_subgraph) Returns: torch.Tensor: DDI prediction scores. """ drug1_mol, drug1_subgraph, drug2_mol, drug2_subgraph = data[0].to(self.device), data[1].to(self.device), data[2].to(self.device), data[3].to(self.device) mol1_atom_feature = self.mol_atom_feature(drug1_mol) mol2_atom_feature = self.mol_atom_feature(drug2_mol) drug1_node_feature = self.drug_node_feature(drug1_subgraph) drug2_node_feature = self.drug_node_feature(drug2_subgraph) mol1_graph_embedding, mol1_atom_embedding, mol1_attn = self.mol_representation_learning(mol1_atom_feature, drug1_mol) mol2_graph_embedding, mol2_atom_embedding, mol2_attn = self.mol_representation_learning(mol2_atom_feature, drug2_mol) drug1_node_embedding, drug1_sub_embedding, drug1_attn = self.node_representation_learning(drug1_node_feature, drug1_subgraph) drug2_node_embedding, drug2_sub_embedding, drug2_attn = self.node_representation_learning(drug2_node_feature, drug2_subgraph) drug1_embedding = self.fc1(torch.cat([drug1_node_embedding, mol1_graph_embedding], dim=-1)) drug2_embedding = self.fc1(torch.cat([drug2_node_embedding, mol2_graph_embedding], dim=-1)) final_layer = torch.cat([drug1_embedding, drug2_embedding], dim=-1) score = self.fc2(final_layer) loss_s_m = self.loss_MI(self.MI(drug1_embedding, mol1_atom_embedding)) + self.loss_MI(self.MI(drug2_embedding, mol2_atom_embedding)) loss_s_d = self.loss_MI(self.MI(drug1_embedding, drug1_sub_embedding)) + self.loss_MI(self.MI(drug2_embedding, drug2_sub_embedding)) self.reserved_loss = self.mol_coeff* loss_s_m + self.mi_coeff * loss_s_d return score
[docs] def MI(self, graph_embeddings, sub_embeddings): """ Compute mutual information between graph and subgraph embeddings. Args: graph_embeddings: Graph-level embeddings. sub_embeddings: Subgraph-level embeddings. Returns: torch.Tensor: Discriminator logits. """ idx = torch.arange(graph_embeddings.shape[0] - 1, -1, -1) if graph_embeddings.shape[0] > 2: idx[len(idx) // 2] = idx[len(idx) // 2 + 1] shuffle_embeddings = torch.index_select(graph_embeddings, 0, idx.to(self.device)) c_0_list, c_1_list = [], [] for c_0, c_1, sub in zip(graph_embeddings, shuffle_embeddings, sub_embeddings): c_0_list.append(c_0.expand_as(sub)) ##pos c_1_list.append(c_1.expand_as(sub)) ##neg c_0, c_1, sub = torch.cat(c_0_list), torch.cat(c_1_list), torch.cat(sub_embeddings) return self.disc(sub, c_0, c_1)
[docs] def loss_MI(self, logits): """ Compute mutual information loss. Args: logits: Discriminator logits. Returns: torch.Tensor: Binary cross-entropy loss. """ num_logits = logits.shape[0] // 2 temp = torch.rand(num_logits) lbl = torch.cat([torch.ones_like(temp), torch.zeros_like(temp)], dim=0).float().to(self.device) return self.b_xent(logits.view([1,-1]), lbl.view([1, -1]))
[docs] def save(self, path): """ Save model state to file. Args: path: Directory path for saving. Returns: str: Path to saved model file. """ save_path = os.path.join(path, self.__class__.__name__+'.pt') torch.save(self.state_dict(), save_path) return save_path
[docs] class Discriminator(nn.Module): """ Discriminator for mutual information estimation. Args: n_h: Hidden dimension size. """
[docs] def __init__(self, n_h): super(Discriminator, self).__init__() self.f_k = nn.Bilinear(n_h, n_h, 1) for m in self.modules(): self.weights_init(m)
[docs] def weights_init(self, m): """ Initialize discriminator weights. Args: m: Module to initialize. """ if isinstance(m, nn.Bilinear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): """ Forward pass for discriminator. Args: c: Context embeddings. h_pl: Positive embeddings. h_mi: Negative embeddings. s_bias1: Optional bias for positive scores. s_bias2: Optional bias for negative scores. Returns: torch.Tensor: Discriminator logits. """ c_x = c sc_1 = self.f_k(h_pl, c_x) sc_2 = self.f_k(h_mi, c_x) if s_bias1 is not None: sc_1 += s_bias1 if s_bias2 is not None: sc_2 += s_bias2 logits = torch.cat((sc_1, sc_2), 0) return logits
[docs] def init_params(module, layers=2): """ Initialize module parameters. Args: module: Module to initialize. layers: Number of layers for initialization scaling (default=2). """ if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(layers)) if module.bias is not None: module.bias.data.zero_() if isinstance(module, torch.nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02)