rdinnager's picture
Upload finetuned TabPFN-SDM models (nonspatial + spatial)
6751e3d verified
metadata
library_name: pytorch
pipeline_tag: tabular-classification
base_model: Prior-Labs/TabPFN-v2-clf
license: other
license_name: prior-labs-license-v1.1
license_link: LICENSE
tags:
  - pytorch
  - tabpfn
  - species-distribution-modeling
  - ecology
  - biodiversity
  - tabular-classification
  - finetuned
metrics:
  - roc_auc
  - average_precision

TabPFN-SDM: TabPFN Finetuned for Species Distribution Modeling

Built with PriorLabs-TabPFN

Finetuned TabPFN models for binary species distribution modeling (SDM). These models were adapted from the pretrained TabPFN foundation model to improve performance on presence/absence classification tasks common in ecology and conservation biology.

Model Variants

File Description Epochs Val ROC-AUC Val PR-AUC
tabpfn-sdm-nonspatial.pt Standard (non-spatial) train/test evaluation 100 0.747 0.261
tabpfn-sdm-spatial.pt Spatially-separated train/test evaluation 50 0.653 0.144

The non-spatial variant is trained and evaluated using standard random train/test splits. The spatial variant uses 10km buffer-based spatial separation between training and test data, which is more realistic for ecological applications but results in a harder prediction task.

Training Details

Two-Step Finetuning

Each model was finetuned in two steps:

  1. Step 1 (Cross-Validation): Finetuned on CV-generated train/test splits from training data. Learning rate: 1e-5. Goal: adapt TabPFN to general SDM classification patterns.
  2. Step 2 (Benchmark): Further finetuned on actual benchmark train/test splits. Learning rate: 1e-6. Goal: refine for the specific evaluation setup.

Training Data

  • 226 species across 6 geographic regions: AWT (Australia), NSW (New South Wales), CAN (Canada), NZ (New Zealand), SA (South Africa), SWI (Switzerland)
  • Species split: 62.5% finetuning, 7.5% validation, 30% held-out test
  • Data from the disdat R package (standardized presence-only and presence-absence datasets)
  • Features include environmental covariates (7-12 per region) with native categorical variable support

Training Procedure

  • Base model: TabPFN v2 classifier (Prior-Labs/TabPFN-v2-clf)
  • Optimizer: AdamW
  • LR Schedule: OneCycleLR
  • Training approach: Iterative R-Python epoch loop with smart sub-batching
    • R generates cross-validation data (parallelized with 20 cores)
    • Python trains one epoch on GPU
    • Smart sub-batching: keeps ALL presences (rare, valuable) in every sub-batch, partitions absences across sub-batches
  • Max training samples per sub-batch: 1,500
  • Random seed: 32639

Hardware

  • GPU: NVIDIA RTX 4000 ADA GPU with 20GB VRAM
  • Peak memory: ~14GB during training
  • Training time: ~188 hours total (Step 1 + Step 2) for non-spatial; ~178 hours for spatial

Usage

Python

import torch
from tabpfn import TabPFNClassifier
from huggingface_hub import hf_hub_download

# Download model
model_path = hf_hub_download(
    repo_id="rdinnager/tabpfn-sdm-finetuned",
    filename="tabpfn-sdm-nonspatial.pt"
)

# Load finetuned model
device = "cuda" if torch.cuda.is_available() else "cpu"
clf = TabPFNClassifier(
    ignore_pretraining_limits=True,
    device=device,
    n_estimators=8,
    random_state=32639
)
clf._initialize_model_variables()

checkpoint = torch.load(model_path, map_location=device, weights_only=False)
clf.models_[0].load_state_dict(checkpoint["model_state_dict"])
clf.models_[0].eval()

# Fit on training data and predict
clf.fit(X_train, y_train)
probs = clf.predict_proba(X_test)[:, 1]  # Probability of presence

R (via reticulate)

The companion R code provides inference wrappers. See the TabPFN-SDM GitHub repository for the full pipeline.

source("R/run_tabpfn_finetuned.R")

result <- run_tabpfn_finetuned_ensemble(
  train_dat = train_data,
  test_dat = test_data,
  model_path = "tabpfn-sdm-nonspatial.pt",
  max_train_size = 1000L,
  n_estimators = 8L
)

# result$test contains y_truth and y_pred columns

Ensemble Inference (Recommended)

For best results, use ensemble inference which matches the training procedure:

  • Keeps ALL presences in every sub-batch
  • Partitions absences across sub-batches for 100% data usage
  • Averages predictions across sub-batches

This is implemented in run_tabpfn_finetuned_ensemble() (R) or can be replicated in Python using the utilities in the GitHub repo.

Checkpoint Contents

Each .pt file is a PyTorch checkpoint dictionary containing:

Key Description
model_state_dict Finetuned TabPFN model weights
config Training configuration (hyperparameters, data settings)
history Training history (loss, metrics per epoch)
step1_path Path to the Step 1 model used as initialization

Limitations

  • Models were trained on species from 6 specific geographic regions and may not generalize to other regions without additional finetuning
  • Performance varies by species; species with very few presences may have lower accuracy
  • The spatial variant has lower validation metrics due to the inherently harder task of spatial extrapolation
  • Requires the tabpfn Python package (v2) for inference
  • Categorical variable handling requires proper setup (see project documentation)

Citation

If you use these models, please cite the accompanying paper:

@article{dinnage2026niche,
  title={A Niche in the Machine: The Promise of AI Foundation Models for Species Distribution Modeling},
  author={Dinnage, Russell and Warren, Dan L.},
  year={2026},
  doi={10.32942/X2VQ10},
  journal={EcoEvoRxiv},
  url={https://ecoevorxiv.org/repository/view/11797/}
}

Please also cite the original TabPFN paper:

@article{hollmann2025tabpfn,
  title={Accurate Predictions on Small Data with a Tabular Foundation Model},
  author={Hollmann, Noah and M{\"u}ller, Samuel and Purucker, Lennart and
          Krishnakumar, Arjun and K{\"o}rfer, Max and Hoo, Shi Bin and
          Schirrmeister, Robin Tibor and Hutter, Frank},
  journal={Nature},
  year={2025},
  publisher={Nature Publishing Group}
}

License

These finetuned model weights are distributed under the Prior Labs License v1.1, consistent with the base TabPFN model license. See the LICENSE file for full terms.

Links