Source code for openddi.data.TIGER_dataset

from __future__ import print_function
from data.BaseDataset import BaseDataset
import os
import argparse
import gc
import torch
from tqdm import tqdm
import numpy as np
import json
import copy
from utils import *
from torch_geometric.utils import degree, subgraph
from torch_geometric.data import InMemoryDataset, Batch
from torch_geometric import data as DATA
from torch.utils.data import Dataset, DataLoader
import networkx as nx
from rdkit import Chem
import random
import numpy as np


def deepwalk_walk_wrapper(class_instance, walk_length, start_node):
    """
    Wrapper function for deepwalk walk method.
    
    Args:
        class_instance: Instance of BasicWalker class.
        walk_length: Length of the random walk.
        start_node: Starting node for the walk.
    """
    class_instance.deepwalk_walk(walk_length, start_node)


class BasicWalker:
    """
    Basic random walker for DeepWalk algorithm.
    
    Args:
        G: NetworkX graph.
        start_nodes: List of starting nodes for walks.
        workers: Number of workers (unused in this implementation).
    """
    def __init__(self, G, start_nodes, workers):
        self.G = G
        self.workers = workers
        self.start_nodes = start_nodes

    def deepwalk_walk(self, walk_length, start_node):
        '''
        Simulate a random walk starting from start node.
        '''
        G = self.G

        walk = [start_node]

        while len(walk) < walk_length:
            cur = walk[-1]
            cur_nbrs = list(G.neighbors(cur))
            if len(cur_nbrs) > 0:
                walk.append(random.choice(cur_nbrs))
            else:
                break

        return walk

    def simulate_walks(self, num_walks, walk_length):
        '''
        Repeatedly simulate random walks from each node.
        '''
        walks = []

        #print('Walk iteration:')
        for walk_iter in range(num_walks):
            #pool = multiprocessing.Pool(processes = )
            #print(str(walk_iter+1), '/', str(num_walks))
            for node in self.start_nodes:
                # walks.append(pool.apply_async(deepwalk_walk_wrapper, (self, walk_length, node, )))
                walks.extend(self.deepwalk_walk(
                    walk_length=walk_length, start_node=node))

        return list(set(walks))


class Walker:
    """
    Node2Vec walker with biased random walks.
    
    Args:
        G: Graph object with G, node_size, and look_up_dict attributes.
        p: Return parameter.
        q: In-out parameter.
        workers: Number of workers.
    """
    def __init__(self, G, p, q, workers):
        self.G = G.G
        self.p = p
        self.q = q
        self.node_size = G.node_size
        self.look_up_dict = G.look_up_dict

    def node2vec_walk(self, walk_length, start_node):
        '''
        Simulate a random walk starting from start node.
        '''
        G = self.G
        alias_nodes = self.alias_nodes
        alias_edges = self.alias_edges
        look_up_dict = self.look_up_dict
        node_size = self.node_size

        walk = [start_node]

        while len(walk) < walk_length:
            cur = walk[-1]
            cur_nbrs = list(G.neighbors(cur))
            if len(cur_nbrs) > 0:
                if len(walk) == 1:
                    walk.append(
                        cur_nbrs[alias_draw(alias_nodes[cur][0], alias_nodes[cur][1])])
                else:
                    prev = walk[-2]
                    pos = (prev, cur)
                    next = cur_nbrs[alias_draw(alias_edges[pos][0],
                                               alias_edges[pos][1])]
                    walk.append(next)
            else:
                break

        return walk

    def simulate_walks(self, num_walks, walk_length):
        '''
        Repeatedly simulate random walks from each node.
        '''
        G = self.G
        walks = []
        nodes = list(G.nodes())
        print('Walk iteration:')
        for walk_iter in range(num_walks):
            print(str(walk_iter+1), '/', str(num_walks))
            random.shuffle(nodes)
            for node in nodes:
                walks.append(self.node2vec_walk(
                    walk_length=walk_length, start_node=node))

        return walks

    def get_alias_edge(self, src, dst):
        '''
        Get the alias edge setup lists for a given edge.
        '''
        G = self.G
        p = self.p
        q = self.q

        unnormalized_probs = []
        for dst_nbr in G.neighbors(dst):
            if dst_nbr == src:
                unnormalized_probs.append(G[dst][dst_nbr]['weight']/p)
            elif G.has_edge(dst_nbr, src):
                unnormalized_probs.append(G[dst][dst_nbr]['weight'])
            else:
                unnormalized_probs.append(G[dst][dst_nbr]['weight']/q)
        norm_const = sum(unnormalized_probs)
        normalized_probs = [
            float(u_prob)/norm_const for u_prob in unnormalized_probs]

        return alias_setup(normalized_probs)

    def preprocess_transition_probs(self):
        '''
        Preprocessing of transition probabilities for guiding the random walks.
        '''
        G = self.G

        alias_nodes = {}
        for node in G.nodes():
            unnormalized_probs = [G[node][nbr]['weight']
                                  for nbr in G.neighbors(node)]
            norm_const = sum(unnormalized_probs)
            normalized_probs = [
                float(u_prob)/norm_const for u_prob in unnormalized_probs]
            alias_nodes[node] = alias_setup(normalized_probs)

        alias_edges = {}
        triads = {}

        look_up_dict = self.look_up_dict
        node_size = self.node_size
        for edge in G.edges():
            alias_edges[edge] = self.get_alias_edge(edge[0], edge[1])

        self.alias_nodes = alias_nodes
        self.alias_edges = alias_edges

        return


