scxpand.util.plots

scxpand.util.plots#

Functions

plot_roc_curve(labels, probs_pred[, ...])

Plot ROC curve for binary classification and calculate AUROC.

plot_roc_curves_per_strata(y_true, ...[, ...])

Plot ROC curves for each stratum in a grid of subplots and calculate AUROC scores.

scxpand.util.plots.plot_roc_curve(labels, probs_pred, show_plot=False, plot_save_dir=None, plot_name='roc_curve', title='Receiver Operating Characteristic (ROC) Curve')#

Plot ROC curve for binary classification and calculate AUROC.

Creates a publication-ready ROC curve plot showing model performance across all classification thresholds. Optionally saves the plot to disk.

Parameters:
  • labels – True binary labels (0 or 1).

  • probs_pred – Predicted probabilities [0-1] from model.

  • show_plot (bool (default: False)) – Whether to display plot interactively.

  • plot_save_dir (Path | None (default: None)) – Directory to save plot. If None, plot is not saved.

  • plot_name (str (default: 'roc_curve')) – Filename for saved plot (without extension).

  • title (str (default: 'Receiver Operating Characteristic (ROC) Curve')) – Plot title text.

Return type:

float

Returns:

AUROC score (Area Under the ROC Curve).

scxpand.util.plots.plot_roc_curves_per_strata(y_true, y_pred_prob, obs_df, strata_columns, show_plot=True, plot_save_dir=None, save_results=True, max_cols=2)#

Plot ROC curves for each stratum in a grid of subplots and calculate AUROC scores.

Parameters:
  • y_true (ndarray) – True binary labels (0 or 1)

  • y_pred_prob (ndarray) – Predicted probabilities [0-1] from model

  • obs_df (DataFrame) – DataFrame containing observation data with stratification columns

  • strata_columns (list[str]) – List of column names to use for stratification.

  • show_plot (bool (default: True)) – Whether to display plots interactively

  • plot_save_dir (Path | None (default: None)) – Directory to save plots. If None, plots are not saved.

  • save_results (bool (default: True)) – Whether to save AUROC results to JSON file

  • max_cols (int (default: 2)) – Maximum number of columns in the plot grid

Return type:

dict[str, float]

Returns:

Dictionary mapping stratum names to AUROC scores