Example

import torch
from openddi.utils.logger import create_logger
from openddi.parms_setting import settings, set_random_seed
from openddi.data.dataset_manager import dataset_manager
from openddi.models.model_manager import model_manager
from openddi.pipeline import Pipeline

# Set random seed for reproducibility
set_random_seed(1, deterministic=True)

# Parse command-line arguments
args = settings()

# Create logger
logger = create_logger(args)

# Set device
args.cuda = (args.device == 'cuda')

# Load dataset
Dataset_manager = dataset_manager(args)
ddi_dataset = Dataset_manager.load_dataset()
ddi_dataset.load_data()

# Load model
Model_manager = model_manager(args)
if args.origin:
    model = Model_manager.load_origin_model(ddi_dataset)
else:
    model = Model_manager.load_model()

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=float(args.lr), weight_decay=float(args.weight_decay))

# Create and run pipeline
ddi_pipeline = Pipeline(args, logger, ddi_dataset, model, optimizer)
ddi_pipeline.run()