def alias_setup(probs):
    '''
    Compute utility lists for non-uniform sampling from discrete distributions.
    Refer to https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
    for details
    '''
    K = len(probs)
    q = np.zeros(K, dtype=np.float32)
    J = np.zeros(K, dtype=np.int32)

    smaller = []
    larger = []
    for kk, prob in enumerate(probs):
        q[kk] = K*prob
        if q[kk] < 1.0:
            smaller.append(kk)
        else:
            larger.append(kk)

    while len(smaller) > 0 and len(larger) > 0:
        small = smaller.pop()
        large = larger.pop()

        J[small] = large
        q[large] = q[large] + q[small] - 1.0
        if q[large] < 1.0:
            smaller.append(large)
        else:
            larger.append(large)

    return J, q


def alias_draw(J, q):
    '''
    Draw sample from a non-uniform discrete distribution using alias sampling.
    '''
    K = len(J)

    kk = int(np.floor(np.random.rand()*K))
    if np.random.rand() < q[kk]:
        return kk
    else:
        return J[kk]

class Node2vec(object):
    """
    Node2Vec algorithm implementation.
    
    Args:
        start_nodes: List of starting nodes for walks.
        graph: NetworkX graph.
        path_length: Length of each random walk.
        num_paths: Number of walks per node.
        p: Return parameter (default=1.0).
        q: In-out parameter (default=1.0).
        dw: Whether to use DeepWalk instead of Node2Vec (default=False).
        **kwargs: Additional arguments including workers.
    """

    def __init__(self, start_nodes, graph, path_length, num_paths, p=1.0, q=1.0, dw=False, **kwargs):

        kwargs["workers"] = kwargs.get("workers", 1)
        if dw:
            kwargs["hs"] = 1
            p = 1.0
            q = 1.0

        self.graph = graph
        if dw: ##deepwalk
            self.walker = BasicWalker(graph, start_nodes, workers=kwargs["workers"])
        else:
            self.walker = Walker(
                graph, p=p, q=q, workers=kwargs["workers"])
            print("Preprocess transition probs...")
            self.walker.preprocess_transition_probs()
        self.walks = self.walker.simulate_walks(
            num_walks=num_paths, walk_length=path_length)


    def get_walks(self):
        """
        Get the generated random walks.
        
        Returns:
            list: List of random walk sequences.
        """
        return self.walks

e_map = {
    'bond_type': [
        'UNSPECIFIED',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'QUADRUPLE',
        'QUINTUPLE',
        'HEXTUPLE',
        'ONEANDAHALF',
        'TWOANDAHALF',
        'THREEANDAHALF',
        'FOURANDAHALF',
        'FIVEANDAHALF',
        'AROMATIC',
        'IONIC',
        'HYDROGEN',
        'THREECENTER',
        'DATIVEONE',
        'DATIVE',
        'DATIVEL',
        'DATIVER',
        'OTHER',
        'ZERO',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOANY',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
    ],
    'is_conjugated': [False, True],
}

