scxpand.autoencoders.ae_trainer

scxpand.autoencoders.ae_trainer#

Functions

calculate_training_metrics(prob_pred, ...[, ...])

Calculate training metrics for logging.

run_ae_inference(model, batch_size[, ...])

Runs autoencoder inference.

run_ae_trainer(data_path, data_format, ...)

scxpand.autoencoders.ae_trainer.calculate_training_metrics(prob_pred, y_true, loss_outputs, optimizer, prm, threshold=0.5)#

Calculate training metrics for logging.

Parameters:
  • prob_pred (ndarray) – Predicted probabilities from model output

  • y_true (ndarray) – True binary labels

  • loss_outputs (tuple[Tensor, ...]) – Tuple containing (loss, recon_loss, cls_loss, l1_loss, cat_loss)

  • optimizer (Optimizer) – Optimizer to get learning rate from

  • prm (AutoEncoderParams) – AutoEncoder parameters to check for categorical losses

  • threshold (float (default: 0.5)) – Classification threshold for binary predictions

Return type:

dict[str, float]

Returns:

Dictionary containing calculated metrics

scxpand.autoencoders.ae_trainer.run_ae_inference(model, batch_size, data_path=None, data_format=None, eval_row_inds=None, device=None, num_workers=0, adata=None)#

Runs autoencoder inference. Accepts either data_path or adata (AnnData object).

Return type:

ndarray

scxpand.autoencoders.ae_trainer.run_ae_trainer(data_path, data_format, row_inds_train, row_inds_dev, save_path, prm, device, trial=None, score_metric='harmonic_avg/AUROC', resume=False, num_workers=0)#
Return type:

BaseAutoencoder