diff --git a/training/__init__.py b/training/__init__.py index 50a697fa..c722aec9 100644 --- a/training/__init__.py +++ b/training/__init__.py @@ -2,8 +2,7 @@ Recipes (fork and customise): - recipes/rl_loop.py: GRPO (RL) training with pluggable policy - losses -- set ``policy_loss`` to ``"grpo"``, ``"dapo"``, or ``"gspo"``; - enable TIS on any loss with ``tis_enabled=True`` + losses -- set ``policy_loss`` to ``"grpo"``, ``"dapo"``, ``"gspo"``, or ``"cispo"`` - recipes/dpo_loop.py: DPO (preference) training - recipes/orpo_loop.py: ORPO (preference) training -- no reference model needed; combines SFT loss with odds-ratio preference loss diff --git a/training/examples/deepmath/train_deepmath.py b/training/examples/deepmath/train_deepmath.py index a820f8da..5d023699 100644 --- a/training/examples/deepmath/train_deepmath.py +++ b/training/examples/deepmath/train_deepmath.py @@ -244,8 +244,7 @@ def main(): epochs=args.epochs, max_rows=args.max_rows, prompt_groups_per_step=args.prompt_groups_per_step, - tis_enabled=True, - tis=ISConfig(clip_high=2.0, clip_low=0.0), + is_correction=ISConfig(tis_cap=2.0), router_replay=args.router_replay, router_replay_completion_only=args.router_replay, infra=InfraConfig( diff --git a/training/recipes/rl_loop.py b/training/recipes/rl_loop.py index b3e5430a..c645595a 100644 --- a/training/recipes/rl_loop.py +++ b/training/recipes/rl_loop.py @@ -47,7 +47,8 @@ load_jsonl_dataset, ) from fireworks.training.sdk.deployment import DeploymentSampler -from training.utils.rl import ISConfig, PromptGroup +from training.utils.rl import PromptGroup +from training.utils.rl.importance_sampling import ISConfig from fireworks.training.sdk.weight_syncer import WeightSyncer from training.utils.rl.pp import compute_pp_recommendation from training.utils.timer import timer, flush_timing @@ -100,11 +101,11 @@ class Config: policy_loss: str = "grpo" """``"grpo"``, ``"dapo"``, ``"gspo"``, or ``"cispo"``.""" - tis_enabled: bool = False - tis: ISConfig = field(default_factory=ISConfig) dapo: DAPOConfig = field(default_factory=DAPOConfig) gspo: GSPOConfig = field(default_factory=GSPOConfig) cispo: CISPOConfig = field(default_factory=CISPOConfig) + is_correction: ISConfig = field(default_factory=ISConfig) + """AReaL-style decoupled IS correction: PPO ratio + behavioral weight.""" infra: InfraConfig = field(default_factory=InfraConfig) deployment: DeployConfig = field(default_factory=DeployConfig) @@ -315,9 +316,9 @@ def main( adam_params = tinker.AdamParams(learning_rate=cfg.learning_rate, **DEFAULT_ADAM) loss_builder = build_loss_fn( policy_loss=cfg.policy_loss, kl_beta=cfg.kl_beta, - tis_enabled=cfg.tis_enabled, tis_config=cfg.tis, dapo_config=cfg.dapo, gspo_config=cfg.gspo, cispo_config=cfg.cispo, + is_config=cfg.is_correction, ) sample_kwargs: dict = dict( @@ -447,9 +448,15 @@ def train_step( logger.info("[step %d] ref_forward: done (%.1fs)", step + 1, _t.time() - t0) data, adv, ref_lp, prompt_lens, inf_lp = combine_prompt_groups(prompt_groups) + + t0 = _t.time() + prox_fwd = policy.forward(data, "cross_entropy") + prox_lp = [prox_fwd.loss_fn_outputs[i]["logprobs"].data for i in range(len(data))] + logger.info("[step %d] prox_forward: done (%.1fs)", step + 1, _t.time() - t0) + t0 = _t.time() fwd_bwd_result = policy.forward_backward_custom( - data, loss_builder(adv, ref_lp, prompt_lens, inf_lp), + data, loss_builder(adv, ref_lp, prompt_lens, inf_lp, prox_lp), ) logger.info("[step %d] fwd_bwd: done (%.1fs)", step + 1, _t.time() - t0) diff --git a/training/tests/e2e/test_grpo_e2e.py b/training/tests/e2e/test_grpo_e2e.py index d3491395..0badab30 100644 --- a/training/tests/e2e/test_grpo_e2e.py +++ b/training/tests/e2e/test_grpo_e2e.py @@ -67,8 +67,7 @@ def test_grpo_full_pipeline( max_rows=10, epochs=1, router_replay=True, - tis_enabled=True, - tis=ISConfig(clip_high=10.0), + is_correction=ISConfig(tis_cap=10.0), infra=InfraConfig( region=e2e_region, skip_validations=True, diff --git a/training/tests/e2e/test_grpo_resume_e2e.py b/training/tests/e2e/test_grpo_resume_e2e.py index 465026c5..0e14e93f 100644 --- a/training/tests/e2e/test_grpo_resume_e2e.py +++ b/training/tests/e2e/test_grpo_resume_e2e.py @@ -82,8 +82,7 @@ def test_grpo_resume_from_checkpoint( max_rows=8, epochs=1, router_replay=True, - tis_enabled=True, - tis=ISConfig(clip_high=10.0), + is_correction=ISConfig(tis_cap=10.0), infra=shared_infra, deployment=DeployConfig( deployment_id=deployment_id, @@ -121,8 +120,7 @@ def test_grpo_resume_from_checkpoint( max_rows=6, epochs=1, router_replay=True, - tis_enabled=True, - tis=ISConfig(clip_high=10.0), + is_correction=ISConfig(tis_cap=10.0), infra=shared_infra, deployment=DeployConfig( deployment_id=deployment_id, diff --git a/training/tests/test_defaults.py b/training/tests/test_defaults.py index e93dc88a..80270305 100644 --- a/training/tests/test_defaults.py +++ b/training/tests/test_defaults.py @@ -15,10 +15,12 @@ def test_grpo_max_completion_tokens(): assert Config().max_completion_tokens == 1024 -def test_tis_clip_high(): +def test_is_config_defaults(): from training.utils.rl.importance_sampling import ISConfig - assert ISConfig().clip_high == 2.0 + cfg = ISConfig() + assert cfg.eps_clip == 0.2 + assert cfg.tis_cap == 5.0 def test_cispo_config_defaults(): diff --git a/training/tests/unit/test_batched_losses.py b/training/tests/unit/test_batched_losses.py index 1581f12d..f67f86ab 100644 --- a/training/tests/unit/test_batched_losses.py +++ b/training/tests/unit/test_batched_losses.py @@ -1,9 +1,4 @@ -"""Tests for batched (multi-prompt) loss functions. - -Verifies that passing per-datum ``prompt_lens: List[int]`` produces the -same result as the single ``prompt_len: int`` path when all prompt lengths -are identical, and handles mixed prompt lengths correctly. -""" +"""Tests for loss function utilities and DPO batching.""" from __future__ import annotations @@ -11,12 +6,8 @@ import pytest from training.utils.losses import make_batch_dpo_loss_fn -from training.utils.rl.dapo import DAPOConfig, make_dapo_loss_fn -from training.utils.rl.grpo import make_grpo_loss_fn -from training.utils.rl.gspo import GSPOConfig, make_gspo_loss_fn from training.utils.rl.common import _normalize_prompt_lens from training.utils.rl.losses import build_loss_fn -from training.utils.rl.importance_sampling import ISConfig, make_tis_weights_fn def _make_dummy_logprobs(seq_len: int, seed: int = 0) -> torch.Tensor: @@ -24,6 +15,10 @@ def _make_dummy_logprobs(seq_len: int, seed: int = 0) -> torch.Tensor: return torch.randn(seq_len, requires_grad=True) +def _zeros(n: int) -> list[float]: + return [0.0] * n + + class TestNormalizePromptLens: def test_int_broadcasts(self): assert _normalize_prompt_lens(5, 3) == [5, 5, 5] @@ -39,168 +34,16 @@ def test_list_length_mismatch_raises(self): class TestLossBuilder: def test_rejects_unknown_policy_loss(self): builder = build_loss_fn(policy_loss="unknown", kl_beta=0.01) - with pytest.raises(ValueError, match="Unsupported policy_loss"): - builder([1.0], [[0.0] * 4], [2], [[0.0] * 4]) - - -class TestGRPOLossBatched: - def test_single_int_equals_list(self): - """Loss with prompt_len=10 should equal prompt_lens=[10, 10].""" - adv = [1.0, -0.5] - ref = [[0.1] * 20, [0.2] * 20] - inf = [[0.0] * 20, [0.0] * 20] - lp0 = _make_dummy_logprobs(20, seed=42) - lp1 = _make_dummy_logprobs(20, seed=43) - - fn_int = make_grpo_loss_fn(adv, ref, prompt_len=10, inf_logprobs=inf, kl_beta=0.01) - fn_list = make_grpo_loss_fn(adv, ref, prompt_len=[10, 10], inf_logprobs=inf, kl_beta=0.01) - - loss_int, met_int = fn_int([], [lp0.detach().requires_grad_(True), lp1.detach().requires_grad_(True)]) - loss_list, met_list = fn_list([], [lp0.detach().requires_grad_(True), lp1.detach().requires_grad_(True)]) - - assert torch.allclose(loss_int, loss_list, atol=1e-6) - assert abs(met_int["mean_kl"] - met_list["mean_kl"]) < 1e-6 - - def test_mixed_prompt_lens(self): - """Different prompt_lens per datum should use different response_start.""" - adv = [1.0, 1.0] - ref = [[0.0] * 10, [0.0] * 10] - inf = [[0.0] * 10] - lp = _make_dummy_logprobs(10, seed=0) - - fn_short = make_grpo_loss_fn(adv[:1], ref[:1], prompt_len=2, inf_logprobs=inf, kl_beta=0.0) - fn_long = make_grpo_loss_fn(adv[:1], ref[:1], prompt_len=8, inf_logprobs=inf, kl_beta=0.0) - - loss_short, _ = fn_short([], [lp.detach().requires_grad_(True)]) - loss_long, _ = fn_long([], [lp.detach().requires_grad_(True)]) - - assert loss_short.item() != pytest.approx(loss_long.item(), abs=1e-6), ( - "Different prompt_lens should produce different losses" - ) - - def test_gradient_equivalence(self): - """Gradients through batched call match sum of per-prompt calls.""" - adv_a = [1.0, -0.5] - adv_b = [0.3, 0.7] - ref_a = [[0.1] * 15, [0.2] * 15] - ref_b = [[0.3] * 15, [0.15] * 15] - inf_a = [[0.0] * 15, [0.0] * 15] - inf_b = [[0.0] * 15, [0.0] * 15] - prompt_len_a, prompt_len_b = 5, 8 - - lp = [_make_dummy_logprobs(15, seed=i) for i in range(4)] - - fn_a = make_grpo_loss_fn(adv_a, ref_a, prompt_len=prompt_len_a, inf_logprobs=inf_a, kl_beta=0.01) - fn_b = make_grpo_loss_fn(adv_b, ref_b, prompt_len=prompt_len_b, inf_logprobs=inf_b, kl_beta=0.01) - loss_a, _ = fn_a([], [lp[0], lp[1]]) - loss_b, _ = fn_b([], [lp[2], lp[3]]) - separate_loss = loss_a + loss_b - - combined_adv = adv_a + adv_b - combined_ref = ref_a + ref_b - combined_inf = inf_a + inf_b - combined_lens = [prompt_len_a] * 2 + [prompt_len_b] * 2 - fn_combined = make_grpo_loss_fn( - combined_adv, - combined_ref, - prompt_len=combined_lens, - inf_logprobs=combined_inf, - kl_beta=0.01, - ) - - lp_fresh = [_make_dummy_logprobs(15, seed=i) for i in range(4)] - combined_loss, _ = fn_combined([], lp_fresh) - - assert torch.allclose(separate_loss, combined_loss, atol=1e-5), ( - f"Combined loss {combined_loss.item()} != separate sum {separate_loss.item()}" - ) - - def test_reports_train_inference_metrics_when_inf_logprobs_present(self): - adv = [1.0] - ref = [[0.0] * 6] - inf_lp = [[0.0, -0.6, -0.3, -0.1, -0.2, -0.5]] - lp = torch.tensor([0.0, -0.5, -0.2, -0.1, -0.3, -0.4], requires_grad=True) - - fn = make_grpo_loss_fn( - adv, - ref, - prompt_len=2, - kl_beta=0.0, - inf_logprobs=inf_lp, - ) - _, metrics = fn([], [lp]) - - diff = lp.detach()[1:] - torch.tensor(inf_lp[0][1:], dtype=lp.dtype) - expected_diff = diff.abs().mean().item() - expected_kld = (torch.exp(diff) - 1.0 - diff).mean().item() - - assert metrics["inference_diff"] == pytest.approx(expected_diff) - assert metrics["inference_kld"] == pytest.approx(expected_kld) - - def test_dapo_reports_train_inference_metrics(self): - adv = [1.0] - ref = [[0.0] * 6] - inf_lp = [[0.0, -0.6, -0.3, -0.1, -0.2, -0.5]] - lp = torch.tensor([0.0, -0.5, -0.2, -0.1, -0.3, -0.4], requires_grad=True) - - fn = make_dapo_loss_fn( - advantages=adv, - ref_logprobs=ref, - inf_logprobs=inf_lp, - prompt_len=2, - dapo_config=DAPOConfig(), - ) - _, metrics = fn([], [lp]) - - assert "inference_diff" in metrics - assert "inference_kld" in metrics - - def test_gspo_reports_train_inference_metrics(self): - adv = [1.0] - ref = [[0.0] * 6] - inf_lp = [[0.0, -0.6, -0.3, -0.1, -0.2, -0.5]] - lp = torch.tensor([0.0, -0.5, -0.2, -0.1, -0.3, -0.4], requires_grad=True) - - fn = make_gspo_loss_fn( - advantages=adv, - ref_logprobs=ref, - inf_logprobs=inf_lp, - prompt_len=2, - gspo_config=GSPOConfig(), - ) - _, metrics = fn([], [lp]) - - assert "inference_diff" in metrics - assert "inference_kld" in metrics + builder([1.0], [_zeros(4)], [2], [_zeros(4)], [_zeros(4)]) class TestBatchDPOLoss: - """Tests for ``make_batch_dpo_loss_fn``.""" def _make_ref_logprobs(self, seq_len: int, seed: int = 0) -> list[float]: torch.manual_seed(seed) return torch.randn(seq_len).tolist() - def test_single_pair_produces_valid_loss(self): - """Batched loss with 1 pair should produce a valid scalar loss.""" - ref_c = self._make_ref_logprobs(10, seed=0) - ref_r = self._make_ref_logprobs(10, seed=1) - rs = 3 - beta = 0.1 - - lp_c = _make_dummy_logprobs(10, seed=10) - lp_r = _make_dummy_logprobs(10, seed=11) - - fn = make_batch_dpo_loss_fn([ref_c], [ref_r], [rs], beta) - loss, met = fn([], [lp_c.clone(), lp_r.clone()]) - - assert loss.dim() == 0 - assert met["batch_pairs"] == 1 - assert "dpo_loss" in met - assert "margin" in met - assert met["accuracy"] in (0.0, 1.0) - def test_multi_pair_averages_correctly(self): """Batched loss with 2 pairs == average of two single-pair calls.""" ref_c0 = self._make_ref_logprobs(8, seed=0) @@ -228,39 +71,11 @@ def test_multi_pair_averages_correctly(self): [], [lp_c0.clone(), lp_r0.clone(), lp_c1.clone(), lp_r1.clone()], ) - assert torch.allclose(expected_avg, loss_b, atol=1e-5), ( - f"Batched {loss_b.item()} != average {expected_avg.item()}" - ) + assert torch.allclose(expected_avg, loss_b, atol=1e-5) assert met_b["batch_pairs"] == 2 def test_wrong_logprobs_count_raises(self): - """Passing wrong number of logprobs should raise.""" fn = make_batch_dpo_loss_fn([[0.0]], [[0.0]], [0], 0.1) lp = _make_dummy_logprobs(1, seed=0) with pytest.raises(AssertionError, match="Expected 2 logprobs"): fn([], [lp]) - - def test_gradient_flows(self): - """Gradients should propagate through the batched loss.""" - ref_c = self._make_ref_logprobs(6, seed=0) - ref_r = self._make_ref_logprobs(6, seed=1) - fn = make_batch_dpo_loss_fn([ref_c], [ref_r], [2], 0.1) - - lp_c = torch.randn(6, requires_grad=True) - lp_r = torch.randn(6, requires_grad=True) - loss, _ = fn([], [lp_c, lp_r]) - loss.backward() - - assert lp_c.grad is not None - assert lp_r.grad is not None - - -class TestTISWeights: - def test_rejects_short_inference_logprobs(self): - weights_fn = make_tis_weights_fn( - inf_logprobs=[[0.0, -0.1]], - prompt_len=2, - tis_config=ISConfig(), - ) - with pytest.raises(ValueError, match="requires at least"): - weights_fn(torch.tensor([-0.5, -0.3, -0.2]), 0) diff --git a/training/tests/unit/test_cispo_loss.py b/training/tests/unit/test_cispo_loss.py index 576c2250..82701ad9 100644 --- a/training/tests/unit/test_cispo_loss.py +++ b/training/tests/unit/test_cispo_loss.py @@ -59,6 +59,7 @@ def _build_and_call( ref_logprobs=[ref_lp], inf_logprobs=[inf_lp], prompt_len=prompt_len, + prox_logprobs=[inf_lp], cispo_config=cfg, ) loss, metrics = fn([], [pi_logprobs]) @@ -104,8 +105,9 @@ def test_int_vs_list_prompt_len(self): lp0 = _make_logprobs(20, seed=42) lp1 = _make_logprobs(20, seed=43) - fn_int = make_cispo_loss_fn(adv, ref, inf, prompt_len=10) - fn_list = make_cispo_loss_fn(adv, ref, inf, prompt_len=[10, 10]) + prox = [[0.05] * 20, [0.15] * 20] + fn_int = make_cispo_loss_fn(adv, ref, inf, prompt_len=10, prox_logprobs=prox) + fn_list = make_cispo_loss_fn(adv, ref, inf, prompt_len=[10, 10], prox_logprobs=prox) loss_int, met_int = fn_int([], [lp0.detach().requires_grad_(True), lp1.detach().requires_grad_(True)]) loss_list, met_list = fn_list([], [lp0.detach().requires_grad_(True), lp1.detach().requires_grad_(True)]) @@ -120,8 +122,9 @@ def test_mixed_prompt_lens(self): inf = [[0.0] * 10] lp = _make_logprobs(10, seed=0) - fn_short = make_cispo_loss_fn(adv, ref, inf, prompt_len=2) - fn_long = make_cispo_loss_fn(adv, ref, inf, prompt_len=8) + prox = [[0.0] * 10] + fn_short = make_cispo_loss_fn(adv, ref, inf, prompt_len=2, prox_logprobs=prox) + fn_long = make_cispo_loss_fn(adv, ref, inf, prompt_len=8, prox_logprobs=prox) loss_short, _ = fn_short([], [lp.detach().requires_grad_(True)]) loss_long, _ = fn_long([], [lp.detach().requires_grad_(True)]) @@ -136,7 +139,8 @@ def test_gradient_flows(self): lp = _make_logprobs(5, seed=0) lp_grad = lp.detach().requires_grad_(True) - fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1) + prox = [[0.0] * 5] + fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1, prox_logprobs=prox) loss, _ = fn([], [lp_grad]) loss.backward() @@ -149,9 +153,10 @@ def test_reports_mean_kl(self): adv = [1.0] ref = [[0.0] * 5] inf = [[0.0] * 5] + prox = [[0.0] * 5] lp = _make_logprobs(5, seed=0).detach().requires_grad_(True) - fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1) + fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1, prox_logprobs=prox) _, metrics = fn([], [lp]) assert "mean_kl" in metrics @@ -162,9 +167,10 @@ def test_reports_inference_metrics(self): adv = [1.0] ref = [[0.0] * 5] inf = [[0.0] * 5] + prox = [[0.0] * 5] lp = _make_logprobs(5, seed=0).detach().requires_grad_(True) - fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1) + fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1, prox_logprobs=prox) _, metrics = fn([], [lp]) assert "inference_diff" in metrics @@ -175,9 +181,10 @@ def test_requires_inf_logprobs(self): adv = [1.0] ref = [[0.0] * 5] inf = [[]] + prox = [[0.0] * 5] lp = _make_logprobs(5, seed=0).detach().requires_grad_(True) - fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1) + fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1, prox_logprobs=prox) with pytest.raises(ValueError, match="CISPO requires inference logprobs"): fn([], [lp]) @@ -186,8 +193,9 @@ def test_requires_sufficient_inf_logprobs(self): adv = [1.0] ref = [[0.0] * 5] inf = [[0.0] * 2] + prox = [[0.0] * 5] lp = _make_logprobs(5, seed=0).detach().requires_grad_(True) - fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1) + fn = make_cispo_loss_fn(adv, ref, inf, prompt_len=1, prox_logprobs=prox) with pytest.raises(ValueError, match="CISPO requires at least"): fn([], [lp]) diff --git a/training/tests/unit/test_decoupled_is.py b/training/tests/unit/test_decoupled_is.py new file mode 100644 index 00000000..ff61ca3b --- /dev/null +++ b/training/tests/unit/test_decoupled_is.py @@ -0,0 +1,57 @@ +"""Tests for TIS (train-inference IS) weight computation.""" + +from __future__ import annotations + +import torch + +from training.utils.rl.importance_sampling import ( + ISConfig, + compute_tis_weight, +) + + +class TestComputeTISWeight: + def test_same_logprobs_gives_weight_one(self): + prox = torch.tensor([-0.5, -0.3, -0.8]) + inf = torch.tensor([-0.5, -0.3, -0.8]) + + weight, metrics = compute_tis_weight(prox, inf, ISConfig()) + + torch.testing.assert_close(weight, torch.ones(3)) + assert metrics["tis/clip_frac"] == 0.0 + + def test_clamps_at_cap(self): + prox = torch.tensor([-0.1, -0.1]) + inf = torch.tensor([-3.0, -3.0]) + + weight, _ = compute_tis_weight(prox, inf, ISConfig(tis_cap=2.0)) + + assert weight.max().item() <= 2.0 + 1e-6 + + +class TestSequenceLevelTIS: + def test_all_tokens_get_same_weight(self): + prox = torch.tensor([-0.5, -0.3, -0.8]) + inf = torch.tensor([-0.6, -0.4, -0.9]) + + weight, metrics = compute_tis_weight(prox, inf, ISConfig(tis_level="sequence")) + + assert weight[0] == weight[1] == weight[2] + assert "tis/seq_ratio" in metrics + + def test_geometric_mean_formula(self): + prox = torch.tensor([-0.2, -0.4]) + inf = torch.tensor([-0.5, -0.3]) + + weight, _ = compute_tis_weight(prox, inf, ISConfig(tis_level="sequence", tis_cap=100.0)) + + expected = torch.exp((prox - inf).mean()) + torch.testing.assert_close(weight[0], expected) + + def test_token_level_gives_different_weights(self): + prox = torch.tensor([-0.2, -0.8]) + inf = torch.tensor([-0.5, -0.3]) + + weight, _ = compute_tis_weight(prox, inf, ISConfig(tis_cap=100.0)) + + assert weight[0] != weight[1] diff --git a/training/utils/rl/__init__.py b/training/utils/rl/__init__.py index 8e0a514c..1719ca18 100644 --- a/training/utils/rl/__init__.py +++ b/training/utils/rl/__init__.py @@ -4,8 +4,8 @@ # Losses & algorithms "CISPOConfig", "DAPOConfig", - "GSPOConfig", "ISConfig", + "GSPOConfig", "PPBatchRecommendation", "PromptGroup", "build_r3_routing_matrices", @@ -14,7 +14,6 @@ "make_dapo_loss_fn", "make_grpo_loss_fn", "make_gspo_loss_fn", - "make_tis_weights_fn", # Training loop "DynamicFilterFn", "TrainStepFns", @@ -44,4 +43,4 @@ add_response_length_stats, ) from training.utils.rl.router_replay import build_r3_routing_matrices -from training.utils.rl.importance_sampling import ISConfig, make_tis_weights_fn +from training.utils.rl.importance_sampling import ISConfig diff --git a/training/utils/rl/cispo.py b/training/utils/rl/cispo.py index 23524cd6..bcd4dc40 100644 --- a/training/utils/rl/cispo.py +++ b/training/utils/rl/cispo.py @@ -1,38 +1,33 @@ -"""CISPO (Clipped Importance Sampling Policy Optimization) loss for GRPO training. +"""CISPO (Clipped Importance Sampling Policy Optimization) loss. -Clips importance sampling weights (the ratio pi/pi_old) rather than the -PPO surrogate objective. Tokens where the ratio has moved too far in a -destabilizing direction are masked out entirely, while all remaining -tokens contribute full gradients. This "always use all tokens" property -yields better sample efficiency than PPO/DAPO clipping empirically. +Clips importance sampling weights (the ratio pi/pi_prox) rather than +the PPO surrogate objective. Tokens where the ratio has moved too far +in a destabilizing direction are masked out entirely, while all +remaining tokens contribute full gradients. Behavioral IS weight +corrects for the train-inference gap. The masking rule (Eq. 7 in the MiniMax-M1 paper): M_{i,t} = 0 if A > 0 and r > 1 + eps_high (already-boosted token) M_{i,t} = 0 if A < 0 and r < 1 - eps_low (already-suppressed token) M_{i,t} = 1 otherwise -Per-token loss: M_{i,t} * (-r_{i,t} * A_i) - -TIS can be composed on top via ``tis_weights_fn`` for additional -train-inference mismatch correction. - Reference: https://arxiv.org/abs/2506.13585 (Section 3.1) - -Example:: - - Config(policy_loss="cispo", cispo=CISPOConfig(eps_low=0.2, eps_high=0.28)) - Config(policy_loss="cispo", tis_enabled=True) # CISPO + TIS """ from __future__ import annotations -from typing import Dict, List, Tuple, Union, Callable +from typing import Dict, List, Tuple, Union from dataclasses import dataclass import torch import tinker from training.utils.rl.common import _normalize_prompt_lens +from training.utils.rl.importance_sampling import ( + SAFETY_CLAMP, + ISConfig, + compute_tis_weight, +) @dataclass @@ -58,23 +53,15 @@ def make_cispo_loss_fn( ref_logprobs: List[List[float]], inf_logprobs: List[List[float]], prompt_len: Union[int, List[int]], + prox_logprobs: List[List[float]], cispo_config: CISPOConfig | None = None, - tis_weights_fn: Callable | None = None, -) -> Callable[[List[tinker.Datum], List[torch.Tensor]], Tuple[torch.Tensor, Dict[str, float]]]: - """Build a CISPO loss closure. - - Computes importance-sampled policy gradient with IS-weight clipping: - the ratio ``pi/pi_old`` is used directly (not clipped), but tokens - where the ratio violates the CISPO mask are zeroed out entirely. - - ``prompt_len`` may be a single int or a per-datum list for multi-prompt - batched calls. - - *inf_logprobs* is always required (rollout/old-policy logprobs for ratio). - Pass *tis_weights_fn* to apply additional TIS correction on top. - """ + is_config: ISConfig | None = None, +) -> ...: + """Build a CISPO loss closure with IS-weight masking and behavioral IS weight.""" if cispo_config is None: cispo_config = CISPOConfig() + if is_config is None: + is_config = ISConfig() prompt_lens = _normalize_prompt_lens(prompt_len, len(advantages)) def loss_fn( @@ -83,19 +70,20 @@ def loss_fn( ) -> Tuple[torch.Tensor, Dict[str, float]]: total_loss = torch.tensor(0.0, requires_grad=True) total_kl = 0.0 - total_rho = 0.0 total_inf_diff = 0.0 total_inf_kld = 0.0 inf_num_samples = 0 num_tokens = 0 mask_frac_sum = 0.0 + ppo_ratio_mean_sum = 0.0 mask_frac_count = 0 - agg_tis: Dict[str, float] = {} + tis_metrics_agg: Dict[str, float] = {} for i, pi_logprobs in enumerate(logprobs_list): adv = advantages[i] ref_lp = ref_logprobs[i] inf_lp = inf_logprobs[i] + prox_lp = prox_logprobs[i] response_start = max(0, prompt_lens[i] - 1) resp_pi = pi_logprobs[response_start:] @@ -105,10 +93,8 @@ def loss_fn( resp_ref = torch.tensor( [ref_lp[response_start + j] if (response_start + j) < len(ref_lp) else 0.0 for j in range(resp_len)], - dtype=resp_pi.dtype, - device=resp_pi.device, + dtype=resp_pi.dtype, device=resp_pi.device, ) - pi_detached = resp_pi.detach() if not inf_lp: @@ -123,25 +109,22 @@ def loss_fn( ) resp_inf = torch.tensor( - inf_lp[response_start : response_start + resp_len], - dtype=resp_pi.dtype, - device=resp_pi.device, + inf_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, + ) + resp_prox = torch.tensor( + prox_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, ) + inf_log_diff = pi_detached - resp_inf total_inf_diff += inf_log_diff.abs().mean().item() total_inf_kld += (torch.exp(inf_log_diff) - inf_log_diff - 1.0).mean().item() inf_num_samples += 1 - # Importance ratio: r = pi / pi_old = exp(log_pi - log_pi_old) - log_ratio = torch.clamp( - resp_pi - resp_inf, - min=-cispo_config.ratio_log_cap, - max=cispo_config.ratio_log_cap, - ) + log_ratio = torch.clamp(resp_pi - resp_prox, min=-cispo_config.ratio_log_cap, max=cispo_config.ratio_log_cap) ratio = torch.exp(log_ratio) - # CISPO mask (Eq. 7): zero out tokens that have already moved - # far enough in the direction the advantage pushes. ratio_detached = ratio.detach() if adv > 0: mask = (ratio_detached <= 1.0 + cispo_config.eps_high).float() @@ -149,36 +132,32 @@ def loss_fn( mask = (ratio_detached >= 1.0 - cispo_config.eps_low).float() else: mask = torch.ones_like(ratio_detached) - mask_frac_sum += 1.0 - mask.mean().item() + ppo_ratio_mean_sum += ratio.detach().mean().item() mask_frac_count += 1 - adv_t = torch.as_tensor(adv, dtype=resp_pi.dtype, device=resp_pi.device) - per_token_loss = mask * (-ratio * adv_t) + tis_weight, bm = compute_tis_weight(resp_prox, resp_inf, is_config) + for k, v in bm.items(): + tis_metrics_agg[k] = tis_metrics_agg.get(k, 0.0) + v - if tis_weights_fn: - weights, tis_metrics = tis_weights_fn(pi_detached, i) - per_token_loss = per_token_loss * weights - total_rho += weights.sum().item() - for k, v in tis_metrics.items(): - agg_tis[k] = agg_tis.get(k, 0.0) + v + adv_t = torch.as_tensor(adv, dtype=resp_pi.dtype, device=resp_pi.device) + per_token_loss = mask * (-ratio * adv_t) * tis_weight total_loss = total_loss + per_token_loss.sum() total_kl += (pi_detached - resp_ref).sum().item() num_tokens += resp_len + n_samples = max(len(logprobs_list), 1) metrics: Dict[str, float] = { "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0, "cispo_mask_frac": mask_frac_sum / mask_frac_count if mask_frac_count > 0 else 0.0, + "ppo_ratio_mean": ppo_ratio_mean_sum / n_samples, } if inf_num_samples > 0: metrics["inference_diff"] = total_inf_diff / inf_num_samples metrics["inference_kld"] = total_inf_kld / inf_num_samples - if tis_weights_fn: - metrics["mean_importance_ratio"] = total_rho / num_tokens if num_tokens > 0 else 1.0 - n_samples = len(logprobs_list) or 1 - for k, v in agg_tis.items(): - metrics[k] = v / n_samples + for k, v in tis_metrics_agg.items(): + metrics[k] = v / n_samples return total_loss, metrics return loss_fn diff --git a/training/utils/rl/dapo.py b/training/utils/rl/dapo.py index a7eb5fb8..91a912c9 100644 --- a/training/utils/rl/dapo.py +++ b/training/utils/rl/dapo.py @@ -1,29 +1,26 @@ """DAPO (Dynamic Advantage Policy Optimization) loss for GRPO training. -Uses PPO-style clipped surrogate objective with asymmetric clipping bounds: -the lower bound (eps_clip) and upper bound (eps_clip_high) can differ. -No explicit KL penalty -- divergence is controlled solely via clipping. - -TIS can be composed on top via ``tis_weights_fn`` for additional -train-inference mismatch correction. +Uses PPO-style clipped surrogate objective with asymmetric clipping bounds +and behavioral IS weight correction. The PPO ratio is computed against +pre-computed proximal logprobs. Reference: https://arxiv.org/abs/2503.14476 - -Example:: - - Config(policy_loss="dapo", dapo=DAPOConfig(eps_clip=0.2, eps_clip_high=0.28)) - Config(policy_loss="dapo", tis_enabled=True) # DAPO + TIS """ from __future__ import annotations -from typing import Dict, List, Tuple, Union, Callable +from typing import Dict, List, Tuple, Union from dataclasses import dataclass import torch import tinker from training.utils.rl.common import _normalize_prompt_lens +from training.utils.rl.importance_sampling import ( + SAFETY_CLAMP, + ISConfig, + compute_tis_weight, +) @dataclass @@ -48,23 +45,15 @@ def make_dapo_loss_fn( ref_logprobs: List[List[float]], inf_logprobs: List[List[float]], prompt_len: Union[int, List[int]], + prox_logprobs: List[List[float]], dapo_config: DAPOConfig | None = None, - tis_weights_fn: Callable | None = None, -) -> Callable[[List[tinker.Datum], List[torch.Tensor]], Tuple[torch.Tensor, Dict[str, float]]]: - """Build a DAPO loss closure. - - Computes the PPO clipped surrogate objective with asymmetric bounds. - The importance ratio ``pi/pi_old`` is clipped to - ``[1 - eps_clip, 1 + eps_clip_high]``. - - ``prompt_len`` may be a single int or a per-datum list for multi-prompt - batched calls. - - *inf_logprobs* is always required (used for the PPO ratio). - Pass *tis_weights_fn* to apply additional TIS correction on top. - """ + is_config: ISConfig | None = None, +) -> ...: + """Build a DAPO loss closure with PPO-clipped ratio and behavioral IS weight.""" if dapo_config is None: dapo_config = DAPOConfig() + if is_config is None: + is_config = ISConfig() prompt_lens = _normalize_prompt_lens(prompt_len, len(advantages)) def loss_fn( @@ -73,19 +62,19 @@ def loss_fn( ) -> Tuple[torch.Tensor, Dict[str, float]]: total_loss = torch.tensor(0.0, requires_grad=True) total_kl = 0.0 - total_rho = 0.0 total_inf_diff = 0.0 total_inf_kld = 0.0 inf_num_samples = 0 num_tokens = 0 clip_frac_sum = 0.0 - clip_frac_count = 0 - agg_tis: Dict[str, float] = {} + ppo_ratio_mean_sum = 0.0 + tis_metrics_agg: Dict[str, float] = {} for i, pi_logprobs in enumerate(logprobs_list): adv = advantages[i] ref_lp = ref_logprobs[i] inf_lp = inf_logprobs[i] + prox_lp = prox_logprobs[i] response_start = max(0, prompt_lens[i] - 1) resp_pi = pi_logprobs[response_start:] @@ -95,47 +84,33 @@ def loss_fn( resp_ref = torch.tensor( [ref_lp[response_start + j] if (response_start + j) < len(ref_lp) else 0.0 for j in range(resp_len)], - dtype=resp_pi.dtype, - device=resp_pi.device, + dtype=resp_pi.dtype, device=resp_pi.device, ) - pi_detached = resp_pi.detach() - if not inf_lp: - raise ValueError( - f"DAPO requires inference logprobs for sample {i} but got empty list. " - f"Ensure logprobs=True is set when using policy_loss='dapo'." - ) - if len(inf_lp) < response_start + resp_len: - raise ValueError( - f"DAPO requires at least {response_start + resp_len} inference logprobs " - f"for sample {i}, got {len(inf_lp)}." - ) - resp_inf = torch.tensor( - inf_lp[response_start : response_start + resp_len], - dtype=resp_pi.dtype, - device=resp_pi.device, + inf_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, + ) + resp_prox = torch.tensor( + prox_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, ) + inf_log_diff = pi_detached - resp_inf total_inf_diff += inf_log_diff.abs().mean().item() total_inf_kld += (torch.exp(inf_log_diff) - inf_log_diff - 1.0).mean().item() inf_num_samples += 1 - # PPO clipped surrogate - log_ratio = torch.clamp( - resp_pi - resp_inf, - min=-dapo_config.ratio_log_cap, - max=dapo_config.ratio_log_cap, - ) + log_ratio = torch.clamp(resp_pi - resp_prox, min=-dapo_config.ratio_log_cap, max=dapo_config.ratio_log_cap) ratio = torch.exp(log_ratio) - clipped_ratio = torch.clamp( - ratio, - min=1.0 - dapo_config.eps_clip, - max=1.0 + dapo_config.eps_clip_high, - ) + clipped_ratio = torch.clamp(ratio, min=1.0 - dapo_config.eps_clip, max=1.0 + dapo_config.eps_clip_high) clip_frac_sum += (clipped_ratio != ratio).float().mean().item() - clip_frac_count += 1 + ppo_ratio_mean_sum += ratio.detach().mean().item() + + tis_weight, bm = compute_tis_weight(resp_prox, resp_inf, is_config) + for k, v in bm.items(): + tis_metrics_agg[k] = tis_metrics_agg.get(k, 0.0) + v adv_t = torch.as_tensor(adv, dtype=resp_pi.dtype, device=resp_pi.device) surr1 = -ratio * adv_t @@ -143,38 +118,29 @@ def loss_fn( clipped_surrogate = torch.maximum(surr1, surr2) if dapo_config.eps_clip_c is not None: if dapo_config.eps_clip_c <= 1.0: - raise ValueError( - f"DAPO dual-clip bound eps_clip_c must be > 1.0, got {dapo_config.eps_clip_c}." - ) + raise ValueError(f"DAPO dual-clip eps_clip_c must be > 1.0, got {dapo_config.eps_clip_c}.") surr3 = -dapo_config.eps_clip_c * adv_t lower_clipped = torch.minimum(surr3, clipped_surrogate) per_token_loss = torch.where(adv_t < 0, lower_clipped, clipped_surrogate) else: per_token_loss = clipped_surrogate - - if tis_weights_fn: - weights, tis_metrics = tis_weights_fn(pi_detached, i) - per_token_loss = per_token_loss * weights - total_rho += weights.sum().item() - for k, v in tis_metrics.items(): - agg_tis[k] = agg_tis.get(k, 0.0) + v + per_token_loss = per_token_loss * tis_weight total_loss = total_loss + per_token_loss.sum() total_kl += (pi_detached - resp_ref).sum().item() num_tokens += resp_len + n_samples = max(len(logprobs_list), 1) metrics: Dict[str, float] = { "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0, - "dapo_clip_frac": clip_frac_sum / clip_frac_count if clip_frac_count > 0 else 0.0, + "dapo_clip_frac": clip_frac_sum / n_samples, + "ppo_ratio_mean": ppo_ratio_mean_sum / n_samples, } if inf_num_samples > 0: metrics["inference_diff"] = total_inf_diff / inf_num_samples metrics["inference_kld"] = total_inf_kld / inf_num_samples - if tis_weights_fn: - metrics["mean_importance_ratio"] = total_rho / num_tokens if num_tokens > 0 else 1.0 - n_samples = len(logprobs_list) or 1 - for k, v in agg_tis.items(): - metrics[k] = v / n_samples + for k, v in tis_metrics_agg.items(): + metrics[k] = v / n_samples return total_loss, metrics return loss_fn diff --git a/training/utils/rl/grpo.py b/training/utils/rl/grpo.py index bd06ca32..065b2e51 100644 --- a/training/utils/rl/grpo.py +++ b/training/utils/rl/grpo.py @@ -1,13 +1,24 @@ -"""GRPO (Group Relative Policy Optimization) loss for RL training.""" +"""GRPO (Group Relative Policy Optimization) loss for RL training. + +Uses PPO-style clipped surrogate objective with behavioral IS weight +correction. The PPO ratio is computed against pre-computed proximal +logprobs (from a forward pass before training), and the behavioral +weight corrects for the train-inference gap. +""" from __future__ import annotations -from typing import Dict, List, Tuple, Union, Callable +from typing import Dict, List, Tuple, Union import torch import tinker from training.utils.rl.common import _normalize_prompt_lens +from training.utils.rl.importance_sampling import ( + SAFETY_CLAMP, + ISConfig, + compute_tis_weight, +) def make_grpo_loss_fn( @@ -15,19 +26,21 @@ def make_grpo_loss_fn( ref_logprobs: List[List[float]], prompt_len: Union[int, List[int]], inf_logprobs: List[List[float]], + prox_logprobs: List[List[float]], kl_beta: float = 0.001, - tis_weights_fn: Callable | None = None, -) -> Callable[[List[tinker.Datum], List[torch.Tensor]], Tuple[torch.Tensor, Dict[str, float]]]: - """GRPO policy-gradient loss with KL penalty against a reference model. - - ``prompt_len`` may be a single int (all datums share the same prompt - length) or a per-datum list for multi-prompt batched calls. - ``inf_logprobs`` is required to compute train/inference divergence metrics. - - Pass *tis_weights_fn* (from :func:`make_tis_weights_fn`) to apply - TIS train-inference mismatch correction on top of the base loss. + is_config: ISConfig | None = None, +) -> ...: + """GRPO loss with PPO-clipped ratio and behavioral IS weight. + + ``prox_logprobs`` are pre-computed by a forward pass before training. + The PPO ratio ``exp(pi_theta - prox)`` is clipped by + ``ISConfig.eps_clip``. The behavioral weight + ``exp(prox - inf)`` corrects for train-inference mismatch. """ + if is_config is None: + is_config = ISConfig() prompt_lens = _normalize_prompt_lens(prompt_len, len(advantages)) + eps_high = is_config.eps_clip if is_config.eps_clip_high is None else is_config.eps_clip_high def loss_fn( data: List[tinker.Datum], @@ -35,17 +48,19 @@ def loss_fn( ) -> Tuple[torch.Tensor, Dict[str, float]]: total_loss = torch.tensor(0.0, requires_grad=True) total_kl = 0.0 - total_rho = 0.0 total_inf_diff = 0.0 total_inf_kld = 0.0 inf_num_samples = 0 num_tokens = 0 - agg_tis: Dict[str, float] = {} + clip_frac_sum = 0.0 + ppo_ratio_mean_sum = 0.0 + tis_metrics_agg: Dict[str, float] = {} for i, pi_logprobs in enumerate(logprobs_list): adv = advantages[i] ref_lp = ref_logprobs[i] inf_lp = inf_logprobs[i] + prox_lp = prox_logprobs[i] response_start = max(0, prompt_lens[i] - 1) resp_pi = pi_logprobs[response_start:] @@ -55,55 +70,56 @@ def loss_fn( resp_ref = torch.tensor( [ref_lp[response_start + j] if (response_start + j) < len(ref_lp) else 0.0 for j in range(resp_len)], - dtype=resp_pi.dtype, - device=resp_pi.device, + dtype=resp_pi.dtype, device=resp_pi.device, ) pi_detached = resp_pi.detach() - per_token_loss = (-adv + kl_beta) * resp_pi - - if tis_weights_fn: - weights, tis_metrics = tis_weights_fn(pi_detached, i) - per_token_loss = per_token_loss * weights - total_rho += weights.sum().item() - for k, v in tis_metrics.items(): - agg_tis[k] = agg_tis.get(k, 0.0) + v - - if not inf_lp: - raise ValueError( - f"GRPO requires inference logprobs for sample {i} but got empty list. " - f"Ensure logprobs=True is set." - ) - if len(inf_lp) < response_start + resp_len: - raise ValueError( - f"GRPO requires at least {response_start + resp_len} inference logprobs " - f"for sample {i}, got {len(inf_lp)}." - ) resp_inf = torch.tensor( - inf_lp[response_start : response_start + resp_len], - dtype=resp_pi.dtype, - device=resp_pi.device, + inf_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, + ) + resp_prox = torch.tensor( + prox_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, ) + inf_log_diff = pi_detached - resp_inf total_inf_diff += inf_log_diff.abs().mean().item() total_inf_kld += (torch.exp(inf_log_diff) - inf_log_diff - 1.0).mean().item() inf_num_samples += 1 + log_ratio = torch.clamp(resp_pi - resp_prox, min=-SAFETY_CLAMP, max=SAFETY_CLAMP) + ratio = torch.exp(log_ratio) + clipped_ratio = torch.clamp(ratio, min=1.0 - is_config.eps_clip, max=1.0 + eps_high) + clip_frac_sum += (clipped_ratio != ratio).float().mean().item() + ppo_ratio_mean_sum += ratio.detach().mean().item() + + tis_weight, bm = compute_tis_weight(resp_prox, resp_inf, is_config) + for k, v in bm.items(): + tis_metrics_agg[k] = tis_metrics_agg.get(k, 0.0) + v + + adv_t = torch.as_tensor(adv, dtype=resp_pi.dtype, device=resp_pi.device) + surr1 = -ratio * adv_t + surr2 = -clipped_ratio * adv_t + per_token_loss = torch.maximum(surr1, surr2) * tis_weight + kl_penalty = kl_beta * (pi_detached - resp_ref) + per_token_loss = per_token_loss + kl_penalty + total_loss = total_loss + per_token_loss.sum() total_kl += (pi_detached - resp_ref).sum().item() num_tokens += resp_len + n_samples = max(len(logprobs_list), 1) metrics: Dict[str, float] = { "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0, + "ppo_clip_frac": clip_frac_sum / n_samples, + "ppo_ratio_mean": ppo_ratio_mean_sum / n_samples, } if inf_num_samples > 0: metrics["inference_diff"] = total_inf_diff / inf_num_samples metrics["inference_kld"] = total_inf_kld / inf_num_samples - if tis_weights_fn: - metrics["mean_importance_ratio"] = total_rho / num_tokens if num_tokens > 0 else 1.0 - n_samples = len(logprobs_list) or 1 - for k, v in agg_tis.items(): - metrics[k] = v / n_samples + for k, v in tis_metrics_agg.items(): + metrics[k] = v / n_samples return total_loss, metrics return loss_fn diff --git a/training/utils/rl/gspo.py b/training/utils/rl/gspo.py index ddaa8c64..51118fa1 100644 --- a/training/utils/rl/gspo.py +++ b/training/utils/rl/gspo.py @@ -1,26 +1,27 @@ """GSPO (Group Sequence Policy Optimization) loss for GRPO training. Implements PPO-style clipping with a **sequence-level importance ratio** -(geometric mean of per-token ratios), then broadcasts that ratio to tokens. -This matches common GSPO implementations in open-source RL training stacks. - -TIS can be composed on top via ``tis_weights_fn``. +(geometric mean of per-token ratios) against pre-computed proximal +logprobs, with behavioral IS weight correction. Example:: Config(policy_loss="gspo", gspo=GSPOConfig(clip_ratio=0.2)) - Config(policy_loss="gspo", tis_enabled=True) # GSPO + TIS """ from __future__ import annotations -from typing import Dict, List, Tuple, Union, Callable +from typing import Dict, List, Tuple, Union from dataclasses import dataclass import torch import tinker from training.utils.rl.common import _normalize_prompt_lens +from training.utils.rl.importance_sampling import ( + ISConfig, + compute_tis_weight, +) @dataclass @@ -43,24 +44,15 @@ def make_gspo_loss_fn( ref_logprobs: List[List[float]], inf_logprobs: List[List[float]], prompt_len: Union[int, List[int]], + prox_logprobs: List[List[float]], gspo_config: GSPOConfig | None = None, - tis_weights_fn: Callable | None = None, -) -> Callable[[List[tinker.Datum], List[torch.Tensor]], Tuple[torch.Tensor, Dict[str, float]]]: - """Build a GSPO loss closure. - - Uses sequence-level importance ratio: - ``r_seq = exp(mean_t(log pi_t - log pi_old_t))`` - followed by PPO clipping on that broadcasted ratio. - - ``prompt_len`` may be a single int or a per-datum list for multi-prompt - batched calls. - - ``inf_logprobs`` is required (rollout/old-policy logprobs for ratio). - - Pass *tis_weights_fn* to apply TIS correction on top. - """ + is_config: ISConfig | None = None, +) -> ...: + """Build a GSPO loss closure with sequence-level PPO ratio and behavioral IS weight.""" if gspo_config is None: gspo_config = GSPOConfig() + if is_config is None: + is_config = ISConfig() clip_low = gspo_config.clip_ratio if gspo_config.clip_ratio_low is None else gspo_config.clip_ratio_low clip_high = gspo_config.clip_ratio if gspo_config.clip_ratio_high is None else gspo_config.clip_ratio_high prompt_lens = _normalize_prompt_lens(prompt_len, len(advantages)) @@ -71,19 +63,19 @@ def loss_fn( ) -> Tuple[torch.Tensor, Dict[str, float]]: total_loss = torch.tensor(0.0, requires_grad=True) total_kl = 0.0 - total_rho = 0.0 total_inf_diff = 0.0 total_inf_kld = 0.0 inf_num_samples = 0 num_tokens = 0 clip_frac_sum = 0.0 - clip_frac_count = 0 - agg_tis: Dict[str, float] = {} + ppo_ratio_mean_sum = 0.0 + tis_metrics_agg: Dict[str, float] = {} for i, pi_logprobs in enumerate(logprobs_list): adv = advantages[i] ref_lp = ref_logprobs[i] inf_lp = inf_logprobs[i] + prox_lp = prox_logprobs[i] response_start = max(0, prompt_lens[i] - 1) resp_pi = pi_logprobs[response_start:] @@ -93,74 +85,58 @@ def loss_fn( resp_ref = torch.tensor( [ref_lp[response_start + j] if (response_start + j) < len(ref_lp) else 0.0 for j in range(resp_len)], - dtype=resp_pi.dtype, - device=resp_pi.device, + dtype=resp_pi.dtype, device=resp_pi.device, ) - pi_detached = resp_pi.detach() - if not inf_lp: - raise ValueError( - f"GSPO requires inference logprobs for sample {i} but got empty list. " - f"Ensure logprobs=True is set when using policy_loss='gspo'." - ) - if len(inf_lp) < response_start + resp_len: - raise ValueError( - f"GSPO requires at least {response_start + resp_len} inference logprobs " - f"for sample {i}, got {len(inf_lp)}." - ) resp_inf = torch.tensor( - inf_lp[response_start : response_start + resp_len], - dtype=resp_pi.dtype, - device=resp_pi.device, + inf_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, + ) + resp_prox = torch.tensor( + prox_lp[response_start:response_start + resp_len], + dtype=resp_pi.dtype, device=resp_pi.device, ) + inf_log_diff = pi_detached - resp_inf total_inf_diff += inf_log_diff.abs().mean().item() total_inf_kld += (torch.exp(inf_log_diff) - inf_log_diff - 1.0).mean().item() inf_num_samples += 1 - log_ratio = resp_pi - resp_inf + log_ratio = resp_pi - resp_prox seq_log_ratio = log_ratio.mean() log_seq_ratio = resp_pi - resp_pi.detach() + seq_log_ratio.detach() log_seq_ratio = torch.clamp(log_seq_ratio, max=gspo_config.seq_ratio_log_cap) seq_ratio = torch.exp(log_seq_ratio) - clipped_seq_ratio = torch.clamp( - seq_ratio, - min=1.0 - clip_low, - max=1.0 + clip_high, - ) + clipped_seq_ratio = torch.clamp(seq_ratio, min=1.0 - clip_low, max=1.0 + clip_high) clip_frac_sum += (clipped_seq_ratio != seq_ratio).float().mean().item() - clip_frac_count += 1 + ppo_ratio_mean_sum += seq_ratio.detach().mean().item() + + tis_weight, bm = compute_tis_weight(resp_prox, resp_inf, is_config) + for k, v in bm.items(): + tis_metrics_agg[k] = tis_metrics_agg.get(k, 0.0) + v adv_t = torch.as_tensor(adv, dtype=resp_pi.dtype, device=resp_pi.device) surr1 = -seq_ratio * adv_t surr2 = -clipped_seq_ratio * adv_t - per_token_loss = torch.maximum(surr1, surr2) - - if tis_weights_fn: - weights, tis_metrics = tis_weights_fn(pi_detached, i) - per_token_loss = per_token_loss * weights - total_rho += weights.sum().item() - for k, v in tis_metrics.items(): - agg_tis[k] = agg_tis.get(k, 0.0) + v + per_token_loss = torch.maximum(surr1, surr2) * tis_weight total_loss = total_loss + per_token_loss.sum() total_kl += (pi_detached - resp_ref).sum().item() num_tokens += resp_len + n_samples = max(len(logprobs_list), 1) metrics: Dict[str, float] = { "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0, - "gspo_clip_frac": clip_frac_sum / clip_frac_count if clip_frac_count > 0 else 0.0, + "gspo_clip_frac": clip_frac_sum / n_samples, + "ppo_ratio_mean": ppo_ratio_mean_sum / n_samples, } if inf_num_samples > 0: metrics["inference_diff"] = total_inf_diff / inf_num_samples metrics["inference_kld"] = total_inf_kld / inf_num_samples - if tis_weights_fn: - metrics["mean_importance_ratio"] = total_rho / num_tokens if num_tokens > 0 else 1.0 - n_samples = len(logprobs_list) or 1 - for k, v in agg_tis.items(): - metrics[k] = v / n_samples + for k, v in tis_metrics_agg.items(): + metrics[k] = v / n_samples return total_loss, metrics return loss_fn diff --git a/training/utils/rl/importance_sampling.py b/training/utils/rl/importance_sampling.py index 16db3140..0171b8bb 100644 --- a/training/utils/rl/importance_sampling.py +++ b/training/utils/rl/importance_sampling.py @@ -1,102 +1,64 @@ -"""Truncated Importance Sampling (TIS) for train-inference mismatch correction. +"""Importance sampling corrections for RL training. -Provides per-token importance weighting that can be composed with **any** -base policy loss (GRPO, DAPO, GSPO, etc.). The architecture follows -an orthogonal design: base loss computes per-token loss, TIS -multiplies by clipped importance weights, then the result is summed. +Two corrections applied to every loss function: -Usage:: +1. **PPO IS ratio** -- ``exp(pi_theta - pi_prox)`` with PPO-style + clipping. The proximal logprobs are pre-computed via a real forward + pass before the training loop. In a 1:1 on-policy loop the ratio + is 1 (no effect); it becomes non-trivial with off-policy or + multi-minibatch training. - Config(policy_loss="grpo", tis_enabled=True, tis=ISConfig(clip_high=10.0)) - Config(policy_loss="dapo", tis_enabled=True) - Config(policy_loss="gspo", tis_enabled=True) - -To write your own TIS, replace ``make_tis_weights_fn`` with a function -that returns the same ``(pi_detached, sample_idx) -> (weights, metrics)`` -signature and pass it to the loss via ``tis_weights_fn=``. +2. **TIS (Train-Inference IS) weight** -- ``exp(pi_prox - pi_old)`` + clamped at ``tis_cap``. Corrects for the numerical gap between the + training model and the inference deployment (FP8, quantization, + different parallelism, etc.). """ from __future__ import annotations -from typing import Dict, List, Tuple, Union, Callable from dataclasses import dataclass import torch -from training.utils.rl.common import _normalize_prompt_lens - SAFETY_CLAMP = 20.0 -"""Clamp log-ratio to [-SAFETY_CLAMP, SAFETY_CLAMP] before exp() to -prevent inf/NaN. Matches VERL's ``SAFETY_BOUND``.""" - -TISWeightsFn = Callable[ - [torch.Tensor, int], - Tuple[torch.Tensor, Dict[str, float]], -] -"""Per-sample TIS weights function: ``(pi_detached, sample_idx) -> (weights, metrics)``.""" @dataclass class ISConfig: - """TIS (Truncated Importance Sampling) configuration.""" - - clip_high: float = 2.0 - clip_low: float = 0.0 - - -def make_tis_weights_fn( - inf_logprobs: List[List[float]], - prompt_len: Union[int, List[int]], - tis_config: ISConfig | None = None, -) -> TISWeightsFn: - """Create a per-sample TIS weights function (vanilla clamped IS). - - Computes ``weights = clamp(exp(train_lp - rollout_lp), low, high)`` - per response token -- the same formula used in common open-source RL stacks. - - ``prompt_len`` may be a single int or a per-datum list for multi-prompt - batched calls. - - Returns a callable ``(pi_detached, sample_idx) -> (weights, metrics)`` - suitable for passing to any loss function's ``tis_weights_fn`` parameter. - """ - if tis_config is None: - tis_config = ISConfig() - prompt_lens = _normalize_prompt_lens(prompt_len, len(inf_logprobs)) - - def weights_fn( - pi_detached: torch.Tensor, - sample_idx: int, - ) -> Tuple[torch.Tensor, Dict[str, float]]: - inf_lp = inf_logprobs[sample_idx] - if not inf_lp: - raise ValueError( - f"TIS requires inference logprobs for sample {sample_idx} but got empty list. " - f"Ensure logprobs=True is set when tis_enabled=True." - ) - response_start = max(0, prompt_lens[sample_idx] - 1) - resp_len = len(pi_detached) - if len(inf_lp) < response_start + resp_len: - raise ValueError( - f"TIS requires at least {response_start + resp_len} inference logprobs " - f"for sample {sample_idx}, got {len(inf_lp)}." - ) - resp_inf = torch.tensor( - inf_lp[response_start : response_start + resp_len], - dtype=pi_detached.dtype, - device=pi_detached.device, - ) - - log_ratio = torch.clamp(pi_detached - resp_inf, min=-SAFETY_CLAMP, max=SAFETY_CLAMP) - rho = torch.exp(log_ratio) - weights = torch.clamp(rho, min=tis_config.clip_low, max=tis_config.clip_high) - clip_frac = (weights != rho).float().mean().item() - - metrics = { - "tis_mean_ratio": rho.mean().item(), - "tis_max_ratio": rho.max().item(), - "tis_clip_frac": clip_frac, - } - return weights, metrics - - return weights_fn + """Importance sampling correction configuration.""" + + eps_clip: float = 0.2 + """PPO clip epsilon for the off-policy ratio (used by GRPO).""" + eps_clip_high: float | None = None + """Asymmetric upper clip bound for GRPO.""" + tis_cap: float = 5.0 + """Upper clamp for the TIS weight.""" + tis_level: str = "token" + """'token': per-token IS weights. 'sequence': geometric mean + of per-token ratios, broadcast to all tokens.""" + + +def compute_tis_weight( + resp_prox: torch.Tensor, + resp_inf: torch.Tensor, + config: ISConfig, +) -> tuple[torch.Tensor, dict[str, float]]: + """Compute TIS weight: clamp(exp(prox - inf), max=tis_cap).""" + tis_log = torch.clamp(resp_prox - resp_inf, min=-SAFETY_CLAMP, max=SAFETY_CLAMP) + + if config.tis_level == "sequence": + tis_raw = torch.exp(tis_log.mean()).expand_as(tis_log) + else: + tis_raw = torch.exp(tis_log) + + tis_weight = torch.clamp(tis_raw, min=0.0, max=config.tis_cap) + clip_frac = (tis_weight != tis_raw).float().mean().item() + + metrics: dict[str, float] = { + "tis/weight_mean": tis_weight.mean().item(), + "tis/clip_frac": clip_frac, + } + if config.tis_level == "sequence": + metrics["tis/seq_ratio"] = tis_raw[0].item() if tis_raw.numel() > 0 else 1.0 + + return tis_weight, metrics diff --git a/training/utils/rl/losses.py b/training/utils/rl/losses.py index 6616715a..6fc01356 100644 --- a/training/utils/rl/losses.py +++ b/training/utils/rl/losses.py @@ -7,6 +7,8 @@ import tinker +from training.utils.rl.importance_sampling import ISConfig + @dataclass class PromptGroup: @@ -49,65 +51,62 @@ def combine_prompt_groups( return data, advantages, ref_logprobs, prompt_lens, inf_logprobs -LossFnBuilder = Callable[ - [List[float], List[List[float]], List[int], List[List[float]]], - Any, -] -"""Signature for the loss builder returned by ``build_loss_fn``. - -``(advantages, ref_logprobs, prompt_lens, inf_logprobs) -> loss_fn_value`` -""" +LossFnBuilder = Callable[..., Any] +"""Signature for the loss builder returned by ``build_loss_fn``.""" def build_loss_fn( policy_loss: str, kl_beta: float, - tis_enabled: bool = False, - tis_config: Any = None, dapo_config: Any = None, gspo_config: Any = None, cispo_config: Any = None, + is_config: ISConfig | None = None, ) -> LossFnBuilder: """Create a loss builder that dispatches to grpo/dapo/gspo/cispo. - Returns a callable that accepts (advantages, ref_logprobs, prompt_lens, - inf_logprobs) and returns a tinker loss_fn value. + Returns a callable: + (advantages, ref_logprobs, prompt_lens, inf_logprobs, prox_logprobs) -> loss_fn """ + if is_config is None: + is_config = ISConfig() + from training.utils.rl.dapo import make_dapo_loss_fn from training.utils.rl.grpo import make_grpo_loss_fn from training.utils.rl.gspo import make_gspo_loss_fn from training.utils.rl.cispo import make_cispo_loss_fn - from training.utils.rl.importance_sampling import make_tis_weights_fn def build( advantages: List[float], ref_logprobs: List[List[float]], prompt_lens: List[int], inf_logprobs: List[List[float]], + prox_logprobs: List[List[float]], ) -> Any: - tis_wf = None - if tis_enabled and tis_config is not None: - tis_wf = make_tis_weights_fn(inf_logprobs, prompt_lens, tis_config) - if policy_loss == "dapo": return make_dapo_loss_fn( advantages, ref_logprobs, inf_logprobs, - prompt_lens, dapo_config, tis_weights_fn=tis_wf, + prompt_lens, prox_logprobs, + dapo_config, is_config=is_config, ) if policy_loss == "gspo": return make_gspo_loss_fn( advantages, ref_logprobs, inf_logprobs, - prompt_lens, gspo_config, tis_weights_fn=tis_wf, + prompt_lens, prox_logprobs, + gspo_config, is_config=is_config, ) if policy_loss == "cispo": return make_cispo_loss_fn( advantages, ref_logprobs, inf_logprobs, - prompt_lens, cispo_config, tis_weights_fn=tis_wf, + prompt_lens, prox_logprobs, + cispo_config, is_config=is_config, ) if policy_loss == "grpo": return make_grpo_loss_fn( advantages, ref_logprobs, - prompt_lens, inf_logprobs=inf_logprobs, kl_beta=kl_beta, tis_weights_fn=tis_wf, + prompt_lens, inf_logprobs=inf_logprobs, + prox_logprobs=prox_logprobs, + kl_beta=kl_beta, is_config=is_config, ) raise ValueError( f"Unsupported policy_loss '{policy_loss}'. Expected one of: grpo, dapo, gspo, cispo."