From 0b4c42d724b58f4d822f517a0680f24612224fb4 Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Sun, 16 Nov 2025 13:42:22 +0100 Subject: [PATCH 01/10] wip mis tis seq --- trl/trainer/grpo_config.py | 41 +++++++++++++++++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 26 +++++++++++++++++++---- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 2d97d67bd8e..918f33a9b8b 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -246,6 +246,35 @@ class GRPOConfig(TrainingArguments): Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework (such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation and training backends. TIS is proposed as a remedy for this issue. + + vllm_importance_sampling_mode (`str`, *optional*, defaults to `"token_truncate"`): + Controls how importance sampling (IS) is applied to correct for the mismatch between vLLM + completion logprobs and recomputed training logprobs when using a separate generation backend. + IS reweights the policy gradient to account for off-policy effects. The mode is defined by two orthogonal choices: + + * Constraint: how to handle importance ratios above `vllm_importance_sampling_cap` (C): + - truncation: clip ratios from above, ρ ← min(ρ, C), as in + [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) + - masking: set ratios above C to zero, so those contributions do not affect the gradient, as in + [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) + + * Granularity: the level at which ratios are computed: + - token: per-token ratios ρ_t + - sequence: a single ratio ρ_seq per sequence, applied to all tokens + + Supported options are: + - `"none"`: + Disable importance sampling. + - `"token_truncate"`: + Token-level truncated IS (default). Per-token ratios are clipped from above at C. + - `"token_mask"`: + Token-level masked IS. Per-token ratios above C are set to zero. + - `"sequence_truncate"`: + Sequence-level truncated IS. A single sequence ratio is clipped from above at C and + applied to all tokens in the sequence. + - `"sequence_mask"`: + Sequence-level masked IS. Sequences with ratios above C are masked out. + vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance sampling ratio, improving training stability. @@ -672,6 +701,18 @@ class GRPOConfig(TrainingArguments): "proposed as a remedy for this issue." }, ) + + vllm_importance_sampling_correction: str = field( + default="sequence_mask", + metadata={ + "help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and " + "recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL " + "Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy " + "effects due to subtle implementation differences between generation and training backends. TIS is " + "proposed as a remedy for this issue." + }, + ) + vllm_importance_sampling_cap: float = field( default=2.0, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ff7674214a3..b79913671ee 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1577,10 +1577,28 @@ def _generate_and_score_completions( # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: - importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps) - importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=self.vllm_importance_sampling_cap - ) + token_diff = old_per_token_logps - sampling_per_token_logps + + if self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]: + # sequence level importance sampling + importance_sampling_ratio = token_diff.sum(dim=-1) + + ### from here the importance_sampling_ratio shape depends on the method, for sequence it is ... and for token it is ... + + importance_sampling_ratio = torch.exp(importance_sampling_ratio) + + if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: + importance_sampling_ratio = torch.clamp( + importance_sampling_ratio, max=self.vllm_importance_sampling_cap + ) + elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: + importance_sampling_ratio = importance_sampling_ratio.masked_fill( + importance_sampling_ratio > self.vllm_importance_sampling_cap, 0.0 + ) + else: + raise ValueError( + f"Unknown vllm importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." + ) # Compute the per-token log probabilities for the reference model if self.beta != 0.0: From 2e06bf65bbf67fd5b663dafa012b4a03fe3da57e Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Sun, 16 Nov 2025 13:53:15 +0100 Subject: [PATCH 02/10] doc strings --- trl/trainer/grpo_config.py | 57 +++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 918f33a9b8b..34b25f2e87d 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -241,17 +241,15 @@ class GRPOConfig(TrainingArguments): vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`): - Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and recomputed - logprobs. [Your Efficient RL Framework Secretly Brings You Off-Policy RL - Training](https://fengyao.notion.site/off-policy-rl) highlights that using a separate generation framework - (such as vLLM) can introduce off-policy effects due to subtle implementation differences between generation - and training backends. TIS is proposed as a remedy for this issue. + Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM + completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied + regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines + how the IS ratios are computed and constrained. - vllm_importance_sampling_mode (`str`, *optional*, defaults to `"token_truncate"`): - Controls how importance sampling (IS) is applied to correct for the mismatch between vLLM - completion logprobs and recomputed training logprobs when using a separate generation backend. - IS reweights the policy gradient to account for off-policy effects. The mode is defined by two orthogonal choices: + vllm_importance_sampling_mode (`str`, *optional*, defaults to `"sequence_mask"`): + Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`. + The mode is defined along two orthogonal dimensions: * Constraint: how to handle importance ratios above `vllm_importance_sampling_cap` (C): - truncation: clip ratios from above, ρ ← min(ρ, C), as in [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) @@ -263,8 +261,6 @@ class GRPOConfig(TrainingArguments): - sequence: a single ratio ρ_seq per sequence, applied to all tokens Supported options are: - - `"none"`: - Disable importance sampling. - `"token_truncate"`: Token-level truncated IS (default). Per-token ratios are clipped from above at C. - `"token_mask"`: @@ -276,8 +272,9 @@ class GRPOConfig(TrainingArguments): Sequence-level masked IS. Sequences with ratios above C are masked out. vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): - Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the importance - sampling ratio, improving training stability. + Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes, + importance ratios are clipped from above at C. For `*_mask` modes, ratios larger than C + are set to zero. > Parameters that control the logging @@ -694,30 +691,38 @@ class GRPOConfig(TrainingArguments): vllm_importance_sampling_correction: bool = field( default=True, metadata={ - "help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and " - "recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL " - "Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy " - "effects due to subtle implementation differences between generation and training backends. TIS is " - "proposed as a remedy for this issue." + "help": ( + "Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM " + "completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied " + "regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines how " + "IS ratios are computed and constrained." + ) }, ) - vllm_importance_sampling_correction: str = field( + vllm_importance_sampling_mode: str = field( default="sequence_mask", metadata={ - "help": "Whether to apply Truncated Importance Sampling (TIS) between vLLM completion logprobs and " - "recomputed logprobs. Your Efficient RL Framework Secretly Brings You Off-Policy RL " - "Training highlights that using a separate generation framework (such as vLLM) can introduce off-policy " - "effects due to subtle implementation differences between generation and training backends. TIS is " - "proposed as a remedy for this issue." + "help": ( + "Specifies how Importance Sampling (IS) is performed when " + "vllm_importance_sampling_correction=True. Modes are defined along two orthogonal " + "dimensions: (1) constraint, which determines how to handle ratios above " + "vllm_importance_sampling_cap (C)—either truncation (clip from above, ρ ← min(ρ, C)) or " + "masking (set ratios above C to zero); and (2) granularity, which determines whether " + "ratios are computed per token or as a single sequence-level ratio applied to all tokens. " + "Supported options are: 'token_truncate', 'token_mask', 'sequence_truncate', and " + "'sequence_mask'." + ) }, ) vllm_importance_sampling_cap: float = field( default=2.0, metadata={ - "help": "Truncation parameter C for Truncated Importance Sampling (TIS). This sets an upper bound on the " - "importance sampling ratio, improving training stability." + "help": ( + "Importance sampling cap C used by `vllm_importance_sampling_mode`. For '*_truncate' modes, " + "ratios are clipped from above at C. For '*_mask' modes, ratios larger than C are set to zero." + ) }, ) From bcf5c5874af21bf449b96cdd5301bd9d7a6e8cf1 Mon Sep 17 00:00:00 2001 From: Leon Date: Sun, 16 Nov 2025 15:19:19 +0100 Subject: [PATCH 03/10] verified --- trl/trainer/grpo_trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b79913671ee..599b539b324 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -397,6 +397,7 @@ def __init__( self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.vllm_importance_sampling_correction = args.vllm_importance_sampling_correction + self.vllm_importance_sampling_mode = args.vllm_importance_sampling_mode self.vllm_importance_sampling_cap = args.vllm_importance_sampling_cap self.use_liger_kernel = args.use_liger_kernel self.loss_type = args.loss_type @@ -1577,11 +1578,11 @@ def _generate_and_score_completions( # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: - token_diff = old_per_token_logps - sampling_per_token_logps + token_logps_diff = old_per_token_logps - sampling_per_token_logps + token_logps_diff *= completion_mask if self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]: - # sequence level importance sampling - importance_sampling_ratio = token_diff.sum(dim=-1) + importance_sampling_ratio = token_logps_diff.sum(dim=-1, keepdim=True) ### from here the importance_sampling_ratio shape depends on the method, for sequence it is ... and for token it is ... @@ -1720,7 +1721,7 @@ def _generate_and_score_completions( self.accelerator.gather(max_delta).max().item() ) - flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + flat_is_ratio = token_logps_diff[completion_mask.bool()] min_importance_sampling_ratio = ( torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) ) From 80d85f20ef1b7398affa1ea3c5c3eaa3b84c1461 Mon Sep 17 00:00:00 2001 From: Leon Date: Sun, 16 Nov 2025 15:32:27 +0100 Subject: [PATCH 04/10] cleanup --- trl/trainer/grpo_trainer.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 599b539b324..3b649ba5d6b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1582,19 +1582,20 @@ def _generate_and_score_completions( token_logps_diff *= completion_mask if self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]: - importance_sampling_ratio = token_logps_diff.sum(dim=-1, keepdim=True) + sequence_logps_diff = token_logps_diff.sum(dim=-1, keepdim=True) - ### from here the importance_sampling_ratio shape depends on the method, for sequence it is ... and for token it is ... + vllm_importance_sampling_ratio = torch.exp(sequence_logps_diff) - importance_sampling_ratio = torch.exp(importance_sampling_ratio) + # From here, vllm_importance_sampling_ratio's shape depends on + # vllm_importance_sampling_mode: "token_*" level: (B, T); "sequence_*" level: (B, 1) if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: - importance_sampling_ratio = torch.clamp( - importance_sampling_ratio, max=self.vllm_importance_sampling_cap + vllm_importance_sampling_ratio = torch.clamp( + vllm_importance_sampling_ratio, max=self.vllm_importance_sampling_cap ) elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: - importance_sampling_ratio = importance_sampling_ratio.masked_fill( - importance_sampling_ratio > self.vllm_importance_sampling_cap, 0.0 + vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( + vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, 0.0 ) else: raise ValueError( @@ -1752,7 +1753,7 @@ def _generate_and_score_completions( if old_per_token_logps is not None: output["old_per_token_logps"] = old_per_token_logps if self.use_vllm and self.vllm_importance_sampling_correction: - output["importance_sampling_ratio"] = importance_sampling_ratio + output["vllm_importance_sampling_ratio"] = vllm_importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps if "pixel_values" in forward_kwargs: @@ -1906,7 +1907,7 @@ def _compute_loss(self, model, inputs): per_token_loss = per_token_loss * entropy_mask if self.use_vllm and self.vllm_importance_sampling_correction: - per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] + per_token_loss = per_token_loss * inputs["vllm_importance_sampling_ratio"] if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl From 19672cb3881be2c560f44fde336f8932a7ba98ad Mon Sep 17 00:00:00 2001 From: Leon Date: Sat, 22 Nov 2025 14:49:51 +0100 Subject: [PATCH 05/10] fix IS ratio logs + review nits --- trl/trainer/grpo_trainer.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3b649ba5d6b..debaf1758b8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1578,16 +1578,20 @@ def _generate_and_score_completions( # Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch if self.use_vllm and self.vllm_importance_sampling_correction: - token_logps_diff = old_per_token_logps - sampling_per_token_logps - token_logps_diff *= completion_mask + per_token_logps_diff = (old_per_token_logps - sampling_per_token_logps) * completion_mask - if self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"]: - sequence_logps_diff = token_logps_diff.sum(dim=-1, keepdim=True) + sequence_level_is = self.vllm_importance_sampling_mode in ["sequence_mask", "sequence_truncate"] + if sequence_level_is: + per_sequence_logps_diff = per_token_logps_diff.sum(dim=-1, keepdim=True) + logps_diff = per_sequence_logps_diff + else: + logps_diff = per_token_logps_diff - vllm_importance_sampling_ratio = torch.exp(sequence_logps_diff) + vllm_importance_sampling_ratio = torch.exp(logps_diff) - # From here, vllm_importance_sampling_ratio's shape depends on - # vllm_importance_sampling_mode: "token_*" level: (B, T); "sequence_*" level: (B, 1) + # vllm_importance_sampling_ratio.shape: + # token_* modes: (B, T) (per-token ratio) + # sequence_* modes: (B, 1) (per-sequence ratio) if self.vllm_importance_sampling_mode in ["sequence_truncate", "token_truncate"]: vllm_importance_sampling_ratio = torch.clamp( @@ -1595,11 +1599,11 @@ def _generate_and_score_completions( ) elif self.vllm_importance_sampling_mode in ["sequence_mask", "token_mask"]: vllm_importance_sampling_ratio = vllm_importance_sampling_ratio.masked_fill( - vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, 0.0 + vllm_importance_sampling_ratio > self.vllm_importance_sampling_cap, value=0.0 ) else: raise ValueError( - f"Unknown vllm importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." + f"Unknown vLLM importance sampling level: {self.vllm_importance_sampling_mode}. Possible values are 'token_truncate', 'token_mask', 'sequence_truncate', and 'sequence_mask'." ) # Compute the per-token log probabilities for the reference model @@ -1722,7 +1726,11 @@ def _generate_and_score_completions( self.accelerator.gather(max_delta).max().item() ) - flat_is_ratio = token_logps_diff[completion_mask.bool()] + if sequence_level_is: + flat_is_ratio = vllm_importance_sampling_ratio.flatten() + else: + flat_is_ratio = vllm_importance_sampling_ratio[completion_mask.bool()] + min_importance_sampling_ratio = ( torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) ) @@ -1753,7 +1761,7 @@ def _generate_and_score_completions( if old_per_token_logps is not None: output["old_per_token_logps"] = old_per_token_logps if self.use_vllm and self.vllm_importance_sampling_correction: - output["vllm_importance_sampling_ratio"] = vllm_importance_sampling_ratio + output["importance_sampling_ratio"] = vllm_importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps if "pixel_values" in forward_kwargs: @@ -1907,7 +1915,7 @@ def _compute_loss(self, model, inputs): per_token_loss = per_token_loss * entropy_mask if self.use_vllm and self.vllm_importance_sampling_correction: - per_token_loss = per_token_loss * inputs["vllm_importance_sampling_ratio"] + per_token_loss = per_token_loss * inputs["importance_sampling_ratio"] if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl From 4fd28af5df764f37d6acf051cfe7eb17d3c9977b Mon Sep 17 00:00:00 2001 From: Leon Date: Sat, 22 Nov 2025 17:15:27 +0100 Subject: [PATCH 06/10] update docs --- docs/source/grpo_trainer.md | 33 +++++++++++++++ docs/source/paper_index.md | 80 ++++++++++++++++++++++++++++++++++++- trl/trainer/grpo_config.py | 34 ++++------------ 3 files changed, 120 insertions(+), 27 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index bdc132e4115..c1a4d5d8a06 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -246,6 +246,39 @@ training_args = GRPOConfig( For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods). + +#### Dealing with the Training-Inference Mismatch +While vLLM greatly accelerates inference, it also decouples the inference engine from the training engine. In theory these engines are mathematically identical, in practice however they can produce different outputs due to precision effects and hardware specific optimizations. This divergence reflects the different optimization objectives of the two systems. This divergence reflects the distinct optimization goals of the two systems. Inference engines aim to maximize sampling throughput, typically measured in tokens per second, while maintaining acceptable sampling fidelity. Training frameworks instead focus on numerical stability and precision for gradient computation, often using higher precision formats like FP32 for master weights and optimizer states. These differing priorities and constraints introduce an inevitable, albeit subtle, mismatch between training and inference. + +This mismatch leads to a biased gradient update which has been observed to destabilize training ([[1]](https://fengyao.notion.site/off-policy-rl)[[2]](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda)[[3]](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/#true-on-policy-rl)[[4]](https://arxiv.org/abs/2510.26788)[[5]](https://arxiv.org/abs/2510.18855)). For simplicity, consider the REINFORCE policy gradient: + +$$ +\nabla_\theta \mathcal{J}(x,\theta) += \mathbb{E}*{y \sim \pi^\text{train}(\cdot \mid x,\theta)} +\left[ \nabla*\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right] +$$ + +Here (x) denotes prompts sampled from some data distribution, and (\pi^\text{train}) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy (\pi^\text{inference}), so the effective policy gradient becomes + +$$ +\nabla_\theta \mathcal{J}*{\text{biased}}(x,\theta) += \mathbb{E}*{y \sim \pi^\text{inference}(\cdot \mid x,\theta)} +\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right]. +$$ + +This turns an otherwise on policy RL problem into an off policy one. + +The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: Truncated Importance Sampling (TIS) and Masked Importance Sampling (MIS). Both variants can be applied either at the token level or at the sequence level.Let (\rho) denote the importance weight, for example (\rho_t) per token or (\rho_{\text{seq}}) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped, + +$$ +\rho \leftarrow \min(\rho, C). +$$ + +Under MIS, ratios larger than `vllm_importance_sampling_cap` are set to zero, so those samples do not contribute to the gradient. In other words, large ratio samples are downweighted under TIS and discarded under MIS. The configuration flag `vllm_importance_sampling_mode` chooses both the IS variant (masking or truncation) and the granularity (token level or sequence level). + +Importance sampling is the principled algorithmic response to the training–inference mismatch. However, there are also more direct approaches that attempt to reduce the mismatch between the two engines themselves. Most of these are engineering solutions. For example, [MiniMax M1 uses an FP32 language model head](https://arxiv.org/abs/2506.13585) in the inference engine. Thinking Machines has explored [deterministic inference kernels](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/), although this comes with a significant efficiency cost. vLLM has shown [bitwise consistent policies](https://blog.vllm.ai/2025/11/10/bitwise-consistent-train-inference.html) by building on the batch invariant deterministic kernels from Thinking Machines, but as of November 2025 there remains a substantial throughput penalty relative to standard vLLM inference. + + ### GRPO at scale: train a 70B+ Model on multiple nodes When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include: diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index bdc41263013..c7dd14b8d11 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -188,7 +188,7 @@ $$ } $$ -where \\( C \\) is a hyper-parameter. In TRL, TIS is implemented for GRPO, and enabled by default when vLLM is used for generation (`use_vllm=True`) +where \\( C \\) is a hyper-parameter. TIS is implemented in GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `truncate`, such as `sequence_truncate` or `token_truncate`. ```python from trl import GRPOConfig @@ -197,10 +197,88 @@ training_args = GRPOConfig( ... use_vllm=True, vllm_importance_sampling_correction=True, # default True + vllm_importance_sampling_mode="sequence_truncate", # or "token_truncate" vllm_importance_sampling_cap=2.0, # hyper-parameter C ) ``` +### Masked Importance Sampling + +**📰 Blog**: https://ringtech.notion.site/icepop +**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + +Masked Importance Sampling (MIS) addresses the same issue as [Truncated Importance Sampling](#truncated-importance-sampling) but replaces clipping with masking. MIS takes a more decisive stance by discarding updates whose discrepancy exceeds a threshold +\\( C \\). We apply upper-side masking, so any ratio above \\( C \\) is removed from the update. + + +$$ +\small{ +\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})} +\Bigl[ +\underbrace{\mathbf{1}\left[ +\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})} +{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})} +\le C +\right] +\cdot +\frac{\pi_{\text{training}}(a, \theta_{\mathrm{old}})} +{\pi_{\text{inference}}(a, \theta_{\mathrm{old}})}}_{\text{masked importance ratio}} \cdot +\nabla_\theta +\min\Bigl( +\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A, +\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A +\Bigr) +\Bigr] +} +$$ + +MIS is implemented for GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `mask`, such as `sequence_mask` or `token_mask`. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ... + use_vllm=True, + vllm_importance_sampling_correction=True, # default True + vllm_importance_sampling_mode="sequence_mask", # or "token_mask" + vllm_importance_sampling_cap=2.0, # hyper-parameter C +) +``` + +### Sequence-level Importance Sampling + +**📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda + +The theoretically principled way to correct for the training-inference distribution shift is importance sampling, as introduced in the two papers above [Truncated Importance Sampling](#truncated-importance-sampling) and [Masked Importance Sampling](). However, the choice of formulation is crucial for keeping the gradient unbiased and ensuring stable training. + +This work shows that sequence-level importance sampling is the sound approach for addressing the training–inference mismatch. Although token-level importance sampling achieves lower variance than a sequence-level ratio, it introduces bias and is therefore argued to be unsuitable for autoregressive models. The token-level gradient estimator is + +$$ +\mathbb{E}_{x\sim\mathcal{D},\, y\sim \pi^{\text{inference}}_\theta(\cdot|x)} +\Bigg[ + R(x,y)\,\cdot\, + \sum_{t=0}^{|y|-1} + \frac{\pi^{\text{training}}_\theta(y_t\,|\,x, y_{ - vllm_importance_sampling_correction (`bool`, *optional*, defaults to `True`): Whether to apply Importance Sampling (IS) to correct for the mismatch between vLLM completion logprobs and recomputed training logprobs. If set to `False`, no IS is applied regardless of `vllm_importance_sampling_mode`. When `True`, the selected mode determines how the IS ratios are computed and constrained. - vllm_importance_sampling_mode (`str`, *optional*, defaults to `"sequence_mask"`): - Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`. - - The mode is defined along two orthogonal dimensions: - * Constraint: how to handle importance ratios above `vllm_importance_sampling_cap` (C): - - truncation: clip ratios from above, ρ ← min(ρ, C), as in - [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) - - masking: set ratios above C to zero, so those contributions do not affect the gradient, as in - [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) - - * Granularity: the level at which ratios are computed: - - token: per-token ratios ρ_t - - sequence: a single ratio ρ_seq per sequence, applied to all tokens - - Supported options are: - - `"token_truncate"`: - Token-level truncated IS (default). Per-token ratios are clipped from above at C. - - `"token_mask"`: - Token-level masked IS. Per-token ratios above C are set to zero. - - `"sequence_truncate"`: - Sequence-level truncated IS. A single sequence ratio is clipped from above at C and - applied to all tokens in the sequence. - - `"sequence_mask"`: - Sequence-level masked IS. Sequences with ratios above C are masked out. - + Specifies how Importance Sampling is performed when `vllm_importance_sampling_correction=True`. Possible + values are: + + - `"token_truncate"`: Token-level truncated IS (default). Per-token ratios are clipped from above at C. + - `"token_mask"`: Token-level masked IS. Per-token ratios above C are set to zero. + - `"sequence_truncate"`: Sequence-level truncated IS. A single sequence ratio is clipped from above at + C and applied to all tokens in the sequence. + - `"sequence_mask"`: Sequence-level masked IS. Sequences with ratios above C are masked out. vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes, importance ratios are clipped from above at C. For `*_mask` modes, ratios larger than C From bffd9b8be7a7f91a17159927d5483e5bbd8c97e0 Mon Sep 17 00:00:00 2001 From: Leon Date: Sat, 22 Nov 2025 17:33:58 +0100 Subject: [PATCH 07/10] doc nits --- docs/source/grpo_trainer.md | 12 ++++++------ docs/source/paper_index.md | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index c1a4d5d8a06..517baee2e95 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -254,21 +254,21 @@ This mismatch leads to a biased gradient update which has been observed to desta $$ \nabla_\theta \mathcal{J}(x,\theta) -= \mathbb{E}*{y \sim \pi^\text{train}(\cdot \mid x,\theta)} -\left[ \nabla*\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right] += \mathbb{E}_{y \sim \pi^\text{train}(\cdot \mid x,\theta)} +\left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right] $$ -Here (x) denotes prompts sampled from some data distribution, and (\pi^\text{train}) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy (\pi^\text{inference}), so the effective policy gradient becomes +Here \\( x \\) denotes prompts sampled from some data distribution, and \\( \pi^\text{train} \\) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy \\( \pi^\text{inference} \\), so the effective policy gradient becomes $$ -\nabla_\theta \mathcal{J}*{\text{biased}}(x,\theta) -= \mathbb{E}*{y \sim \pi^\text{inference}(\cdot \mid x,\theta)} +\nabla_\theta \mathcal{J}_{\text{biased}}(x,\theta) += \mathbb{E}_{y \sim \pi^\text{inference}(\cdot \mid x,\theta)} \left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right]. $$ This turns an otherwise on policy RL problem into an off policy one. -The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: Truncated Importance Sampling (TIS) and Masked Importance Sampling (MIS). Both variants can be applied either at the token level or at the sequence level.Let (\rho) denote the importance weight, for example (\rho_t) per token or (\rho_{\text{seq}}) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped, +The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: [Truncated Importance Sampling (TIS)](paper_index#truncated-importance-sampling) and [Masked Importance Sampling (MIS)](paper_index#masked-importance-sampling). Both variants can be applied either at the token level or at the sequence level.Let (\rho) denote the importance weight, for example \\( \rho_t \\) per token or \\( \rho_{\text{seq}} \\) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped, $$ \rho \leftarrow \min(\rho, C). diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index c7dd14b8d11..78f74a24c49 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -188,7 +188,7 @@ $$ } $$ -where \\( C \\) is a hyper-parameter. TIS is implemented in GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `truncate`, such as `sequence_truncate` or `token_truncate`. +where \\( C \\) is a hyper-parameter. TIS is implemented in GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `truncate`, such as `"sequence_truncate"` or `"token_truncate"`. ```python from trl import GRPOConfig @@ -232,7 +232,7 @@ $$ } $$ -MIS is implemented for GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `mask`, such as `sequence_mask` or `token_mask`. +MIS is implemented for GRPO, and is enabled by selecting a `vllm_importance_sampling_mode` variant that includes the term `"mask"`, such as `"sequence_mask"` or `"token_mask"`. ```python from trl import GRPOConfig @@ -250,7 +250,7 @@ training_args = GRPOConfig( **📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda -The theoretically principled way to correct for the training-inference distribution shift is importance sampling, as introduced in the two papers above [Truncated Importance Sampling](#truncated-importance-sampling) and [Masked Importance Sampling](). However, the choice of formulation is crucial for keeping the gradient unbiased and ensuring stable training. +The theoretically principled way to correct for the training-inference distribution shift is importance sampling, as introduced in the two papers above [Truncated Importance Sampling](#truncated-importance-sampling) and [Masked Importance Sampling](#masked-importance-sampling). However, the choice of formulation is crucial for keeping the gradient unbiased and ensuring stable training. This work shows that sequence-level importance sampling is the sound approach for addressing the training–inference mismatch. Although token-level importance sampling achieves lower variance than a sequence-level ratio, it introduces bias and is therefore argued to be unsuitable for autoregressive models. The token-level gradient estimator is @@ -265,7 +265,7 @@ $$ \Bigg] $$ -The correct, unbiased policy gradient estimator applies a single importance ratio over the entire generated sequence (trajectory) $y$, The Sequence-Level IS estimator looks like: +The correct, unbiased policy gradient estimator applies a single importance ratio over the entire generated sequence (trajectory) \\( y \\), The Sequence-Level IS estimator looks like: $$ \mathbb{E}_{x\sim\mathcal{D},\, y\sim \pi^{\text{inference}}_\theta(\cdot|x)} @@ -277,7 +277,7 @@ $$ \Bigg] $$ -TRL exposes the Importance Sampling granularity level through the `vllm_importance_sampling_mode` configuration parameter where `sequence_*` modes implement a sequence-level importance sampling ratio and `token_*` a per-token ratio. +TRL exposes the Importance Sampling granularity level through the `vllm_importance_sampling_mode` configuration parameter where `"sequence_*"` modes implement a sequence-level importance sampling ratio and `"token_*"` a per-token ratio. ### Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning From 9a1cb7bbce1c03a69431bc58fbc59865b07b70ff Mon Sep 17 00:00:00 2001 From: Leon Date: Sat, 22 Nov 2025 17:39:59 +0100 Subject: [PATCH 08/10] doc nits --- docs/source/grpo_trainer.md | 5 ++--- docs/source/paper_index.md | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 517baee2e95..1cf41a0b5ea 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -258,7 +258,7 @@ $$ \left[ \nabla_\theta \log \pi^\text{train}(y \mid x,\theta) \cdot R(x,y) \right] $$ -Here \\( x \\) denotes prompts sampled from some data distribution, and \\( \pi^\text{train} \\) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy \\( \pi^\text{inference} \\), so the effective policy gradient becomes +Here \\( x \\) denotes prompts sampled from some data distribution, and \\( \pi^\text{train} \\) is the policy implemented by the training engine. With vLLM in the loop we obtain a separate inference policy \\( \pi^\text{inference} \\), so the effective policy gradient becomes $$ \nabla_\theta \mathcal{J}_{\text{biased}}(x,\theta) @@ -268,7 +268,7 @@ $$ This turns an otherwise on policy RL problem into an off policy one. -The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: [Truncated Importance Sampling (TIS)](paper_index#truncated-importance-sampling) and [Masked Importance Sampling (MIS)](paper_index#masked-importance-sampling). Both variants can be applied either at the token level or at the sequence level.Let (\rho) denote the importance weight, for example \\( \rho_t \\) per token or \\( \rho_{\text{seq}} \\) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped, +The standard way to correct for this distribution shift is **importance sampling (IS)**. We provide two IS variants: [Truncated Importance Sampling (TIS)](paper_index#truncated-importance-sampling) and [Masked Importance Sampling (MIS)](paper_index#masked-importance-sampling). Both variants can be applied either at the token level or at the sequence level.Let \\( \rho \\) denote the importance weight, for example \\( \rho_t \\) per token or \\( \rho_{\text{seq}} \\) per sequence. Under TIS, ratios larger than `vllm_importance_sampling_cap` are clipped, $$ \rho \leftarrow \min(\rho, C). @@ -278,7 +278,6 @@ Under MIS, ratios larger than `vllm_importance_sampling_cap` are set to zero, so Importance sampling is the principled algorithmic response to the training–inference mismatch. However, there are also more direct approaches that attempt to reduce the mismatch between the two engines themselves. Most of these are engineering solutions. For example, [MiniMax M1 uses an FP32 language model head](https://arxiv.org/abs/2506.13585) in the inference engine. Thinking Machines has explored [deterministic inference kernels](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/), although this comes with a significant efficiency cost. vLLM has shown [bitwise consistent policies](https://blog.vllm.ai/2025/11/10/bitwise-consistent-train-inference.html) by building on the batch invariant deterministic kernels from Thinking Machines, but as of November 2025 there remains a substantial throughput penalty relative to standard vLLM inference. - ### GRPO at scale: train a 70B+ Model on multiple nodes When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include: diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 78f74a24c49..5ea4628f2f7 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -207,8 +207,7 @@ training_args = GRPOConfig( **📰 Blog**: https://ringtech.notion.site/icepop **📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda -Masked Importance Sampling (MIS) addresses the same issue as [Truncated Importance Sampling](#truncated-importance-sampling) but replaces clipping with masking. MIS takes a more decisive stance by discarding updates whose discrepancy exceeds a threshold -\\( C \\). We apply upper-side masking, so any ratio above \\( C \\) is removed from the update. +Masked Importance Sampling (MIS) addresses the same issue as [Truncated Importance Sampling](#truncated-importance-sampling) but replaces clipping with masking. MIS takes a more decisive stance by discarding updates whose discrepancy exceeds a threshold \\( C \\). We apply upper-side masking, so any ratio above \\( C \\) is removed from the update. $$ From 5ef70afa0e86954a843c0f1cb6c348ab95d727d4 Mon Sep 17 00:00:00 2001 From: Leon Date: Sun, 23 Nov 2025 11:19:06 +0100 Subject: [PATCH 09/10] review nits --- docs/source/paper_index.md | 4 ++-- trl/trainer/grpo_config.py | 34 ++++++++++++++-------------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 5ea4628f2f7..1fc5691fd90 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -205,6 +205,7 @@ training_args = GRPOConfig( ### Masked Importance Sampling **📰 Blog**: https://ringtech.notion.site/icepop + **📰 Blog**: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda Masked Importance Sampling (MIS) addresses the same issue as [Truncated Importance Sampling](#truncated-importance-sampling) but replaces clipping with masking. MIS takes a more decisive stance by discarding updates whose discrepancy exceeds a threshold \\( C \\). We apply upper-side masking, so any ratio above \\( C \\) is removed from the update. @@ -263,8 +264,7 @@ $$ \,\nabla_\theta \log \pi^{\text{training}}_\theta(y_t\,|\,x, y_{ Date: Tue, 25 Nov 2025 08:25:51 +0100 Subject: [PATCH 10/10] IS cap default 3 --- trl/trainer/grpo_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 2e96578e86c..5510417d02d 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -253,7 +253,7 @@ class GRPOConfig(TrainingArguments): - `"sequence_truncate"`: Sequence-level truncated IS. A single sequence ratio is clipped from above at C and applied to all tokens in the sequence. - `"sequence_mask"`: Sequence-level masked IS. Sequences with ratios above C are masked out. - vllm_importance_sampling_cap (`float`, *optional*, defaults to `2.0`): + vllm_importance_sampling_cap (`float`, *optional*, defaults to `3.0`): Importance sampling cap C used by `vllm_importance_sampling_mode`. For `*_truncate` modes, importance ratios are clipped from above at C. For `*_mask` modes, ratios larger than C are set to zero. @@ -695,7 +695,7 @@ class GRPOConfig(TrainingArguments): ) vllm_importance_sampling_cap: float = field( - default=2.0, + default=3.0, metadata={ "help": "Importance sampling cap C used by `vllm_importance_sampling_mode`. For '*_truncate' modes, " "ratios are clipped from above at C. For '*_mask' modes, ratios larger than C are set to zero."