diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index 134635c05a..7f90f2d757 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -108,3 +108,32 @@ def test_mix_policy_loss(self): self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss)) self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss)) + + def test_ppo_policy_loss_with_truncate_is(self): + """Test PPO policy loss with truncate large IS enabled.""" + policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo") + policy_loss_fn_args = policy_loss_fn_cls.default_args() + # Enable truncate large IS with default bounds [0.0, 2.0] + policy_loss_fn_args["truncate_large_is"] = True + policy_loss_fn_args["truncate_is_range_low"] = 0.0 + policy_loss_fn_args["truncate_is_range_high"] = 2.0 + policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args) + loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch) + + # Expected values with IS truncation enabled (range: [0.0, 2.0]) + ppo_loss_truncated = torch.tensor(0.2230827361345291) + pg_clipfrac_truncated = torch.tensor(0.3541666567325592) + ppo_kl_truncated = torch.tensor(-0.21663446724414825) + is_truncate_frac_expected = torch.tensor(0.2708333432674408) + + self.assertTrue(torch.allclose(loss, ppo_loss_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl_truncated)) + self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated)) + # Check that IS truncation metric is present and has expected value + self.assertIn("is_truncate_frac", metrics) + self.assertTrue( + torch.allclose(torch.tensor(metrics["is_truncate_frac"]), is_truncate_frac_expected) + ) + self.assertGreaterEqual(metrics["is_truncate_frac"], 0.0) + self.assertLessEqual(metrics["is_truncate_frac"], 1.0) diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index 9c9bbaf2a5..b8cad22ce1 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -20,7 +20,24 @@ def __init__( clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, loss_agg_mode: Optional[str] = "token-mean", + truncate_large_is: bool = False, + truncate_is_range_low: Optional[float] = 0.0, + truncate_is_range_high: Optional[float] = 2.0, ) -> None: + """ + Initialize PPO policy loss function. + + Args: + backend: Backend framework (default: "verl") + clip_range: Symmetric clipping range for PPO + clip_range_low: Lower bound for clipping (1.0 - clip_range_low) + clip_range_high: Upper bound for clipping (1.0 + clip_range_high) + loss_agg_mode: Loss aggregation mode (default: "token-mean") + truncate_large_is: Whether to truncate large importance sampling ratios + to handle calculation discrepancies between rollout and training engines + truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0) + truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0) + """ super().__init__(backend=backend) if clip_range_low is None: self.clip_range_low = clip_range @@ -34,6 +51,22 @@ def __init__( assert self.clip_range_high is not None, "clip_range_high must be specified." self.loss_agg_mode = loss_agg_mode + # Truncate large IS configuration + self.truncate_large_is = truncate_large_is + if truncate_large_is: + self.truncate_is_range_low = truncate_is_range_low + self.truncate_is_range_high = truncate_is_range_high + assert ( + self.truncate_is_range_low is not None + ), "truncate_is_range_low must be specified." + assert ( + self.truncate_is_range_high is not None + ), "truncate_is_range_high must be specified." + assert self.truncate_is_range_low >= 0.0, "truncate_is_range_low must be non-negative." + assert ( + self.truncate_is_range_high > self.truncate_is_range_low + ), "truncate_is_range_high must be greater than truncate_is_range_low." + def __call__( # type: ignore self, logprob: torch.Tensor, @@ -46,6 +79,18 @@ def __call__( # type: ignore ratio = torch.exp(negative_approx_kl) ppo_kl = masked_mean(-negative_approx_kl, action_mask) + # Truncate large IS ratios if enabled + # This helps stabilize training when there are calculation discrepancies between + # rollout and training engines, especially for small probabilities + if self.truncate_large_is: + # Track how often truncation occurs (before actually truncating) + # More efficient than cloning: directly check which values fall outside bounds + ratio_detached = ratio.detach() + is_truncate_frac = masked_mean( + (ratio_detached < self.truncate_is_range_low).float(), action_mask + ) + masked_mean((ratio_detached > self.truncate_is_range_high).float(), action_mask) + ratio = torch.clamp(ratio, self.truncate_is_range_low, self.truncate_is_range_high) + pg_losses = -advantages * ratio pg_losses2 = -advantages * torch.clamp( ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore @@ -60,6 +105,11 @@ def __call__( # type: ignore "ppo_kl": ppo_kl.detach().item(), "pg_loss": pg_loss.detach().item(), } + + # Add IS truncation metrics if enabled + if self.truncate_large_is: + metrics["is_truncate_frac"] = is_truncate_frac.detach().item() + return pg_loss, metrics @classmethod @@ -67,4 +117,7 @@ def default_args(cls) -> Dict: return { "clip_range": 0.2, "loss_agg_mode": "token-mean", + "truncate_large_is": False, + "truncate_is_range_low": 0.0, + "truncate_is_range_high": 2.0, }