Source code for openddi.data.MUFFIN_dataset

from data.BaseDataset import BaseDataset
import os
import argparse
import gc
import torch
from tqdm import tqdm
import numpy as np
import json
import copy
from utils import *

from torch.utils.data import Dataset, DataLoader

import random
import numpy as np

class MuffinDataset(Dataset):
    """
    Dataset class for MUFFIN model containing drug pairs and labels.
    
    Args:
        pairs: Array of drug pairs (u_id, v_id).
        labels: Array of labels for each drug pair.
    """
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        """
        Return the total number of samples in the dataset.
        """
        return len(self.pairs)

    def __getitem__(self, idx):
        """
        Get a single sample from the dataset.
        
        Args:
            idx: Index of the sample.
            
        Returns:
            tuple: (u_id, v_id, label) for the given index.
        """
        return self.pairs[idx][0], self.pairs[idx][1], self.labels[idx]

def make_collate_fn(entity_embed, structure_embed, task_type='multiclass', device='cpu'):
    """
    Create collate function for batching MUFFIN data.
    
    Args:
        entity_embed: Entity embedding matrix.
        structure_embed: Structure embedding matrix.
        task_type: Type of task ('multiclass' or 'multilabel').
        device: Target device for tensors.
        
    Returns:
        function: Collate function for DataLoader.
    """
    # Move embeddings to shared memory or device once if possible, 
    # but be careful with multiprocessing.
    # Here we keep them as tensors.
    
    def _collate(samples):
        """
        Collate function for batch processing.
        
        Args:
            samples: List of tuples (u, v, label).
            
        Returns:
            list: [None, None, edge_index, labels] format for MUFFIN.forward.
        """
        # samples: list of (u, v, label)
        us, vs, lbls = zip(*samples)
        
        # Construct edge_index: [2, batch_size]
        # These are indices into the embedding matrices
        edge_index = torch.tensor([us, vs], dtype=torch.long)
        
        if task_type == 'multiclass':
            labels = torch.tensor(lbls, dtype=torch.long)
        else:
            labels = torch.tensor(lbls, dtype=torch.float32)
            
        # Return format matches MUFFIN.forward expectations:
        # batch_data = [entity_embed, structure_embed, edge_index, labels]
        # Note: We pass None for embeddings to avoid pickling large matrices in multiprocessing.
        # The Trainer is responsible for injecting the embeddings (on device) into the batch.
        return [None, None, edge_index, labels]
    return _collate

[docs] class MUFFIN_dataset(BaseDataset): """ MUFFIN dataset class for multi-modal DDI prediction. Features: - Dual embedding support (entity and structure embeddings) - Pre-trained embedding integration - Supports both multiclass and multilabel classification """
[docs] def __init__(self, args:argparse.ArgumentParser): """ Initialize MUFFIN dataset. Args: args: Argument parser with configuration parameters. """ super().__init__(args) self.args = args self.entity_dim = 0 self.structure_dim = 0 self.entity_pre_embed = None self.structure_pre_embed = None
[docs] def load_data(self, val_ratio: float = 0.1, test_ratio: float = 0.2): """ Main data loading method. Args: val_ratio: Validation set ratio. test_ratio: Test set ratio. """ super().load_data() # 1. Load Pretrained Embeddings # Use BaseDataset's helper to load .pt files (id -> embedding) entity_path = self.args.embedding_map['drkg'] structure_path = self.args.embedding_map['smiles'] print(f"Loading entity embeddings from {entity_path}") # read_id_embedding_pt returns (id2vec_dict, dim) entity_id2vec, self.entity_dim = self.data_loader.read_id_embedding_pt(entity_path) print(f"Loading structure embeddings from {structure_path}") structure_id2vec, self.structure_dim = self.data_loader.read_id_embedding_pt(structure_path) # 2. Get Data Splits (using BaseDataset's method) # This returns original string IDs splits = self.build_pairs_labels_splits(val_ratio=val_ratio, test_ratio=test_ratio, return_original_ids=True) # 3. Build ID to Index Mapping & Embedding Matrices # We need to align entity and structure embeddings to the same index space. # We collect ALL unique drug IDs from the dataset splits. all_pairs = np.concatenate([splits['train'][0], splits['val'][0], splits['test'][0]]) unique_ids = sorted(list(set(all_pairs.flatten()))) id2idx = {id_str: i for i, id_str in enumerate(unique_ids)} num_drugs = len(unique_ids) # Initialize embedding matrices self.entity_pre_embed = torch.zeros((num_drugs, self.entity_dim), dtype=torch.float32) self.structure_pre_embed = torch.zeros((num_drugs, self.structure_dim), dtype=torch.float32) # Fill matrices missing_entity = 0 missing_structure = 0 for id_str, idx in id2idx.items(): # Entity embedding if id_str in entity_id2vec: self.entity_pre_embed[idx] = torch.from_numpy(entity_id2vec[id_str]) else: missing_entity += 1 self.entity_pre_embed[idx] = torch.zeros(self.entity_dim, dtype=torch.float32) # Structure embedding if id_str in structure_id2vec: self.structure_pre_embed[idx] = torch.from_numpy(structure_id2vec[id_str]) else: missing_structure += 1 self.structure_pre_embed[idx] = torch.zeros(self.structure_dim, dtype=torch.float32) # 4. Convert ID pairs to Index pairs def process_split(split_name): """ Convert string ID pairs to index pairs for a given split. Args: split_name: Name of the split ('train', 'val', 'test'). Returns: tuple: (index_pairs, labels) for the split. """ pairs, labels = splits[split_name] idx_pairs = [] valid_labels = [] for (u, v), l in zip(pairs, labels): if u in id2idx and v in id2idx: idx_pairs.append([id2idx[u], id2idx[v]]) valid_labels.append(l) return np.array(idx_pairs), np.array(valid_labels) train_pairs, train_labels = process_split('train') val_pairs, val_labels = process_split('val') test_pairs, test_labels = process_split('test') # 5. Build DataLoaders dt_flag = 'multilabel' if self.args.matrix in ['multilabel', 'twosides'] else 'multiclass' collate_fn = make_collate_fn(self.entity_pre_embed, self.structure_pre_embed, task_type=dt_flag, device=self.args.device) params = { 'batch_size': self.args.batch, 'shuffle': True, 'num_workers': int(getattr(self.args, 'workers', 0)), 'collate_fn': collate_fn, 'drop_last': False } # For multiprocessing, we shouldn't pass CUDA tensors if num_workers > 0 # But here we handle device movement inside collate_fn or after loading. if params['num_workers'] > 0: params['prefetch_factor'] = 1 self.train_loader = DataLoader(MuffinDataset(train_pairs, train_labels), **params) self.val_loader = DataLoader(MuffinDataset(val_pairs, val_labels), **params) self.test_loader = DataLoader(MuffinDataset(test_pairs, test_labels), **params)