--- license: cc-by-sa-4.0 datasets: - Smith42/galaxies - Smith42/galaxies_metadata - Smith42/galaxies_embeddings tags: - astronomy - images - huggingscience - science ---
astroPT_shoggoth
# astroPTv2.0: a Large Observation Model for Astronomy Here we have the model files for the astroPT project, the code to run inference with these models is found here: [https://github.com/smith42/astropt](https://github.com/smith42/astropt) You will find the fully trained models (pretrained on 8.6 million galaxies) in folders labelled with the model parameter count in the `astropt` directory. Unlike the older models which were trained on the "image" column in [smith42/galaxies](https://huggingface.co/datasets/Smith42/galaxies), these models are trained on the "cropped" galaxies from the "image_crop" column. Those galaxies have been cropped and zoomed so that they take up the majority of each image before uploading. We get some promising scaling on this new dataset, see below:
scaling_law
## Usage To use these models in anger you can `pip install astropt` and run the following code: ```python from astropt.model_utils import load_astropt from astropt.local_datasets import GalaxyImageDataset from datasets import load_dataset # for Smith42/galaxies import torch import numpy as np from functools import partial from torch.utils.data import DataLoader from torchvision import transforms # boilerplate to preprocess galaxy images def normalise(x): std, mean = torch.std_mean(x, dim=1, keepdim=True) return (x - mean) / (std + 1e-8) def data_transforms(): return transforms.Compose([transforms.Lambda(normalise)]) def _process_galaxy_wrapper(idx, func): """This function ensures that the image is tokenised in the same way as the pre-trained model is expecting""" galaxy = func( torch.from_numpy(np.array(idx["image"]).swapaxes(0, 2)).to(float) ).to(torch.float) galaxy_positions = torch.arange(0, len(galaxy), dtype=torch.long) return { "images": galaxy, "images_positions": galaxy_positions, } # for 095M parameter model, 015M and 850M models are also available: model = load_astropt("Smith42/astroPT_v2.0", path="astropt/095M") galproc = GalaxyImageDataset( None, spiral=True, transform={"images": data_transforms()}, modality_registry=model.modality_registry ) ds = ( load_dataset("Smith42/galaxies", split="test", revision="v2.0", streaming=True) .select_columns("image") .map(partial(_process_galaxy_wrapper, func=galproc.process_galaxy)) .with_format("torch") ) dl = iter(DataLoader(ds, batch_size=128, num_workers=32)) zs = [] for B in dl: zs.append(model.generate_embeddings(B)["images"].detach().numpy()) zs = np.concatenate(zs) # do cool stuff with zs... ``` ## Updates and community AstroPT is an open-to-all UniverseTBD project. Please join the [UniverseTBD](https://universetbd.org) Discord for updates: [https://discord.gg/MNEVegvfJq](https://discord.gg/MNEVegvfJq)