Source code for openddi.models.tiger.GraphTransformer
# -*- coding: utf-8 -*-
# @Time : 2023/5/30 下午4:14
# @Author : xiaorui su
# @Email : suxiaorui19@mails.ucas.edu.cn
# @File : GraphTransformer.py
# @Software : PyCharm
import os
import sys
BASEDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASEDIR)
import torch
import math
from torch.nn import TransformerEncoderLayer, TransformerEncoder, BCEWithLogitsLoss
from torch_geometric.nn import GCNConv, SAGEConv, GCN2Conv, GATConv, ECConv, global_mean_pool, GINConv
from torch.nn import Linear, Sequential, ReLU
from torch_geometric.nn.conv import MessagePassing
[docs]
class GraphTransformerEncode(torch.nn.Module):
"""
Graph Transformer encoder layer.
Args:
num_heads: Number of attention heads.
in_dim: Input dimension.
dim_forward: Feed-forward dimension.
rel_encoder: Relation encoder module.
spatial_encoder: Spatial encoder module.
dropout: Dropout rate.
"""
[docs]
def __init__(self, num_heads, in_dim, dim_forward, rel_encoder, spatial_encoder, dropout):
super(GraphTransformerEncode, self).__init__()
self.num_heads = num_heads
self.in_dim = in_dim
self.dim_forward = dim_forward
self.ffn = Sequential(
Linear(self.in_dim, self.dim_forward),
ReLU(),
Linear(self.dim_forward, self.in_dim)
)
self.multiHeadAttention = MultiheadAttention(dim_model = self.in_dim, num_heads = self.num_heads, rel_encoder=rel_encoder, spatial_encoder = spatial_encoder)
self.layernorm1 = torch.nn.LayerNorm(normalized_shape=in_dim, eps=1e-6)
self.layernorm2 = torch.nn.LayerNorm(normalized_shape=in_dim, eps=1e-6)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
[docs]
def reset_parameters(self):
"""
Reset model parameters.
"""
self.ffn[0].reset_parameters()
self.ffn[2].reset_parameters()
self.multiHeadAttention.reset_parameters()
self.layernorm1.reset_parameters()
self.layernorm2.reset_parameters()
[docs]
def forward(self, feature, sp_edge_index, sp_value, edge_rel):
"""
Forward pass for graph transformer encoder.
Args:
feature: Node features.
sp_edge_index: Shortest path edge indices.
sp_value: Shortest path distances.
edge_rel: Edge relation types.
Returns:
tuple: (output features, attention weights)
"""
x_norm = self.layernorm1(feature)
attn_output, attn_weight = self.multiHeadAttention(x_norm, sp_edge_index, sp_value, edge_rel)
attn_output = self.dropout1(attn_output)
out1 = attn_output + feature
residual = out1
out1_norm = self.layernorm2(out1)
ffn_output = self.ffn(out1_norm)
ffn_output = self.dropout2(ffn_output)
out2 = residual + ffn_output
return out2, attn_weight
[docs]
class SpatialEncoding(torch.nn.Module):
"""
Spatial encoding module for graph distances.
Args:
dim_model: Dimension of spatial encoding.
"""
[docs]
def __init__(self, dim_model):
super(SpatialEncoding, self).__init__()
self.dim = dim_model
self.fnn = Sequential(
Linear(1, dim_model),
ReLU(),
Linear(dim_model, 1),
ReLU()
)
[docs]
def reset_parameters(self):
"""
Reset model parameters.
"""
self.fnn[0].reset_parameters()
self.fnn[2].reset_parameters()
[docs]
def forward(self, lap):
"""
Encode spatial distances.
Args:
lap: Shortest path distances.
Returns:
torch.Tensor: Spatial encoding.
"""
lap_ = torch.unsqueeze(lap, dim=-1) ##[n_edges, 1]
out = self.fnn(lap_)
return out
[docs]
class MultiheadAttention(MessagePassing):
"""
Multi-head attention for graph transformer.
Args:
dim_model: Model dimension.
num_heads: Number of attention heads.
rel_encoder: Relation encoder.
spatial_encoder: Spatial encoder.
**kwargs: Additional MessagePassing arguments.
"""
[docs]
def __init__(self, dim_model, num_heads, rel_encoder, spatial_encoder, **kwargs):
kwargs.setdefault('aggr', 'add')
super().__init__(**kwargs)
self.d_model = dim_model
self.num_heads = num_heads
self.rel_embedding = rel_encoder
self.rel_encoding = Sequential(
Linear(dim_model, 1),
ReLU()
)
self.spatial_encoding = spatial_encoder
assert dim_model % num_heads == 0
self.depth = self.d_model // num_heads
self.wq = Linear(dim_model, dim_model)
self.wk = Linear(dim_model, dim_model)
self.wv = Linear(dim_model, dim_model)
self.dense = Linear(dim_model, dim_model)
[docs]
def reset_parameters(self):
"""
Reset model parameters.
"""
self.rel_embedding.reset_parameters()
self.rel_encoding[0].reset_parameters()
self.spatial_encoding.reset_parameters()
self.wq.reset_parameters()
self.wk.reset_parameters()
self.wv.reset_parameters()
self.dense.reset_parameters()
[docs]
def softmax_kernel_transformation(self, data, is_query, projection_matrix=None, numerical_stabilizer=0.000001):
"""
Apply softmax kernel transformation for efficient attention.
Args:
data: Input data.
is_query: Whether input is query.
projection_matrix: Projection matrix.
numerical_stabilizer: Numerical stability term.
Returns:
torch.Tensor: Transformed data.
"""
data_normalizer = 1.0 / torch.sqrt(torch.sqrt(torch.tensor(data.shape[-1], dtype=torch.float32)))
data = data_normalizer * data
ratio = data_normalizer
data_dash = projection_matrix(data) ##[node_num, dim]
diag_data = torch.square(data)
diag_data = torch.sum(diag_data, dim=len(data.shape) - 1)
diag_data = diag_data / 2.0
diag_data = torch.unsqueeze(diag_data, dim=len(data.shape) - 1)
last_dims_t = len(data_dash.shape) - 1
attention_dims_t = len(data_dash.shape) - 3
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(data_dash, dim=last_dims_t, keepdim=True)[
0]) + numerical_stabilizer
)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(torch.max(data_dash, dim=last_dims_t, keepdim=True)[0],
dim=attention_dims_t, keepdim=True)[
0]) + numerical_stabilizer
)
return data_dash
[docs]
def denominator(self, qs, ks):
"""
Compute attention denominator for normalization.
Args:
qs: Query embeddings.
ks: Key embeddings.
Returns:
torch.Tensor: Normalization denominator.
"""
##qs [num_node, num_heads, depth]
all_ones = torch.ones([ks.shape[0]]).to(qs.device)
ks_sum = torch.einsum("nhm,n->hm", ks, all_ones) # ks_sum refers to O_k in the paper
return torch.einsum("nhm,hm->nh", qs, ks_sum)
[docs]
def forward(self, x, sp_edge_index, sp_value, edge_rel):
"""
Forward pass for multi-head attention.
Args:
x: Node features.
sp_edge_index: Shortest path edge indices.
sp_value: Shortest path distances.
edge_rel: Edge relation types.
Returns:
tuple: (attention output, attention weights)
"""
rel_embedding = self.rel_embedding(edge_rel)
q = self.wq(x)
k = self.wk(x)
v = self.wv(x).view(x.shape[0],self.num_heads,self.depth) ##[nodes_num, num_heads, depth]
row, col = sp_edge_index
query_end, key_start = q[col], k[row] ##[edge_nums, num_heads, depths]
query_end += rel_embedding
key_start += rel_embedding
query_end = query_end.view(sp_edge_index.shape[1],self.num_heads,self.depth)
key_start = key_start.view(sp_edge_index.shape[1],self.num_heads,self.depth)
edge_attn_num = torch.einsum("ehd,ehd->eh", query_end, key_start) ##[edge_nums, num_heads]
data_normalizer = 1.0 / torch.sqrt(torch.sqrt(torch.tensor(edge_attn_num.shape[-1], dtype=torch.float32)))
edge_attn_num *= data_normalizer
edge_attn_bias = self.spatial_encoding(sp_value)
edge_attn_num += edge_attn_bias
attn_normalizer = self.denominator(q.view(x.shape[0],self.num_heads,self.depth), k.view(x.shape[0],self.num_heads,self.depth))
edge_attn_dem = attn_normalizer[col] ##[edge_nums, num_heads]
attention_weight = edge_attn_num / edge_attn_dem ##[edge_nums, num_heads] ##scaled
outputs = []
for i in range(self.num_heads):
output_per_head = self.propagate(edge_index=sp_edge_index, x = v[:,i,:], edge_weight = attention_weight[:, i], size=None)
outputs.append(output_per_head)
out = torch.cat(outputs,dim=-1)
return self.dense(out), attention_weight
[docs]
class GraphTransformer(torch.nn.Module):
"""
Graph Transformer model for molecular representation learning.
Args:
layer_num: Number of transformer layers (default=3).
embedding_dim: Embedding dimension (default=64).
num_heads: Number of attention heads (default=4).
num_rel: Number of relation types (default=10).
dropout: Dropout rate (default=0.2).
type: Type of representation ('graph' or 'node') (default='graph').
"""
[docs]
def __init__(self, layer_num = 3, embedding_dim = 64, num_heads = 4, num_rel = 10, dropout = 0.2, type = 'graph'): ##type指示的是graph还是node,也就是对应的是图级别的表示学习,还是节点级别的表示学习
super(GraphTransformer, self).__init__()
self.type = type
self.rel_encoder = torch.nn.Embedding(num_rel, embedding_dim) # Weight sharing
self.spatial_encoder = SpatialEncoding(embedding_dim) # Weight sharing
self.encoder = torch.nn.ModuleList()
for i in range(layer_num - 1):
self.encoder.append(GraphTransformerEncode(num_heads = num_heads, in_dim = embedding_dim, dim_forward = embedding_dim*2,
rel_encoder = self.rel_encoder, spatial_encoder = self.spatial_encoder, dropout=dropout))
[docs]
def reset_parameters(self):
"""
Reset model parameters.
"""
for e in self.encoder:
e.reset_parameters()
[docs]
def forward(self, feature, data):
"""
Forward pass for graph transformer.
Args:
feature: Input node features.
data: Graph data object.
Returns:
tuple: (graph representation, subgraph representations, attention weights)
"""
# First, compute attn_weight according to edge index, then aggregate according to the weights!
x = feature
graph_embedding_layer = []
attn_layer = []
for graphEncoder in self.encoder:
x, attn = graphEncoder(x, data.sp_edge_index, data.sp_value, data.sp_edge_rel)
graph_embedding_layer.append(x)
attn_layer.append(attn)
# all_out = torch.stack([x for x in graph_embedding_layer])
if self.type == 'graph':
# Pooling
sub_representation = []
for index, drug_mol_graph in enumerate(data.to_data_list()):
# The representation of each node in the index-th graph, shape: [atom_number, emd_dim]
sub_embedding = x[(data.batch == index).nonzero().flatten()]
sub_representation.append(sub_embedding)
representation = global_mean_pool(x, batch=data.batch) # Representation of each drug molecule's graph
else:
# only return the first one
sub_representation = []
for index, drug_subgraph in enumerate(data.to_data_list()):
sub_embedding = x[(data.batch == index).nonzero().flatten()]
# print(sub_embedding.shape)
sub_representation.append(sub_embedding) # only take the embedding of that node
# print(x.shape)
# print(data.id.shape)
representation = x[data.id.nonzero().flatten()]
return representation, sub_representation, attn_layer
# For node-level representation, we need to cascade each layer and then do the final mutual information maximization. This layer-level optimization may need to be considered, but in the end, it still falls on the nodes and the graph.