Source code for openddi.data.MVA_dataset

import numpy as np
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem.rdchem import BondType
from torch.utils import data
from subword_nmt.apply_bpe import BPE
import codecs
import numpy as np
from rdkit import Chem
from torch.utils.data import Dataset, DataLoader
from data.BaseDataset import BaseDataset
import argparse
import os
from utils.config import CODE_DIR

ESPF_path = os.path.join(CODE_DIR,'..','datasets','ESPF')

def get_intervals(l):
    """
    For list of lists, gets the cumulative products of the lengths.
    
    Args:
        l: List of lists.
        
    Returns:
        list: Cumulative products of lengths.
    """
    intervals = len(l) * [0]
    intervals[0] = 1
    for k in range(1, len(l)):
        intervals[k] = (len(l[k]) + 1) * intervals[k - 1]
    return intervals

possible_atom_list = [
    'C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br', 'Fe', 'Ca', 'Cu',
    'Mc', 'Pd', 'Pb', 'K', 'I', 'Al', 'Ni', 'Mn'
]
possible_numH_list = [0, 1, 2, 3, 4]
possible_valence_list = [0, 1, 2, 3, 4, 5, 6]
possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3]
possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
possible_number_radical_e_list = [0, 1, 2]
possible_chirality_list = ['R', 'S']

reference_lists = [
    possible_atom_list, possible_numH_list, possible_valence_list,
    possible_formal_charge_list, possible_number_radical_e_list,
    possible_hybridization_list, possible_chirality_list
]

intervals = get_intervals(reference_lists)

def safe_index(l, e):
    """
    Gets the index of e in l, providing an index of len(l) if not found.
    
    Args:
        l: List to search in.
        e: Element to find.
        
    Returns:
        int: Index of element or len(l) if not found.
    """
    try:
        return l.index(e)
    except:
        return len(l)

def get_feature_list(atom):
    """
    Extract atom features as indices into reference lists.
    
    Args:
        atom: RDKit atom object.
        
    Returns:
        list: Feature indices for the atom.
    """
    features = 6 * [0]
    features[0] = safe_index(possible_atom_list, atom.GetSymbol())
    features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs())
    features[2] = safe_index(possible_valence_list, atom.GetImplicitValence())
    features[3] = safe_index(possible_formal_charge_list, atom.GetFormalCharge())
    features[4] = safe_index(possible_number_radical_e_list,
                           atom.GetNumRadicalElectrons())
    features[5] = safe_index(possible_hybridization_list, atom.GetHybridization())
    return features

def features_to_id(features, intervals):
    """
    Convert list of features into index using spacings provided in intervals.
    
    Args:
        features: List of feature indices.
        intervals: Cumulative product intervals.
        
    Returns:
        int: Unique ID for the atom type.
    """
    id = 0
    for k in range(len(intervals)):
        id += features[k] * intervals[k]
    id = id + 1  # Allow 0 index to correspond to null molecule 1
    return id

def atom_to_id(atom):
    """
    Return a unique id corresponding to the atom type.
    
    Args:
        atom: RDKit atom object.
        
    Returns:
        int: Unique atom ID.
    """
    features = get_feature_list(atom)
    return features_to_id(features, intervals)

def one_of_k_encoding(x, allowable_set):
    """
    One-hot encoding for known values.
    
    Args:
        x: Value to encode.
        allowable_set: Set of allowable values.
        
    Returns:
        list: One-hot encoded vector.
    """
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """
    One-hot encoding with unknown value mapping to last element.
    
    Args:
        x: Value to encode.
        allowable_set: Set of allowable values.
        
    Returns:
        list: One-hot encoded vector.
    """
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def atom_features(atom,
                  bool_id_feat=False,
                  explicit_H=False,
                  use_chirality=False):

    if bool_id_feat:
        return np.array([atom_to_id(atom)])
    else:
        results = one_of_k_encoding_unk(
            atom.GetSymbol(),
            [
                'C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As','Al','I','B','V','K',
                'Tl','Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn','H','Li','Ge','Cu','Au','Ni','Cd',
                'In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown'
            ]) + one_of_k_encoding(atom.GetDegree(), [0,1,2,3,4,5,6,7,8,9,10]) + \
                one_of_k_encoding_unk(atom.GetImplicitValence(), [0,1,2,3,4,5,6]) + \
                [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
                one_of_k_encoding_unk(atom.GetHybridization(), [
                    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                    Chem.rdchem.HybridizationType.SP3D2
                ]) + [atom.GetIsAromatic()]
        
        if not explicit_H:
            results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0,1,2,3,4])
        
        if use_chirality:
            try:
                results = results + one_of_k_encoding_unk(
                    atom.GetProp('_CIPCode'), ['R','S']) + [atom.HasProp('_ChiralityPossible')]
            except:
                results = results + [False, False] + [atom.HasProp('_ChiralityPossible')]
        
        return np.array(results)

