Abstract

The Llama-3.1-8B-Instruct-w16a8-rw model is a domain-adapted Turkish legal instruction-tuned variant of Meta’s Llama-3.1-8B-Instruct, trained using the Float8 Rowwise (AXISWISE) quantization recipe. This model was developed within the “FSDP2 with Float8 Precision for Faster Training” project to evaluate how fine-grained FP8 scaling affects both training efficiency and downstream legal reasoning performance. During training, model weights were kept in BF16, while the inputs, weights, and gradient outputs were dynamically quantized to FP8-E4M3 using TorchAO’s rowwise configuration, where each row receives its own scaling factor. This finer granularity enabled higher GPU utilization and reduced training time, achieving ~19.87% speedup over the BF16 baseline on H100 GPUs while maintaining stable convergence. The model was trained on the newmindai/EuroHPC-Legal dataset (multi-domain Q/A format) to improve reasoning quality across various subfields of Turkish law.

Experiment Context

This model was trained with the finer grained resolution for Float8 Rowwise recipe. The recipe sets scaling_granularity to AXISWISE for each of input, weight and gradient output cast configurations and ScalingType to DYNAMIC where each row of a weight matrix gets its own scaling factor, instead of one scaling factor for the entire tensor. and the dtype to float8_e4m3, basically it points that the bites be distributed with 4 exponents and 3 mantissas and the first bit is signal bit.

from torchao.float8 import (
    convert_to_float8_training,
    Float8LinearConfig)
config = Float8LinearConfig.from_recipe_name("rowwise")
model = convert_to_float8_training(model, config=config)

Base Model Technical Specifications

  • Parameters: 8 Billion
  • Architecture Family: Llama 3.1
  • Maximum Position Embeddings: 131,072
  • Attention Heads: 32 (num_attention_heads)
  • Key-Value Heads: 8 (num_key_value_heads)
  • Hidden Layers: 32 (num_hidden_layers)
  • Hidden Size: 4,096 (hidden_size)
  • Intermediate Size: 14,336
  • Vocabulary Size: 128,256
  • Precision: bfloat16
  • RoPE Scaling: type llama3, factor = 8.0
  • RMS Norm Epsilon: 1e-05
  • Activation: SiLU

Training Methodology

Training Configuration

  • Model: meta-llama/Llama-3.1-8B-Instruct
  • Sequence Length: 4,096 (seq_len)
  • Epochs: 1
  • Per-Device Micro Batch Size: 2
  • Gradient Accumulation: 4
  • GPUs: 4 (via CUDA_VISIBLE_DEVICES=0,1,2,3)
  • dtype: bf16 && fp8=false
    • Weights: bfloat16
    • Activations: bfloat16
  • Optimizer: AdamW
    • Learning Rate: 2e-5
    • Weight Decay: 0.01
    • Betas: (0.9, 0.95)
    • Epsilon: 1e-8
  • LR Scheduler: Cosine; warmup = 10% (warmup_ratio=0.1) | also warmup_steps=100
  • Max Grad Norm: 1.0
  • Gradient Checkpointing: Enabled
  • Evaluation: every 5 steps (eval_steps=5, eval_samples=1000)
  • Checkpointing: every 10 steps; keep last 5; select best by eval_loss
  • Logging: every step to file; Weights & Biases in offline mode
  • Seed: 100
  • Distributed Training: torch.distributed.run (single node, multi-GPU)
    • FSDP2 (Optimized Fully Sharded Data Parallel)

Setups

  • Precision: Used Half-precision bfloat16 as data type and for computation.
  • Hardware: HPC (EuroHPC/BSC-class) node with 4 × NVIDIA H100 GPUs.
  • Framework: PyTorch with torchrun for distributed training.

Dependencies

package Version
Transformers 4.57.1
torch 2.9.0+cu128
accelerate 0.14.1
datasets 4.3.0
huggingface-hub 0.36.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
wandb 0.22.1

