Abstract

This model was trained as part of our study for comparing FSDP2 with bfloat16 precision against FSDP2 with FP8 mixed precision bfp16-fp8. We used meta-llama/Llama-3.1-8B-Instruct. The model has been loaded using torch_dtype = bfloat16 and for FP8 + FSDP2 compatibility the model has been wrap per-layer instead of whole model This helped to avoid dimension misalignment issues and during forward and backward passes float8 variats been used using the default Tensorwise quantization scaling recipeand we setted thepad_inner_dim` for automatically pad dimensions to be divisible by 16 which is required for FP8.

from torchao.float8 import (
    convert_to_float8_training,
    Float8LinearConfig,
    precompute_float8_dynamic_scale_for_fsdp)
config = Float8LinearConfig(
  pad_inner_dim=True,
  enable_fsdp_float8_all_gather=True)
model = convert_to_float8_training(model, config=config)
if use_fp8:
    for i, layer in enumerate(model.model.layers):
        fully_shard(layer, **fsdp_kwargs)
    fully_shard(model.model.embed_tokens, **fsdp_kwargs)
    fully_shard(model.lm_head, **fsdp_kwargs)

Base Model Technical Specifications

  • Parameters: 8 Billion
  • Architecture Family: Llama 3.1
  • Maximum Position Embeddings: 131,072
  • Attention Heads: 32 (num_attention_heads)
  • Key-Value Heads: 8 (num_key_value_heads)
  • Hidden Layers: 32 (num_hidden_layers)
  • Hidden Size: 4,096 (hidden_size)
  • Intermediate Size: 14,336
  • Vocabulary Size: 128,256
  • Precision: bfloat16
  • RoPE Scaling: type llama3, factor = 8.0
  • RMS Norm Epsilon: 1e-05
  • Activation: SiLU

Training Methodology

Training Configuration

  • Model: meta-llama/Llama-3.1-8B-Instruct
  • Sequence Length: 4,096 (seq_len)
  • Epochs: 2
  • Per-Device Micro Batch Size: 8
  • Gradient Accumulation: 8
  • GPUs: 4 (via CUDA_VISIBLE_DEVICES=0,1,2,3)
  • dtype: bf16 && fp8=true
    • Weights: bfloat16
    • Activations: float8
  • Optimizer: AdamW
    • Learning Rate: 2e-5
    • Weight Decay: 0.01
    • Betas: (0.9, 0.95)
    • Epsilon: 1e-8
  • LR Scheduler: Cosine; warmup = 10% (warmup_ratio=0.1) | also warmup_steps=100
  • Max Grad Norm: 1.0
  • Gradient Checkpointing: Enabled
  • Checkpointing: every 10 steps; keep last 5; select best by eval_loss
  • Logging: every step to file; Weights & Biases in offline mode
  • Seed: 100
  • Distributed Training: torch.distributed.run (4 nodes, multi-GPU)
    • FSDP2 (Optimized Fully Sharded Data Parallel)

Setups

  • Precision: Used Half-precision bfloat16 as data type and for computation.
  • Hardware: HPC (EuroHPC/BSC-class) 4 nodes with 4 × NVIDIA H100 GPUs.
  • Framework: PyTorch with torchrun for distributed training.

Dependencies

package Version
Transformers 4.57.1
torch 2.9.0+cu128
accelerate 0.14.1
datasets 4.3.0
huggingface-hub 0.36.0
tensorboard 2.20.0
tensorboard-data-server 0.7.2
wandb 0.22.1

Job Details

model Job ID Runtime (mins) Nodes GPUs Node-hour GPU-hour micro-batch batch-size gradient_accumulation total_batch_size
Llama-3.1-8B-Instruct_w16a8_rw 31768103 115.75 1 4 1.929 7.716 2 2 4 32
Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp 31837629 109.00 1 4 1.816 7.266 2 2 4 32
Llama-3.1-8B-Instruct-w16a8-mxtw 31768031 64.00 4 4 1.066 4.266 2 2 4 32
Llama-3.1-8B-Instruct-w16a16-tw 31768074 138.75 1 4 0.858 3.433 2 2 4 32
Llama-3.1-8B-Instruct-w16a8-1node-bs8 31768093 123.75 1 4 0.788 3.151 2 2 4 32
Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 31478433 31.75 4 4 2.117 8.467 4 4 8 512
Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 31478468 39.75 4 4 2.650 10.600 4 4 8 512
Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 31476914 22.00 8 4 2.933 11.733 4 4 8 1024
Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 31476844 23.50 8 4 3.133 12.533 4 4 8 1024
Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 31476914 22.00 8 4 2.933 11.733 4 8 8 1024
Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 31476844 23.50 8 4 3.133 12.533 4 8 8 1024

All 6-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations)

perplexity metric results for bfp16 && bfp16-fp8 configurations Accuracy metric results for bfp16 && bfp16-fp8 configurations Loss metric results for bfp16 && bfp16-fp8 configurations Memory allocation for bfp16 && bfp16-fp8 configurations Utilization for bfp16 && bfp16-fp8 configurations
prep_train acc_train loss_train mem_al utils
Model Max Loss (train) Min Loss (train) Avg Loss (train) Final Loss (train) ± Std (train) Max Loss (val) Min Loss (val) Avg Loss (val) Final Loss (val) ± Std (val)
Llama-3.1-8B-Instruct-w16a8-rw 8 3.1682 0.5740 0.8118 0.6431 0.2746 1.0613 0.8394 0.8937 0.8394
Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp 8 3.1837 0.5763 0.8116 0.6420 0.2751 1.0599 0.8391 0.8933 0.8391
Llama-3.1-8B-Instruct-w16a8-mxtw 8 3.1983 0.5747 0.8115 0.6446 0.2758 1.0562 0.8384 0.8923 0.8384
Llama-3.1-8B-Instruct-w16a16-tw 8 3.1235 0.7203 0.9750 0.3344 0.7612 1.9113 0.8907 0.9831 0.1897
Llama-3.1-8B-Instruct-w16a8-1node-bs8 8 3.1661 0.7261 0.9804 0.3374 0.7672 1.9230 0.8948 0.9867 0.1906
Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 32 3.2452 0.7414 0.9665 0.4844 0.7504 1.0538 0.8382 0.8844 0.0725
Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 32 3.2840 0.7478 0.9748 0.4905 0.7581 1.0701 0.8430 0.8922 0.0764
Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 32 3.2311 0.8448 1.1856 0.6434 0.8448 1.0257 0.8977 0.9460 0.0568
Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 32 3.3003 0.8473 1.1866 0.6481 0.8473 1.0203 0.8992 0.9445 0.0539
Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 64 3.2311 0.8448 1.1856 0.6434 0.8448 1.0257 0.8977 0.9460 0.0568
Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 64 3.3003 0.8473 1.1866 0.6481 0.8473 1.0203 0.8992 0.9445 0.0539

Implementation

Usage

Note: the final model has saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "Llama-3.1-8B-Instruct-w16a8-4nodes-bs64"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
    device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False
    )
print(tok.decode(out[0], skip_special_tokens=True))

Ethical Considerations and Disclaimers

  • Research & development purposes only; not a substitute for professional legal counsel.
  • Users must ensure compliance with data protection and sector regulations.
  • Potential biases may exist in domain data and model outputs.

Model & Data Card Metadata

  • Total Parameters: 8,030,261,248
  • Serialized Size (approx.): 16,060,522,496 bytes
  • Config precision: bfloat16
  • RoPE: llama3 scaling, factor 8.0

References and Citations

Base Model

@misc{meta_llama31_8b_instruct,
  title={Llama 3.1 8B Instruct},
  author={Meta AI},
  year={2024},
  howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}

Training Dataset

@misc{euro_hpc_legal,
  title={EuroHPC-Legal},
  author={newmindai},
  year={2025},
  howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
Downloads last month
15
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for newmindai/Llama-3.1-8B-Instruct-w16a8-4nodes-bs64

Finetuned
(2004)
this model

Dataset used to train newmindai/Llama-3.1-8B-Instruct-w16a8-4nodes-bs64

Collection including newmindai/Llama-3.1-8B-Instruct-w16a8-4nodes-bs64