def drug2emb_encoder(x, dbpe, words2idx_d):
    """
    Encode SMILES string to BPE embedding indices.
    
    Args:
        x: SMILES string.
        dbpe: BPE encoder object.
        words2idx_d: Dictionary mapping words to indices.
        
    Returns:
        tuple: (indices, input_mask) for the encoded SMILES.
    """
    max_d = 50
    t1 = dbpe.process_line(x).split()
    try:
        i1 = np.asarray([words2idx_d[i] for i in t1])
    except:
        i1 = np.array([1])
    l = len(i1)
    if l < max_d:
        i = np.pad(i1, (0, max_d - l), 'constant', constant_values=0)
        input_mask = ([1] * l) + ([0] * (max_d - l))
    else:
        i = i1[:max_d]
        input_mask = [1] * max_d
    return i, np.asarray(input_mask)

class MolGraphDataset(Dataset):
    """
    Dataset class for molecular graph data with SMILES strings.
    
    Args:
        data_df: DataFrame containing drug pairs and labels.
        id2smiles: Dictionary mapping drug IDs to SMILES strings.
        prediction: Whether this is for prediction task (affects label type).
    """
    def __init__(self, data_df, id2smiles, prediction=False):
        # Load BPE vocabulary for SMILES encoding
        vocab_path = os.path.join(ESPF_path,'drug_codes_chembl.txt')
        bpe_codes_drug = codecs.open(vocab_path)
        dbpe = BPE(bpe_codes_drug, merges=-1, separator='')

        sub_csv = pd.read_csv(os.path.join(ESPF_path,'subword_units_map_chembl.csv'))
        idx2word_d = sub_csv['index'].values
        words2idx_d = dict(zip(idx2word_d, range(0, len(idx2word_d))))
        self.dbpe = dbpe
        self.words2idx_d = words2idx_d

        self.smiles1 = []
        self.smiles2 = []
        self.targets = []

        for _, row in data_df.iterrows():
            if pd.isna(row['id1']) or pd.isna(row['id2']):
                continue
            id1 = int(row['id1'])
            id2 = int(row['id2'])
            
            if id1 in id2smiles and id2 in id2smiles:
                s1 = id2smiles[id1]
                s2 = id2smiles[id2]
                
                if pd.isna(s1) or pd.isna(s2) or s1 == '' or s2 == '':
                    continue
                
                self.smiles1.append(s1)
                self.smiles2.append(s2)
                
                if prediction:
                    self.targets.append(float(row['ddi']))
                else:
                    self.targets.append(int(row['ddi']))

        self.smiles1 = np.array(self.smiles1)
        self.smiles2 = np.array(self.smiles2)
        if prediction:
            self.targets = np.array(self.targets, dtype=np.float32)
        else:
            self.targets = np.array(self.targets, dtype=np.int64)

    def __getitem__(self, index):
        """
        Get a single sample from the dataset.
        
        Args:
            index: Index of the sample.
            
        Returns:
            tuple: ((fts1, adjs1), (fts2, adjs2), num_size, targets, d1, d2, mask_1, mask_2)
        """
        fts1, adjs1 = smile_to_graph(self.smiles1[index])
        fts2, adjs2 = smile_to_graph(self.smiles2[index])

        num_size = Chem.MolFromSmiles(self.smiles1[index]).GetNumAtoms()
        d1, mask_1 = drug2emb_encoder(self.smiles1[index], self.dbpe, self.words2idx_d)
        d2, mask_2 = drug2emb_encoder(self.smiles2[index], self.dbpe, self.words2idx_d)

        targets = self.targets[index]

        return (fts1, adjs1), (fts2, adjs2), num_size, targets, d1, d2, mask_1, mask_2

    def __len__(self):
        """
        Return the total number of samples in the dataset.
        """
        return len(self.smiles1)
    
def smile_to_graph(smile):
    """
    Convert SMILES string to molecular graph representation.
    
    Args:
        smile: SMILES string.
        
    Returns:
        tuple: (node_features, adjacency_matrix) for the molecule.
    """
    molecule = Chem.MolFromSmiles(smile)
    n_atoms = molecule.GetNumAtoms()
    atoms = [molecule.GetAtomWithIdx(i) for i in range(n_atoms)]
    adjacency = Chem.rdmolops.GetAdjacencyMatrix(molecule)
    node_features = np.array([atom_features(atom) for atom in atoms])
    return node_features, adjacency

