model: opt-125m
config: ModuleFqnToConfig
with Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig and IntxWeightOnlyConfig
config version: 1
torchao version: 0.14.0.dev
Generate Quantized Model
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
# Configure logging to see warnings and debug information
logging.basicConfig(
    level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
)
# Enable specific loggers that might contain the serialization warnings
logging.getLogger("transformers").setLevel(logging.INFO)
logging.getLogger("torchao").setLevel(logging.INFO)
logging.getLogger("safetensors").setLevel(logging.INFO)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)
model_id = "facebook/opt-125m"
from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    Int4WeightOnlyConfig,
    IntxWeightOnlyConfig,
    PerRow,
    PerAxis,
    ModuleFqnToConfig,
    Float8Tensor,
    Int4TilePackedTo4dTensor,
    IntxUnpackedToInt8Tensor,
)
float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")
intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))
qconfig_dict = {
    # highest priority
    "model.decoder.layers.3.self_attn.q_proj": int4wo,
    "model.decoder.layers.3.self_attn.k_proj": int4wo,
    "model.decoder.layers.3.self_attn.v_proj": int4wo,
    # vllm
    "model.decoder.layers.3.self_attn.qkv_proj": int4wo,
    "re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn,
    # this should not take effect and we'll fallback to _default
    # since no full mach (missing `j` in the end)
    "re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn,
    # vllm
    "re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn,
    "_default": intxwo,
}
quant_config = ModuleFqnToConfig(qconfig_dict)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
)
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
for i in range(12):
    if i == 3:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
    else:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
    assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)
# # # Push to hub
MODEL_NAME = model_id.split("/")[-1]
save_to = f"torchao-testing/{MODEL_NAME}-ModuleFqnToConfig-v1-regex-0.14.0.dev"
quantized_model.push_to_hub(save_to, safe_serialization=False)
tokenizer.push_to_hub(save_to)
# quantized_model.save_pretrained(save_to, safe_serialization=False)
# tokenizer.save_pretrained(save_to)
# Manual Testing
prompt = "What are we having for dinner?"
print("Prompt:", prompt)
inputs = tokenizer(
    prompt,
    return_tensors="pt",
).to("cuda")
# setting temperature to 0 to make sure result deterministic
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0)
correct_output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", correct_output_text[0][len(prompt) :])
# # # Load model from saved checkpoint
reloaded_model = AutoModelForCausalLM.from_pretrained(
    save_to,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
    # quantization_config=quantization_config,
)
generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", output_text[0][len(prompt) :])
assert(correct_output_text == output_text)
Test Loading
from transformers import (
  AutoModelForCausalLM,
  AutoProcessor,
  AutoTokenizer,
  TorchAoConfig,
)
from torchao.quantization import Float8Tensor
from torchao.quantization import (
    Float8Tensor,
    Int4TilePackedTo4dTensor,
    IntxUnpackedToInt8Tensor,
)
import torch
model_name = "torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev"
device = "cuda"
input_text = "What are we having for dinner?"
max_new_tokens = 10
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device,
    dtype=torch.bfloat16,
)
for i in range(12):
    if i == 3:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
    else:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
    assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer(input_text, return_tensors="pt").to(device)
output = quantized_model.generate(**input_ids, max_new_tokens=max_new_tokens)
print(tokenizer.decode(output[0], skip_special_tokens=True))
Output:
What are we having for dinner?
A nice dinner with a friend.
I
- Downloads last month
- 120
	Inference Providers
	NEW
	
	
	This model isn't deployed by any Inference Provider.
	๐
			
		Ask for provider support