Trainers

BaseTrainer

class openddi.trainer.BaseTrainer.BaseTrainer(args, logger, dataset, model, optimizer)[source]

Bases: ABC

Base Trainer class that provides common functionality for training neural networks. Supports both multiclass and multilabel classification tasks.

Features: - Device management (CPU/GPU) - Mixed precision training with GradScaler - Memory tracking - Time tracking - Common training loop structure - Result logging and saving

__init__(args, logger, dataset, model, optimizer)[source]

Initialize the BaseTrainer.

Parameters:
  • args – Configuration arguments

  • logger – Logger instance for logging

  • dataset – Dataset object containing data loaders

  • model – Neural network model to train

  • optimizer – Optimizer for training

train()[source]

Main training method. Determines task type and starts appropriate training.

train_binary()[source]

Legacy method - redirects to main train method.

train_multi()[source]

Legacy method - redirects to main train method.

Unified_Trainer

class openddi.trainer.Unified_Trainer.Unified_Trainer(args, logger, dataset, model, optimizer)[source]

Bases: BaseTrainer

Unified Trainer implementing BaseTrainer framework. Supports both multiclass and multilabel classification tasks.

Features: - Multiclass: CrossEntropyLoss with accuracy, F1, recall, precision metrics - Multilabel: BCEWithLogitsLoss with AUC and AP metrics - Memory and time tracking inherited from BaseTrainer

__init__(args, logger, dataset, model, optimizer)[source]

Initialize Unified Trainer.

Parameters:
  • args – Configuration arguments

  • logger – Logger instance for logging

  • dataset – Dataset object containing data loaders

  • model – Neural network model to train

  • optimizer – Optimizer for training

MRCGNN_Trainer

class openddi.trainer.MRCGNN_Trainer.MRCGNN_Trainer(args, logger, dataset, model, optimizer)[source]

Bases: BaseTrainer

MRCGNN Trainer implementing BaseTrainer framework with MRCGNN-specific logic.

Features: - Triple loss training (main loss + two auxiliary losses) - BCEWithLogitsLoss for auxiliary losses - Support for multiple data objects (data_o, data_s, data_a)

__init__(args, logger, dataset, model, optimizer)[source]

Initialize MRCGNN Trainer.

Parameters:
  • args – Configuration arguments

  • logger – Logger instance for logging

  • dataset – Dataset object containing data loaders

  • model – Neural network model to train

  • optimizer – Optimizer for training

ZeroDDI_Trainer

class openddi.trainer.ZeroDDI_Trainer.ZeroDDI_Trainer(args, logger, dataset, model, optimizer=None)[source]

Bases: BaseTrainer

ZeroDDI Trainer: Dual-Modality Unified Alignment (DUA) Loss = λ1*alignment + λ2*uniformity(pair) + λ3*uniformity(event)

Features: - Dynamic prototype learning - Uniformity regularization for pairs and events - Support for both multiclass and multilabel classification

__init__(args, logger, dataset, model, optimizer=None)[source]

Initialize ZeroDDI Trainer.

Parameters:
  • args – Configuration arguments

  • logger – Logger instance for logging

  • dataset – Dataset object containing data loaders

  • model – Neural network model to train

  • optimizer – Optimizer for training

train()[source]

Override base train method with ZeroDDI-specific logic.

TIGER_Trainer

class openddi.trainer.TIGER_Trainer.TIGER_Trainer(args, logger, dataset, model, optimizer)[source]

Bases: BaseTrainer

Trainer class for TIGER (Temporal Interaction Graph Embedding and Reasoning) model.

Handles both multiclass and multilabel classification tasks.

__init__(args, logger, dataset, model, optimizer)[source]

Initialize the TIGER trainer.

Parameters:
  • args – Command line arguments or configuration object

  • logger – Logger instance for logging

  • dataset – Dataset object containing train/val/test loaders

  • model – Model to train

  • optimizer – Optimizer for training

GoGNN_Trainer

class openddi.trainer.GoGNN_Trainer.GoGNN_Trainer(args, logger, dataset, model, optimizer)[source]

Bases: BaseTrainer

Trainer class for Graph of Graphs Neural Network (GoGNN) models.

Handles both multiclass and multilabel classification tasks with mixed precision training.

__init__(args, logger, dataset, model, optimizer)[source]

Initialize the GoGNN trainer.

Parameters:
  • args – Command line arguments or configuration object

  • logger – Logger instance for logging

  • dataset – Dataset object containing train/val/test loaders

  • model – Model to train

  • optimizer – Optimizer for training

MUFFIN_Trainer

class openddi.trainer.MUFFIN_Trainer.MUFFIN_Trainer(args, logger, dataset, model, optimizer)[source]

Bases: BaseTrainer

Trainer class for MUFFIN (Multi-modal Fusion Framework for Interaction Networks) model.

Handles both multiclass and multilabel classification tasks with pre-loaded embeddings.

__init__(args, logger, dataset, model, optimizer)[source]

Initialize the MUFFIN trainer.

Parameters:
  • args – Command line arguments or configuration object

  • logger – Logger instance for logging

  • dataset – Dataset object containing train/val/test loaders and pre-computed embeddings

  • model – Model to train

  • optimizer – Optimizer for training

MVA_Trainer

class openddi.trainer.MVA_Trainer.MVA_Trainer(args, logger, dataset, model, optimizer)[source]

Bases: BaseTrainer

Trainer class for MVA (Multi-View Attention) model.

Handles both multiclass and multilabel classification tasks with progress tracking.

__init__(args, logger, dataset, model, optimizer)[source]

Initialize the MVA trainer.

Parameters:
  • args – Command line arguments or configuration object

  • logger – Logger instance for logging

  • dataset – Dataset object containing train/val/test loaders

  • model – Model to train

  • optimizer – Optimizer for training