scxpand.autoencoders.ae_losses#

Functions

compute_total_autoencoder_loss(*, ...[, ...])

Compute the total loss as the sum of reconstruction loss, classification loss, and categorical losses.

create_autoencoder_recon_loss_function(prm)

Create reconstruction loss function based on parameters.

should_use_soft_loss(*, epoch, prm)

Check if the soft loss should be used at the given epoch.

Classes

MSELoss([eps])

NB([use_masking, eps])

ZINBLoss([eps])

class scxpand.autoencoders.ae_losses.MSELoss(eps=1e-08)#
__init__(eps=1e-08)#

Mean squared error loss.

Parameters:

eps (float (default: 1e-08)) – Small constant for numerical stability

mse_loss(x_genes_true, x_pred)#
Return type:

Tensor

class scxpand.autoencoders.ae_losses.NB(use_masking=False, eps=1e-08)#
__init__(use_masking=False, eps=1e-08)#

Negative binomial loss.

Parameters:
  • use_masking (bool (default: False)) – Whether to use masking for NaN values

  • eps (float (default: 1e-08)) – Small constant for numerical stability

loss(x_genes_true, mu, theta)#

Negative binomial negative log-likelihood loss.

Parameters:
  • x_genes_true (Tensor) – Observed counts

  • mu (Tensor) – Expected means (μ)

  • theta (Tensor) – Dispersion parameter

Return type:

Tensor

Returns:

Mean negative log-likelihood across batch

class scxpand.autoencoders.ae_losses.ZINBLoss(eps=1e-08)#
__init__(eps=1e-08)#

Zero-inflated negative binomial loss.

Parameters:

eps (float (default: 1e-08)) – Small constant for numerical stability

nb_loss(x_genes_true, mu, theta)#

Negative binomial negative log-likelihood without zero-inflation.

Parameters:
  • x_genes_true (Tensor) – Observed counts

  • mu (Tensor) – Expected means (μ)

  • theta (Tensor) – Dispersion parameter

Return type:

Tensor

Returns:

Mean negative log-likelihood across batch

zinb_loss(x_genes_true, mu, theta, pi)#

Zero-inflated negative binomial negative log-likelihood.

For ZINB, the probability mass function is: P(Y = 0) = π + (1-π) * NB(0; μ, θ) P(Y = k) = (1-π) * NB(k; μ, θ) for k > 0

Parameters:
  • x_genes_true (Tensor) – Observed counts

  • mu (Tensor) – Expected means (μ)

  • theta (Tensor) – Dispersion parameter

  • pi (Tensor) – Zero-inflation parameter (probability of structural zero)

Returns:

Negative log-likelihood across batch

Return type:

nll

scxpand.autoencoders.ae_losses.compute_total_autoencoder_loss(*, x_genes_true, mu, pi, theta, latent_vec, class_logit, y_true, y_soft_gt, recon_loss_fn, bce_loss, prm, epoch, categorical_logits=None, categorical_targets=None)#

Compute the total loss as the sum of reconstruction loss, classification loss, and categorical losses.

Parameters:
  • x_genes_true (Tensor) – Observed counts, shape [batch_size, n_genes]

  • mu (Tensor) – Predicted means, shape [batch_size, n_genes]

  • pi (Tensor | None) – Zero-inflation probabilities, shape [batch_size, n_genes] (None for MSE/NB)

  • theta (Tensor | None) – Dispersion parameters, shape [batch_size, n_genes] (None for MSE)

  • latent_vec (Tensor) – Latent vector from encoder, shape [batch_size, latent_dim]

  • class_logit (Tensor) – Classification logits, shape [batch_size]

  • y_true (Tensor) – True binary labels, shape [batch_size]

  • y_soft_gt (Tensor | None) – Soft binary labels, shape [batch_size] in range [0, 1] (None if not using soft loss)

  • recon_loss_fn (ZINBLoss | NB | MSELoss) – Loss function instance (ZINBLoss, NB, or MSELoss)

  • bce_loss (Module) – Binary cross-entropy loss module

  • prm (AutoEncoderParams) – AutoEncoderParams containing all loss weights and regularization parameters

  • epoch (int) – Current training epoch (used for soft loss scheduling)

  • categorical_logits (dict[str, Tensor] | None (default: None)) – Dict of logits for categorical features, each shape [batch_size, n_classes]

  • categorical_targets (dict[str, Tensor] | None (default: None)) – Dict of target indices for categorical features, each shape [batch_size]

Returns:

torch.Tensor, sum of reconstruction, classification, and categorical losses recon_loss: torch.Tensor, reconstruction loss cls_loss: torch.Tensor, classification loss l1_loss: torch.Tensor, L1 regularization loss cat_loss: torch.Tensor, sum of categorical classification losses

Return type:

total_loss

scxpand.autoencoders.ae_losses.create_autoencoder_recon_loss_function(prm)#

Create reconstruction loss function based on parameters.

Return type:

Module

scxpand.autoencoders.ae_losses.should_use_soft_loss(*, epoch, prm)#

Check if the soft loss should be used at the given epoch.

Return type:

bool