diff --git a/docs/advance/rollout_is.md b/docs/advance/rollout_is.md index 8f75c803a5c..dd282b58dbc 100644 --- a/docs/advance/rollout_is.md +++ b/docs/advance/rollout_is.md @@ -1,13 +1,22 @@ # Rollout Importance Sampling -Last updated: 10/11/2025. +**Author:** [Yingru Li](https://richardli.xyz/) + +Last updated: 10/27/2025. This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl. -## References +### BibTeX Citation -- [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) -- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) +```bibtex +@misc{liu-li-2025, + title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, + url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, + author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, + year = {2025}, + month = september, +} +``` ## Overview @@ -17,6 +26,31 @@ Rollout Importance Sampling corrects for distribution mismatch between: This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases. +### Key Design Principle: Separation of IS Weights and Rejection Sampling + +**Important**: As of 10/27/2025, the implementation separates two mechanisms: + +1. **IS Weights** (`rollout_is_weights`): Ratios π_train/π_rollout with processing: + - **Safety-bounded** to [exp(-20), exp(20)] ≈ [2e-9, 5e8] to prevent overflow: + * Token level: Bounds per-token ratios + * Sequence level: Bounds product of ratios (broadcast to all tokens in sequence) + * Geometric level: Bounds geometric mean of ratios (broadcast to all tokens) + - **Truncate mode**: Upper clamped via .clamp(max=upper_threshold) + - **Mask mode**: Safety-bounded ratios preserved (no threshold clamping) + - **All modes**: Zeroed at padding positions (response_mask == 0) + - Used for policy gradient calculations + +2. **Rejection Sampling** (`modified_response_mask`): Applied via response_mask + - Mask mode: Excludes tokens/sequences with outlier IS ratios + - Veto: Excludes sequences with catastrophic tokens + - Used for loss aggregation (denominator calculation) + +This separation ensures: +- ✅ Correct loss normalization (rejected samples excluded from denominator) +- ✅ Mode-specific weight processing (truncate: upper clamped, mask: safety-bounded only) +- ✅ Padding positions zeroed in weights (necessary for correct aggregation) +- ✅ Safety bounds always applied (prevent overflow in all modes) + ## Configuration ```yaml @@ -29,7 +63,7 @@ algorithm: rollout_is_threshold_lower: null # Auto-reciprocal rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: null # Disable veto by default # REQUIRED: Enable log prob calculation actor_rollout_ref: @@ -86,7 +120,9 @@ Key features: ### `algorithm.rollout_is` (bool) Whether to apply IS weights to policy loss. Default: `False` - `true` = apply weights to loss (full IS correction) -- `false` = compute metrics only (useful for monitoring before enabling) +- `false` = metrics only mode (no weight correction, but rejection still applies) + +**IMPORTANT**: This flag controls IS weight application, NOT rejection sampling. See "Operation Modes" below. **Recommended threshold ranges:** - Token level: 1.5 - 5.0 @@ -98,18 +134,120 @@ Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal). ### `algorithm.rollout_is_level` (str) Aggregation level for IS weights: -- `"token"`: Per-token ratios -- `"sequence"`: Product of ratios -- `"geometric"`: Geometric mean (experimental) +- `"token"`: Per-token ratios ρ_t = π_train(t)/π_rollout(t) + - Each token has its own IS weight + - Safety bound: each token's ratio bounded to [exp(-20), exp(20)] + - Biased estimator but low variance +- `"sequence"`: Product of ratios ρ_seq = ∏_t ρ_t for entire sequence + - All tokens in a sequence share the same IS weight (product of per-token ratios) + - Safety bound: product bounded to [exp(-20), exp(20)], then broadcast to all tokens + - Unbiased estimator but high variance +- `"geometric"`: Geometric mean ρ_geo = (∏_t ρ_t)^(1/T) (experimental) + - All tokens in a sequence share the same IS weight (geometric mean) + - Safety bound: geometric mean bounded to [exp(-20), exp(20)], then broadcast to all tokens + - Trade-off between bias and variance ### `algorithm.rollout_is_mode` (str) -Bounding mode: -- `"truncate"`: Cap weights at upper threshold only -- `"mask"`: Zero out weights outside [lower, upper] - -### `algorithm.rollout_is_veto_threshold` (float) -Per-token veto threshold. If any token ratio < this, entire sequence is rejected. -Default: `1e-4` (ratio 10,000x off) +Bounding mode for handling outlier IS weights: +- `"truncate"`: Clamp weights at upper threshold only (TIS) + - No lower bound clamping or rejection for outlier ratios + - **IS weights modified**: Upper bound clamped via .clamp(max=upper_threshold) + - Lower bound remains at exp(-20) ≈ 2e-9 from safety bound + - **Note**: Veto-based rejection can still occur via response_mask (see `rollout_is_veto_threshold`) +- `"mask"`: Rejection sampling via response_mask (MIS) + - Rejects tokens/sequences with IS ratios outside [lower, upper] + - **Important**: Rejection applied to `response_mask`, NOT by modifying IS weights + - **IS weights**: Safety-bounded ratios preserved (no threshold clamping, rejection via mask) + - **Note**: Veto-based rejection also applies via response_mask (independent mechanism) + +### `algorithm.rollout_is_veto_threshold` (float or None) +Per-token veto threshold for catastrophic outliers. +- If any token has **unclamped** ratio < this threshold, the entire sequence is rejected via `response_mask` +- Veto checks the **true per-token ratio** π_train(t)/π_rollout(t) before any bounds are applied +- Applied for all levels (token, sequence, geometric) - always checks individual token ratios +- Default: `None` (veto disabled by default) +- Recommended: `1e-4` to `1e-6` when enabled (catches extreme outliers like 10,000x off) +- Set to `None` to disable veto mechanism +- **Important**: Applied **independently** of `rollout_is_mode` (works in both truncate and mask modes) +- Veto applies rejection to `response_mask`, NOT by modifying IS weights +- **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded) + +### Summary: How IS Weights are Processed + +The final IS weights go through multiple stages of processing: + +**Stage 1: Safety Bound (All Modes)** +- Token level: `exp(clamp(log_ratio, -20, 20))` per token → bounds each token to [2e-9, 5e8] +- Sequence level: `exp(clamp(sum(log_ratio), -20, 20))` → bounds product to [2e-9, 5e8], broadcast to all tokens +- Geometric level: `exp(clamp(mean(log_ratio), -20, 20))` → bounds geometric mean to [2e-9, 5e8], broadcast to all tokens + +**Stage 2: Threshold Processing (Mode-Dependent)** +- Truncate mode: `.clamp(max=upper_threshold)` → upper clamps weights to threshold +- Mask mode: No modification → weights remain as safety-bounded ratios + +**Stage 3: Padding (All Modes)** +- `weights * response_mask` → zeros out padding positions + +**Rejection Mechanisms (Modify response_mask, NOT weights)** +- Veto: Checks **unclamped per-token ratios** (before safety bound), rejects sequences via mask +- Outlier (mask mode only): Checks safety-bounded weights against [lower, upper], rejects via mask + +## Operation Modes + +The system has **two independent control flags** that combine to create different operation modes: + +1. **`rollout_is_threshold`**: Main on/off switch (None = disabled, float = enabled) +2. **`rollout_is`**: Apply IS weights to loss (True/False) + +### Mode Combinations + +| `rollout_is_threshold` | `rollout_is` | `rollout_is_mode` | Behavior | +|------------------------|--------------|-------------------|----------| +| `None` | any | any | **Disabled**: No computation, no metrics, no rejection | +| `2.0` | `False` | `truncate` | **Metrics only**: Compute weights & metrics, NO weight correction, NO rejection for outliers | +| `2.0` | `False` | `mask` | **Rejection only**: Compute weights & metrics, NO weight correction, YES rejection sampling | +| `2.0` | `True` | `truncate` | **Truncate mode**: Weight correction enabled, weights upper-clamped, NO rejection for outliers | +| `2.0` | `True` | `mask` | **Mask mode (full)**: Weight correction enabled, rejection sampling enabled | + +### Key Insights + +**Rejection sampling is ALWAYS applied when:** +- `rollout_is_threshold` is set (not None) +- AND `rollout_is_mode = "mask"` +- **Regardless of the `rollout_is` flag** + +This means: +- ✅ You can use **rejection sampling alone** without IS weight correction (`rollout_is=False, rollout_is_mode="mask"`) +- ✅ You can use **IS weights alone** without outlier rejection (`rollout_is=True, rollout_is_mode="truncate"`) +- ✅ You can use **both together** (`rollout_is=True, rollout_is_mode="mask"`) +- ✅ You can **monitor metrics only** without any correction or outlier rejection (`rollout_is=False, rollout_is_mode="truncate"`) + +**Veto rejection** (if enabled via `rollout_is_veto_threshold`) is applied **independently** in all modes where `rollout_is_threshold` is set. + +### Recommended Workflow + +1. **Start with metrics only** to understand the mismatch: + ```yaml + rollout_is_threshold: 2.0 + rollout_is: false + rollout_is_mode: truncate + ``` + Monitor `mismatch/rollout_is_mean`, `mismatch/mismatch_kl` to assess distribution mismatch. + +2. **Enable rejection sampling** if you see high outlier fractions: + ```yaml + rollout_is_threshold: 2.0 + rollout_is: false + rollout_is_mode: mask # Rejection now applies + ``` + This excludes outliers from training without modifying gradients. + +3. **Enable full IS correction** once comfortable with metrics: + ```yaml + rollout_is_threshold: 2.0 + rollout_is: true + rollout_is_mode: mask # Both rejection and weight correction + ``` ## Usage @@ -143,9 +281,13 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear - **`rollout_is_min`**: Minimum IS weight observed - Shows the most underweighted token/sequence + - For sequence/geometric: computed from unclamped log-space ratios (true minimum) + - For token: computed from safety-bounded weights -- **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking) +- **`rollout_is_max`**: Maximum IS weight observed - Shows the most overweighted token/sequence + - For sequence/geometric: computed from unclamped log-space ratios (true maximum before safety bound) + - For token: computed from safety-bounded weights (before threshold clamping) - Compare with `rollout_is_threshold` to see truncation impact #### **Effective Sample Size** @@ -159,21 +301,31 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear #### **Veto Mechanism Metrics** - **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism + - **Important**: Sequences are rejected via `response_mask=0`, NOT by modifying IS weights + - **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded) + - Veto checks **unclamped per-token ratios** π_train(t)/π_rollout(t) (true ratios before safety bound) + - Detects catastrophic tokens (true ratio < veto_threshold, e.g., < 1e-4) - **Ideal value**: < 0.05 (less than 5% vetoed) - **Warning**: > 0.1 suggests policies are too different or numerical issues - **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold - - Identifies problematic tokens before sequence-level veto - - **Warning**: > 0.01 indicates widespread distribution issues + - Identifies problematic tokens before sequence-level veto is applied + - Checks **unclamped per-token ratios** (true ratios, not safety-bounded) + - Each catastrophic token causes its entire sequence to be rejected + - **Warning**: > 0.01 indicates widespread distribution issues or numerical instability #### **Threshold Exceedance Metrics** - **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold - Shows how often truncation/masking occurs on high end + - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) + - For token: computed from safety-bounded weights (before threshold clamping) - **Ideal value**: < 0.1 (most weights within bounds) - **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold - Shows how often masking occurs on low end (mask mode only) + - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) + - For token: computed from safety-bounded weights - **Ideal value**: < 0.1 #### **Sequence-Level Metrics** (for sequence/geometric modes) @@ -197,12 +349,16 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear #### **Masking Metrics** (mask mode only) -- **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero) - - **Ideal value**: < 0.1 +- **`rollout_is_masked_fraction`**: Fraction of tokens rejected via response_mask (mask mode only) + - **Important**: Tokens are rejected by setting `response_mask=0`, NOT by modifying IS weights + - **IS weights in mask mode**: Safety-bounded ratios preserved (no threshold clamping) + - **Ideal value**: < 0.1 (less than 10% rejected) - **Warning**: > 0.3 means losing too much data -- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token - - Shows sequence-level impact of masking +- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one rejected token + - Shows sequence-level impact of rejection sampling + - For token-level: sequence rejected if ANY token is outside [lower, upper] + - For sequence-level: all tokens have same weight, so entire sequence rejected or accepted #### **Distribution Mismatch Metrics** (Training vs Rollout Policy) @@ -253,22 +409,51 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear # Metrics are returned from compute_rollout_importance_weights from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights -weights_proto, metrics = compute_rollout_importance_weights( +# NEW: Returns 3 values (weights, modified_response_mask, metrics) +weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( old_log_prob=training_log_probs, # from training policy rollout_log_prob=rollout_log_probs, # from rollout policy response_mask=response_mask, rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is_mode="mask", # Using mask mode for rejection sampling rollout_is_threshold=2.0, - rollout_is_veto_threshold=1e-4, + rollout_is_threshold_lower=0.5, + rollout_is_veto_threshold=1e-4, # Enable veto for catastrophic outliers ) +# Extract IS weights (processed, zeroed at padding) +is_weights = weights_proto.batch["rollout_is_weights"] + +# IS weights processing (mask mode with token level): +# 1. Safety-bounded: exp(clamp(log_ratio, -20, 20)) per token +# 2. Mask mode: no threshold clamping (safety-bounded ratios preserved) +# 3. Zeroed at padding positions + +# modified_response_mask has rejection applied: +# 1. Outlier rejection: tokens outside [0.5, 2.0] masked to 0 (mask mode) +# 2. Veto rejection: sequences with catastrophic tokens (ratio < 1e-4) masked to 0 +# Note: Veto checks unclamped per-token ratios, not the safety-bounded weights + # All metrics have 'mismatch/' prefix print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}") print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}") print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}") +print(f"Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.3f}") print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}") +# Check IS weights for valid tokens (non-padding) +valid_weights = is_weights[response_mask.bool()] +print(f"\n✓ IS weights min (valid tokens): {valid_weights.min():.4f}") +print(f"✓ IS weights max (valid tokens): {valid_weights.max():.4f}") +print(f"✓ All valid IS weights > 0: {(valid_weights > 0).all()}") + +# Check rejection via response_mask +rejected_tokens = (response_mask == 1) & (modified_response_mask == 0) +print(f"\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask") +print(f"✓ In mask mode: IS weights for rejected tokens are NON-ZERO (safety-bounded ratios)") +print(f"✓ In truncate mode: IS weights upper clamped to {rollout_is_threshold}") +print(f"✓ Both modes: IS weights safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]") + # Check for warning conditions if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0: print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected") @@ -288,14 +473,15 @@ for epoch in range(num_epochs): for batch_idx, batch in enumerate(dataloader): # ... rollout phase ... - # Compute IS weights and get metrics - weights_proto, metrics = compute_rollout_importance_weights( + # Compute IS weights and get metrics (NEW: 3 return values) + weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( old_log_prob=batch.old_log_prob, rollout_log_prob=batch.rollout_log_prob, response_mask=batch.response_mask, rollout_is_level=config.rollout_is_level, rollout_is_mode=config.rollout_is_mode, rollout_is_threshold=config.rollout_is_threshold, + rollout_is_threshold_lower=config.rollout_is_threshold_lower, rollout_is_veto_threshold=config.rollout_is_veto_threshold, ) @@ -303,7 +489,13 @@ for epoch in range(num_epochs): for metric_name, metric_value in metrics.items(): logger.log_scalar(metric_name, metric_value, step=global_step) - # Use IS weights in training + # IMPORTANT: Update batch response_mask with rejection applied + batch.response_mask = modified_response_mask + + # Use IS weights in training (processed based on mode) + # Truncate mode: upper clamped to min(weight, upper_threshold) + # Mask mode: safety-bounded ratios preserved (no threshold clamping) + # Both modes: safety bounded to [exp(-20), exp(20)], zeroed at padding is_weights = weights_proto.batch["rollout_is_weights"] # ... apply weights to policy gradient ... ``` @@ -349,8 +541,8 @@ def check_rollout_is_health(metrics, config): print("✅ Rollout IS metrics look healthy") return True -# Use in training -_, metrics = compute_rollout_importance_weights(...) +# Use in training (NEW: 3 return values) +_, _, metrics = compute_rollout_importance_weights(...) is_healthy = check_rollout_is_health(metrics, config) if not is_healthy: @@ -508,8 +700,8 @@ metrics_history = { # In training loop for step in range(num_steps): - # ... compute IS weights ... - _, metrics = compute_rollout_importance_weights(...) + # ... compute IS weights ... (NEW: 3 return values) + _, _, metrics = compute_rollout_importance_weights(...) # Store metrics for key in metrics_history.keys(): @@ -556,3 +748,8 @@ Rollout Importance Sampling provides: - ✅ Comprehensive metrics for monitoring - ✅ Flexibility for different scenarios - ✅ Memory-efficient computation + +## References + +- [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda) +- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) \ No newline at end of file diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 387f197f494..45cc8b14c60 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -133,7 +133,7 @@ Actor/Rollout/Reference Policy rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper) rollout_is_level: token # Aggregation: token/sequence/geometric rollout_is_mode: truncate # Bounding: truncate/mask - rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold + rollout_is_veto_threshold: null # Catastrophic outlier threshold (null to disable) use_torch_compile: True # False to disable torch compile kl_loss_coef: 0.001 # for grpo kl_loss_type: low_var_kl # for grpo @@ -519,7 +519,7 @@ Algorithm rollout_is_threshold_lower: null rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: null # Disabled by default - ``gamma``: discount factor - ``lam``: Trade-off between bias and variance in the GAE estimator @@ -537,7 +537,7 @@ Algorithm - ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper). - ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental). - ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds). -- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4. +- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is null (disabled). Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``. Trainer diff --git a/examples/rollout_importance_sampling/README.md b/examples/rollout_importance_sampling/README.md index 1d0ce66394e..7baf55ebf2e 100644 --- a/examples/rollout_importance_sampling/README.md +++ b/examples/rollout_importance_sampling/README.md @@ -61,7 +61,7 @@ bash examples/rollout_importance_sampling/run_with_rollout_is.sh - `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.** - `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false. - `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper) -- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4) +- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: null, disabled) ## Configuration Examples @@ -73,7 +73,7 @@ algorithm: rollout_is: true # Apply to loss rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: null # Disabled by default ``` ### Example 2: Metrics Only (No Weight Application) @@ -95,7 +95,7 @@ algorithm: rollout_is_threshold_lower: 0.9998 rollout_is_level: geometric rollout_is_mode: mask - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: 1e-4 # Enable veto for this example ``` ### Example 4: Sequence-level with Truncate @@ -107,7 +107,7 @@ algorithm: rollout_is_threshold_lower: null # Auto-reciprocal: 0.2 rollout_is_level: sequence rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 + rollout_is_veto_threshold: 1e-4 # Enable veto for this example ``` ### Example 5: Asymmetric Thresholds @@ -226,8 +226,8 @@ The veto mechanism zeros out entire sequences containing catastrophic outliers: - If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected - This prevents extreme outliers from dominating training -- Default threshold: 1e-4 (ratio 10,000x off) -- Set to `null` to disable: `rollout_is_veto_threshold: null` +- Default: `null` (disabled by default) +- Set to `1e-4` to enable (catches ratios 10,000x off) ## Examples diff --git a/examples/rollout_importance_sampling/run_with_rollout_is.sh b/examples/rollout_importance_sampling/run_with_rollout_is.sh index 93ebbe4fdc2..42c4a2a5981 100755 --- a/examples/rollout_importance_sampling/run_with_rollout_is.sh +++ b/examples/rollout_importance_sampling/run_with_rollout_is.sh @@ -24,8 +24,8 @@ rollout_is_level=token # Bounding mode: truncate (cap upper) | mask (zero outside bounds) rollout_is_mode=truncate -# Catastrophic outlier veto threshold -rollout_is_veto_threshold=1e-4 +# Catastrophic outlier veto threshold (set to null to disable, or e.g., 1e-4 to enable) +rollout_is_veto_threshold=null # ============================================================================== # Model and Data Configuration diff --git a/tests/trainer/ppo/test_rollout_is.py b/tests/trainer/ppo/test_rollout_is.py index 584dec5a8dd..9ae13f0eab0 100644 --- a/tests/trainer/ppo/test_rollout_is.py +++ b/tests/trainer/ppo/test_rollout_is.py @@ -49,7 +49,7 @@ def test_basic_rollout_is(): # Test token-level truncate mode print("\n1. Testing token-level truncate mode...") - weights_proto, metrics = compute_rollout_importance_weights( + weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -71,7 +71,7 @@ def test_basic_rollout_is(): # Test sequence-level mode print("\n2. Testing sequence-level mode...") - weights_seq_proto, metrics_seq = compute_rollout_importance_weights( + weights_seq_proto, _, metrics_seq = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -92,7 +92,7 @@ def test_basic_rollout_is(): # Test geometric mean mode print("\n3. Testing geometric mean mode...") - weights_geo_proto, metrics_geo = compute_rollout_importance_weights( + weights_geo_proto, _, metrics_geo = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -116,7 +116,7 @@ def test_basic_rollout_is(): rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7 eos_mask_veto = torch.ones(2, 5, device=device) - weights_veto_proto, metrics_veto = compute_rollout_importance_weights( + weights_veto_proto, modified_response_mask_veto, metrics_veto = compute_rollout_importance_weights( old_log_prob=old_log_prob_veto, rollout_log_prob=rollout_log_prob_veto, response_mask=eos_mask_veto, @@ -128,14 +128,17 @@ def test_basic_rollout_is(): weights_veto = weights_veto_proto.batch["rollout_is_weights"] print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}") - # Check that the sequence with catastrophic token has all weights zeroed - assert weights_veto[0].sum() == 0, "Sequence with catastrophic token should be vetoed" - assert weights_veto[1].sum() > 0, "Normal sequence should not be vetoed" + # KEY FIX: Veto is applied via response_mask, not by zeroing weights + # Check that weights are NON-ZERO (safety-bounded ratios preserved, not zeroed) + assert weights_veto[0].sum() > 0, "Weights should be non-zero (not zeroed by veto)" + # Check that response_mask has veto applied + assert modified_response_mask_veto[0].sum() == 0, "Vetoed sequence should have response_mask zeroed" + assert modified_response_mask_veto[1].sum() > 0, "Normal sequence should have response_mask unchanged" print(" ✓ Veto mechanism passed") # Test disabled IS (threshold=None) print("\n5. Testing disabled IS...") - weights_disabled, metrics_disabled = compute_rollout_importance_weights( + weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -143,6 +146,7 @@ def test_basic_rollout_is(): ) assert weights_disabled is None, "Should return None when threshold is None" + assert torch.equal(modified_response_mask_disabled, eos_mask), "Should return original mask unchanged" assert len(metrics_disabled) == 0, "Should return empty metrics when disabled" print(" ✓ Disabled IS passed") @@ -160,7 +164,7 @@ def test_metrics_completeness(): rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2 eos_mask = torch.ones(batch_size, seq_length, device=device) - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, @@ -264,6 +268,64 @@ def test_mismatch_metrics(): print(" ✓ Mismatch metrics work without rollout log probs") +def test_mask_mode(): + """Test mask mode applies rejection via response_mask, keeps true IS weights.""" + print("\nTesting mask mode behavior...") + + batch_size = 2 + seq_length = 5 + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Sequence 0: ratio ≈ 0.37 (below 0.5, should be rejected) + # Sequence 1: ratio ≈ 1.65 (in [0.5, 2.0], should be accepted) + old_log_prob = torch.tensor([[-2.0] * seq_length, [-2.0] * seq_length], device=device) + rollout_log_prob = torch.tensor( + [ + [-1.0] * seq_length, # exp(-2.0 - (-1.0)) = exp(-1.0) ≈ 0.37 + [-2.5] * seq_length, # exp(-2.0 - (-2.5)) = exp(0.5) ≈ 1.65 + ], + device=device, + ) + response_mask = torch.ones(batch_size, seq_length, device=device) + + weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( + old_log_prob=old_log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + rollout_is_level="token", + rollout_is_mode="mask", + rollout_is_threshold=2.0, + rollout_is_threshold_lower=0.5, + rollout_is_veto_threshold=None, + ) + + weights = weights_proto.batch["rollout_is_weights"] + + # KEY FIX: Weights should be safety-bounded ratios (NOT zeroed) + assert torch.all(weights[0, :] > 0), "Weights should remain as safety-bounded ratios (not zeroed)" + assert torch.allclose(weights[0, 0], torch.tensor(0.368, device=device), atol=0.01), ( + "First seq ratio should be ≈0.37" + ) + assert torch.allclose(weights[1, 0], torch.tensor(1.649, device=device), atol=0.01), ( + "Second seq ratio should be ≈1.65" + ) + + # Rejection should be applied via response_mask + assert torch.all(modified_response_mask[0, :] == 0), "First sequence should be rejected via mask" + assert torch.all(modified_response_mask[1, :] == 1), "Second sequence should be accepted" + + # Verify mask metrics exist + assert "mismatch/rollout_is_masked_fraction" in metrics + assert abs(metrics["mismatch/rollout_is_masked_fraction"] - 0.5) < 0.01, "Should reject 50% of tokens" + + print(f" First seq IS weight: {weights[0, 0]:.4f} (expected ≈0.37)") + print(f" Second seq IS weight: {weights[1, 0]:.4f} (expected ≈1.65)") + print(f" First seq mask: {modified_response_mask[0, 0]:.0f} (expected 0 - rejected)") + print(f" Second seq mask: {modified_response_mask[1, 0]:.0f} (expected 1 - accepted)") + print(f" Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.2f}") + print(" ✓ Mask mode correctly separates IS weights from rejection") + + if __name__ == "__main__": print("=" * 60) print("Rollout Importance Sampling Test Suite") @@ -273,6 +335,7 @@ def test_mismatch_metrics(): test_basic_rollout_is() test_metrics_completeness() test_mismatch_metrics() + test_mask_mode() print("\n" + "=" * 60) print("ALL TESTS PASSED ✓") print("=" * 60) diff --git a/tests/trainer/ppo/test_rollout_is_integration.py b/tests/trainer/ppo/test_rollout_is_integration.py index abcbcb70502..b96fb77523f 100644 --- a/tests/trainer/ppo/test_rollout_is_integration.py +++ b/tests/trainer/ppo/test_rollout_is_integration.py @@ -61,7 +61,7 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is): This test simulates that workflow. """ # First compute IS weights (as trainer would do centrally) - rollout_is_weights_proto, _ = compute_rollout_importance_weights( + rollout_is_weights_proto, _, _ = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -92,7 +92,7 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is): def test_rollout_is_weights_computation(self, sample_data): """Test rollout IS weights and metrics computation.""" - weights_proto, metrics = compute_rollout_importance_weights( + weights_proto, _, metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -120,7 +120,7 @@ def test_all_aggregation_levels(self, sample_data): levels = ["token", "sequence", "geometric"] for level in levels: - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -136,7 +136,7 @@ def test_both_bounding_modes(self, sample_data): modes = ["truncate", "mask"] for mode in modes: - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], @@ -175,7 +175,7 @@ def test_veto_mechanism(self): response_mask = torch.ones(batch_size, seq_length, device=device) - _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_importance_weights( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask, @@ -196,7 +196,7 @@ def test_metrics_only_mode(self, sample_data, config_with_rollout_is): but rollout_is=False (disables weight application to policy loss). """ # Compute IS weights (as trainer would do) - rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights( + rollout_is_weights_proto, _, is_metrics = compute_rollout_importance_weights( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 5b8dc20c4b0..656cc0b731b 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -490,7 +490,7 @@ algorithm: rollout_is_threshold_lower: null rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 0.0001 + rollout_is_veto_threshold: null rollout_is: false trainer: balance_batch: true diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index c9bccaf8fe0..3ec8e229fb0 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -476,7 +476,7 @@ algorithm: rollout_is_threshold_lower: null rollout_is_level: token rollout_is_mode: truncate - rollout_is_veto_threshold: 0.0001 + rollout_is_veto_threshold: null rollout_is: false trainer: balance_batch: true diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index 859ce90ff32..bd6763226f7 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -78,7 +78,7 @@ class AlgoConfig(BaseConfig): rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper. rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric". rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds). - rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers. + rollout_is_veto_threshold (float or None): Per-token veto threshold for catastrophic outliers. None to disable. rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights, False = compute metrics only (useful for monitoring before enabling correction). Default: False. """ @@ -99,7 +99,7 @@ class AlgoConfig(BaseConfig): rollout_is_threshold_lower: Optional[float] = None rollout_is_level: str = "token" rollout_is_mode: str = "truncate" - rollout_is_veto_threshold: Optional[float] = 1e-4 + rollout_is_veto_threshold: Optional[float] = None # Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set) # True = apply weights to loss, False = compute metrics only (no weight application) rollout_is: bool = False diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 8a23a3b72d4..fb5eef5d6fc 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -87,8 +87,8 @@ algorithm: # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds) rollout_is_mode: truncate - # Per-token veto threshold for catastrophic outliers - rollout_is_veto_threshold: 1e-4 + # Per-token veto threshold for catastrophic outliers (null to disable) + rollout_is_veto_threshold: null # Whether to apply IS weights to policy loss # true = apply weights to loss, false = compute metrics only (no weight application) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 72c97e4a7cd..d9d250c9a48 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -127,8 +127,8 @@ algorithm: # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds) rollout_is_mode: truncate - # Per-token veto threshold for catastrophic outliers - rollout_is_veto_threshold: 1e-4 + # Per-token veto threshold for catastrophic outliers (null to disable) + rollout_is_veto_threshold: null # Whether to apply IS weights to policy loss # true = apply weights to loss, false = compute metrics only (no weight application) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 9ec460aba2e..7c2bfd53673 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -788,10 +788,13 @@ def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str loss = verl_F.masked_mean(loss_mat, loss_mask) elif loss_agg_mode == "seq-mean-token-sum": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum - loss = torch.mean(seq_losses) # seq-mean + seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences + loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean elif loss_agg_mode == "seq-mean-token-mean": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / torch.sum(loss_mask, dim=-1) # token-mean - loss = torch.mean(seq_losses) # seq-mean + seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean + seq_mask = (seq_mask > 0).float() # exclude fully masked sequences + loss = verl_F.masked_mean(seq_losses, seq_mask) # seq-mean elif loss_agg_mode == "seq-mean-token-sum-norm": seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) loss = torch.sum(seq_losses) / loss_mask.shape[-1] # The divisor diff --git a/verl/trainer/ppo/mismatch_helper.py b/verl/trainer/ppo/mismatch_helper.py index 78bb5d6accb..7b14383bd94 100644 --- a/verl/trainer/ppo/mismatch_helper.py +++ b/verl/trainer/ppo/mismatch_helper.py @@ -52,46 +52,71 @@ def compute_rollout_importance_weights( rollout_is_mode: str = "truncate", rollout_is_threshold: Optional[float] = None, rollout_is_threshold_lower: Optional[float] = None, - rollout_is_veto_threshold: Optional[float] = 1e-4, -) -> tuple[Optional[DataProto], dict[str, Any]]: - """Compute importance sampling weights and metrics for rollout-training mismatch correction. - - This function handles the computation of importance sampling (IS) weights to correct - for the distribution mismatch between rollout policy and training policy. + rollout_is_veto_threshold: Optional[float] = None, +) -> tuple[Optional[DataProto], torch.Tensor, dict[str, Any]]: + """Compute importance sampling weights and rejection mask for rollout-training mismatch. + + This function computes IS weights to correct for distribution mismatch between rollout + and training policies, and applies rejection sampling for outliers. + + Key Design: Separation of IS Weights and Rejection Sampling + - IS weights (rollout_is_weights): Ratios π_train/π_rollout with processing applied: + * Safety-bounded to prevent overflow: + - Token level: exp(clamp(log_ratio, -20, 20)) per token + - Sequence level: exp(clamp(sum(log_ratio), -20, 20)) broadcast to all tokens + - Geometric level: exp(clamp(mean(log_ratio), -20, 20)) broadcast to all tokens + * Truncate mode: upper clamped via .clamp(max=upper_threshold) + * Mask mode: safety-bounded ratios preserved (no threshold clamping) + * All modes: zeroed at padding positions + Used for policy gradient calculations + - Response mask (modified_response_mask): Has rejection applied (mask mode + veto) + Used for loss aggregation to exclude rejected samples from training Reference: When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda - Memory-efficient implementation that prevents CUDA OOM by: - - Using log-space computation where possible - - Applying safety bounds to prevent numerical overflow - - Computing metrics without creating huge intermediate tensors + Memory-efficient implementation: + - Log-space computation to prevent overflow + - Safety bounds (exp(±20)) on all exponentiations + - Metrics computed without large intermediate tensors Args: - old_log_prob: Log probabilities from training policy (e.g., FSDP), shape (batch_size, seq_length) - rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM), shape (batch_size, seq_length) - response_mask: Mask for valid tokens, shape (batch_size, seq_length) - rollout_is_level: Level of IS aggregation: - - "token": Per-token ratios (biased) - - "sequence": Product of ratios (unbiased) - - "geometric": Geometric mean of ratios (experimental) - rollout_is_mode: How to handle weights exceeding threshold: - - "truncate": Cap weights at upper_threshold only - - "mask": Zero out weights outside [lower_threshold, upper_threshold] - rollout_is_threshold: Upper threshold for IS weights - rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper) - rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence. - If None, veto mechanism is disabled. + old_log_prob: Log probs from training policy (FSDP FP32), shape (batch_size, seq_length) + rollout_log_prob: Log probs from rollout policy (vLLM BF16), shape (batch_size, seq_length) + response_mask: Valid token mask (1=valid, 0=padding), shape (batch_size, seq_length) + rollout_is_level: IS weight aggregation level + - "token": Per-token ratios ρ_t = π_train(t)/π_rollout(t) (biased but low variance) + - "sequence": Sequence product ρ_seq = ∏ρ_t (unbiased but high variance) + - "geometric": Geometric mean ρ_geo = (∏ρ_t)^(1/T) (experimental trade-off) + rollout_is_mode: Treatment of outlier IS weights + - "truncate": Clamp weights at upper threshold only. No rejection for outlier ratios, + but veto can still apply (TIS) + - "mask": Reject tokens/sequences outside [lower, upper] via response_mask (MIS/rejection sampling) + rollout_is_threshold: Upper threshold for IS weights (required, e.g., 2.0) + rollout_is_threshold_lower: Lower threshold for mask mode (if None, defaults to 1/upper) + rollout_is_veto_threshold: Catastrophic token threshold. If any token has ratio < this, + reject entire sequence. Applied independently of rollout_is_mode. If None, veto disabled. Default None. Returns: - Tuple of (weights_proto, metrics) where: - weights_proto: DataProto containing IS weights with key "rollout_is_weights", - shape (batch_size, seq_length). Returns None if rollout_is_threshold is None. - metrics: Dictionary of IS statistics and mismatch metrics (KL, PPL, etc.), - all converted to scalars and prefixed with "mismatch/" + Tuple of (weights_proto, modified_response_mask, metrics): + weights_proto: DataProto with processed IS weights, key "rollout_is_weights", + shape (batch_size, seq_length). Processing applied: + - Safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]: + * Token level: bounds per-token ratios + * Sequence/geometric level: bounds aggregated ratio (broadcast to all tokens) + - Truncate mode: upper clamped via .clamp(max=upper_threshold) + - Mask mode: safety-bounded ratios preserved (no threshold clamping) + - All modes: zeroed at padding positions (response_mask == 0) + None if rollout_is_threshold is None. + modified_response_mask: Response mask with rejection applied: + - truncate mode: unchanged for outlier ratios, but veto rejection still applied + - mask mode: tokens outside [lower, upper] masked to 0 + - veto: sequences with catastrophic tokens masked to 0 (applied in both modes) + Shape (batch_size, seq_length). + metrics: Dict of IS and mismatch metrics, all scalars with "mismatch/" prefix """ if rollout_is_threshold is None: - return None, {} + return None, response_mask, {} # Parse thresholds: if lower not specified, use 1/upper (reciprocal) upper_threshold = rollout_is_threshold @@ -179,38 +204,53 @@ def compute_rollout_importance_weights( SAFETY_BOUND=SAFETY_BOUND, ) - # Step 3: Apply truncation or masking based on mode + # Step 3: Apply outlier handling and rejection sampling + # Key design principle: IS weights and rejection are separate mechanisms + # - rollout_is_weights: IS weight ratios with mode-specific processing + # * Truncate mode: upper clamped to prevent extreme values + # * Mask mode: safety-bounded ratios preserved (no threshold clamping, rejection via mask) + # Used for policy gradient calculations + # - modified_response_mask: Has rejection applied (excludes outliers from training) + # Used for loss denominator: ensures rejected samples don't dilute gradients + if rollout_is_mode == "truncate": - # Truncate mode: only cap upper bound to prevent overweighting + # Truncated IS (TIS): clamp weights to prevent extreme importance ratios + # Weights are modified by clamping; no rejection via mask for outlier ratios + # Veto rejection (if enabled) will still be applied to modified_response_mask below rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold) + modified_response_mask = response_mask # Unchanged for outlier ratios (veto applied later) elif rollout_is_mode == "mask": - # Mask mode: zero out weights outside [lower_threshold, upper_threshold] + # Masked IS (MIS): rejection sampling for outlier IS weights + # Reject tokens/sequences with IS ratios outside [lower, upper] via response_mask + # IS weights themselves are NOT threshold-clamped (remain safety-bounded only) mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) mask = mask.float() - # Track MIS-specific metrics + # Compute rejection rate metrics metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask) - - # Sequence-level masking fraction if rollout_is_level in ["sequence", "geometric"]: - # All tokens in a sequence have the same weight, so reuse mask + # Sequence-level: all tokens have same weight, check first token metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean() else: - # Check if any token in each sequence is masked + # Token-level: sequence rejected if ANY token is rejected seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0 metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean() - rollout_is_weights = rollout_is_weights * mask + # Apply rejection via response_mask (NOT by clamping IS weights) + modified_response_mask = response_mask * mask + # rollout_is_weights kept as safety-bounded ratios (no threshold clamping) else: raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.") - # Apply veto mask AFTER all thresholding - # This zeros out entire sequences that have any catastrophic token - rollout_is_weights = rollout_is_weights * veto_mask + # Apply veto: reject entire sequences with catastrophic tokens (ratio < veto_threshold) + # Veto is independent of mode - it applies to modified_response_mask after mode-specific handling + modified_response_mask = modified_response_mask * veto_mask + # Note: rollout_is_weights unaffected by veto (already clamped in truncate mode, or kept as-is in mask mode) - # Apply response_mask to ensure weights are 0 where mask is 0 + # Zero out padding positions in IS weights for correct aggregation + # This is different from rejection - padding must be zeroed regardless of mode rollout_is_weights = rollout_is_weights * response_mask # Wrap in DataProto for consistency with worker methods @@ -231,7 +271,7 @@ def compute_rollout_importance_weights( else: metrics_scalar[f"mismatch/{key}"] = value - return rollout_is_weights_proto, metrics_scalar + return rollout_is_weights_proto, modified_response_mask, metrics_scalar def compute_is_metrics( diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 7e9e6981f3c..9f633730ab2 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -952,41 +952,57 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle metrics.update(global_balance_stats) def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) -> tuple[DataProto, dict]: - """Compute rollout importance sampling weights and mismatch metrics, conditionally add weights to batch. + """Compute IS weights and apply rejection sampling for rollout-training mismatch. - This method computes IS weights to correct for distribution mismatch between - rollout policy and training policy. It always computes metrics when enabled, but - only adds weights to batch if algorithm.rollout_is is True. + Computes importance sampling weights to correct for distribution mismatch between + rollout and training policies. Applies rejection sampling (mask mode/veto) by + modifying response_mask. Always updates response_mask; conditionally adds IS weights. + + Key behavior: + - response_mask: ALWAYS updated with rejection (mask mode + veto excluded from training) + - rollout_is_weights: Added to batch ONLY if config.algorithm.rollout_is=True + + This separation ensures: + - Rejection works even when IS weights are disabled (rollout_is=False) + - Metrics can be monitored before enabling IS weight application Args: - batch: DataProto containing old_log_probs, rollout_log_probs, response_mask + batch: DataProto with old_log_probs, rollout_log_probs, response_mask Returns: - Tuple of (updated_batch, metrics) where: - - updated_batch: Batch with rollout_is_weights added (if rollout_is=True) - - metrics: Dictionary of IS and mismatch metrics (all with mismatch/ prefix) + Tuple of (updated_batch, metrics): + updated_batch: Batch with modified response_mask (always) and rollout_is_weights (if rollout_is=True) + metrics: Dict of IS and mismatch metrics, all with "mismatch/" prefix """ # Compute rollout IS weights if enabled and data is available - # rollout_is_threshold is the main on/off switch - if self.config.algorithm.rollout_is_threshold is not None and "rollout_log_probs" in batch.batch: - rollout_is_weights, rollout_is_metrics = compute_rollout_importance_weights( + # rollout_is_threshold is the main on/off switch (None = disabled, float = enabled) + rollout_is_threshold = self.config.algorithm.get("rollout_is_threshold", None) + if rollout_is_threshold is not None and rollout_is_threshold > 0 and "rollout_log_probs" in batch.batch: + # Compute IS weights and get modified response_mask + rollout_is_weights, modified_response_mask, rollout_is_metrics = compute_rollout_importance_weights( old_log_prob=batch.batch["old_log_probs"], rollout_log_prob=batch.batch["rollout_log_probs"], response_mask=batch.batch["response_mask"], rollout_is_level=self.config.algorithm.rollout_is_level, rollout_is_mode=self.config.algorithm.rollout_is_mode, rollout_is_threshold=self.config.algorithm.rollout_is_threshold, - rollout_is_threshold_lower=self.config.algorithm.rollout_is_threshold_lower, - rollout_is_veto_threshold=self.config.algorithm.rollout_is_veto_threshold, + rollout_is_threshold_lower=self.config.algorithm.get("rollout_is_threshold_lower", None), + rollout_is_veto_threshold=self.config.algorithm.get("rollout_is_veto_threshold", None), ) - # Control: Should we apply weights to policy loss? - # True = add weights to batch (actor will apply them) - # False = don't add weights (metrics only, no loss modification) + # ALWAYS update response_mask with rejection (even if rollout_is=False) + # - Mask mode: tokens with outlier IS ratios excluded + # - Veto: sequences with catastrophic tokens excluded + # This ensures correct loss normalization (rejected samples not in denominator) + batch.batch["response_mask"] = modified_response_mask + + # Conditionally add IS weights based on rollout_is config flag + # - rollout_is=True: Enable IS weight correction in policy loss + # - rollout_is=False: Metrics-only mode (rejection still applied via mask) apply_weights = self.config.algorithm.get("rollout_is", False) if apply_weights: - # Add IS weights to batch for distribution to workers + # Add IS weights (safety-bounded, mode-processed) to enable weight correction batch = batch.union(rollout_is_weights) return batch, rollout_is_metrics