Source code for openddi.pipeline

# pipeline.py
import os
from inspect import signature
import torch

# Only used when probing dimensions from CSV
try:
    import pandas as pd
except Exception:
    pd = None

from models.model_manager import model_manager

from trainer.MRCGNN_Trainer import MRCGNN_Trainer
from trainer.ZeroDDI_Trainer import ZeroDDI_Trainer
from trainer.Unified_Trainer import Unified_Trainer
from trainer.TIGER_Trainer import TIGER_Trainer
from trainer.GoGNN_Trainer import GoGNN_Trainer
from trainer.MUFFIN_Trainer import MUFFIN_Trainer
from trainer.MVA_Trainer import MVA_Trainer



[docs] class Pipeline: """ Pipeline class for managing model training workflow. Handles model initialization, trainer selection, and feature dimension validation. Args: args: Configuration arguments logger: Logger instance for logging dataset: Dataset object containing data loaders model: Neural network model to train optimizer: Optimizer for training """
[docs] def __init__(self, args, logger, dataset, model, optimizer): self.args = args self.logger = logger self.dataset = dataset self.model = model self.optimizer = optimizer # Key: Ensure args.features / args.dimensions are consistent with data; rebuild model/optimizer if necessary self._ensure_modal_dims_and_model() self.trainer_mapping = { "MRCGNN": MRCGNN_Trainer, "GOGNN": Unified_Trainer, # Original mapping preserved "ZeroDDI": ZeroDDI_Trainer, "DDIMDL": Unified_Trainer, "ConvLSTM": Unified_Trainer, "MVA": Unified_Trainer, "MUFFIN": Unified_Trainer, "TIGER": Unified_Trainer, "DeepDDI": Unified_Trainer, "DDKG": Unified_Trainer, "SumGNN": Unified_Trainer, "KGNN": Unified_Trainer, "LaGAT": Unified_Trainer, "PHGLDDI": Unified_Trainer, "MMDGDTI": Unified_Trainer, "DSNDDI": Unified_Trainer, "ExDDI": Unified_Trainer, "MIRACLE": Unified_Trainer, "CASTER": Unified_Trainer, "MKGFENN": Unified_Trainer, } self.trainer = self.load_trainer()
[docs] def run(self): """ Execute the main training pipeline. """ if self.args.task == 'train_xxxx': self.trainer.train()
[docs] def load_trainer(self): """ Load the appropriate trainer based on model type. Returns: Trainer instance for the specified model """ return self.trainer_mapping[self.args.model](self.args, self.logger, self.dataset, self.model, self.optimizer)
# --------------------------- # Internal utilities # --------------------------- def _ensure_modal_dims_and_model(self): """ 1) Infer modal dimensions from dataset or embedding paths, write back to args.features / args.dimensions. 2) If current model construction requires features and doesn't match args.features, rebuild model and optimizer. """ # 1) Write back modal dimensions features = None # Priority: modal dimensions exposed by dataset (if stored during data loading) for attr in ["modal_dims", "feature_dims", "modality_dims"]: md = getattr(self.dataset, attr, None) if isinstance(md, (list, tuple)) and len(md) > 0: features = list(map(int, md)) break # Secondary: probe dimensions from --embedding_path (only read first line/one sample, not too heavy) if features is None: paths = getattr(self.args, "embedding_path", None) if paths: features = self._probe_modal_dims_from_paths(paths) # If dimensions found, write back to args if features: self.args.features = list(map(int, features)) self.args.dimensions = int(sum(self.args.features)) if self.logger: self.logger.info(f"[Pipeline] modal features = {self.args.features} (sum={self.args.dimensions})") else: print(f"[Pipeline] modal features = {self.args.features} (sum={self.args.dimensions})") # 2) Check model construction signature, rebuild if necessary cls = type(self.model) want = set(signature(cls.__init__).parameters.keys()) need_features = ("features" in want) and ("feature" not in want) # e.g., DDIMDL type need_feature = ("feature" in want) # e.g., most models # Is current instance consistent with args? ok = True if need_features: ok = hasattr(self.args, "features") and isinstance(self.args.features, (list, tuple)) and len(self.args.features) > 0 # If model can get its features, further compare if ok and hasattr(self.model, "features"): try: cur = list(getattr(self.model, "features")) ok = list(map(int, cur)) == list(map(int, self.args.features)) except Exception: pass elif need_feature: ok = hasattr(self.args, "dimensions") and int(self.args.dimensions) > 0 if not ok: if self.logger: self.logger.info("[Pipeline] Rebuilding model because input feature dims changed/missing.") else: print("[Pipeline] Rebuilding model because input feature dims changed/missing.") # Use model_manager to rebuild model with latest args man = model_manager(self.args) self.model = man.load_model() # Also rebuild optimizer (if you override later in main, this is fine) self.optimizer = torch.optim.Adam( self.model.parameters(), lr=float(getattr(self.args, "lr", 1e-3)), weight_decay=float(getattr(self.args, "weight_decay", 5e-4)), ) def _probe_modal_dims_from_paths(self, paths): """ Quickly probe modal dimensions from embedding paths (.pt/.csv). Args: paths: List of file paths to embedding files Returns: list: List of dimension sizes for each modality """ dims = [] for p in paths: ext = os.path.splitext(p)[1].lower() if ext == ".pt": d = torch.load(p, map_location="cpu") # Expected format: {id: vector} k = next(iter(d.keys())) v = d[k] v = v.detach().cpu().numpy() if torch.is_tensor(v) else v dims.append(int(v.shape[-1] if v.ndim > 1 else v.shape[0])) elif ext == ".csv": if pd is None: raise RuntimeError("需要 pandas 来从 CSV 探测维度,请安装 pandas。") df = pd.read_csv(p, nrows=1) # Only read 1 row to get column count dims.append(int(df.shape[1] - 1)) # Subtract first column (id) else: raise ValueError(f"不支持的嵌入文件后缀: {p}") return dims