Source code for openddi.models.MUFFIN

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

[docs] class MUFFIN(nn.Module): """ MUFFIN model for drug-drug interaction prediction using convolutional networks. """
[docs] def __init__(self, feature: int, hidden1: int, hidden2: int, num_relations: int, num_classes: int, dropout: float = 0.3, entity_dim = 128): """ Initialize MUFFIN model. Args: feature: Input feature dimension hidden1: First hidden layer dimension hidden2: Second hidden layer dimension num_relations: Number of relation types (unused in current implementation) num_classes: Number of output classes dropout: Dropout rate for regularization entity_dim: Entity embedding dimension """ super().__init__() self.num_relations = int(num_relations) self.num_classes = int(num_classes) self.feature = int(feature) self.num_edge_features = int(num_classes) self.hidden1 = int(hidden1) self.hidden2 = int(hidden2) self.entity_dim = int(entity_dim) if(self.num_classes > 150) : self.entity_dim = 128 self.activate = nn.ReLU() self.conv1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(5, 5)), nn.BatchNorm2d(8), nn.MaxPool2d((2, 2)), nn.ReLU()) self.conv2 = nn.Sequential( nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(5, 5)), nn.BatchNorm2d(8), nn.MaxPool2d((2, 2)), nn.ReLU()) self.liner = nn.Linear(self.feature, self.entity_dim) self.fc1 = nn.Sequential(nn.Linear(29 * 29 * 8, self.entity_dim), nn.BatchNorm1d(self.entity_dim), nn.ReLU(True)) if(self.num_classes > 150) : self.fc1 = nn.Sequential(nn.Linear(29 * 29 * 8, self.entity_dim), nn.BatchNorm1d(self.entity_dim), nn.ReLU(True)) self.fc2 = nn.Sequential(nn.Linear(self.entity_dim, self.entity_dim), nn.ReLU(True)) # self.layer1 = nn.Sequential(nn.Linear(2 * self.entity_dim, self.num_classes)) self.layer1 = nn.Sequential(nn.Linear(2 * self.entity_dim, self.hidden1), nn.BatchNorm1d(self.hidden1), nn.ReLU(True)) self.layer2 = nn.Sequential(nn.Linear(self.hidden1, self.hidden2), nn.BatchNorm1d(self.hidden2), nn.ReLU(True)) self.layer3 = nn.Sequential(nn.Linear(self.hidden2, self.num_classes))
[docs] def forward(self, data_o, idx): """ Forward pass of the MUFFIN model. Args: data_o: Graph data object containing node features idx: Batch indices containing pairs of nodes to process Returns: Output logits for the given node pairs """ x_o = data_o.x # (N, feature) device = x_o.device out1 = self.liner(x_o) out1 = torch.bmm(out1.unsqueeze(2), out1.unsqueeze(1)) out1 = out1.unsqueeze(1) out1 = self.conv1(out1) out1 = self.conv2(out1) out1 = out1.view(out1.size(0), -1) out1 = self.fc1(out1) out1 = self.fc2(out1) a_idx = torch.as_tensor(list(idx[0]), dtype=torch.long, device=device) b_idx = torch.as_tensor(list(idx[1]), dtype=torch.long, device=device) drug1_embed = out1[a_idx] drug2_embed = out1[b_idx] drug_data = torch.cat([drug1_embed, drug2_embed], dim=1) x = self.layer1(drug_data) x = self.layer2(x) x = self.layer3(x) return x