# mol atom feature for mol graph
def atom_features(atom):
    """
    Extract atom features for molecular graph.
    
    Args:
        atom: RDKit atom object.
        
    Returns:
        tuple: (feature_vector, degree) where feature_vector is 78-dimensional.
    """
    # 44 +11 +11 +11 +1
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                          ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As',
                                           'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
                                           'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
                                           'Pt', 'Hg', 'Pb', 'X']) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    [atom.GetIsAromatic()]), atom.GetDegree()

def one_of_k_encoding_unk(x, allowable_set):
    '''
    Maps inputs not in the allowable set to the last element.
    '''
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def single_smile_to_graph(smile):
    """
    Convert SMILES string to molecular graph representation.
    
    Args:
        smile: SMILES string.
        
    Returns:
        tuple: (c_size, features, edge_index, rel_index, s_edge_index, s_value, s_rel, max_degree)
    """
    mol = Chem.MolFromSmiles(smile)
    c_size = mol.GetNumAtoms()

    features = []
    degrees = []
    for atom in mol.GetAtoms():
        feature, degree = atom_features(atom)
        features.append((feature / sum(feature)).tolist())
        degrees.append(degree)

    mol_index = []  ##begin, end, rel
    for bond in mol.GetBonds():
        mol_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), e_map['bond_type'].index(str(bond.GetBondType()))])
        mol_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx(), e_map['bond_type'].index(str(bond.GetBondType()))])

    if len(mol_index) == 0:
        return 0, 0, 0, 0, 0, 0, 0, 0

    mol_index = np.array(sorted(mol_index))
    mol_edge_index = mol_index[:,:2]
    mol_rel_index = mol_index[:,2]

    # The shortest path should be calculated at this location
    s_edge_index_value = calculate_shortest_path(mol_edge_index)
    s_edge_index = s_edge_index_value[:, :2]
    s_value = s_edge_index_value[:, 2]
    s_rel = s_value
    s_rel[np.where(s_value == 1)] = mol_rel_index  # Map directly connected relationships to the original edge relationships
    s_rel[np.where(s_value != 1)] += 23

    assert len(s_edge_index) == len(s_value)
    assert len(s_edge_index) == len(s_rel)

    # c_size: Number of atoms
    # features: The characteristics of each atom c_size * 67
    # edge_index: The edges connecting atoms n_edges * 2
    return c_size, features, mol_edge_index.tolist(), mol_rel_index.tolist(), s_edge_index.tolist(), s_value.tolist(), s_rel.tolist(), max(degrees)

def calculate_shortest_path(edge_index):
    """
    Calculate shortest path distances between all node pairs.
    
    Args:
        edge_index: Edge index array of shape (n_edges, 2).
        
    Returns:
        np.array: Array of shape (n_pairs, 3) with [node_i, node_j, distance].
    """
    s_edge_index_value = []

    g = nx.DiGraph()
    g.add_edges_from(edge_index.tolist())

    paths = nx.all_pairs_shortest_path_length(g)
    for node_i, node_ij in paths:
        for node_j, length_ij in node_ij.items():
            s_edge_index_value.append([node_i, node_j, length_ij])

    s_edge_index_value.sort()

    return np.array(s_edge_index_value)