def molgraph_collate_fn(data):
    """
    Collate function for batching molecular graph data.
    
    Args:
        data: List of samples from MolGraphDataset.
        
    Returns:
        tuple: Batched tensors for model input.
    """
    n_samples = len(data)
    (fts1, adjs1), (fts2, adjs2), num_size, targets_0, d1, d2, mask_1, mask_2 = data[0]

    n_nodes_largest_graph_1 = max(map(lambda sample: sample[0][0].shape[0], data))
    n_nodes_largest_graph_2 = max(map(lambda sample: sample[1][0].shape[0], data))

    n_node_fts_1 = fts1.shape[1]
    n_node_fts_2 = fts2.shape[1]

    n_emb = d1.shape[0]
    n_mask = mask_1.shape[0]

    adjacency_tensor_1 = torch.zeros(n_samples, n_nodes_largest_graph_1, n_nodes_largest_graph_1)
    node_tensor_1 = torch.zeros(n_samples, n_nodes_largest_graph_1, n_node_fts_1)

    adjacency_tensor_2 = torch.zeros(n_samples, n_nodes_largest_graph_2, n_nodes_largest_graph_2)
    node_tensor_2 = torch.zeros(n_samples, n_nodes_largest_graph_2, n_node_fts_2)

    num_size_tensor = torch.zeros(n_samples, num_size)
    
    # Determine target shape based on type
    if isinstance(targets_0, (int, np.integer)):
        target_tensor = torch.zeros(n_samples, dtype=torch.long)
    else:
        target_tensor = torch.zeros(n_samples, 1, dtype=torch.float32)

    d1_emb_tensor = torch.zeros(n_samples, n_emb)
    d2_emb_tensor = torch.zeros(n_samples, n_emb)
    mask_1_tensor = torch.zeros(n_samples, n_mask)
    mask_2_tensor = torch.zeros(n_samples, n_mask)

    for i in range(n_samples):
        (fts1, adjs1), (fts2, adjs2), num_size, target, d1, d2, mask_1, mask_2 = data[i]

        n_nodes_1 = adjs1.shape[0]
        n_nodes_2 = adjs2.shape[0]

        num_size_tensor[i] = torch.tensor(num_size)
        adjacency_tensor_1[i, :n_nodes_1, :n_nodes_1] = torch.Tensor(adjs1)
        node_tensor_1[i, :n_nodes_1, :] = torch.Tensor(fts1)
        adjacency_tensor_2[i, :n_nodes_2, :n_nodes_2] = torch.Tensor(adjs2)
        node_tensor_2[i, :n_nodes_2, :] = torch.Tensor(fts2)

        target_tensor[i] = torch.tensor(target)
        d1_emb_tensor[i] = torch.IntTensor(d1)
        d2_emb_tensor[i] = torch.IntTensor(d2)
        mask_1_tensor[i] = torch.tensor(mask_1)
        mask_2_tensor[i] = torch.tensor(mask_2)

    return node_tensor_1, adjacency_tensor_1, node_tensor_2, adjacency_tensor_2, num_size_tensor, target_tensor, d1_emb_tensor, d2_emb_tensor, mask_1_tensor, mask_2_tensor

[docs] class MVA_dataset(BaseDataset): """ MVA dataset class for multi-view attention DDI prediction. Features: - Molecular graph representation from SMILES - BPE encoding for SMILES sequences - Atom feature extraction with multiple attribute types - Supports both multiclass and multilabel classification """
[docs] def __init__(self, args: argparse.ArgumentParser): """ Initialize MVA dataset. Args: args: Argument parser with configuration parameters. """ super().__init__(args) self.args = args
[docs] def load_data(self, val_ratio: float = 0.1, test_ratio: float = 0.2): """ Main data loading method. Args: val_ratio: Validation set ratio. test_ratio: Test set ratio. """ super().load_data(val_ratio, test_ratio) smiles_path = self.args.oriSmiles_path ddi_path = self.args.matrix_path # Load SMILES mapping id_smiles_df = pd.read_csv(smiles_path) id2smiles = dict(zip(id_smiles_df['id1'], id_smiles_df['SMILES'])) # Load DDI data ddi_df = pd.read_csv(ddi_path) # Shuffle and split ddi_df = ddi_df.sample(frac=1, random_state=1).reset_index(drop=True) n_total = len(ddi_df) n_val = int(n_total * val_ratio) n_test = int(n_total * test_ratio) n_train = n_total - n_val - n_test train_df = ddi_df.iloc[:n_train] val_df = ddi_df.iloc[n_train:n_train+n_val] test_df = ddi_df.iloc[n_train+n_val:] # Create datasets based on task type if self.args.matrix in ['multilabel', 'twosides']: train_dataset = MolGraphDataset(train_df, id2smiles, prediction=True) validation_dataset = MolGraphDataset(val_df, id2smiles, prediction=True) test_dataset = MolGraphDataset(test_df, id2smiles, prediction=True) else: train_dataset = MolGraphDataset(train_df, id2smiles, prediction=False) validation_dataset = MolGraphDataset(val_df, id2smiles, prediction=False) test_dataset = MolGraphDataset(test_df, id2smiles, prediction=False) self.train_loader = DataLoader(train_dataset, batch_size=self.args.batch, shuffle=True, collate_fn=molgraph_collate_fn) self.val_loader = DataLoader(validation_dataset, batch_size=self.args.batch, shuffle=False, collate_fn=molgraph_collate_fn) self.test_loader = DataLoader(test_dataset, batch_size=self.args.batch, shuffle=True, collate_fn=molgraph_collate_fn)