Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,38 @@ training_args = GRPOConfig(

For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).


#### Dealing with the Training-Inference Mismatch
While vLLM greatly accelerates inference, it also decouples the inference engine from the training engine. In theory these engines are mathematically identical, in practice however they can produce different outputs due to precision effects and hardware specific optimizations. This divergence reflects the different optimization objectives of the two systems. This divergence reflects the distinct optimization goals of the two systems. Inference engines aim to maximize sampling throughput, typically measured in tokens per second, while maintaining acceptable sampling fidelity. Training frameworks instead focus on numerical stability and precision for gradient computation, often using higher precision formats like FP32 for master weights and optimizer states. These differing priorities and constraints introduce an inevitable, albeit subtle, mismatch between training and inference.

This mismatch leads to a biased gradient update which has been observed to destabilize training ([[1]](https://fengyao.notion.site/off-policy-rl)[[2]](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda)[[3]](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/#true-on-policy-rl)[[4]](https://arxiv.org/abs/2510.26788)[[5]](https://arxiv.org/abs/2510.18855)). For simplicity, consider the REINFORCE policy gradient:

$$
\nabla_\theta \mathcal{J}(x,\theta)
= \mathbb{E}_{y \sim \pi^\text{train}(\cdot \mid x,\theta)}
\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right]
$$

Here \\( x \\) denotes prompts sampled from some data distribution, and \\( \pi^\text{train} \\) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy \\( \pi^\text{inference} \\), so the effective policy gradient becomes

$$
\nabla_\theta \mathcal{J}_{\text{biased}}(x,\theta)
= \mathbb{E}_{y \sim \pi^\text{inference}(\cdot \mid x,\theta)}
\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right].
$$

This turns an otherwise on policy RL problem into an off policy one.

The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: [Truncated Importance Sampling (TIS)](paper_index#truncated-importance-sampling) and [Masked Importance Sampling (MIS)](paper_index#masked-importance-sampling). Both variants can be applied either at the token level or at the sequence level.Let \\( \rho \\) denote the importance weight, for example \\( \rho_t \\) per token or \\( \rho_{\text{seq}} \\) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped,

$$
\rho \leftarrow \min(\rho, C).
$$

Under MIS, ratios larger than `vllm_importance_sampling_cap` are set to zero, so those samples do not contribute to the gradient. In other words, large ratio samples are downweighted under TIS and discarded under MIS. The configuration flag `vllm_importance_sampling_mode` chooses both the IS variant (masking or truncation) and the granularity (token level or sequence level).

Importance sampling is the principled algorithmic response to the training–inference mismatch. However, there are also more direct approaches that attempt to reduce the mismatch between the two engines themselves. Most of these are engineering solutions. For example, [MiniMax M1 uses an FP32 language model head](https://arxiv.org/abs/2506.13585) in the inference engine. Thinking Machines has explored [deterministic inference kernels](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/), although this comes with a significant efficiency cost. vLLM has shown [bitwise consistent policies](https://blog.vllm.ai/2025/11/10/bitwise-consistent-train-inference.html) by building on the batch invariant deterministic kernels from Thinking Machines, but as of November 2025 there remains a substantial throughput penalty relative to standard vLLM inference.

### GRPO at scale: train a 70B+ Model on multiple nodes

When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
Expand Down
79 changes: 78 additions & 1 deletion docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ $$
}
$$

where \\( C \\) is a hyper-parameter. In TRL, TIS is implemented for GRPO, and enabled by default when vLLM is used for generation (`use_vllm=True`)
where \\( C \\) is a hyper-parameter. TIS is implemented in GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `truncate`, such as `"sequence_truncate"` or `"token_truncate"`.

```python
from trl import GRPOConfig
Expand All @@ -235,10 +235,87 @@ training_args = GRPOConfig(
...
use_vllm=True,
vllm_importance_sampling_correction=True, # default True
vllm_importance_sampling_mode="sequence_truncate", # or "token_truncate"
vllm_importance_sampling_cap=2.0, # hyper-parameter C
)
```

### Masked Importance Sampling

**📰 Blog**: https://ringtech.notion.site/icepop

**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda
Comment thread
LeonEricsson marked this conversation as resolved.

Masked Importance Sampling (MIS) addresses the same issue as [Truncated Importance Sampling](#truncated-importance-sampling) but replaces clipping with masking. MIS takes a more decisive stance by discarding updates whose discrepancy exceeds a threshold \\( C \\). We apply upper-side masking, so any ratio above \\( C \\) is removed from the update.


$$
\small{
\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})}
\Bigl[
\underbrace{\mathbf{1}\left[
\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})}
{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})}
\le C
\right]
\cdot
\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})}
{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})}}_{\text{masked importance ratio}} \cdot
\nabla_\theta
\min\Bigl(
\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A,
\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A
\Bigr)
\Bigr]
}
$$

