-
Notifications
You must be signed in to change notification settings - Fork 2.8k
ScaleRL: Add CISPO Loss #4495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ScaleRL: Add CISPO Loss #4495
Changes from 8 commits
44577a1
13512d2
0ced03b
ecd2c22
b8de22d
7472f9e
4e9e47f
442943a
6b30fb0
3de5ce8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -166,6 +166,8 @@ class GRPOConfig(TrainingArguments): | |||||
| epsilon_high (`float`, *optional*): | ||||||
| Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound | ||||||
| specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. | ||||||
| When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the | ||||||
| [ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| importance_sampling_level (`str`, *optional*, defaults to `"token"`): | ||||||
| Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` | ||||||
| keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the | ||||||
|
|
@@ -201,6 +203,11 @@ class GRPOConfig(TrainingArguments): | |||||
| batch. Note that normalization is performed over the local batch only, so results may slightly vary | ||||||
| depending on the local batch size, despite a constant effective batch size. When using | ||||||
| `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. | ||||||
| - `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. | ||||||
| The clipped weights are then multiplied with the advantages and policy model's log probs. | ||||||
| Individual token losses are aggregated by normalizing with the number of active tokens in | ||||||
| the global accumulated batch. This method was introduced in the | ||||||
| [MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585). | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| mask_truncated_completions (`bool`, *optional*, defaults to `False`): | ||||||
| When enabled, truncated completions are excluded from the loss calculation, preventing them from being | ||||||
| incorrectly penalized and introducing noise during training. According to the | ||||||
|
|
@@ -533,7 +540,9 @@ class GRPOConfig(TrainingArguments): | |||||
| default=None, | ||||||
| metadata={ | ||||||
| "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " | ||||||
| "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." | ||||||
| "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. " | ||||||
| "When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the" | ||||||
| "[ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`." | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| }, | ||||||
| ) | ||||||
| importance_sampling_level: str = field( | ||||||
|
|
@@ -582,6 +591,11 @@ class GRPOConfig(TrainingArguments): | |||||
| "Note that normalization is performed over the local batch only, so results may slightly vary depending " | ||||||
| "on the local batch size, despite a constant effective batch size. When using " | ||||||
| "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss." | ||||||
| "'cispo': Clips the importance sampling weights instead of the advantage scaled importance weights. " | ||||||
| "The clipped weights are then multiplied with the advantages and policy model's log probs. " | ||||||
| "Individual token losses are aggregated by normalizing with the number of active tokens in " | ||||||
| "the global accumulated batch. This method was introduced in the " | ||||||
| "[MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)." | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| }, | ||||||
| ) | ||||||
| mask_truncated_completions: bool = field( | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1816,19 +1816,25 @@ def _compute_loss(self, model, inputs): | |||||
| f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " | ||||||
| "and 'sequence'." | ||||||
| ) | ||||||
|
|
||||||
| coef_1 = torch.exp(log_importance_weights) | ||||||
|
|
||||||
| # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on | ||||||
| # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) | ||||||
| if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: | ||||||
| clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe in the documentation of |
||||||
| per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps | ||||||
|
|
||||||
| coef_1 = torch.exp(log_importance_weights) | ||||||
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) | ||||||
| else: | ||||||
| coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) | ||||||
| # Two-sided clipping | ||||||
| if self.args.delta is not None: | ||||||
| coef_1 = torch.clamp(coef_1, max=self.args.delta) | ||||||
|
|
||||||
| # Two-sided clipping | ||||||
| if self.args.delta is not None: | ||||||
| coef_1 = torch.clamp(coef_1, max=self.args.delta) | ||||||
| per_token_loss1 = coef_1 * advantages.unsqueeze(1) | ||||||
| per_token_loss2 = coef_2 * advantages.unsqueeze(1) | ||||||
| per_token_loss = -torch.min(per_token_loss1, per_token_loss2) | ||||||
|
|
||||||
| per_token_loss1 = coef_1 * advantages.unsqueeze(1) | ||||||
| per_token_loss2 = coef_2 * advantages.unsqueeze(1) | ||||||
| per_token_loss = -torch.min(per_token_loss1, per_token_loss2) | ||||||
| if entropy_mask is not None: | ||||||
| per_token_loss = per_token_loss * entropy_mask | ||||||
|
|
||||||
|
|
@@ -1847,7 +1853,7 @@ def _compute_loss(self, model, inputs): | |||||
| elif self.loss_type == "dr_grpo": | ||||||
| loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) | ||||||
| loss = loss / self.current_gradient_accumulation_steps | ||||||
| elif self.loss_type == "dapo": | ||||||
| elif self.loss_type in ["cispo", "dapo"]: | ||||||
| normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes | ||||||
| loss = (per_token_loss * completion_mask).sum() / normalizer | ||||||
| else: | ||||||
|
|
@@ -1871,23 +1877,30 @@ def masked_batch_mean(x): | |||||
| mean_entropy = masked_batch_mean(entropies) | ||||||
| self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) | ||||||
|
|
||||||
| # Compute the clipped probability ratios | ||||||
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) | ||||||
| is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) | ||||||
| is_region_clipped = is_low_clipped | is_high_clipped | ||||||
|
|
||||||
| low_clip = masked_batch_mean(is_low_clipped.float()) | ||||||
| high_clip = masked_batch_mean(is_high_clipped.float()) | ||||||
| clip_ratio = masked_batch_mean(is_region_clipped.float()) | ||||||
|
|
||||||
| gathered_low_clip = self.accelerator.gather(low_clip) | ||||||
| self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) | ||||||
| gathered_high_clip = self.accelerator.gather(high_clip) | ||||||
| self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) | ||||||
| gathered_clip_ratio = self.accelerator.gather(clip_ratio) | ||||||
| self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) | ||||||
| if self.loss_type != "cispo": | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, again (explicit better than implicit
Suggested change
|
||||||
| # Compute the clipped probability ratios | ||||||
| is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) | ||||||
| is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) | ||||||
| is_region_clipped = is_low_clipped | is_high_clipped | ||||||
|
|
||||||
| low_clip = masked_batch_mean(is_low_clipped.float()) | ||||||
| high_clip = masked_batch_mean(is_high_clipped.float()) | ||||||
| clip_ratio = masked_batch_mean(is_region_clipped.float()) | ||||||
|
|
||||||
| gathered_low_clip = self.accelerator.gather(low_clip) | ||||||
| self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) | ||||||
| gathered_high_clip = self.accelerator.gather(high_clip) | ||||||
| self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) | ||||||
| self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) | ||||||
| gathered_clip_ratio = self.accelerator.gather(clip_ratio) | ||||||
| self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) | ||||||
| elif self.loss_type == "cispo": | ||||||
| is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) | ||||||
| cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) | ||||||
| gathered_cispo_clip_ratio = self.accelerator.gather(cispo_clip_ratio) | ||||||
| self._metrics[mode]["cispo_clip_ratio"].append(gathered_cispo_clip_ratio.nanmean().item()) | ||||||
|
|
||||||
| return loss | ||||||
|
|
||||||
| def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.