Performance Evaluation

2-models trained on 1Node with fp8 recipes

Loss metric results for w16a16 tensorwise & w16a8 rowwise recipe Memory allocation for w16a16 tensorwise & w16a8 rowwise recipe Utilization for w16a16 tensorwise & w16a8 rowwise recipe
lossRW MemAlRW gpuUtilsRW
Loss metric results for w16a8 recipes Memory allocation for w16a8 recipes Utilization for w16a8 recipes
recipeloss recipeMemAl recipeUtils

Loss Analysision

Model Max Loss (train) Min Loss (train) Avg Loss (train) Final Loss (train) ± Std (train) Max Loss (val) Min Loss (val) Avg Loss (val) Final Loss (val) ± Std (val)
Llama-3.1-8B-Instruct_w16a16 3.1462 0.5710 0.8048 0.6374 0.2716 1.0517 0.8335 0.8876 0.8335 0.0678
Llama-3.1-8B-Instruct-w16a8-tw 3.1983 0.5759 0.8113 0.6419 0.2756 1.0566 0.8390 0.8925 0.8391 0.0675
Llama-3.1-8B-Instruct_w16a8_4nodes_rw 3.1682 0.5740 0.8118 0.6431 0.2746 1.0613 0.8394 0.8937 0.8394 0.0688
Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp 3.1837 0.5763 0.8116 0.6420 0.2751 1.0599 0.8391 0.8933 0.8391 0.0685
Llama-3.1-8B-Instruct-w16a8-mxtw 3.1983 0.5747 0.8115 0.6446 0.2758 1.0562 0.8384 0.8923 0.8384 0.0677

Training Time Analysision

Model Training Time (mins) Memory Allocated (avg %) GPU Utilization (avg %) Speed vs bf16
Llama-3.1-8B-Instruct_w16a16 138.75267 74.4189 56.6059% _
Llama-3.1-8B-Instruct-w16a8-tw 123.75267 68.8982 97.5364% 12.11%
Llama-3.1-8B-Instruct_w16a8_rw 115.75364 69.6132 97.7689% 19.87%
Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp 109.00364 69.4806 97.3312% 27.33%
Llama-3.1-8B-Instruct-w16a8-mxtw 64.00328 68.8982 95.5661% 116.82%

Implementation

Gpu && Memory usage Profiling

The training progress has been profiled using pytorch-profiler tool.

  • follow the steps to visualize the profiles:
    1. pip install the versions that mentioned in the dependencies section of these libs tensorboard and tensorboard-data-server.
    2. Visualize pytorch profiles by runing the command provided below.
    tensorboard --logdir="./Llama-3.1-8B-Instruct_rowwise" --port="6006"
    

Usage

Note: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "newmindai/Llama-3.1-8B-Instruct-w16a8-rw"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
    device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False
    )

print(tok.decode(out[0], skip_special_tokens=True))

Ethical Considerations and Disclaimers

  • Research & development purposes only; not a substitute for professional legal counsel.
  • Users must ensure compliance with data protection and sector regulations.
  • Potential biases may exist in domain data and model outputs.

Model & Data Card Metadata

  • Total Parameters: 8,030,261,248
  • Serialized Size (approx.): 16,060,522,496 bytes
  • Config precision: bfloat16
  • RoPE: llama3 scaling, factor 8.0

References and Citations

Base Model

@misc{meta_llama31_8b_instruct,
  title={Llama 3.1 8B Instruct},
  author={Meta AI},
  year={2024},
  howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}

Training Dataset

@misc{euro_hpc_legal,
  title={EuroHPC-Legal},
  author={newmindai},
  year={2025},
  howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
Downloads last month
6
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for newmindai/Llama-3.1-8B-Instruct-w16a8-rw

Finetuned
(1973)
this model

Dataset used to train newmindai/Llama-3.1-8B-Instruct-w16a8-rw

Collection including newmindai/Llama-3.1-8B-Instruct-w16a8-rw