def smile_to_graph(datapath, ligands):
    """
    Convert SMILES strings to graph representations and cache to JSON.
    
    Args:
        datapath: Path to save/load cached graph data.
        ligands: Dictionary mapping drug IDs to SMILES strings.
        
    Returns:
        tuple: (smile_graph, max_rel, max_degree) where smile_graph is dictionary of graph data.
    """
    smile_graph = {}

    paths = datapath + "/mol_sp.json"

    if os.path.exists(paths):
        with open(paths, 'r') as f:
            smile_graph = json.load(f)
        max_rel = 0
        max_degree = 0
        for s in smile_graph.keys():
            max_rel = max(smile_graph[s][6]) if max(smile_graph[s][6]) > max_rel else max_rel
            max_degree = smile_graph[s][7] if smile_graph[s][7] > max_degree else max_degree

        return smile_graph, max_rel, max_degree

    smiles_max_node_degree = []
    num_rel_mol_update = 0
    invalid_smiles = []
    single_atom_or_empty = []
    for d, smi in ligands.items():
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            # Unparsable SMILES: Using a placeholder empty graph (maintaining the 8-tuple structure)
            invalid_smiles.append(d)
            placeholder = (1, [[0 for _ in range(67)]], [[0, 0]], [0], [[0, 0]], [1], [1], 1)
            smile_graph[d] = placeholder
            smiles_max_node_degree.append(1)
            continue
        lg = Chem.MolToSmiles(mol)  # normalize SMILES
        c_size, features, edge_index, rel_index, s_edge_index, s_value, s_rel, deg = single_smile_to_graph(lg)
        if c_size == 0:  # Single atom / no edges: also use placeholder instead of skipping
            single_atom_or_empty.append(d)
            placeholder = (1, [[0 for _ in range(67)]], [[0, 0]], [0], [[0, 0]], [1], [1], 1)
            smile_graph[d] = placeholder
            smiles_max_node_degree.append(1)
            continue
        if len(s_value) > 0 and max(s_value) > num_rel_mol_update:
            num_rel_mol_update = max(s_value)
        smile_graph[d] = (c_size, features, edge_index, rel_index, s_edge_index, s_value, s_rel, deg)
        smiles_max_node_degree.append(deg)

    if invalid_smiles:
        print(f"[smile_to_graph] 占位无效 SMILES 数: {len(invalid_smiles)} 示例: {invalid_smiles[:8]}")
    if single_atom_or_empty:
        print(f"[smile_to_graph] 单原子/空图占位数: {len(single_atom_or_empty)} 示例: {single_atom_or_empty[:8]}")

    with open(paths, 'w') as f:
        json.dump(smile_graph, f)

    return smile_graph, num_rel_mol_update, max(smiles_max_node_degree) if smiles_max_node_degree else 0

def read_network(path):
    """
    Read knowledge graph network from TSV file.
    
    Args:
        path: Path to TSV file.
        
    Returns:
        tuple: (num_node, edge_index, rel_index, num_rel)
    """
    edge_index = []
    rel_index = []

    flag = 0
    with open(path, 'r') as f:
        for line in f.readlines():
            if flag == 0:
                flag = 1
                continue
            else:
                flag += 1
                head, rel, tail = line.strip().split("\t")[:3]
                edge_index.append([int(head), int(tail)])
                rel_index.append(int(rel))

        f.close()
    num_node = np.max((np.array(edge_index)))
    num_rel = max(rel_index) + 1
    print(len(list(set(rel_index))))

    return num_node, edge_index, rel_index, num_rel

def read_smiles(path):
    """
    Simple reader that returns a dict mapping id->SMILES.
    If `path` is a directory, looks for a file named 'id_smiles.csv' inside it.
    Supports lines like 'id,SMILES' or 'id\tSMILES'. Keeps first occurrence on duplicates.
    
    Args:
        path: Path to directory or file containing SMILES data.
        
    Returns:
        dict: Dictionary mapping drug IDs to SMILES strings.
    """
    # allow passing either a file path or a directory containing id_smiles.csv
    if os.path.isdir(path):
        file_path = os.path.join(path, 'id_smiles.csv')
    else:
        file_path = path

    out = {}
    flag = 0
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for raw in f:
                if flag == 0:
                    flag = 1
                    continue
                line = raw.strip()
                if not line:
                    continue
                # support both comma and tab; split only on first occurrence
                if ',' in line and '\t' not in line:
                    parts = line.split(',', 1)
                else:
                    parts = line.split('\t', 1)
                if len(parts) < 2:
                    continue
                id_, seq = parts[0].strip(), parts[1].strip()
                # skip header if present
                if id_.lower() == 'id' or 'smiles' in id_.lower():
                    continue
                if id_ not in out:
                    out[id_] = seq
    except FileNotFoundError:
        print("read_smiles: file not found:", file_path)
    return out

