scxpand.mlp.mlp_losses#

Functions

compute_batch_loss(*, main_logit, y_soft_gt, ...)

Compute the complete batch loss including target selection and loss computation.

compute_total_nn_loss(*, main_logit, y_true, ...)

Compute the total loss as the sum of binary classification loss and categorical losses.

create_loss_function(*, prm, device)

Create and return the binary classification loss function.

should_use_soft_loss(*, epoch, prm)

Check if the soft loss should be used at the given epoch.

scxpand.mlp.mlp_losses.compute_batch_loss(*, main_logit, y_soft_gt, y_gt, categorical_logits, categorical_targets, loss_fn, prm, epoch)#

Compute the complete batch loss including target selection and loss computation.

Parameters:
  • main_logit (Tensor) – Binary classification logits, shape [batch_size]

  • y_soft_gt (Tensor) – Soft binary labels, shape [batch_size] in range [0, 1]

  • y_gt (Tensor) – Hard binary labels, shape [batch_size] in {0, 1}

  • categorical_logits (dict[str, Tensor] | None) – Dict of logits for categorical features, each shape [batch_size, n_classes]

  • categorical_targets (dict[str, Tensor] | None) – Dict of target indices for categorical features, each shape [batch_size]

  • loss_fn (Module) – Binary classification loss function (e.g., BCEWithLogitsLoss)

  • prm (MLPParam) – NNParam containing all loss weights

  • epoch (int) – Current training epoch

Returns:

torch.Tensor, sum of binary classification and categorical losses bin_cls_loss: torch.Tensor, binary classification loss cat_loss: torch.Tensor, sum of categorical classification losses

Return type:

total_loss

scxpand.mlp.mlp_losses.compute_total_nn_loss(*, main_logit, y_true, main_cls_loss_fn, prm, categorical_logits=None, categorical_targets=None)#

Compute the total loss as the sum of binary classification loss and categorical losses.

Parameters:
  • main_logit (Tensor) – Binary classification logits, shape [batch_size]

  • y_true (Tensor) – True binary labels, shape [batch_size]

  • main_cls_loss_fn (Module) – Binary classification loss function (e.g., BCEWithLogitsLoss)

  • prm (MLPParam) – NNParam containing all loss weights

  • categorical_logits (dict[str, Tensor] | None (default: None)) – Dict of logits for categorical features, each shape [batch_size, n_classes]

  • categorical_targets (dict[str, Tensor] | None (default: None)) – Dict of target indices for categorical features, each shape [batch_size]

Returns:

torch.Tensor, sum of binary classification and categorical losses bin_cls_loss: torch.Tensor, binary classification loss cat_loss: torch.Tensor, sum of categorical classification losses

Return type:

total_loss

scxpand.mlp.mlp_losses.create_loss_function(*, prm, device)#

Create and return the binary classification loss function.

Return type:

Module

scxpand.mlp.mlp_losses.should_use_soft_loss(*, epoch, prm)#

Check if the soft loss should be used at the given epoch.

Return type:

bool