From 44577a1d76b8aaa9d0905a2ca3fa7e9fa96c6357 Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 15:00:29 +0000 Subject: [PATCH 01/10] CISPO Loss --- tests/test_grpo_trainer.py | 2 +- trl/trainer/grpo_config.py | 6 ++++ trl/trainer/grpo_trainer.py | 63 ++++++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 88d2579a69d..a3bf80a641b 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -167,7 +167,7 @@ def test_training(self, config_name): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo"]) + @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo"]) def test_training_loss_types(self, loss_type): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 6001a8dc524..e0e87322d53 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -201,6 +201,9 @@ class GRPOConfig(TrainingArguments): batch. Note that normalization is performed over the local batch only, so results may slightly vary depending on the local batch size, despite a constant effective batch size. When using `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. + - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights + are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by + normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper. mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -582,6 +585,9 @@ class GRPOConfig(TrainingArguments): "Note that normalization is performed over the local batch only, so results may slightly vary depending " "on the local batch size, despite a constant effective batch size. When using " "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss." + "'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights " + "are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by " + "normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper." }, ) mask_truncated_completions: bool = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index fe29c78b92a..341473975ed 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1818,17 +1818,21 @@ def _compute_loss(self, model, inputs): ) # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) + if self.loss_type == "cispo": + coef_1 = torch.min(log_importance_weights, 1 + self.epsilon_high, dim=-1).detach() + per_token_loss = coef_1 * advantages.unsqueeze(1) * per_token_logps - coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + else: + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) - # Two-sided clipping - if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask @@ -1847,7 +1851,7 @@ def _compute_loss(self, model, inputs): elif self.loss_type == "dr_grpo": loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) loss = loss / self.current_gradient_accumulation_steps - elif self.loss_type == "dapo": + elif self.loss_type in ["cispo", "dapo"]: normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * completion_mask).sum() / normalizer else: @@ -1871,23 +1875,30 @@ def masked_batch_mean(x): mean_entropy = masked_batch_mean(entropies) self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) - # Compute the clipped probability ratios - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) - is_region_clipped = is_low_clipped | is_high_clipped - - low_clip = masked_batch_mean(is_low_clipped.float()) - high_clip = masked_batch_mean(is_high_clipped.float()) - clip_ratio = masked_batch_mean(is_region_clipped.float()) - - gathered_low_clip = self.accelerator.gather(low_clip) - self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) - gathered_high_clip = self.accelerator.gather(high_clip) - self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) - self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) - gathered_clip_ratio = self.accelerator.gather(clip_ratio) - self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + if self.loss_type != "cispo": + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = masked_batch_mean(is_low_clipped.float()) + high_clip = masked_batch_mean(is_high_clipped.float()) + clip_ratio = masked_batch_mean(is_region_clipped.float()) + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + else: + is_cispo_clipped = (log_importance_weights > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) + gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) + self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) + return loss def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): From 13512d25f1034a3036971d09b7273a25ed387b2c Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 15:12:02 +0000 Subject: [PATCH 02/10] fix coef_1 calculation --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 341473975ed..9faf9accc1d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1819,7 +1819,7 @@ def _compute_loss(self, model, inputs): # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) if self.loss_type == "cispo": - coef_1 = torch.min(log_importance_weights, 1 + self.epsilon_high, dim=-1).detach() + coef_1 = torch.clamp(torch.exp(log_importance_weights), max=self.epsilon_high).detach() per_token_loss = coef_1 * advantages.unsqueeze(1) * per_token_logps else: From 0ced03ba8ae9c3748b3ac67babd4c36978a30d50 Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 15:31:21 +0000 Subject: [PATCH 03/10] Add paper index. --- docs/source/paper_index.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 6467548d8ea..4e7ab86dcd5 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -232,6 +232,28 @@ trainer = PAPOTrainer( ) ``` +### The Art of Scaling Reinforcement Learning + +**📜 Paper**: https://huggingface.co/papers/2510.13786 + +A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL. + +You can parially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs: + +```python +from trl import GPROConfig + +config = GRPOConfig( + loss_type="cispo", + epsilon_high=5, + num_completions=16, + scale_rewards="batch", + cast_lm_head_to_fp32=True +) +``` + + + ## Direct Policy Optimization Papers relating to the [`DPOTrainer`] From ecd2c228dc1ce51897cd3bb50214b67dde04275a Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 15:45:04 +0000 Subject: [PATCH 04/10] bugs --- docs/source/paper_index.md | 4 ++-- trl/trainer/grpo_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 4e7ab86dcd5..3633b509e7c 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -238,10 +238,10 @@ trainer = PAPOTrainer( A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL. -You can parially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs: +You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs: ```python -from trl import GPROConfig +from trl import GRPOConfig config = GRPOConfig( loss_type="cispo", diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9faf9accc1d..ee6c630bab2 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1820,7 +1820,7 @@ def _compute_loss(self, model, inputs): # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) if self.loss_type == "cispo": coef_1 = torch.clamp(torch.exp(log_importance_weights), max=self.epsilon_high).detach() - per_token_loss = coef_1 * advantages.unsqueeze(1) * per_token_logps + per_token_loss = -coef_1 * advantages.unsqueeze(1) * per_token_logps else: coef_1 = torch.exp(log_importance_weights) @@ -1894,7 +1894,7 @@ def masked_batch_mean(x): gathered_clip_ratio = self.accelerator.gather(clip_ratio) self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) else: - is_cispo_clipped = (log_importance_weights > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_cispo_clipped = (torch.exp(log_importance_weights) > self.epsilon_high) & (advantages.unsqueeze(1) > 0) cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) From b8de22da6be5c34d2d454e68749ed4cb24912c0e Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 15:48:30 +0000 Subject: [PATCH 05/10] clean up. --- trl/trainer/grpo_trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ee6c630bab2..e3603474377 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1816,14 +1816,16 @@ def _compute_loss(self, model, inputs): f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " "and 'sequence'." ) + + coef_1 = torch.exp(log_importance_weights) + # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) if self.loss_type == "cispo": - coef_1 = torch.clamp(torch.exp(log_importance_weights), max=self.epsilon_high).detach() - per_token_loss = -coef_1 * advantages.unsqueeze(1) * per_token_logps + clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() + per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps else: - coef_1 = torch.exp(log_importance_weights) coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) # Two-sided clipping if self.args.delta is not None: @@ -1894,7 +1896,7 @@ def masked_batch_mean(x): gathered_clip_ratio = self.accelerator.gather(clip_ratio) self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) else: - is_cispo_clipped = (torch.exp(log_importance_weights) > self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) From 7472f9efbee1a339b0d0610a785e0ef26c1bba8e Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 15:49:40 +0000 Subject: [PATCH 06/10] precommit --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e3603474377..1b86804ae0e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1816,7 +1816,7 @@ def _compute_loss(self, model, inputs): f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " "and 'sequence'." ) - + coef_1 = torch.exp(log_importance_weights) # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on From 4e9e47fc48e61a8f8587b0318a8200a3b58e488d Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 18:40:33 +0000 Subject: [PATCH 07/10] Update docs. --- docs/source/paper_index.md | 4 ++-- trl/trainer/grpo_config.py | 22 +++++++++++++++------- trl/trainer/grpo_trainer.py | 4 ++-- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 3633b509e7c..b5ac187f744 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -238,14 +238,14 @@ trainer = PAPOTrainer( A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL. -You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs: +You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with the following configs: ```python from trl import GRPOConfig config = GRPOConfig( loss_type="cispo", - epsilon_high=5, + epsilon_high=5.0, num_completions=16, scale_rewards="batch", cast_lm_head_to_fp32=True diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index e0e87322d53..41383a3bc3c 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -166,6 +166,8 @@ class GRPOConfig(TrainingArguments): epsilon_high (`float`, *optional*): Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. + When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the + [ScaleRL paper](https://arxiv.org/pdf/2510.13786). importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the @@ -201,9 +203,11 @@ class GRPOConfig(TrainingArguments): batch. Note that normalization is performed over the local batch only, so results may slightly vary depending on the local batch size, despite a constant effective batch size. When using `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. - - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights - are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by - normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper. + - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. + The clipped weights are then multiplied with the advantages and policy model's log probs. + Individual token losses are aggregated by normalizing with the number of active tokens in + the global accumulated batch. This method was introduced in the + [MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -536,7 +540,9 @@ class GRPOConfig(TrainingArguments): default=None, metadata={ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " - "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." + "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. " + "When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the" + "[ScaleRL paper](https://arxiv.org/pdf/2510.13786)." }, ) importance_sampling_level: str = field( @@ -585,9 +591,11 @@ class GRPOConfig(TrainingArguments): "Note that normalization is performed over the local batch only, so results may slightly vary depending " "on the local batch size, despite a constant effective batch size. When using " "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss." - "'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights " - "are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by " - "normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper." + "'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. " + "The clipped weights are then multiplied with the advantages and policy model's log probs. " + "Individual token losses are aggregated by normalizing with the number of active tokens in " + "the global accumulated batch. This method was introduced in the " + "[MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)." }, ) mask_truncated_completions: bool = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 1b86804ae0e..bddef14adfc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1821,7 +1821,7 @@ def _compute_loss(self, model, inputs): # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) - if self.loss_type == "cispo": + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps @@ -1895,7 +1895,7 @@ def masked_batch_mean(x): self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) gathered_clip_ratio = self.accelerator.gather(clip_ratio) self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) - else: + elif self.loss_type == "cispo": is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) From 442943a8898808d0972fd5063bb258d37d69a686 Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 6 Nov 2025 18:43:46 +0000 Subject: [PATCH 08/10] recommended value --- trl/trainer/grpo_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 41383a3bc3c..ee5011855c0 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -167,7 +167,7 @@ class GRPOConfig(TrainingArguments): Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the - [ScaleRL paper](https://arxiv.org/pdf/2510.13786). + [ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`. importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the @@ -542,7 +542,7 @@ class GRPOConfig(TrainingArguments): "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. " "When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the" - "[ScaleRL paper](https://arxiv.org/pdf/2510.13786)." + "[ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`." }, ) importance_sampling_level: str = field( From 6b30fb0b35a0bd8b3426f41965e9e05fb3d82ae9 Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Thu, 6 Nov 2025 21:38:56 +0000 Subject: [PATCH 09/10] nits --- trl/trainer/grpo_config.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ee5011855c0..4950859b215 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -166,8 +166,8 @@ class GRPOConfig(TrainingArguments): epsilon_high (`float`, *optional*): Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. - When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the - [ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`. + When used with `loss_type="cispo"`, this corresponds to the ε_max param specified in the [ScaleRL + paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`. importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the @@ -203,11 +203,10 @@ class GRPOConfig(TrainingArguments): batch. Note that normalization is performed over the local batch only, so results may slightly vary depending on the local batch size, despite a constant effective batch size. When using `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. - - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. - The clipped weights are then multiplied with the advantages and policy model's log probs. - Individual token losses are aggregated by normalizing with the number of active tokens in - the global accumulated batch. This method was introduced in the - [MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). + - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The + clipped weights are then multiplied with the advantages and policy model's log probs. Individual token + losses are aggregated by normalizing with the number of active tokens in the global accumulated batch. + This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -542,7 +541,7 @@ class GRPOConfig(TrainingArguments): "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. " "When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the" - "[ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`." + "[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`." }, ) importance_sampling_level: str = field( @@ -595,7 +594,7 @@ class GRPOConfig(TrainingArguments): "The clipped weights are then multiplied with the advantages and policy model's log probs. " "Individual token losses are aggregated by normalizing with the number of active tokens in " "the global accumulated batch. This method was introduced in the " - "[MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)." + "[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)." }, ) mask_truncated_completions: bool = field( From 3de5ce85067adca284190ad9eff02bf5b8251a34 Mon Sep 17 00:00:00 2001 From: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Thu, 6 Nov 2025 21:40:49 +0000 Subject: [PATCH 10/10] nits --- trl/trainer/grpo_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 4950859b215..260a73ab164 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -166,7 +166,7 @@ class GRPOConfig(TrainingArguments): epsilon_high (`float`, *optional*): Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. - When used with `loss_type="cispo"`, this corresponds to the ε_max param specified in the [ScaleRL + When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the [ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`. importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` @@ -540,7 +540,7 @@ class GRPOConfig(TrainingArguments): metadata={ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. " - "When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the" + "When used with `loss_type='cispo'`, this corresponds to the ε_max param specified in the" "[ScaleRL paper]https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`." }, )