Source code for openddi.data.Unified_dataset

import argparse
import numpy as np
from torch_geometric.data import Data
from data.BaseDataset import BaseDataset

[docs] class UnifiedDataset(BaseDataset): """ Unified dataset class that inherits from BaseDataset. Features: - Unified reading of id->embedding (supports multi-modal concatenation, see args.embedding_path/embedding_dir + --modality) - Multi-class: Uses real relationship types as edge_type (used by RGCN) - Multi-label: Uses single relationship graph (edge_type=0) - Supports feature Gaussian noise (noise_std) and label flip noise (noise_ratio) - DataLoader: pin_memory=False, persistent_workers=False; prefetch_factor=1 when workers>0 - Supports sparse sampling (sparse_sample_rate) and sparse dropping (sparse_drop_rate) """
[docs] def __init__(self, args: argparse.Namespace): """ Initialize the Unified dataset. Args: args: Namespace object containing the following parameters: - matrix: Data type ('multilabel', 'twosides' or other multi-class types) - embedding_path: Embedding file path - matrix_path: Matrix data file path - batch: Batch size - workers: Number of worker processes (optional) - noise_std: Feature Gaussian noise standard deviation (optional) - noise_ratio: Label noise ratio (optional) - sparse_sample_rate: Sparse sampling rate (optional) - sparse_drop_rate: Sparse drop rate (optional) - network_ratio: Graph edge usage ratio (optional) - flip_per_label: Multi-label flip bits (optional, default 50) """ super().__init__(args)
[docs] def load_data(self, val_ratio: float = 0.1, test_ratio: float = 0.2): """ Main entry method for loading data. Args: val_ratio: Validation set ratio, defaults to 0.1 test_ratio: Test set ratio, defaults to 0.2 """ print("=== Unified Dataset Loading ===") print(f"数据类型: {self.args.matrix}") print(f"嵌入路径: {self.args.embedding_path}") print(f"矩阵路径: {self.args.matrix_path}") # Call parent class's load_data method, which automatically selects # loading method based on matrix type super().load_data(val_ratio, test_ratio)
# Print dataset statistics # self._print_dataset_info() def _print_dataset_info(self): """Print dataset statistics.""" stats = self.get_data_stats() print("\n=== 数据集统计信息 ===") print(f"节点数量: {stats['num_nodes']}") print(f"边数量: {stats['num_edges']}") print(f"特征维度: {stats['feature_dim']}") print(f"类别数量: {stats['num_classes']}") print(f"训练集大小: {stats['train_size']}") print(f"验证集大小: {stats['val_size']}") print(f"测试集大小: {stats['test_size']}") print("========================\n")
[docs] def get_noise_config(self) -> dict: """ Get noise configuration information. Returns: dict: Dictionary containing noise configuration """ noise_config = { 'feature_noise_std': getattr(self.args, 'noise_std', 0.0), 'label_noise_ratio': getattr(self.args, 'noise_ratio', 0.0), 'sparse_drop_rate': getattr(self.args, 'sparse_drop_rate', 0.0), 'sparse_sample_rate': getattr(self.args, 'sparse_sample_rate', 0.0), 'network_ratio': getattr(self.args, 'network_ratio', 1.0), } if self.args.matrix in ['multilabel', 'twosides']: noise_config['flip_per_label'] = getattr(self.args, 'flip_per_label', 50) return noise_config
# To maintain backward compatibility, the original class names have been retained class Unified_dataset(UnifiedDataset): """ Alias for UnifiedDataset class to maintain backward compatibility. """ pass