# GRPO Trainer

[![model badge](https://img.shields.io/badge/All_models-GRPO-blue)](https://huggingface.co/models?other=grpo,trl)

## Overview

TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday).

The abstract from the paper is the following:

> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.

This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec).

## Quick start

This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [DeepMath-103K dataset](https://huggingface.co/datasets/trl-lib/DeepMath-103K). You can view the data in the dataset here:

Below is the script to train the model.

```python
# train_grpo.py
from datasets import load_dataset
from trl import GRPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()
```

Execute the script using the following command:

```bash
accelerate launch train_grpo.py
```

Distributed across 8 GPUs, the training takes approximately 1 day.

![GRPO curves](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png)

## Looking deeper into the GRPO method

GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.

![GRPO visual](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png)

### Generating completions

At each training step, we sample a batch of prompts and generate a set of  \\( G \\) completions for each prompt (denoted as  \\( o_i \\)).

### Computing the advantage

For each of the  \\( G \\) sequences, we compute the reward using a reward model or reward function. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:

$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$

This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.

> [!TIP]
> It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by  \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [GRPOConfig](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOConfig).
> Note that turning off std-based scaling also removes variance normalization, so update magnitudes depend directly on the raw reward scale and batch composition.

> [!TIP]
> As shown in [Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221), calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [GRPOConfig](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOConfig).

### Estimating the KL divergence

KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows:

$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i, [!TIP]
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by  \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).

> [!TIP]
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use  \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [GRPOConfig](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOConfig) to a non-zero value.

In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted  \\( \mu \\), can be set with `num_iterations` in [GRPOConfig](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOConfig)) by leveraging the **clipped surrogate objective**:

$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i, 0 \\
\tau_{\text{neg}}, & \text{otherwise}
\end{cases}
$$

They recommend using asymmetric temperatures,  \\( \tau_{\text{neg}} > \tau_{\text{pos}} \\) (defaults are  \\( \tau_{\text{pos}}=1.0, \tau_{\text{neg}}=1.05 \\) ). This ensures that the model is penalized more strictly for "bad" actions to prevent instability, while being more permissive with "good" actions.

To use this formulation, set `loss_type="sapo"` in the [GRPOConfig](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOConfig).

## Logged metrics

While training and evaluating, we record the following reward metrics:

- `num_tokens`: The total number of tokens processed so far, including both prompts and completions. When using tools, only non-tool tokens are counted.
- `step_time`: The average time (in seconds) taken per training step (including generation).
- `completions/mean_length`: The average length of generated completions. When using tools, only non-tool tokens are counted.
- `completions/min_length`: The minimum length of generated completions. When using tools, only non-tool tokens are counted.
- `completions/max_length`: The maximum length of generated completions. When using tools, only non-tool tokens are counted.
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted.
- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted.
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted.
- `completions/clipped_ratio`: The ratio of truncated (clipped) completions.
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
- `reward`: The overall average reward after summing rewards across functions (unweighted).
- `reward_std`: The standard deviation of summed rewards across functions (unweighted), computed over the full batch.
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region:  \\( \text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \quad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i, 1 + \epsilon_\mathrm{high}\\).
- `clip_ratio/high_max`: The maximum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region:  \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).

## Customization

### Speed up training with vLLM-powered generation

Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with

```shell
pip install trl[vllm]
```

We support two ways of using vLLM during training: **server mode** and **colocate mode**.

> [!TIP]
> By default, Truncated Importance Sampling is activated for vLLM generation to address the generation-training mismatch that occurs when using different frameworks. This can be turned off by setting `vllm_importance_sampling_correction=False`. For more information, see [Truncated Importance Sampling](paper_index#truncated-importance-sampling)

#### Option 1: Colocate mode

In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. This is the default mode.

```python
from trl import GRPOConfig

training_args = GRPOConfig(
    ...,
    use_vllm=True,  # vllm_mode="colocate" by default
)
```

#### Option 2: Server mode

In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.

1. **Start the vLLM server**:

   ```bash
   trl vllm-serve --model 
   ```

2. **Enable server mode in your training script**:

   ```python
   from trl import GRPOConfig

   training_args = GRPOConfig(
       ...,
       use_vllm=True,
       vllm_mode="server",
   )
   ```

> [!WARNING]
> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.

> [!TIP]
> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [GRPOConfig](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOConfig) to avoid underutilization or out-of-memory errors.
>
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
>
> 
>
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
>
> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).

