diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 92a40b009d2..70f7166a842 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -239,6 +239,38 @@ training_args = GRPOConfig( For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods). + +#### Dealing with the Training-Inference Mismatch +While vLLM greatly accelerates inference, it also decouples the inference engine from the training engine. In theory these engines are mathematically identical, in practice however they can produce different outputs due to precision effects and hardware specific optimizations. This divergence reflects the different optimization objectives of the two systems. This divergence reflects the distinct optimization goals of the two systems. Inference engines aim to maximize sampling throughput, typically measured in tokens per second, while maintaining acceptable sampling fidelity. Training frameworks instead focus on numerical stability and precision for gradient computation, often using higher precision formats like FP32 for master weights and optimizer states. These differing priorities and constraints introduce an inevitable, albeit subtle, mismatch between training and inference. + +This mismatch leads to a biased gradient update which has been observed to destabilize training ([[1]](https://fengyao.notion.site/off-policy-rl)[[2]](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda)[[3]](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/#true-on-policy-rl)[[4]](https://arxiv.org/abs/2510.26788)[[5]](https://arxiv.org/abs/2510.18855)). For simplicity, consider the REINFORCE policy gradient: + +$$ +\nabla_\theta \mathcal{J}(x,\theta) += \mathbb{E}_{y \sim \pi^\text{train}(\cdot \mid x,\theta)} +\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right] +$$ + +Here \\( x \\) denotes prompts sampled from some data distribution, and \\( \pi^\text{train} \\) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy \\( \pi^\text{inference} \\), so the effective policy gradient becomes + +$$ +\nabla_\theta \mathcal{J}_{\text{biased}}(x,\theta) += \mathbb{E}_{y \sim \pi^\text{inference}(\cdot \mid x,\theta)} +\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right]. +$$ + +This turns an otherwise on policy RL problem into an off policy one. + +The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: [Truncated Importance Sampling (TIS)](paper_index#truncated-importance-sampling) and [Masked Importance Sampling (MIS)](paper_index#masked-importance-sampling). Both variants can be applied either at the token level or at the sequence level.Let \\( \rho \\) denote the importance weight, for example \\( \rho_t \\) per token or \\( \rho_{\text{seq}} \\) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped, + +$$ +\rho \leftarrow \min(\rho, C). +$$ + +Under MIS, ratios larger than `vllm_importance_sampling_cap` are set to zero, so those samples do not contribute to the gradient. In other words, large ratio samples are downweighted under TIS and discarded under MIS. The configuration flag `vllm_importance_sampling_mode` chooses both the IS variant (masking or truncation) and the granularity (token level or sequence level). + +Importance sampling is the principled algorithmic response to the training–inference mismatch. However, there are also more direct approaches that attempt to reduce the mismatch between the two engines themselves. Most of these are engineering solutions. For example, [MiniMax M1 uses an FP32 language model head](https://arxiv.org/abs/2506.13585) in the inference engine. Thinking Machines has explored [deterministic inference kernels](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/), although this comes with a significant efficiency cost. vLLM has shown [bitwise consistent policies](https://blog.vllm.ai/2025/11/10/bitwise-consistent-train-inference.html) by building on the batch invariant deterministic kernels from Thinking Machines, but as of November 2025 there remains a substantial throughput penalty relative to standard vLLM inference. + ### GRPO at scale: train a 70B+ Model on multiple nodes When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include: diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 70955566f7c..b1c2eb1a1fb 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -226,7 +226,7 @@ $$ } $$ -where \\( C \\) is a hyper-parameter. In TRL, TIS is implemented for GRPO, and enabled by default when vLLM is used for generation (`use_vllm=True`) +where \\( C \\) is a hyper-parameter. TIS is implemented in GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `truncate`, such as `"sequence_truncate"` or `"token_truncate"`. ```python from trl import GRPOConfig @@ -235,10 +235,87 @@ training_args = GRPOConfig( ... use_vllm=True, vllm_importance_sampling_correction=True, # default True + vllm_importance_sampling_mode="sequence_truncate", # or "token_truncate" vllm_importance_sampling_cap=2.0, # hyper-parameter C ) ``` +### Masked Importance Sampling + +**📰 Blog**: https://ringtech.notion.site/icepop + +**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + +Masked Importance Sampling (MIS) addresses the same issue as [Truncated Importance Sampling](#truncated-importance-sampling) but replaces clipping with masking. MIS takes a more decisive stance by discarding updates whose discrepancy exceeds a threshold \\( C \\). We apply upper-side masking, so any ratio above \\( C \\) is removed from the update. + + +$$ +\small{ +\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})} +\Bigl[ +\underbrace{\mathbf{1}\left[ +\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})} +{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})} +\le C +\right] +\cdot +\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})} +{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})}}_{\text{masked importance ratio}} \cdot +\nabla_\theta +\min\Bigl( +\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A, +\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A +\Bigr) +\Bigr] +} +$$ + +MIS is implemented for GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `"mask"`, such as `"sequence_mask"` or `"token_mask"`. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ... + use_vllm=True, + vllm_importance_sampling_correction=True, # default True + vllm_importance_sampling_mode="sequence_mask", # or "token_mask" + vllm_importance_sampling_cap=2.0, # hyper-parameter C +) +``` + +### Sequence-level Importance Sampling + +**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + +The theoretically principled way to correct for the training-inference distribution shift is importance sampling, as introduced in the two papers above [Truncated Importance Sampling](#truncated-importance-sampling) and [Masked Importance Sampling](#masked-importance-sampling). However, the choice of formulation is crucial for keeping the gradient unbiased and ensuring stable training. + +This work shows that sequence-level importance sampling is the sound approach for addressing the training–inference mismatch. Although token-level importance sampling achieves lower variance than a sequence-level ratio, it introduces bias and is therefore argued to be unsuitable for autoregressive models. The token-level gradient estimator is + +$$ +\mathbb{E}_{x\sim\mathcal{D},\, y\sim \pi^{\text{inference}}_\theta(\cdot|x)} +\Bigg[ + R(x,y)\,\cdot\, + \sum_{t=0}^{|y|-1} + \frac{\pi^{\text{training}}_\theta(y_t\,|\,x, y_{ - vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`): - Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed - logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL - Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework - (such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation - and training backends. TIS is proposed as a remedy for this issue. - vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): - Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance - sampling ratio, improving training stability. + Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM + completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied + regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines + how the IS ratios are computed and constrained. + vllm_importance_sampling_mode (`str`, *optional*, defaults to `"sequence_mask"`): + Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`. Possible + values are: + + - `"token_truncate"`: Token-level truncated IS (default). Per-token ratios are clipped from above at C. + - `"token_mask"`: Token-level masked IS. Per-token ratios above C are set to zero. + - `"sequence_truncate"`: Sequence-level truncated IS. A single sequence ratio is clipped from above at + C and applied to all tokens in the sequence. + - `"sequence_mask"`: Sequence-level masked IS. Sequences with ratios above C are masked out. + vllm_importance_sampling_cap (`float`, *optional*, defaults to `3.0`): + Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes, + importance ratios are clipped from above at C. For `*_mask` modes, ratios larger than C + are set to zero. > Parameters that control the logging @@ -676,18 +684,32 @@ class GRPOConfig(TrainingArguments): vllm_importance_sampling_correction: bool = field( default=True, metadata={ - "help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and " - "recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL " - "Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy " - "effects due to subtle implementation differences between generation and training backends. TIS is " - "proposed as a remedy for this issue." + "help": "Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM " + "completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied " + "regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines how " + "IS ratios are computed and constrained." + }, + ) + + vllm_importance_sampling_mode: str = field( + default="sequence_mask", + metadata={ + "help": "Specifies how Importance Sampling (IS) is performed when " + "vllm_importance_sampling_correction=True. Modes are defined along two orthogonal " + "dimensions: (1) constraint, which determines how to handle ratios above " + "vllm_importance_sampling_cap (C)—either truncation (clip from above, ρ ← min(ρ, C)) or " + "masking (set ratios above C to zero); and (2) granularity, which determines whether " + "ratios are computed per token or as a single sequence-level ratio applied to all tokens. " + "Supported options are: 'token_truncate', 'token_mask', 'sequence_truncate', and " + "'sequence_mask'." }, ) + vllm_importance_sampling_cap: float = field( - default=2.0, + default=3.0, metadata={ - "help": "Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the " - "importance sampling ratio, improving training stability." + "help": "Importance sampling cap C used by `vllm_importance_sampling_mode`. For '*_truncate' modes, " + "ratios are clipped from above at C. For '*_mask' modes, ratios larger than C are set to zero." }, ) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 56d1a895dcb..995bfdc17ab 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -392,6 +392,7 @@ def __init__( self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction + self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap self.use_liger_kernel = args.use_liger_kernel self.loss_type = args.loss_type @@ -1576,10 +1577,33 @@ def _generate_and_score_completions( # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: - importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) - importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=self.vllm_importance_sampling_cap - ) + per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * completion_mask + + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + logps_diff = per_sequence_logps_diff + else: + logps_diff = per_token_logps_diff + + vllm_importance_sampling_ratio = torch.exp(logps_diff) + + # vllm_importance_sampling_ratio.shape: + # token_* modes: (B, T) (per-token ratio) + # sequence_* modes: (B, 1) (per-sequence ratio) + + if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: + vllm_importance_sampling_ratio = torch.clamp( + vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: + vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( + vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 + ) + else: + raise ValueError( + f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." + ) # Compute the per-token log probabilities for the reference model if self.beta != 0.0: @@ -1702,7 +1726,11 @@ def _generate_and_score_completions( self.accelerator.gather(max_delta).max().item() ) - flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + if sequence_level_is: + flat_is_ratio = vllm_importance_sampling_ratio.flatten() + else: + flat_is_ratio = vllm_importance_sampling_ratio[completion_mask.bool()] + min_importance_sampling_ratio = ( torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) ) @@ -1733,7 +1761,7 @@ def _generate_and_score_completions( if old_per_token_logps is not None: output["old_per_token_logps"] = old_per_token_logps if self.use_vllm and self.vllm_importance_sampling_correction: - output["importance_sampling_ratio"] = importance_sampling_ratio + output["importance_sampling_ratio"] = vllm_importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps if "pixel_values" in forward_kwargs: