--- library_name: pytorch pipeline_tag: tabular-classification base_model: Prior-Labs/TabPFN-v2-clf license: other license_name: prior-labs-license-v1.1 license_link: LICENSE tags: - pytorch - tabpfn - species-distribution-modeling - ecology - biodiversity - tabular-classification - finetuned metrics: - roc_auc - average_precision --- # TabPFN-SDM: TabPFN Finetuned for Species Distribution Modeling **Built with PriorLabs-TabPFN** Finetuned [TabPFN](https://huggingface.co/Prior-Labs/TabPFN-v2-clf) models for binary species distribution modeling (SDM). These models were adapted from the pretrained TabPFN foundation model to improve performance on presence/absence classification tasks common in ecology and conservation biology. ## Model Variants | File | Description | Epochs | Val ROC-AUC | Val PR-AUC | |------|-------------|--------|-------------|------------| | `tabpfn-sdm-nonspatial.pt` | Standard (non-spatial) train/test evaluation | 100 | 0.747 | 0.261 | | `tabpfn-sdm-spatial.pt` | Spatially-separated train/test evaluation | 50 | 0.653 | 0.144 | The **non-spatial** variant is trained and evaluated using standard random train/test splits. The **spatial** variant uses 10km buffer-based spatial separation between training and test data, which is more realistic for ecological applications but results in a harder prediction task. ## Training Details ### Two-Step Finetuning Each model was finetuned in two steps: 1. **Step 1 (Cross-Validation):** Finetuned on CV-generated train/test splits from training data. Learning rate: 1e-5. Goal: adapt TabPFN to general SDM classification patterns. 2. **Step 2 (Benchmark):** Further finetuned on actual benchmark train/test splits. Learning rate: 1e-6. Goal: refine for the specific evaluation setup. ### Training Data - **226 species** across **6 geographic regions**: AWT (Australia), NSW (New South Wales), CAN (Canada), NZ (New Zealand), SA (South Africa), SWI (Switzerland) - Species split: 62.5% finetuning, 7.5% validation, 30% held-out test - Data from the `disdat` R package (standardized presence-only and presence-absence datasets) - Features include environmental covariates (7-12 per region) with native categorical variable support ### Training Procedure - **Base model:** TabPFN v2 classifier (`Prior-Labs/TabPFN-v2-clf`) - **Optimizer:** AdamW - **LR Schedule:** OneCycleLR - **Training approach:** Iterative R-Python epoch loop with smart sub-batching - R generates cross-validation data (parallelized with 20 cores) - Python trains one epoch on GPU - Smart sub-batching: keeps ALL presences (rare, valuable) in every sub-batch, partitions absences across sub-batches - **Max training samples per sub-batch:** 1,500 - **Random seed:** 32639 ### Hardware - GPU: NVIDIA RTX 4000 ADA GPU with 20GB VRAM - Peak memory: ~14GB during training - Training time: ~188 hours total (Step 1 + Step 2) for non-spatial; ~178 hours for spatial ## Usage ### Python ```python import torch from tabpfn import TabPFNClassifier from huggingface_hub import hf_hub_download # Download model model_path = hf_hub_download( repo_id="rdinnager/tabpfn-sdm-finetuned", filename="tabpfn-sdm-nonspatial.pt" ) # Load finetuned model device = "cuda" if torch.cuda.is_available() else "cpu" clf = TabPFNClassifier( ignore_pretraining_limits=True, device=device, n_estimators=8, random_state=32639 ) clf._initialize_model_variables() checkpoint = torch.load(model_path, map_location=device, weights_only=False) clf.models_[0].load_state_dict(checkpoint["model_state_dict"]) clf.models_[0].eval() # Fit on training data and predict clf.fit(X_train, y_train) probs = clf.predict_proba(X_test)[:, 1] # Probability of presence ``` ### R (via reticulate) The companion R code provides inference wrappers. See the [TabPFN-SDM GitHub repository](https://github.com/rdinnager/TabPFN-SDM) for the full pipeline. ```r source("R/run_tabpfn_finetuned.R") result <- run_tabpfn_finetuned_ensemble( train_dat = train_data, test_dat = test_data, model_path = "tabpfn-sdm-nonspatial.pt", max_train_size = 1000L, n_estimators = 8L ) # result$test contains y_truth and y_pred columns ``` ### Ensemble Inference (Recommended) For best results, use ensemble inference which matches the training procedure: - Keeps ALL presences in every sub-batch - Partitions absences across sub-batches for 100% data usage - Averages predictions across sub-batches This is implemented in `run_tabpfn_finetuned_ensemble()` (R) or can be replicated in Python using the utilities in the GitHub repo. ## Checkpoint Contents Each `.pt` file is a PyTorch checkpoint dictionary containing: | Key | Description | |-----|-------------| | `model_state_dict` | Finetuned TabPFN model weights | | `config` | Training configuration (hyperparameters, data settings) | | `history` | Training history (loss, metrics per epoch) | | `step1_path` | Path to the Step 1 model used as initialization | ## Limitations - Models were trained on species from 6 specific geographic regions and may not generalize to other regions without additional finetuning - Performance varies by species; species with very few presences may have lower accuracy - The spatial variant has lower validation metrics due to the inherently harder task of spatial extrapolation - Requires the `tabpfn` Python package (v2) for inference - Categorical variable handling requires proper setup (see project documentation) ## Citation If you use these models, please cite the accompanying paper: ```bibtex @article{dinnage2026niche, title={A Niche in the Machine: The Promise of AI Foundation Models for Species Distribution Modeling}, author={Dinnage, Russell and Warren, Dan L.}, year={2026}, doi={10.32942/X2VQ10}, journal={EcoEvoRxiv}, url={https://ecoevorxiv.org/repository/view/11797/} } ``` Please also cite the original TabPFN paper: ```bibtex @article{hollmann2025tabpfn, title={Accurate Predictions on Small Data with a Tabular Foundation Model}, author={Hollmann, Noah and M{\"u}ller, Samuel and Purucker, Lennart and Krishnakumar, Arjun and K{\"o}rfer, Max and Hoo, Shi Bin and Schirrmeister, Robin Tibor and Hutter, Frank}, journal={Nature}, year={2025}, publisher={Nature Publishing Group} } ``` ## License These finetuned model weights are distributed under the [Prior Labs License v1.1](LICENSE), consistent with the base TabPFN model license. See the LICENSE file for full terms. ## Links - [TabPFN-SDM GitHub Repository](https://github.com/rdinnager/TabPFN-SDM) - [Base TabPFN Model](https://huggingface.co/Prior-Labs/TabPFN-v2-clf) - [TabPFN GitHub](https://github.com/PriorLabs/TabPFN)