Source code for openddi.trainer.MRCGNN_Trainer

# MRCGNN_Trainer.py
import copy
import time
import os
import random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_auc_score, average_precision_score, f1_score, accuracy_score,
    recall_score, precision_score, precision_recall_curve, auc
)
from typing import Dict, Any, Tuple, Optional
from evaluate.evaluate import _metrics_from_logits, plot_metrics,_metrics_from_logits_multilabel

# Enable TF32 acceleration for A100 GPUs
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass
from trainer.BaseTrainer import BaseTrainer


[docs] class MRCGNN_Trainer(BaseTrainer): """ MRCGNN Trainer implementing BaseTrainer framework with MRCGNN-specific logic. Features: - Triple loss training (main loss + two auxiliary losses) - BCEWithLogitsLoss for auxiliary losses - Support for multiple data objects (data_o, data_s, data_a) """
[docs] def __init__(self, args, logger, dataset, model, optimizer): """ Initialize MRCGNN Trainer. Args: args: Configuration arguments logger: Logger instance for logging dataset: Dataset object containing data loaders model: Neural network model to train optimizer: Optimizer for training """ super().__init__(args, logger, dataset, model, optimizer) self.time = time.time() # MRCGNN-specific loss function self.b_xent = nn.BCEWithLogitsLoss()
def _get_loss_function(self, task_type: str): """ Get the loss function for MRCGNN training. Args: task_type: Type of task ('multiclass' or 'multilabel') Returns: Loss function """ if task_type == 'multiclass': return nn.CrossEntropyLoss() else: return nn.BCEWithLogitsLoss() def _compute_metrics(self, y_true: np.ndarray, y_logits: np.ndarray, task_type: str) -> Dict[str, float]: """ Compute evaluation metrics for MRCGNN predictions. Args: y_true: Ground truth labels y_logits: Model logits task_type: Type of task ('multiclass' or 'multilabel') Returns: Dictionary of computed metrics """ if task_type == 'multilabel': auc, ap = _metrics_from_logits_multilabel(y_true, y_logits) return {'AUC': auc} else: acc, f1, rec, pre = _metrics_from_logits(y_true, y_logits) return {'Accuracy': acc, 'F1': f1, 'Recall': rec, 'Precision': pre} def _train_epoch(self, epoch: int, loss_fct, scaler: torch.cuda.amp.GradScaler, task_type: str) -> Tuple[Dict[str, float], float]: """ Train for one epoch with MRCGNN-specific logic. Args: epoch: Current epoch number loss_fct: Loss function scaler: GradScaler for mixed precision task_type: Type of task Returns: Tuple of (metrics_dict, average_loss) """ self.model.train() epoch_loss_sum = 0.0 epoch_batches = 0 y_pred_logits_epoch = [] y_true_epoch = [] for inp in self.dataset.train_loader: labels = self._prepare_batch_data(inp, task_type) self.optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')): # MRCGNN-specific model call - passing 4 parameters output, cla_os, cla_os_a, _ = self.model( data_o=self.dataset.data_o, data_s=self.dataset.data_s, data_a=self.dataset.data_a, idx=inp ) # MRCGNN-specific triple loss loss1 = loss_fct(output, labels.long() if task_type == 'multiclass' else labels) # Ensure all tensors are on the same device data_a_y = self.dataset.data_a.y.float().to(self.device) loss2 = self.b_xent(cla_os, data_a_y) loss3 = self.b_xent(cla_os_a, data_a_y) loss_train = (self.args.loss_ratio1 * loss1 + self.args.loss_ratio2 * loss2 + self.args.loss_ratio3 * loss3) scaler.scale(loss_train).backward() scaler.step(self.optimizer) scaler.update() epoch_loss_sum += float(loss_train.item()) epoch_batches += 1 if task_type == 'multiclass': y_true_epoch.extend(labels.detach().cpu().numpy().tolist()) else: y_true_epoch.append(labels.detach().cpu().numpy()) y_pred_logits_epoch.append(output.detach().cpu().numpy()) # Compute metrics if task_type == 'multiclass': y_true_epoch = np.array(y_true_epoch) y_pred_logits_epoch = np.concatenate(y_pred_logits_epoch, axis=0) else: y_true_epoch = np.concatenate(y_true_epoch, axis=0) y_pred_logits_epoch = np.concatenate(y_pred_logits_epoch, axis=0) avg_loss = epoch_loss_sum / max(epoch_batches, 1) metrics = self._compute_metrics(y_true_epoch, y_pred_logits_epoch, task_type) return metrics, avg_loss def _evaluate(self, loader, loss_fct, task_type: str) -> Tuple[Dict[str, float], float]: """ Evaluate model on given data loader. Args: loader: Data loader for evaluation loss_fct: Loss function task_type: Type of task Returns: Tuple of (metrics_dict, average_loss) """ self.model.eval() y_pred_logits = [] y_label = [] loss_sum = 0.0 batches = 0 with torch.no_grad(): for inp in loader: labels = self._prepare_batch_data(inp, task_type) with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')): output, cla_os, cla_os_a, _ = self.model( data_o=self.dataset.data_o, data_s=self.dataset.data_s, data_a=self.dataset.data_a, idx=inp ) loss1 = loss_fct(output, labels.long() if task_type == 'multiclass' else labels) # Ensure all tensors are on the same device data_a_y = self.dataset.data_a.y.float().to(self.device) loss2 = self.b_xent(cla_os, data_a_y) loss3 = self.b_xent(cla_os_a, data_a_y) loss = (self.args.loss_ratio1 * loss1 + self.args.loss_ratio2 * loss2 + self.args.loss_ratio3 * loss3) loss_sum += float(loss.item()) batches += 1 if task_type == 'multiclass': y_label.extend(labels.detach().cpu().numpy().tolist()) else: y_label.append(labels.detach().cpu().numpy()) y_pred_logits.append(output.detach().cpu().numpy()) # Compute metrics if task_type == 'multiclass': y_label_np = np.array(y_label) y_pred_logits_np = np.concatenate(y_pred_logits, axis=0) else: y_label_np = np.concatenate(y_label, axis=0) y_pred_logits_np = np.concatenate(y_pred_logits, axis=0) avg_loss = loss_sum / max(batches, 1) metrics = self._compute_metrics(y_label_np, y_pred_logits_np, task_type) return metrics, avg_loss def _move_data_to_device(self): """Override to handle MRCGNN-specific data objects.""" if hasattr(self.dataset, 'data_o'): self.dataset.data_o = self.dataset.data_o.to(self.device) if hasattr(self.dataset, 'data_s'): self.dataset.data_s = self.dataset.data_s.to(self.device) if hasattr(self.dataset, 'data_a'): self.dataset.data_a = self.dataset.data_a.to(self.device) # Ensure model is also on the correct device if not next(self.model.parameters()).is_cuda and self.device.type == 'cuda': self.model.to(self.device) elif next(self.model.parameters()).is_cuda and self.device.type == 'cpu': self.model.to(self.device)