> [!TIP]
> By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.

For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).

#### Dealing with the Training-Inference Mismatch
While vLLM greatly accelerates inference, it also decouples the inference engine from the training engine. In theory these engines are mathematically identical, in practice however they can produce different outputs due to precision effects and hardware specific optimizations. This divergence reflects the different optimization objectives of the two systems. This divergence reflects the distinct optimization goals of the two systems. Inference engines aim to maximize sampling throughput, typically measured in tokens per second, while maintaining acceptable sampling fidelity. Training frameworks instead focus on numerical stability and precision for gradient computation, often using higher precision formats like FP32 for master weights and optimizer states. These differing priorities and constraints introduce an inevitable, albeit subtle, mismatch between training and inference.

This mismatch leads to a biased gradient update which has been observed to destabilize training ([[1]](https://fengyao.notion.site/off-policy-rl)[[2]](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda)[[3]](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/#true-on-policy-rl)[[4]](https://huggingface.co/papers/2510.26788)[[5]](https://huggingface.co/papers/2510.18855)). For simplicity, consider the REINFORCE policy gradient:

$$
\nabla_\theta \mathcal{J}(x,\theta)
= \mathbb{E}_{y \sim \pi^\text{train}(\cdot \mid x,\theta)}
\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right]
$$

Here  \\( x \\) denotes prompts sampled from some data distribution, and  \\( \pi^\text{train} \\) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy  \\( \pi^\text{inference} \\), so the effective policy gradient becomes

$$
\nabla_\theta \mathcal{J}_{\text{biased}}(x,\theta)
= \mathbb{E}_{y \sim \pi^\text{inference}(\cdot \mid x,\theta)}
\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right].
$$

This turns an otherwise on policy RL problem into an off policy one.

