import os
import re
import pandas as pd
import numpy as np
import torch
import random
import argparse
from collections import defaultdict
from typing import Dict, List, Tuple, Optional, Union
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
# =============================================================================
# Data Loading and Preprocessing Module
# =============================================================================
[docs]
class DataLoadingModule:
"""
Data loading and preprocessing module.
Responsible for reading embedding files, paired data files, and basic data preprocessing.
"""
[docs]
@staticmethod
def read_id_embedding_pt(embedding_paths: Union[str, List[str]]) -> Tuple[Dict[str, np.ndarray], int]:
"""
Read single or multiple modality embedding files and concatenate embedding vectors.
Args:
embedding_paths: Path or list of paths to embedding files.
Returns:
Tuple containing:
- id2vec: Dictionary mapping IDs to concatenated embedding vectors
- total_dim: Total dimension of concatenated embedding vectors
"""
if isinstance(embedding_paths, str):
embedding_paths = [embedding_paths]
id2vec = {}
total_dim = 0
# Iterate through all modality embedding paths
for path in embedding_paths:
if not os.path.isfile(path):
raise FileNotFoundError(f"未找到嵌入文件: {path}")
# Read embedding file
data = torch.load(path)
if not isinstance(data, dict):
raise ValueError(f"期望 pt 文件中是 dict,但得到的是 {type(data)}")
# Convert to numpy arrays, ensuring float32
current_id2vec = {str(k): v.detach().cpu().numpy().astype(np.float32) for k, v in data.items()}
# Get current modality dimension
example_key = next(iter(current_id2vec))
current_dim = current_id2vec[example_key].shape[0]
total_dim += current_dim
# Initialize id2vec or concatenate embeddings
if not id2vec:
id2vec = current_id2vec
else:
# Ensure ID consistency
if set(id2vec.keys()) != set(current_id2vec.keys()):
raise ValueError(f"嵌入文件 {path} 的 ID 集与之前的嵌入不一致")
# Connect the embedding vectors of each ID
for id_ in id2vec:
id2vec[id_] = np.concatenate([id2vec[id_], current_id2vec[id_]], axis=0)
return id2vec, total_dim
[docs]
@staticmethod
def read_id_embedding_pt_split(embedding_paths: List[str]) -> Tuple[Dict[int, Dict[str, np.ndarray]], List[int]]:
"""
Read multiple modality embedding files and store embedding vectors separately.
Args:
embedding_paths: List of paths to embedding files.
Returns:
Tuple containing:
- modal2id2vec: Dictionary mapping modality indices to id2vec dictionaries
- dims: List of embedding vector dimensions for each modality
"""
modal2id2vec = {}
dims = []
for idx, path in enumerate(embedding_paths):
if not os.path.isfile(path):
raise FileNotFoundError(f"未找到嵌入文件: {path}")
data = torch.load(path)
if not isinstance(data, dict):
raise ValueError(f"期望 pt 文件中是 dict,但得到的是 {type(data)}")
current_id2vec = {str(k): v.detach().cpu().numpy().astype(np.float32) for k, v in data.items()}
example_key = next(iter(current_id2vec))
current_dim = current_id2vec[example_key].shape[0]
dims.append(current_dim)
modal2id2vec[idx] = current_id2vec
# Verify that ID sets are consistent across all modalities
id_sets = [set(id2vec.keys()) for id2vec in modal2id2vec.values()]
if len(set(frozenset(id_set) for id_set in id_sets)) > 1:
raise ValueError("不同模态的嵌入文件中 ID 集不一致")
return modal2id2vec, dims
[docs]
@staticmethod
def read_multi_pairs_and_remap(matrix_path: str) -> Tuple[pd.DataFrame, int]:
"""
Read multi-class paired data and perform label remapping.
Args:
matrix_path: Path to multi-class data file.
Returns:
Tuple containing:
- df: Processed DataFrame containing id1, id2, ddi columns
- num_relations: Number of relation types
"""
if not os.path.isfile(matrix_path):
raise FileNotFoundError(f"未找到多分类文件: {matrix_path}")
df = pd.read_csv(matrix_path, dtype=str)
cols_lower = {c.lower(): c for c in df.columns}
for k in ['id1', 'id2', 'ddi']:
if k not in cols_lower:
raise KeyError(f"文件 {matrix_path} 缺少列:{k}")
id1 = df['id1'].astype(str)
id2 = df['id2'].astype(str)
ddi = pd.to_numeric(df['ddi'], errors='coerce')
df2 = pd.DataFrame({'id1': id1, 'id2': id2, 'ddi_raw': ddi})
df2 = df2.dropna(subset=['id1', 'id2', 'ddi_raw']).copy()
df2['ddi_raw'] = df2['ddi_raw'].astype(int)
# Remap to continuous labels based on unique values
unique_raw = np.sort(df2['ddi_raw'].unique())
raw2new = {raw: i for i, raw in enumerate(unique_raw)}
df2['ddi'] = df2['ddi_raw'].map(raw2new).astype(int)
num_relations = len(unique_raw)
return df2[['id1','id2','ddi']], num_relations
[docs]
@staticmethod
def read_multilabel_pairs_and_remap(matrix_path: str) -> Tuple[pd.DataFrame, int]:
"""
Read multi-label paired data.
Args:
matrix_path: Path to multi-label data file.
Returns:
Tuple containing:
- df: Processed DataFrame containing id1, id2, ddi columns
- num_ddi: Number of labels
"""
if not os.path.isfile(matrix_path):
raise FileNotFoundError(f"未找到多标签分类文件 {matrix_path}")
df = pd.read_csv(matrix_path, dtype=str)
cols_lower = {c.lower(): c for c in df.columns}
for k in ['id1', 'id2', 'ddi']:
if k not in cols_lower:
raise KeyError(f"文件 {matrix_path} 缺少列:{k}")
id1 = df['id1'].astype(str)
id2 = df['id2'].astype(str)
ddi = df['ddi'].apply(lambda x: np.array([int(i) for i in str(x).split(',')], dtype=np.float32))
df2 = pd.DataFrame({'id1': id1, 'id2': id2, 'ddi': ddi})
df2 = df2.dropna(subset=['id1', 'id2', 'ddi']).copy()
num_ddi = len(df2['ddi'].iloc[0])
return df2[['id1', 'id2', 'ddi']], num_ddi
# =============================================================================
# Feature Processing and Noise Injection Module
# =============================================================================
[docs]
class FeatureProcessingModule:
"""
Feature processing and noise injection module.
Responsible for feature matrix construction, standardization, noise injection, etc.
"""
[docs]
@staticmethod
def build_feature_matrix(drug_list: List[str], id2vec: Dict[str, np.ndarray],
emb_dim: int, args: argparse.Namespace) -> torch.Tensor:
"""
Build drug feature matrix.
Args:
drug_list: List of drug IDs.
id2vec: Dictionary mapping IDs to embedding vectors.
emb_dim: Embedding dimension.
args: Configuration parameters.
Returns:
Standardized feature matrix.
"""
feats, miss = [], 0
for d in drug_list:
if d in id2vec:
v = id2vec[d]
elif d.upper() in id2vec:
v = id2vec[d.upper()]
elif d.lower() in id2vec:
v = id2vec[d.lower()]
else:
miss += 1
v = np.zeros(emb_dim, dtype=np.float32)
feats.append(v)
feats = np.asarray(feats, dtype=np.float32)
# Feature standardization
feats = FeatureProcessingModule.normalize_features(feats)
# Add Gaussian noise
if getattr(args, 'noise_std', 0.0) > 0:
feats = FeatureProcessingModule.add_gaussian_noise(feats, args.noise_std)
# Sparse dropout
if float(getattr(args, 'sparse_drop_rate', 0.0)) > 0:
feats = FeatureProcessingModule.sparse_dropout(feats, args.sparse_drop_rate)
return torch.tensor(feats, dtype=torch.float32)
[docs]
@staticmethod
def normalize_features(feats: np.ndarray) -> np.ndarray:
"""
Standardize features.
Args:
feats: Original feature matrix.
Returns:
Standardized feature matrix.
"""
return (feats - feats.mean(axis=0)) / (feats.std(axis=0) + 1e-8)
[docs]
@staticmethod
def add_gaussian_noise(feats: np.ndarray, noise_std: float) -> np.ndarray:
"""
Add Gaussian noise.
Args:
feats: Original feature matrix.
noise_std: Noise standard deviation.
Returns:
Feature matrix with added noise.
"""
noise = np.random.normal(0, float(noise_std), feats.shape).astype(np.float32)
return feats + noise
[docs]
@staticmethod
def sparse_dropout(feats: np.ndarray, drop_rate: float) -> np.ndarray:
"""
Apply sparse dropout.
Args:
feats: Original feature matrix.
drop_rate: Dropout rate.
Returns:
Sparsified feature matrix.
"""
mask = (np.random.rand(*feats.shape) > drop_rate).astype(np.float32)
return feats * mask
# =============================================================================
# Data Splitting and Sampling Module
# =============================================================================
[docs]
class DataSplittingModule:
"""
Data splitting and sampling module.
Responsible for dataset splitting, label noise injection, sparse sampling, etc.
"""
[docs]
@staticmethod
def split_data(triples: np.ndarray, val_ratio: float = 0.1,
test_ratio: float = 0.2, random_seed: int = 1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Split dataset.
Args:
triples: Triple data.
val_ratio: Validation set ratio.
test_ratio: Test set ratio.
random_seed: Random seed.
Returns:
Tuple containing:
- train_data: Training set
- val_data: Validation set
- test_data: Test set
"""
rng = np.random.RandomState(random_seed)
rng.shuffle(triples)
n_total = len(triples)
n_test = int(n_total * test_ratio)
n_val = int(n_total * val_ratio)
test_data = triples[:n_test]
val_data = triples[n_test:n_test+n_val]
train_data = triples[n_test+n_val:]
print(f"训练集: {train_data.shape}, 验证集: {val_data.shape}, 测试集: {test_data.shape}")
return train_data, val_data, test_data
@staticmethod
def _entity_holdout_split_indices(triples: np.ndarray, val_ratio: float = 0.1,
test_ratio: float = 0.2, random_seed: int = 1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Split dataset based on drug entities to ensure some drugs in test set do not appear in train/validation set.
Args:
triples (List[Tuple]): Triple data (h, t, r).
val_ratio (float): Validation set ratio (applied to non-test part).
test_ratio (float): Test set ratio (target proportion of total samples).
random_seed (int): Random seed.
Returns:
Tuple[List[int], List[int], List[int]]:
- train_idx: Indices for training set
- val_idx: Indices for validation set
- test_idx: Indices for test set
"""
n_total = len(triples)
if n_total == 0:
return np.array([], dtype=np.int64), np.array([], dtype=np.int64), np.array([], dtype=np.int64)
rng = np.random.RandomState(random_seed)
target_test = max(1, int(np.ceil(n_total * float(test_ratio))))
# Count the sample indices associated with each drug.
drug_to_indices: Dict[int, set] = defaultdict(set)
for idx, (h, t, _) in enumerate(triples):
drug_to_indices[int(h)].add(idx)
drug_to_indices[int(t)].add(idx)
drug_ids = np.array(list(drug_to_indices.keys()))
rng.shuffle(drug_ids)
# Select drugs one by one until the desired number of test samples is reached.
test_indices: set = set()
for d in drug_ids:
test_indices.update(drug_to_indices[d])
if len(test_indices) >= target_test:
break
test_idx = np.array(sorted(test_indices), dtype=np.int64)
# Remaining samples for training/validation, split by ratio
remaining_mask = np.ones(n_total, dtype=bool)
remaining_mask[test_idx] = False
remaining_indices = np.nonzero(remaining_mask)[0]
rng_remaining = rng.permutation(len(remaining_indices)) if len(remaining_indices) > 0 else np.array([], dtype=np.int64)
n_val = int(len(remaining_indices) * float(val_ratio))
val_idx = remaining_indices[rng_remaining[:n_val]] if n_val > 0 else np.array([], dtype=np.int64)
train_idx = remaining_indices[rng_remaining[n_val:]] if len(remaining_indices) > 0 else np.array([], dtype=np.int64)
return train_idx, val_idx, test_idx
[docs]
@staticmethod
def split_data_generalization(triples: np.ndarray, val_ratio: float = 0.1,
test_ratio: float = 0.2, random_seed: int = 1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Based on a dataset categorized by drug entities, we ensure that some drugs in the test set do not appear in the training/validation sets.
"""
train_idx, val_idx, test_idx = DataSplittingModule._entity_holdout_split_indices(
triples, val_ratio, test_ratio, random_seed
)
train_data = triples[train_idx]
val_data = triples[val_idx]
test_data = triples[test_idx]
# Count the drug coverage to confirm no overlap
train_drugs = set(np.unique(train_data[:, :2].reshape(-1))) if len(train_data) > 0 else set()
val_drugs = set(np.unique(val_data[:, :2].reshape(-1))) if len(val_data) > 0 else set()
test_drugs = set(np.unique(test_data[:, :2].reshape(-1))) if len(test_data) > 0 else set()
overlap_train_test = len(train_drugs.intersection(test_drugs))
overlap_val_test = len(val_drugs.intersection(test_drugs))
print(f"[general] 训练集: {train_data.shape}, 验证集: {val_data.shape}, 测试集: {test_data.shape}")
return train_data, val_data, test_data
[docs]
@staticmethod
def split_multilabel_data(triples: np.ndarray, labels: np.ndarray, val_ratio: float = 0.1,
test_ratio: float = 0.2, random_seed: int = 1) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Split multi-label dataset.
Args:
triples: Triple data.
labels: Label data.
val_ratio: Validation set ratio.
test_ratio: Test set ratio.
random_seed: Random seed.
Returns:
Tuple containing:
- train_triples, train_labels: Training set triples and labels
- val_triples, val_labels: Validation set triples and labels
- test_triples, test_labels: Test set triples and labels
"""
rng = np.random.RandomState(random_seed)
idx = rng.permutation(len(triples))
n_total = len(triples)
n_test = int(n_total * test_ratio)
n_val = int(n_total * val_ratio)
test_idx, val_idx, train_idx = idx[:n_test], idx[n_test:n_test+n_val], idx[n_test+n_val:]
train_triples, train_labels = triples[train_idx], labels[train_idx]
val_triples, val_labels = triples[val_idx], labels[val_idx]
test_triples, test_labels = triples[test_idx], labels[test_idx]
return train_triples, train_labels, val_triples, val_labels, test_triples, test_labels
[docs]
@staticmethod
def split_multilabel_data_generalization(triples: np.ndarray, labels: np.ndarray, val_ratio: float = 0.1,
test_ratio: float = 0.2, random_seed: int = 1) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Split multi-label dataset (drug entity split) to ensure some drugs in the test set do not appear in the training/validation sets.
"""
train_idx, val_idx, test_idx = DataSplittingModule._entity_holdout_split_indices(
triples, val_ratio, test_ratio, random_seed
)
train_triples, train_labels = triples[train_idx], labels[train_idx]
val_triples, val_labels = triples[val_idx], labels[val_idx]
test_triples, test_labels = triples[test_idx], labels[test_idx]
train_drugs = set(np.unique(train_triples[:, :2].reshape(-1))) if len(train_triples) > 0 else set()
val_drugs = set(np.unique(val_triples[:, :2].reshape(-1))) if len(val_triples) > 0 else set()
test_drugs = set(np.unique(test_triples[:, :2].reshape(-1))) if len(test_triples) > 0 else set()
print(f"[general] multilabel 训练集: {train_triples.shape}, 验证集: {val_triples.shape}, 测试集: {test_triples.shape}")
print(f"[general] multilabel 训练药物数: {len(train_drugs)}, 验证药物数: {len(val_drugs)}, 测试药物数: {len(test_drugs)}")
print(f"[general] multilabel 训练-测试药物交集: {len(train_drugs.intersection(test_drugs))}, 验证-测试药物交集: {len(val_drugs.intersection(test_drugs))}")
return train_triples, train_labels, val_triples, val_labels, test_triples, test_labels
[docs]
@staticmethod
def add_label_noise_multiclass(train_data: np.ndarray, num_classes: int,
noise_ratio: float, random_seed: int = 1) -> np.ndarray:
"""
Inject label noise for multi-class classification.
Args:
train_data: Training data.
num_classes: Number of classes.
noise_ratio: Noise ratio.
random_seed: Random seed.
Returns:
Training data with added label noise.
"""
if noise_ratio <= 0:
return train_data
rng = np.random.RandomState(random_seed)
n_train = len(train_data)
flip_n = int(n_train * float(noise_ratio))
idx_sel = rng.choice(n_train, size=flip_n, replace=False)
for ii in idx_sel:
y0 = train_data[ii, 2]
cand = [c for c in range(num_classes) if c != y0]
train_data[ii, 2] = rng.choice(cand)
print(f"训练集标签噪声比例: {noise_ratio}, 影响样本数: {len(idx_sel)}")
return train_data
[docs]
@staticmethod
def add_label_noise_multilabel(labels: np.ndarray, noise_ratio: float,
flip_per_label: int = 50, random_seed: int = 1) -> np.ndarray:
"""
Inject label noise for multi-label classification.
Args:
labels: Label matrix.
noise_ratio: Noise ratio.
flip_per_label: Number of flipped labels per sample.
random_seed: Random seed.
Returns:
Label matrix with added noise.
"""
if noise_ratio <= 0:
return labels
rng = np.random.RandomState(random_seed)
n_train = len(labels)
flip_n = int(n_train * float(noise_ratio))
idx_sel = rng.choice(n_train, size=flip_n, replace=False)
num_classes = labels.shape[1]
for ii in idx_sel:
flip_indices = rng.choice(num_classes, size=flip_per_label, replace=False)
labels[idx_sel[ii], flip_indices] = 1.0 - labels[idx_sel[ii], flip_indices]
print(f"训练集标签噪声比例: {noise_ratio}, 影响样本数: {len(idx_sel)}, 每样本翻转标签数: {flip_per_label}")
return labels
[docs]
@staticmethod
def sparse_sampling_multiclass(train_data: np.ndarray, num_classes: int,
sparse_sample_rate: float, random_seed: int = 1) -> np.ndarray:
"""
Perform sparse sampling for multi-class classification.
Args:
train_data: Training data.
num_classes: Number of classes.
sparse_sample_rate: Sampling rate.
random_seed: Random seed.
Returns:
Sampled training data.
"""
if sparse_sample_rate <= 0:
return train_data
if not (0.0 < sparse_sample_rate < 1.0):
raise ValueError("sparse_sample_rate 必须在 (0, 1) 范围内")
rng = np.random.RandomState(random_seed)
labels = train_data[:, 2]
n_train = len(train_data)
# Calculate frequency of each label
label_counts = np.bincount(labels.astype(int), minlength=num_classes)
print(f"采样前各标签频率: {label_counts}")
# Calculate number of samples to keep for each label
keep_ratios = 1.0 - sparse_sample_rate
keep_counts = (label_counts * keep_ratios).astype(int)
keep_counts = np.maximum(keep_counts, 1)
# Initialize list of indices to keep
keep_indices = []
# Sample for each label
for label in range(num_classes):
label_indices = np.where(labels == label)[0]
n_keep = keep_counts[label]
if n_keep >= len(label_indices):
keep_indices.extend(label_indices)
else:
selected_indices = rng.choice(label_indices, size=n_keep, replace=False)
keep_indices.extend(selected_indices)
train_data = train_data[keep_indices]
print(f"采样后训练集大小: {train_data.shape}")
print(f"采样后各标签频率: {np.bincount(train_data[:, 2].astype(int), minlength=num_classes)}")
return train_data
# =============================================================================
# Graph Construction Module
# =============================================================================
[docs]
class GraphConstructionModule:
"""
Graph construction module.
Responsible for building graph structures from training data,
supporting multi-relation and single-relation graphs.
"""
[docs]
@staticmethod
def build_multigraph(train_data: np.ndarray, network_ratio: float = 1.0,
random_seed: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Build multi-relation graph.
Args:
train_data: Training data.
network_ratio: Edge usage ratio.
random_seed: Random seed.
Returns:
Tuple containing:
- edge_index: Edge index tensor
- edge_type: Edge type tensor
"""
use_ratio = float(network_ratio)
if use_ratio <= 0 or use_ratio > 1:
use_ratio = 1.0
edges = train_data
if use_ratio < 1.0:
keep = int(max(1, round(edges.shape[0] * use_ratio)))
sel = np.random.RandomState(random_seed).permutation(edges.shape[0])[:keep]
edges = edges[sel]
print(f"[graph] using edge_ratio={use_ratio} -> {keep}/{train_data.shape[0]} edges for RGCN graph.")
else:
print(f"[graph] using all {train_data.shape[0]} edges for RGCN graph.")
edge_index, edge_type = [], []
for i, j, r in edges:
i = int(i); j = int(j); r = int(r)
edge_index.append([i, j]); edge_type.append(r)
edge_index.append([j, i]); edge_type.append(r)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_type = torch.tensor(edge_type, dtype=torch.long)
return edge_index, edge_type
[docs]
@staticmethod
def build_single_relation_graph(train_triples: np.ndarray, network_ratio: float = 1.0,
random_seed: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Build single-relation graph.
Args:
train_triples: Training triples.
network_ratio: Edge usage ratio.
random_seed: Random seed.
Returns:
Tuple containing:
- edge_index: Edge index tensor
- edge_type: Edge type tensor (all zeros)
"""
use_ratio = float(network_ratio)
if use_ratio <= 0 or use_ratio > 1:
use_ratio = 1.0
edges = train_triples
if use_ratio < 1.0:
keep = int(max(1, round(edges.shape[0] * use_ratio)))
sel = np.random.RandomState(random_seed).permutation(edges.shape[0])[:keep]
edges = edges[sel]
print(f"[graph] using edge_ratio={use_ratio} -> {keep}/{len(train_triples)} edges for RGCN graph.")
else:
print(f"[graph] using all {len(train_triples)} edges for RGCN graph.")
edge_index, edge_type = [], []
for i, j, _ in edges:
i = int(i); j = int(j); r = 0
edge_index.append([i, j]); edge_type.append(r)
edge_index.append([j, i]); edge_type.append(r)
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
edge_type = torch.tensor(edge_type, dtype=torch.long)
return edge_index, edge_type
# =============================================================================
# DataLoader Creation Module
# =============================================================================
[docs]
class DataLoaderCreationModule:
"""
DataLoader creation module.
Responsible for creating and configuring DataLoaders.
"""
[docs]
@staticmethod
def create_dataloader_config(args: argparse.Namespace) -> Dict:
"""
Create DataLoader configuration.
Args:
args: Configuration parameters.
Returns:
DataLoader configuration dictionary.
"""
params = {
'batch_size': args.batch,
'shuffle': False,
'num_workers': int(getattr(args, 'workers', 0)),
'drop_last': False,
'pin_memory': False,
'persistent_workers': False,
}
if params['num_workers'] > 0:
params['prefetch_factor'] = 1
return params
[docs]
@staticmethod
def create_multiclass_dataloaders(train_data: np.ndarray, val_data: np.ndarray,
test_data: np.ndarray, args: argparse.Namespace) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Create multi-class DataLoaders.
Args:
train_data: Training data.
val_data: Validation data.
test_data: Test data.
args: Configuration parameters.
Returns:
Tuple containing:
- train_loader: Training DataLoader
- val_loader: Validation DataLoader
- test_loader: Test DataLoader
"""
params = DataLoaderCreationModule.create_dataloader_config(args)
train_loader = DataLoader(BaseMultiDataset(train_data), **{**params, 'shuffle': True})
val_loader = DataLoader(BaseMultiDataset(val_data), **params)
test_loader = DataLoader(BaseMultiDataset(test_data), **params)
return train_loader, val_loader, test_loader
[docs]
@staticmethod
def create_multilabel_dataloaders(train_triples: np.ndarray, train_labels: np.ndarray,
val_triples: np.ndarray, val_labels: np.ndarray,
test_triples: np.ndarray, test_labels: np.ndarray,
args: argparse.Namespace) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Create multi-label DataLoaders.
Args:
train_triples: Training triples.
train_labels: Training labels.
val_triples: Validation triples.
val_labels: Validation labels.
test_triples: Test triples.
test_labels: Test labels.
args: Configuration parameters.
Returns:
Tuple containing:
- train_loader: Training DataLoader
- val_loader: Validation DataLoader
- test_loader: Test DataLoader
"""
params = DataLoaderCreationModule.create_dataloader_config(args)
train_loader = DataLoader(BaseMultiLabelDataset(train_triples, train_labels), **{**params, 'shuffle': True})
val_loader = DataLoader(BaseMultiLabelDataset(val_triples, val_labels), **params)
test_loader = DataLoader(BaseMultiLabelDataset(test_triples, test_labels), **params)
return train_loader, val_loader, test_loader
# =============================================================================
# Base Dataset Classes
# =============================================================================
[docs]
class BaseMultiDataset(Dataset):
"""Multi-class dataset class."""
[docs]
def __init__(self, triple: np.ndarray):
self.entity1 = triple[:, 0]
self.entity2 = triple[:, 1]
self.relationtype = triple[:, 2]
def __len__(self):
return len(self.relationtype)
def __getitem__(self, index):
return (self.entity1[index], self.entity2[index], self.relationtype[index])
[docs]
class BaseMultiLabelDataset(Dataset):
"""Multi-label dataset class."""
[docs]
def __init__(self, triple: np.ndarray, labels: np.ndarray = None):
self.entity1 = triple[:, 0]
self.entity2 = triple[:, 1]
self.labels = labels
self.relationtype = triple[:, 2] if labels is None else np.zeros(len(triple), dtype=np.int64)
def __len__(self):
return len(self.entity1)
def __getitem__(self, index):
if self.labels is not None:
return (self.entity1[index], self.entity2[index], self.labels[index])
return (self.entity1[index], self.entity2[index], 0)
# =============================================================================
# Main BaseDataset Class
# =============================================================================
[docs]
class BaseDataset:
"""
Base dataset class integrating all functional modules.
Provides a unified data processing interface supporting
multi-class and multi-label tasks.
"""
[docs]
def __init__(self, args: argparse.Namespace):
self.args = args
self.data_o: Data = None
self.train_loader = None
self.val_loader = None
self.test_loader = None
# Initialize functional modules
self.data_loader = DataLoadingModule()
self.feature_processor = FeatureProcessingModule()
self.data_splitter = DataSplittingModule()
self.graph_builder = GraphConstructionModule()
self.dataloader_creator = DataLoaderCreationModule()
[docs]
def load_data(self, val_ratio: float = 0.1, test_ratio: float = 0.2):
"""
Main entry point for loading data.
Args:
val_ratio: Validation set ratio.
test_ratio: Test set ratio.
"""
if self.args.matrix in ['multilabel', 'twosides']:
self._load_multilabel_data(val_ratio, test_ratio)
else:
self._load_multiclass_data(val_ratio, test_ratio)
def _load_multiclass_data(self, val_ratio: float = 0.1, test_ratio: float = 0.2):
"""Load multi-class data."""
print("=== 开始加载多分类数据 ===")
# 1. Read node embeddings
id2vec, emb_dim = self.data_loader.read_id_embedding_pt(self.args.embedding_path)
print(f"嵌入维数: {emb_dim}")
# 2. Read pairs and remap
pairs_df, num_relations = self.data_loader.read_multi_pairs_and_remap(self.args.matrix_path)
print(f"DDI 类型数量: {num_relations}")
self.args.num_classes = int(num_relations)
# 3. Build feature matrix
drug_list = sorted(set(pairs_df['id1']).union(set(pairs_df['id2'])))
x = self.feature_processor.build_feature_matrix(drug_list, id2vec, emb_dim, self.args)
self.args.dimensions = int(x.shape[1])
# 4. Build triples
drug_id_to_index = {d: i for i, d in enumerate(drug_list)}
triples = np.asarray(
[(drug_id_to_index[h], drug_id_to_index[t], int(r))
for h, t, r in zip(pairs_df['id1'], pairs_df['id2'], pairs_df['ddi'])],
dtype=np.int64
)
# 5. data splitting
if getattr(self.args, 'general', True):
train_data, val_data, test_data = self.data_splitter.split_data_generalization(
triples, val_ratio, test_ratio, getattr(self.args, 'seed', 1)
)
else:
train_data, val_data, test_data = self.data_splitter.split_data(
triples, val_ratio, test_ratio, getattr(self.args, 'seed', 1)
)
# 6. Label noise processing
if getattr(self.args, 'noise_ratio', 0.0) > 0:
train_data = self.data_splitter.add_label_noise_multiclass(
train_data, self.args.num_classes, self.args.noise_ratio
)
# 7. Sparse sampling
if getattr(self.args, 'sparse_sample_rate', 0.0) > 0:
train_data = self.data_splitter.sparse_sampling_multiclass(
train_data, self.args.num_classes, self.args.sparse_sample_rate
)
# 8. Create DataLoader
self.train_loader, self.val_loader, self.test_loader = self.dataloader_creator.create_multiclass_dataloaders(
train_data, val_data, test_data, self.args
)
# 9. Build graph
edge_index, edge_type = self.graph_builder.build_multigraph(
train_data, getattr(self.args, 'network_ratio', 1.0)
)
self.data_o = Data(x=x, edge_index=edge_index, edge_type=edge_type)
print("=== 多分类数据加载完成 ===")
def _load_multilabel_data(self, val_ratio: float = 0.1, test_ratio: float = 0.2):
"""Load multi-label data."""
print("=== 开始加载多标签数据 ===")
# 1. Read node embeddings
id2vec, emb_dim = self.data_loader.read_id_embedding_pt(self.args.embedding_path)
print(f"嵌入维数: {emb_dim}")
# 2. Read pairs and remap
pairs_df, num_relations = self.data_loader.read_multilabel_pairs_and_remap(self.args.matrix_path)
print(f"DDI 类型数量: {num_relations}")
self.args.num_classes = int(num_relations)
# 3. Build feature matrix
drug_list = sorted(set(pairs_df['id1']).union(set(pairs_df['id2'])))
x = self.feature_processor.build_feature_matrix(drug_list, id2vec, emb_dim, self.args)
self.args.dimensions = int(x.shape[1])
# 4. Build triples and labels
drug_id_to_index = {d: i for i, d in enumerate(drug_list)}
triples = np.asarray(
[(drug_id_to_index[h], drug_id_to_index[t], 0)
for h, t in zip(pairs_df['id1'], pairs_df['id2'])],
dtype=np.int64
)
labels = np.stack(pairs_df['ddi'].values).astype(np.float32)
# 5. data splitting
if getattr(self.args, 'general', True):
train_triples, train_labels, val_triples, val_labels, test_triples, test_labels = self.data_splitter.split_multilabel_data_generalization(
triples, labels, val_ratio, test_ratio, getattr(self.args, 'seed', 1)
)
else:
train_triples, train_labels, val_triples, val_labels, test_triples, test_labels = self.data_splitter.split_multilabel_data(
triples, labels, val_ratio, test_ratio, getattr(self.args, 'seed', 1)
)
# 6. Label noise processing
if getattr(self.args, 'noise_ratio', 0.0) > 0:
train_labels = self.data_splitter.add_label_noise_multilabel(
train_labels, self.args.noise_ratio, getattr(self.args, 'flip_per_label', 50)
)
# 7. Create DataLoader
self.train_loader, self.val_loader, self.test_loader = self.dataloader_creator.create_multilabel_dataloaders(
train_triples, train_labels, val_triples, val_labels, test_triples, test_labels, self.args
)
# 8. Build single-relation graph
edge_index, edge_type = self.graph_builder.build_single_relation_graph(
train_triples, getattr(self.args, 'network_ratio', 1.0)
)
self.data_o = Data(x=x, edge_index=edge_index, edge_type=edge_type)
print("=== 多标签数据加载完成 ===")
[docs]
def get_data_stats(self) -> Dict:
"""Get dataset statistics."""
stats = {
'num_nodes': self.data_o.x.shape[0] if self.data_o is not None else 0,
'num_edges': self.data_o.edge_index.shape[1] if self.data_o is not None else 0,
'feature_dim': self.args.dimensions if hasattr(self.args, 'dimensions') else 0,
'num_classes': self.args.num_classes if hasattr(self.args, 'num_classes') else 0,
'train_size': len(self.train_loader.dataset) if self.train_loader else 0,
'val_size': len(self.val_loader.dataset) if self.val_loader else 0,
'test_size': len(self.test_loader.dataset) if self.test_loader else 0,
}
return stats
# -------------------------------------------------------------------------
# TIGER reuse: Returns the original ID pairs and label assignments, used to directly construct x and y.
# -------------------------------------------------------------------------
[docs]
def build_pairs_labels_splits(self, val_ratio: float = 0.1, test_ratio: float = 0.2,
random_seed: Optional[int] = None,
return_original_ids: bool = True) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
"""
Load paired data from the given args.matrix or matrix_path, and perform a random split in the same style as BaseDataset.
Return (pairs, labels) for train/val/test splits.
- For multi-class tasks: labels are int64 of shape [N]
- For multi-label tasks: labels are float32 of shape [N, C]
Args:
val_ratio (float): Validation set ratio.
test_ratio (float): Test set ratio.
random_seed (int, optional): Random seed (default from args.seed or 1).
return_original_ids (bool, optional): Whether to return original ID (string) pairs.
If False, returns index pairs.
Returns:
dict: {
'train': (pairs, labels),
'val': (pairs, labels),
'test': (pairs, labels)
}
"""
seed = int(getattr(self.args, 'seed', 1)) if random_seed is None else int(random_seed)
# Read pairings and labels (select parsing method based on task type)
if self.args.matrix in ['multilabel', 'twosides']:
pairs_df, num_relations = self.data_loader.read_multilabel_pairs_and_remap(self.args.matrix_path)
labels_all = np.stack(pairs_df['ddi'].values).astype(np.float32)
else:
pairs_df, num_relations = self.data_loader.read_multi_pairs_and_remap(self.args.matrix_path)
labels_all = pairs_df['ddi'].to_numpy(dtype=np.int64)
# Original ID pairs (strings)
pairs_orig = pairs_df[['id1', 'id2']].to_numpy(dtype=object)
# If not returning original IDs, remap based on drug_list
if not return_original_ids:
drug_list = sorted(set(pairs_df['id1']).union(set(pairs_df['id2'])))
drug_id_to_index = {d: i for i, d in enumerate(drug_list)}
pairs_idx = np.asarray([(drug_id_to_index[h], drug_id_to_index[t])
for h, t in pairs_orig], dtype=np.int64)
else:
pairs_idx = None # Not used
# Generate shuffled indices and split
rng = np.random.RandomState(seed)
n_total = len(pairs_df)
perm = rng.permutation(n_total)
n_test = int(n_total * float(test_ratio))
n_val = int(n_total * float(val_ratio))
test_idx = perm[:n_test]
val_idx = perm[n_test:n_test + n_val]
train_idx = perm[n_test + n_val:]
# Assemble output
if return_original_ids:
train_pairs = pairs_orig[train_idx]
val_pairs = pairs_orig[val_idx]
test_pairs = pairs_orig[test_idx]
else:
train_pairs = pairs_idx[train_idx]
val_pairs = pairs_idx[val_idx]
test_pairs = pairs_idx[test_idx]
train_labels = labels_all[train_idx]
val_labels = labels_all[val_idx]
test_labels = labels_all[test_idx]
return {
'train': (train_pairs, train_labels),
'val': (val_pairs, val_labels),
'test': (test_pairs, test_labels),
}