Data Splitting Strategy#
Note
This document explains how scXpand splits datasets into training, validation, and test sets while maintaining biological integrity and preventing data leakage.
Overview#
scXpand uses two distinct splitting strategies depending on the evaluation level:
Train-Validation Split: Patient-level splitting with stratification by cancer type
Test Data Split: Study-level separation where test data comes from entirely different studies without stratification considerations
This dual approach ensures both proper model selection during training and realistic evaluation of generalization across different study contexts.
Multi-Level Data Splitting Strategy#
Key Principles#
Train-Validation Split (Patient-Level)#
- Patient-Level Splitting:
Cells from the same patient share genetic background, treatment history, and disease progression
Random cell-level splitting would create data leakage where training and validation sets contain cells from the same patients
Patient-level splitting provides more realistic performance estimates for model selection
- Implementation:
The splitting algorithm operates on unique patient identifiers (combining study and patient information) rather than individual cells, ensuring complete separation of patients between training and validation sets.
Test Data Split (Across Studies)#
- Study-Level Separation:
Test data comes from entirely different studies than those used for training and validation
No stratification by cancer type or other variables is applied
Provides true assessment of model generalization across different experimental contexts
- Purpose:
Evaluate model performance on completely unseen study populations
Test study-specific batch effects and methodological differences
Simulate real-world deployment scenarios where models encounter new study contexts
Stratified Splitting (Train-Validation Only)#
- Preserved Distributions (applies only to train-validation split):
Cancer Type Distribution: Maintains similar proportions of cancer types in training and validation sets
Tissue Type Distribution: Balances tissue types across splits when possible
Expansion Label Distribution: Preserves the balance of expanded vs. non-expanded cells
- Benefits:
Prevents bias toward specific cancer types in either set
Ensures validation set is representative of the overall patient population across studies
Maintains statistical power for rare cancer types during model selection
Note: Test data split does not use stratification, allowing evaluation of model performance on naturally occurring distributions in new studies.
Implementation Details#
Train-Validation Split Implementation#
Patient Identifier Generation#
scXpand creates unique patient identifiers by combining study and patient information:
from scxpand.data_util.data_splitter import get_patient_identifiers
# Generate composite patient IDs
patient_identifiers = get_patient_identifiers(adata.obs)
# Format: "study_name:patient_id"
# Example: ["study1:P001", "study1:P002", "study2:P003"]
- Required Metadata Columns:
study: Study or dataset identifierpatient: Patient identifier within each studycancer_type: Cancer type for stratification
Core Splitting Algorithm#
The split_data() function implements the patient-aware splitting:
from scxpand.data_util.data_splitter import split_data
# Split data by patients
train_indices, dev_indices = split_data(
adata=adata,
dev_ratio=0.2, # 20% for validation
random_seed=42, # Reproducible splits
save_path=results_dir # Save patient ID lists
)
Algorithm Steps:
Patient Enumeration: Extract unique patient identifiers
Cancer Type Mapping: Map each patient to their cancer type
Stratified Split: Use scikit-learn’s stratified splitting on patients
Cell Index Generation: Map patient splits back to cell-level indices
Quality Validation: Verify distribution preservation
Stratification Process#
The splitting uses scikit-learn’s train_test_split with stratification:
from sklearn.model_selection import train_test_split
# Stratify by cancer type at patient level
train_patients, dev_patients = train_test_split(
unique_patient_ids,
test_size=dev_ratio,
stratify=cancer_types_per_patient, # One cancer type per patient
random_state=random_seed
)
- Stratification Variables (train-validation split only):
Primary: Cancer type (ensures balanced representation)
Secondary: Tissue type and expansion status (monitored and reported)
Test Data Split Implementation#
For test evaluation, data comes from studies that are completely separate from those used in training and validation:
# Test data workflow (conceptual)
# Training studies: ["study_A", "study_B", "study_C"]
# Test studies: ["study_D", "study_E"]
# No stratification applied - use natural distribution
test_data = load_test_studies(["study_D", "study_E"])
# Evaluate trained model on test data
test_results = evaluate_model(model, test_data)
- Key Differences from Train-Validation Split:
No patient-level splitting needed (entire studies are separate)
No stratification by cancer type or other variables
Evaluation reflects natural distribution in new study contexts
Tests true generalization across different experimental settings
Reproducibility#
Deterministic Splitting#
The splitting process is fully deterministic when using a fixed random seed:
# Reproducible splits across runs
train_indices, dev_indices = split_data(
adata=adata,
dev_ratio=0.2,
random_seed=42 # Fixed seed ensures identical splits
)
- Saved Artifacts:
train_patient_ids.csv: List of training patient identifiersdev_patient_ids.csv: List of validation patient identifiersdata_splits.npz: Numpy arrays of cell indices for fast loading
Resumable Workflows#
Patient ID lists are saved to enable consistent splits across different runs:
# Load existing splits
train_patients = pd.read_csv("results/train_patient_ids.csv").values.flatten()
dev_patients = pd.read_csv("results/dev_patient_ids.csv").values.flatten()
# Reconstruct cell indices
patient_identifiers = get_patient_identifiers(adata.obs)
train_indices = np.where(patient_identifiers.isin(train_patients))[0]
dev_indices = np.where(patient_identifiers.isin(dev_patients))[0]
Integration with Training Pipeline#
Train-Validation Integration#
Data splitting for training and validation is automatically integrated into the training preparation:
from scxpand.data_util.prepare_data_for_train import prepare_data_for_training
# Prepare data with automatic splitting
bundle = prepare_data_for_training(
data_path="data.h5ad",
dev_ratio=0.2, # Validation split ratio
rand_seed=42, # Reproducible splits
save_dir="results/" # Output directory
)
# Access split results
train_indices = bundle.row_inds_train
dev_indices = bundle.row_inds_dev
data_format = bundle.data_format
Dataset Creation#
Split indices are used to create training and validation datasets:
from scxpand.data_util.dataset import CellsDataset
# Create training dataset
train_dataset = CellsDataset(
data_format=data_format,
row_inds=train_indices, # Only training cells
is_train=True,
data_path="data.h5ad"
)
# Create validation dataset
dev_dataset = CellsDataset(
data_format=data_format,
row_inds=dev_indices, # Only validation cells
is_train=False,
data_path="data.h5ad"
)