def read_interactions(path, drug_dict):
    """
    Read DDI interactions from file.
    
    Args:
        path: Path to interactions file.
        drug_dict: Dictionary of valid drug IDs.
        
    Returns:
        tuple: (interactions_array, set_of_drugs_in_DDI)
    """
    interactions = []
    all_drug_in_ddi = []
    positive_drug_inter_dict = {}
    positive_num = 0
    negative_num = 0
    with open(path, 'r') as f:
        for line in f.readlines():
            drug1_id, drug2_id, rel, label = line.strip().split(" ")[:4]
            if drug1_id in drug_dict and drug2_id in drug_dict:
                all_drug_in_ddi.append(drug1_id)
                all_drug_in_ddi.append(drug2_id)
                if float(label) > 0:
                    positive_num += 1
                else:
                    negative_num += 1
                if drug1_id in positive_drug_inter_dict:
                    if drug2_id not in positive_drug_inter_dict[drug1_id]:
                        positive_drug_inter_dict[drug1_id].append(drug2_id)
                        interactions.append([int(drug1_id), int(drug2_id), int(rel), int(label)])
                else:
                    positive_drug_inter_dict[drug1_id] = [drug2_id]
                    interactions.append([int(drug1_id), int(drug2_id), int(rel), int(label)])
        f.close()

    print(positive_num)
    print(negative_num)

    assert negative_num == positive_num

    return np.array(interactions, dtype=int), set(all_drug_in_ddi)

def generate_node_subgraphs(dataset, drug_id, network_edge_index, network_rel_index, num_rel, args):
    """
    Generate subgraphs for drugs using random walk extraction.
    
    Args:
        dataset: Dataset path.
        drug_id: Set of drug IDs.
        network_edge_index: Knowledge graph edge indices.
        network_rel_index: Knowledge graph relation indices.
        num_rel: Number of relations in KG.
        args: Arguments object.
        
    Returns:
        tuple: (subgraphs_dict, max_degree, max_relation_number)
    """
    edge_index = torch.from_numpy(np.array(network_edge_index).T) ##[2, num_edges]
    rel_index = torch.from_numpy(np.array(network_rel_index))

    row, col = edge_index
    reverse_edge_index = torch.stack((col, row),0)
    undirected_edge_index = torch.cat((edge_index, reverse_edge_index),1)

    paths = str(dataset) + "/"

    if not os.path.exists(paths):
        os.mkdir(paths)

    subgraphs, max_degree, max_rel_num = rwExtractor(drug_id, undirected_edge_index, rel_index, paths, num_rel,
                                                         sub_num=1, length=32)

    return subgraphs, max_degree, max_rel_num

