Source code for openddi.trainer.BaseTrainer

import copy
import time
import os
import numpy as np
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, Optional


[docs] class BaseTrainer(ABC): """ Base Trainer class that provides common functionality for training neural networks. Supports both multiclass and multilabel classification tasks. Features: - Device management (CPU/GPU) - Mixed precision training with GradScaler - Memory tracking - Time tracking - Common training loop structure - Result logging and saving """
[docs] def __init__(self, args, logger, dataset, model, optimizer): """ Initialize the BaseTrainer. Args: args: Configuration arguments logger: Logger instance for logging dataset: Dataset object containing data loaders model: Neural network model to train optimizer: Optimizer for training """ self.args = args self.logger = logger self.dataset = dataset self.model = model self.optimizer = optimizer # Initialize tracking variables self.time0 = time.time() self.device = self._setup_device() # Enable TF32 acceleration for A100 GPUs self._setup_tf32()
def _setup_device(self) -> torch.device: """ Setup and return the appropriate device for training. Returns: torch.device: Device to use for training """ device_name = getattr(self.args, 'device', None) if device_name: return torch.device(device_name) return torch.device('cuda' if torch.cuda.is_available() else 'cpu') def _setup_tf32(self): """Enable TF32 acceleration for A100 GPUs if available.""" try: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True except Exception: pass def _setup_scaler(self) -> torch.cuda.amp.GradScaler: """ Setup GradScaler for mixed precision training. Returns: torch.cuda.amp.GradScaler: Configured scaler """ return torch.cuda.amp.GradScaler(enabled=(self.device.type == 'cuda')) def _move_data_to_device(self): """Move dataset data object to the training device.""" if hasattr(self.dataset, 'data_o'): self.dataset.data_o = self.dataset.data_o.to(self.device) def _get_memory_usage(self) -> float: """ Get current GPU memory usage in MB. Returns: float: Memory usage in megabytes """ if torch.cuda.is_available(): return torch.cuda.max_memory_allocated() / (1024 ** 2) return 0.0 def _get_elapsed_time(self) -> float: """ Get elapsed time since training started. Returns: float: Elapsed time in seconds """ return time.time() - self.time0 def _log_training_progress(self, epoch: int, train_metrics: Dict[str, float], val_metrics: Dict[str, float]): """ Log training progress for current epoch. Args: epoch: Current epoch number train_metrics: Dictionary of training metrics val_metrics: Dictionary of validation metrics """ train_str = " | ".join([f"Train {k}={v:.4f}" for k, v in train_metrics.items()]) val_str = " | ".join([f"Val {k}={v:.4f}" for k, v in val_metrics.items()]) print(f"Epoch {epoch+1:02d}: {train_str} | {val_str}") print(f"Memory: {self._get_memory_usage():.2f} MB | Time: {self._get_elapsed_time():.3f}s") def _save_results(self, test_metrics: Dict[str, float], model_name: str = "Model"): """ Save test results to output file. Args: test_metrics: Dictionary of test metrics model_name: Name of the model for logging """ out_path = getattr(self.args, 'out_file', 'result.txt') os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) metrics_str = " ".join([f"{v:.4f}" for v in test_metrics.values()]) time_str = f"{self._get_elapsed_time():.3f}" memory_str = f"{self._get_memory_usage():.2f}" with open(out_path, "a") as f: f.write(f"{model_name} {metrics_str} time:{time_str} memory:{memory_str}\n") def _prepare_batch_data(self, batch_data, task_type: str = 'multiclass') -> torch.Tensor: """ Prepare and convert batch data to appropriate tensor format. Args: batch_data: Batch data from data loader task_type: Type of task ('multiclass' or 'multilabel') Returns: torch.Tensor: Prepared labels tensor """ labels = batch_data[2] if task_type == 'multiclass': return torch.as_tensor(np.array(labels), dtype=torch.long, device=self.device) elif task_type == 'multilabel': return torch.as_tensor(labels, dtype=torch.float32, device=self.device) else: raise ValueError(f"Unsupported task type: {task_type}") @abstractmethod def _get_loss_function(self, task_type: str): """ Get the appropriate loss function for the task. Args: task_type: Type of task ('multiclass' or 'multilabel') Returns: Loss function """ pass @abstractmethod def _compute_metrics(self, y_true: np.ndarray, y_logits: np.ndarray, task_type: str) -> Dict[str, float]: """ Compute evaluation metrics for predictions. Args: y_true: Ground truth labels y_logits: Model logits task_type: Type of task ('multiclass' or 'multilabel') Returns: Dictionary of computed metrics """ pass @abstractmethod def _train_epoch(self, epoch: int, loss_fct, scaler: torch.cuda.amp.GradScaler, task_type: str) -> Tuple[Dict[str, float], float]: """ Train for one epoch. Args: epoch: Current epoch number loss_fct: Loss function scaler: GradScaler for mixed precision task_type: Type of task Returns: Tuple of (metrics_dict, average_loss) """ pass @abstractmethod def _evaluate(self, loader, loss_fct, task_type: str) -> Tuple[Dict[str, float], float]: """ Evaluate the model on given data loader. Args: loader: Data loader for evaluation loss_fct: Loss function task_type: Type of task Returns: Tuple of (metrics_dict, average_loss) """ pass
[docs] def train(self): """ Main training method. Determines task type and starts appropriate training. """ # Move data to device self._move_data_to_device() # Determine task type based on matrix argument if getattr(self.args, 'matrix', None) in ['multilabel', 'twosides']: self._train_multilabel() else: self._train_multiclass()
def _train_multiclass(self): """Training loop for multiclass classification.""" print('Start Training (multiclass)...') # Setup training components loss_fct = self._get_loss_function('multiclass') scaler = self._setup_scaler() # Training loop for epoch in range(self.args.epochs): # Train one epoch train_metrics, train_loss = self._train_epoch(epoch, loss_fct, scaler, 'multiclass') # Validate val_metrics, val_loss = self._evaluate(self.dataset.val_loader, loss_fct, 'multiclass') # Log progress self._log_training_progress(epoch, {'Loss': train_loss, **train_metrics}, {'Loss': val_loss, **val_metrics}) # Final test evaluation self.model.eval() test_metrics, test_loss = self._evaluate(self.dataset.test_loader, loss_fct, 'multiclass') # Print and save results metrics_str = " | ".join([f"{k}={v:.4f}" for k, v in test_metrics.items()]) print(f"[Model] Test {metrics_str}") self._save_results(test_metrics, "Model") def _train_multilabel(self): """Training loop for multilabel classification.""" print('Start Training (multilabel)...') # Setup training components loss_fct = self._get_loss_function('multilabel') scaler = self._setup_scaler() # Training loop for epoch in range(self.args.epochs): # Train one epoch train_metrics, train_loss = self._train_epoch(epoch, loss_fct, scaler, 'multilabel') # Validate val_metrics, val_loss = self._evaluate(self.dataset.val_loader, loss_fct, 'multilabel') # Log progress self._log_training_progress(epoch, {'Loss': train_loss, **train_metrics}, {'Loss': val_loss, **val_metrics}) # Final test evaluation self.model.eval() test_metrics, test_loss = self._evaluate(self.dataset.test_loader, loss_fct, 'multilabel') # Print and save results metrics_str = " | ".join([f"{k}={v:.4f}" for k, v in test_metrics.items()]) print(f"[Model] Test {metrics_str}") self._save_results(test_metrics, "Model") # Legacy method names for backward compatibility
[docs] def train_binary(self): """Legacy method - redirects to main train method.""" self.train()
[docs] def train_multi(self): """Legacy method - redirects to main train method.""" self.train()