scxpand.mlp.mlp_losses#
Functions
|
Compute the complete batch loss including target selection and loss computation. |
|
Compute the total loss as the sum of binary classification loss and categorical losses. |
|
Create and return the binary classification loss function. |
|
Check if the soft loss should be used at the given epoch. |
- scxpand.mlp.mlp_losses.compute_batch_loss(*, main_logit, y_soft_gt, y_gt, categorical_logits, categorical_targets, loss_fn, prm, epoch)#
Compute the complete batch loss including target selection and loss computation.
- Parameters:
main_logit (
Tensor) – Binary classification logits, shape [batch_size]y_soft_gt (
Tensor) – Soft binary labels, shape [batch_size] in range [0, 1]y_gt (
Tensor) – Hard binary labels, shape [batch_size] in {0, 1}categorical_logits (
dict[str,Tensor] |None) – Dict of logits for categorical features, each shape [batch_size, n_classes]categorical_targets (
dict[str,Tensor] |None) – Dict of target indices for categorical features, each shape [batch_size]loss_fn (
Module) – Binary classification loss function (e.g., BCEWithLogitsLoss)prm (
MLPParam) – NNParam containing all loss weightsepoch (
int) – Current training epoch
- Returns:
torch.Tensor, sum of binary classification and categorical losses bin_cls_loss: torch.Tensor, binary classification loss cat_loss: torch.Tensor, sum of categorical classification losses
- Return type:
total_loss
- scxpand.mlp.mlp_losses.compute_total_nn_loss(*, main_logit, y_true, main_cls_loss_fn, prm, categorical_logits=None, categorical_targets=None)#
Compute the total loss as the sum of binary classification loss and categorical losses.
- Parameters:
main_logit (
Tensor) – Binary classification logits, shape [batch_size]y_true (
Tensor) – True binary labels, shape [batch_size]main_cls_loss_fn (
Module) – Binary classification loss function (e.g., BCEWithLogitsLoss)prm (
MLPParam) – NNParam containing all loss weightscategorical_logits (
dict[str,Tensor] |None(default:None)) – Dict of logits for categorical features, each shape [batch_size, n_classes]categorical_targets (
dict[str,Tensor] |None(default:None)) – Dict of target indices for categorical features, each shape [batch_size]
- Returns:
torch.Tensor, sum of binary classification and categorical losses bin_cls_loss: torch.Tensor, binary classification loss cat_loss: torch.Tensor, sum of categorical classification losses
- Return type:
total_loss
- scxpand.mlp.mlp_losses.create_loss_function(*, prm, device)#
Create and return the binary classification loss function.
- Return type: