#!/usr/bin/env python3 """ Merge BAGEL EMA checkpoint into a standard inference checkpoint. The repository ships two shards: * ``ema.safetensors`` – EMA weights for the Mixture-of-Transformer stack, connector and ViT encoder described by ``llm_config.json`` / ``vit_config.json``. * ``ae.safetensors`` – VAE weights referenced by ``model.safetensors.index.json``. This script combines the two into a single ``model`` checkpoint that can be used in place of the EMA file. By default the script keeps the source files untouched and writes a new ``model_from_ema.safetensors`` plus, optionally, an accompanying index. """ from __future__ import annotations import argparse import json from collections import OrderedDict from pathlib import Path from typing import Dict import torch def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Convert BAGEL EMA weights into a regular inference checkpoint." ) parser.add_argument( "--ema", type=Path, default=Path("ema.safetensors"), help="Path to the EMA weights file (default: ema.safetensors).", ) parser.add_argument( "--ae", type=Path, default=Path("ae.safetensors"), help="Path to the VAE weights file (default: ae.safetensors).", ) parser.add_argument( "--output", type=Path, default=Path("model_from_ema.safetensors"), help="Destination for the merged checkpoint.", ) parser.add_argument( "--index", type=Path, default=None, help="Optional path for a Hugging Face style index JSON file.", ) return parser.parse_args() def load_safetensors(path: Path) -> Dict[str, torch.Tensor]: try: from safetensors.torch import load_file except ImportError as exc: # pragma: no cover - raises early when dependency missing raise RuntimeError( "safetensors is required. Install it with `pip install safetensors`." ) from exc tensors = load_file(str(path)) if not tensors: raise ValueError(f"{path} does not contain any tensors.") return tensors def save_safetensors( tensors: Dict[str, torch.Tensor], path: Path, *, metadata: Dict[str, str] ) -> None: try: from safetensors.torch import save_file except ImportError as exc: # pragma: no cover - raises early when dependency missing raise RuntimeError( "safetensors is required. Install it with `pip install safetensors`." ) from exc save_file(tensors, str(path), metadata=metadata) def compute_total_size_bytes(tensors: Dict[str, torch.Tensor]) -> int: total = 0 for tensor in tensors.values(): total += tensor.element_size() * tensor.nelement() return total def main() -> None: args = parse_args() if not args.ema.is_file(): raise FileNotFoundError(f"EMA weights not found: {args.ema}") if not args.ae.is_file(): raise FileNotFoundError(f"VAE weights not found: {args.ae}") ema_state = load_safetensors(args.ema) ae_state = load_safetensors(args.ae) overlap = set(ae_state.keys()) & set(ema_state.keys()) if overlap: raise ValueError( f"Found {len(overlap)} overlapping parameter names between ae and ema files; " "please inspect your checkpoints before merging." ) merged = OrderedDict() merged.update(sorted(ae_state.items())) merged.update(sorted(ema_state.items())) total_size = compute_total_size_bytes(merged) metadata = {"total_size": str(total_size)} save_safetensors(merged, args.output, metadata=metadata) if args.index: weight_map = {key: args.output.name for key in merged.keys()} index_payload = { "metadata": {"total_size": total_size}, "weight_map": weight_map, } args.index.write_text(json.dumps(index_payload, indent=4, ensure_ascii=False) + "\n") if __name__ == "__main__": main()