Abstract
The Llama-3.1-8B-Instruct-w16a8-tw model is a Turkish legal instruction-tuned variant of Meta’s Llama-3.1-8B-Instruct, trained using the Float8 Tensorwise quantization recipe. Developed within the “FSDP2 with Float8 Precision for Faster Training” initiative, this model evaluates how coarse-grained FP8 scaling affects training efficiency and downstream legal reasoning performance compared to BF16 baselines. In this tensorwise setup, the scaling granularity is applied at the entire tensor level, meaning all rows of each weight matrix share a single dynamic scaling factor. Inputs and weights are cast to FP8-E4M3FN, while gradient outputs use FP8-E5M2, following TorchAO’s default mixed-precision configuration. Model weights are maintained in BF16, ensuring stable updates while activations and intermediate tensors are handled in FP8 for improved efficiency. When trained on the newmindai/EuroHPC-Legal dataset (multi-domain Turkish legal Q/A), this model achieved a ~12.11% speedup over the BF16 baseline while keeping convergence behavior closely aligned with BF16 training. GPU utilization increased significantly under FP8-tensorwise, demonstrating the efficiency benefits of coarse-grained FP8 casting under FSDP2 on H100 hardware.
Experiment Context
This model was trained with the Float8 default recipe "tensorwise". The recipe sets the granularity to TENSORWISE for the input, weight, and gradient output cast configurations, and the ScalingType to DYNAMIC, where all rows of a weight matrix share the same scaling factor. The dtype is set to float8_e4m3fn for the input and weight cast configurations, and to float8_e5m2 for the gradient output cast configuration. e4m3 indicates a normalized floating-point format with 4 exponent bits and 3 mantissa bits, while e4m3fn represents the same distribution but in a finite normalized form. In both formats, the first bit serves as the sign bit.
from torchao.float8 import (
convert_to_float8_training,
Float8LinearConfig)
config = Float8LinearConfig.from_recipe_name("tensorwise")
model = convert_to_float8_training(model, config=config)
Base Model Technical Specifications
- Parameters: 8 Billion
- Architecture Family: Llama 3.1
- Maximum Position Embeddings: 131,072
- Attention Heads: 32 (
num_attention_heads) - Key-Value Heads: 8 (
num_key_value_heads) - Hidden Layers: 32 (
num_hidden_layers) - Hidden Size: 4,096 (
hidden_size) - Intermediate Size: 14,336
- Vocabulary Size: 128,256
- Precision: bfloat16
- RoPE Scaling: type
llama3, factor = 8.0 - RMS Norm Epsilon: 1e-05
- Activation: SiLU
Training Methodology
Training Configuration
- Model:
meta-llama/Llama-3.1-8B-Instruct - Sequence Length: 4,096 (
seq_len) - Epochs: 1
- Per-Device Micro Batch Size: 2
- Gradient Accumulation: 4
- GPUs: 4 (via
CUDA_VISIBLE_DEVICES=0,1,2,3) - dtype:
bf16&&fp8=false- Weights: bfloat16
- Activations: bfloat16
- Optimizer: AdamW
- Learning Rate: 2e-5
- Weight Decay: 0.01
- Betas: (0.9, 0.95)
- Epsilon: 1e-8
- LR Scheduler: Cosine; warmup = 10% (
warmup_ratio=0.1) | alsowarmup_steps=100 - Max Grad Norm: 1.0
- Gradient Checkpointing: Enabled
- Evaluation: every 5 steps (
eval_steps=5,eval_samples=1000) - Checkpointing: every 10 steps; keep last 5; select best by
eval_loss - Logging: every step to file; Weights & Biases in offline mode
- Seed: 100
- Distributed Training:
torch.distributed.run(single node, multi-GPU)- FSDP2 (Optimized Fully Sharded Data Parallel)
Setups
- Precision: Used Half-precision bfloat16 as data type and for computation.
- Hardware: HPC (EuroHPC/BSC-class) node with 4 × NVIDIA H100 GPUs.
- Framework: PyTorch with
torchrunfor distributed training.
Dependencies
| package | Version |
|---|---|
| Transformers | 4.57.1 |
| torch | 2.9.0+cu128 |
| accelerate | 0.14.1 |
| datasets | 4.3.0 |
| huggingface-hub | 0.36.0 |
| tensorboard | 2.20.0 |
| tensorboard-data-server | 0.7.2 |
| wandb | 0.22.1 |
Performance Evaluation
2-models trained on 1Node with fp8 recipes
| Loss metric results for w16a16 tensorwise & w16a8 tensorwise recipe | Memory allocation for w16a16 tensorwise & w16a8 tensorwise recipe | Utilization for w16a16 tensorwise & w16a8 tensorwise recipe |
|---|---|---|
![]() |
![]() |
![]() |
| Loss metric results for w16a8 recipes | Memory allocation for w16a8 recipes | Utilization for w16a8 recipes |
|---|---|---|
![]() |
![]() |
![]() |
Loss Analysision
| Model | Max Loss (train) | Min Loss (train) | Avg Loss (train) | Final Loss (train) | ± Std (train) | Max Loss (val) | Min Loss (val) | Avg Loss (val) | Final Loss (val) | ± Std (val) |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct_w16a16 | 3.1462 | 0.5710 | 0.8048 | 0.6374 | 0.2716 | 1.0517 | 0.8335 | 0.8876 | 0.8335 | 0.0678 |
| Llama-3.1-8B-Instruct-w16a8-tw | 3.1983 | 0.5759 | 0.8113 | 0.6419 | 0.2756 | 1.0566 | 0.8390 | 0.8925 | 0.8391 | 0.0675 |
| Llama-3.1-8B-Instruct_w16a8_4nodes_rw | 3.1682 | 0.5740 | 0.8118 | 0.6431 | 0.2746 | 1.0613 | 0.8394 | 0.8937 | 0.8394 | 0.0688 |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp | 3.1837 | 0.5763 | 0.8116 | 0.6420 | 0.2751 | 1.0599 | 0.8391 | 0.8933 | 0.8391 | 0.0685 |
| Llama-3.1-8B-Instruct-w16a8-mxtw | 3.1983 | 0.5747 | 0.8115 | 0.6446 | 0.2758 | 1.0562 | 0.8384 | 0.8923 | 0.8384 | 0.0677 |
Training Time Analysision
| Model | Training Time (mins) | Memory Allocated (avg %) | GPU Utilization (avg %) | Speed vs bf16 |
|---|---|---|---|---|
| Llama-3.1-8B-Instruct_w16a16 | 138.75267 | 74.4189 | 56.6059% | _ |
| Llama-3.1-8B-Instruct-w16a8-tw | 123.75267 | 68.8982 | 97.5364% | 12.11% |
| Llama-3.1-8B-Instruct_w16a8_rw | 115.75364 | 69.6132 | 97.7689% | 19.87% |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp | 109.00364 | 69.4806 | 97.3312% | 27.33% |
| Llama-3.1-8B-Instruct-w16a8-mxtw | 64.00328 | 68.8982 | 95.5661% | 116.82% |
Implementation
Gpu && Memory usage Profiling
The training progress has been profiled using pytorch-profiler tool.
- follow the steps to visualize the profiles:
- pip install the versions that mentioned in the dependencies section of these libs tensorboard and tensorboard-data-server.
- Visualize pytorch profiles by runing the command provided below.
tensorboard --logdir="./Llama-3.1-8B-Instruct_tensorwise" --port="6006"
Usage
Note: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "newmindai/Llama-3.1-8B-Instruct-w16a8-tw"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False
)
print(tok.decode(out[0], skip_special_tokens=True))
Ethical Considerations and Disclaimers
- Research & development purposes only; not a substitute for professional legal counsel.
- Users must ensure compliance with data protection and sector regulations.
- Potential biases may exist in domain data and model outputs.
Model & Data Card Metadata
- Total Parameters: 8,030,261,248
- Serialized Size (approx.): 16,060,522,496 bytes
- Config precision: bfloat16
- RoPE: llama3 scaling, factor 8.0
References and Citations
Base Model
@misc{meta_llama31_8b_instruct,
title={Llama 3.1 8B Instruct},
author={Meta AI},
year={2024},
howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}
Training Dataset
@misc{euro_hpc_legal,
title={EuroHPC-Legal},
author={newmindai},
year={2025},
howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
- Downloads last month
- 8
Model tree for newmindai/Llama-3.1-8B-Instruct-w16a8-tw
Base model
meta-llama/Llama-3.1-8B




