scxpand.mlp.mlp_trainer#
Functions
|
Run inference using a trained MLP model. |
|
Runs the training loop for the multi-layer perceptron (MLP) model. |
- scxpand.mlp.mlp_trainer.run_mlp_inference(data_path=None, data_format=None, eval_row_inds=None, model=None, device=None, batch_size=1024, num_workers=0, adata=None)#
Run inference using a trained MLP model.
- Returns:
np.ndarray, shape [N] (predicted probabilities of the positive class)
- Return type:
pred_prob
Accepts either data_path or adata (AnnData object).
- scxpand.mlp.mlp_trainer.run_trainer(data_path, data_format, row_inds_train, row_inds_dev, save_path, prm, device, trial=None, score_metric='harmonic_avg/AUROC', resume=False, num_workers=0)#
Runs the training loop for the multi-layer perceptron (MLP) model.
- Parameters:
data_format (
DataFormat) – DataFormat object containing preprocessing parametersrow_inds_train (
ndarray) – cell indices of the training data (in the full dataset)row_inds_dev (
ndarray) – cell indices of the validation data (in the full dataset)save_path (
Path) – path to save resultsprm (
MLPParam) – MLPParam object containing model parametersdevice (
str) – Device to train on (‘cuda’, ‘mps’, or ‘cpu’)trial (
Trial|None(default:None)) – Optuna trial for hyperparameter optimizationscore_metric (
str(default:'harmonic_avg/AUROC')) – Metric to use for model selectionresume (
bool(default:False)) – whether to resume training from a checkpointnum_workers (
int(default:0)) – Number of worker processes for data loading