scxpand.util.train_util#

Functions

check_early_stopping(current_score, ...)

get_lr_scheduler(optimizer, ...)

get_optimizer(model, optimizer_params)

Get an optimizer that applies weight decay selectively.

report_to_optuna_and_handle_pruning(trial, ...)

Report current score to Optuna trial and handle pruning, with duplicate prevention.

update_lr_scheduler(lr_scheduler, ...)

Update the learning rate scheduler.

scxpand.util.train_util.check_early_stopping(current_score, log_manager, patience_counter, patience_limit, epoch)#
Return type:

tuple[int, bool]

scxpand.util.train_util.get_lr_scheduler(optimizer, lr_scheduler_params, n_epochs, train_loader, init_learning_rate)#
Return type:

LRScheduler | None

scxpand.util.train_util.get_optimizer(model, optimizer_params)#

Get an optimizer that applies weight decay selectively.

Excludes LayerNorm weights and biases from weight decay following best practices.

Return type:

Optimizer

scxpand.util.train_util.report_to_optuna_and_handle_pruning(trial, current_score, epoch)#

Report current score to Optuna trial and handle pruning, with duplicate prevention.

This function prevents duplicate epoch reports that can occur when resuming from checkpoints, ensuring the Optuna dashboard shows accurate progress.

Parameters:
  • trial (Trial | None) – The Optuna trial object (or None).

  • current_score (float) – The current score to report.

  • epoch (int) – The current epoch number.

Raises:

optuna.TrialPruned – If the trial should be pruned.

Return type:

None

scxpand.util.train_util.update_lr_scheduler(lr_scheduler, lr_scheduler_params, score)#

Update the learning rate scheduler.