--- license: mit language: - en library_name: pytorch pipeline_tag: tabular-regression tags: - pytorch - transformer - bioinformatics - negative-binomial - glm - statistics - genomics - computational-biology datasets: - synthetic metrics: - mae - rmse model-index: - name: NB-Transformer results: - task: type: tabular-regression name: Negative Binomial GLM Parameter Estimation dataset: type: synthetic name: Synthetic NB GLM Data metrics: - type: mae value: 0.152 name: Log Fold Change MAE - type: inference_time value: 0.076 name: Inference Time (ms) --- # NB-Transformer: Fast Negative Binomial GLM Parameter Estimation [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) [![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-red.svg)](https://pytorch.org/) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) **NB-Transformer** is a fast, accurate neural network approach for Negative Binomial GLM parameter estimation, designed as a modern replacement for statistical analysis of counts. Using transformer-based attention mechanisms, it provides **14.8x speedup** over classical methods while maintaining **superior accuracy**. Paper: [arxiv.org/abs/2508.04111](https://arxiv.org/abs/2508.04111) ## 🚀 Key Features - **⚡ Ultra-Fast**: 14.8x faster than classical GLM (0.076ms vs 1.128ms per test) - **🎯 More Accurate**: 47% better accuracy on log fold change estimation - **🔬 Complete Statistical Inference**: P-values, confidence intervals, and power analysis - **📊 Robust**: 100% success rate vs 98.7% for classical methods - **🧠 Transformer Architecture**: Attention-based modeling of variable-length sample sets - **📦 Easy to Use**: Simple API with pre-trained model included ## 📈 Performance Benchmarks Based on comprehensive validation with 1000+ test cases: | Method | Success Rate | Time (ms) | μ MAE | β MAE | α MAE | |--------|--------------|-----------|-------|-------|-------| | **NB-Transformer** | **100.0%** | **0.076** | **0.202** | **0.152** | **0.477** | | Classical GLM | 98.7% | 1.128 | 0.212 | 0.284 | 0.854 | | Method of Moments | 100.0% | 0.021 | 0.213 | 0.289 | 0.852 | **Key Achievements:** - **47% better accuracy** on β (log fold change) - the critical parameter for differential expression - **44% better accuracy** on α (dispersion) - essential for proper statistical inference - **100% convergence rate** with no numerical instabilities ## 🛠️ Installation ```bash pip install nb-transformer ``` Or install from source: ```bash git clone https://huggingface.co/valsv/nb-transformer cd nb-transformer pip install -e . ``` ## 🎯 Quick Start ### Basic Usage ```python from nb_transformer import load_pretrained_model # Load the pre-trained model (downloads automatically) model = load_pretrained_model() # Your data: log10(CPM + 1) transformed counts control_samples = [2.1, 1.8, 2.3, 2.0] # 4 control samples treatment_samples = [1.5, 1.2, 1.7, 1.4] # 4 treatment samples # Get NB GLM parameters instantly params = model.predict_parameters(control_samples, treatment_samples) print(f"μ̂ (base mean): {params['mu']:.3f}") # -0.245 print(f"β̂ (log fold change): {params['beta']:.3f}") # -0.421 print(f"α̂ (log dispersion): {params['alpha']:.3f}") # -1.832 print(f"Fold change: {np.exp(params['beta']):.2f}x") # 0.66x (downregulated) ``` ### Complete Statistical Analysis ```python import numpy as np from nb_transformer import load_pretrained_model from nb_transformer.inference import compute_nb_glm_inference # Load model and data model = load_pretrained_model() control_counts = np.array([1520, 1280, 1650, 1400]) treatment_counts = np.array([980, 890, 1100, 950]) control_lib_sizes = np.array([1e6, 1.1e6, 0.9e6, 1.05e6]) treatment_lib_sizes = np.array([1e6, 1.0e6, 1.1e6, 0.95e6]) # Transform to log10(CPM + 1) control_transformed = np.log10(1e4 * control_counts / control_lib_sizes + 1) treatment_transformed = np.log10(1e4 * treatment_counts / treatment_lib_sizes + 1) # Get parameters params = model.predict_parameters(control_transformed, treatment_transformed) # Complete statistical inference results = compute_nb_glm_inference( params['mu'], params['beta'], params['alpha'], control_counts, treatment_counts, control_lib_sizes, treatment_lib_sizes ) print(f"Log fold change: {results['beta']:.3f} ± {results['se_beta']:.3f}") print(f"P-value: {results['pvalue']:.2e}") print(f"Significant: {'Yes' if results['pvalue'] < 0.05 else 'No'}") ``` ### Quick Demo ```python from nb_transformer import quick_inference_example # Run a complete example with sample data params = quick_inference_example() ``` ## 🔬 Validation & Reproducibility This package includes three comprehensive validation scripts that reproduce all key results: ### 1. Accuracy Validation Compare parameter estimation accuracy and speed across methods: ```bash python examples/validate_accuracy.py --n_tests 1000 --output_dir results/ ``` **Expected Output:** - Accuracy comparison plots - Speed benchmarks - Parameter estimation metrics - Success rate analysis ### 2. P-value Calibration Validation Validate that p-values are properly calibrated under null hypothesis: ```bash python examples/validate_calibration.py --n_tests 10000 --output_dir results/ ``` **Expected Output:** - QQ plots for p-value uniformity - Statistical tests for calibration - False positive rate analysis - Calibration assessment report ### 3. Statistical Power Analysis Evaluate statistical power across experimental designs and effect sizes: ```bash python examples/validate_power.py --n_tests 1000 --output_dir results/ ``` **Expected Output:** - Power curves by experimental design (3v3, 5v5, 7v7, 9v9) - Effect size analysis - Method comparison across designs - Statistical power benchmarks ## 🧮 Mathematical Foundation ### Model Architecture NB-Transformer uses a specialized transformer architecture for set-to-set comparison: - **Input**: Two variable-length sets of log-transformed expression values - **Architecture**: Pair-set transformer with intra-set and cross-set attention - **Output**: Three parameters (μ, β, α) for Negative Binomial GLM - **Training**: 2.5M parameters trained on synthetic data with known ground truth ### Statistical Inference The model enables complete statistical inference through Fisher information: 1. **Parameter Estimation**: Direct neural network prediction (μ̂, β̂, α̂) 2. **Fisher Weights**: Wi = mi/(1 + φmi) where mi = ℓiexp(μ̂ + xiβ̂) 3. **Standard Errors**: SE(β̂) = √[(X'WX)-1]ββ 4. **Wald Statistics**: W = β̂²/SE(β̂)² ~ χ²(1) under H₀: β = 0 5. **P-values**: Proper Type I error control validated via calibration analysis ### Key Innovation Unlike iterative maximum likelihood estimation, NB-Transformer learns the parameter mapping directly from data patterns, enabling: - **Instant inference** without convergence issues - **Robust parameter estimation** across challenging scenarios - **Full statistical validity** through Fisher information framework ## 📊 Comprehensive Validation Results ### Accuracy Across Parameter Types | Parameter | NB-Transformer | Classical GLM | Improvement | |-----------|---------------|---------------|-------------| | μ (base mean) | 0.202 MAE | 0.212 MAE | **5% better** | | β (log fold change) | **0.152 MAE** | 0.284 MAE | **47% better** | | α (dispersion) | **0.477 MAE** | 0.854 MAE | **44% better** | ### Statistical Power Analysis Power analysis across experimental designs shows competitive performance: | Design | Effect Size β=1.0 | Effect Size β=2.0 | |--------|-------------------|-------------------| | 3v3 samples | 85% power | 99% power | | 5v5 samples | 92% power | >99% power | | 7v7 samples | 96% power | >99% power | | 9v9 samples | 98% power | >99% power | ### P-value Calibration Rigorous calibration validation confirms proper statistical inference: - **Kolmogorov-Smirnov test**: p = 0.127 (well-calibrated) - **Anderson-Darling test**: p = 0.089 (well-calibrated) - **False positive rate**: 5.1% at α = 0.05 (properly controlled) ## 🏗️ Architecture Details ### Model Specifications - **Model Type**: Pair-set transformer for NB GLM parameter estimation - **Parameters**: 2.5M trainable parameters - **Architecture**: - Input dimension: 128 - Attention heads: 8 - Self-attention layers: 3 - Cross-attention layers: 3 - Dropout: 0.1 - **Training**: Synthetic data with online generation - **Validation Loss**: 0.4628 (v13 checkpoint) ### Input/Output Specification - **Input**: Two lists of log10(CPM + 1) transformed expression values - **Output**: Dictionary with keys 'mu', 'beta', 'alpha' (all on log scale) - **Sample Size**: Handles 2-20 samples per condition (variable length) - **Expression Range**: Optimized for typical RNA-seq expression levels ## 🔧 Advanced Usage ### Custom Model Loading ```python from nb_transformer import load_pretrained_model # Load model on specific device model = load_pretrained_model(device='cuda') # or 'cpu', 'mps' # Load custom checkpoint model = load_pretrained_model(checkpoint_path='path/to/custom.ckpt') ``` ### Batch Processing ```python # Process multiple gene comparisons efficiently from nb_transformer.method_of_moments import estimate_batch_parameters_vectorized control_sets = [[2.1, 1.8, 2.3], [1.9, 2.2, 1.7]] # Multiple genes treatment_sets = [[1.5, 1.2, 1.7], [2.1, 2.4, 1.9]] # Fast batch estimation results = estimate_batch_parameters_vectorized(control_sets, treatment_sets) ``` ### Training Custom Models ```python from nb_transformer import train_dispersion_transformer, ParameterDistributions # Define custom parameter distributions param_dist = ParameterDistributions() param_dist.mu_params = {'loc': -1.0, 'scale': 2.0} param_dist.alpha_params = {'mean': -2.0, 'std': 1.0} param_dist.beta_params = {'prob_de': 0.3, 'std': 1.0} # Training configuration config = { 'model_config': { 'd_model': 128, 'n_heads': 8, 'num_self_layers': 3, 'num_cross_layers': 3, 'dropout': 0.1 }, 'batch_size': 512, 'max_epochs': 20, 'examples_per_epoch': 100000, 'parameter_distributions': param_dist } # Train model results = train_dispersion_transformer(config) ``` ## 📋 Requirements ### Core Dependencies - Python ≥ 3.8 - PyTorch ≥ 1.10.0 - PyTorch Lightning ≥ 1.8.0 - NumPy ≥ 1.21.0 - SciPy ≥ 1.7.0 ### Optional Dependencies - **Validation**: `statsmodels`, `pandas`, `matplotlib`, `scikit-learn` - **Visualization**: `plotnine`, `theme-nxn` (custom plotting theme) - **Development**: `pytest`, `flake8`, `black`, `mypy` ## 🧪 Model Training Details ### Training Data - **Synthetic Generation**: Online negative binomial data generation - **Parameter Distributions**: Based on empirical RNA-seq statistics - **Sample Sizes**: Variable 2-10 samples per condition - **Expression Levels**: Realistic RNA-seq dynamic range - **Library Sizes**: Log-normal distribution (CV ~30%) ### Training Process - **Epochs**: 100 epochs - **Batch Size**: 32 - **Learning Rate**: 1e-4 with ReduceLROnPlateau scheduler - **Loss Function**: Multi-task MSE loss with parameter-specific weights - **Validation**: Hold-out synthetic data with different parameter seeds ### Hardware Optimization - **Apple Silicon**: Optimized for MPS (Metal Performance Shaders) - **Multi-core CPU**: Efficient multi-worker data generation - **Memory Usage**: Minimal memory footprint (~100MB model) - **Inference Speed**: Single-core CPU sufficient for real-time analysis ## 🤝 Contributing We welcome contributions! Please see our contributing guidelines: 1. **Bug Reports**: Open issues with detailed reproduction steps 2. **Feature Requests**: Propose new functionality with use cases 3. **Code Contributions**: Fork, develop, and submit pull requests 4. **Validation**: Run validation scripts to ensure reproducibility 5. **Documentation**: Improve examples and documentation ### Development Setup ```bash git clone https://huggingface.co/valsv/nb-transformer cd nb-transformer pip install -e ".[dev,analysis]" # Run tests pytest tests/ # Run validation python examples/validate_accuracy.py --n_tests 100 ``` ## 📖 Citation If you use NB-Transformer in your research, please cite: ```bibtex @software{svensson2025nbtransformer, title={NB-Transformer: Fast Negative Binomial GLM Parameter Estimation using Transformers}, author={Svensson, Valentine}, year={2025}, url={https://huggingface.co/valsv/nb-transformer}, version={1.0.0} } ``` ## 📚 Related Work ### Transformer Applications in Biology - **Set-based Learning**: Zaheer et al. (2017). Deep Sets. *NIPS*. - **Attention Mechanisms**: Vaswani et al. (2017). Attention Is All You Need. *NIPS*. - **Biological Applications**: Rives et al. (2021). Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences. *PNAS*. ## ⚖️ License MIT License - see [LICENSE](LICENSE) file for details. ## 🏷️ Version History ### v1.0.0 (2025-08-04) - **Initial release** with pre-trained v13 model - **Complete validation suite** (accuracy, calibration, power) - **Production-ready API** with comprehensive documentation - **Hugging Face integration** for easy model distribution --- **🚀 Ready to revolutionize your differential expression analysis? Install NB-Transformer today!** ```bash pip install nb-transformer ``` For questions, issues, or contributions, visit our [Hugging Face repository](https://huggingface.co/valsv/nb-transformer) or open an issue.