Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,61 @@ def test_mix_policy_loss(self):
self.assertTrue(torch.allclose(torch.tensor(metrics["usual/pg_loss"]), pg_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["expert/sft_loss"]), sft_loss))
self.assertTrue(torch.allclose(torch.tensor(metrics["loss"]), mix_loss))

def test_ppo_policy_loss_with_truncate_adv_pos_is(self):
"""Test PPO policy loss with truncate large IS enabled."""
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
# Truncate small IS when advantage is positive
policy_loss_fn_args["truncate_adv_neg_is"] = False
policy_loss_fn_args["truncate_adv_pos_is"] = True
policy_loss_fn_args["truncate_is_range_low"] = 0.5
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)

# Expected values with IS truncation enabled when advantage is positive
ppo_loss_truncated = torch.tensor(0.28531503677368164)
pg_clipfrac = torch.tensor(0.3541666567325592)
ppo_kl = torch.tensor(-0.21663446724414825)
is_truncate_frac_pos_expected = torch.tensor(0.02083333395421505)

self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
# Check that IS truncation metric is present and has expected value
self.assertIn("is_truncate_frac_pos", metrics)
self.assertTrue(
torch.allclose(torch.tensor(metrics["is_truncate_frac_pos"]), is_truncate_frac_pos_expected)
)
self.assertGreaterEqual(metrics["is_truncate_frac_pos"], 0.0)
self.assertLessEqual(metrics["is_truncate_frac_pos"], 1.0)

def test_ppo_policy_loss_with_truncate_adv_neg_is(self):
"""Test PPO policy loss with truncate large IS enabled."""
policy_loss_fn_cls = POLICY_LOSS_FN.get("ppo")
policy_loss_fn_args = policy_loss_fn_cls.default_args()
# truncate large IS when advantage is negative
policy_loss_fn_args["truncate_adv_pos_is"] = False
policy_loss_fn_args["truncate_adv_neg_is"] = True
policy_loss_fn_args["truncate_is_range_high"] = 2.0
policy_loss_fn = policy_loss_fn_cls(**policy_loss_fn_args)
loss, metrics = policy_loss_fn(log_prob=self.logprob, **self.input_data.batch)

# Expected values with IS truncation enabled when advantage is negative
ppo_loss_truncated = torch.tensor(0.2230827361345291)
pg_clipfrac = torch.tensor(0.3541666567325592)
ppo_kl = torch.tensor(-0.21663446724414825)
is_truncate_frac_neg_expected = torch.tensor(0.1041666641831398)

