import copy
import time
import os
import random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import (
roc_auc_score, average_precision_score, f1_score, accuracy_score,
recall_score, precision_score, precision_recall_curve, auc
)
from typing import Dict, Any, Tuple, Optional
from evaluate.evaluate import _metrics_from_logits, plot_metrics, _metrics_from_logits_multilabel
# Enable TF32 for A100 acceleration (doesn't affect numerical precision, only speed)
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
except Exception:
pass
from trainer.BaseTrainer import BaseTrainer
[docs]
class MVA_Trainer(BaseTrainer):
"""
Trainer class for MVA (Multi-View Attention) model.
Handles both multiclass and multilabel classification tasks with progress tracking.
"""
[docs]
def __init__(self, args, logger, dataset, model, optimizer):
"""
Initialize the MVA trainer.
Args:
args: Command line arguments or configuration object
logger: Logger instance for logging
dataset: Dataset object containing train/val/test loaders
model: Model to train
optimizer: Optimizer for training
"""
super().__init__(args, logger, dataset, model, optimizer)
self.time = time.time()
def _prepare_batch_data(self, batch_data, task_type: str = 'multiclass') -> torch.Tensor:
"""
Prepare batch labels for training/evaluation.
Args:
batch_data: Batch data from dataloader
task_type: Type of task ('multiclass' or 'multilabel')
Returns:
Prepared labels as torch tensor
"""
labels = batch_data[5]
if task_type == 'multiclass':
return torch.as_tensor(np.array(labels), dtype=torch.long, device=self.device)
elif task_type == 'multilabel':
return torch.as_tensor(labels, dtype=torch.float32, device=self.device)
else:
raise ValueError(f"Unsupported task type: {task_type}")
def _train_multiclass(self):
"""
Training loop for multiclass classification tasks.
"""
print('Start Training (multiclass)...')
scaler = self._setup_scaler()
for epoch in range(self.args.epochs):
# Train one epoch
train_metrics, train_loss = self._train_epoch(epoch, scaler, 'multiclass')
# Validate
val_metrics, val_loss = self._evaluate(self.dataset.val_loader, 'multiclass')
# Log progress
self._log_training_progress(epoch,
{'Loss': train_loss, **train_metrics},
{'Loss': val_loss, **val_metrics})
# Final test evaluation
self.model.eval()
test_metrics, test_loss = self._evaluate(self.dataset.test_loader, 'multiclass')
# Print and save results
metrics_str = " | ".join([f"{k}={v:.4f}" for k, v in test_metrics.items()])
print(f"[Model] Test {metrics_str}")
self._save_results(test_metrics, "Model")
def _train_multilabel(self):
"""
Training loop for multilabel classification tasks.
"""
print('Start Training (multilabel)...')
scaler = self._setup_scaler()
for epoch in range(self.args.epochs):
# Train one epoch
train_metrics, train_loss = self._train_epoch(epoch, scaler, 'multilabel')
# Validate
val_metrics, val_loss = self._evaluate(self.dataset.val_loader, 'multilabel')
# Log progress
self._log_training_progress(epoch,
{'Loss': train_loss, **train_metrics},
{'Loss': val_loss, **val_metrics})
# Final test evaluation
self.model.eval()
test_metrics, test_loss = self._evaluate(self.dataset.test_loader, 'multilabel')
# Print and save results
metrics_str = " | ".join([f"{k}={v:.4f}" for k, v in test_metrics.items()])
print(f"[Model] Test {metrics_str}")
self._save_results(test_metrics, "Model")
def _compute_metrics(self, y_true: np.ndarray, y_logits: np.ndarray,
task_type: str) -> Dict[str, float]:
"""
Compute evaluation metrics for predictions.
Args:
y_true: Ground truth labels
y_logits: Model logits
task_type: Type of task ('multiclass' or 'multilabel')
Returns:
Dictionary of computed metrics
"""
if task_type == 'multilabel':
auc, ap = _metrics_from_logits_multilabel(y_true, y_logits)
return {'AUC': auc}
else:
acc, f1, rec, pre = _metrics_from_logits(y_true, y_logits)
return {'Accuracy': acc, 'F1': f1, 'Recall': rec, 'Precision': pre}
def _train_epoch(self, epoch: int, scaler: torch.cuda.amp.GradScaler,
task_type: str) -> Tuple[Dict[str, float], float]:
"""
Train model for one epoch.
Args:
epoch: Current epoch number
scaler: Gradient scaler for mixed precision training
task_type: Type of task ('multiclass' or 'multilabel')
Returns:
Tuple of (metrics dictionary, average loss)
"""
self.model.train()
epoch_loss_sum = 0.0
epoch_batches = 0
y_pred_logits_epoch = []
y_true_epoch = []
for inp in tqdm(self.dataset.train_loader,
desc="Training", # Progress bar prefix
leave=True, # Keep progress bar after completion
dynamic_ncols=True):
# Move all input tensors to device
labels = self._prepare_batch_data(inp, task_type)
self.optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')):
output = self.model(inp)
loss_train = self.model.loss(output, labels.long() if task_type == 'multiclass' else labels)
scaler.scale(loss_train).backward()
scaler.step(self.optimizer)
scaler.update()
epoch_loss_sum += float(loss_train.item())
epoch_batches += 1
if task_type == 'multiclass':
y_true_epoch.extend(labels.detach().cpu().numpy().tolist())
else:
y_true_epoch.append(labels.detach().cpu().numpy())
y_pred_logits_epoch.append(output.detach().cpu().numpy())
# Compute metrics
if task_type == 'multiclass':
y_true_epoch = np.array(y_true_epoch)
y_pred_logits_epoch = np.concatenate(y_pred_logits_epoch, axis=0)
else:
y_true_epoch = np.concatenate(y_true_epoch, axis=0)
y_pred_logits_epoch = np.concatenate(y_pred_logits_epoch, axis=0)
avg_loss = epoch_loss_sum / max(epoch_batches, 1)
metrics = self._compute_metrics(y_true_epoch, y_pred_logits_epoch, task_type)
return metrics, avg_loss
def _evaluate(self, loader, task_type: str) -> Tuple[Dict[str, float], float]:
"""
Evaluate model on given data loader.
Args:
loader: Data loader for evaluation
task_type: Type of task ('multiclass' or 'multilabel')
Returns:
Tuple of (metrics dictionary, average loss)
"""
self.model.eval()
y_pred_logits = []
y_label = []
loss_sum = 0.0
batches = 0
with torch.no_grad():
for inp in loader:
# Move all input tensors to device
labels = self._prepare_batch_data(inp, task_type)
with torch.cuda.amp.autocast(enabled=(self.device.type == 'cuda')):
output = self.model(inp)
loss = self.model.loss(output, labels.long() if task_type == 'multiclass' else labels)
loss_sum += float(loss.item())
batches += 1
if task_type == 'multiclass':
y_label.extend(labels.detach().cpu().numpy().tolist())
else:
y_label.append(labels.detach().cpu().numpy())
y_pred_logits.append(output.detach().cpu().numpy())
# Compute metrics
if task_type == 'multiclass':
y_label_np = np.array(y_label)
y_pred_logits_np = np.concatenate(y_pred_logits, axis=0)
else:
y_label_np = np.concatenate(y_label, axis=0)
y_pred_logits_np = np.concatenate(y_pred_logits, axis=0)
avg_loss = loss_sum / max(batches, 1)
metrics = self._compute_metrics(y_label_np, y_pred_logits_np, task_type)
return metrics, avg_loss