def rwExtractor(drug_id, edge_index, rel_index, shortest_paths, num_rel, sub_num, length):
    """
    Extract subgraphs using random walk sampling.
    
    Args:
        drug_id: Set of drug IDs.
        edge_index: Graph edge index tensor.
        rel_index: Relation index tensor.
        shortest_paths: Path for caching.
        num_rel: Number of relations.
        sub_num: Number of walks per node.
        length: Walk length.
        
    Returns:
        tuple: (subgraphs_dict, max_degree, max_relation)
    """
    json_path = shortest_paths + "rw_num_" + str(sub_num) + "_length_" + str(length) + "sp.json"
    if os.path.exists(json_path):
        with open(json_path, 'r') as f:
            subgraphs = json.load(f)
            max_rel = 0
            max_degree = 0
            for s in subgraphs.keys():
                max_rel = max(subgraphs[s][6]) if max(subgraphs[s][6]) > max_rel else max_rel
                max_degree = subgraphs[s][7] if subgraphs[s][7] > max_degree else max_degree
        return subgraphs, max_degree, max_rel;

    my_graph = nx.Graph()
    my_graph.add_edges_from(edge_index.transpose(1,0).numpy().tolist())
    undirected_rel_index = torch.cat((rel_index, rel_index), 0)

    num_rel_update = []
    max_degree = []
    subgraphs = {}
    for d in drug_id:
        # Convert the ID to an integer; if unsuccessful, use a placeholder image
        try:
            start_node = int(d)
        except Exception:
            # Placeholder: Minimal single-node graph
            placeholder_sub = ([0], [[0, 0]], [0], [True], [[0, 0]], [1], [1], 1)
            subgraphs[d] = placeholder_sub
            num_rel_update.append(1)
            max_degree.append(1)
            continue

        # If the starting node is not in the graph, also provide a placeholder subgraph to ensure no errors and allow training
        if not my_graph.has_node(start_node):
            placeholder_sub = ([start_node], [[0, 0]], [0], [True], [[0, 0]], [1], [1], 1)
            subgraphs[d] = placeholder_sub
            num_rel_update.append(1)
            max_degree.append(1)
            continue

        subsets = Node2vec(start_nodes=[start_node], graph=my_graph, path_length=length, num_paths=sub_num, workers=6, dw=True).get_walks() # returns list of lists
        # The "walks" returned by DeepWalk, as implemented in BasicWalker, are several lists of walks with duplicate nodes removed. We need a set of nodes that includes the `start_node`
        # Here, we maintain consistency with the original logic: subsets are used as a set of nodes
        try:
            mapping_id = subsets.index(start_node)
        except ValueError:
            # In rare cases, the start_node is not in the returned list, but a placeholder is still provided
            placeholder_sub = ([start_node], [[0, 0]], [0], [True], [[0, 0]], [1], [1], 1)
            subgraphs[d] = placeholder_sub
            num_rel_update.append(1)
            max_degree.append(1)
            continue

        mapping_list = [False for _ in range(len((subsets)))]
        mapping_list[mapping_id] = True

        sub_edge_index, sub_rel_index = subgraph(subsets, edge_index, undirected_rel_index, relabel_nodes=True)
        row_sub, col_sub = sub_edge_index
        # Because this involves multi-relation, all edges must be added when adding subgraphs
        new_s_edge_index = sub_edge_index.transpose(1, 0).numpy().tolist()
        new_s_value = [1 for _ in range(len(new_s_edge_index))]
        new_s_rel = sub_rel_index.numpy().tolist()

        s_edge_index = new_s_edge_index.copy()
        s_value = new_s_value.copy()
        s_rel = new_s_rel.copy()

        edge_index_value = calculate_shortest_path(sub_edge_index.transpose(1, 0).numpy())
        sp_edge_index = edge_index_value[:, :2]
        sp_value = edge_index_value[:, 2]

        for i in range(len(sp_edge_index)):
            if sp_value[i] == 1:  # Also ensure all multi-relational edges are in the data
                continue
            else:
                s_edge_index.append(sp_edge_index[i].tolist())
                s_value.append(sp_value[i])
                s_rel.append(sp_value[i] + num_rel)

        assert len(s_edge_index) == len(s_value)
        assert len(s_edge_index) == len(s_rel)

        num_rel_update.append(int(np.max(s_rel)) if len(s_rel) > 0 else 1)
        node_degree = torch.max(degree(col_sub)).item() if col_sub.numel() > 0 else 1
        max_degree.append(node_degree)

        subgraphs[d] = (subsets, new_s_edge_index, new_s_rel, mapping_list, s_edge_index, s_value, s_rel, node_degree)

    with open(json_path, 'w') as f:
        json.dump(subgraphs, f, default=convert)

    return subgraphs, max(max_degree), max(num_rel_update)

def convert(o):
    """
    Convert numpy int64 to Python int for JSON serialization.
    
    Args:
        o: Object to convert.
        
    Returns:
        int: Converted integer.
    """
    if isinstance(o, np.int64): return int(o)
    raise TypeError

