-
Notifications
You must be signed in to change notification settings - Fork 41
[Feature] Truncate based on the sign of advantage after clipping #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,29 @@ def __init__( | |||||
| clip_range_low: Optional[float] = None, | ||||||
| clip_range_high: Optional[float] = None, | ||||||
| loss_agg_mode: Optional[str] = "token-mean", | ||||||
| truncate_adv_pos_is: bool = False, | ||||||
| truncate_adv_neg_is: bool = False, | ||||||
| truncate_is_range_low: Optional[float] = 0.0, | ||||||
| truncate_is_range_high: Optional[float] = 2.0, | ||||||
| ) -> None: | ||||||
| """ | ||||||
| Initialize PPO policy loss function. | ||||||
|
|
||||||
| Args: | ||||||
| backend: Backend framework (default: "verl") | ||||||
| clip_range: Symmetric clipping range for PPO | ||||||
| clip_range_low: Lower bound for clipping (1.0 - clip_range_low) | ||||||
| clip_range_high: Upper bound for clipping (1.0 + clip_range_high) | ||||||
| loss_agg_mode: Loss aggregation mode (default: "token-mean") | ||||||
| truncate_adv_pos_is: Whether to truncate large importance sampling ratios | ||||||
| when advantage is positive to handle calculation discrepancies between | ||||||
| rollout and training engines | ||||||
| truncate_adv_neg_is: Whether to truncate large importance sampling ratios | ||||||
| when advantage is negative to handle calculation discrepancies between | ||||||
| rollout and training engines | ||||||
| truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0) | ||||||
| truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0) | ||||||
| """ | ||||||
| super().__init__(backend=backend) | ||||||
| if clip_range_low is None: | ||||||
| self.clip_range_low = clip_range | ||||||
|
|
@@ -34,6 +56,32 @@ def __init__( | |||||
| assert self.clip_range_high is not None, "clip_range_high must be specified." | ||||||
| self.loss_agg_mode = loss_agg_mode | ||||||
|
|
||||||
| # Truncate large IS configuration | ||||||
| self.truncate_adv_pos_is = truncate_adv_pos_is | ||||||
| self.truncate_adv_neg_is = truncate_adv_neg_is | ||||||
| if truncate_adv_pos_is: | ||||||
| self.truncate_is_range_low = truncate_is_range_low | ||||||
| assert ( | ||||||
| self.truncate_is_range_low is not None | ||||||
| ), "truncate_is_range_low must be specified." | ||||||
| assert ( | ||||||
| self.truncate_is_range_low >= 0.0 | ||||||
| ), "truncate_is_range_low must be non-negative." | ||||||
| assert (self.truncate_is_range_low < 1.0-self.clip_range_low | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| ), "truncate_is_range_low must be less than 1.0 - clip_range_low." | ||||||
| if truncate_adv_neg_is: | ||||||
| self.truncate_is_range_high = truncate_is_range_high | ||||||
| assert ( | ||||||
| self.truncate_is_range_high is not None | ||||||
| ), "truncate_is_range_high must be specified." | ||||||
| assert ( | ||||||
| self.truncate_is_range_high > 1.0+self.clip_range_high | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| ), "truncate_is_range_high must be greater than clip_range_high + 1.0." | ||||||
| if truncate_adv_pos_is and truncate_adv_neg_is: | ||||||
| assert ( | ||||||
| self.truncate_is_range_high > self.truncate_is_range_low | ||||||
| ), "truncate_is_range_high must be greater than truncate_is_range_low." | ||||||
|
|
||||||
| def __call__( # type: ignore | ||||||
| self, | ||||||
| logprob: torch.Tensor, | ||||||
|
|
@@ -46,25 +94,64 @@ def __call__( # type: ignore | |||||
| ratio = torch.exp(negative_approx_kl) | ||||||
| ppo_kl = masked_mean(-negative_approx_kl, action_mask) | ||||||
|
|
||||||
| pg_losses = -advantages * ratio | ||||||
| # First clipping by clip_range, and calculate pg_clipfrac | ||||||
| pg_losses1 = -advantages * ratio | ||||||
| pg_losses2 = -advantages * torch.clamp( | ||||||
| ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore | ||||||
| ) | ||||||
| pg_losses_clip = torch.maximum(pg_losses1, pg_losses2) | ||||||
| pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask) | ||||||
|
|
||||||
| # After clipped by clip_range, further truncate IS ratios if enabled | ||||||
| # This helps stabilize training when there are calculation discrepancies between | ||||||
| # rollout and training engines, especially for small probabilities | ||||||
| pg_truncfrac_pos, pg_truncfrac_neg = 0.0, 0.0 | ||||||
| pg_losses_trunc = pg_losses_clip | ||||||
|
|
||||||
| # Add IS truncation for positive advantages | ||||||
| if self.truncate_adv_pos_is: | ||||||
| pg_losses_pos_trunc = -advantages * self.truncate_is_range_low | ||||||
| pg_truncfrac_pos = masked_mean( | ||||||
| torch.lt(pg_losses_pos_trunc, pg_losses_trunc) * (advantages > 0).float(), | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| action_mask, | ||||||
| ) | ||||||
| pg_losses_pos = torch.minimum(pg_losses_trunc, pg_losses_pos_trunc) | ||||||
| pg_losses_trunc = torch.where(advantages > 0, pg_losses_pos, pg_losses_trunc) | ||||||
|
|
||||||
| # Add IS truncation for negative advantages | ||||||
| if self.truncate_adv_neg_is: | ||||||
| pg_losses_neg_trunc = -advantages * self.truncate_is_range_high | ||||||
| pg_truncfrac_neg = masked_mean( | ||||||
| torch.lt(pg_losses_neg_trunc, pg_losses_trunc) * (advantages < 0).float(), | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| action_mask, | ||||||
| ) | ||||||
| pg_losses_neg = torch.minimum(pg_losses_trunc, pg_losses_neg_trunc) | ||||||
| pg_losses_trunc = torch.where(advantages < 0, pg_losses_neg, pg_losses_trunc) | ||||||
|
|
||||||
| pg_loss = masked_loss( | ||||||
| torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode | ||||||
| pg_losses_trunc, action_mask, loss_agg_mode=self.loss_agg_mode | ||||||
| ) | ||||||
| pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask) | ||||||
| metrics = { | ||||||
| "pg_clipfrac": pg_clipfrac.detach().item(), | ||||||
| "ppo_kl": ppo_kl.detach().item(), | ||||||
| "pg_loss": pg_loss.detach().item(), | ||||||
| } | ||||||
|
|
||||||
| # Add IS truncation metrics if enabled | ||||||
| if self.truncate_adv_pos_is: | ||||||
| metrics["is_truncate_frac_pos"] = pg_truncfrac_pos.detach().item() | ||||||
| if self.truncate_adv_neg_is: | ||||||
| metrics["is_truncate_frac_neg"] = pg_truncfrac_neg.detach().item() | ||||||
|
|
||||||
| return pg_loss, metrics | ||||||
|
|
||||||
| @classmethod | ||||||
| def default_args(cls) -> Dict: | ||||||
| return { | ||||||
| "clip_range": 0.2, | ||||||
| "loss_agg_mode": "token-mean", | ||||||
| "truncate_adv_pos_is": False, | ||||||
| "truncate_adv_neg_is": False, | ||||||
| "truncate_is_range_low": 0.0, | ||||||
| "truncate_is_range_high": 2.0, | ||||||
| } | ||||||
Uh oh!
There was an error while loading. Please reload this page.