diff --git a/docs/guides/prorlv2.md b/docs/guides/prorlv2.md index c028e88764..795bb0d08a 100644 --- a/docs/guides/prorlv2.md +++ b/docs/guides/prorlv2.md @@ -7,7 +7,7 @@ ProRLv2 (as used in this repo) is best thought of as **GRPO and a bundle of stab - **DAPO dynamic sampling**: skip prompt-groups with zero reward variance - **Decoupled (asymmetric) clipping**: `ratio_clip_max > ratio_clip_min` - **Token-level policy gradient loss** -- **Importance sampling correction and TIS/CE-POP** (especially helpful for MoE/backend-mismatch scenarios) +- **Importance sampling correction and TIS/ICE-POP** (especially helpful for MoE/backend-mismatch scenarios) - **Reinforce++: Decoupled local/global advantage normalization** (`reinforce_plus_plus`) - **“Stop properly” penalty** for truncated responses @@ -151,9 +151,38 @@ loss_fn: - **`truncated_importance_sampling_type`**: - `"tis"`: clamp weights to `<= truncated_importance_sampling_ratio` - `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers) + - `"seq-mask-tis"`: sequence-level geometric-mean mask + non-truncated token-level IS correction (see below) - **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). +### Seq-mask-tis: Sequence-level Geometric-Mean Mask + +`seq-mask-tis` is an alternative to ICE-POP that operates at the **sequence level** instead of per-token: + +1. For each sequence, compute the **geometric mean** of per-token IS ratios: \(\text{geo\_mean}_i = \exp\!\bigl(\frac{1}{T_i}\sum_t \log \frac{\pi_{\text{train}}(a_t)}{\pi_{\text{gen}}(a_t)}\bigr)\) +2. **Mask out** entire sequences whose geometric mean falls outside \([min, max]\). +3. For retained sequences, apply the **non-truncated** (raw) token-level IS ratios to correct per-token gradients — no clamping, no per-token filtering. + +Key differences from ICE-POP: + +| | ICE-POP | seq-mask-tis | +|---|---|---| +| Filtering granularity | per token | per sequence | +| IS correction weights | filtered (zeroed outside bounds) | raw / non-truncated | +| Reference bounds | min=0.5, max=5 | min=0.999, max=1.002 | + +```yaml +loss_fn: + use_importance_sampling_correction: true + truncated_importance_sampling_ratio: 1.002 + truncated_importance_sampling_ratio_min: 0.999 + truncated_importance_sampling_type: "seq-mask-tis" +``` + +Both ICE-POP and seq-mask-tis report a shared metric **`is_oob_ratio`** — the fraction of tokens (ICE-POP) or sequences (seq-mask-tis) that were filtered out. + +- **Reference**: [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) + ## Full Example Config (Annotated) The ProRLv2 example config is intentionally small and relies on defaults from `grpo_math_1B.yaml`. @@ -201,5 +230,6 @@ In addition to task rewards/accuracy, a few stability signals are particularly u - **GRPO**: [Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) - **REINFORCE++**: [REINFORCE++](https://arxiv.org/abs/2501.03262) - **DLER (stop properly penalty explanation)**: [DLER](https://arxiv.org/pdf/2510.15110) +- **seq-mask-tis blog**: [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) - **[NeMo RL GRPO Guide](grpo.md)** - **[NeMo RL DAPO Guide](dapo.md)** diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index c61cb5f0ce..1a275146d2 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -45,9 +45,12 @@ class ClippedPGLossConfig(TypedDict): use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool truncated_importance_sampling_ratio: float | None - # Type of truncated importance sampling: "tis" (clamp max) or "icepop" (filter [min, max]) + # Type of truncated importance sampling: + # "tis" – clamp IS weights to max + # "icepop" – zero out tokens with IS weight outside [min, max] + # "seq-mask-tis" – zero out sequences by geometric-mean IS ratio, non-truncated token IS correction truncated_importance_sampling_type: NotRequired[str | None] - # Lower bound for ICE-POP filtering (default 0.5) + # Lower bound for ICE-POP / seq-mask-tis filtering truncated_importance_sampling_ratio_min: NotRequired[float | None] token_level_loss: bool # If True, apply the off-policy importance-sampling correction at the @@ -138,11 +141,11 @@ def __init__(self, cfg: ClippedPGLossConfig): self.truncated_importance_sampling_ratio = cfg[ "truncated_importance_sampling_ratio" ] - # Type of truncated importance sampling: "tis" (clamp max) or "icepop" (filter [min, max]) + # Type of truncated importance sampling: "tis" | "icepop" | "seq-mask-tis" self.truncated_importance_sampling_type = cfg.get( "truncated_importance_sampling_type" ) - # Lower bound for ICE-POP filtering (default 0.5) + # Lower bound for ICE-POP / seq-mask-tis filtering self.truncated_importance_sampling_ratio_min = cfg.get( "truncated_importance_sampling_ratio_min" ) @@ -165,9 +168,19 @@ def __init__(self, cfg: ClippedPGLossConfig): assert self.truncated_importance_sampling_ratio > 0, ( "truncated_importance_sampling_ratio should be positive" ) - assert self.truncated_importance_sampling_type in ("tis", "icepop"), ( - f"truncated_importance_sampling_type must be 'tis' or 'icepop', got {self.truncated_importance_sampling_type}" + assert self.truncated_importance_sampling_type in ( + "tis", + "icepop", + "seq-mask-tis", + ), ( + f"truncated_importance_sampling_type must be 'tis', 'icepop', or 'seq-mask-tis', " + f"got {self.truncated_importance_sampling_type}" ) + if self.truncated_importance_sampling_type == "seq-mask-tis": + assert not self.sequence_level_importance_ratios, ( + "seq-mask-tis uses token-level IS correction with sequence-level masking, " + "and is incompatible with sequence_level_importance_ratios=True" + ) else: # Warn user that TIS-related parameters are ignored when truncated_importance_sampling_ratio is not set ignored_params = [] @@ -383,6 +396,7 @@ def __call__( # ------------------------------------------------------------- # Off-policy (actor) importance-sampling correction # ------------------------------------------------------------- + _is_filter_metrics: dict = {} # populated for icepop / seq-mask-tis # See: docs/guides/grpo.md#importance-sampling-correction if self.sequence_level_importance_ratios: # importance weight w_i = exp(Σ_t (log π_actor − log π_behaviour)) @@ -401,29 +415,81 @@ def __call__( actor_importance_weights_expanded = torch.nan_to_num( actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 ) - # Truncated Importance Sampling (TIS / ICE-POP) - # TIS: Simple clamp to max value - # ICE-POP: Filter out samples with importance weights outside [min, max] + # ---- Truncated Importance Sampling ---- + # "tis" – clamp IS weights to [0, max] + # "icepop" – zero out tokens whose IS weight ∉ [min, max] (ref bounds: 0.5–5) + # "seq-mask-tis" – zero out entire sequences whose geometric-mean + # IS ratio ∉ [min, max]; retained sequences keep + # raw (non-truncated) token-level IS weights (ref bounds: 0.999–1.002) + # Blog: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda if self.truncated_importance_sampling_ratio is not None: if self.truncated_importance_sampling_type == "tis": - # TIS: Simple clamp to max value + token_in_bounds = ( + actor_importance_weights_expanded + <= self.truncated_importance_sampling_ratio + ) + _is_filter_metrics = { + "is_oob_ratio": 1.0 + - masked_mean( + token_in_bounds.float(), + mask, + global_normalization_factor=global_valid_toks, + ).item(), + } actor_importance_weights_expanded = torch.clamp( actor_importance_weights_expanded, max=self.truncated_importance_sampling_ratio, ) - elif self.truncated_importance_sampling_type == "icepop": # icepop - # ICE-POP: Filter out samples with importance weights outside [min, max] + elif self.truncated_importance_sampling_type == "icepop": + token_kept_mask = ( + actor_importance_weights_expanded + >= self.truncated_importance_sampling_ratio_min + ) & ( + actor_importance_weights_expanded + <= self.truncated_importance_sampling_ratio + ) + _is_filter_metrics = { + "is_oob_ratio": 1.0 + - masked_mean( + token_kept_mask.float(), + mask, + global_normalization_factor=global_valid_toks, + ).item(), + } actor_importance_weights_expanded = torch.where( + token_kept_mask, + actor_importance_weights_expanded, + torch.zeros_like(actor_importance_weights_expanded), + ) + elif self.truncated_importance_sampling_type == "seq-mask-tis": + # geo_mean_i = exp( mean_t( log(π_prev / π_gen) ) ) + log_is_ratio = torch.nan_to_num( + prev_logprobs - generation_logprobs, + nan=0.0, + posinf=0.0, + neginf=0.0, + ) + seq_log_is_ratio_mean = masked_mean( + log_is_ratio, token_mask, dim=-1 + ) # [B] + seq_geomean_is_ratio = torch.exp(seq_log_is_ratio_mean).detach() # [B] + seq_kept_mask = ( ( - actor_importance_weights_expanded + seq_geomean_is_ratio >= self.truncated_importance_sampling_ratio_min ) - & ( - actor_importance_weights_expanded - <= self.truncated_importance_sampling_ratio - ), - actor_importance_weights_expanded, - torch.zeros_like(actor_importance_weights_expanded), + & (seq_geomean_is_ratio <= self.truncated_importance_sampling_ratio) + ).float() # [B] + _is_filter_metrics = { + "is_oob_ratio": 1.0 + - masked_mean( + seq_kept_mask, + sample_mask, + global_normalization_factor=global_valid_seqs, + ).item(), + } + actor_importance_weights_expanded = ( + actor_importance_weights_expanded * seq_kept_mask.unsqueeze(-1) ) else: raise ValueError( @@ -528,6 +594,7 @@ def __call__( "sampling_importance_ratio": sample_importance_ratio.item(), "num_valid_samples": sample_mask.sum().item(), "approx_entropy": seq_entropy_approx.item(), + **_is_filter_metrics, }, ) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 5be0e69c80..fbec4c8504 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -1122,6 +1122,106 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) +def test_clipped_pg_loss_icepop_importance_sampling(): + """Tests ClippedPGLossFn with ICE-POP truncated importance sampling. + + Uses reference bounds min=0.5, max=5. + """ + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = deepcopy(basic_pg_loss_test_config) + cfg["use_importance_sampling_correction"] = True + cfg["truncated_importance_sampling_ratio"] = 5.0 # max (ref) + cfg["truncated_importance_sampling_type"] = "icepop" + cfg["truncated_importance_sampling_ratio_min"] = 0.5 # min (ref) + loss_fn = ClippedPGLossFn(cfg) + + # On-policy (curr = prev) → ratios = 1, clip_loss = -adv + prev_lp = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Token 1 has very stale gen logprob → IS weight > 5 → filtered by ICE-POP + gen_lp = torch.tensor([[-0.5, -3.5, -0.8]], device=device) + adv = torch.tensor([[1.0, -1.0, 2.0]], device=device) + + data["advantages"][0, 1:] = adv + data["prev_logprobs"][0, 1:] = prev_lp + data["generation_logprobs"][0, 1:] = gen_lp + + # IS weights = exp(prev-gen) = exp([-0.5, 2.5, -0.2]) ≈ [0.6065, 12.182, 0.8187] + # ICE-POP [0.5, 5]: keep=[T, F, T] (12.182 > 5 → zeroed) + iw = torch.exp(prev_lp - gen_lp) + filtered_iw = torch.where((iw >= 0.5) & (iw <= 5.0), iw, torch.zeros_like(iw)) + expected_loss = torch.mean(filtered_iw * (-adv)) + + dummy_logits = _create_exact_logits( + prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device + ) + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) + + +def test_clipped_pg_loss_seq_mask_tis(): + """Tests ClippedPGLossFn with seq-mask-tis, including nan_to_num on -inf. + + Uses reference bounds min=0.999, max=1.002. + """ + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, batch_size, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = deepcopy(basic_pg_loss_test_config) + cfg["use_importance_sampling_correction"] = True + cfg["truncated_importance_sampling_ratio"] = 1.002 # max (ref) + cfg["truncated_importance_sampling_type"] = "seq-mask-tis" + cfg["truncated_importance_sampling_ratio_min"] = 0.999 # min (ref) + loss_fn = ClippedPGLossFn(cfg) + + # On-policy (curr = prev), gen very close to prev + # geo_mean = exp(mean([0.0005]*3)) = exp(0.0005) ≈ 1.0005 → in [0.999, 1.002] → kept + prev_lp = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + gen_lp = torch.tensor([[-1.0005, -1.0005, -1.0005]], device=device) + adv = torch.tensor([[1.0, -1.0, 2.0]], device=device) + + data["advantages"][0, 1:] = adv + data["prev_logprobs"][0, 1:] = prev_lp + data["generation_logprobs"][0, 1:] = gen_lp + + iw = torch.exp(prev_lp - gen_lp) + expected_loss = torch.mean(iw * (-adv)) + + dummy_logits = _create_exact_logits( + prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device + ) + actual_loss, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) + + # nan_to_num: inject -inf → loss must stay finite + data["generation_logprobs"][0, 2] = float("-inf") + actual_loss2, _ = loss_fn( + dummy_logits, + data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + ) + assert not torch.isnan(actual_loss2), "Loss is NaN — nan_to_num fix not working" + assert not torch.isinf(actual_loss2), "Loss is inf — nan_to_num fix not working" + + def test_masked_mean_all_zeros(): """Test masked_mean function with all zeros mask.""" values = torch.tensor([1.0, 2.0, 3.0, 4.0])