scxpand.data_util.dataset#

Functions

apply_post_normalization_augmentations(X[, ...])

Apply post-normalization augmentations to input tensor.

apply_pre_normalization_augmentations(X[, ...])

Apply pre-normalization augmentations to input tensor.

cells_collate_fn(batch_indices, dataset)

Collate function to efficiently create batches from the dataset using the new transformation system.

compute_categorical_targets_from_batch_obs(...)

Compute categorical targets directly from observation data using vectorized operations.

compute_soft_labels(obs_df, dataset_params)

Compute soft labels for the training data.

encode_categorical_features_batch(obs_df, ...)

Encode categorical features into one-hot vectors.

encode_categorical_value(value, mapping)

Encode a single categorical value to an index in the mapping.

get_dataloader_kwargs(num_workers, dataset)

Get common DataLoader keyword arguments.

Classes

CellsDataset(data_format[, row_inds, ...])

class scxpand.data_util.dataset.CellsDataset(data_format, row_inds=None, dataset_params=None, is_train=True, data_path=None, include_row_normalized_gene_counts=False, adata=None)#
__init__(data_format, row_inds=None, dataset_params=None, is_train=True, data_path=None, include_row_normalized_gene_counts=False, adata=None)#

PyTorch Dataset for single-cell expression data with preprocessing pipeline.

Provides efficient batch loading with on-the-fly preprocessing including normalization, log transformation, and z-score standardization. Supports both file-based and in-memory data access.

Parameters:
  • data_format (DataFormat) – DataFormat object containing preprocessing parameters.

  • row_inds (ndarray | None (default: None)) – Cell indices to include. If None, includes all cells.

  • dataset_params (DataAugmentParams | None (default: None)) – Data augmentation parameters. Only used during training.

  • is_train (bool (default: True)) – Whether this is training data (enables augmentation).

  • data_path (str | Path | None (default: None)) – Path to H5AD file. Required unless adata is provided.

  • include_row_normalized_gene_counts (bool (default: False)) – Include raw normalized counts in batches (useful for autoencoder training)

  • adata (AnnData | None (default: None)) – In-memory AnnData object. Alternative to data_path.

open_adata(indices)#

Context manager to yield (AnnData object, indices) for batch access.

Uses the utility function for multiprocessing-safe file opening.

transform_batch_data(X_raw, in_place=True)#

Transform raw batch data according to data format requirements.

Parameters:
  • X_raw (Tensor) – Raw gene expression data tensor [batch_size, n_raw_genes]

  • in_place (bool (default: True)) – Whether to modify X_raw in-place when possible (faster)

Return type:

Tensor

Returns:

Transformed data tensor [batch_size, n_target_genes]

scxpand.data_util.dataset.apply_post_normalization_augmentations(X, noise_std=0.0, generator=None)#

Apply post-normalization augmentations to input tensor.

These augmentations add controlled noise to normalized data.

Parameters:
  • X (Tensor) – Input tensor to augment (normalized data)

  • noise_std (float (default: 0.0)) – Standard deviation of Gaussian noise to add

  • generator (Generator | None (default: None)) – Optional PyTorch generator for reproducible randomness

Return type:

Tensor

Returns:

Augmented tensor with added noise

scxpand.data_util.dataset.apply_pre_normalization_augmentations(X, mask_rate=0.0, generator=None)#

Apply pre-normalization augmentations to input tensor.

These augmentations simulate missing data and should be applied to raw counts.

Parameters:
  • X (Tensor) – Input tensor to augment (raw counts)

  • mask_rate (float (default: 0.0)) – Rate at which to mask values (set to 0) to simulate missing genes

  • generator (Generator | None (default: None)) – Optional PyTorch generator for reproducible randomness

Return type:

Tensor

Returns:

Augmented tensor with masked values

scxpand.data_util.dataset.cells_collate_fn(batch_indices, dataset)#

Collate function to efficiently create batches from the dataset using the new transformation system.

Return type:

dict[str, Tensor]

scxpand.data_util.dataset.compute_categorical_targets_from_batch_obs(dataset, batch_obs)#

Compute categorical targets directly from observation data using vectorized operations.

Parameters:
  • dataset (CellsDataset) – The dataset containing category mappings and metadata

  • batch_obs (dict[str, ndarray]) – Dictionary of observation data for the current batch

Return type:

dict[str, Tensor]

Returns:

Dictionary mapping feature names to tensors containing categorical target indices

scxpand.data_util.dataset.compute_soft_labels(obs_df, dataset_params)#

Compute soft labels for the training data.

Parameters:
  • obs_df (DataFrame) – DataFrame containing observation data

  • dataset_params (DataAugmentParams) – Data augmentation parameters containing soft_loss_beta

  • prm – Param object with the model parameters

Returns:

A NumPy array of soft labels in the range [0, 1], if y_soft > 0.5, the cell is expanded.

Return type:

y_soft

scxpand.data_util.dataset.encode_categorical_features_batch(obs_df, categorical_features_types, categorical_mappings)#

Encode categorical features into one-hot vectors.

Parameters:
  • obs_df (DataFrame) – DataFrame containing observation data

  • categorical_features_types (list[str]) – List of categorical feature names

  • categorical_mappings (dict[str, dict[str, int]]) – Dict mapping feature names to {category: index} dicts.

Return type:

ndarray

Returns:

2D numpy array of shape (batch_size, total_categorical_vector_length)

scxpand.data_util.dataset.encode_categorical_value(value, mapping)#

Encode a single categorical value to an index in the mapping.

Parameters:
  • value (str | float) – The value to encode

  • mapping (dict[str, int]) – Dictionary mapping string values to indices

Return type:

tuple[int, bool]

Returns:

Tuple of (index, valid) where valid is True if the mapping contains the value

scxpand.data_util.dataset.get_dataloader_kwargs(num_workers, dataset)#

Get common DataLoader keyword arguments.

Parameters:
  • num_workers (int) – Number of worker processes

  • dataset (CellsDataset) – Dataset to create loader for

Return type:

dict[str, object]

Returns:

Dictionary of common DataLoader arguments