Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,28 @@ trainer = PAPOTrainer(
)
```

### The Art of Scaling Reinforcement Learning

**📜 Paper**: https://huggingface.co/papers/2510.13786

A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL.

You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs:
You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with the following configs:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed for this PR, but it would be neat to have a list of what's supported and what's not. If useful, these are some light reading notes

baseline: GRPO beta=0 and asymetric clipping

scaling low applies after ∼1.5k GPU hours

3 stages:
1. Ablation: Test individual design choices on the baseline (3.5k–4k GPU-hours) to identify stable ones.
2. LOO Experiments: Combine stable choices into *ScaleRL* and run 16k GPU-hour leave-one-out tests to assess predictability (fit first 8k, extrapolate rest).
3. Scaling Demonstration: Validate ScaleRL predictability on larger, more complex setups (bigger batches, MoE models, multitask, longer sequences).

### Asynchronous RL Setup

**PPO-off-policy-k** is equivalent to `steps_per_generation` in TRL. -> faster convergence `steps_per_generation=8` than `steps_per_generation=1`
**PipelineRL-k** acheives true asyncronisity by allowing weight updates during rollouts. -> faster convergence than **PPO-off-policy-k**

Note that the asymptotic performance is the same for these methods.

### Algorithmic Choices

* loss type: DAPO/GSPO/CISPO: CISPO > GSPO >> DAPO (not yet in TRL, but PR opened by a contributor)
* precision fixes: fp32 for lm_head substantially improves final performance (`cast_lm_head_to_fp32=True` in TRL)
* loss aggregation: (GRPO vs DAPO vs DrGRPO): DAPO-style works best (`loss_type="dapo"` in TRL)
* advantage normalization (prompt-level vs batch-level vs none): equivalent, choose batch-level ( `scale_rewards="batch"` in TRL)
* batch definition: zero-Variance filtering gives better final performance (not implemented in TRL!)
* data curriculum: filter "too easy" prompt are filtered for futur epochs (not yet in TRL)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a nice list! I might have a PR for batch definition coming up, have a draft version that needs polishing.


```python
from trl import GRPOConfig

config = GRPOConfig(
loss_type="cispo",
epsilon_high=5,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, it's a float, so I think it's better:

Suggested change
epsilon_high=5,
epsilon_high=5.0,

num_completions=16,
scale_rewards="batch",
cast_lm_head_to_fp32=True
)
```



## Direct Policy Optimization

Papers relating to the [`DPOTrainer`]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo"])
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo"])
def test_training_loss_types(self, loss_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ class GRPOConfig(TrainingArguments):
batch. Note that normalization is performed over the local batch only, so results may slightly vary
depending on the local batch size, despite a constant effective batch size. When using
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
- `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just make sure that the line length <= 120

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by
normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper.
normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).

mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
Expand Down Expand Up @@ -582,6 +585,9 @@ class GRPOConfig(TrainingArguments):
"Note that normalization is performed over the local batch only, so results may slightly vary depending "
"on the local batch size, despite a constant effective batch size. When using "
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss."
"'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights "
"are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by "
"normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper."
},
)
mask_truncated_completions: bool = field(
Expand Down
65 changes: 39 additions & 26 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1816,19 +1816,25 @@ def _compute_loss(self, model, inputs):
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
"and 'sequence'."
)

coef_1 = torch.exp(log_importance_weights)

# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
if self.loss_type == "cispo":
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in the documentation of epsilon_high we can mention that this is the value used for epsilon_max when used with CISPO loss. and that the paper recommends =5.0

per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps

coef_1 = torch.exp(log_importance_weights)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
else:
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
# Two-sided clipping
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)

# Two-sided clipping
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if entropy_mask is not None:
per_token_loss = per_token_loss * entropy_mask

Expand All @@ -1847,7 +1853,7 @@ def _compute_loss(self, model, inputs):
elif self.loss_type == "dr_grpo":
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
loss = loss / self.current_gradient_accumulation_steps
elif self.loss_type == "dapo":
elif self.loss_type in ["cispo", "dapo"]:
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer
else:
Expand All @@ -1871,23 +1877,30 @@ def masked_batch_mean(x):
mean_entropy = masked_batch_mean(entropies)
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())

# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
is_region_clipped = is_low_clipped | is_high_clipped

low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())

gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
if self.loss_type != "cispo":

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, again (explicit better than implicit

Suggested change
if self.loss_type != "cispo":
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:

# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
is_region_clipped = is_low_clipped | is_high_clipped

low_clip = masked_batch_mean(is_low_clipped.float())
high_clip = masked_batch_mean(is_high_clipped.float())
clip_ratio = masked_batch_mean(is_region_clipped.float())

gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
else:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
elif loss_type == "cispo":

is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0)
cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float())
gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio)
self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item())

return loss

def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):
Expand Down
Loading