scxpand.autoencoders.ae_losses#
Functions
|
Compute the total loss as the sum of reconstruction loss, classification loss, and categorical losses. |
Create reconstruction loss function based on parameters. |
|
|
Check if the soft loss should be used at the given epoch. |
Classes
- class scxpand.autoencoders.ae_losses.MSELoss(eps=1e-08)#
- class scxpand.autoencoders.ae_losses.NB(use_masking=False, eps=1e-08)#
- __init__(use_masking=False, eps=1e-08)#
Negative binomial loss.
- class scxpand.autoencoders.ae_losses.ZINBLoss(eps=1e-08)#
- __init__(eps=1e-08)#
Zero-inflated negative binomial loss.
- Parameters:
eps (
float(default:1e-08)) – Small constant for numerical stability
- nb_loss(x_genes_true, mu, theta)#
Negative binomial negative log-likelihood without zero-inflation.
- zinb_loss(x_genes_true, mu, theta, pi)#
Zero-inflated negative binomial negative log-likelihood.
For ZINB, the probability mass function is: P(Y = 0) = π + (1-π) * NB(0; μ, θ) P(Y = k) = (1-π) * NB(k; μ, θ) for k > 0
- scxpand.autoencoders.ae_losses.compute_total_autoencoder_loss(*, x_genes_true, mu, pi, theta, latent_vec, class_logit, y_true, y_soft_gt, recon_loss_fn, bce_loss, prm, epoch, categorical_logits=None, categorical_targets=None)#
Compute the total loss as the sum of reconstruction loss, classification loss, and categorical losses.
- Parameters:
x_genes_true (
Tensor) – Observed counts, shape [batch_size, n_genes]mu (
Tensor) – Predicted means, shape [batch_size, n_genes]pi (
Tensor|None) – Zero-inflation probabilities, shape [batch_size, n_genes] (None for MSE/NB)theta (
Tensor|None) – Dispersion parameters, shape [batch_size, n_genes] (None for MSE)latent_vec (
Tensor) – Latent vector from encoder, shape [batch_size, latent_dim]class_logit (
Tensor) – Classification logits, shape [batch_size]y_true (
Tensor) – True binary labels, shape [batch_size]y_soft_gt (
Tensor|None) – Soft binary labels, shape [batch_size] in range [0, 1] (None if not using soft loss)recon_loss_fn (
ZINBLoss|NB|MSELoss) – Loss function instance (ZINBLoss, NB, or MSELoss)bce_loss (
Module) – Binary cross-entropy loss moduleprm (
AutoEncoderParams) – AutoEncoderParams containing all loss weights and regularization parametersepoch (
int) – Current training epoch (used for soft loss scheduling)categorical_logits (
dict[str,Tensor] |None(default:None)) – Dict of logits for categorical features, each shape [batch_size, n_classes]categorical_targets (
dict[str,Tensor] |None(default:None)) – Dict of target indices for categorical features, each shape [batch_size]
- Returns:
torch.Tensor, sum of reconstruction, classification, and categorical losses recon_loss: torch.Tensor, reconstruction loss cls_loss: torch.Tensor, classification loss l1_loss: torch.Tensor, L1 regularization loss cat_loss: torch.Tensor, sum of categorical classification losses
- Return type:
total_loss
- scxpand.autoencoders.ae_losses.create_autoencoder_recon_loss_function(prm)#
Create reconstruction loss function based on parameters.
- Return type: