-
Notifications
You must be signed in to change notification settings - Fork 2.8k
ScaleRL: Add CISPO Loss #4495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ScaleRL: Add CISPO Loss #4495
Changes from 6 commits
44577a1
13512d2
0ced03b
ecd2c22
b8de22d
7472f9e
4e9e47f
442943a
6b30fb0
3de5ce8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -232,6 +232,28 @@ trainer = PAPOTrainer( | |||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
| ### The Art of Scaling Reinforcement Learning | ||||||
|
|
||||||
| **📜 Paper**: https://huggingface.co/papers/2510.13786 | ||||||
|
|
||||||
| A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL. | ||||||
|
|
||||||
| You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed for this PR, but it would be neat to have a list of what's supported and what's not. If useful, these are some light reading notes
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh that's a nice list! I might have a PR for |
||||||
|
|
||||||
| ```python | ||||||
| from trl import GRPOConfig | ||||||
|
|
||||||
| config = GRPOConfig( | ||||||
| loss_type="cispo", | ||||||
| epsilon_high=5, | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, it's a float, so I think it's better:
Suggested change
|
||||||
| num_completions=16, | ||||||
| scale_rewards="batch", | ||||||
| cast_lm_head_to_fp32=True | ||||||
| ) | ||||||
| ``` | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
| ## Direct Policy Optimization | ||||||
|
|
||||||
| Papers relating to the [`DPOTrainer`] | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -201,6 +201,9 @@ class GRPOConfig(TrainingArguments): | |||||
| batch. Note that normalization is performed over the local batch only, so results may slightly vary | ||||||
| depending on the local batch size, despite a constant effective batch size. When using | ||||||
| `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. | ||||||
| - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you just make sure that the line length <= 120
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||
| are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by | ||||||
| normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| mask_truncated_completions (`bool`, *optional*, defaults to `False`): | ||||||
| When enabled, truncated completions are excluded from the loss calculation, preventing them from being | ||||||
| incorrectly penalized and introducing noise during training. According to the | ||||||
|
|
@@ -582,6 +585,9 @@ class GRPOConfig(TrainingArguments): | |||||
| "Note that normalization is performed over the local batch only, so results may slightly vary depending " | ||||||
| "on the local batch size, despite a constant effective batch size. When using " | ||||||
| "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss." | ||||||
| "'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights " | ||||||
| "are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by " | ||||||
| "normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper." | ||||||
| }, | ||||||
| ) | ||||||
| mask_truncated_completions: bool = field( | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1816,19 +1816,25 @@ def _compute_loss(self, model, inputs): | |||||
| f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " | ||||||
| "and 'sequence'." | ||||||
| ) | ||||||
|
|
||||||
| coef_1 = torch.exp(log_importance_weights) | ||||||
|
|
||||||
| # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on | ||||||
| # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) | ||||||
| if self.loss_type == "cispo": | ||||||
| clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe in the documentation of |
||||||
| per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps | ||||||
|
|
||||||
| coef_1 = torch.exp(log_importance_weights) | ||||||
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) | ||||||
| else: | ||||||
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) | ||||||
| # Two-sided clipping | ||||||
| if self.args.delta is not None: | ||||||
| coef_1 = torch.clamp(coef_1, max=self.args.delta) | ||||||
|
|
||||||
| # Two-sided clipping | ||||||
| if self.args.delta is not None: | ||||||
| coef_1 = torch.clamp(coef_1, max=self.args.delta) | ||||||
| per_token_loss1 = coef_1 * advantages.unsqueeze(1) | ||||||
| per_token_loss2 = coef_2 * advantages.unsqueeze(1) | ||||||
| per_token_loss = -torch.min(per_token_loss1, per_token_loss2) | ||||||
|
|
||||||
| per_token_loss1 = coef_1 * advantages.unsqueeze(1) | ||||||
| per_token_loss2 = coef_2 * advantages.unsqueeze(1) | ||||||
| per_token_loss = -torch.min(per_token_loss1, per_token_loss2) | ||||||
| if entropy_mask is not None: | ||||||
| per_token_loss = per_token_loss * entropy_mask | ||||||
|
|
||||||
|
|
@@ -1847,7 +1853,7 @@ def _compute_loss(self, model, inputs): | |||||
| elif self.loss_type == "dr_grpo": | ||||||
| loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) | ||||||
| loss = loss / self.current_gradient_accumulation_steps | ||||||
| elif self.loss_type == "dapo": | ||||||
| elif self.loss_type in ["cispo", "dapo"]: | ||||||
| normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes | ||||||
| loss = (per_token_loss * completion_mask).sum() / normalizer | ||||||
| else: | ||||||
|
|
@@ -1871,23 +1877,30 @@ def masked_batch_mean(x): | |||||
| mean_entropy = masked_batch_mean(entropies) | ||||||
| self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) | ||||||
|
|
||||||
| # Compute the clipped probability ratios | ||||||
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) | ||||||
| is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) | ||||||
| is_region_clipped = is_low_clipped | is_high_clipped | ||||||
|
|
||||||
| low_clip = masked_batch_mean(is_low_clipped.float()) | ||||||
| high_clip = masked_batch_mean(is_high_clipped.float()) | ||||||
| clip_ratio = masked_batch_mean(is_region_clipped.float()) | ||||||
|
|
||||||
| gathered_low_clip = self.accelerator.gather(low_clip) | ||||||
| self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) | ||||||
| gathered_high_clip = self.accelerator.gather(high_clip) | ||||||
| self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) | ||||||
| gathered_clip_ratio = self.accelerator.gather(clip_ratio) | ||||||
| self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) | ||||||
| if self.loss_type != "cispo": | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, again (explicit better than implicit
Suggested change
|
||||||
| # Compute the clipped probability ratios | ||||||
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) | ||||||
| is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) | ||||||
| is_region_clipped = is_low_clipped | is_high_clipped | ||||||
|
|
||||||
| low_clip = masked_batch_mean(is_low_clipped.float()) | ||||||
| high_clip = masked_batch_mean(is_high_clipped.float()) | ||||||
| clip_ratio = masked_batch_mean(is_region_clipped.float()) | ||||||
|
|
||||||
| gathered_low_clip = self.accelerator.gather(low_clip) | ||||||
| self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) | ||||||
| gathered_high_clip = self.accelerator.gather(high_clip) | ||||||
| self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) | ||||||
| gathered_clip_ratio = self.accelerator.gather(clip_ratio) | ||||||
| self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) | ||||||
| else: | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) | ||||||
| cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) | ||||||
| gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) | ||||||
| self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) | ||||||
|
|
||||||
| return loss | ||||||
|
|
||||||
| def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.