class DTADataset(InMemoryDataset):
    """
    Dataset class for TIGER model combining molecular and subgraph data.
    
    Args:
        x: Array of drug pairs.
        y: Array of labels.
        sub_graph: Dictionary of drug subgraphs.
        smile_graph: Dictionary of molecular graphs.
        dt: Dataset type flag.
    """
    def __init__(self, x=None, y=None, sub_graph=None, smile_graph=None, dt = None):
        super(DTADataset, self).__init__()

        self.labels = y
        self.drug_ID = x
        self.sub_graph = sub_graph
        self.smile_graph = smile_graph
        self.dt = dt

    def read_drug_info(self, drug_id):
        """
        Read drug information including molecular and subgraph data.
        
        Args:
            drug_id: Drug ID.
            
        Returns:
            tuple: (data_mol, data_graph) PyTorch Geometric Data objects.
        """
        c_size, features, edge_index, rel_index, sp_edge_index, sp_value, sp_rel, deg = self.smile_graph[str(drug_id)]  ##drug——id是str类型的,不是int型的,这点要注意
        subset, subgraph_edge_index, subgraph_rel, mapping_id, s_edge_index, s_value, s_rel, deg = self.sub_graph[str(drug_id)]

        if edge_index == 0:
            c_size = 1
            features = [[0 for j in range(67)]]
            edge_index = [[0, 0]]
            rel_index = [0]
            sp_edge_index = [[0, 0]]
            sp_value = [1]
            sp_rel = [1]

        data_mol = DATA.Data(x=torch.Tensor(np.array(features)),
                              edge_index=torch.LongTensor(edge_index).transpose(1, 0),
                            #   y=torch.LongTensor([labels]),
                              rel_index=torch.Tensor(np.array(rel_index, dtype=int)),
                              sp_edge_index=torch.LongTensor(sp_edge_index).transpose(1, 0),
                              sp_value=torch.Tensor(np.array(sp_value, dtype=int)),
                              sp_edge_rel=torch.LongTensor(np.array(sp_rel, dtype=int))
                              )
        data_mol.__setitem__('c_size', torch.LongTensor([c_size]))

        data_graph = DATA.Data(x=torch.LongTensor(subset),
                                edge_index=torch.LongTensor(subgraph_edge_index).transpose(1,0),
                                # y=torch.LongTensor([labels]),
                                id=torch.LongTensor(np.array(mapping_id, dtype=bool)),
                                rel_index=torch.Tensor(np.array(subgraph_rel, dtype=int)),
                                sp_edge_index=torch.LongTensor(s_edge_index).transpose(1, 0),
                                sp_value=torch.Tensor(np.array(s_value, dtype=int)),
                                sp_edge_rel=torch.LongTensor(np.array(s_rel, dtype=int))
                                )

        return data_mol, data_graph

    def __len__(self):
        """
        Return the total number of samples in the dataset.
        """
        #self.data_mol1, self.data_drug1, self.data_mol2, self.data_drug2
        return len(self.drug_ID)

    def __getitem__(self, idx):
        """
        Get a single sample from the dataset.
        
        Args:
            idx: Index of the sample.
            
        Returns:
            tuple: (drug1_mol, drug1_subgraph, drug2_mol, drug2_subgraph, labels)
        """
        drug1_id = self.drug_ID[idx, 0]
        drug2_id = self.drug_ID[idx, 1]
        # labels = int(self.labels[idx])
        if self.dt == 'multiclass':
            labels = torch.LongTensor([self.labels[idx]])
        else:
            labels = torch.FloatTensor(self.labels[idx])

        drug1_mol, drug1_subgraph = self.read_drug_info(drug1_id)
        drug2_mol, drug2_subrgraph = self.read_drug_info(drug2_id)

        return drug1_mol, drug1_subgraph, drug2_mol, drug2_subrgraph, labels


def collate(data_list):
    """
    Collate function for batching DTADataset samples.
    
    Args:
        data_list: List of samples.
        
    Returns:
        tuple: Batched PyTorch Geometric Data objects and labels.
    """
    batchA = Batch.from_data_list([data[0] for data in data_list])
    batchB = Batch.from_data_list([data[1] for data in data_list])
    batchC = Batch.from_data_list([data[2] for data in data_list])
    batchD = Batch.from_data_list([data[3] for data in data_list])
    batchE = torch.stack([data[4] for data in data_list]).squeeze(1)

    return batchA, batchB, batchC, batchD, batchE


