Will there be an "fp8_input_scaled" version for the Dev model?

#21
by jool454 - opened

Thanks for everything.

The official light tricks model should already be that, though it's a whole checkpoint:

https://huggingface.co/Lightricks/LTX-2.3-fp8/tree/main

edit) I deleted fp8 model on my repo since it overlapped with here, but I'll leave the extraction code here for reference.

import sys
import json
import torch
from safetensors.torch import safe_open, save_file

def cut_safetensors(input_path, output_path):
    with safe_open(input_path, framework="pt", device="cpu") as f:
        metadata = f.metadata()

        config = json.loads(metadata.get('config', '{}'))
        for key in ['vae', 'audio_vae', 'vocoder']:
            if key in config:
                del config[key]
        metadata['config'] = json.dumps(config)

        quant_meta = json.loads(metadata.get('_quantization_metadata', '{"layers": {}}'))
        quant_layers = quant_meta.get("layers", {})
        del metadata['_quantization_metadata']

        new_state_dict = {}
        prefix = "model.diffusion_model."

        for key in f.keys():
            if key.startswith(prefix):
                new_state_dict[key] = f.get_tensor(key)
                base_key = key.replace(".weight", "")
                if base_key in quant_layers:
                    quant_info = quant_layers[base_key]
                    json_data = json.dumps(quant_info).encode("utf-8")
                    new_tensor = torch.tensor(list(json_data), dtype=torch.uint8)
                    new_state_dict[f"{base_key}.comfy_quant"] = new_tensor

        save_file(new_state_dict, output_path, metadata=metadata)

input_path, output_path = sys.argv[1:3]

if __name__ == "__main__":
    cut_safetensors(input_path, output_path)

Added my extraction here as well.

Kijai changed discussion status to closed

Sign up or log in to comment