MIS is implemented for GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `"mask"`, such as `"sequence_mask"` or `"token_mask"`.

```python
from trl import GRPOConfig

training_args = GRPOConfig(
...
use_vllm=True,
vllm_importance_sampling_correction=True, # default True
vllm_importance_sampling_mode="sequence_mask", # or "token_mask"
vllm_importance_sampling_cap=2.0, # hyper-parameter C
)
```

### Sequence-level Importance Sampling

**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda

The theoretically principled way to correct for the training-inference distribution shift is importance sampling, as introduced in the two papers above [Truncated Importance Sampling](#truncated-importance-sampling) and [Masked Importance Sampling](#masked-importance-sampling). However, the choice of formulation is crucial for keeping the gradient unbiased and ensuring stable training.

This work shows that sequence-level importance sampling is the sound approach for addressing the training–inference mismatch. Although token-level importance sampling achieves lower variance than a sequence-level ratio, it introduces bias and is therefore argued to be unsuitable for autoregressive models. The token-level gradient estimator is

$$
\mathbb{E}_{x\sim\mathcal{D},\, y\sim \pi^{\text{inference}}_\theta(\cdot|x)}
\Bigg[
R(x,y)\,\cdot\,
\sum_{t=0}^{|y|-1}
\frac{\pi^{\text{training}}_\theta(y_t\,|\,x, y_{<t})}
{\pi^{\text{inference}}_\theta(y_t\,|\,x, y_{<t})}
\,\nabla_\theta \log \pi^{\text{training}}_\theta(y_t\,|\,x, y_{<t})
\Bigg]
$$
The correct, unbiased policy gradient estimator applies a single importance ratio over the entire generated sequence (trajectory) \\( y \\), The Sequence-Level IS estimator looks like:

$$
\mathbb{E}_{x\sim\mathcal{D},\, y\sim \pi^{\text{inference}}_\theta(\cdot|x)}
\Bigg[
\frac{\pi^{\text{training}}_\theta(y|x)}
{\pi^{\text{inference}}_\theta(y|x)}
\, R(x,y)\,
\nabla_\theta \log \pi^{\text{training}}_\theta(y|x)
\Bigg]
$$

TRL exposes the Importance Sampling granularity level through the `vllm_importance_sampling_mode` configuration parameter where `"sequence_*"` modes implement a sequence-level importance sampling ratio and `"token_*"` a per-token ratio.

### Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning

**📜 Paper**: https://huggingface.co/papers/2508.09726
Expand Down
56 changes: 39 additions & 17 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,16 +243,24 @@ class GRPOConfig(TrainingArguments):
instead.

</Deprecated>

vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`):
Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed
logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL
Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework
(such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation
and training backends. TIS is proposed as a remedy for this issue.
vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`):
Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance
sampling ratio, improving training stability.
Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM
completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied
regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines
how the IS ratios are computed and constrained.
vllm_importance_sampling_mode (`str`, *optional*, defaults to `"sequence_mask"`):
Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`. Possible
values are:

- `"token_truncate"`: Token-level truncated IS (default). Per-token ratios are clipped from above at C.
- `"token_mask"`: Token-level masked IS. Per-token ratios above C are set to zero.
- `"sequence_truncate"`: Sequence-level truncated IS. A single sequence ratio is clipped from above at
C and applied to all tokens in the sequence.
- `"sequence_mask"`: Sequence-level masked IS. Sequences with ratios above C are masked out.
vllm_importance_sampling_cap (`float`, *optional*, defaults to `3.0`):
Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes,
importance ratios are clipped from above at C. For `*_mask` modes, ratios larger than C
are set to zero.

> Parameters that control the logging

Expand Down Expand Up @@ -676,18 +684,32 @@ class GRPOConfig(TrainingArguments):
vllm_importance_sampling_correction: bool = field(
default=True,
metadata={
"help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and "
"recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL "
"Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy "
"effects due to subtle implementation differences between generation and training backends. TIS is "
"proposed as a remedy for this issue."
"help": "Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM "
"completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied "
"regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines how "
"IS ratios are computed and constrained."
},
)

vllm_importance_sampling_mode: str = field(
default="sequence_mask",
metadata={
"help": "Specifies how Importance Sampling (IS) is performed when "
"vllm_importance_sampling_correction=True. Modes are defined along two orthogonal "
"dimensions: (1) constraint, which determines how to handle ratios above "
"vllm_importance_sampling_cap (C)—either truncation (clip from above, ρ ← min(ρ, C)) or "
"masking (set ratios above C to zero); and (2) granularity, which determines whether "
"ratios are computed per token or as a single sequence-level ratio applied to all tokens. "
"Supported options are: 'token_truncate', 'token_mask', 'sequence_truncate', and "
"'sequence_mask'."
},
)

vllm_importance_sampling_cap: float = field(
default=2.0,
default=3.0,
metadata={
"help": "Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the "
"importance sampling ratio, improving training stability."
"help": "Importance sampling cap C used by `vllm_importance_sampling_mode`. For '*_truncate' modes, "
"ratios are clipped from above at C. For '*_mask' modes, ratios larger than C are set to zero."
},
)

Expand Down
40 changes: 34 additions & 6 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def __init__(
self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction
self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode
self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap
self.use_liger_kernel = args.use_liger_kernel
self.loss_type = args.loss_type
Expand Down Expand Up @@ -1576,10 +1577,33 @@ def _generate_and_score_completions(

# Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
if self.use_vllm and self.vllm_importance_sampling_correction:
importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps)
importance_sampling_ratio = torch.clamp(
importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * completion_mask

sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]
if sequence_level_is:
per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True)
logps_diff = per_sequence_logps_diff
else:
logps_diff = per_token_logps_diff

vllm_importance_sampling_ratio = torch.exp(logps_diff)

# vllm_importance_sampling_ratio.shape:
# token_* modes: (B, T) (per-token ratio)
# sequence_* modes: (B, 1) (per-sequence ratio)

if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]:
vllm_importance_sampling_ratio = torch.clamp(
vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]:
vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill(
vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0
)
else:
raise ValueError(
f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'."
)

# Compute the per-token log probabilities for the reference model
if self.beta != 0.0:
Expand Down Expand Up @@ -1702,7 +1726,11 @@ def _generate_and_score_completions(
self.accelerator.gather(max_delta).max().item()
)

flat_is_ratio = importance_sampling_ratio[completion_mask.bool()]
if sequence_level_is:
flat_is_ratio = vllm_importance_sampling_ratio.flatten()
else:
flat_is_ratio = vllm_importance_sampling_ratio[completion_mask.bool()]

min_importance_sampling_ratio = (
torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
Expand Down Expand Up @@ -1733,7 +1761,7 @@ def _generate_and_score_completions(
if old_per_token_logps is not None:
output["old_per_token_logps"] = old_per_token_logps
if self.use_vllm and self.vllm_importance_sampling_correction:
output["importance_sampling_ratio"] = importance_sampling_ratio
output["importance_sampling_ratio"] = vllm_importance_sampling_ratio
if ref_per_token_logps is not None:
output["ref_per_token_logps"] = ref_per_token_logps
if "pixel_values" in forward_kwargs:
Expand Down
Loading