[docs] class TIGER_dataset(BaseDataset): """ TIGER dataset class for knowledge graph-enhanced DDI prediction. Features: - Molecular graph representation from SMILES - Knowledge graph subgraph extraction via random walks - Dual graph representation (molecular + knowledge graph) - Supports both multiclass and multilabel classification """
[docs] def __init__(self, args:argparse.ArgumentParser): """ Initialize TIGER dataset. Args: args: Argument parser with configuration parameters. """ super().__init__(args) self.args = args self.interactions = None self.labels = None self.smile_graph = None self.drug_subgraphs = None self.data_sta = None
[docs] def load_data(self, val_ratio: float = 0.1, test_ratio: float = 0.2): """ Main data loading method. Args: val_ratio: Validation set ratio. test_ratio: Test set ratio. """ super().load_data(val_ratio, test_ratio) data_path = self.args.oridata_path ligands = read_smiles(data_path) # smiles to graphs print("load drug smiles graphs!!") smile_graph, num_rel_mol_update, max_smiles_degree = smile_to_graph(data_path, ligands) print("load networks !!") num_node, network_edge_index, network_rel_index, num_rel = read_network(data_path + "/kgnet.tsv") print("load DDI samples!!") # Use the new helper methods in BaseDataset to obtain paired data and label assignments (in the form of original ID strings) splits = self.build_pairs_labels_splits(val_ratio=val_ratio, test_ratio=test_ratio, random_seed=getattr(self.args, 'seed', 1), return_original_ids=True) # `pairs` represents the original string IDs, ensuring consistency for subsequent graph index lookups train_pairs, train_labels = splits['train'] val_pairs, val_labels = splits['val'] test_pairs, test_labels = splits['test'] # This involves counting all the drug IDs to be used for generating subgraphs and filtering all_contained_drugs = set(map(str, np.unique(np.concatenate([train_pairs, val_pairs, test_pairs]).ravel()))) # Add placeholder empty graphs for missing drugs in `smile_graph` to prevent subsequent `KeyError` (e.g., invalid IDs like 'nan') placeholder_mol = (1, [[0 for _ in range(67)]], [[0, 0]], [0], [[0, 0]], [1], [1], 1) missing_smiles = [] for did in all_contained_drugs: if did not in smile_graph: smile_graph[did] = placeholder_mol missing_smiles.append(did) if len(missing_smiles) > 0: print(f"[TIGER_dataset] 为 {len(missing_smiles)} 个在 SMILES 映射中缺失的药物填充占位分子图。示例: {missing_smiles[:8]}") print("generate subgraphs!!") drug_subgraphs, max_subgraph_degree, num_rel_update = generate_node_subgraphs(data_path, all_contained_drugs, network_edge_index, network_rel_index, num_rel, self.args) total_interactions = len(train_pairs) + len(val_pairs) + len(test_pairs) data_sta = { 'num_nodes': num_node + 1, 'num_rel_mol': num_rel_mol_update + 1, 'num_rel_graph': num_rel_update + 1, 'num_interactions': int(total_interactions), 'num_drugs_DDI': len(all_contained_drugs), 'max_degree_graph': max_smiles_degree + 1, 'max_degree_node': int(max_subgraph_degree)+1 } print(data_sta) self.data_sta = data_sta # Convert string ID pairs to numpy arrays as expected by DataLoader def pairs_to_np(pairs): return np.array([[p[0], p[1]] for p in pairs], dtype=object) train_x = pairs_to_np(train_pairs) val_x = pairs_to_np(val_pairs) test_x = pairs_to_np(test_pairs) # Construct label tensors based on the dataset type if self.args.matrix in ['multilabel', 'twosides']: # Multi-label: Keep as float32 array train_y = np.array(train_labels, dtype=np.float32) val_y = np.array(val_labels, dtype=np.float32) test_y = np.array(test_labels, dtype=np.float32) else: # Multi-class: Single integer label train_y = np.array(train_labels, dtype=np.int64) val_y = np.array(val_labels, dtype=np.int64) test_y = np.array(test_labels, dtype=np.int64) # Construct three DTADataset instances # dt flag is used in __getitem__ to determine the shape/type of the label tensor: # - Multi-class: Use 'drugbank' (LongTensor single label) # - Multi-label: Use 'twosides' (FloatTensor multi-label vector) dt_flag = 'multilabel' if self.args.matrix in ['multilabel', 'twosides'] else 'multiclass' train_data = DTADataset(x=train_x, y=train_y, sub_graph=drug_subgraphs, smile_graph=smile_graph, dt=dt_flag) val_data = DTADataset(x=val_x, y=val_y, sub_graph=drug_subgraphs, smile_graph=smile_graph, dt=dt_flag) test_data = DTADataset(x=test_x, y=test_y, sub_graph=drug_subgraphs, smile_graph=smile_graph, dt=dt_flag) # DataLoader construction self.train_loader = DataLoader(train_data, batch_size=self.args.batch, shuffle=True, collate_fn=collate) self.val_loader = DataLoader(val_data, batch_size=self.args.batch, shuffle=True, collate_fn=collate) self.test_loader = DataLoader(test_data, batch_size=self.args.batch, shuffle=True, collate_fn=collate)