---
license: cc-by-sa-4.0
datasets:
- Smith42/galaxies
- Smith42/galaxies_metadata
- Smith42/galaxies_embeddings
tags:
- astronomy
- images
- huggingscience
- science
---
# astroPTv2.0: a Large Observation Model for Astronomy
Here we have the model files for the astroPT project, the code to run inference
with these models is found here:
[https://github.com/smith42/astropt](https://github.com/smith42/astropt)
You will find the fully trained models (pretrained on 8.6 million galaxies) in
folders labelled with the model parameter count in the `astropt` directory.
Unlike the older models which were trained on the "image" column in
[smith42/galaxies](https://huggingface.co/datasets/Smith42/galaxies),
these models are trained on the "cropped" galaxies from the "image_crop"
column. Those galaxies have been cropped and zoomed so that they take up the
majority of each image before uploading.
We get some promising scaling on this new dataset, see below:
## Usage
To use these models in anger you can `pip install astropt` and run the following code:
```python
from astropt.model_utils import load_astropt
from astropt.local_datasets import GalaxyImageDataset
from datasets import load_dataset # for Smith42/galaxies
import torch
import numpy as np
from functools import partial
from torch.utils.data import DataLoader
from torchvision import transforms
# boilerplate to preprocess galaxy images
def normalise(x):
std, mean = torch.std_mean(x, dim=1, keepdim=True)
return (x - mean) / (std + 1e-8)
def data_transforms():
return transforms.Compose([transforms.Lambda(normalise)])
def _process_galaxy_wrapper(idx, func):
"""This function ensures that the image is tokenised in the same way as the pre-trained model is expecting"""
galaxy = func(
torch.from_numpy(np.array(idx["image"]).swapaxes(0, 2)).to(float)
).to(torch.float)
galaxy_positions = torch.arange(0, len(galaxy), dtype=torch.long)
return {
"images": galaxy,
"images_positions": galaxy_positions,
}
# for 095M parameter model, 015M and 850M models are also available:
model = load_astropt("Smith42/astroPT_v2.0", path="astropt/095M")
galproc = GalaxyImageDataset(
None,
spiral=True,
transform={"images": data_transforms()},
modality_registry=model.modality_registry
)
ds = (
load_dataset("Smith42/galaxies", split="test", revision="v2.0", streaming=True)
.select_columns("image")
.map(partial(_process_galaxy_wrapper, func=galproc.process_galaxy))
.with_format("torch")
)
dl = iter(DataLoader(ds, batch_size=128, num_workers=32))
zs = []
for B in dl:
zs.append(model.generate_embeddings(B)["images"].detach().numpy())
zs = np.concatenate(zs)
# do cool stuff with zs...
```
## Updates and community
AstroPT is an open-to-all UniverseTBD project. Please join the [UniverseTBD](https://universetbd.org) Discord for updates: [https://discord.gg/MNEVegvfJq](https://discord.gg/MNEVegvfJq)