diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 134635c05a..d284352ac3 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -108,3 +108,99 @@ def test_mix_policy_loss(self): self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) + + def test_ppo_policy_loss_with_truncate_adv_pos_is(self): + """Test PPO policy loss with truncate large IS enabled.""" + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + # Truncate small IS when advantage is positive + policy_loss_fn_args["truncate_adv_neg_is"] = False + policy_loss_fn_args["truncate_adv_pos_is"] = True + policy_loss_fn_args["truncate_is_range_low"] = 0.5 + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + + # Expected values with IS truncation enabled when advantage is positive + ppo_loss_truncated = torch.tensor(0.28531503677368164) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + is_truncate_frac_pos_expected = torch.tensor(0.02083333395421505) + + self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated)) + # Check that IS truncation metric is present and has expected value + self.assertIn("is_truncate_frac_pos", metrics) + self.assertTrue( + torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected) + ) + self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0) + self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0) + + def test_ppo_policy_loss_with_truncate_adv_neg_is(self): + """Test PPO policy loss with truncate large IS enabled.""" + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + # truncate large IS when advantage is negative + policy_loss_fn_args["truncate_adv_pos_is"] = False + policy_loss_fn_args["truncate_adv_neg_is"] = True + policy_loss_fn_args["truncate_is_range_high"] = 2.0 + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + + # Expected values with IS truncation enabled when advantage is negative + ppo_loss_truncated = torch.tensor(0.2230827361345291) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + is_truncate_frac_neg_expected = torch.tensor(0.1041666641831398) + + self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated)) + # Check that IS truncation metric is present and has expected value + self.assertIn("is_truncate_frac_neg", metrics) + self.assertTrue( + torch.allclose(torch.tensor(metrics["is_truncate_frac_neg"]), is_truncate_frac_neg_expected) + ) + self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0) + self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0) + + def test_ppo_policy_loss_with_truncate_adv_both_is(self): + """Test PPO policy loss with truncate large IS enabled.""" + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + # truncate large IS when advantage is negative + policy_loss_fn_args["truncate_adv_pos_is"] = True + policy_loss_fn_args["truncate_is_range_low"] = 0.5 + policy_loss_fn_args["truncate_adv_neg_is"] = True + policy_loss_fn_args["truncate_is_range_high"] = 2.0 + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + + # Expected values with IS truncation enabled when advantage is negative + # ppo_loss_truncated = ppo_loss_adv_pos_truncated + ppo_loss_adv_neg_truncated - ppo_loss_untruncated + ppo_loss_truncated = torch.tensor(0.2227930873632431) + pg_clipfrac = torch.tensor(0.3541666567325592) + ppo_kl = torch.tensor(-0.21663446724414825) + is_truncate_frac_pos_expected = torch.tensor(0.02083333395421505) + is_truncate_frac_neg_expected = torch.tensor(0.1041666641831398) + + self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated)) + # Check that IS truncation metric is present and has expected value + self.assertIn("is_truncate_frac_pos", metrics) + self.assertTrue( + torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected) + ) + self.assertIn("is_truncate_frac_neg", metrics) + self.assertTrue( + torch.allclose(torch.tensor(metrics["is_truncate_frac_neg"]), is_truncate_frac_neg_expected) + ) + self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0) + self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0) + self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0) + self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 9c9bbaf2a5..b8cce8c028 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -20,7 +20,29 @@ def __init__( clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, loss_agg_mode: Optional[str] = "token-mean", + truncate_adv_pos_is: bool = False, + truncate_adv_neg_is: bool = False, + truncate_is_range_low: Optional[float] = 0.0, + truncate_is_range_high: Optional[float] = 2.0, ) -> None: + """ + Initialize PPO policy loss function. + + Args: + backend: Backend framework (default: "verl") + clip_range: Symmetric clipping range for PPO + clip_range_low: Lower bound for clipping (1.0 - clip_range_low) + clip_range_high: Upper bound for clipping (1.0 + clip_range_high) + loss_agg_mode: Loss aggregation mode (default: "token-mean") + truncate_adv_pos_is: Whether to truncate large importance sampling ratios + when advantage is positive to handle calculation discrepancies between + rollout and training engines + truncate_adv_neg_is: Whether to truncate large importance sampling ratios + when advantage is negative to handle calculation discrepancies between + rollout and training engines + truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0) + truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0) + """ super().__init__(backend=backend) if clip_range_low is None: self.clip_range_low = clip_range @@ -34,6 +56,32 @@ def __init__( assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode + # Truncate large IS configuration + self.truncate_adv_pos_is = truncate_adv_pos_is + self.truncate_adv_neg_is = truncate_adv_neg_is + if truncate_adv_pos_is: + self.truncate_is_range_low = truncate_is_range_low + assert ( + self.truncate_is_range_low is not None + ), "truncate_is_range_low must be specified." + assert ( + self.truncate_is_range_low >= 0.0 + ), "truncate_is_range_low must be non-negative." + assert (self.truncate_is_range_low < 1.0-self.clip_range_low + ), "truncate_is_range_low must be less than 1.0 - clip_range_low." + if truncate_adv_neg_is: + self.truncate_is_range_high = truncate_is_range_high + assert ( + self.truncate_is_range_high is not None + ), "truncate_is_range_high must be specified." + assert ( + self.truncate_is_range_high > 1.0+self.clip_range_high + ), "truncate_is_range_high must be greater than clip_range_high + 1.0." + if truncate_adv_pos_is and truncate_adv_neg_is: + assert ( + self.truncate_is_range_high > self.truncate_is_range_low + ), "truncate_is_range_high must be greater than truncate_is_range_low." + def __call__( # type: ignore self, logprob: torch.Tensor, @@ -46,20 +94,55 @@ def __call__( # type: ignore ratio = torch.exp(negative_approx_kl) ppo_kl = masked_mean(-negative_approx_kl, action_mask) - pg_losses = -advantages * ratio + # First clipping by clip_range, and calculate pg_clipfrac + pg_losses1 = -advantages * ratio pg_losses2 = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore ) + pg_losses_clip = torch.maximum(pg_losses1, pg_losses2) + pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask) + + # After clipped by clip_range, further truncate IS ratios if enabled + # This helps stabilize training when there are calculation discrepancies between + # rollout and training engines, especially for small probabilities + pg_truncfrac_pos, pg_truncfrac_neg = 0.0, 0.0 + pg_losses_trunc = pg_losses_clip + + # Add IS truncation for positive advantages + if self.truncate_adv_pos_is: + pg_losses_pos_trunc = -advantages * self.truncate_is_range_low + pg_truncfrac_pos = masked_mean( + torch.lt(pg_losses_pos_trunc, pg_losses_trunc) * (advantages > 0).float(), + action_mask, + ) + pg_losses_pos = torch.minimum(pg_losses_trunc, pg_losses_pos_trunc) + pg_losses_trunc = torch.where(advantages > 0, pg_losses_pos, pg_losses_trunc) + + # Add IS truncation for negative advantages + if self.truncate_adv_neg_is: + pg_losses_neg_trunc = -advantages * self.truncate_is_range_high + pg_truncfrac_neg = masked_mean( + torch.lt(pg_losses_neg_trunc, pg_losses_trunc) * (advantages < 0).float(), + action_mask, + ) + pg_losses_neg = torch.minimum(pg_losses_trunc, pg_losses_neg_trunc) + pg_losses_trunc = torch.where(advantages < 0, pg_losses_neg, pg_losses_trunc) pg_loss = masked_loss( - torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode + pg_losses_trunc, action_mask, loss_agg_mode=self.loss_agg_mode ) - pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask) metrics = { "pg_clipfrac": pg_clipfrac.detach().item(), "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } + + # Add IS truncation metrics if enabled + if self.truncate_adv_pos_is: + metrics["is_truncate_frac_pos"] = pg_truncfrac_pos.detach().item() + if self.truncate_adv_neg_is: + metrics["is_truncate_frac_neg"] = pg_truncfrac_neg.detach().item() + return pg_loss, metrics @classmethod @@ -67,4 +150,8 @@ def default_args(cls) -> Dict: return { "clip_range": 0.2, "loss_agg_mode": "token-mean", + "truncate_adv_pos_is": False, + "truncate_adv_neg_is": False, + "truncate_is_range_low": 0.0, + "truncate_is_range_high": 2.0, }