Source code for openddi.trainer.GoGNN_Trainer

import copy
from tqdm import tqdm
import time
import os
import random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from typing import Dict, Any, Tuple, Optional
from evaluate.evaluate import _metrics_from_logits, plot_metrics, _metrics_from_logits_multilabel

# Enable TF32 for A100 acceleration (doesn't affect numerical precision, only speed)
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass
from trainer.BaseTrainer import BaseTrainer


[docs] class GoGNN_Trainer(BaseTrainer): """ Trainer class for Graph of Graphs Neural Network (GoGNN) models. Handles both multiclass and multilabel classification tasks with mixed precision training. """
[docs] def __init__(self, args, logger, dataset, model, optimizer): """ Initialize the GoGNN trainer. Args: args: Command line arguments or configuration object logger: Logger instance for logging dataset: Dataset object containing train/val/test loaders model: Model to train optimizer: Optimizer for training """ super().__init__(args, logger, dataset, model, optimizer) self.time = time.time() self._debug_done = False # Debug flag for first batch only
def _prepare_batch_data(self, batch_data, task_type: str = 'multiclass') -> torch.Tensor: """ Prepare batch labels for training/evaluation. Args: batch_data: Batch data from dataloader task_type: Type of task ('multiclass' or 'multilabel') Returns: Prepared labels as torch tensor """ labels = batch_data[3] if task_type == 'multiclass': return torch.as_tensor(np.array(labels), dtype=torch.long, device=self.device) elif task_type == 'multilabel': return torch.as_tensor(labels, dtype=torch.float32, device=self.device) else: raise ValueError(f"Unsupported task type: {task_type}") def _train_multiclass(self): """ Training loop for multiclass classification tasks. """ print('Start Training (multiclass)...') scaler = self._setup_scaler() for epoch in range(self.args.epochs): # Train one epoch train_metrics, train_loss = self._train_epoch(epoch, scaler, 'multiclass') # Validate val_metrics, val_loss = self._evaluate(self.dataset.val_loader, 'multiclass') # Log progress self._log_training_progress(epoch, {'Loss': train_loss, **train_metrics}, {'Loss': val_loss, **val_metrics}) # Final test evaluation self.model.eval() test_metrics, test_loss = self._evaluate(self.dataset.test_loader, 'multiclass') # Print and save results metrics_str = " | ".join([f"{k}={v:.4f}" for k, v in test_metrics.items()]) print(f"[Model] Test {metrics_str}") self._save_results(test_metrics, "Model") def _train_multilabel(self): """ Training loop for multilabel classification tasks. """ print('Start Training (multilabel)...') scaler = self._setup_scaler() for epoch in range(self.args.epochs): # Train one epoch train_metrics, train_loss = self._train_epoch(epoch, scaler, 'multilabel') # Validate val_metrics, val_loss = self._evaluate(self.dataset.val_loader, 'multilabel') # Log progress self._log_training_progress(epoch, {'Loss': train_loss, **train_metrics}, {'Loss': val_loss, **val_metrics}) # Final test evaluation self.model.eval() test_metrics, test_loss = self._evaluate(self.dataset.test_loader, 'multilabel') # Print and save results metrics_str = " | ".join([f"{k}={v:.4f}" for k, v in test_metrics.items()]) print(f"[Model] Test {metrics_str}") self._save_results(test_metrics, "Model") def _compute_metrics(self, y_true: np.ndarray, y_logits: np.ndarray, task_type: str) -> Dict[str, float]: """ Compute evaluation metrics for predictions. Args: y_true: Ground truth labels y_logits: Model logits task_type: Type of task ('multiclass' or 'multilabel') Returns: Dictionary of computed metrics """ if task_type == 'multilabel': auc, ap = _metrics_from_logits_multilabel(y_true, y_logits) return {'AUC': auc} else: acc, f1, rec, pre = _metrics_from_logits(y_true, y_logits) return {'Accuracy': acc, 'F1': f1, 'Recall': rec, 'Precision': pre} def _train_epoch(self, epoch: int, scaler: torch.cuda.amp.GradScaler, task_type: str) -> Tuple[Dict[str, float], float]: """ Train model for one epoch. Args: epoch: Current epoch number scaler: Gradient scaler for mixed precision training task_type: Type of task ('multiclass' or 'multilabel') Returns: Tuple of (metrics dictionary, average loss) """ self.model.train() epoch_loss_sum = 0.0 epoch_batches = 0 y_pred_logits_epoch = [] y_true_epoch = [] for inp in tqdm(self.dataset.train_loader, desc="Training", # Progress bar prefix leave=True, # Keep progress bar after completion dynamic_ncols=True): labels = self._prepare_batch_data(inp, task_type) self.optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')): output = self.model(inp) loss_train = self.model.loss(output, labels.long() if task_type == 'multiclass' else labels) scaler.scale(loss_train).backward() scaler.step(self.optimizer) scaler.update() epoch_loss_sum += float(loss_train.item()) epoch_batches += 1 if task_type == 'multiclass': y_true_epoch.extend(labels.detach().cpu().numpy().tolist()) else: y_true_epoch.append(labels.detach().cpu().numpy()) y_pred_logits_epoch.append(output.detach().cpu().numpy()) # Compute metrics if task_type == 'multiclass': y_true_epoch = np.array(y_true_epoch) y_pred_logits_epoch = np.concatenate(y_pred_logits_epoch, axis=0) else: y_true_epoch = np.concatenate(y_true_epoch, axis=0) y_pred_logits_epoch = np.concatenate(y_pred_logits_epoch, axis=0) avg_loss = epoch_loss_sum / max(epoch_batches, 1) metrics = self._compute_metrics(y_true_epoch, y_pred_logits_epoch, task_type) return metrics, avg_loss def _evaluate(self, loader, task_type: str) -> Tuple[Dict[str, float], float]: """ Evaluate model on given data loader. Args: loader: Data loader for evaluation task_type: Type of task ('multiclass' or 'multilabel') Returns: Tuple of (metrics dictionary, average loss) """ self.model.eval() y_pred_logits = [] y_label = [] loss_sum = 0.0 batches = 0 with torch.no_grad(): for inp in loader: labels = self._prepare_batch_data(inp, task_type) with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')): output = self.model(inp) loss = self.model.loss(output, labels.long() if task_type == 'multiclass' else labels) loss_sum += float(loss.item()) batches += 1 if task_type == 'multiclass': y_label.extend(labels.detach().cpu().numpy().tolist()) else: y_label.append(labels.detach().cpu().numpy()) y_pred_logits.append(output.detach().cpu().numpy()) # Compute metrics if task_type == 'multiclass': y_label_np = np.array(y_label) y_pred_logits_np = np.concatenate(y_pred_logits, axis=0) else: y_label_np = np.concatenate(y_label, axis=0) y_pred_logits_np = np.concatenate(y_pred_logits, axis=0) avg_loss = loss_sum / max(batches, 1) metrics = self._compute_metrics(y_label_np, y_pred_logits_np, task_type) return metrics, avg_loss