The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: [Truncated Importance Sampling (TIS)](paper_index#truncated-importance-sampling) and [Masked Importance Sampling (MIS)](paper_index#masked-importance-sampling). Both variants can be applied either at the token level or at the sequence level.Let  \\( \rho \\) denote the importance weight, for example  \\( \rho_t \\) per token or  \\( \rho_{\text{seq}} \\) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped,

$$
\rho \leftarrow \min(\rho, C).
$$

Under MIS, ratios larger than `vllm_importance_sampling_cap` are set to zero, so those samples do not contribute to the gradient. In other words, large ratio samples are downweighted under TIS and discarded under MIS. The configuration flag `vllm_importance_sampling_mode` chooses both the IS variant (masking or truncation) and the granularity (token level or sequence level).

Importance sampling is the principled algorithmic response to the training–inference mismatch. However, there are also more direct approaches that attempt to reduce the mismatch between the two engines themselves. Most of these are engineering solutions. For example, [MiniMax M1 uses an FP32 language model head](https://huggingface.co/papers/2506.13585) in the inference engine. Thinking Machines has explored [deterministic inference kernels](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/), although this comes with a significant efficiency cost. vLLM has shown [bitwise consistent policies](https://blog.vllm.ai/2025/11/10/bitwise-consistent-train-inference.html) by building on the batch invariant deterministic kernels from Thinking Machines, but as of November 2025 there remains a substantial throughput penalty relative to standard vLLM inference.

### GRPO at scale: train a 70B+ Model on multiple nodes

When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:

- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration).
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
- **vLLM**: See the previous section on how to use vLLM to speed up generation.

Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.

```sh
#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8

# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))

# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}"  # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}"  # Node 4 for vLLM

# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
     --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
     --num_processes 32 \
     --num_machines 4 \
     --main_process_ip ${NODELIST[0]} \
     --machine_rank $SLURM_PROCID \
     --rdzv_backend c10d \
     train_grpo.py \
     --server_ip $VLLM_NODE &

# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &

wait
```

```python
import argparse

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
    args = parser.parse_args()

    dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

    training_args = GRPOConfig(
        per_device_train_batch_size=4,
        use_vllm=True,
        vllm_mode="server",
        vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."),  # from ip-X-X-X-X to X.X.X.X
    )

    trainer = GRPOTrainer(
        model="Qwen/Qwen2.5-72B",
        args=training_args,
        reward_funcs=accuracy_reward,
        train_dataset=dataset
    )
    trainer.train()

if __name__=="__main__":
    main()
```

### Using a custom reward function

The [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:

Reward functions can be either synchronous Python callables or asynchronous `async def` coroutines. When you provide multiple asynchronous reward functions, they are awaited concurrently (run in parallel via `asyncio.gather`) so their latency overlaps.

1. **Input arguments**:
   - The function must accept the following as keyword arguments:
     - `prompts` (contains the prompts),
     - `completions` (contains the generated completions),
     - `completion_ids` (contains the tokenized completions),
     - `trainer_state` ([TrainerState](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/callback#transformers.TrainerState)): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
     - `log_extra`: a callable `log_extra(column: str, values: list)` to add extra columns to the completions table. See Example 6. In distributed training, it's important that all processes log the same set of keys.
     - `log_metric`: a callable `log_metric(name: str, value: float)` to log scalar metrics as plots alongside `kl`, `entropy`, etc. See Example 6. In distributed training, it's important that all processes log the same set of keys.
     - All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.

     The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
   - Depending on the dataset format, the input will vary:
     - For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
     - For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.

2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.

#### Example 1: Reward longer completions

Below is an example of a reward function for a standard format that rewards longer completions:

```python
def reward_func(completion_ids, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of token count)."""
    return [float(len(ids)) for ids in completion_ids]
```

You can test it as follows:

```python
>>> prompts = ["The sky is", "The sun is"]  # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."]  # not used in the reward function, but the trainer will pass it
>>> completion_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids)
[2.0, 4.0]
```

#### Example 1.1: Reward longer completions (based on the number of characters)

Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.

```python
def reward_func(completions, **kwargs):
    """Reward function that assigns higher scores to longer completions (in terms of character count)."""
    return [float(len(completion)) for completion in completions]
```

You can test it as follows:

```python
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> completion_ids = [[6303, 13], [304, 279, 12884, 13]]  # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completion_ids=completion_ids)
[6.0, 12.0]
```

#### Example 2: Reward completions with a specific format

Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
It is designed for a conversational format, where prompts and completions consist of structured messages.

```python
import re

def format_reward_func(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^.*?.*?$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]
```

You can test this function as follows:

```python
>>> prompts = [
...     [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
...     [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
...     [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}],
...     [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]
```

#### Example 3: Reward completions based on a reference

Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.

```python
import re

def reward_func(completions, ground_truth, **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
```

You can test this function as follows:

```python
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]
```

#### Example 4: Multi-task reward functions

Below is an example of using multiple reward functions in the [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer). In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.

```python
from datasets import Dataset
from trl import GRPOTrainer

# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
    [
        {"prompt": "What is 2+2?", "task": "math"},
        {"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
        {"prompt": "What is 3*4?", "task": "math"},
        {"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
    ]
)

# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "math":
            # Calculate math-specific reward
            correct = check_math_solution(prompt, completion)
            reward = 1.0 if correct else -1.0
            rewards.append(reward)
        else:
            # Return None for non-math tasks
            rewards.append(None)
    return rewards

# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
    rewards = []
    for prompt, completion, t in zip(prompts, completions, task):
        if t == "coding":
            # Calculate coding-specific reward
            works = test_code_solution(prompt, completion)
            reward = 1.0 if works else -1.0
            rewards.append(reward)
        else:
            # Return None for non-coding tasks
            rewards.append(None)
    return rewards

# Use both task-specific reward functions
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=[math_reward_func, coding_reward_func],
    train_dataset=dataset,
)

trainer.train()
```

In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) will continue with the valid functions and tasks. This allows the [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) to handle multiple reward functions with different applicability.

Note that the [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.

#### Example 5: Asynchronous reward functions

Custom reward functions can also be defined as `async def` coroutines. This is useful if your reward depends on slow I/O (for example, calling a remote service). When you pass multiple async reward functions, [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) executes them concurrently so their latency overlaps.

Below is a minimal example of an async reward function that simulates an I/O-bound operation:

```python
import asyncio

async def async_reward_func(prompts, completions, **kwargs):
    # Simulate an I/O-bound call (e.g., HTTP request, database lookup)
    await asyncio.sleep(0.01)
    # Simple toy reward: 1.0 if the completion is non-empty, else 0.0
    return [1.0 if completion else 0.0 for completion in completions]
```

#### Example 6: Logging extra columns and metrics

Below is an example of a reward function that logs extra columns to the completions table and scalar metrics as plots.

```python
import re

def reward_func(completions, ground_truth, log_extra=None, log_metric=None, **kwargs):
    extracted = [re.search(r"\\boxed\{(.*?)\}", c) for c in completions]
    extracted = [m.group(1) if m else None for m in extracted]
    rewards = [1.0 if e == gt else 0.0 for e, gt in zip(extracted, ground_truth)]

    if log_extra:
        log_extra("golden_answer", list(ground_truth))
        log_extra("extracted_answer", [e or "[none]" for e in extracted])

    if log_metric:
        log_metric("accuracy", sum(rewards) / len(rewards))

    return rewards
```

#### Passing the reward function to the trainer

To use your custom reward function, pass it to the [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) as follows:

```python
from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=reward_func,
    ...,
)
```

You can pass several reward functions as a list; this list may include both synchronous and asynchronous functions:

```python
from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=[reward_func, async_reward_func1, async_reward_func2],
    ...,
)
```

and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config.

Note that [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) supports multiple reward functions of different types. See the parameters documentation for more details.

### Rapid Experimentation for GRPO

RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple GRPO configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you **see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight** without restarting. For more information, see [RapidFire AI Integration](rapidfire_integration).

## Agent Training

GRPO supports **agent training** through the `tools` argument in [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer).
This parameter expects a list of Python functions (sync or async) that define the tools available to the agent:

```python
from trl import GRPOTrainer

trainer = GRPOTrainer(
    tools=[tool1, tool2],
    ...,
)
```

Each tool must be a standard Python function with **type-hinted arguments and return types**, along with a **Google-style docstring** describing its purpose, arguments, and return value.
For more details, see the [Passing tools guide](https://huggingface.co/docs/transformers/en/chat_extras#passing-tools).

Example:

```python
from trl import GRPOTrainer

def multiply(a: int, b: int) -> int:
    """
    Multiplies two integers.

    Args:
        a: The first integer.
        b: The second integer.

    Returns:
        The product of the two integers.
    """
    return a * b

async def async_add(a: int, b: int) -> int:
    """
    Asynchronously adds two integers.

    Args:
        a: The first integer.
        b: The second integer.

    Returns:
        The sum of the two integers.
    """
    return a + b

trainer = GRPOTrainer(
    tools=[multiply, async_add],
    ...,
)
```

You can also provide tools through `environment_factory`. In this mode, [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer) creates one environment instance per rollout and exposes the environment's public methods as tools.

> [!IMPORTANT]
> `environment_factory` requires `transformers>=5.2.0`.

The following is a minimal example of using `environment_factory` to define a simple environment with an `increment` method, which is exposed as a tool to the agent:

```python
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer

instructions = [f"Increment the counter by {i}." for i in range(1, 7)]
dataset = Dataset.from_dict({"prompt": [[{"role": "user", "content": instruction}] for instruction in instructions]})

def reward_func(environments, **kwargs):  # dummy reward: the reward is the current value of the counter
    return [environment.counter for environment in environments]

class IncrementEnv:
    def reset(self, **kwargs) -> str | None:  # required; receives sampled row fields as kwargs (e.g., `prompt`)
        self.counter = 0
        return "Counter reset to 0.\n"

    def increment(self, step: int) -> int:  # the other public methods of the environment are exposed as tools
        """
        Increment the internal counter.

        Args:
            step: Value to add to the counter.

        Returns:
            The updated counter value.
        """
        self.counter += step
        return self.counter

trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    args=GRPOConfig(chat_template_kwargs={"enable_thinking": False}),
    train_dataset=dataset,
    reward_funcs=reward_func,
    environment_factory=IncrementEnv,
)
trainer.train()
```

`reset` can return either `None` or a string. In GRPO, when it returns a string, that string is appended to the last user message before generation.

### Supported Models

Tested with:

- [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B`
- [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B`

> [!TIP]
> Compatibility with all LLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.

### Quick Start

Use [grpo\_agent.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_agent.py) to fine-tune a LLM for agentic workflows.

```bash
accelerate launch \
  --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
  examples/scripts/grpo_agent.py \
  --model_name_or_path Qwen/Qwen3-0.6B
  ...
```

## Vision-Language Model (VLM) Training

GRPO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images.

### Supported Models

Tested with:

- **Gemma3** — e.g., `google/gemma-3-4b-it`
- **LLaVA-NeXT** — e.g., `llava-hf/llava-v1.6-mistral-7b-hf`
- **Qwen2-VL** — e.g., `Qwen/Qwen2-VL-2B-Instruct`
- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct`
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
  
> [!TIP]
> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.

### Quick Start

Use [grpo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified):

```bash
accelerate launch \
  --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
  examples/scripts/grpo_vlm.py \
  --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
  --output_dir grpo-Qwen2.5-VL-3B-Instruct \
  --learning_rate 1e-5 \
  --dtype bfloat16 \
  --max_completion_length 1024 \
  --use_vllm \
  --vllm_mode colocate \
  --use_peft \
  --lora_target_modules "q_proj", "v_proj" \
  --log_completions
```

### Configuration Tips

- Use LoRA on vision-language projection layers
- Enable 4-bit quantization to reduce memory usage
- VLMs are memory-intensive — start with smaller batch sizes
- Most models are compatible with vLLM (`server` and `colocate` modes)

### Dataset Format

Each training sample should include:

- `prompt`: Text formatted via the processor's chat template
- `image`/`images`: PIL Image or list of PIL Images

The trainer automatically handles image-to-tensor conversion via the model’s image processor.

## GRPOTrainer[[trl.GRPOTrainer]]

#### trl.GRPOTrainer[[trl.GRPOTrainer]]

[Source](https://github.com/huggingface/trl/blob/v1.0.0rc1/trl/trainer/grpo_trainer.py#L130)

Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language
Models](https://huggingface.co/papers/2402.03300).

Example:

```python
from trl import GRPOTrainer
from trl.rewards import accuracy_reward
from datasets import load_dataset

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()
```

traintrl.GRPOTrainer.trainhttps://github.com/huggingface/trl/blob/v1.0.0rc1/transformers/trainer.py#L1322[{"name": "resume_from_checkpoint", "val": ": str | bool | None = None"}, {"name": "trial", "val": ": optuna.Trial | dict[str, Any] | None = None"}, {"name": "ignore_keys_for_eval", "val": ": list[str] | None = None"}]- **resume_from_checkpoint** (`str` or `bool`, *optional*) --
  If a `str`, local path to a saved checkpoint as saved by a previous instance of `Trainer`. If a
  `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
  of `Trainer`. If present, training will resume from the model/optimizer/scheduler states loaded here.
- **trial** (`optuna.Trial` or `dict[str, Any]`, *optional*) --
  The trial run or the hyperparameter dictionary for hyperparameter search.
- **ignore_keys_for_eval** (`list[str]`, *optional*) --
  A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  gathering predictions for evaluation during the training.0`~trainer_utils.TrainOutput`Object containing the global step count, training loss, and metrics.

Main training entry point.

**Parameters:**

model (`str` or [PreTrainedModel](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/model#transformers.PreTrainedModel) or `PeftModel`) : Model to be trained. Can be either:  - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [save_pretrained](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/model#transformers.PreTrainedModel.save_pretrained), e.g., `'./my_model_directory/'`. The model is loaded using `.from_pretrained` (where `` is derived from the model config) with the keyword arguments in `args.model_init_kwargs`. - A [PreTrainedModel](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/model#transformers.PreTrainedModel) object. Only causal language models are supported. - A `PeftModel` object. Only causal language models are supported.

reward_funcs (`RewardFunc | list[RewardFunc]`) : Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward functions with the prompts and completions and sum the rewards. Can be either:  - A single reward function, such as: - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a path to a *directory* containing model weights saved using [save_pretrained](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/model#transformers.PreTrainedModel.save_pretrained), e.g., `'./my_model_directory/'`. The model is loaded using [from_pretrained](https://huggingface.co/docs/transformers/v5.3.0/en/model_doc/auto#transformers.AutoModelForSequenceClassification.from_pretrained) with `num_labels=1` and the keyword arguments in `args.model_init_kwargs`. - A [PreTrainedModel](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/model#transformers.PreTrainedModel) object: Only sequence classification models are supported. - A custom reward function: The function is provided with the prompts and the generated completions, plus any additional columns in the dataset. It should return a list of rewards. Custom reward functions can be either synchronous or asynchronous and can also return `None` when the reward is not applicable to those samples. This is useful for multi-task training where different reward functions apply to different types of samples. When a reward function returns `None` for a sample, that reward function is excluded from the reward calculation for that sample. For more details, see [Using a custom reward function](#using-a-custom-reward-function).  The trainer's state is also passed to the reward function. The trainer's state is an instance of [TrainerState](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/callback#transformers.TrainerState) and can be accessed by accessing the `trainer_state` argument to the reward function's signature. - A list of reward functions, where each item can independently be any of the above types. Mixing different types within the list (e.g., a string model ID and a custom reward function) is allowed.

args ([GRPOConfig](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOConfig), *optional*) : Configuration for this trainer. If `None`, a default configuration is used.

train_dataset ([Dataset](https://huggingface.co/docs/datasets/v4.8.3/en/package_reference/main_classes#datasets.Dataset) or [IterableDataset](https://huggingface.co/docs/datasets/v4.8.3/en/package_reference/main_classes#datasets.IterableDataset)) : Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is ignored. The format of the samples can be either:  - [Standard](dataset_formats#standard): Each sample contains plain text. - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role and content).

eval_dataset ([Dataset](https://huggingface.co/docs/datasets/v4.8.3/en/package_reference/main_classes#datasets.Dataset), [IterableDataset](https://huggingface.co/docs/datasets/v4.8.3/en/package_reference/main_classes#datasets.IterableDataset) or `dict[str, Dataset | IterableDataset]`) : Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.

processing_class ([PreTrainedTokenizerBase](https://huggingface.co/docs/transformers/v5.3.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase), [ProcessorMixin](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/processors#transformers.ProcessorMixin), *optional*) : Processing class used to process the data. The padding side must be set to "left". If `None`, the processing class is loaded from the model's name with [from_pretrained](https://huggingface.co/docs/transformers/v5.3.0/en/model_doc/auto#transformers.AutoProcessor.from_pretrained). A padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.

reward_processing_classes ([PreTrainedTokenizerBase](https://huggingface.co/docs/transformers/v5.3.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase) or `list[PreTrainedTokenizerBase]`, *optional*) : Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:  - A single processing class: Used when `reward_funcs` contains only one reward function. - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. If set to `None`, or if an element of the list corresponding to a [PreTrainedModel](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/model#transformers.PreTrainedModel) is `None`, the tokenizer for the model is automatically loaded using [from_pretrained](https://huggingface.co/docs/transformers/v5.3.0/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained). For elements in `reward_funcs` that are custom reward functions (not [PreTrainedModel](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/model#transformers.PreTrainedModel)), the corresponding entries in `reward_processing_classes` are ignored.

callbacks (list of [TrainerCallback](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/callback#transformers.TrainerCallback), *optional*) : List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).  If you want to remove one of the default callbacks used, use the [remove_callback](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/trainer#transformers.Trainer.remove_callback) method.

optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`) : A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your model and a scheduler given by [get_linear_schedule_with_warmup](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/optimizer_schedules#transformers.get_linear_schedule_with_warmup) controlled by `args`.

peft_config (`PeftConfig`, *optional*) : PEFT configuration used to wrap the model. If `None`, the model is not wrapped.

tools (list of `Callable`, *optional*) : A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool should be a standard Python function with properly type-hinted arguments and return values, and a Google-style docstring describing its purpose, arguments, and return value. For more details, see: https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool use and that it has been fine-tuned for tool calling.

rollout_func (`RolloutFunc`, *optional*) : Function to use for generating completions. It receives the list of prompts allocated to the current process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and `"logprobs"` fields, and can optionally return `"logprob_token_ids"` (same shape as `"logprobs"`). Any other fields are forwarded to the reward functions. The function receives the raw per-process prompt slice with no duplication; it is responsible for returning the correct number of completions per prompt (see `num_generations` / `num_generations_eval` on the trainer). This feature is experimental and may change or be removed at any time without prior notice.

environment_factory (`EnvironmentFactory`, *optional*) : A callable that creates and returns an environment instance. The environment class should define methods that can be invoked as tools during generation. Each method should comply with the same requirements as the `tools` described above. If `environment_factory` is provided, an instance of the environment is created for each generation in the batch, allowing for parallel and independent interactions. The environment must also implement a callable `reset` method that can be used to reset state between generations. The `reset` method should return either `None` or a string: when it returns a string, that string is appended to the last user message before generation. This feature is experimental and may change or be removed at any time without prior notice.

**Returns:**

``~trainer_utils.TrainOutput``

Object containing the global step count, training loss, and metrics.
#### save_model[[trl.GRPOTrainer.save_model]]

[Source](https://github.com/huggingface/trl/blob/v1.0.0rc1/transformers/trainer.py#L3739)

Will save the model, so you can reload it using `from_pretrained()`.

Will only save from the main process.
#### push_to_hub[[trl.GRPOTrainer.push_to_hub]]

[Source](https://github.com/huggingface/trl/blob/v1.0.0rc1/transformers/trainer.py#L3986)

Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.

**Parameters:**

commit_message (`str`, *optional*, defaults to `"End of training"`) : Message to commit while pushing.

blocking (`bool`, *optional*, defaults to `True`) : Whether the function should return only when the `git push` has finished.

token (`str`, *optional*, defaults to `None`) : Token with write permission to overwrite Trainer's original args.

revision (`str`, *optional*) : The git revision to commit from. Defaults to the head of the "main" branch.

kwargs (`dict[str, Any]`, *optional*) : Additional keyword arguments passed along to `~Trainer.create_model_card`.

**Returns:**

The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
progress of the commit if `blocking=True`.

## GRPOConfig[[trl.GRPOConfig]]

#### trl.GRPOConfig[[trl.GRPOConfig]]

[Source](https://github.com/huggingface/trl/blob/v1.0.0rc1/trl/trainer/grpo_config.py#L22)

Configuration class for the [GRPOTrainer](/docs/trl/v1.0.0rc1/en/grpo_trainer#trl.GRPOTrainer).

This class includes only the parameters that are specific to GRPO training. For a full list of training arguments,
please refer to the [TrainingArguments](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/trainer#transformers.TrainingArguments) documentation. Note that default values in this class may
differ from those in [TrainingArguments](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/trainer#transformers.TrainingArguments).

Using [HfArgumentParser](https://huggingface.co/docs/transformers/v5.3.0/en/internal/trainer_utils#transformers.HfArgumentParser) we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

> [!NOTE]
> These parameters have default values different from [TrainingArguments](https://huggingface.co/docs/transformers/v5.3.0/en/main_classes/trainer#transformers.TrainingArguments):
> - `logging_steps`: Defaults to `10` instead of `500`.
> - `gradient_checkpointing`: Defaults to `True` instead of `False`.
> - `bf16`: Defaults to `True` if `fp16` is not set, instead of `False`.
> - `learning_rate`: Defaults to `1e-6` instead of `5e-5`.

