scxpand.mlp.mlp_trainer

scxpand.mlp.mlp_trainer#

Functions

run_mlp_inference([data_path, data_format, ...])

Run inference using a trained MLP model.

run_trainer(data_path, data_format, ...[, ...])

Runs the training loop for the multi-layer perceptron (MLP) model.

scxpand.mlp.mlp_trainer.run_mlp_inference(data_path=None, data_format=None, eval_row_inds=None, model=None, device=None, batch_size=1024, num_workers=0, adata=None)#

Run inference using a trained MLP model.

Returns:

np.ndarray, shape [N] (predicted probabilities of the positive class)

Return type:

pred_prob

Accepts either data_path or adata (AnnData object).

scxpand.mlp.mlp_trainer.run_trainer(data_path, data_format, row_inds_train, row_inds_dev, save_path, prm, device, trial=None, score_metric='harmonic_avg/AUROC', resume=False, num_workers=0)#

Runs the training loop for the multi-layer perceptron (MLP) model.

Parameters:
  • data_path (str | Path) – Path to the AnnData file

  • data_format (DataFormat) – DataFormat object containing preprocessing parameters

  • row_inds_train (ndarray) – cell indices of the training data (in the full dataset)

  • row_inds_dev (ndarray) – cell indices of the validation data (in the full dataset)

  • save_path (Path) – path to save results

  • prm (MLPParam) – MLPParam object containing model parameters

  • device (str) – Device to train on (‘cuda’, ‘mps’, or ‘cpu’)

  • trial (Trial | None (default: None)) – Optuna trial for hyperparameter optimization

  • score_metric (str (default: 'harmonic_avg/AUROC')) – Metric to use for model selection

  • resume (bool (default: False)) – whether to resume training from a checkpoint

  • num_workers (int (default: 0)) – Number of worker processes for data loading