scxpand.linear.linear_trainer#

Linear model training components - consolidated trainer with all functionality.

Functions

run_linear_training(base_save_dir, prm, ...)

Run SGDClassifier model training and evaluation with support for logistic regression and SVM.

Classes

LinearBatchPredictor(dataset, dataloader)

Handles batch prediction for linear models.

LinearTrainLogger(base_save_dir[, trial])

Specialized logger for linear model training that extends the existing TrainLogger.

LinearTrainer(prm, base_save_dir)

Consolidated linear model trainer with data preparation, training, and evaluation.

TrainingSession(prm, score_metric)

Manages a single training session with state tracking.

class scxpand.linear.linear_trainer.LinearBatchPredictor(dataset, dataloader)#

Handles batch prediction for linear models.

__init__(dataset, dataloader)#
predict_all(model)#

Predict probabilities for all samples in the dataset.

Return type:

ndarray

predict_batch(model, X_batch)#

Predict probabilities for a single batch.

Return type:

ndarray

class scxpand.linear.linear_trainer.LinearTrainLogger(base_save_dir, trial=None)#

Specialized logger for linear model training that extends the existing TrainLogger.

__init__(base_save_dir, trial=None)#
init_linear_training(n_epochs, n_batches_per_epoch)#

Initialize training parameters for linear models.

Return type:

None

log_epoch_end(epoch)#

Log end of epoch with timing information.

Return type:

None

log_training_summary()#

Log training completion summary using existing infrastructure.

Return type:

None

log_validation_metrics(epoch, dev_set_metrics, score_metric)#

Log validation metrics with hierarchical display.

Return type:

None

update_best_score(score, epoch, metrics)#

Update the best model score and metrics.

Return type:

None

class scxpand.linear.linear_trainer.LinearTrainer(prm, base_save_dir)#

Consolidated linear model trainer with data preparation, training, and evaluation.

__init__(prm, base_save_dir)#
evaluate_model(model, eval_dataset, eval_dataloader, train_logger, score_metric, epoch)#

Evaluate model on validation set using DataLoader.

Return type:

tuple[float, dict, ndarray]

finalize_training(model, eval_dataset, eval_dataloader, train_logger, trial, score_metric)#

Finalize training by evaluating and saving the final model.

Return type:

dict

prepare_data_and_model(dev_ratio, data_path, num_workers=0)#

Prepare data and initialize model for training.

Return type:

tuple[SGDClassifier, CellsDataset, DataLoader, CellsDataset, DataLoader]

run_training(dev_ratio=0.2, trial=None, score_metric='harmonic_avg/AUROC', data_path=None, num_workers=0)#

Run the complete training process.

Return type:

dict[str, dict[str, float]]

train_epoch(model, train_dataloader, train_logger, classes, epoch)#

Train the model for one epoch using DataLoader following scikit-learn SGD best practices.

Return type:

None

class scxpand.linear.linear_trainer.TrainingSession(prm, score_metric)#

Manages a single training session with state tracking.

__init__(prm, score_metric)#
check_early_stopping(current_score, epoch)#

Check if early stopping should be triggered and update patience counter.

Return type:

bool

update_best_model(model, current_score, epoch, dev_set_metrics, logger)#

Update the best model state if current score is better.

Return type:

None

scxpand.linear.linear_trainer.run_linear_training(base_save_dir, prm, data_path, dev_ratio=0.2, trial=None, score_metric='harmonic_avg/AUROC', num_workers=0, resume=False)#

Run SGDClassifier model training and evaluation with support for logistic regression and SVM.

Note: Linear models don’t support resuming from checkpoints like PyTorch models. The resume parameter is accepted for API compatibility but not implemented.

Return type:

dict[str, dict[str, float]]