Precision Labs
Collection
FP8 Rowwise and BF16 tensorwise models with optimized recipes for large-scale training efficiency and convergence stability.
•
23 items
•
Updated
•
1
This model was trained as part of our study for comparing FSDP2 with bfloat16 precision against FSDP2 with FP8 mixed precision bfp16-fp8.
We used meta-llama/Llama-3.1-8B-Instruct. The model has been loaded using torch_dtype = bfloat16 and for FP8 + FSDP2 compatibility the model has been wrap per-layer instead of whole model This helped to avoid dimension misalignment issues and during forward and backward passes float8 variats been used using the default Tensorwise quantization scaling recipeand we setted thepad_inner_dim` for automatically pad dimensions to be divisible by 16 which is required for FP8.
from torchao.float8 import (
convert_to_float8_training,
Float8LinearConfig,
precompute_float8_dynamic_scale_for_fsdp)
config = Float8LinearConfig(
pad_inner_dim=True,
enable_fsdp_float8_all_gather=True)
model = convert_to_float8_training(model, config=config)
if use_fp8:
for i, layer in enumerate(model.model.layers):
fully_shard(layer, **fsdp_kwargs)
fully_shard(model.model.embed_tokens, **fsdp_kwargs)
fully_shard(model.lm_head, **fsdp_kwargs)
num_attention_heads)num_key_value_heads)num_hidden_layers)hidden_size)llama3, factor = 8.0meta-llama/Llama-3.1-8B-Instructseq_len)CUDA_VISIBLE_DEVICES=0,1,2,3)bf16 && fp8=truewarmup_ratio=0.1) | also warmup_steps=100eval_losstorch.distributed.run (4 nodes, multi-GPU)torchrun for distributed training.| package | Version |
|---|---|
| Transformers | 4.57.1 |
| torch | 2.9.0+cu128 |
| accelerate | 0.14.1 |
| datasets | 4.3.0 |
| huggingface-hub | 0.36.0 |
| tensorboard | 2.20.0 |
| tensorboard-data-server | 0.7.2 |
| wandb | 0.22.1 |
| model | Job ID | Runtime (mins) | Nodes | GPUs | Node-hour | GPU-hour | micro-batch | batch-size | gradient_accumulation | total_batch_size |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct_w16a8_rw | 31768103 | 115.75 | 1 | 4 | 1.929 | 7.716 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp | 31837629 | 109.00 | 1 | 4 | 1.816 | 7.266 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a8-mxtw | 31768031 | 64.00 | 4 | 4 | 1.066 | 4.266 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a16-tw | 31768074 | 138.75 | 1 | 4 | 0.858 | 3.433 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a8-1node-bs8 | 31768093 | 123.75 | 1 | 4 | 0.788 | 3.151 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 | 31478433 | 31.75 | 4 | 4 | 2.117 | 8.467 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 | 31478468 | 39.75 | 4 | 4 | 2.650 | 10.600 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 | 31476914 | 22.00 | 8 | 4 | 2.933 | 11.733 | 4 | 4 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 | 31476844 | 23.50 | 8 | 4 | 3.133 | 12.533 | 4 | 4 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 | 31476914 | 22.00 | 8 | 4 | 2.933 | 11.733 | 4 | 8 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 | 31476844 | 23.50 | 8 | 4 | 3.133 | 12.533 | 4 | 8 | 8 | 1024 |
| Model | Max Loss (train) | Min Loss (train) | Avg Loss (train) | Final Loss (train) | ± Std (train) | Max Loss (val) | Min Loss (val) | Avg Loss (val) | Final Loss (val) | ± Std (val) |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct-w16a8-rw | 8 | 3.1682 | 0.5740 | 0.8118 | 0.6431 | 0.2746 | 1.0613 | 0.8394 | 0.8937 | 0.8394 |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp | 8 | 3.1837 | 0.5763 | 0.8116 | 0.6420 | 0.2751 | 1.0599 | 0.8391 | 0.8933 | 0.8391 |
| Llama-3.1-8B-Instruct-w16a8-mxtw | 8 | 3.1983 | 0.5747 | 0.8115 | 0.6446 | 0.2758 | 1.0562 | 0.8384 | 0.8923 | 0.8384 |
| Llama-3.1-8B-Instruct-w16a16-tw | 8 | 3.1235 | 0.7203 | 0.9750 | 0.3344 | 0.7612 | 1.9113 | 0.8907 | 0.9831 | 0.1897 |
| Llama-3.1-8B-Instruct-w16a8-1node-bs8 | 8 | 3.1661 | 0.7261 | 0.9804 | 0.3374 | 0.7672 | 1.9230 | 0.8948 | 0.9867 | 0.1906 |
| Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 | 32 | 3.2452 | 0.7414 | 0.9665 | 0.4844 | 0.7504 | 1.0538 | 0.8382 | 0.8844 | 0.0725 |
| Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 | 32 | 3.2840 | 0.7478 | 0.9748 | 0.4905 | 0.7581 | 1.0701 | 0.8430 | 0.8922 | 0.0764 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 | 32 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 | 32 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 | 64 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 | 64 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 |
Note: the final model has saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "Llama-3.1-8B-Instruct-w16a8-4nodes-bs64"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False
)
print(tok.decode(out[0], skip_special_tokens=True))
@misc{meta_llama31_8b_instruct,
title={Llama 3.1 8B Instruct},
author={Meta AI},
year={2024},
howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}
@misc{euro_hpc_legal,
title={EuroHPC-Legal},
author={newmindai},
year={2025},
howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
Base model
meta-llama/Llama-3.1-8B