# Low Precision Training Methods

Accelerate provides integrations to train on lower precision methods using specified supported hardware through the `TransformersEngine`, `MS-AMP`, and `torchao` packages. This documentation will help guide you through what hardware is supported, how to configure your [Accelerator](/docs/accelerate/v1.13.0/en/package_reference/accelerator#accelerate.Accelerator) to leverage the low precision methods, and what you can expect when training. 

## What training on FP8 means

To explore more of the nitty-gritty in training in FP8 with PyTorch and Accelerate, check out the [concept_guide](../concept_guides/low_precision_training) on why this can be difficult. But essentially rather than training in BF16, some (or all) aspects of training a model can be performed using 8 bits instead of 16. The challenge is doing so without degrading final performance. 

This is only enabled on specific NVIDIA hardware, namely:

* Anything after the 3000 series consumer graphics cards (such as the 4090)
* Hopper-based GPU architectures (such as the `H100` and `H200`)

What this will result in is some reduction in the memory used (as we've cut the needed memory in half for some parts of training) and an increase in throughput *should* be seen as well for larger models that can replace certain layers with FP8-enabled ones.

## Configuring the Accelerator

Currently two actively maintained backends for FP8 are supported (`TransformersEngine` and `torchao`), each with different capabilities and configurations. A legacy `MS-AMP` backend also exists but is no longer recommended (see [below](#configuring-ms-amp) for details).

To use either, the same core API is used. Just pass `mixed_precision="fp8"` to either the [Accelerator](/docs/accelerate/v1.13.0/en/package_reference/accelerator#accelerate.Accelerator), during `accelerate config` when prompted about mixed precision, or as part of your `config.yaml` file in the `mixed_precision` key:

```{python}
from accelerate import Accelerator
accelerator = Accelerator(mixed_precision="fp8")
```

To specify a backend (and customize other parts of the FP8 mixed precision setup), you can utilize one of the `RecipeKwargs` dataclasses such as `utils.AORecipeKwargs`, `utils.TERecipeKwargs`, or `utils.MSAMPRecipeKwargs`; you can also clarify it in your config `yaml`/during `accelerate launch`. We recommend using `TransformersEngine` or `torchao` for new projects:

```{python}
from accelerate import Accelerator
from accelerate.utils import TERecipeKwargs, AORecipeKwargs
# Use TransformersEngine
kwargs = [TERecipeKwargs()]
# Or to use torchao
# kwargs = [AORecipeKwargs()]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```

```{yaml}
mixed_precision: fp8
fp8_config:
  amax_compute_algo: max
  amax_history_len: 1024
  backend: TE
  fp8_format: HYBRID
  interval: 1
  margin: 0
  override_linear_precision: (false, false, false)
  use_autocast_during_eval: false
```

## Configuring MS-AMP

**⚠️ Deprecated / Unmaintained:** MS-AMP is no longer actively maintained by Microsoft. The [MS-AMP repository](https://github.com/Azure/MS-AMP) has not received updates since 2023 and has known compatibility issues:

- Requires CUDA 11.x (does not support CUDA 12.x+)
- Requires older NCCL versions incompatible with recent PyTorch releases
- Does not support recent PyTorch versions (2.2+)

**We strongly recommend using [`TransformersEngine`](#configuring-transformersengine) or [`torchao`](#configuring-torchao) instead for all new and existing FP8 training workflows.** Both are actively maintained and support modern CUDA/PyTorch versions. Native PyTorch FP8 support via `torchao` is particularly promising as a vendor-neutral solution.

The MS-AMP backend is retained in Accelerate for legacy compatibility but may be removed in a future release.

`MS-AMP` has a single configuration argument: the optimization level. 

Currently two levels of optimization are supported in the Accelerate integration, `"O1"` and `"O2"` (using the letter 'o', not zero). 

* `"O1"` will cast the weight gradients and `all_reduce` communications to happen in 8-bit, while the rest are done in 16 bit. This reduces the general GPU memory usage and speeds up communication bandwidths.
* `"O2"` will also cast first-order optimizer states into 8 bit, while the second order states are in FP16. (Currently just the `Adam` optimizer is supported). This tries its best to minimize final accuracy degradation and will save the highest potential memory.

To specify an optimization level, pass it to the `FP8KwargsHandler` by setting the `optimization_level` argument:

```{python}
from accelerate import Accelerator
from accelerate.utils import FP8RecipeKwargs
kwargs = [FP8RecipeKwargs(backend="msamp", optimization_level="O2")]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```

Or during `accelerate launch` via `--fp8_backend=msamp --fp8_opt_level=O2`

Similarly this can be set in your `config.yaml`:

```{yaml}
mixed_precision: fp8
fp8_config:
    backend: MSAMP
    opt_level: O2
```

## Configuring TransformersEngine

TransformersEngine has many options for customizing how and what FP8 calculations are performed. A full list of supported arguments and what they mean are available in [NVIDIA's documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html), however they are restated as part of `FP8KwargsHandler`'s docstring for your convenience. 

Accelerate tries to set sensible defaults, but exploring and tweaking the various parameters yourself can lead to better performance potentially.

To use it, specify `backend="te"` and modify any of the arguments you want as part of your kwarg handler:

```{python}
from accelerate import Accelerator
from accelerate.utils import FP8RecipeKwargs
kwargs = [FP8RecipeKwargs(backend="te", ...)]
accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```

Or during `accelerate launch` via `--fp8_backend=te ...`. Use `accelerate launch --fp8_backend=te -h` to see relevent arguments.

Similarly this can be set in your `config.yaml`:

```{yaml}
mixed_precision: fp8
fp8_config:
    amax_compute_algo: max
    amax_history_len: 1024
    backend: TE
    fp8_format: HYBRID
    interval: 1
    margin: 0
    override_linear_precision: (false, false, false)
    use_autocast_during_eval: false
```

## Configuring `torchao`

`torchao` is a [PyTorch-driven](https://github.com/pytorch/ao/tree/main/torchao/float8) hackable FP8 backend, aiming to be more approchable than the prior two engines. One of the core differences with `ao` compared to the prior two is that for numerical stability, it's found to be generally better off keeping the first *and* last layers in the model at the regular precision (be it FP32 or BF16), and then the other layers quantized down to FP8. As a result, a config for `ao` looks a bit differently:

> Note: this API is experimental and is subject to change

```{python}
from accelerate import Accelerator
from accelerate.utils import AORecipeKwargs, TorchDynamoPlugin, FullyShardedDataParallelPlugin
from torchao.float8 import Float8LinearConfig

fsdp2_plugin = FullyShardedDataParallelPlugin(
  fsdp_version=2,
  cpu_ram_efficient_loading=False, # CPU RAM efficient loading CANNOT work with fp8 torchao
  fsdp_auto_wrap_policy="TRANSFORMER_BASED_WRAP",
)
dynamo_plugin = TorchDynamoPlugin(
  backend="inductor",
  use_regional_compilation=True,
)
fp8_config = Float8LinearConfig(
  enable_fsdp_float8_all_gather=True, # Use FP8 all_gather in FSDP2
  pad_inner_dim=True,
)
kwargs = [AORecipeKwargs(
  config=fp8_config
)]
accelerator = Accelerator(
  mixed_precision="fp8",
  fsdp_plugin=fsdp2_plugin,
  dynamo_plugin=dynamo_plugin,
  kwarg_handlers=kwargs,
)
```

Or during `accelerate launch` via `--fp8_backend=ao ...`. Use `accelerate launch --fp8_backend=ao -h` to see relevent arguments.

Similarly, this can be set in `config.yaml`:

```{yaml}
mixed_precision: fp8
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_cpu_ram_efficient_loading: false
  fsdp_version: 2
fp8_config:
  backend: AO
  pad_inner_dim: true
  enable_fsdp_float8_all_gather: true
dynamo_config:
  dynamo_backend: INDUCTOR
  dynamo_use_regional_compilation: true
```

To learn more about the specific parameters to be used, please see the official `torchao` repo.

## Example Zoo

We have examples showcasing training with FP8 both with accelerate and its underlying implementation available in the accelerate repo.
Currently we support scripts showcasing:

* Single GPU
* Distributed Data Parallelism (Multi-GPU)
* Fully Sharded Data Parallelism
* DeepSpeed ZeRO 1 through 3

Find out more [here](https://github.com/huggingface/accelerate/tree/main/benchmarks/fp8)

## Further Reading

To learn more about training in FP8 please check out the following resources:

* [Our concept guide](../concept_guides/low_precision_training) detailing into more about TransformersEngine, torchao, and MS-AMP
* [The `transformers-engine` documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html)
* [The `torchao` documentation](https://github.com/pytorch/ao/tree/main/torchao/float8)
* [The `MS-AMP` documentation](https://azure.github.io/MS-AMP/docs/) (⚠️ no longer maintained)

