Trainers¶
BaseTrainer¶
- class openddi.trainer.BaseTrainer.BaseTrainer(args, logger, dataset, model, optimizer)[source]
Bases:
ABCBase Trainer class that provides common functionality for training neural networks. Supports both multiclass and multilabel classification tasks.
Features: - Device management (CPU/GPU) - Mixed precision training with GradScaler - Memory tracking - Time tracking - Common training loop structure - Result logging and saving
- __init__(args, logger, dataset, model, optimizer)[source]
Initialize the BaseTrainer.
- Parameters:
args – Configuration arguments
logger – Logger instance for logging
dataset – Dataset object containing data loaders
model – Neural network model to train
optimizer – Optimizer for training
- train()[source]
Main training method. Determines task type and starts appropriate training.
- train_binary()[source]
Legacy method - redirects to main train method.
- train_multi()[source]
Legacy method - redirects to main train method.
Unified_Trainer¶
- class openddi.trainer.Unified_Trainer.Unified_Trainer(args, logger, dataset, model, optimizer)[source]
Bases:
BaseTrainerUnified Trainer implementing BaseTrainer framework. Supports both multiclass and multilabel classification tasks.
Features: - Multiclass: CrossEntropyLoss with accuracy, F1, recall, precision metrics - Multilabel: BCEWithLogitsLoss with AUC and AP metrics - Memory and time tracking inherited from BaseTrainer
- __init__(args, logger, dataset, model, optimizer)[source]
Initialize Unified Trainer.
- Parameters:
args – Configuration arguments
logger – Logger instance for logging
dataset – Dataset object containing data loaders
model – Neural network model to train
optimizer – Optimizer for training
MRCGNN_Trainer¶
- class openddi.trainer.MRCGNN_Trainer.MRCGNN_Trainer(args, logger, dataset, model, optimizer)[source]
Bases:
BaseTrainerMRCGNN Trainer implementing BaseTrainer framework with MRCGNN-specific logic.
Features: - Triple loss training (main loss + two auxiliary losses) - BCEWithLogitsLoss for auxiliary losses - Support for multiple data objects (data_o, data_s, data_a)
- __init__(args, logger, dataset, model, optimizer)[source]
Initialize MRCGNN Trainer.
- Parameters:
args – Configuration arguments
logger – Logger instance for logging
dataset – Dataset object containing data loaders
model – Neural network model to train
optimizer – Optimizer for training
ZeroDDI_Trainer¶
- class openddi.trainer.ZeroDDI_Trainer.ZeroDDI_Trainer(args, logger, dataset, model, optimizer=None)[source]
Bases:
BaseTrainerZeroDDI Trainer: Dual-Modality Unified Alignment (DUA) Loss = λ1*alignment + λ2*uniformity(pair) + λ3*uniformity(event)
Features: - Dynamic prototype learning - Uniformity regularization for pairs and events - Support for both multiclass and multilabel classification
- __init__(args, logger, dataset, model, optimizer=None)[source]
Initialize ZeroDDI Trainer.
- Parameters:
args – Configuration arguments
logger – Logger instance for logging
dataset – Dataset object containing data loaders
model – Neural network model to train
optimizer – Optimizer for training
- train()[source]
Override base train method with ZeroDDI-specific logic.
TIGER_Trainer¶
- class openddi.trainer.TIGER_Trainer.TIGER_Trainer(args, logger, dataset, model, optimizer)[source]
Bases:
BaseTrainerTrainer class for TIGER (Temporal Interaction Graph Embedding and Reasoning) model.
Handles both multiclass and multilabel classification tasks.
- __init__(args, logger, dataset, model, optimizer)[source]
Initialize the TIGER trainer.
- Parameters:
args – Command line arguments or configuration object
logger – Logger instance for logging
dataset – Dataset object containing train/val/test loaders
model – Model to train
optimizer – Optimizer for training
GoGNN_Trainer¶
- class openddi.trainer.GoGNN_Trainer.GoGNN_Trainer(args, logger, dataset, model, optimizer)[source]
Bases:
BaseTrainerTrainer class for Graph of Graphs Neural Network (GoGNN) models.
Handles both multiclass and multilabel classification tasks with mixed precision training.
- __init__(args, logger, dataset, model, optimizer)[source]
Initialize the GoGNN trainer.
- Parameters:
args – Command line arguments or configuration object
logger – Logger instance for logging
dataset – Dataset object containing train/val/test loaders
model – Model to train
optimizer – Optimizer for training
MUFFIN_Trainer¶
- class openddi.trainer.MUFFIN_Trainer.MUFFIN_Trainer(args, logger, dataset, model, optimizer)[source]
Bases:
BaseTrainerTrainer class for MUFFIN (Multi-modal Fusion Framework for Interaction Networks) model.
Handles both multiclass and multilabel classification tasks with pre-loaded embeddings.
- __init__(args, logger, dataset, model, optimizer)[source]
Initialize the MUFFIN trainer.
- Parameters:
args – Command line arguments or configuration object
logger – Logger instance for logging
dataset – Dataset object containing train/val/test loaders and pre-computed embeddings
model – Model to train
optimizer – Optimizer for training
MVA_Trainer¶
- class openddi.trainer.MVA_Trainer.MVA_Trainer(args, logger, dataset, model, optimizer)[source]
Bases:
BaseTrainerTrainer class for MVA (Multi-View Attention) model.
Handles both multiclass and multilabel classification tasks with progress tracking.
- __init__(args, logger, dataset, model, optimizer)[source]
Initialize the MVA trainer.
- Parameters:
args – Command line arguments or configuration object
logger – Logger instance for logging
dataset – Dataset object containing train/val/test loaders
model – Model to train
optimizer – Optimizer for training