From 39dd2e42464e837dad912e89da16f9568bf9d197 Mon Sep 17 00:00:00 2001 From: szrlee Date: Mon, 27 Oct 2025 02:49:55 +0800 Subject: [PATCH 1/5] refactor(ppo): separate IS weight correction from rejection sampling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous implementation applied rejection by zeroing IS weights, which conflated two distinct mechanisms. This refactoring properly separates IS weight correction from rejection sampling to follow correct principles. This commit separates two mechanisms: IS Weights (rollout_is_weights): Always TRUE ratios π_train/π_rollout - Never zeroed, even for rejected samples - Preserved for policy gradient calculations Rejection Sampling (modified_response_mask): Applied via response_mask - Mask mode: Excludes tokens/sequences with outlier IS ratios - Veto: Excludes sequences with catastrophic tokens - Used for loss aggregation (excluded from denominator) This ensures: - Correct loss normalization (rejected samples excluded from denominator) - True IS ratios preserved for policy gradient calculations - Clear separation of concerns between IS correction and rejection Changes: - compute_rollout_importance_weights() now returns 3 values instead of 2 - Always update batch response_mask with rejection applied - Updated all tests to verify new behavior - Comprehensive documentation update with BibTeX citation Reference: When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch Liu, Li, Fu, Wang, Liu, Shen (2025) https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda --- docs/advance/rollout_is.md | 115 ++++++++++++++---- tests/trainer/ppo/test_rollout_is.py | 81 ++++++++++-- .../ppo/test_rollout_is_integration.py | 12 +- verl/trainer/ppo/mismatch_helper.py | 105 +++++++++------- verl/trainer/ppo/ray_trainer.py | 50 +++++--- 5 files changed, 267 insertions(+), 96 deletions(-) diff --git a/docs/advance/rollout_is.md b/docs/advance/rollout_is.md index 8f75c803a5c..5e070839354 100644 --- a/docs/advance/rollout_is.md +++ b/docs/advance/rollout_is.md @@ -1,14 +1,26 @@ # Rollout Importance Sampling -Last updated: 10/11/2025. +Last updated: 10/27/2025. This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl. ## References -- [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) +- [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda) - [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) +### BibTeX Citation + +```bibtex +@misc{liu-li-2025, + title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, + url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, + author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, + year = {2025}, + month = september, +} +``` + ## Overview Rollout Importance Sampling corrects for distribution mismatch between: @@ -17,6 +29,24 @@ Rollout Importance Sampling corrects for distribution mismatch between: This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases. +### Key Design Principle: Separation of IS Weights and Rejection Sampling + +**Important**: As of 10/27/2025, the implementation separates two mechanisms: + +1. **IS Weights** (`rollout_is_weights`): Always TRUE ratios π_train/π_rollout + - Never zeroed, even for rejected samples + - Preserves true importance ratios for policy gradient calculations + +2. **Rejection Sampling** (`modified_response_mask`): Applied via response_mask + - Mask mode: Excludes tokens/sequences with outlier IS ratios + - Veto: Excludes sequences with catastrophic tokens + - Used for loss aggregation (denominator calculation) + +This separation ensures: +- ✅ Correct loss normalization (rejected samples excluded from denominator) +- ✅ True IS ratios preserved (not zeroed for rejected samples) +- ✅ Padding positions still zeroed in weights (different from rejection) + ## Configuration ```yaml @@ -103,13 +133,21 @@ Aggregation level for IS weights: - `"geometric"`: Geometric mean (experimental) ### `algorithm.rollout_is_mode` (str) -Bounding mode: -- `"truncate"`: Cap weights at upper threshold only -- `"mask"`: Zero out weights outside [lower, upper] +Bounding mode for handling outlier IS weights: +- `"truncate"`: Clamp weights at upper threshold, no rejection (TIS) + - All samples used for training + - IS weights capped to prevent extreme importance ratios +- `"mask"`: Rejection sampling via response_mask (MIS) + - Rejects tokens/sequences with IS ratios outside [lower, upper] + - **Important**: Rejection applied to `response_mask`, NOT by zeroing IS weights + - IS weights remain as true ratios ### `algorithm.rollout_is_veto_threshold` (float) -Per-token veto threshold. If any token ratio < this, entire sequence is rejected. -Default: `1e-4` (ratio 10,000x off) +Per-token veto threshold for catastrophic outliers. +- If any token has ratio < this threshold, the entire sequence is rejected via `response_mask` +- Default: `1e-4` (detects ratios 10,000x off) +- **Important**: Veto applies rejection to `response_mask`, NOT by zeroing IS weights +- IS weights remain as true ratios even for vetoed sequences ## Usage @@ -159,12 +197,16 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear #### **Veto Mechanism Metrics** - **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism + - **Important**: Sequences are rejected via `response_mask=0`, NOT by zeroing IS weights + - IS weights remain as true ratios even for vetoed sequences + - Veto detects catastrophic tokens (ratio < veto_threshold, e.g., < 1e-4) - **Ideal value**: < 0.05 (less than 5% vetoed) - **Warning**: > 0.1 suggests policies are too different or numerical issues - **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold - - Identifies problematic tokens before sequence-level veto - - **Warning**: > 0.01 indicates widespread distribution issues + - Identifies problematic tokens before sequence-level veto is applied + - Each catastrophic token causes its entire sequence to be rejected + - **Warning**: > 0.01 indicates widespread distribution issues or numerical instability #### **Threshold Exceedance Metrics** @@ -197,12 +239,16 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear #### **Masking Metrics** (mask mode only) -- **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero) - - **Ideal value**: < 0.1 +- **`rollout_is_masked_fraction`**: Fraction of tokens rejected via response_mask + - **Important**: Tokens are rejected by setting `response_mask=0`, NOT by zeroing IS weights + - IS weights remain as true ratios (π_train/π_rollout) + - **Ideal value**: < 0.1 (less than 10% rejected) - **Warning**: > 0.3 means losing too much data -- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token - - Shows sequence-level impact of masking +- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one rejected token + - Shows sequence-level impact of rejection sampling + - For token-level: sequence rejected if ANY token is outside [lower, upper] + - For sequence-level: all tokens have same weight, so entire sequence rejected or accepted #### **Distribution Mismatch Metrics** (Training vs Rollout Policy) @@ -253,22 +299,43 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear # Metrics are returned from compute_rollout_importance_weights from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights -weights_proto, metrics = compute_rollout_importance_weights( +# NEW: Returns 3 values (weights, modified_response_mask, metrics) +weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( old_log_prob=training_log_probs, # from training policy rollout_log_prob=rollout_log_probs, # from rollout policy response_mask=response_mask, rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is_mode="mask", # Using mask mode for rejection sampling rollout_is_threshold=2.0, + rollout_is_threshold_lower=0.5, rollout_is_veto_threshold=1e-4, ) +# Extract IS weights (always true ratios, never zeroed) +is_weights = weights_proto.batch["rollout_is_weights"] + +# modified_response_mask has rejection applied +# - Tokens/sequences outside [0.5, 2.0] are masked to 0 +# - Sequences with catastrophic tokens are masked to 0 +# - IS weights remain as true ratios (NOT zeroed) + # All metrics have 'mismatch/' prefix print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}") print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}") print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}") +print(f"Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.3f}") print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}") +# Verify IS weights are true ratios (not zeroed) +print(f"\n✓ IS weights min: {is_weights[response_mask.bool()].min():.4f}") +print(f"✓ IS weights max: {is_weights[response_mask.bool()].max():.4f}") +print(f"✓ All IS weights > 0: {(is_weights[response_mask.bool()] > 0).all()}") + +# Check rejection via response_mask +rejected_tokens = (response_mask == 1) & (modified_response_mask == 0) +print(f"\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask") +print(f"✓ IS weights for rejected tokens are NON-ZERO (true ratios)") + # Check for warning conditions if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0: print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected") @@ -288,14 +355,15 @@ for epoch in range(num_epochs): for batch_idx, batch in enumerate(dataloader): # ... rollout phase ... - # Compute IS weights and get metrics - weights_proto, metrics = compute_rollout_importance_weights( + # Compute IS weights and get metrics (NEW: 3 return values) + weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( old_log_prob=batch.old_log_prob, rollout_log_prob=batch.rollout_log_prob, response_mask=batch.response_mask, rollout_is_level=config.rollout_is_level, rollout_is_mode=config.rollout_is_mode, rollout_is_threshold=config.rollout_is_threshold, + rollout_is_threshold_lower=config.rollout_is_threshold_lower, rollout_is_veto_threshold=config.rollout_is_veto_threshold, ) @@ -303,7 +371,10 @@ for epoch in range(num_epochs): for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step=global_step) - # Use IS weights in training + # IMPORTANT: Update batch response_mask with rejection applied + batch.response_mask = modified_response_mask + + # Use IS weights in training (true ratios, never zeroed) is_weights = weights_proto.batch["rollout_is_weights"] # ... apply weights to policy gradient ... ``` @@ -349,8 +420,8 @@ def check_rollout_is_health(metrics, config): print("✅ Rollout IS metrics look healthy") return True -# Use in training -_, metrics = compute_rollout_importance_weights(...) +# Use in training (NEW: 3 return values) +_, _, metrics = compute_rollout_importance_weights(...) is_healthy = check_rollout_is_health(metrics, config) if not is_healthy: @@ -508,8 +579,8 @@ metrics_history = { # In training loop for step in range(num_steps): - # ... compute IS weights ... - _, metrics = compute_rollout_importance_weights(...) + # ... compute IS weights ... (NEW: 3 return values) + _, _, metrics = compute_rollout_importance_weights(...) # Store metrics for key in metrics_history.keys(): diff --git a/tests/trainer/ppo/test_rollout_is.py b/tests/trainer/ppo/test_rollout_is.py index 584dec5a8dd..7acc8cc9836 100644 --- a/tests/trainer/ppo/test_rollout_is.py +++ b/tests/trainer/ppo/test_rollout_is.py @@ -49,7 +49,7 @@ def test_basic_rollout_is(): # Test token-level truncate mode print("\n1. Testing token-level truncate mode...") - weights_proto, metrics = compute_rollout_importance_weights( + weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -71,7 +71,7 @@ def test_basic_rollout_is(): # Test sequence-level mode print("\n2. Testing sequence-level mode...") - weights_seq_proto, metrics_seq = compute_rollout_importance_weights( + weights_seq_proto, _, metrics_seq = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -92,7 +92,7 @@ def test_basic_rollout_is(): # Test geometric mean mode print("\n3. Testing geometric mean mode...") - weights_geo_proto, metrics_geo = compute_rollout_importance_weights( + weights_geo_proto, _, metrics_geo = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -116,7 +116,7 @@ def test_basic_rollout_is(): rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7 eos_mask_veto = torch.ones(2, 5, device=device) - weights_veto_proto, metrics_veto = compute_rollout_importance_weights( + weights_veto_proto, modified_response_mask_veto, metrics_veto = compute_rollout_importance_weights( old_log_prob=old_log_prob_veto, rollout_log_prob=rollout_log_prob_veto, response_mask=eos_mask_veto, @@ -128,14 +128,17 @@ def test_basic_rollout_is(): weights_veto = weights_veto_proto.batch["rollout_is_weights"] print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}") - # Check that the sequence with catastrophic token has all weights zeroed - assert weights_veto[0].sum() == 0, "Sequence with catastrophic token should be vetoed" - assert weights_veto[1].sum() > 0, "Normal sequence should not be vetoed" + # KEY FIX: Veto is applied via response_mask, not by zeroing weights + # Check that weights are NON-ZERO (true ratios preserved) + assert weights_veto[0].sum() > 0, "Weights should be non-zero (true ratios preserved)" + # Check that response_mask has veto applied + assert modified_response_mask_veto[0].sum() == 0, "Vetoed sequence should have response_mask zeroed" + assert modified_response_mask_veto[1].sum() > 0, "Normal sequence should have response_mask unchanged" print(" ✓ Veto mechanism passed") # Test disabled IS (threshold=None) print("\n5. Testing disabled IS...") - weights_disabled, metrics_disabled = compute_rollout_importance_weights( + weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -143,6 +146,7 @@ def test_basic_rollout_is(): ) assert weights_disabled is None, "Should return None when threshold is None" + assert torch.equal(modified_response_mask_disabled, eos_mask), "Should return original mask unchanged" assert len(metrics_disabled) == 0, "Should return empty metrics when disabled" print(" ✓ Disabled IS passed") @@ -160,7 +164,7 @@ def test_metrics_completeness(): rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2 eos_mask = torch.ones(batch_size, seq_length, device=device) - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -264,6 +268,64 @@ def test_mismatch_metrics(): print(" ✓ Mismatch metrics work without rollout log probs") +def test_mask_mode(): + """Test mask mode applies rejection via response_mask, keeps true IS weights.""" + print("\nTesting mask mode behavior...") + + batch_size = 2 + seq_length = 5 + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Sequence 0: ratio ≈ 0.37 (below 0.5, should be rejected) + # Sequence 1: ratio ≈ 1.65 (in [0.5, 2.0], should be accepted) + old_log_prob = torch.tensor([[-2.0] * seq_length, [-2.0] * seq_length], device=device) + rollout_log_prob = torch.tensor( + [ + [-1.0] * seq_length, # exp(-2.0 - (-1.0)) = exp(-1.0) ≈ 0.37 + [-2.5] * seq_length, # exp(-2.0 - (-2.5)) = exp(0.5) ≈ 1.65 + ], + device=device, + ) + response_mask = torch.ones(batch_size, seq_length, device=device) + + weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( + old_log_prob=old_log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + rollout_is_level="token", + rollout_is_mode="mask", + rollout_is_threshold=2.0, + rollout_is_threshold_lower=0.5, + rollout_is_veto_threshold=None, + ) + + weights = weights_proto.batch["rollout_is_weights"] + + # KEY FIX: Weights should be TRUE ratios (NOT zeroed) + assert torch.all(weights[0, :] > 0), "Weights should remain as true ratios (not zeroed)" + assert torch.allclose(weights[0, 0], torch.tensor(0.368, device=device), atol=0.01), ( + "First seq ratio should be ≈0.37" + ) + assert torch.allclose(weights[1, 0], torch.tensor(1.649, device=device), atol=0.01), ( + "Second seq ratio should be ≈1.65" + ) + + # Rejection should be applied via response_mask + assert torch.all(modified_response_mask[0, :] == 0), "First sequence should be rejected via mask" + assert torch.all(modified_response_mask[1, :] == 1), "Second sequence should be accepted" + + # Verify mask metrics exist + assert "mismatch/rollout_is_masked_fraction" in metrics + assert abs(metrics["mismatch/rollout_is_masked_fraction"] - 0.5) < 0.01, "Should reject 50% of tokens" + + print(f" First seq IS weight: {weights[0, 0]:.4f} (expected ≈0.37)") + print(f" Second seq IS weight: {weights[1, 0]:.4f} (expected ≈1.65)") + print(f" First seq mask: {modified_response_mask[0, 0]:.0f} (expected 0 - rejected)") + print(f" Second seq mask: {modified_response_mask[1, 0]:.0f} (expected 1 - accepted)") + print(f" Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.2f}") + print(" ✓ Mask mode correctly separates IS weights from rejection") + + if __name__ == "__main__": print("=" * 60) print("Rollout Importance Sampling Test Suite") @@ -273,6 +335,7 @@ def test_mismatch_metrics(): test_basic_rollout_is() test_metrics_completeness() test_mismatch_metrics() + test_mask_mode() print("\n" + "=" * 60) print("ALL TESTS PASSED ✓") print("=" * 60) diff --git a/tests/trainer/ppo/test_rollout_is_integration.py b/tests/trainer/ppo/test_rollout_is_integration.py index abcbcb70502..b96fb77523f 100644 --- a/tests/trainer/ppo/test_rollout_is_integration.py +++ b/tests/trainer/ppo/test_rollout_is_integration.py @@ -61,7 +61,7 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is): This test simulates that workflow. """ # First compute IS weights (as trainer would do centrally) - rollout_is_weights_proto, _ = compute_rollout_importance_weights( + rollout_is_weights_proto, _, _ = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -92,7 +92,7 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is): def test_rollout_is_weights_computation(self, sample_data): """Test rollout IS weights and metrics computation.""" - weights_proto, metrics = compute_rollout_importance_weights( + weights_proto, _, metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -120,7 +120,7 @@ def test_all_aggregation_levels(self, sample_data): levels = ["token", "sequence", "geometric"] for level in levels: - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -136,7 +136,7 @@ def test_both_bounding_modes(self, sample_data): modes = ["truncate", "mask"] for mode in modes: - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -175,7 +175,7 @@ def test_veto_mechanism(self): response_mask = torch.ones(batch_size, seq_length, device=device) - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask, @@ -196,7 +196,7 @@ def test_metrics_only_mode(self, sample_data, config_with_rollout_is): but rollout_is=False (disables weight application to policy loss). """ # Compute IS weights (as trainer would do) - rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights( + rollout_is_weights_proto, _, is_metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], diff --git a/verl/trainer/ppo/mismatch_helper.py b/verl/trainer/ppo/mismatch_helper.py index 78bb5d6accb..6b649d1b48d 100644 --- a/verl/trainer/ppo/mismatch_helper.py +++ b/verl/trainer/ppo/mismatch_helper.py @@ -53,45 +53,55 @@ def compute_rollout_importance_weights( rollout_is_threshold: Optional[float] = None, rollout_is_threshold_lower: Optional[float] = None, rollout_is_veto_threshold: Optional[float] = 1e-4, -) -> tuple[Optional[DataProto], dict[str, Any]]: - """Compute importance sampling weights and metrics for rollout-training mismatch correction. +) -> tuple[Optional[DataProto], torch.Tensor, dict[str, Any]]: + """Compute importance sampling weights and rejection mask for rollout-training mismatch. - This function handles the computation of importance sampling (IS) weights to correct - for the distribution mismatch between rollout policy and training policy. + This function computes IS weights to correct for distribution mismatch between rollout + and training policies, and applies rejection sampling for outliers. + + Key Design: Separation of IS Weights and Rejection Sampling + - IS weights (rollout_is_weights): Always TRUE ratios π_train/π_rollout + Preserved for policy gradient calculations + - Response mask (modified_response_mask): Has rejection applied (mask mode + veto) + Used for loss aggregation to exclude rejected samples from training Reference: When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda - Memory-efficient implementation that prevents CUDA OOM by: - - Using log-space computation where possible - - Applying safety bounds to prevent numerical overflow - - Computing metrics without creating huge intermediate tensors + Memory-efficient implementation: + - Log-space computation to prevent overflow + - Safety bounds (exp(±20)) on all exponentiations + - Metrics computed without large intermediate tensors Args: - old_log_prob: Log probabilities from training policy (e.g., FSDP), shape (batch_size, seq_length) - rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM), shape (batch_size, seq_length) - response_mask: Mask for valid tokens, shape (batch_size, seq_length) - rollout_is_level: Level of IS aggregation: - - "token": Per-token ratios (biased) - - "sequence": Product of ratios (unbiased) - - "geometric": Geometric mean of ratios (experimental) - rollout_is_mode: How to handle weights exceeding threshold: - - "truncate": Cap weights at upper_threshold only - - "mask": Zero out weights outside [lower_threshold, upper_threshold] - rollout_is_threshold: Upper threshold for IS weights - rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper) - rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence. - If None, veto mechanism is disabled. + old_log_prob: Log probs from training policy (FSDP FP32), shape (batch_size, seq_length) + rollout_log_prob: Log probs from rollout policy (vLLM BF16), shape (batch_size, seq_length) + response_mask: Valid token mask (1=valid, 0=padding), shape (batch_size, seq_length) + rollout_is_level: IS weight aggregation level + - "token": Per-token ratios ρ_t = π_train(t)/π_rollout(t) (biased but low variance) + - "sequence": Sequence product ρ_seq = ∏ρ_t (unbiased but high variance) + - "geometric": Geometric mean ρ_geo = (∏ρ_t)^(1/T) (experimental trade-off) + rollout_is_mode: Treatment of outlier IS weights + - "truncate": Clamp weights at [1, upper], no rejection (TIS) + - "mask": Reject tokens/sequences outside [lower, upper] via response_mask (MIS/rejection sampling) + rollout_is_threshold: Upper threshold for IS weights (required, e.g., 2.0) + rollout_is_threshold_lower: Lower threshold for mask mode (if None, defaults to 1/upper) + rollout_is_veto_threshold: Catastrophic token threshold. If any token has ratio < this, + reject entire sequence. If None, veto disabled. Default 1e-4. Returns: - Tuple of (weights_proto, metrics) where: - weights_proto: DataProto containing IS weights with key "rollout_is_weights", - shape (batch_size, seq_length). Returns None if rollout_is_threshold is None. - metrics: Dictionary of IS statistics and mismatch metrics (KL, PPL, etc.), - all converted to scalars and prefixed with "mismatch/" + Tuple of (weights_proto, modified_response_mask, metrics): + weights_proto: DataProto with TRUE IS weights (never zeroed), key "rollout_is_weights", + shape (batch_size, seq_length). None if rollout_is_threshold is None. + modified_response_mask: Response mask with rejection applied: + - truncate mode: same as input (no rejection) + - mask mode: tokens outside [lower, upper] masked to 0 + - veto: sequences with catastrophic tokens masked to 0 + Shape (batch_size, seq_length). + metrics: Dict of IS and mismatch metrics, all scalars with "mismatch/" prefix """ if rollout_is_threshold is None: - return None, {} + return None, response_mask, {} # Parse thresholds: if lower not specified, use 1/upper (reciprocal) upper_threshold = rollout_is_threshold @@ -179,38 +189,49 @@ def compute_rollout_importance_weights( SAFETY_BOUND=SAFETY_BOUND, ) - # Step 3: Apply truncation or masking based on mode + # Step 3: Apply outlier handling and rejection sampling + # Key design principle: IS weights and rejection are separate mechanisms + # - rollout_is_weights: Always TRUE ratios π_train/π_rollout (never zeroed) + # Preserved for policy gradient calculations + # - modified_response_mask: Has rejection applied (excludes outliers from training) + # Used for loss denominator: ensures rejected samples don't dilute gradients + if rollout_is_mode == "truncate": - # Truncate mode: only cap upper bound to prevent overweighting + # Truncated IS (TIS): clamp weights to prevent extreme importance ratios + # No rejection - all samples used for training rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold) + modified_response_mask = response_mask # Return unchanged elif rollout_is_mode == "mask": - # Mask mode: zero out weights outside [lower_threshold, upper_threshold] + # Masked IS (MIS): rejection sampling for outlier IS weights + # Reject tokens/sequences with IS ratios outside [lower, upper] mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) mask = mask.float() - # Track MIS-specific metrics + # Compute rejection rate metrics metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask) - - # Sequence-level masking fraction if rollout_is_level in ["sequence", "geometric"]: - # All tokens in a sequence have the same weight, so reuse mask + # Sequence-level: all tokens have same weight, check first token metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean() else: - # Check if any token in each sequence is masked + # Token-level: sequence rejected if ANY token is rejected seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0 metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean() - rollout_is_weights = rollout_is_weights * mask + # Apply rejection via response_mask (NOT by zeroing IS weights) + modified_response_mask = response_mask * mask + # rollout_is_weights kept as true ratios (unchanged) else: raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.") - # Apply veto mask AFTER all thresholding - # This zeros out entire sequences that have any catastrophic token - rollout_is_weights = rollout_is_weights * veto_mask + # Apply veto: reject entire sequences with catastrophic tokens (ratio < veto_threshold) + # Veto is independent of mask mode - it applies to modified_response_mask after mask rejection + modified_response_mask = modified_response_mask * veto_mask + # rollout_is_weights still unchanged (true ratios preserved) - # Apply response_mask to ensure weights are 0 where mask is 0 + # Apply original response_mask to zero out padding positions in IS weights + # This is different from rejection - padding must be zeroed for correct aggregation rollout_is_weights = rollout_is_weights * response_mask # Wrap in DataProto for consistency with worker methods @@ -231,7 +252,7 @@ def compute_rollout_importance_weights( else: metrics_scalar[f"mismatch/{key}"] = value - return rollout_is_weights_proto, metrics_scalar + return rollout_is_weights_proto, modified_response_mask, metrics_scalar def compute_is_metrics( diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 7e9e6981f3c..0102b68e4a9 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -952,41 +952,57 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle metrics.update(global_balance_stats) def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) -> tuple[DataProto, dict]: - """Compute rollout importance sampling weights and mismatch metrics, conditionally add weights to batch. + """Compute IS weights and apply rejection sampling for rollout-training mismatch. - This method computes IS weights to correct for distribution mismatch between - rollout policy and training policy. It always computes metrics when enabled, but - only adds weights to batch if algorithm.rollout_is is True. + Computes importance sampling weights to correct for distribution mismatch between + rollout and training policies. Applies rejection sampling (mask mode/veto) by + modifying response_mask. Always updates response_mask; conditionally adds IS weights. + + Key behavior: + - response_mask: ALWAYS updated with rejection (mask mode + veto excluded from training) + - rollout_is_weights: Added to batch ONLY if config.algorithm.rollout_is=True + + This separation ensures: + - Rejection works even when IS weights are disabled (rollout_is=False) + - Metrics can be monitored before enabling IS weight application Args: - batch: DataProto containing old_log_probs, rollout_log_probs, response_mask + batch: DataProto with old_log_probs, rollout_log_probs, response_mask Returns: - Tuple of (updated_batch, metrics) where: - - updated_batch: Batch with rollout_is_weights added (if rollout_is=True) - - metrics: Dictionary of IS and mismatch metrics (all with mismatch/ prefix) + Tuple of (updated_batch, metrics): + updated_batch: Batch with modified response_mask (always) and rollout_is_weights (if rollout_is=True) + metrics: Dict of IS and mismatch metrics, all with "mismatch/" prefix """ # Compute rollout IS weights if enabled and data is available - # rollout_is_threshold is the main on/off switch - if self.config.algorithm.rollout_is_threshold is not None and "rollout_log_probs" in batch.batch: - rollout_is_weights, rollout_is_metrics = compute_rollout_importance_weights( + # rollout_is_threshold is the main on/off switch (None = disabled, float = enabled) + rollout_is_threshold = self.config.algorithm.get("rollout_is_threshold", None) + if rollout_is_threshold is not None and rollout_is_threshold > 0 and "rollout_log_probs" in batch.batch: + # Compute IS weights and get modified response_mask + rollout_is_weights, modified_response_mask, rollout_is_metrics = compute_rollout_importance_weights( old_log_prob=batch.batch["old_log_probs"], rollout_log_prob=batch.batch["rollout_log_probs"], response_mask=batch.batch["response_mask"], rollout_is_level=self.config.algorithm.rollout_is_level, rollout_is_mode=self.config.algorithm.rollout_is_mode, rollout_is_threshold=self.config.algorithm.rollout_is_threshold, - rollout_is_threshold_lower=self.config.algorithm.rollout_is_threshold_lower, - rollout_is_veto_threshold=self.config.algorithm.rollout_is_veto_threshold, + rollout_is_threshold_lower=self.config.algorithm.get("rollout_is_threshold_lower", None), + rollout_is_veto_threshold=self.config.algorithm.get("rollout_is_veto_threshold", 1e-4), ) - # Control: Should we apply weights to policy loss? - # True = add weights to batch (actor will apply them) - # False = don't add weights (metrics only, no loss modification) + # ALWAYS update response_mask with rejection (even if rollout_is=False) + # - Mask mode: tokens with outlier IS ratios excluded + # - Veto: sequences with catastrophic tokens excluded + # This ensures correct loss normalization (rejected samples not in denominator) + batch.batch["response_mask"] = modified_response_mask + + # Conditionally add IS weights based on rollout_is config flag + # - rollout_is=True: Enable IS weight correction in policy loss + # - rollout_is=False: Metrics-only mode (rejection still applied via mask) apply_weights = self.config.algorithm.get("rollout_is", False) if apply_weights: - # Add IS weights to batch for distribution to workers + # Add TRUE IS weights (never zeroed, always π_train/π_rollout) batch = batch.union(rollout_is_weights) return batch, rollout_is_metrics From a5aa743969eb42cf61863b393596ca7bc786ab70 Mon Sep 17 00:00:00 2001 From: szrlee Date: Mon, 27 Oct 2025 04:45:38 +0800 Subject: [PATCH 2/5] docs: clarify truncate mode and veto independence Fixed two documentation issues: 1. Truncate mode only clamps upper bound (not [1, upper]) 2. Veto applies independently of rollout_is_mode The previous documentation was misleading: - Stated 'no rejection' for truncate mode (veto can still reject) - Stated clamp at [1, upper] (only upper is clamped) Changes: - Clarified truncate only clamps max (no lower bound) - Emphasized veto applies in both truncate and mask modes - Updated docstring, docs, and in-code comments - Prevents silent data loss when using truncate mode --- docs/advance/rollout_is.md | 23 ++++++++++++++--------- verl/trainer/ppo/mismatch_helper.py | 13 +++++++------ 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/docs/advance/rollout_is.md b/docs/advance/rollout_is.md index 5e070839354..8b9f6072e12 100644 --- a/docs/advance/rollout_is.md +++ b/docs/advance/rollout_is.md @@ -1,14 +1,11 @@ # Rollout Importance Sampling +**Author:** [Yingru Li](https://richardli.xyz/) + Last updated: 10/27/2025. This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl. -## References - -- [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda) -- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) - ### BibTeX Citation ```bibtex @@ -134,19 +131,22 @@ Aggregation level for IS weights: ### `algorithm.rollout_is_mode` (str) Bounding mode for handling outlier IS weights: -- `"truncate"`: Clamp weights at upper threshold, no rejection (TIS) - - All samples used for training - - IS weights capped to prevent extreme importance ratios +- `"truncate"`: Clamp weights at upper threshold only (TIS) + - No lower bound clamping or rejection for outlier ratios + - IS weights capped at upper threshold to prevent extreme importance ratios + - **Note**: Veto-based rejection can still occur (see `rollout_is_veto_threshold`) - `"mask"`: Rejection sampling via response_mask (MIS) - Rejects tokens/sequences with IS ratios outside [lower, upper] - **Important**: Rejection applied to `response_mask`, NOT by zeroing IS weights - IS weights remain as true ratios + - **Note**: Veto-based rejection also applies (independent mechanism) ### `algorithm.rollout_is_veto_threshold` (float) Per-token veto threshold for catastrophic outliers. - If any token has ratio < this threshold, the entire sequence is rejected via `response_mask` - Default: `1e-4` (detects ratios 10,000x off) -- **Important**: Veto applies rejection to `response_mask`, NOT by zeroing IS weights +- **Important**: Applied **independently** of `rollout_is_mode` (works in both truncate and mask modes) +- Veto applies rejection to `response_mask`, NOT by zeroing IS weights - IS weights remain as true ratios even for vetoed sequences ## Usage @@ -627,3 +627,8 @@ Rollout Importance Sampling provides: - ✅ Comprehensive metrics for monitoring - ✅ Flexibility for different scenarios - ✅ Memory-efficient computation + +## References + +- [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda) +- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) \ No newline at end of file diff --git a/verl/trainer/ppo/mismatch_helper.py b/verl/trainer/ppo/mismatch_helper.py index 6b649d1b48d..636e0508655 100644 --- a/verl/trainer/ppo/mismatch_helper.py +++ b/verl/trainer/ppo/mismatch_helper.py @@ -82,21 +82,22 @@ def compute_rollout_importance_weights( - "sequence": Sequence product ρ_seq = ∏ρ_t (unbiased but high variance) - "geometric": Geometric mean ρ_geo = (∏ρ_t)^(1/T) (experimental trade-off) rollout_is_mode: Treatment of outlier IS weights - - "truncate": Clamp weights at [1, upper], no rejection (TIS) + - "truncate": Clamp weights at upper threshold only. No rejection for outlier ratios, + but veto can still apply (TIS) - "mask": Reject tokens/sequences outside [lower, upper] via response_mask (MIS/rejection sampling) rollout_is_threshold: Upper threshold for IS weights (required, e.g., 2.0) rollout_is_threshold_lower: Lower threshold for mask mode (if None, defaults to 1/upper) rollout_is_veto_threshold: Catastrophic token threshold. If any token has ratio < this, - reject entire sequence. If None, veto disabled. Default 1e-4. + reject entire sequence. Applied independently of rollout_is_mode. If None, veto disabled. Default 1e-4. Returns: Tuple of (weights_proto, modified_response_mask, metrics): weights_proto: DataProto with TRUE IS weights (never zeroed), key "rollout_is_weights", shape (batch_size, seq_length). None if rollout_is_threshold is None. modified_response_mask: Response mask with rejection applied: - - truncate mode: same as input (no rejection) + - truncate mode: unchanged for outlier ratios, but veto rejection still applied - mask mode: tokens outside [lower, upper] masked to 0 - - veto: sequences with catastrophic tokens masked to 0 + - veto: sequences with catastrophic tokens masked to 0 (applied in both modes) Shape (batch_size, seq_length). metrics: Dict of IS and mismatch metrics, all scalars with "mismatch/" prefix """ @@ -198,9 +199,9 @@ def compute_rollout_importance_weights( if rollout_is_mode == "truncate": # Truncated IS (TIS): clamp weights to prevent extreme importance ratios - # No rejection - all samples used for training + # No rejection for outlier ratios (mask unchanged), but veto can still apply below rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold) - modified_response_mask = response_mask # Return unchanged + modified_response_mask = response_mask # Unchanged for outlier ratios elif rollout_is_mode == "mask": # Masked IS (MIS): rejection sampling for outlier IS weights From ab3e8af0aa25a355fe4646a1f9e70cc590969ed0 Mon Sep 17 00:00:00 2001 From: szrlee Date: Mon, 27 Oct 2025 16:02:37 +0800 Subject: [PATCH 3/5] fix(ppo): exclude fully masked sequences from seq-mean loss Fixed seq-mean-token-sum and seq-mean-token-mean modes to exclude fully masked sequences from denominator using masked_mean, and added epsilon to prevent division by zero. --- verl/trainer/ppo/core_algos.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 9ec460aba2e..7c2bfd53673 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -788,10 +788,13 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str loss = verl_F.masked_mean(loss_mat, loss_mask) elif loss_agg_mode == "seq-mean-token-sum": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum - loss = torch.mean(seq_losses) # seq-mean + seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences + loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean elif loss_agg_mode == "seq-mean-token-mean": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean - loss = torch.mean(seq_losses) # seq-mean + seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean + seq_mask = (seq_mask > 0).float() # exclude fully masked sequences + loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean elif loss_agg_mode == "seq-mean-token-sum-norm": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor From 10350b613de7625d759a48e582573749484ef861 Mon Sep 17 00:00:00 2001 From: szrlee Date: Mon, 27 Oct 2025 16:02:44 +0800 Subject: [PATCH 4/5] feat(rollout_is): set veto threshold default to None Changed rollout_is_veto_threshold default from 1e-4 to None, making the veto mechanism opt-in across 11 files (configs, runtime, docs). --- docs/advance/rollout_is.md | 10 ++++++---- docs/examples/config.rst | 6 +++--- examples/rollout_importance_sampling/README.md | 12 ++++++------ .../run_with_rollout_is.sh | 4 ++-- .../config/_generated_ppo_megatron_trainer.yaml | 2 +- verl/trainer/config/_generated_ppo_trainer.yaml | 2 +- verl/trainer/config/algorithm.py | 4 ++-- verl/trainer/config/ppo_megatron_trainer.yaml | 4 ++-- verl/trainer/config/ppo_trainer.yaml | 4 ++-- verl/trainer/ppo/mismatch_helper.py | 4 ++-- verl/trainer/ppo/ray_trainer.py | 2 +- 11 files changed, 28 insertions(+), 26 deletions(-) diff --git a/docs/advance/rollout_is.md b/docs/advance/rollout_is.md index 8b9f6072e12..2b3f54f9c05 100644 --- a/docs/advance/rollout_is.md +++ b/docs/advance/rollout_is.md @@ -56,7 +56,7 @@ algorithm: rollout_is_threshold_lower: null # Auto-reciprocal rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: null # Disable veto by default # REQUIRED: Enable log prob calculation actor_rollout_ref: @@ -141,10 +141,12 @@ Bounding mode for handling outlier IS weights: - IS weights remain as true ratios - **Note**: Veto-based rejection also applies (independent mechanism) -### `algorithm.rollout_is_veto_threshold` (float) +### `algorithm.rollout_is_veto_threshold` (float or None) Per-token veto threshold for catastrophic outliers. - If any token has ratio < this threshold, the entire sequence is rejected via `response_mask` -- Default: `1e-4` (detects ratios 10,000x off) +- Default: `None` (veto disabled by default) +- Recommended: `1e-4` to `1e-6` when enabled (catches extreme outliers like 10,000x off) +- Set to `None` to disable veto mechanism - **Important**: Applied **independently** of `rollout_is_mode` (works in both truncate and mask modes) - Veto applies rejection to `response_mask`, NOT by zeroing IS weights - IS weights remain as true ratios even for vetoed sequences @@ -308,7 +310,7 @@ weights_proto, modified_response_mask, metrics = compute_rollout_importance_weig rollout_is_mode="mask", # Using mask mode for rejection sampling rollout_is_threshold=2.0, rollout_is_threshold_lower=0.5, - rollout_is_veto_threshold=1e-4, + rollout_is_veto_threshold=1e-4, # Enable veto for catastrophic outliers ) # Extract IS weights (always true ratios, never zeroed) diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 387f197f494..45cc8b14c60 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -133,7 +133,7 @@ Actor/Rollout/Reference Policy rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper) rollout_is_level: token # Aggregation: token/sequence/geometric rollout_is_mode: truncate # Bounding: truncate/mask - rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold + rollout_is_veto_threshold: null # Catastrophic outlier threshold (null to disable) use_torch_compile: True # False to disable torch compile kl_loss_coef: 0.001 # for grpo kl_loss_type: low_var_kl # for grpo @@ -519,7 +519,7 @@ Algorithm rollout_is_threshold_lower: null rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: null # Disabled by default - ``gamma``: discount factor - ``lam``: Trade-off between bias and variance in the GAE estimator @@ -537,7 +537,7 @@ Algorithm - ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper). - ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental). - ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds). -- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4. +- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is null (disabled). Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``. Trainer diff --git a/examples/rollout_importance_sampling/README.md b/examples/rollout_importance_sampling/README.md index 1d0ce66394e..7baf55ebf2e 100644 --- a/examples/rollout_importance_sampling/README.md +++ b/examples/rollout_importance_sampling/README.md @@ -61,7 +61,7 @@ bash examples/rollout_importance_sampling/run_with_rollout_is.sh - `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.** - `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false. - `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper) -- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4) +- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: null, disabled) ## Configuration Examples @@ -73,7 +73,7 @@ algorithm: rollout_is: true # Apply to loss rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: null # Disabled by default ``` ### Example 2: Metrics Only (No Weight Application) @@ -95,7 +95,7 @@ algorithm: rollout_is_threshold_lower: 0.9998 rollout_is_level: geometric rollout_is_mode: mask - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: 1e-4 # Enable veto for this example ``` ### Example 4: Sequence-level with Truncate @@ -107,7 +107,7 @@ algorithm: rollout_is_threshold_lower: null # Auto-reciprocal: 0.2 rollout_is_level: sequence rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: 1e-4 # Enable veto for this example ``` ### Example 5: Asymmetric Thresholds @@ -226,8 +226,8 @@ The veto mechanism zeros out entire sequences containing catastrophic outliers: - If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected - This prevents extreme outliers from dominating training -- Default threshold: 1e-4 (ratio 10,000x off) -- Set to `null` to disable: `rollout_is_veto_threshold: null` +- Default: `null` (disabled by default) +- Set to `1e-4` to enable (catches ratios 10,000x off) ## Examples diff --git a/examples/rollout_importance_sampling/run_with_rollout_is.sh b/examples/rollout_importance_sampling/run_with_rollout_is.sh index 93ebbe4fdc2..42c4a2a5981 100755 --- a/examples/rollout_importance_sampling/run_with_rollout_is.sh +++ b/examples/rollout_importance_sampling/run_with_rollout_is.sh @@ -24,8 +24,8 @@ rollout_is_level=token # Bounding mode: truncate (cap upper) | mask (zero outside bounds) rollout_is_mode=truncate -# Catastrophic outlier veto threshold -rollout_is_veto_threshold=1e-4 +# Catastrophic outlier veto threshold (set to null to disable, or e.g., 1e-4 to enable) +rollout_is_veto_threshold=null # ============================================================================== # Model and Data Configuration diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 5b8dc20c4b0..656cc0b731b 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -490,7 +490,7 @@ algorithm: rollout_is_threshold_lower: null rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 0.0001 + rollout_is_veto_threshold: null rollout_is: false trainer: balance_batch: true diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index c9bccaf8fe0..3ec8e229fb0 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -476,7 +476,7 @@ algorithm: rollout_is_threshold_lower: null rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 0.0001 + rollout_is_veto_threshold: null rollout_is: false trainer: balance_batch: true diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index 859ce90ff32..bd6763226f7 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -78,7 +78,7 @@ class AlgoConfig(BaseConfig): rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper. rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric". rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds). - rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers. + rollout_is_veto_threshold (float or None): Per-token veto threshold for catastrophic outliers. None to disable. rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights, False = compute metrics only (useful for monitoring before enabling correction). Default: False. """ @@ -99,7 +99,7 @@ class AlgoConfig(BaseConfig): rollout_is_threshold_lower: Optional[float] = None rollout_is_level: str = "token" rollout_is_mode: str = "truncate" - rollout_is_veto_threshold: Optional[float] = 1e-4 + rollout_is_veto_threshold: Optional[float] = None # Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set) # True = apply weights to loss, False = compute metrics only (no weight application) rollout_is: bool = False diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 8a23a3b72d4..fb5eef5d6fc 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -87,8 +87,8 @@ algorithm: # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds) rollout_is_mode: truncate - # Per-token veto threshold for catastrophic outliers - rollout_is_veto_threshold: 1e-4 + # Per-token veto threshold for catastrophic outliers (null to disable) + rollout_is_veto_threshold: null # Whether to apply IS weights to policy loss # true = apply weights to loss, false = compute metrics only (no weight application) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 72c97e4a7cd..d9d250c9a48 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -127,8 +127,8 @@ algorithm: # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds) rollout_is_mode: truncate - # Per-token veto threshold for catastrophic outliers - rollout_is_veto_threshold: 1e-4 + # Per-token veto threshold for catastrophic outliers (null to disable) + rollout_is_veto_threshold: null # Whether to apply IS weights to policy loss # true = apply weights to loss, false = compute metrics only (no weight application) diff --git a/verl/trainer/ppo/mismatch_helper.py b/verl/trainer/ppo/mismatch_helper.py index 636e0508655..b70229c0030 100644 --- a/verl/trainer/ppo/mismatch_helper.py +++ b/verl/trainer/ppo/mismatch_helper.py @@ -52,7 +52,7 @@ def compute_rollout_importance_weights( rollout_is_mode: str = "truncate", rollout_is_threshold: Optional[float] = None, rollout_is_threshold_lower: Optional[float] = None, - rollout_is_veto_threshold: Optional[float] = 1e-4, + rollout_is_veto_threshold: Optional[float] = None, ) -> tuple[Optional[DataProto], torch.Tensor, dict[str, Any]]: """Compute importance sampling weights and rejection mask for rollout-training mismatch. @@ -88,7 +88,7 @@ def compute_rollout_importance_weights( rollout_is_threshold: Upper threshold for IS weights (required, e.g., 2.0) rollout_is_threshold_lower: Lower threshold for mask mode (if None, defaults to 1/upper) rollout_is_veto_threshold: Catastrophic token threshold. If any token has ratio < this, - reject entire sequence. Applied independently of rollout_is_mode. If None, veto disabled. Default 1e-4. + reject entire sequence. Applied independently of rollout_is_mode. If None, veto disabled. Default None. Returns: Tuple of (weights_proto, modified_response_mask, metrics): diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 0102b68e4a9..ef2abb5e920 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -987,7 +987,7 @@ def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) rollout_is_mode=self.config.algorithm.rollout_is_mode, rollout_is_threshold=self.config.algorithm.rollout_is_threshold, rollout_is_threshold_lower=self.config.algorithm.get("rollout_is_threshold_lower", None), - rollout_is_veto_threshold=self.config.algorithm.get("rollout_is_veto_threshold", 1e-4), + rollout_is_veto_threshold=self.config.algorithm.get("rollout_is_veto_threshold", None), ) # ALWAYS update response_mask with rejection (even if rollout_is=False) From 359c9661013fc6bd86cbd96ed989c77a07855bb4 Mon Sep 17 00:00:00 2001 From: szrlee Date: Mon, 27 Oct 2025 20:11:56 +0800 Subject: [PATCH 5/5] docs(rollout_is): clarify IS weight processing and operation modes Fix inaccurate documentation about IS weight processing: - IS weights are safety-bounded to [exp(-20), exp(20)], not "true ratios" - IS weights ARE zeroed at padding (not "never zeroed") - Truncate mode: safety-bounded + upper clamped - Mask mode: safety-bounded only (no threshold clamping) - Veto checks unclamped ratios before safety bounds Add "Operation Modes" section documenting independent control flags: - rollout_is_threshold: main on/off switch - rollout_is: controls IS weight application to loss - Rejection sampling (mask mode) applies regardless of rollout_is flag - Include mode combinations table and recommended workflow Update terminology throughout: - "safety-bounded ratios" replaces "true ratios" for mask mode - Update code comments in ray_trainer.py and test files --- docs/advance/rollout_is.md | 189 ++++++++++++++++++++++----- tests/trainer/ppo/test_rollout_is.py | 8 +- verl/trainer/ppo/mismatch_helper.py | 48 ++++--- verl/trainer/ppo/ray_trainer.py | 2 +- 4 files changed, 192 insertions(+), 55 deletions(-) diff --git a/docs/advance/rollout_is.md b/docs/advance/rollout_is.md index 2b3f54f9c05..dd282b58dbc 100644 --- a/docs/advance/rollout_is.md +++ b/docs/advance/rollout_is.md @@ -30,9 +30,15 @@ This mismatch can lead to biased gradient estimates and unstable training. Rollo **Important**: As of 10/27/2025, the implementation separates two mechanisms: -1. **IS Weights** (`rollout_is_weights`): Always TRUE ratios π_train/π_rollout - - Never zeroed, even for rejected samples - - Preserves true importance ratios for policy gradient calculations +1. **IS Weights** (`rollout_is_weights`): Ratios π_train/π_rollout with processing: + - **Safety-bounded** to [exp(-20), exp(20)] ≈ [2e-9, 5e8] to prevent overflow: + * Token level: Bounds per-token ratios + * Sequence level: Bounds product of ratios (broadcast to all tokens in sequence) + * Geometric level: Bounds geometric mean of ratios (broadcast to all tokens) + - **Truncate mode**: Upper clamped via .clamp(max=upper_threshold) + - **Mask mode**: Safety-bounded ratios preserved (no threshold clamping) + - **All modes**: Zeroed at padding positions (response_mask == 0) + - Used for policy gradient calculations 2. **Rejection Sampling** (`modified_response_mask`): Applied via response_mask - Mask mode: Excludes tokens/sequences with outlier IS ratios @@ -41,8 +47,9 @@ This mismatch can lead to biased gradient estimates and unstable training. Rollo This separation ensures: - ✅ Correct loss normalization (rejected samples excluded from denominator) -- ✅ True IS ratios preserved (not zeroed for rejected samples) -- ✅ Padding positions still zeroed in weights (different from rejection) +- ✅ Mode-specific weight processing (truncate: upper clamped, mask: safety-bounded only) +- ✅ Padding positions zeroed in weights (necessary for correct aggregation) +- ✅ Safety bounds always applied (prevent overflow in all modes) ## Configuration @@ -113,7 +120,9 @@ Key features: ### `algorithm.rollout_is` (bool) Whether to apply IS weights to policy loss. Default: `False` - `true` = apply weights to loss (full IS correction) -- `false` = compute metrics only (useful for monitoring before enabling) +- `false` = metrics only mode (no weight correction, but rejection still applies) + +**IMPORTANT**: This flag controls IS weight application, NOT rejection sampling. See "Operation Modes" below. **Recommended threshold ranges:** - Token level: 1.5 - 5.0 @@ -125,31 +134,120 @@ Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal). ### `algorithm.rollout_is_level` (str) Aggregation level for IS weights: -- `"token"`: Per-token ratios -- `"sequence"`: Product of ratios -- `"geometric"`: Geometric mean (experimental) +- `"token"`: Per-token ratios ρ_t = π_train(t)/π_rollout(t) + - Each token has its own IS weight + - Safety bound: each token's ratio bounded to [exp(-20), exp(20)] + - Biased estimator but low variance +- `"sequence"`: Product of ratios ρ_seq = ∏_t ρ_t for entire sequence + - All tokens in a sequence share the same IS weight (product of per-token ratios) + - Safety bound: product bounded to [exp(-20), exp(20)], then broadcast to all tokens + - Unbiased estimator but high variance +- `"geometric"`: Geometric mean ρ_geo = (∏_t ρ_t)^(1/T) (experimental) + - All tokens in a sequence share the same IS weight (geometric mean) + - Safety bound: geometric mean bounded to [exp(-20), exp(20)], then broadcast to all tokens + - Trade-off between bias and variance ### `algorithm.rollout_is_mode` (str) Bounding mode for handling outlier IS weights: - `"truncate"`: Clamp weights at upper threshold only (TIS) - No lower bound clamping or rejection for outlier ratios - - IS weights capped at upper threshold to prevent extreme importance ratios - - **Note**: Veto-based rejection can still occur (see `rollout_is_veto_threshold`) + - **IS weights modified**: Upper bound clamped via .clamp(max=upper_threshold) + - Lower bound remains at exp(-20) ≈ 2e-9 from safety bound + - **Note**: Veto-based rejection can still occur via response_mask (see `rollout_is_veto_threshold`) - `"mask"`: Rejection sampling via response_mask (MIS) - Rejects tokens/sequences with IS ratios outside [lower, upper] - - **Important**: Rejection applied to `response_mask`, NOT by zeroing IS weights - - IS weights remain as true ratios - - **Note**: Veto-based rejection also applies (independent mechanism) + - **Important**: Rejection applied to `response_mask`, NOT by modifying IS weights + - **IS weights**: Safety-bounded ratios preserved (no threshold clamping, rejection via mask) + - **Note**: Veto-based rejection also applies via response_mask (independent mechanism) ### `algorithm.rollout_is_veto_threshold` (float or None) Per-token veto threshold for catastrophic outliers. -- If any token has ratio < this threshold, the entire sequence is rejected via `response_mask` +- If any token has **unclamped** ratio < this threshold, the entire sequence is rejected via `response_mask` +- Veto checks the **true per-token ratio** π_train(t)/π_rollout(t) before any bounds are applied +- Applied for all levels (token, sequence, geometric) - always checks individual token ratios - Default: `None` (veto disabled by default) - Recommended: `1e-4` to `1e-6` when enabled (catches extreme outliers like 10,000x off) - Set to `None` to disable veto mechanism - **Important**: Applied **independently** of `rollout_is_mode` (works in both truncate and mask modes) -- Veto applies rejection to `response_mask`, NOT by zeroing IS weights -- IS weights remain as true ratios even for vetoed sequences +- Veto applies rejection to `response_mask`, NOT by modifying IS weights +- **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded) + +### Summary: How IS Weights are Processed + +The final IS weights go through multiple stages of processing: + +**Stage 1: Safety Bound (All Modes)** +- Token level: `exp(clamp(log_ratio, -20, 20))` per token → bounds each token to [2e-9, 5e8] +- Sequence level: `exp(clamp(sum(log_ratio), -20, 20))` → bounds product to [2e-9, 5e8], broadcast to all tokens +- Geometric level: `exp(clamp(mean(log_ratio), -20, 20))` → bounds geometric mean to [2e-9, 5e8], broadcast to all tokens + +**Stage 2: Threshold Processing (Mode-Dependent)** +- Truncate mode: `.clamp(max=upper_threshold)` → upper clamps weights to threshold +- Mask mode: No modification → weights remain as safety-bounded ratios + +**Stage 3: Padding (All Modes)** +- `weights * response_mask` → zeros out padding positions + +**Rejection Mechanisms (Modify response_mask, NOT weights)** +- Veto: Checks **unclamped per-token ratios** (before safety bound), rejects sequences via mask +- Outlier (mask mode only): Checks safety-bounded weights against [lower, upper], rejects via mask + +## Operation Modes + +The system has **two independent control flags** that combine to create different operation modes: + +1. **`rollout_is_threshold`**: Main on/off switch (None = disabled, float = enabled) +2. **`rollout_is`**: Apply IS weights to loss (True/False) + +### Mode Combinations + +| `rollout_is_threshold` | `rollout_is` | `rollout_is_mode` | Behavior | +|------------------------|--------------|-------------------|----------| +| `None` | any | any | **Disabled**: No computation, no metrics, no rejection | +| `2.0` | `False` | `truncate` | **Metrics only**: Compute weights & metrics, NO weight correction, NO rejection for outliers | +| `2.0` | `False` | `mask` | **Rejection only**: Compute weights & metrics, NO weight correction, YES rejection sampling | +| `2.0` | `True` | `truncate` | **Truncate mode**: Weight correction enabled, weights upper-clamped, NO rejection for outliers | +| `2.0` | `True` | `mask` | **Mask mode (full)**: Weight correction enabled, rejection sampling enabled | + +### Key Insights + +**Rejection sampling is ALWAYS applied when:** +- `rollout_is_threshold` is set (not None) +- AND `rollout_is_mode = "mask"` +- **Regardless of the `rollout_is` flag** + +This means: +- ✅ You can use **rejection sampling alone** without IS weight correction (`rollout_is=False, rollout_is_mode="mask"`) +- ✅ You can use **IS weights alone** without outlier rejection (`rollout_is=True, rollout_is_mode="truncate"`) +- ✅ You can use **both together** (`rollout_is=True, rollout_is_mode="mask"`) +- ✅ You can **monitor metrics only** without any correction or outlier rejection (`rollout_is=False, rollout_is_mode="truncate"`) + +**Veto rejection** (if enabled via `rollout_is_veto_threshold`) is applied **independently** in all modes where `rollout_is_threshold` is set. + +### Recommended Workflow + +1. **Start with metrics only** to understand the mismatch: + ```yaml + rollout_is_threshold: 2.0 + rollout_is: false + rollout_is_mode: truncate + ``` + Monitor `mismatch/rollout_is_mean`, `mismatch/mismatch_kl` to assess distribution mismatch. + +2. **Enable rejection sampling** if you see high outlier fractions: + ```yaml + rollout_is_threshold: 2.0 + rollout_is: false + rollout_is_mode: mask # Rejection now applies + ``` + This excludes outliers from training without modifying gradients. + +3. **Enable full IS correction** once comfortable with metrics: + ```yaml + rollout_is_threshold: 2.0 + rollout_is: true + rollout_is_mode: mask # Both rejection and weight correction + ``` ## Usage @@ -183,9 +281,13 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear - **`rollout_is_min`**: Minimum IS weight observed - Shows the most underweighted token/sequence + - For sequence/geometric: computed from unclamped log-space ratios (true minimum) + - For token: computed from safety-bounded weights -- **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking) +- **`rollout_is_max`**: Maximum IS weight observed - Shows the most overweighted token/sequence + - For sequence/geometric: computed from unclamped log-space ratios (true maximum before safety bound) + - For token: computed from safety-bounded weights (before threshold clamping) - Compare with `rollout_is_threshold` to see truncation impact #### **Effective Sample Size** @@ -199,14 +301,16 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear #### **Veto Mechanism Metrics** - **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism - - **Important**: Sequences are rejected via `response_mask=0`, NOT by zeroing IS weights - - IS weights remain as true ratios even for vetoed sequences - - Veto detects catastrophic tokens (ratio < veto_threshold, e.g., < 1e-4) + - **Important**: Sequences are rejected via `response_mask=0`, NOT by modifying IS weights + - **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded) + - Veto checks **unclamped per-token ratios** π_train(t)/π_rollout(t) (true ratios before safety bound) + - Detects catastrophic tokens (true ratio < veto_threshold, e.g., < 1e-4) - **Ideal value**: < 0.05 (less than 5% vetoed) - **Warning**: > 0.1 suggests policies are too different or numerical issues - **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold - Identifies problematic tokens before sequence-level veto is applied + - Checks **unclamped per-token ratios** (true ratios, not safety-bounded) - Each catastrophic token causes its entire sequence to be rejected - **Warning**: > 0.01 indicates widespread distribution issues or numerical instability @@ -214,10 +318,14 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear - **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold - Shows how often truncation/masking occurs on high end + - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) + - For token: computed from safety-bounded weights (before threshold clamping) - **Ideal value**: < 0.1 (most weights within bounds) - **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold - Shows how often masking occurs on low end (mask mode only) + - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) + - For token: computed from safety-bounded weights - **Ideal value**: < 0.1 #### **Sequence-Level Metrics** (for sequence/geometric modes) @@ -241,9 +349,9 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear #### **Masking Metrics** (mask mode only) -- **`rollout_is_masked_fraction`**: Fraction of tokens rejected via response_mask - - **Important**: Tokens are rejected by setting `response_mask=0`, NOT by zeroing IS weights - - IS weights remain as true ratios (π_train/π_rollout) +- **`rollout_is_masked_fraction`**: Fraction of tokens rejected via response_mask (mask mode only) + - **Important**: Tokens are rejected by setting `response_mask=0`, NOT by modifying IS weights + - **IS weights in mask mode**: Safety-bounded ratios preserved (no threshold clamping) - **Ideal value**: < 0.1 (less than 10% rejected) - **Warning**: > 0.3 means losing too much data @@ -313,13 +421,18 @@ weights_proto, modified_response_mask, metrics = compute_rollout_importance_weig rollout_is_veto_threshold=1e-4, # Enable veto for catastrophic outliers ) -# Extract IS weights (always true ratios, never zeroed) +# Extract IS weights (processed, zeroed at padding) is_weights = weights_proto.batch["rollout_is_weights"] -# modified_response_mask has rejection applied -# - Tokens/sequences outside [0.5, 2.0] are masked to 0 -# - Sequences with catastrophic tokens are masked to 0 -# - IS weights remain as true ratios (NOT zeroed) +# IS weights processing (mask mode with token level): +# 1. Safety-bounded: exp(clamp(log_ratio, -20, 20)) per token +# 2. Mask mode: no threshold clamping (safety-bounded ratios preserved) +# 3. Zeroed at padding positions + +# modified_response_mask has rejection applied: +# 1. Outlier rejection: tokens outside [0.5, 2.0] masked to 0 (mask mode) +# 2. Veto rejection: sequences with catastrophic tokens (ratio < 1e-4) masked to 0 +# Note: Veto checks unclamped per-token ratios, not the safety-bounded weights # All metrics have 'mismatch/' prefix print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}") @@ -328,15 +441,18 @@ print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}") print(f"Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.3f}") print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}") -# Verify IS weights are true ratios (not zeroed) -print(f"\n✓ IS weights min: {is_weights[response_mask.bool()].min():.4f}") -print(f"✓ IS weights max: {is_weights[response_mask.bool()].max():.4f}") -print(f"✓ All IS weights > 0: {(is_weights[response_mask.bool()] > 0).all()}") +# Check IS weights for valid tokens (non-padding) +valid_weights = is_weights[response_mask.bool()] +print(f"\n✓ IS weights min (valid tokens): {valid_weights.min():.4f}") +print(f"✓ IS weights max (valid tokens): {valid_weights.max():.4f}") +print(f"✓ All valid IS weights > 0: {(valid_weights > 0).all()}") # Check rejection via response_mask rejected_tokens = (response_mask == 1) & (modified_response_mask == 0) print(f"\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask") -print(f"✓ IS weights for rejected tokens are NON-ZERO (true ratios)") +print(f"✓ In mask mode: IS weights for rejected tokens are NON-ZERO (safety-bounded ratios)") +print(f"✓ In truncate mode: IS weights upper clamped to {rollout_is_threshold}") +print(f"✓ Both modes: IS weights safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]") # Check for warning conditions if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0: @@ -376,7 +492,10 @@ for epoch in range(num_epochs): # IMPORTANT: Update batch response_mask with rejection applied batch.response_mask = modified_response_mask - # Use IS weights in training (true ratios, never zeroed) + # Use IS weights in training (processed based on mode) + # Truncate mode: upper clamped to min(weight, upper_threshold) + # Mask mode: safety-bounded ratios preserved (no threshold clamping) + # Both modes: safety bounded to [exp(-20), exp(20)], zeroed at padding is_weights = weights_proto.batch["rollout_is_weights"] # ... apply weights to policy gradient ... ``` diff --git a/tests/trainer/ppo/test_rollout_is.py b/tests/trainer/ppo/test_rollout_is.py index 7acc8cc9836..9ae13f0eab0 100644 --- a/tests/trainer/ppo/test_rollout_is.py +++ b/tests/trainer/ppo/test_rollout_is.py @@ -129,8 +129,8 @@ def test_basic_rollout_is(): weights_veto = weights_veto_proto.batch["rollout_is_weights"] print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}") # KEY FIX: Veto is applied via response_mask, not by zeroing weights - # Check that weights are NON-ZERO (true ratios preserved) - assert weights_veto[0].sum() > 0, "Weights should be non-zero (true ratios preserved)" + # Check that weights are NON-ZERO (safety-bounded ratios preserved, not zeroed) + assert weights_veto[0].sum() > 0, "Weights should be non-zero (not zeroed by veto)" # Check that response_mask has veto applied assert modified_response_mask_veto[0].sum() == 0, "Vetoed sequence should have response_mask zeroed" assert modified_response_mask_veto[1].sum() > 0, "Normal sequence should have response_mask unchanged" @@ -301,8 +301,8 @@ def test_mask_mode(): weights = weights_proto.batch["rollout_is_weights"] - # KEY FIX: Weights should be TRUE ratios (NOT zeroed) - assert torch.all(weights[0, :] > 0), "Weights should remain as true ratios (not zeroed)" + # KEY FIX: Weights should be safety-bounded ratios (NOT zeroed) + assert torch.all(weights[0, :] > 0), "Weights should remain as safety-bounded ratios (not zeroed)" assert torch.allclose(weights[0, 0], torch.tensor(0.368, device=device), atol=0.01), ( "First seq ratio should be ≈0.37" ) diff --git a/verl/trainer/ppo/mismatch_helper.py b/verl/trainer/ppo/mismatch_helper.py index b70229c0030..7b14383bd94 100644 --- a/verl/trainer/ppo/mismatch_helper.py +++ b/verl/trainer/ppo/mismatch_helper.py @@ -60,8 +60,15 @@ def compute_rollout_importance_weights( and training policies, and applies rejection sampling for outliers. Key Design: Separation of IS Weights and Rejection Sampling - - IS weights (rollout_is_weights): Always TRUE ratios π_train/π_rollout - Preserved for policy gradient calculations + - IS weights (rollout_is_weights): Ratios π_train/π_rollout with processing applied: + * Safety-bounded to prevent overflow: + - Token level: exp(clamp(log_ratio, -20, 20)) per token + - Sequence level: exp(clamp(sum(log_ratio), -20, 20)) broadcast to all tokens + - Geometric level: exp(clamp(mean(log_ratio), -20, 20)) broadcast to all tokens + * Truncate mode: upper clamped via .clamp(max=upper_threshold) + * Mask mode: safety-bounded ratios preserved (no threshold clamping) + * All modes: zeroed at padding positions + Used for policy gradient calculations - Response mask (modified_response_mask): Has rejection applied (mask mode + veto) Used for loss aggregation to exclude rejected samples from training @@ -92,8 +99,15 @@ def compute_rollout_importance_weights( Returns: Tuple of (weights_proto, modified_response_mask, metrics): - weights_proto: DataProto with TRUE IS weights (never zeroed), key "rollout_is_weights", - shape (batch_size, seq_length). None if rollout_is_threshold is None. + weights_proto: DataProto with processed IS weights, key "rollout_is_weights", + shape (batch_size, seq_length). Processing applied: + - Safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]: + * Token level: bounds per-token ratios + * Sequence/geometric level: bounds aggregated ratio (broadcast to all tokens) + - Truncate mode: upper clamped via .clamp(max=upper_threshold) + - Mask mode: safety-bounded ratios preserved (no threshold clamping) + - All modes: zeroed at padding positions (response_mask == 0) + None if rollout_is_threshold is None. modified_response_mask: Response mask with rejection applied: - truncate mode: unchanged for outlier ratios, but veto rejection still applied - mask mode: tokens outside [lower, upper] masked to 0 @@ -192,20 +206,24 @@ def compute_rollout_importance_weights( # Step 3: Apply outlier handling and rejection sampling # Key design principle: IS weights and rejection are separate mechanisms - # - rollout_is_weights: Always TRUE ratios π_train/π_rollout (never zeroed) - # Preserved for policy gradient calculations + # - rollout_is_weights: IS weight ratios with mode-specific processing + # * Truncate mode: upper clamped to prevent extreme values + # * Mask mode: safety-bounded ratios preserved (no threshold clamping, rejection via mask) + # Used for policy gradient calculations # - modified_response_mask: Has rejection applied (excludes outliers from training) # Used for loss denominator: ensures rejected samples don't dilute gradients if rollout_is_mode == "truncate": # Truncated IS (TIS): clamp weights to prevent extreme importance ratios - # No rejection for outlier ratios (mask unchanged), but veto can still apply below + # Weights are modified by clamping; no rejection via mask for outlier ratios + # Veto rejection (if enabled) will still be applied to modified_response_mask below rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold) - modified_response_mask = response_mask # Unchanged for outlier ratios + modified_response_mask = response_mask # Unchanged for outlier ratios (veto applied later) elif rollout_is_mode == "mask": # Masked IS (MIS): rejection sampling for outlier IS weights - # Reject tokens/sequences with IS ratios outside [lower, upper] + # Reject tokens/sequences with IS ratios outside [lower, upper] via response_mask + # IS weights themselves are NOT threshold-clamped (remain safety-bounded only) mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) mask = mask.float() @@ -219,20 +237,20 @@ def compute_rollout_importance_weights( seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0 metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean() - # Apply rejection via response_mask (NOT by zeroing IS weights) + # Apply rejection via response_mask (NOT by clamping IS weights) modified_response_mask = response_mask * mask - # rollout_is_weights kept as true ratios (unchanged) + # rollout_is_weights kept as safety-bounded ratios (no threshold clamping) else: raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.") # Apply veto: reject entire sequences with catastrophic tokens (ratio < veto_threshold) - # Veto is independent of mask mode - it applies to modified_response_mask after mask rejection + # Veto is independent of mode - it applies to modified_response_mask after mode-specific handling modified_response_mask = modified_response_mask * veto_mask - # rollout_is_weights still unchanged (true ratios preserved) + # Note: rollout_is_weights unaffected by veto (already clamped in truncate mode, or kept as-is in mask mode) - # Apply original response_mask to zero out padding positions in IS weights - # This is different from rejection - padding must be zeroed for correct aggregation + # Zero out padding positions in IS weights for correct aggregation + # This is different from rejection - padding must be zeroed regardless of mode rollout_is_weights = rollout_is_weights * response_mask # Wrap in DataProto for consistency with worker methods diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index ef2abb5e920..9f633730ab2 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1002,7 +1002,7 @@ def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) apply_weights = self.config.algorithm.get("rollout_is", False) if apply_weights: - # Add TRUE IS weights (never zeroed, always π_train/π_rollout) + # Add IS weights (safety-bounded, mode-processed) to enable weight correction batch = batch.union(rollout_is_weights) return batch, rollout_is_metrics