self.assertTrue(torch.allclose(loss, ppo_loss_truncated))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_clipfrac"]), pg_clipfrac))
self.assertTrue(torch.allclose(torch.tensor(metrics["ppo_kl"]), ppo_kl))
self.assertTrue(torch.allclose(torch.tensor(metrics["pg_loss"]), ppo_loss_truncated))
# Check that IS truncation metric is present and has expected value
self.assertIn("is_truncate_frac_neg", metrics)
self.assertTrue(
torch.allclose(torch.tensor(metrics["is_truncate_frac_neg"]), is_truncate_frac_neg_expected)
)
self.assertGreaterEqual(metrics["is_truncate_frac_neg"], 0.0)
self.assertLessEqual(metrics["is_truncate_frac_neg"], 1.0)
93 changes: 90 additions & 3 deletions trinity/algorithm/policy_loss_fn/ppo_policy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,29 @@ def __init__(
clip_range_low: Optional[float] = None,
clip_range_high: Optional[float] = None,
loss_agg_mode: Optional[str] = "token-mean",
truncate_adv_pos_is: bool = False,
truncate_adv_neg_is: bool = False,
truncate_is_range_low: Optional[float] = 0.0,
truncate_is_range_high: Optional[float] = 2.0,
) -> None:
"""
Initialize PPO policy loss function.

Args:
backend: Backend framework (default: "verl")
clip_range: Symmetric clipping range for PPO
clip_range_low: Lower bound for clipping (1.0 - clip_range_low)
clip_range_high: Upper bound for clipping (1.0 + clip_range_high)
loss_agg_mode: Loss aggregation mode (default: "token-mean")
truncate_adv_pos_is: Whether to truncate large importance sampling ratios
when advantage is positive to handle calculation discrepancies between
rollout and training engines
truncate_adv_neg_is: Whether to truncate large importance sampling ratios
when advantage is negative to handle calculation discrepancies between
rollout and training engines
truncate_is_range_low: Lower bound for IS ratio truncation (default: 0.0)
truncate_is_range_high: Upper bound for IS ratio truncation (default: 2.0)
"""
super().__init__(backend=backend)
if clip_range_low is None:
self.clip_range_low = clip_range
Expand All @@ -34,6 +56,32 @@ def __init__(
assert self.clip_range_high is not None, "clip_range_high must be specified."
self.loss_agg_mode = loss_agg_mode

# Truncate large IS configuration
self.truncate_adv_pos_is = truncate_adv_pos_is
self.truncate_adv_neg_is = truncate_adv_neg_is
if truncate_adv_pos_is:
self.truncate_is_range_low = truncate_is_range_low
assert (
self.truncate_is_range_low is not None
), "truncate_is_range_low must be specified."
assert (
self.truncate_is_range_low >= 0.0
), "truncate_is_range_low must be non-negative."
assert (self.truncate_is_range_low < 1.0-self.clip_range_low
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert (self.truncate_is_range_low < 1.0-self.clip_range_low
assert (self.truncate_is_range_low < 1.0 - self.clip_range_low

), "truncate_is_range_low must be less than 1.0 - clip_range_low."
if truncate_adv_neg_is:
self.truncate_is_range_high = truncate_is_range_high
assert (
self.truncate_is_range_high is not None
), "truncate_is_range_high must be specified."
assert (
self.truncate_is_range_high > 1.0+self.clip_range_high
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.truncate_is_range_high > 1.0+self.clip_range_high
self.truncate_is_range_high > 1.0 + self.clip_range_high

), "truncate_is_range_high must be greater than clip_range_high + 1.0."
if truncate_adv_pos_is and truncate_adv_neg_is:
assert (
self.truncate_is_range_high > self.truncate_is_range_low
), "truncate_is_range_high must be greater than truncate_is_range_low."

def __call__( # type: ignore
self,
logprob: torch.Tensor,
Expand All @@ -46,25 +94,64 @@ def __call__( # type: ignore
ratio = torch.exp(negative_approx_kl)
ppo_kl = masked_mean(-negative_approx_kl, action_mask)

pg_losses = -advantages * ratio
# First clipping by clip_range, and calculate pg_clipfrac
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(
ratio, 1.0 - self.clip_range_low, 1.0 + self.clip_range_high # type: ignore
)
pg_losses_clip = torch.maximum(pg_losses1, pg_losses2)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses1).float(), action_mask)

# After clipped by clip_range, further truncate IS ratios if enabled
# This helps stabilize training when there are calculation discrepancies between
# rollout and training engines, especially for small probabilities
pg_truncfrac_pos, pg_truncfrac_neg = 0.0, 0.0
pg_losses_trunc = pg_losses_clip

# Add IS truncation for positive advantages
if self.truncate_adv_pos_is:
pg_losses_pos_trunc = -advantages * self.truncate_is_range_low
pg_truncfrac_pos = masked_mean(
torch.lt(pg_losses_pos_trunc, pg_losses_trunc) * (advantages > 0).float(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.lt(pg_losses_pos_trunc, pg_losses_trunc) * (advantages > 0).float(),
torch.lt(pg_losses_pos_trunc, pg_losses_trunc).float() * (advantages > 0),

action_mask,
)
pg_losses_pos = torch.minimum(pg_losses_trunc, pg_losses_pos_trunc)
pg_losses_trunc = torch.where(advantages > 0, pg_losses_pos, pg_losses_trunc)

# Add IS truncation for negative advantages
if self.truncate_adv_neg_is:
pg_losses_neg_trunc = -advantages * self.truncate_is_range_high
pg_truncfrac_neg = masked_mean(
torch.lt(pg_losses_neg_trunc, pg_losses_trunc) * (advantages < 0).float(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.lt(pg_losses_neg_trunc, pg_losses_trunc) * (advantages < 0).float(),
torch.lt(pg_losses_neg_trunc, pg_losses_trunc).float() * (advantages < 0),

action_mask,
)
pg_losses_neg = torch.minimum(pg_losses_trunc, pg_losses_neg_trunc)
pg_losses_trunc = torch.where(advantages < 0, pg_losses_neg, pg_losses_trunc)

pg_loss = masked_loss(
torch.max(pg_losses, pg_losses2), action_mask, loss_agg_mode=self.loss_agg_mode
pg_losses_trunc, action_mask, loss_agg_mode=self.loss_agg_mode
)
pg_clipfrac = masked_mean(torch.gt(pg_losses2, pg_losses).float(), action_mask)
metrics = {
"pg_clipfrac": pg_clipfrac.detach().item(),
"ppo_kl": ppo_kl.detach().item(),
"pg_loss": pg_loss.detach().item(),
}

# Add IS truncation metrics if enabled
if self.truncate_adv_pos_is:
metrics["is_truncate_frac_pos"] = pg_truncfrac_pos.detach().item()
if self.truncate_adv_neg_is:
metrics["is_truncate_frac_neg"] = pg_truncfrac_neg.detach().item()

return pg_loss, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"clip_range": 0.2,
"loss_agg_mode": "token-mean",
"truncate_adv_pos_is": False,
"truncate_adv_neg_is": False,
"truncate_is_range_low": 0.0,
"truncate_is_range_high": 2.0,
}