import torch
from torch import nn
import torch.nn.functional as F
import math
import copy
from torch.nn.parameter import Parameter
[docs]
class MVA(nn.Module):
"""
Multi-View Attention model for DDI prediction.
Features:
- Graph convolutional networks for molecular structure
- Transformer-based sequence encoders for SMILES
- Attention-based feature fusion
- Supports both multiclass and multilabel classification
Args:
args: Configuration arguments.
gcn_in_features: Input dimension for GCN.
gcn_out_features: Output dimension for GCN.
num_rel: Number of relation types.
bias: Whether to use bias in GCN (default=True).
"""
[docs]
def __init__(self, args, gcn_in_features, gcn_out_features, num_rel, bias=True):
super(MVA, self).__init__()
# gcn Parameters
self.args = args
self.num_rel = num_rel
self.gcn_in_features = gcn_in_features
self.gcn_out_features = gcn_out_features
# Parameter用于将参数自动加入到参数列表
self.weight = Parameter(torch.FloatTensor(gcn_in_features, gcn_out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(gcn_out_features))
else:
self.register_parameter('bias', None) # 为模型添加参数
self.reset_parameters()
self.fusionsize = 128
self.max_d = 50
self.input_dim_drug = 23532
self.n_layer = 2
self.emb_size = 384
self.dropout_rate = 0
# encoder
self.hidden_size = 384
self.intermediate_size = 1536
self.num_attention_heads = 4 # 2 4 8
self.attention_probs_dropout_prob = 0.1
self.hidden_dropout_prob = 0.1
# specialized embedding with positional one
self.emb = Embeddings(self.input_dim_drug, self.emb_size, self.max_d, self.dropout_rate)
self.d_encoder = Encoder_MultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size,
self.num_attention_heads, self.attention_probs_dropout_prob,
self.hidden_dropout_prob)
# self.p_encoder = Encoder_MultipleLayers(self.n_layer, self.hidden_size, self.intermediate_size,
# self.num_attention_heads, self.attention_probs_dropout_prob,
# self.hidden_dropout_prob)
self.fusion = AFF(self.fusionsize)
# dencoder
self.decoder_trans_mpnn_cat = nn.Sequential(
nn.Linear(128 * 2, 64),
nn.Dropout(0.1),
nn.ReLU(True),
nn.BatchNorm1d(64),
nn.Linear(64, 32),
nn.ReLU(True),
# output layer
nn.Linear(32, self.num_rel)
)
self.decoder_1 = nn.Sequential(
nn.Linear(50 * 384, 512),
nn.ReLU(True),
nn.BatchNorm1d(512),
nn.Linear(512, 128)
)
self.flatten = nn.Flatten()
self.decoder_2 = nn.Sequential(
nn.Linear(600 * 128, 512),
nn.ReLU(True),
nn.BatchNorm1d(512),
nn.Linear(512, 128)
)
self.decoder_3 = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(True),
nn.BatchNorm1d(32),
nn.Linear(32, 1)
)
self.query_proj = nn.Linear(256, 256 * 2, bias=False)
self.key_proj = nn.Linear(256, 256 * 2, bias=False)
self.value_proj = nn.Linear(256, 256 * 2, bias=False)
self.output_proj = nn.Linear(256 * 2, 256, bias=False)
[docs]
def aggregate_message_1(self, nodes, node_neighbours, edges, mask):
"""
Aggregate messages for first view (to be implemented).
"""
raise NotImplementedError
[docs]
def aggregate_message_2(self, nodes, node_neighbours, edges, mask):
"""
Aggregate messages for second view (to be implemented).
"""
raise NotImplementedError
# inputs are "batches" of shape (maximum number of nodes in batch, number of features)
[docs]
def update_1(self, nodes, messages):
"""
Update node representations for first view (to be implemented).
"""
raise NotImplementedError
[docs]
def update_2(self, nodes, messages):
"""
Update node representations for second view (to be implemented).
"""
raise NotImplementedError
# inputs are "batches" of same shape as the nodes passed to update
# node_mask is same shape as inputs and is 1 if elements corresponding exists, otherwise 0
[docs]
def readout_1(self, hidden_nodes, input_nodes, node_mask):
"""
Readout function for first view (to be implemented).
"""
raise NotImplementedError
[docs]
def readout_2(self, hidden_nodes, input_nodes, node_mask):
"""
Readout function for second view (to be implemented).
"""
raise NotImplementedError
[docs]
def readout(self,input_nodes, node_mask):
"""
General readout function (to be implemented).
"""
raise NotImplementedError
[docs]
def final_layer(self,out):
"""
Final layer processing (to be implemented).
"""
raise NotImplementedError
[docs]
def reset_parameters(self):
"""
Reset model parameters using uniform initialization.
"""
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
[docs]
def forward(self, data):
"""
Forward pass of MVA model.
Args:
data: Tuple containing molecular graph and sequence data.
Returns:
torch.Tensor: Logits for DDI prediction.
"""
# node_tensor_1, adjacency_tensor_1, node_tensor_2, adjacency_tensor_2, num_size_tensor, target_tensor, d1_emb_tensor, d2_emb_tensor, mask_1_tensor, mask_2_tensor
# 把数据都放到device上
fts_1, adjs_1, fts_2, adjs_2, num_size, _, de_1, de_2, _, _ = data
fts_1 = fts_1.to(self.args.device)
adjs_1 = adjs_1.to(self.args.device)
fts_2 = fts_2.to(self.args.device)
adjs_2 = adjs_2.to(self.args.device)
de_1 = de_1.to(self.args.device)
de_2 = de_2.to(self.args.device)
num_size = num_size.size(0)
# GCN encoder
paddingsize = 600
device = self.weight.device
fts_padding_1 = torch.zeros(fts_1.size()[0], paddingsize, 75, dtype=torch.float32, device=device)
adjs_padding_1 = torch.zeros(adjs_1.size()[0], paddingsize, paddingsize, dtype=torch.float32, device=device)
fts_padding_2 = torch.zeros(fts_2.size()[0], paddingsize, 75, dtype=torch.float32, device=device)
adjs_padding_2 = torch.zeros(adjs_2.size()[0], paddingsize, paddingsize, dtype=torch.float32, device=device)
fts_padding_1[:, :fts_1.size()[1], :] = fts_1
adjs_padding_1[:, :fts_1.size()[1], :fts_1.size()[1]] = adjs_1
fts_padding_2[:, :fts_2.size()[1], :] = fts_2
adjs_padding_2[:, :fts_2.size()[1], :fts_2.size()[1]] = adjs_2
support_1 = torch.matmul(fts_padding_1, self.weight)
support_2 = torch.matmul(fts_padding_2, self.weight)
output_1 = torch.matmul(adjs_padding_1, support_1)
output_2 = torch.matmul(adjs_padding_2, support_2)
if self.bias is not None:
output_1 = output_1 + self.bias
output_2 = output_2 + self.bias
# Sequence encoder
ex_d_mask = de_1.unsqueeze(1).unsqueeze(2)
ex_p_mask = de_2.unsqueeze(1).unsqueeze(2)
ex_d_mask = (1.0 - ex_d_mask) * -10000.0
ex_p_mask = (1.0 - ex_p_mask) * -10000.0
d_emb = self.emb(de_1) # num_size x seq_length x embed_size
p_emb = self.emb(de_2)
# set output_all_encoded_layers be false, to obtain the last layer hidden states only...
d_encoded_layers = self.d_encoder(d_emb.float(), ex_d_mask.float())
p_encoded_layers = self.d_encoder(p_emb.float(), ex_p_mask.float())
d1_trans_fts = d_encoded_layers.view(num_size, -1)
d2_trans_fts = p_encoded_layers.view(num_size, -1)
d1_trans_fts_layer1 = self.decoder_1(d1_trans_fts)
d2_trans_fts_layer1 = self.decoder_1(d2_trans_fts)
output_1 = self.decoder_2(self.flatten(output_1))
output_2 = self.decoder_2(self.flatten(output_2))
output1 = self.fusion(d1_trans_fts_layer1, output_1)
output2 = self.fusion(d2_trans_fts_layer1, output_2)
final_fts_cat = torch.cat((output1, output2), dim=1)
result = self.decoder_trans_mpnn_cat(final_fts_cat)
return result
[docs]
def loss(self, logits, labels):
"""
Compute supervised loss for DDI edge classification.
Args:
logits: Model predictions.
labels: Ground truth labels.
Returns:
torch.Tensor: Loss value.
"""
task = getattr(self.args, 'matrix', 'multiclass')
if task in ['multilabel', 'twosides']:
labels = labels.float()
return nn.BCEWithLogitsLoss()(logits, labels)
else:
return nn.CrossEntropyLoss()(logits, labels.long())
[docs]
class AFF(nn.Module):
"""
Attention Feature Fusion module for multi-view feature integration.
Args:
channels: Input channel dimension (default=128).
r: Reduction ratio for intermediate channels (default=4).
"""
[docs]
def __init__(self, channels=128, r=4):
super(AFF, self).__init__()
inter_channels = int(channels // r)
self.local_att = nn.Sequential(
nn.Conv2d(1, inter_channels, kernel_size=(1, 128), stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, 1, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(1),
)
self.global_att = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 128)),
nn.Conv2d(1, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, 1, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(1),
)
self.sigmoid = nn.Sigmoid()
[docs]
def forward(self, x, y):
"""
Forward pass for attention feature fusion.
Args:
x: First feature tensor.
y: Second feature tensor.
Returns:
torch.Tensor: Fused feature tensor.
"""
batch_size, feature_size = x.size()
# Reshape x and y as 2D images
x = x.view(batch_size, 1, 1, feature_size)
y = y.view(batch_size, 1, 1, feature_size)
xy = x + y
xl = self.local_att(xy)
xg = self.global_att(xy)
xlg = xl + xg
wei = self.sigmoid(xlg.squeeze(dim=2).squeeze(dim=2))
wei_new = wei.squeeze(dim=1)
wei_new = torch.mean(wei_new, dim=1, keepdim=True)
# print(wei_new)
xo = x.squeeze(dim=2).squeeze(dim=2) * wei + y.squeeze(dim=2).squeeze(dim=2) * (1 - wei)
xo = xo.squeeze(dim=1)
return xo
# sub-transformer
[docs]
class LayerNorm(nn.Module):
"""
Layer normalization module.
Args:
hidden_size: Hidden dimension size.
variance_epsilon: Small epsilon for numerical stability (default=1e-12).
"""
[docs]
def __init__(self, hidden_size, variance_epsilon=1e-12):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = variance_epsilon
[docs]
def forward(self, x):
"""
Apply layer normalization.
Args:
x: Input tensor.
Returns:
torch.Tensor: Normalized tensor.
"""
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.gamma * x + self.beta
[docs]
class Embeddings(nn.Module):
"""
Construct embeddings from sequence tokens and positional information.
Args:
vocab_size: Vocabulary size.
hidden_size: Hidden dimension size.
max_position_size: Maximum sequence length.
dropout_rate: Dropout rate.
"""
[docs]
def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate):
super(Embeddings, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
[docs]
def forward(self, input_ids):
"""
Create embeddings for input tokens.
Args:
input_ids: Token indices.
Returns:
torch.Tensor: Combined token and position embeddings.
"""
input_ids = input_ids.long()
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
[docs]
class SelfAttention(nn.Module):
"""
Self-attention module.
Args:
hidden_size: Hidden dimension size.
num_attention_heads: Number of attention heads.
attention_probs_dropout_prob: Dropout probability for attention.
"""
[docs]
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
super(SelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
[docs]
def transpose_for_scores(self, x):
"""
Reshape tensor for multi-head attention.
Args:
x: Input tensor.
Returns:
torch.Tensor: Reshaped tensor for attention computation.
"""
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
[docs]
def forward(self, hidden_states, attention_mask):
"""
Compute self-attention.
Args:
hidden_states: Input hidden states.
attention_mask: Attention mask.
Returns:
torch.Tensor: Attention output.
"""
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
[docs]
class SelfOutput(nn.Module):
"""
Output module for self-attention with residual connection.
Args:
hidden_size: Hidden dimension size.
hidden_dropout_prob: Dropout probability.
"""
[docs]
def __init__(self, hidden_size, hidden_dropout_prob):
super(SelfOutput, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
[docs]
def forward(self, hidden_states, input_tensor):
"""
Apply output transformation with residual connection.
Args:
hidden_states: Attention output.
input_tensor: Original input.
Returns:
torch.Tensor: Output with residual connection.
"""
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
[docs]
class Attention(nn.Module):
"""
Complete attention module with self-attention and output projection.
Args:
hidden_size: Hidden dimension size.
num_attention_heads: Number of attention heads.
attention_probs_dropout_prob: Dropout probability for attention.
hidden_dropout_prob: Dropout probability for hidden layers.
"""
[docs]
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
super(Attention, self).__init__()
self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
self.output = SelfOutput(hidden_size, hidden_dropout_prob)
[docs]
def forward(self, input_tensor, attention_mask):
"""
Apply attention mechanism.
Args:
input_tensor: Input tensor.
attention_mask: Attention mask.
Returns:
torch.Tensor: Attention output.
"""
self_output = self.self(input_tensor, attention_mask) # +注意力
attention_output = self.output(self_output, input_tensor) # +残差
return attention_output
[docs]
class Output(nn.Module):
"""
Output module with residual connection.
Args:
intermediate_size: Intermediate dimension size.
hidden_size: Output hidden dimension.
hidden_dropout_prob: Dropout probability.
"""
[docs]
def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob):
super(Output, self).__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
[docs]
def forward(self, hidden_states, input_tensor):
"""
Apply output transformation with residual connection.
Args:
hidden_states: Intermediate tensor.
input_tensor: Original input.
Returns:
torch.Tensor: Output with residual connection.
"""
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
[docs]
class Encoder(nn.Module):
"""
Transformer encoder block with attention and feed-forward layers.
Args:
hidden_size: Hidden dimension size.
intermediate_size: Intermediate dimension size.
num_attention_heads: Number of attention heads.
attention_probs_dropout_prob: Dropout probability for attention.
hidden_dropout_prob: Dropout probability for hidden layers.
"""
[docs]
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob):
super(Encoder, self).__init__()
self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
self.intermediate = Intermediate(hidden_size, intermediate_size)
self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob)
[docs]
def forward(self, hidden_states, attention_mask):
"""
Apply transformer encoder block.
Args:
hidden_states: Input hidden states.
attention_mask: Attention mask.
Returns:
torch.Tensor: Encoded tensor.
"""
attention_output = self.attention(hidden_states, attention_mask) # 给向量加了残差和注意力机制
intermediate_output = self.intermediate(attention_output) # 给向量拉长
layer_output = self.output(intermediate_output, attention_output) # 把向量带着残差压缩回去
return layer_output
[docs]
class Encoder_MultipleLayers(nn.Module):
"""
Multi-layer transformer encoder.
Args:
n_layer: Number of encoder layers.
hidden_size: Hidden dimension size.
intermediate_size: Intermediate dimension size.
num_attention_heads: Number of attention heads.
attention_probs_dropout_prob: Dropout probability for attention.
hidden_dropout_prob: Dropout probability for hidden layers.
"""
[docs]
def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob):
super(Encoder_MultipleLayers, self).__init__()
layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)])
[docs]
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
"""
Apply multiple transformer encoder layers.
Args:
hidden_states: Input hidden states.
attention_mask: Attention mask.
output_all_encoded_layers: Whether to output all layers (default=True).
Returns:
torch.Tensor: Encoded tensor.
"""
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
return hidden_states