Source code for openddi.data.MRCGNN_dataset
import os
import re
import pandas as pd
import numpy as np
import torch
import random
import argparse
from typing import Optional
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from data.BaseDataset import BaseDataset
[docs]
class MRCGNN_dataset(BaseDataset):
"""
Dataset class for MRCGNN model with support for multi-class and multi-label data.
This class extends BaseDataset to provide MRCGNN-specific data loading logic,
including construction of adversarial samples for contrastive learning.
Args:
args (argparse.ArgumentParser): Command line arguments containing
configuration parameters.
"""
[docs]
def __init__(self,
args:argparse.ArgumentParser):
"""
Initialize the MRCGNN dataset.
Args:
args (argparse.ArgumentParser): Command line arguments containing
configuration parameters.
"""
super().__init__(args)
self.data_s = None
self.data_a = None
self.predict_loader = None
self.index_to_drug_id = None
[docs]
def load_data(self, val_ratio=0.1, test_ratio=0.2):
"""
Load data with validation and test splits.
Args:
val_ratio (float): Ratio of data to use for validation. Defaults to 0.1.
test_ratio (float): Ratio of data to use for testing. Defaults to 0.2.
"""
super().load_data(val_ratio, test_ratio)
if self.args.matrix in ['multilabel', 'twosides']:
self._load_multilabel_data_additional()
else :
self._load_multi_data_additional()
def _load_multi_data_additional(self):
"""
Load multi-class data with MRCGNN-specific logic including adversarial sample construction.
This method constructs adversarial samples for contrastive learning by
randomly perturbing edges/node features and creates MRCGNN-specific data objects.
"""
# Get training data from BaseDataset for graph construction
# Use loaded self.data_o to build additional contrastive learning data
# Adversarial nodes, used for contrastive learning in MRCGNN,
# randomly perturb edges/node features
features_o = self.data_o.x.detach().cpu().numpy()
id_perm = np.random.permutation(features_o.shape[0])
x_a = torch.tensor(features_o[id_perm], dtype=torch.float)
# Get drug list
num_drugs = features_o.shape[0]
y_a = torch.cat((torch.ones(num_drugs, 1), torch.zeros(num_drugs, 1)), dim=1)
# Get edge and edge type information for contrastive learning
edge_index_o = self.data_o.edge_index
edge_types = self.data_o.edge_type
# Build edge type pair data
edge_types_pair = []
for i in range(0, len(edge_types), 2): # Take one from each bidirectional edge
r = edge_types[i].item()
edge_types_pair.extend([r, r])
edge_types_pair = torch.tensor(edge_types_pair, dtype=torch.int64)
# Create MRCGNN-specific data objects
self.data_s = Data(x=x_a, edge_index=edge_index_o, edge_type=edge_types)
self.data_a = Data(x=self.data_o.x, y=y_a, edge_type=edge_types_pair)
def _load_multilabel_data_additional(self):
"""
Load multi-label data with MRCGNN-specific logic including adversarial sample construction.
This method constructs adversarial samples for contrastive learning specifically
for multi-label classification tasks.
"""
# Get training data from BaseDataset for graph construction
# Use loaded self.data_o to build additional contrastive learning data
# Adversarial nodes, used for contrastive learning in MRCGNN,
# randomly perturb edges/node features
features_o = self.data_o.x.detach().cpu().numpy()
id_perm = np.random.permutation(features_o.shape[0])
x_a = torch.tensor(features_o[id_perm], dtype=torch.float)
# Get drug list
num_drugs = features_o.shape[0]
y_a = torch.cat((torch.ones(num_drugs, 1), torch.zeros(num_drugs, 1)), dim=1)
# Get edge and edge type information for contrastive learning
edge_index_o = self.data_o.edge_index
edge_types = self.data_o.edge_type
# For multi-label, edge_types are all 0
edge_types_pair = []
for i in range(0, len(edge_types), 2): # Take one from each bidirectional edge
edge_types_pair.extend([0, 0])
edge_types_pair = torch.tensor(edge_types_pair, dtype=torch.int64)
# Create MRCGNN-specific data objects
self.data_s = Data(x=x_a, edge_index=edge_index_o, edge_type=edge_types)
self.data_a = Data(x=self.data_o.x, y=y_a, edge_type=edge_types_pair)