diff --git a/docs/advance/fully_async.md b/docs/advance/fully_async.md index 314c9f324fb..a3400485b96 100644 --- a/docs/advance/fully_async.md +++ b/docs/advance/fully_async.md @@ -166,10 +166,10 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a During the training process, we observed that metrics and response lengths may become unstable in the later stages of training. To mitigate this issue, we can use - the [Rollout Importance Sampling](https://verl.readthedocs.io/en/latest/advance/rollout_is.html) - technique for importance sampling. To utilize Rollout Importance Sampling, we need to compute log_prob using + the [Rollout Correction](https://verl.readthedocs.io/en/latest/advance/rollout_corr.html) + technique for importance sampling and rejection sampling. To utilize Rollout Correction, we need to compute log_prob using the training engine, which requires enabling this switch. - Additionally, when compute_prox_log_prob and Rollout Importance Sampling are enabled under mode d + Additionally, when compute_prox_log_prob and Rollout Correction are enabled under mode d (async stream pipeline with partial rollout), our implementation approximates `Areal's Decoupled PPO`. ### Supported Modes diff --git a/docs/advance/rollout_corr.md b/docs/advance/rollout_corr.md new file mode 100644 index 00000000000..db945abbc8e --- /dev/null +++ b/docs/advance/rollout_corr.md @@ -0,0 +1,1065 @@ +# Rollout Correction + +**Author:** [Yingru Li](https://richardli.xyz/) + +Last updated: 10/30/2025. + +This document provides a comprehensive overview of the Rollout Correction implementation in verl. + +**Note on Naming**: This feature is called "Rollout Correction" to reflect the complete functionality: importance sampling (IS) weights, rejection sampling (RS), and veto mechanism. The internal variable `rollout_is_weights` retains its name as it specifically refers to the IS weights component. + +### BibTeX Citation + +```bibtex +@misc{liu-li-2025, + title = {When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch}, + url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, + author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, + year = {2025}, + month = sep, +} +``` + +## Overview + +Rollout Correction provides a unified framework to handle **general off-policy problems** in RL training. Any scenario where the data collection distribution differs from the training distribution can benefit from these methods. + +**Common off-policy scenarios:** + +1. **Policy Mismatch** (Implementation Differences) + - Different precision: FP8 vs FP16 vs BF16 vs FP32 + - Different backends: vLLM vs SGLang vs FSDP vs Megatron + - Different implementations even with identical weights + +2. **Temporal Lag** (Model Staleness) + - Rollout uses older checkpoint while training has progressed + - Asynchronous rollout workers with stale parameters + - Common in distributed/async RL systems + +3. **Replay Buffers** + - Training on historical trajectories from earlier iterations + - Experience replay from different policy versions + - Data augmentation or resampling strategies + +4. **Off-Policy Algorithms** + - Behavioral cloning from expert demonstrations + - DAPO (data from auxiliary policies) + - Any algorithm using trajectories from a different policy + +5. **Data Quality Filtering** + - Reweighting or filtering collected data + - Preference learning with modified distributions + - Curriculum learning with distribution shifts + +These off-policy gaps can cause training instability and policy collapse. Rollout Correction uses importance sampling (IS) weights and rejection sampling (RS) to correct for any distribution shift between data collection and training. + +**Important Note on Common Implementation Mistakes:** + +Many LLM-RL implementations incorrectly apply PPO by **ignoring the actual rollout policy** π_rollout and assuming the training reference policy π_old is the behavior policy. This is mathematically incorrect when π_rollout ≠ π_old (which is typical in LLM-RL due to precision/backend differences between rollout and training). + +**This is not PPO's fault** - PPO itself is mathematically correct. The issue is the incorrect assumption that π_old = π_rollout in naive implementations. + +This critical implementation mistake that leads to RL training collapse was identified in the blog post ["When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch"](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) and motivated the development of this rollout correction framework. + +**Mathematically correct approaches:** +- **Decoupled mode**: Three policies (π_rollout, π_old, π_θ) with IS correction from π_rollout to π_old +- **Bypass mode**: Two policies (π_rollout = π_old, π_θ) using actual rollout policy as PPO anchor +- **Pure IS mode**: Two policies (π_rollout, π_θ) with IS correction and no PPO clipping + +See [Mathematical Formulations](rollout_corr_math.md#321-incorrect-llm-rl-implementation-ppo-without-rollout-correction) for detailed explanation. + +### Key Design Principle: Separation of IS Weights and Rejection Sampling + +The implementation separates two mechanisms: + +1. **IS Weights** (`rollout_is_weights`): Policy ratios with processing (π_old/π_rollout in decoupled mode, π_θ/π_rollout in bypass/pure IS mode): + - **Safety-bounded** to [exp(-20), exp(20)] ≈ [2e-9, 5e8] to prevent overflow: + * Token level: Bounds per-token ratios + * Sequence level: Bounds product of ratios (broadcast to all tokens in sequence) + * Geometric level: Bounds geometric mean of ratios (broadcast to all tokens) + - **Truncate mode**: Upper clamped via .clamp(max=upper_threshold) + - **Mask mode**: Safety-bounded ratios preserved (no threshold clamping) + - **All modes**: Zeroed at padding positions (response_mask == 0) + - Used for policy gradient calculations + +2. **Rejection Sampling** (`modified_response_mask`): Applied via response_mask + - Mask mode: Excludes tokens/sequences with outlier IS ratios + - Veto: Excludes sequences with catastrophic tokens + - Used for loss aggregation (denominator calculation) + +This separation ensures: +- ✅ Correct loss normalization (rejected samples excluded from denominator) +- ✅ Mode-specific weight processing (truncate: upper clamped, mask: safety-bounded only) +- ✅ Padding positions zeroed in weights (for correct aggregation) +- ✅ Safety bounds always applied (prevent overflow in all modes) + +## Quick Start: Using Verified Presets + +**NEW**: We now provide typed configuration with verified presets for common scenarios. These presets have been validated with tens of thousands of GPU hours across various models and training scenarios. + +### Python API + +```python +from verl.trainer.config.algorithm import RolloutCorrectionConfig + +# Token-level IS +config = RolloutCorrectionConfig.token_is() + +# Sequence-level IS +config = RolloutCorrectionConfig.seq_is() + +# Sequence IS + rejection sampling - alias: seq_mis() +config = RolloutCorrectionConfig.seq_is_rs() + +# Geometric IS + RS + Veto (maximum outlier sensitivity) +config = RolloutCorrectionConfig.geo_rs() + +# Performance mode: PPO with bypass +config = RolloutCorrectionConfig.ppo_is_bypass() + +# Advanced: Pure policy gradient with IS +config = RolloutCorrectionConfig.pure_is() + +# Metrics only (no correction) +config = RolloutCorrectionConfig.disabled() +``` + +### YAML Configuration (Advanced) + +For advanced customization or YAML-based configs: + +```yaml +algorithm: + rollout_correction: + rollout_is: token # IS weights: "token", "sequence", or null + rollout_is_threshold: 2.0 # Upper threshold for IS weights + rollout_rs: null # Rejection sampling: "token", "sequence", "geometric", or null + rollout_rs_threshold: null # RS upper threshold (required if rollout_rs is enabled) + rollout_rs_threshold_lower: null # RS lower threshold (auto-reciprocal if null) + rollout_token_veto_threshold: null # Per-token veto threshold (null = disabled) + bypass_old_logprob_for_rollout: false # Skip old_log_prob computation + use_pure_rollout_correction: false # Pure policy gradient with IS + +# REQUIRED: Enable log prob calculation +actor_rollout_ref: + rollout: + calculate_log_probs: true +``` + +## Files + +### **Core Implementation** + +- `verl/trainer/ppo/rollout_corr_helper.py` - Contains `compute_rollout_correction_and_rejection_mask()` and `compute_offpolicy_metrics()` +- `verl/trainer/ppo/core_algos.py` - Rollout Correction integration with PPO and pure IS mode (`compute_policy_loss_with_rollout_correction()`) +- `verl/trainer/ppo/ray_trainer.py` - Bypass mode implementation (skips `old_log_prob` computation) +- `verl/workers/actor/dp_actor.py` - Mode selection logic and metrics collection + +### **Configuration Files** + +- `verl/trainer/config/algorithm.py` - Rollout Correction parameters in `AlgoConfig` +- `verl/workers/config/actor.py` - Rollout Correction parameters in `ActorConfig` +- `verl/trainer/config/actor/actor.yaml` - Rollout Correction configuration section +- `verl/trainer/config/ppo_trainer.yaml` - Algorithm config with Rollout Correction + +### **Documentation** + +- `docs/examples/config.rst` - Configuration parameter descriptions + +### **Example Scripts** + +- `recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh` - DAPO example with Rollout Correction +- `examples/rollout_correction/README.md` - Comprehensive usage guide +- `examples/rollout_correction/run_with_rollout_corr.sh` - Basic example + +### **Tests** + +- `tests/trainer/ppo/test_rollout_corr.py` - Unit tests for IS/RS mechanisms +- `tests/trainer/ppo/test_rollout_corr_integration.py` - Integration tests + +## Configuration Parameters + +All parameters are under `algorithm.rollout_correction`: + +### `rollout_is` (str or null) +Importance sampling weights aggregation level: +- `null` = No IS weights computed (metrics-only mode) +- `"token"`: Per-token IS weights + - **Decoupled mode**: ρ_t = π_old(t)/π_rollout(t) + - **Bypass/Pure IS mode**: ρ_t = π_θ(t)/π_rollout(t) + - Independent truncation per token + - Typical threshold: 1.5 - 5.0 +- `"sequence"`: Per-sequence weight ρ_seq = ∏_t ρ_t + - Multiplicative aggregation across sequence + - Typical threshold: 2.0 - 10.0 + +All IS weights are safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8] + +### `rollout_is_threshold` (float) +Upper threshold for IS weights. Default: `2.0` +- Used to clamp IS weights (not for rejection) +- Rejection is controlled by `rollout_rs` parameters + +### `rollout_rs` (str or null) +Rejection sampling aggregation level: +- `null` = No rejection sampling +- `"token"`: Reject individual tokens with outlier ratios +- `"sequence"`: Reject entire sequences with outlier ratios +- `"geometric"`: Geometric mean aggregation for rejection + - Typical threshold: 1.0002 - 1.001 + +### `rollout_rs_threshold` (float or null) +Upper threshold for rejection sampling. Default: `null` +- If `null`, uses `rollout_is_threshold` +- Tokens/sequences with ratios > threshold are masked out + +### `rollout_rs_threshold_lower` (float or null) +Lower threshold for rejection sampling. Default: `null` +- If `null`, uses reciprocal of upper threshold (1/upper) +- Tokens/sequences with ratios < threshold are masked out + +### `rollout_token_veto_threshold` (float or null) +Per-token veto for catastrophic outliers. Default: `null` +- Checks **unclamped per-token ratios** before safety bounds +- If ANY token has ratio < threshold, entire sequence is rejected +- Independent of `rollout_is` and `rollout_rs` settings +- Typical values: `1e-4` to `1e-6` when enabled +- Example: `1e-4` catches tokens 10,000x less likely + +## Preset Configuration Guide + +This section provides detailed guidance on choosing and using the verified presets for different scenarios. + +### 1. Token-level Importance Sampling + +**Theory:** Decoupled PPO with per-token truncated importance sampling. + +**Configuration:** +```python +config = RolloutCorrectionConfig.token_is(threshold=2.0) +``` + +**Equivalent YAML:** +```yaml +algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: null +``` + +**Properties:** +- **Algorithm**: Decoupled PPO +- **Policies**: Three (π_rollout, π_old, π_θ) in decoupled mode +- **Double correction**: IS weights correct Drift 1 (rollout→old), PPO clips correct Drift 2 (old→current) + +### 2. Sequence-level Importance Sampling + +**Theory:** Decoupled PPO with sequence-level importance sampling. + +**Configuration:** +```python +config = RolloutCorrectionConfig.seq_is(threshold=2.0) +``` + +**Equivalent YAML:** +```yaml +algorithm: + rollout_correction: + rollout_is: sequence + rollout_is_threshold: 2.0 + rollout_rs: null +``` + +**Properties:** +- **Algorithm**: Decoupled PPO +- **Policies**: Three (π_rollout, π_old, π_θ) in decoupled mode +- **Sequence-level IS**: Uses product of all token ratios (broadcast to all tokens) + +**Note:** Sequence-level IS uses multiplicative aggregation. Typical thresholds: 5.0-10.0 (compared to token-level: 1.5-5.0). + +### 3. Sequence-level IS + Rejection Sampling + +**Theory:** Decoupled PPO combining sequence-level IS weighting with rejection sampling. + +**Alias:** `seq_mis(threshold)` + +**Configuration:** +```python +config = RolloutCorrectionConfig.seq_is_rs(is_threshold=2.0, rs_threshold=2.0) +# OR use alias with single threshold (sets rs_threshold_lower=0) +config = RolloutCorrectionConfig.seq_mis(threshold=2.0) +``` + +**Equivalent YAML:** +```yaml +algorithm: + rollout_correction: + rollout_is: sequence + rollout_is_threshold: 2.0 + rollout_rs: sequence + rollout_rs_threshold: 2.0 + rollout_rs_threshold_lower: 0.5 # Reciprocal of threshold +``` + +**Properties:** +- **Algorithm**: Decoupled PPO + rejection sampling +- **Policies**: Three (π_rollout, π_old, π_θ) in decoupled mode +- **Double mechanism**: IS reweighting + rejection filtering + +### 4. Geometric IS + RS + Veto (Maximum Sensitivity) + +**Theory:** Pure rejection sampling based on geometric mean of IS ratios. + +**Configuration:** +```python +config = RolloutCorrectionConfig.geo_rs(rs_threshold=1.001, veto_threshold=1e-4) +``` + +**Equivalent YAML:** +```yaml +algorithm: + rollout_correction: + rollout_is: null + rollout_rs: geometric + rollout_rs_threshold: 1.001 + rollout_rs_threshold_lower: 0.999 + rollout_token_veto_threshold: 1e-4 +``` + +**Properties:** +- **Algorithm**: Decoupled PPO + geometric rejection sampling +- **Policies**: Three (π_rollout, π_old, π_θ) in decoupled mode +- **No IS weights**: Pure rejection (no reweighting) +- **Extremely selective**: Requires near-perfect policy match + +**Note:** Geometric thresholds are typically very close to 1.0 (typical: 1.0001-1.001, ±0.01%-0.1%). Geometric mean is very sensitive - a threshold of 1.001 rejects sequences with average per-token deviation > 0.1%. + +### 5. PPO with IS Bypass + +**Theory:** PPO applied to off-policy data by using π_rollout as the PPO anchor (bypass mode). + +**Configuration:** +```python +config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) +``` + +**Equivalent YAML:** +```yaml +algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: null + bypass_old_logprob_for_rollout: true + use_pure_rollout_correction: false +``` + +**Properties:** +- **Algorithm**: PPO in bypass mode +- **Policies**: Two (π_rollout = π_old, π_θ) +- **Faster**: Skips `actor.compute_log_prob()` forward pass +- **PPO clipping**: Clips against π_rollout +- **Mathematically correct**: Uses actual behavior policy π_rollout as proximal policy (avoids common mistake of ignoring π_rollout) + +**Configuration requirement:** +- Set `actor_rollout_ref.rollout.calculate_log_probs: true` + +### 6. Pure IS (Off-Policy REINFORCE) + +**Configuration:** +```python +config = RolloutCorrectionConfig.pure_is(threshold=2.0) +``` + +**Theory:** Off-policy REINFORCE with sequence-level truncated importance sampling. + +**Properties:** +- **Algorithm**: Off-policy REINFORCE + IS +- **Policies**: Two (π_rollout, π_θ) +- **No PPO clipping**: Pure policy gradient +- **Always uses bypass mode**: No π_old computation +- **Fast**: Single forward pass for IS weights + +### Summary: How IS Weights are Processed + +The final IS weights go through multiple stages of processing: + +**Stage 1: Safety Bound (All Modes)** +- Token level: `exp(clamp(log_ratio, -20, 20))` per token → bounds each token to [2e-9, 5e8] +- Sequence level: `exp(clamp(sum(log_ratio), -20, 20))` → bounds product to [2e-9, 5e8], broadcast to all tokens +- Geometric level: `exp(clamp(mean(log_ratio), -20, 20))` → bounds geometric mean to [2e-9, 5e8], broadcast to all tokens + +**Stage 2: Threshold Processing (Mode-Dependent)** +- Truncate mode: `.clamp(max=upper_threshold)` → upper clamps weights to threshold +- Mask mode: No modification → weights remain as safety-bounded ratios + +**Stage 3: Padding (All Modes)** +- `weights * response_mask` → zeros out padding positions + +**Rejection Mechanisms (Modify response_mask, NOT weights)** +- Veto: Checks **unclamped per-token ratios** (before safety bound), rejects sequences via mask +- Outlier (mask mode only): Checks safety-bounded weights against [lower, upper], rejects via mask + +## Operation Modes + +The system has **two operating modes** for computing π_old, plus an additional algorithmic option: + +### Operating Modes and Configuration + +| Configuration | `bypass_old_logprob_for_rollout` | `use_pure_rollout_correction` | Operating Mode | Loss Function | Description | +|---------------|----------------------------------|------------------------------|----------------|---------------|-------------| +| **Decoupled** | `false` | `false` | Decoupled | PPO | Computes `old_log_prob` separately via `actor.compute_log_prob()` | +| **Bypass** | `true` | `false` | Bypass | PPO | Sets `old_log_prob = rollout_log_prob`, PPO clips against rollout policy | +| **Pure IS** | `true` | `true` | Bypass | Pure Policy Gradient | Bypass mode with pure IS loss (no PPO clipping) | + +**Operating Mode Descriptions:** + +**Decoupled Mode** (three policies: π_rollout, π_old, π_θ): +- Computes π_old separately at start of training epoch +- Requires extra forward pass via `actor.compute_log_prob()` +- Achieves batch size invariance +- Separately corrects Drift 1 (rollout→old) and Drift 2 (old→current) + +**Bypass Mode** (two policies: π_rollout = π_old, π_θ): +- Sets π_old = π_rollout (skips separate computation) +- Faster: No extra forward pass needed +- Uses π_rollout as both behavior policy and proximal policy +- Does not achieve batch size invariance +- Can be used with PPO clipping or pure policy gradient (Pure IS) + +### IS Weights and Rejection Sampling + +Within each training mode, you can independently control **two correction mechanisms**: + +1. **Importance Sampling (IS) weights**: Controlled by `rollout_is` parameter +2. **Rejection Sampling (RS)**: Controlled by `rollout_rs` parameter + +### Mode Combinations + +| `rollout_is` | `rollout_rs` | Behavior | +|--------------|--------------|----------| +| `null` | `null` | **Disabled**: No computation, no metrics, no rejection | +| `null` | `"token"`, `"sequence"`, or `"geometric"` | **Rejection only**: Compute metrics, NO weight correction, YES rejection sampling | +| `"token"` or `"sequence"` | `null` | **IS weights only**: Weight correction enabled, NO rejection sampling | +| `"token"` or `"sequence"` | `"token"`, `"sequence"`, or `"geometric"` | **Full correction**: Both weight correction and rejection sampling enabled | + +### Key Insights + +- ✅ You can use **rejection sampling alone** without IS weight correction (`rollout_is=null, rollout_rs="token"`) +- ✅ You can use **IS weights alone** without outlier rejection (`rollout_is="token", rollout_rs=null`) +- ✅ You can use **both together** (`rollout_is="token", rollout_rs="token"`) +- ✅ You can **monitor metrics only** without any correction by setting both to `null` but still providing rollout_log_probs + +**Veto rejection** (if enabled via `rollout_token_veto_threshold`) is applied **independently** of IS and RS settings. + +### Example Workflow + +1. **Start with metrics only** to understand the off-policy gap: + ```yaml + algorithm: + rollout_correction: + rollout_is: null + rollout_rs: null + ``` + Monitor `rollout_corr/rollout_is_mean`, `rollout_corr/kl` to assess off-policy gap. + +2. **Enable rejection sampling** if you see high outlier fractions: + ```yaml + algorithm: + rollout_correction: + rollout_is: null + rollout_rs: token + rollout_rs_threshold: 2.0 + ``` + This excludes outliers from training without modifying gradients. + +3. **Enable full IS correction** once comfortable with metrics: + ```yaml + algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: token + rollout_rs_threshold: 2.0 + ``` + +4. **Optional: Enable bypass mode** to save compute: + ```yaml + algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + bypass_old_logprob_for_rollout: true # Skip old_log_prob computation + use_pure_rollout_correction: false # Use Bypass mode + ``` + **Benefits**: Skips expensive forward pass for `old_log_prob` computation + + **Trade-off**: PPO clips against rollout policy instead of true old policy + + **Alternative**: Set `use_pure_rollout_correction: true` for pure policy gradient with IS (no clipping) + +## Usage + +### Basic Setup + +```yaml +algorithm: + rollout_correction: + rollout_is: token # Enable IS weights at token level + rollout_is_threshold: 2.0 # Threshold for IS weights + rollout_rs: null # No rejection sampling + rollout_token_veto_threshold: null # No veto + +actor_rollout_ref: + rollout: + calculate_log_probs: true # Required! +``` + +### Metrics + +All metrics are prefixed with `rollout_corr/` in logs. For example, `rollout_is_mean` appears as `rollout_corr/rollout_is_mean`. + +These metrics cover both: +- **Diagnostic metrics**: KL divergence, perplexity differences (measuring off-policy gap) +- **Correction statistics**: IS weights, rejection rates, veto stats (measuring correction applied) + +#### **Core IS Weight Metrics** + +- **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens + - Value close to 1.0 indicates minimal off-policy gap + +- **`rollout_is_std`**: Standard deviation of IS weights + - Higher values indicate greater variance in IS weights + +- **`rollout_is_min`**: Minimum IS weight observed + - Shows the most underweighted token/sequence + - For sequence/geometric: computed from unclamped log-space ratios (true minimum) + - For token: computed from safety-bounded weights + +- **`rollout_is_max`**: Maximum IS weight observed + - Shows the most overweighted token/sequence + - For sequence/geometric: computed from unclamped log-space ratios (true maximum before safety bound) + - For token: computed from safety-bounded weights (before threshold clamping) + - Compare with `rollout_is_threshold` to see truncation impact + +#### **Effective Sample Size** + +- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting + - **Formula**: `1 / mean(weights²)` where weights are normalized + - **Range**: 0.0 to 1.0 (as fraction of original batch) + - Lower values indicate weight concentration on fewer samples + +#### **Veto Mechanism Metrics** + +- **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism + - **Important**: Sequences are rejected via `response_mask=0`, NOT by modifying IS weights + - **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded) + - Veto checks **unclamped per-token ratios** (true ratios before safety bound) + - Decoupled mode: π_old(t)/π_rollout(t) + - Bypass/Pure IS mode: π_θ(t)/π_rollout(t) + - Detects catastrophic tokens (true ratio < veto_threshold, e.g., < 1e-4) + +- **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold + - Identifies problematic tokens before sequence-level veto is applied + - Checks **unclamped per-token ratios** (true ratios, not safety-bounded) + - Each catastrophic token causes its entire sequence to be rejected + +#### **Threshold Exceedance Metrics** + +- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold + - Shows how often truncation/masking occurs on high end + - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) + - For token: computed from safety-bounded weights (before threshold clamping) + +- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold + - Shows how often masking occurs on low end (mask mode only) + - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) + - For token: computed from safety-bounded weights + +#### **Sequence-Level Metrics** (for sequence/geometric modes) + +- **`rollout_is_seq_mean`**: Mean IS weight at sequence level + - Should match `rollout_is_mean` for sequence-level aggregation + +- **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights + +- **`rollout_is_seq_min`**: Minimum sequence-level IS weight + +- **`rollout_is_seq_max`**: Maximum sequence-level IS weight + +- **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level + - Shows worst-case sequence off-policy gap + +- **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold + +- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold + +#### **Masking Metrics** (mask mode only) + +- **`rollout_is_masked_fraction`**: Fraction of tokens rejected via response_mask (mask mode only) + - **Important**: Tokens are rejected by setting `response_mask=0`, NOT by modifying IS weights + - **IS weights in mask mode**: Safety-bounded ratios preserved (no threshold clamping) + +- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one rejected token + - Shows sequence-level impact of rejection sampling + - For token-level: sequence rejected if ANY token is outside [lower, upper] + - For sequence-level: all tokens have same weight, so entire sequence rejected or accepted + +#### **Off-Policy Diagnostic Metrics** (Training vs Rollout Policy) + +**Note on terminology:** These metrics use "training" to refer to the training reference policy and "rollout" to refer to π_rollout (the behavior policy used for data collection). +- **Decoupled mode**: "training" = π_old (computed at start of training epoch) +- **Bypass/Pure IS mode**: "training" = π_θ (current policy being trained) + +In bypass/pure IS mode, metrics measure the drift between π_θ and π_rollout directly. + +- **`training_ppl`**: Perplexity of training reference policy (π_old in decoupled mode, π_θ in bypass/pure IS mode) + - **Formula**: `exp(-mean(log_probs))` + - Lower values indicate higher model confidence + +- **`rollout_ppl`**: Perplexity of rollout policy π_rollout (e.g., vLLM BF16) + +- **`ppl_ratio`**: Ratio of training PPL to rollout PPL + - **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))` + - **Meaning**: > 1.0 means training is less confident than rollout + +- **`training_log_ppl`**: Log perplexity of training policy + - Useful for identifying trends (linear scale) + +- **`rollout_log_ppl`**: Log perplexity of rollout policy + +- **`log_ppl_diff`**: Mean difference in log perplexities + - **Formula**: `mean(log_ppl_rollout - log_ppl_training)` + - Sign indicates which policy is more confident + +- **`log_ppl_abs_diff`**: Mean absolute log perplexity difference + - Magnitude of off-policy gap regardless of direction + +- **`log_ppl_diff_max`**: Maximum log perplexity difference across sequences + - Identifies worst-case sequence + +- **`log_ppl_diff_min`**: Minimum log perplexity difference across sequences + +- **`kl`**: KL divergence KL(π_rollout || π_training) + - **Formula**: `mean(log_prob_rollout - log_prob_training)` + - **Note**: Can be negative (rollout is less confident) + +- **`k3_kl`**: K3 KL estimator + - **Formula**: `mean(exp(log_ratio) - log_ratio - 1)` + - More stable for small KL values + - Always non-negative + +- **`chi2_token`**: Chi-squared divergence at token level + - **Formula**: `mean(ratio²) - 1` where ratio = π_training/π_rollout + - Measures second moment of IS weight distribution + - Always non-negative + +- **`chi2_seq`**: Chi-squared divergence at sequence level + - **Formula**: `mean((∏_t ratio_t)²) - 1` + - Sequence-level second moment of IS weights + - More sensitive than token-level chi-squared + +#### **Example: Accessing Metrics in Code** + +```python +# Metrics are returned from compute_rollout_correction_and_rejection_mask +from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask + +# Returns 3 values (weights, modified_response_mask, metrics) +weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask( + old_log_prob=training_log_probs, # from training policy + rollout_log_prob=rollout_log_probs, # from rollout policy + response_mask=response_mask, + rollout_is="token", # Enable IS weights at token level + rollout_is_threshold=2.0, + rollout_rs="token", # Enable rejection sampling at token level + rollout_rs_threshold=2.0, + rollout_rs_threshold_lower=0.5, + rollout_token_veto_threshold=1e-4, # Enable veto for catastrophic outliers +) + +# Extract IS weights (processed, zeroed at padding) +is_weights = weights_proto.batch["rollout_is_weights"] + +# IS weights processing (with IS enabled at token level): +# 1. Safety-bounded: exp(clamp(log_ratio, -20, 20)) per token +# 2. Zeroed at padding positions +# Note: Not threshold-clamped since we're using rejection sampling (rollout_rs) + +# modified_response_mask has rejection applied (since rollout_rs="token"): +# 1. Outlier rejection: tokens outside [0.5, 2.0] masked to 0 +# 2. Veto rejection: sequences with catastrophic tokens (ratio < 1e-4) masked to 0 +# Note: Veto checks unclamped per-token ratios, not the safety-bounded weights + +# All metrics have 'rollout_corr/' prefix +print(f"Mean IS weight: {metrics['rollout_corr/rollout_is_mean']:.3f}") +print(f"Effective sample size: {metrics['rollout_corr/rollout_is_eff_sample_size']:.3f}") +print(f"Veto fraction: {metrics['rollout_corr/rollout_is_veto_fraction']:.3f}") +print(f"Masked fraction: {metrics['rollout_corr/rollout_is_masked_fraction']:.3f}") +print(f"KL divergence: {metrics['rollout_corr/kl']:.3f}") + +# Check IS weights for valid tokens (non-padding) +valid_weights = is_weights[response_mask.bool()] +print(f"\n✓ IS weights min (valid tokens): {valid_weights.min():.4f}") +print(f"✓ IS weights max (valid tokens): {valid_weights.max():.4f}") +print(f"✓ All valid IS weights > 0: {(valid_weights > 0).all()}") + +# Check rejection via response_mask +rejected_tokens = (response_mask == 1) & (modified_response_mask == 0) +print(f"\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask") +print(f"✓ With rejection sampling (rollout_rs): tokens outside thresholds are masked") +print(f"✓ IS weights are always safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]") + +# Check for warning conditions +if metrics['rollout_corr/rollout_is_mean'] < 0.5 or metrics['rollout_corr/rollout_is_mean'] > 2.0: + print("⚠️ Warning: Mean IS weight far from 1.0, significant off-policy gap detected") + +if metrics['rollout_corr/rollout_is_eff_sample_size'] < 0.3: + print("⚠️ Warning: Low effective sample size, high weight concentration") + +if metrics['rollout_corr/rollout_is_veto_fraction'] > 0.1: + print("⚠️ Warning: High veto fraction, policies may be too different") +``` + +#### **Example: Monitoring Metrics During Training** + +```python +# In your training loop +for epoch in range(num_epochs): + for batch_idx, batch in enumerate(dataloader): + # ... rollout phase ... + + # Compute IS weights and get metrics + rollout_corr_config = config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None: + weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask( + old_log_prob=batch.old_log_prob, + rollout_log_prob=batch.rollout_log_prob, + response_mask=batch.response_mask, + rollout_is=rollout_corr_config.get("rollout_is", None), + rollout_is_threshold=rollout_corr_config.get("rollout_is_threshold", 2.0), + rollout_rs=rollout_corr_config.get("rollout_rs", None), + rollout_rs_threshold=rollout_corr_config.get("rollout_rs_threshold", None), + rollout_rs_threshold_lower=rollout_corr_config.get("rollout_rs_threshold_lower", None), + rollout_token_veto_threshold=rollout_corr_config.get("rollout_token_veto_threshold", None), + ) + + # Log to tensorboard/wandb + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step=global_step) + + # IMPORTANT: Update batch response_mask with rejection applied + batch.response_mask = modified_response_mask + + # Use IS weights in training (always safety-bounded, zeroed at padding) + is_weights = weights_proto.batch["rollout_is_weights"] + # ... apply weights to policy gradient ... +``` + +#### **Example: Conditional Alerting Based on Metrics** + +```python +def check_rollout_correction_health(metrics, config): + """Check if Rollout Correction metrics indicate healthy training.""" + warnings = [] + + # Check mean IS weight + mean_weight = metrics['rollout_corr/rollout_is_mean'] + if mean_weight < 0.5 or mean_weight > 2.0: + warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0") + + # Check effective sample size + ess = metrics['rollout_corr/rollout_is_eff_sample_size'] + if ess < 0.3: + warnings.append(f"Effective sample size {ess:.3f} is too low") + + # Check veto fraction + veto_frac = metrics['rollout_corr/rollout_is_veto_fraction'] + if veto_frac > 0.1: + warnings.append(f"Veto fraction {veto_frac:.3f} is too high") + + # Check standard deviation + std = metrics['rollout_corr/rollout_is_std'] + if std > 1.0: + warnings.append(f"IS weight std {std:.3f} is too high") + + # Check KL divergence + kl = metrics['rollout_corr/kl'] + if abs(kl) > 0.1: + warnings.append(f"KL divergence {kl:.3f} indicates significant off-policy gap") + + # Check chi-squared divergence + if 'rollout_corr/chi2_token' in metrics: + chi2_token = metrics['rollout_corr/chi2_token'] + if chi2_token > 1.0: + warnings.append(f"Chi-squared divergence (token) {chi2_token:.3f} indicates severe distribution shift") + + if warnings: + print("⚠️ Rollout Correction Health Warnings:") + for warning in warnings: + print(f" - {warning}") + return False + else: + print("✅ Rollout Correction metrics look healthy") + return True + +# Use in training +_, _, metrics = compute_rollout_correction_and_rejection_mask(...) +is_healthy = check_rollout_correction_health(metrics, config) + +if not is_healthy: + # Consider adjusting config or investigating issues + print("Consider:") + print(" - Tightening rollout_is_threshold") + print(" - Switching to geometric aggregation level") + print(" - Checking if rollout and training policies are too different") +``` + +### Running Examples + +Start with the basic token-level truncate configuration: +```bash +bash examples/rollout_correction/run_with_rollout_corr.sh +``` + +Monitor metrics for 1-2 epochs before adjusting parameters. + +## Configuration Examples + +### Example 1: IS Weights Only (Token Level) +```yaml +algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: null # No rejection sampling +``` + +### Example 2: Rejection Sampling Only (No IS Weights) +```yaml +algorithm: + rollout_correction: + rollout_is: null # No IS weights + rollout_rs: token + rollout_rs_threshold: 2.0 + rollout_rs_threshold_lower: 0.5 +``` + +### Example 3: Both IS and RS (Geometric RS) +```yaml +algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: geometric + rollout_rs_threshold: 1.0002 + rollout_rs_threshold_lower: 0.9998 +``` + +### Example 4: Full Correction with Veto +```yaml +algorithm: + rollout_correction: + rollout_is: sequence + rollout_is_threshold: 2.0 + rollout_rs: token + rollout_rs_threshold: 2.0 + rollout_rs_threshold_lower: 0.5 + rollout_token_veto_threshold: 1e-4 # Veto catastrophic tokens +``` + +### Example 5: Bypass Mode +```yaml +algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: token + rollout_rs_threshold: 2.0 + bypass_old_logprob_for_rollout: true # Skip old_log_prob computation + use_pure_rollout_correction: false # Use bypass mode: PPO with rollout_log_prob as old_log_prob +``` +**Skips expensive `actor.compute_log_prob()` forward pass** + +### Example 6: Pure Policy Gradient Mode +```yaml +algorithm: + rollout_correction: + rollout_is: token # Explicit IS correction in loss + rollout_is_threshold: 2.0 + rollout_rs: null # Optional: can add rejection sampling + bypass_old_logprob_for_rollout: true # Required for pure mode + use_pure_rollout_correction: true # Use pure policy gradient with IS +``` +**No PPO clipping, pure policy gradient with IS correction** + +## Troubleshooting + +### Issue: High spread in IS weights +**Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3 + +**Solutions:** +1. Switch from `sequence` to `geometric` level +2. Tighten thresholds +3. Verify rollout and training aren't too different + +### Issue: Too many sequences vetoed +**Symptoms:** `rollout_is_veto_fraction` > 0.1 + +**Solutions:** +1. Relax veto threshold in config: + ```yaml + algorithm: + rollout_correction: + rollout_token_veto_threshold: 1e-3 + ``` +2. Check for numerical issues in log prob computation +3. Verify policies aren't completely different + +### Issue: Mean IS weight far from 1.0 +**Symptoms:** `rollout_is_mean` < 0.5 or > 2.0 + +**Solutions:** +1. Verify `calculate_log_probs=True` is set +2. Check rollout_log_probs are correctly passed +3. Check for systematic distribution shift + +### Debugging: Visualizing Metrics + +**Example: Plot IS weight distribution** + +```python +import matplotlib.pyplot as plt +import numpy as np + +def plot_is_metrics(metrics_history): + """Plot rollout IS metrics over training steps.""" + fig, axes = plt.subplots(2, 3, figsize=(15, 10)) + + # Plot 1: Mean IS weight over time + axes[0, 0].plot(metrics_history['rollout_corr/rollout_is_mean']) + axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal') + axes[0, 0].set_title('Mean IS Weight') + axes[0, 0].set_xlabel('Step') + axes[0, 0].legend() + + # Plot 2: Effective sample size + axes[0, 1].plot(metrics_history['rollout_corr/rollout_is_eff_sample_size']) + axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good') + axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning') + axes[0, 1].set_title('Effective Sample Size') + axes[0, 1].set_xlabel('Step') + axes[0, 1].legend() + + # Plot 3: Veto fraction + axes[0, 2].plot(metrics_history['rollout_corr/rollout_is_veto_fraction']) + axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning') + axes[0, 2].set_title('Veto Fraction') + axes[0, 2].set_xlabel('Step') + axes[0, 2].legend() + + # Plot 4: KL divergence over time + axes[1, 0].plot(metrics_history['rollout_corr/kl'], label='KL') + axes[1, 0].plot(metrics_history['rollout_corr/k3_kl'], label='K3 KL') + axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3) + axes[1, 0].set_title('KL Divergence') + axes[1, 0].set_xlabel('Step') + axes[1, 0].legend() + + # Plot 5: PPL ratio over time + axes[1, 1].plot(metrics_history['rollout_corr/ppl_ratio']) + axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal') + axes[1, 1].set_title('PPL Ratio (Training/Rollout)') + axes[1, 1].set_xlabel('Step') + axes[1, 1].legend() + + # Plot 6: Chi-squared divergence + if 'rollout_corr/chi2_token' in metrics_history: + axes[1, 2].plot(metrics_history['rollout_corr/chi2_token'], label='Token-level') + if 'rollout_corr/chi2_seq' in metrics_history: + axes[1, 2].plot(metrics_history['rollout_corr/chi2_seq'], label='Seq-level') + axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Warning') + axes[1, 2].set_title('Chi-squared Divergence') + axes[1, 2].set_xlabel('Step') + axes[1, 2].legend() + else: + axes[1, 2].axis('off') + + plt.tight_layout() + plt.savefig('rollout_is_metrics.png', dpi=150) + print("Saved plot to rollout_is_metrics.png") +``` + +**Example: Metric collection during training** + +```python +# Collect metrics over time +metrics_history = { + 'rollout_corr/rollout_is_mean': [], + 'rollout_corr/rollout_is_eff_sample_size': [], + 'rollout_corr/rollout_is_veto_fraction': [], + 'rollout_corr/kl': [], + 'rollout_corr/k3_kl': [], + 'rollout_corr/ppl_ratio': [], + 'rollout_corr/chi2_token': [], + 'rollout_corr/chi2_seq': [], +} + +# In training loop +for step in range(num_steps): + # ... compute IS weights and rejection mask ... + _, _, metrics = compute_rollout_correction_and_rejection_mask(...) + + # Store metrics + for key in metrics_history.keys(): + if key in metrics: + metrics_history[key].append(metrics[key]) + + # Plot every 100 steps + if step % 100 == 0: + plot_is_metrics(metrics_history) +``` + +## Performance Impact + +- **Memory overhead**: ~1% of model memory +- **Computational overhead**: 1-3% depending on level +- **Training stability**: Significantly improved when off-policy gap exists + + +## Testing + +Run the test suite to verify everything works: + +```bash +# Basic unit tests +python test_rollout_corr.py + +# Integration tests (if pytest is available) +pytest tests/trainer/ppo/test_rollout_corr_integration.py -v +``` + +Expected output: All tests pass ✓ + +## Additional Resources + +- **Implementation**: `verl/trainer/ppo/rollout_corr_helper.py` +- **Examples**: `examples/rollout_correction/` +- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh` + +## Summary + +Rollout Correction provides a unified framework for handling general off-policy problems in RL: +- ✅ Corrects ANY distribution shift between data collection and training +- ✅ Supports diverse scenarios: policy mismatch, staleness, replay buffers, off-policy algorithms +- ✅ Numerical stability with safety bounds and rejection mechanisms +- ✅ Comprehensive diagnostics: KL, perplexity, χ² divergence +- ✅ Flexible methods from token-level (token_is) to sequence-level (seq_is_rs) +- ✅ Memory-efficient implementation + +## References + +- **[Mathematical Formulations](rollout_corr_math.md)** - Detailed mathematical theory and derivations for all rollout correction methods +- [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) +- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) \ No newline at end of file diff --git a/docs/advance/rollout_corr_math.md b/docs/advance/rollout_corr_math.md new file mode 100644 index 00000000000..8e756a423b6 --- /dev/null +++ b/docs/advance/rollout_corr_math.md @@ -0,0 +1,585 @@ +# Mathematical Formulations of Rollout Correction Methods in `verl` + +**Author:** [Yingru Li](https://richardli.xyz) +**Last updated:** 2025-11-04 + +--- + +## Abstract + +This document provides the definitive mathematical formulations for rollout correction methods in `verl`, following the natural progression from **REINFORCE** to **PPO** to **Decoupled PPO**. + +Rollout correction provides a unified framework to handle **general off-policy problems** in RL training - any scenario where the data collection distribution differs from the training distribution. + +**Applicable scenarios include:** +- **Policy mismatch**: Different precision (FP8 vs FP16 vs BF16 vs FP32), different backends (vLLM vs SGLang vs FSDP vs Megatron) +- **Temporal lag**: Model staleness, asynchronous rollout workers +- **Replay buffers**: Training on historical trajectories from earlier policy versions +- **Off-policy algorithms**: Behavioral cloning, DAPO, expert demonstrations +- **Data filtering**: Reweighting, preference learning, curriculum learning + +--- + +## Table of Contents + +1. [Theoretical Foundation: From REINFORCE to Decoupled PPO](#1-theoretical-foundation-from-reinforce-to-decoupled-ppo) +2. [Implementation in verl: The Three-Policy Framework](#2-implementation-in-verl-the-three-policy-framework) +3. [Method Variants: Different Algorithmic Choices](#3-method-variants-different-algorithmic-choices) +4. [Safety Mechanisms and Rejection Sampling](#4-safety-mechanisms-and-rejection-sampling) +5. [Off-Policy Diagnostic Metrics](#5-off-policy-diagnostic-metrics) +6. [Summary and Decision Guide](#6-summary-and-decision-guide) +7. [Implementation References](#7-implementation-references) + +--- + +## 1. Theoretical Foundation: From REINFORCE to Decoupled PPO + +This section establishes the theoretical progression that `verl` implements. + +### 1.1 REINFORCE: Policy Gradient Baseline + +The REINFORCE algorithm ([Williams, 1992](https://doi.org/10.1007/BF00992696)) is the foundation of policy gradient methods. + +**Vanilla REINFORCE (On-Policy)** + +For trajectories $\tau = (s_0, a_0, s_1, a_1, \ldots, s_T, a_T)$ sampled from the current policy $\pi_\theta$, the policy gradient is: + +$$ +\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A_t \right] +$$ + +where $A_t$ is the advantage function at timestep $t$. + +**Off-Policy REINFORCE** + +When trajectories are sampled from a different behavior policy $\mu$, we apply importance sampling over the **joint trajectory distribution**: + +$$ +\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \mu} \left[ \frac{P_{\pi_\theta}(\tau)}{P_\mu(\tau)} \sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot A_t \right] +$$ + +where the trajectory-level importance weight is: + +$$ +\frac{P_{\pi_\theta}(\tau)}{P_\mu(\tau)} = \frac{p(s_0) \prod_{t=0}^T \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)}{p(s_0) \prod_{t=0}^T \mu(a_t|s_t) p(s_{t+1}|s_t, a_t)} = \prod_{t=0}^T \frac{\pi_\theta(a_t|s_t)}{\mu(a_t|s_t)} +$$ + +The transition dynamics $p(s_{t+1}|s_t, a_t)$ and initial state $p(s_0)$ cancel out, leaving only the product of per-step action probability ratios. + +**Key properties:** +- **Off-policy capable**: Can learn from any behavior policy via importance sampling +- **No trust region**: Policy updates not constrained + +**Implementation in verl:** The `pure_is` method implements off-policy REINFORCE with truncated importance sampling. + +### 1.2 PPO: Adding Trust Region Control + +Proximal Policy Optimization ([Schulman et al., 2017](https://arxiv.org/abs/1707.06347)) adds a clipped surrogate objective: + +$$ +L_{\text{PPO}}(\theta) = -\mathbb{E}_{(s,a) \sim \mu} \left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\mu(a_t|s_t)}$ and $\epsilon$ is the clip range (typically 0.2). + +**Key properties:** +- **Two policies**: $\mu$ (reference for clipping) and $\pi_\theta$ (being updated) +- **Trust region via clipping**: Limits policy update magnitude via ratio $r_t(\theta) = \frac{\pi_\theta}{\mu}$ + +### 1.3 Decoupled PPO: Achieving Batch Size Invariance + +Decoupled PPO ([Hilton et al., 2021](https://arxiv.org/abs/2110.00641)) solves PPO's batch size sensitivity by **decoupling two roles**: +1. **Proximal policy** $\pi_{\text{prox}}$: The anchor policy for PPO clipping (controls policy update size) +2. **Behavior policy** $\mu$: The policy that collected the data (for off-policy correction via importance sampling) + +**The problem**: Standard PPO controls policy update size via the ratio $\frac{\pi_\theta}{\pi_{\text{old}}}$, where $\pi_{\text{old}}$ is assumed to be both the proximal policy *and* the behavior policy. This coupling makes the algorithm sensitive to batch size because aggregating data from multiple workers or using replay buffers changes the effective behavior policy. + +**The solution**: Decouple these two roles, leading to a **three-policy formulation**: + +$$ +L_{\text{DecoupledPPO}}(\theta) = -\mathbb{E}_{(s,a) \sim \mu} \left[ w_t \cdot \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where: +- $w_t = \frac{\pi_{\text{prox}}(a_t|s_t)}{\mu(a_t|s_t)}$: Importance sampling weight (corrects for behavior policy $\mu$) +- $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{prox}}(a_t|s_t)}$: PPO ratio (controls policy update size against proximal policy $\pi_{\text{prox}}$) + +**Key properties**: By decoupling: +- **Batch size invariance**: Policy update control (via $\pi_{\text{prox}}$) is independent of data aggregation +- **Flexible behavior policy**: Any $\mu$ can be used (different workers, replay buffers, or stale checkpoints) +- **Stale data utilization**: Older trajectories can be corrected via importance sampling +- **Clipping preserved**: Clipping against $\pi_{\text{prox}}$ limits update magnitude + +**This is the algorithm that `verl` implements via its three-policy framework.** + +--- + +## 2. Implementation in verl: The Three-Policy Framework + +The `verl` library implements decoupled PPO using three distinct policies, each serving a specific role. + +### 2.1 Policy Roles and Notation + +**$\pi_{\text{rollout}}$ (Behavior Policy $\mu$)** +The policy used for data collection. This is the behavior distribution $\mu$ from theory. + +- **When created**: During rollout/data collection phase +- **Purpose**: Generate trajectories for training +- **Common sources**: + - Policy mismatch: Same weights, different implementation (precision, backend) + - Temporal lag: Stale checkpoint from async workers + - Replay buffer: Historical data from earlier iterations + - Off-policy algorithms: Expert demonstrations, auxiliary policies (DAPO) + - Data filtering: Reweighted or filtered data +- **Fixed**: Frozen during training on a batch + +**$\pi_{\text{old}}$ (Proximal Policy $\pi_{\text{prox}}$)** +The reference policy for PPO clipping. This is the "proximal policy" from decoupled PPO theory. + +- **When created**: + - **Decoupled mode**: Computed at start of training epoch via `actor.compute_log_prob()` + - **Bypass mode**: Set equal to $\pi_{\text{rollout}}$ (skips separate computation) +- **Purpose**: + - Anchor point for PPO clipping (controls policy update size) + - When separate from $\pi_{\text{rollout}}$: Enables batch size invariance and efficient use of stale data +- **Fixed**: Frozen during all PPO update epochs on the same batch + +**$\pi_{\theta}$ (Current Policy)** +The policy being actively optimized during training. + +- **Updated**: Every gradient step +- **Purpose**: The policy we're improving + +### 2.2 Operating Modes + +The three-policy framework can operate in two modes: + +**Decoupled Mode (Three Policies)** +- Computes $\pi_{\text{old}}$ separately at the start of each training epoch +- **Algorithm**: Full decoupled PPO with three policies (mathematically correct) +- **Properties**: Achieves batch size invariance; separately corrects Drift 1 (rollout→old) and Drift 2 (old→current) + +**Bypass Mode (Two Policies)** +- Sets $\pi_{\text{old}} = \pi_{\text{rollout}}$ (skips separate computation) +- **Algorithm**: Uses $\pi_{\text{rollout}}$ as both behavior policy and proximal policy (mathematically correct) +- **Key difference**: Proximal policy equals behavior policy, so no IS correction needed between them +- **Properties**: Faster (skips `actor.compute_log_prob()` call); does not achieve batch size invariance + +### 2.3 Two Distribution Shifts + +The three-policy framework handles two types of distribution drift: + +**Drift 1: $\pi_{\text{rollout}} \to \pi_{\text{old}}$ (Off-Policy Gap)** + +This is the distribution shift between the data collection policy and the training reference policy. + +- **Nature**: Ranges from negligible (same checkpoint, minor differences) to severe (replay buffers, expert data) +- **Correction**: Importance sampling weight $w_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ +- **Optional**: Can be ignored (bypass mode) when negligible + +**Drift 2: $\pi_{\text{old}} \to \pi_{\theta}$ (Policy Update Drift)** + +This is the drift from policy parameter updates during training. + +- **Nature**: Occurs as $\pi_\theta$ is updated via gradient descent +- **Correction**: PPO clipping on ratio $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$ +- **Universal**: Applies to both on-policy and off-policy training + +### 2.4 Notation Summary + +- $\pi_{\text{rollout}}$: Behavior policy (data collection) +- $\pi_{\text{old}}$: Proximal policy (PPO anchor) +- $\pi_{\theta}$: Current policy (being updated) +- $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$: Per-token IS ratio (corrects Drift 1) +- $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$: PPO ratio (corrects Drift 2) +- $A_t$: Advantage at token $t$ +- $T$: Set of valid tokens in a sequence +- $C_{\text{IS}}$: Upper threshold for IS weights (e.g., 2.0) +- $C_{\text{RS-upper}}$: Upper threshold for RS mask (e.g., 2.0) +- $C_{\text{RS-lower}}$: Lower threshold for RS mask (typically $1/C_{\text{RS-upper}}$) +- $\epsilon$: PPO clip range (typically 0.2) + +--- + +## 3. Method Variants: Different Algorithmic Choices + +This section describes the different algorithmic variants available in `verl`, organized by their theoretical foundation. + +### 3.1 Off-Policy REINFORCE Methods + +These methods implement REINFORCE with importance sampling, without PPO clipping. + +#### 3.1.1 Pure IS (pure_is) + +**Theory:** Off-policy REINFORCE with sequence-level truncated importance sampling. + +**Configuration:** +```python +RolloutCorrectionConfig.pure_is(threshold=2.0) +``` + +**Loss Function:** + +$$ +L_{\text{PureIS}}(\theta) = -\mathbb{E}_{(s,a) \sim \pi_{\text{rollout}}} \left[ w_{\text{seq}}(\theta) \cdot \sum_{t \in T} \log \pi_{\theta}(a_t|s_t) \cdot A_t \right] +$$ + +where: +- Sequence-level IS weight: $w_{\text{seq}}(\theta) = \min\left( \prod_{t \in T} \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}, C_{\text{IS}} \right)$ +- IS weight is **detached from gradient** (treated as constant) +- Direct comparison: $\pi_\theta$ to $\pi_{\text{rollout}}$ + +**Effective gradient:** + +$$ +\nabla_\theta L_{\text{PureIS}} = -\mathbb{E}_{(s,a) \sim \pi_{\text{rollout}}} \left[ \text{stopgrad}(w_{\text{seq}}(\theta)) \cdot \sum_{t \in T} \nabla_\theta \log \pi_{\theta}(a_t|s_t) \cdot A_t \right] +$$ + +**Properties:** +- **Algorithm**: Off-policy REINFORCE + IS +- **Policies**: Two ($\pi_{\text{rollout}}$, $\pi_\theta$) +- **No PPO clipping**: Pure policy gradient +- **Always uses bypass mode**: No $\pi_{\text{old}}$ computation +- **Fast**: Single forward pass for IS weights + +**Implementation:** `compute_policy_loss_with_rollout_correction()` in [core_algos.py](../../verl/trainer/ppo/core_algos.py#L1537-L1681) + +--- + +### 3.2 Two-Policy PPO Methods + +These methods use two policies without importance sampling between behavior and proximal policies. + +#### 3.2.1 Incorrect LLM-RL Implementation (PPO Without Rollout Correction) + +**Theory:** Naive LLM-RL implementation that incorrectly applies PPO by ignoring the actual rollout policy and assuming $\mu = \pi_{\text{old}}$. + +**Note:** This incorrect implementation pattern was identified in [Liu, Li, et al. (2025)](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda) as a key cause of training instability in LLM-RL systems, motivating the development of this rollout correction framework. + +**Loss Function:** + +$$ +L_{\text{PPO}}(\theta) = -\mathbb{E}_t \left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$. + +**Properties:** +- **Algorithm**: Common but incorrect LLM-RL implementation (mathematically wrong when $\pi_{\text{rollout}} \neq \pi_{\text{old}}$) +- **Policies**: Two ($\pi_{\text{old}}$, $\pi_\theta$) +- **Ignores $\pi_{\text{rollout}}$**: Uses $\pi_{\text{old}}$ as behavior policy instead of actual $\pi_{\text{rollout}}$ +- **Policy mismatch**: This is the typical case in LLM-RL - rollout uses different precision/backend/checkpoint than training, causing $\pi_{\text{rollout}} \neq \pi_{\text{old}}$ even with same model weights +- **Not PPO's fault**: PPO itself is correct; the issue is the incorrect assumption that $\pi_{\text{old}} = \pi_{\text{rollout}}$ in LLM-RL implementations + +**Implementation:** `compute_policy_loss()` in [core_algos.py](../../verl/trainer/ppo/core_algos.py#L812-L884) + +#### 3.2.2 PPO Bypass (ppo_is_bypass) + +**Theory:** Original PPO applied to off-policy data by using $\pi_{\text{rollout}}$ as the PPO anchor. + +**Configuration:** +```python +RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) +``` + +**Implementation:** When `bypass_old_logprob_for_rollout=True`, we set $\pi_{\text{old}} = \pi_{\text{rollout}}$: +- IS weight: $w_t = \frac{\pi_{\text{old}}}{\pi_{\text{rollout}}} = 1$ +- PPO ratio: $r_t(\theta) = \frac{\pi_{\theta}}{\pi_{\text{old}}} = \frac{\pi_{\theta}}{\pi_{\text{rollout}}}$ + +**Loss Function:** + +$$ +L_{\text{PPO-Bypass}}(\theta) = -\mathbb{E}_t \left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$ (clips against rollout policy). + +**Properties:** +- **Algorithm**: PPO with $\pi_{\text{rollout}}$ as proximal policy (two policies) +- **Policies**: Two ($\pi_{\text{rollout}}$, $\pi_\theta$) +- **No IS correction needed**: Uses actual behavior policy $\pi_{\text{rollout}}$ as proximal policy (mathematically correct) +- **PPO clips against rollout**: Trust region relative to data collection policy +- **Fast**: Skips `actor.compute_log_prob()` call + +--- + +### 3.3 Decoupled PPO Methods + +These methods implement full decoupled PPO with three policies, combining importance sampling (for Drift 1) with PPO clipping (for Drift 2). + +#### 3.3.1 Token-Level IS (token_is) + +**Theory:** Decoupled PPO with per-token truncated importance sampling. + +**Configuration:** +```python +RolloutCorrectionConfig.token_is(threshold=2.0) +``` + +**Loss Function:** + +$$ +L_{\text{PPO+TIS}}(\theta) = -\mathbb{E}_t \left[ w_t \cdot \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where: +- Per-token IS weight: $w_t = \min(\rho_t, C_{\text{IS}}) = \min\left(\frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}, C_{\text{IS}} \right)$ +- PPO ratio: $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$ +- $\pi_{\text{old}}$ is computed at **start of training epoch** + +**Properties:** +- **Algorithm**: Decoupled PPO +- **Policies**: Three ($\pi_{\text{rollout}}$, $\pi_{\text{old}}$, $\pi_\theta$) in decoupled mode +- **Double correction**: IS weights correct Drift 1, PPO clips correct Drift 2 +- **Per-token truncation**: Stable IS weight computation + +**Implementation:** +- IS weights: `compute_rollout_correction_weights()` in [rollout_corr_helper.py](../../verl/trainer/ppo/rollout_corr_helper.py#L325-L402) +- Loss: `compute_policy_loss()` in [core_algos.py](../../verl/trainer/ppo/core_algos.py#L812-L884) + +#### 3.3.2 Sequence-Level IS (seq_is) + +**Theory:** Decoupled PPO with sequence-level importance sampling. + +**Configuration:** +```python +RolloutCorrectionConfig.seq_is(threshold=2.0) +``` + +**Loss Function:** + +$$ +L_{\text{PPO+SeqIS}}(\theta) = -\mathbb{E}_t \left[ w_{\text{seq}} \cdot \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where: +- Sequence-level IS weight (broadcast to all tokens): +$$w_{\text{seq}} = \min\left( \prod_{t \in T} \rho_t, C_{\text{IS}} \right) = \min\left( \exp\left(\sum_{t \in T} \log \rho_t\right), C_{\text{IS}} \right)$$ +- PPO ratio: $r_t(\theta) = \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\text{old}}(a_t|s_t)}$ + +**Properties:** +- **Algorithm**: Decoupled PPO +- **Policies**: Three ($\pi_{\text{rollout}}$, $\pi_{\text{old}}$, $\pi_\theta$) in decoupled mode +- **Sequence-level IS**: Uses product of all token ratios + +#### 3.3.3 Mixed IS + Rejection Sampling (seq_is_rs / seq_mis) + +**Theory:** Decoupled PPO combining sequence-level IS weighting with rejection sampling. + +**Configuration:** +```python +RolloutCorrectionConfig.seq_is_rs( + is_threshold=2.0, + rs_threshold=2.0, + rs_threshold_lower=None, # defaults to 1/rs_threshold +) +``` + +**Loss Function:** + +$$ +L_{\text{PPO+MIS}}(\theta) = -\mathbb{E}_{t \mid \text{seq} \in \mathcal{A}} \left[ w_{\text{seq}} \cdot \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where: +- IS weight: $w_{\text{seq}} = \min\left( \prod_{t \in T} \rho_t, C_{\text{IS}} \right)$ +- Acceptance set: $\mathcal{A} = \{ \text{seq} : C_{\text{RS-lower}} \leq \prod_{t \in T} \rho_t \leq C_{\text{RS-upper}} \}$ + +**Properties:** +- **Algorithm**: Decoupled PPO + rejection sampling +- **Double mechanism**: IS reweighting + rejection filtering +- **Lower effective sample size**: Rejects outlier sequences + +**Implementation:** `compute_rollout_rejection_mask()` in [rollout_corr_helper.py](../../verl/trainer/ppo/rollout_corr_helper.py#L80-L188) + +--- + +## 4. Safety Mechanisms and Rejection Sampling + +### 4.1 Geometric Rejection Sampling (geo_rs) + +**Theory:** Pure rejection sampling based on geometric mean of IS ratios. + +**Configuration:** +```python +RolloutCorrectionConfig.geo_rs( + rs_threshold=1.001, # Very tight threshold + rs_threshold_lower=None, + veto_threshold=1e-4, +) +``` + +**Loss Function:** + +$$ +L_{\text{GeoRS}}(\theta) = -\mathbb{E}_{t \mid \text{seq} \in \mathcal{A}_{\text{geo}} \cap \mathcal{A}_{\text{veto}}} \left[ \min\left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right] +$$ + +where: +- Geometric mean: $\rho_{\text{geo}} = \exp\left( \frac{1}{|T|} \sum_{t \in T} \log \rho_t \right) = \left(\prod_{t \in T} \rho_t\right)^{1/|T|}$ +- Geometric acceptance: $\mathcal{A}_{\text{geo}} = \{ \text{seq} : C_{\text{RS-lower}} \leq \rho_{\text{geo}} \leq C_{\text{RS-upper}} \}$ +- Veto acceptance: $\mathcal{A}_{\text{veto}} = \{ \text{seq} : \rho_t \geq C_{\text{veto}} \text{ for all } t \in T \}$ + +**Why tight thresholds?** +Geometric mean is extremely sensitive. For 100 tokens with $\rho_t = 1.01$ each: +- Arithmetic product: $\prod_{t=1}^{100} \rho_t = 1.01^{100} \approx 2.7$ +- Geometric mean: $(1.01)^{1} = 1.01$ + +A threshold of 1.001 means rejecting sequences with average per-token deviation > 0.1%. + +**Properties:** +- **No IS weights**: Pure rejection +- **Extremely selective**: Requires near-perfect policy match +- **High rejection rate**: Only suitable for very slight distribution shifts + +### 4.2 Veto Mechanism + +An independent safety layer that rejects sequences with catastrophically low token probabilities. + +**Configuration:** +```python +RolloutCorrectionConfig(..., rollout_token_veto_threshold=1e-4) +``` + +**Veto condition:** + +$$ +\text{Reject entire sequence if } \exists t \in T \text{ such that } \rho_t < C_{\text{veto}} +$$ + +**Purpose:** +- Prevents catastrophic updates from tokens with near-zero probability under $\pi_{\text{old}}$ +- Independent of IS/RS settings +- Typical values: $10^{-4}$ to $10^{-6}$ + +**Implementation:** [rollout_corr_helper.py](../../verl/trainer/ppo/rollout_corr_helper.py#L620-L640) + +--- + +## 5. Off-Policy Diagnostic Metrics + +These metrics quantify the severity of off-policy drift. + +**Note on notation:** Metrics use $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$. In bypass mode, $\pi_{\text{old}} = \pi_{\text{rollout}}$, so metrics measure rollout→current drift using $\rho_t = \frac{\pi_{\theta}}{\pi_{\text{rollout}}}$ instead. + +### 5.1 KL Divergence + +**Direct KL estimator:** + +$$ +\text{KL}(\pi_{\text{rollout}} \| \pi_{\text{old}}) = \mathbb{E}_{t \sim \pi_{\text{rollout}}} \left[ \log \pi_{\text{rollout}}(a_t|s_t) - \log \pi_{\text{old}}(a_t|s_t) \right] +$$ + +**K3 KL estimator** (alternative formulation): + +$$ +\text{KL}_{\text{K3}} = \mathbb{E}_{t \sim \pi_{\text{rollout}}} \left[ \rho_t - \log \rho_t - 1 \right] +$$ + +where $\rho_t = \frac{\pi_{\text{old}}(a_t|s_t)}{\pi_{\text{rollout}}(a_t|s_t)}$. + +### 5.2 Perplexity + +**Old policy perplexity:** + +$$ +\text{PPL}_{\text{old}} = \exp\left( -\frac{1}{|T|} \sum_{t \in T} \log \pi_{\text{old}}(a_t|s_t) \right) +$$ + +**Rollout policy perplexity:** + +$$ +\text{PPL}_{\text{rollout}} = \exp\left( -\frac{1}{|T|} \sum_{t \in T} \log \pi_{\text{rollout}}(a_t|s_t) \right) +$$ + +**PPL ratio** (inverse of geometric mean IS weight): + +$$ +\text{PPL}_{\text{ratio}} = \frac{\text{PPL}_{\text{old}}}{\text{PPL}_{\text{rollout}}} = \exp\left( -\frac{1}{|T|} \sum_{t \in T} \log \rho_t \right) = \left(\prod_{t \in T} \rho_t\right)^{-1/|T|} +$$ + +**Interpretation:** Values > 1 mean $\pi_{\text{old}}$ assigns lower probability than $\pi_{\text{rollout}}$ to the observed actions (distribution shift). + +### 5.3 Chi-squared Divergence + +Measures the second moment of the IS weight distribution. + +**Token-level:** + +$$ +\chi^2_{\text{token}} = \mathbb{E}_{t \sim \pi_{\text{rollout}}} \left[ \rho_t^2 \right] - 1 +$$ + +**Sequence-level:** + +$$ +\chi^2_{\text{seq}} = \mathbb{E}_{\text{seq} \sim \pi_{\text{rollout}}} \left[ \left(\prod_{t \in T} \rho_t\right)^2 \right] - 1 +$$ + +**Interpretation:** +- $\chi^2 = 0$: Policies are identical +- $\chi^2 > 0$: Higher values indicate more severe off-policy distribution shift + +**Implementation:** `compute_offpolicy_metrics()` in [rollout_corr_helper.py](../../verl/trainer/ppo/rollout_corr_helper.py#L670-L776) + +--- + +## 6. Summary and Decision Guide + +### 6.1 Method Summary Table + +| Method | Theory | Policies | PPO Clip | IS Correction | Correctness | Speed | +|--------|--------|----------|----------|---------------|-------------|-------| +| `pure_is` | Off-policy REINFORCE | 2 (rollout, θ) | ❌ | ✅ Seq-level | ✅ Correct | **Fast** | +| Naive LLM-RL | Incorrect PPO usage | 2 (old, θ) | ✅ | ❌ | ⚠️ Incorrect | Standard | +| `ppo_is_bypass` | PPO (rollout as prox) | 2 (rollout, θ) | ✅ | ❌ | ✅ Correct | **Fast** | +| `token_is` | Decoupled PPO | 3 (rollout, old, θ) | ✅ | ✅ Token-level | ✅ Correct | Standard | +| `seq_is` | Decoupled PPO | 3 (rollout, old, θ) | ✅ | ✅ Seq-level | ✅ Correct | Standard | +| `seq_is_rs` | Decoupled PPO + RS | 3 (rollout, old, θ) | ✅ | ✅ + Rejection | ✅ Correct | Standard | +| `geo_rs` | Decoupled PPO + Geo RS | 3 (rollout, old, θ) | ✅ | Rejection only | ✅ Correct | Standard | + +### 6.2 Method Characteristics by Scenario + +**Off-policy severity:** +- **Negligible** (same checkpoint, minor differences): `ppo_is_bypass` uses $\pi_{\text{rollout}}$ as proximal policy (mathematically correct); naive LLM-RL implementations use $\pi_{\text{old}}$ instead of $\pi_{\text{rollout}}$ (mathematically incorrect when $\pi_{\text{rollout}} \neq \pi_{\text{old}}$) +- **Moderate** (async workers, slight staleness): `token_is` provides per-token IS correction with separate proximal policy +- **Severe** (replay buffers, old data): `seq_is` and `seq_is_rs` provide sequence-level IS correction with optional rejection sampling + +**Algorithm properties:** +- **Batch size invariance**: Decoupled mode with three policies (`token_is`, `seq_is`) achieves batch size invariance +- **Computational efficiency**: Bypass mode (`ppo_is_bypass`) skips `old_log_prob` computation +- **Pure policy gradient**: `pure_is` implements off-policy REINFORCE without PPO clipping + +### 6.3 Decoupled Mode vs Bypass Mode + +**Decoupled mode** (computes `old_log_prob` separately): +- Implements full decoupled PPO with three policies (mathematically correct) +- Separately measures and corrects Drift 1 (rollout→old) and Drift 2 (old→current) +- Achieves batch size invariance and efficient stale data utilization +- Enables accurate off-policy metrics monitoring + +**Bypass mode** (sets $\pi_{\text{old}} = \pi_{\text{rollout}}$): +- Uses $\pi_{\text{rollout}}$ as both behavior policy and proximal policy (mathematically correct) +- Computational efficiency: Skips separate `old_log_prob` computation +- Does not achieve batch size invariance (proximal policy depends on data collection) + +--- + +## 7. Implementation References + +- **[Rollout Correction Usage Guide](rollout_corr.md)** - Practical configuration and troubleshooting +- **Config:** [verl/trainer/config/algorithm.py](../../verl/trainer/config/algorithm.py) +- **IS/RS Helper:** [verl/trainer/ppo/rollout_corr_helper.py](../../verl/trainer/ppo/rollout_corr_helper.py) +- **PPO Loss:** [verl/trainer/ppo/core_algos.py](../../verl/trainer/ppo/core_algos.py) +- **Tests:** [tests/trainer/ppo/test_rollout_corr.py](../../tests/trainer/ppo/test_rollout_corr.py) + +--- + +## References + +- **Williams, R. J. (1992).** "Simple statistical gradient-following algorithms for connectionist reinforcement learning." *Machine Learning*, 8(3-4), 229-256. https://doi.org/10.1007/BF00992696 +- **Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017).** "Proximal policy optimization algorithms." *arXiv preprint arXiv:1707.06347.* https://arxiv.org/abs/1707.06347 +- **Hilton, J., Cobbe, K., & Schulman, J. (2021).** "Batch size-invariance for policy optimization." *arXiv preprint arXiv:2110.00641.* https://arxiv.org/abs/2110.00641 + - Introduced decoupled PPO: separating proximal policy (for controlling policy update size) from behavior policy (for off-policy correction) to achieve batch size invariance +- **Liu, J., Li, Y., et al. (2025).** "When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch" + - Blog post: https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda diff --git a/docs/advance/rollout_is.md b/docs/advance/rollout_is.md deleted file mode 100644 index dd282b58dbc..00000000000 --- a/docs/advance/rollout_is.md +++ /dev/null @@ -1,755 +0,0 @@ -# Rollout Importance Sampling - -**Author:** [Yingru Li](https://richardli.xyz/) - -Last updated: 10/27/2025. - -This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation in verl. - -### BibTeX Citation - -```bibtex -@misc{liu-li-2025, - title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, - url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, - author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, - year = {2025}, - month = september, -} -``` - -## Overview - -Rollout Importance Sampling corrects for distribution mismatch between: -- **Rollout policy**: e.g., vLLM with BFloat16 -- **Training policy**: e.g., FSDP with FP32 - -This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases. - -### Key Design Principle: Separation of IS Weights and Rejection Sampling - -**Important**: As of 10/27/2025, the implementation separates two mechanisms: - -1. **IS Weights** (`rollout_is_weights`): Ratios π_train/π_rollout with processing: - - **Safety-bounded** to [exp(-20), exp(20)] ≈ [2e-9, 5e8] to prevent overflow: - * Token level: Bounds per-token ratios - * Sequence level: Bounds product of ratios (broadcast to all tokens in sequence) - * Geometric level: Bounds geometric mean of ratios (broadcast to all tokens) - - **Truncate mode**: Upper clamped via .clamp(max=upper_threshold) - - **Mask mode**: Safety-bounded ratios preserved (no threshold clamping) - - **All modes**: Zeroed at padding positions (response_mask == 0) - - Used for policy gradient calculations - -2. **Rejection Sampling** (`modified_response_mask`): Applied via response_mask - - Mask mode: Excludes tokens/sequences with outlier IS ratios - - Veto: Excludes sequences with catastrophic tokens - - Used for loss aggregation (denominator calculation) - -This separation ensures: -- ✅ Correct loss normalization (rejected samples excluded from denominator) -- ✅ Mode-specific weight processing (truncate: upper clamped, mask: safety-bounded only) -- ✅ Padding positions zeroed in weights (necessary for correct aggregation) -- ✅ Safety bounds always applied (prevent overflow in all modes) - -## Configuration - -```yaml -# Rollout IS configuration (all in algorithm config) -algorithm: - # Main control: set threshold to enable (null = disabled) - rollout_is_threshold: 2.0 - # Whether to apply weights to loss (default: false = metrics only) - rollout_is: true - rollout_is_threshold_lower: null # Auto-reciprocal - rollout_is_level: token - rollout_is_mode: truncate - rollout_is_veto_threshold: null # Disable veto by default - -# REQUIRED: Enable log prob calculation -actor_rollout_ref: - rollout: - calculate_log_probs: true -``` - -Key features: -- ✅ Three aggregation levels: token, sequence, geometric -- ✅ Two bounding modes: truncate, mask -- ✅ Dual threshold support (upper/lower) -- ✅ Veto mechanism for catastrophic outliers -- ✅ 30+ comprehensive metrics -- ✅ Log-space computation for numerical stability -- ✅ Memory-efficient implementation - -## Files - -### **Core Implementation** - -- `verl/trainer/ppo/mismatch_helper.py` - Contains `compute_rollout_importance_weights()` and `compute_is_metrics()` -- `verl/trainer/ppo/core_algos.py` - Rollout IS integration with PPO -- `verl/workers/actor/dp_actor.py` - Metrics collection and logging - -### **Configuration Files** - -- `verl/trainer/config/algorithm.py` - Rollout IS parameters in `AlgoConfig` -- `verl/workers/config/actor.py` - Rollout IS parameters in `ActorConfig` -- `verl/trainer/config/actor/actor.yaml` - Rollout IS configuration section -- `verl/trainer/config/ppo_trainer.yaml` - Algorithm config with rollout IS - -### **Documentation** - -- `docs/examples/config.rst` - Configuration parameter descriptions - -### **Example Scripts** - -- `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` - DAPO example with rollout IS -- `examples/rollout_importance_sampling/README.md` - Comprehensive usage guide -- `examples/rollout_importance_sampling/run_with_rollout_is.sh` - Basic example - -### **Tests** - -- `tests/trainer/ppo/test_rollout_is.py` - Unit tests -- `tests/trainer/ppo/test_rollout_is_integration.py` - Integration tests - -## Configuration Parameters - -### `algorithm.rollout_is_threshold` (float or null) -**Main on/off switch.** Upper threshold for IS weights. -- `null` = disabled (no computation, no metrics) -- `float` value (e.g., 2.0) = enabled (compute weights and metrics) - -### `algorithm.rollout_is` (bool) -Whether to apply IS weights to policy loss. Default: `False` -- `true` = apply weights to loss (full IS correction) -- `false` = metrics only mode (no weight correction, but rejection still applies) - -**IMPORTANT**: This flag controls IS weight application, NOT rejection sampling. See "Operation Modes" below. - -**Recommended threshold ranges:** -- Token level: 1.5 - 5.0 -- Sequence level: 2.0 - 10.0 -- Geometric level: 1.0002 - 1.001 - -### `algorithm.rollout_is_threshold_lower` (float or null) -Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal). - -### `algorithm.rollout_is_level` (str) -Aggregation level for IS weights: -- `"token"`: Per-token ratios ρ_t = π_train(t)/π_rollout(t) - - Each token has its own IS weight - - Safety bound: each token's ratio bounded to [exp(-20), exp(20)] - - Biased estimator but low variance -- `"sequence"`: Product of ratios ρ_seq = ∏_t ρ_t for entire sequence - - All tokens in a sequence share the same IS weight (product of per-token ratios) - - Safety bound: product bounded to [exp(-20), exp(20)], then broadcast to all tokens - - Unbiased estimator but high variance -- `"geometric"`: Geometric mean ρ_geo = (∏_t ρ_t)^(1/T) (experimental) - - All tokens in a sequence share the same IS weight (geometric mean) - - Safety bound: geometric mean bounded to [exp(-20), exp(20)], then broadcast to all tokens - - Trade-off between bias and variance - -### `algorithm.rollout_is_mode` (str) -Bounding mode for handling outlier IS weights: -- `"truncate"`: Clamp weights at upper threshold only (TIS) - - No lower bound clamping or rejection for outlier ratios - - **IS weights modified**: Upper bound clamped via .clamp(max=upper_threshold) - - Lower bound remains at exp(-20) ≈ 2e-9 from safety bound - - **Note**: Veto-based rejection can still occur via response_mask (see `rollout_is_veto_threshold`) -- `"mask"`: Rejection sampling via response_mask (MIS) - - Rejects tokens/sequences with IS ratios outside [lower, upper] - - **Important**: Rejection applied to `response_mask`, NOT by modifying IS weights - - **IS weights**: Safety-bounded ratios preserved (no threshold clamping, rejection via mask) - - **Note**: Veto-based rejection also applies via response_mask (independent mechanism) - -### `algorithm.rollout_is_veto_threshold` (float or None) -Per-token veto threshold for catastrophic outliers. -- If any token has **unclamped** ratio < this threshold, the entire sequence is rejected via `response_mask` -- Veto checks the **true per-token ratio** π_train(t)/π_rollout(t) before any bounds are applied -- Applied for all levels (token, sequence, geometric) - always checks individual token ratios -- Default: `None` (veto disabled by default) -- Recommended: `1e-4` to `1e-6` when enabled (catches extreme outliers like 10,000x off) -- Set to `None` to disable veto mechanism -- **Important**: Applied **independently** of `rollout_is_mode` (works in both truncate and mask modes) -- Veto applies rejection to `response_mask`, NOT by modifying IS weights -- **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded) - -### Summary: How IS Weights are Processed - -The final IS weights go through multiple stages of processing: - -**Stage 1: Safety Bound (All Modes)** -- Token level: `exp(clamp(log_ratio, -20, 20))` per token → bounds each token to [2e-9, 5e8] -- Sequence level: `exp(clamp(sum(log_ratio), -20, 20))` → bounds product to [2e-9, 5e8], broadcast to all tokens -- Geometric level: `exp(clamp(mean(log_ratio), -20, 20))` → bounds geometric mean to [2e-9, 5e8], broadcast to all tokens - -**Stage 2: Threshold Processing (Mode-Dependent)** -- Truncate mode: `.clamp(max=upper_threshold)` → upper clamps weights to threshold -- Mask mode: No modification → weights remain as safety-bounded ratios - -**Stage 3: Padding (All Modes)** -- `weights * response_mask` → zeros out padding positions - -**Rejection Mechanisms (Modify response_mask, NOT weights)** -- Veto: Checks **unclamped per-token ratios** (before safety bound), rejects sequences via mask -- Outlier (mask mode only): Checks safety-bounded weights against [lower, upper], rejects via mask - -## Operation Modes - -The system has **two independent control flags** that combine to create different operation modes: - -1. **`rollout_is_threshold`**: Main on/off switch (None = disabled, float = enabled) -2. **`rollout_is`**: Apply IS weights to loss (True/False) - -### Mode Combinations - -| `rollout_is_threshold` | `rollout_is` | `rollout_is_mode` | Behavior | -|------------------------|--------------|-------------------|----------| -| `None` | any | any | **Disabled**: No computation, no metrics, no rejection | -| `2.0` | `False` | `truncate` | **Metrics only**: Compute weights & metrics, NO weight correction, NO rejection for outliers | -| `2.0` | `False` | `mask` | **Rejection only**: Compute weights & metrics, NO weight correction, YES rejection sampling | -| `2.0` | `True` | `truncate` | **Truncate mode**: Weight correction enabled, weights upper-clamped, NO rejection for outliers | -| `2.0` | `True` | `mask` | **Mask mode (full)**: Weight correction enabled, rejection sampling enabled | - -### Key Insights - -**Rejection sampling is ALWAYS applied when:** -- `rollout_is_threshold` is set (not None) -- AND `rollout_is_mode = "mask"` -- **Regardless of the `rollout_is` flag** - -This means: -- ✅ You can use **rejection sampling alone** without IS weight correction (`rollout_is=False, rollout_is_mode="mask"`) -- ✅ You can use **IS weights alone** without outlier rejection (`rollout_is=True, rollout_is_mode="truncate"`) -- ✅ You can use **both together** (`rollout_is=True, rollout_is_mode="mask"`) -- ✅ You can **monitor metrics only** without any correction or outlier rejection (`rollout_is=False, rollout_is_mode="truncate"`) - -**Veto rejection** (if enabled via `rollout_is_veto_threshold`) is applied **independently** in all modes where `rollout_is_threshold` is set. - -### Recommended Workflow - -1. **Start with metrics only** to understand the mismatch: - ```yaml - rollout_is_threshold: 2.0 - rollout_is: false - rollout_is_mode: truncate - ``` - Monitor `mismatch/rollout_is_mean`, `mismatch/mismatch_kl` to assess distribution mismatch. - -2. **Enable rejection sampling** if you see high outlier fractions: - ```yaml - rollout_is_threshold: 2.0 - rollout_is: false - rollout_is_mode: mask # Rejection now applies - ``` - This excludes outliers from training without modifying gradients. - -3. **Enable full IS correction** once comfortable with metrics: - ```yaml - rollout_is_threshold: 2.0 - rollout_is: true - rollout_is_mode: mask # Both rejection and weight correction - ``` - -## Usage - -### Basic Setup - -```yaml -algorithm: - rollout_is_threshold: 2.0 # Main control - rollout_is: true # Apply to loss (default: false) - rollout_is_level: token - rollout_is_mode: truncate - -actor_rollout_ref: - rollout: - calculate_log_probs: true # Required! -``` - -### Metrics - -All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs. - -#### **Core IS Weight Metrics** - -- **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens - - **Ideal value**: Close to 1.0 (indicates minimal distribution mismatch) - - **Warning**: < 0.5 or > 2.0 suggests significant policy mismatch - -- **`rollout_is_std`**: Standard deviation of IS weights - - **Ideal value**: < 0.5 for stable training - - **Warning**: > 1.0 indicates high variance, may need tighter thresholds - -- **`rollout_is_min`**: Minimum IS weight observed - - Shows the most underweighted token/sequence - - For sequence/geometric: computed from unclamped log-space ratios (true minimum) - - For token: computed from safety-bounded weights - -- **`rollout_is_max`**: Maximum IS weight observed - - Shows the most overweighted token/sequence - - For sequence/geometric: computed from unclamped log-space ratios (true maximum before safety bound) - - For token: computed from safety-bounded weights (before threshold clamping) - - Compare with `rollout_is_threshold` to see truncation impact - -#### **Effective Sample Size** - -- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting - - **Formula**: `1 / mean(weights²)` where weights are normalized - - **Range**: 0.0 to 1.0 (as fraction of original batch) - - **Ideal value**: > 0.5 (retaining at least 50% effective samples) - - **Warning**: < 0.3 means high variance, losing too many effective samples - -#### **Veto Mechanism Metrics** - -- **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism - - **Important**: Sequences are rejected via `response_mask=0`, NOT by modifying IS weights - - **IS weights unchanged by veto**: Already processed by mode (truncate: clamped, mask: safety-bounded) - - Veto checks **unclamped per-token ratios** π_train(t)/π_rollout(t) (true ratios before safety bound) - - Detects catastrophic tokens (true ratio < veto_threshold, e.g., < 1e-4) - - **Ideal value**: < 0.05 (less than 5% vetoed) - - **Warning**: > 0.1 suggests policies are too different or numerical issues - -- **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold - - Identifies problematic tokens before sequence-level veto is applied - - Checks **unclamped per-token ratios** (true ratios, not safety-bounded) - - Each catastrophic token causes its entire sequence to be rejected - - **Warning**: > 0.01 indicates widespread distribution issues or numerical instability - -#### **Threshold Exceedance Metrics** - -- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold - - Shows how often truncation/masking occurs on high end - - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) - - For token: computed from safety-bounded weights (before threshold clamping) - - **Ideal value**: < 0.1 (most weights within bounds) - -- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold - - Shows how often masking occurs on low end (mask mode only) - - For sequence/geometric: computed from unclamped log-space ratios (true exceedance) - - For token: computed from safety-bounded weights - - **Ideal value**: < 0.1 - -#### **Sequence-Level Metrics** (for sequence/geometric modes) - -- **`rollout_is_seq_mean`**: Mean IS weight at sequence level - - Should match `rollout_is_mean` for sequence-level aggregation - -- **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights - -- **`rollout_is_seq_min`**: Minimum sequence-level IS weight - -- **`rollout_is_seq_max`**: Maximum sequence-level IS weight - -- **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level - - **Ideal value**: < 1.0 - - Shows worst-case sequence mismatch - -- **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold - -- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold - -#### **Masking Metrics** (mask mode only) - -- **`rollout_is_masked_fraction`**: Fraction of tokens rejected via response_mask (mask mode only) - - **Important**: Tokens are rejected by setting `response_mask=0`, NOT by modifying IS weights - - **IS weights in mask mode**: Safety-bounded ratios preserved (no threshold clamping) - - **Ideal value**: < 0.1 (less than 10% rejected) - - **Warning**: > 0.3 means losing too much data - -- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one rejected token - - Shows sequence-level impact of rejection sampling - - For token-level: sequence rejected if ANY token is outside [lower, upper] - - For sequence-level: all tokens have same weight, so entire sequence rejected or accepted - -#### **Distribution Mismatch Metrics** (Training vs Rollout Policy) - -- **`mismatch_training_ppl`**: Perplexity of training policy (e.g., FSDP FP32) - - **Formula**: `exp(-mean(log_probs))` - - Lower is better (model is more confident) - -- **`mismatch_rollout_ppl`**: Perplexity of rollout policy (e.g., vLLM BF16) - - Should be close to `mismatch_training_ppl` if policies match well - -- **`mismatch_ppl_ratio`**: Ratio of training PPL to rollout PPL - - **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))` - - **Ideal value**: Close to 1.0 - - **Meaning**: > 1.0 means training is less confident than rollout - -- **`mismatch_training_log_ppl`**: Log perplexity of training policy - - Useful for identifying trends (linear scale) - -- **`mismatch_rollout_log_ppl`**: Log perplexity of rollout policy - -- **`mismatch_log_ppl_diff`**: Mean difference in log perplexities - - **Formula**: `mean(log_ppl_rollout - log_ppl_training)` - - **Ideal value**: Close to 0.0 - - Sign indicates which policy is more confident - -- **`mismatch_log_ppl_abs_diff`**: Mean absolute log perplexity difference - - Magnitude of mismatch regardless of direction - -- **`mismatch_log_ppl_diff_max`**: Maximum log perplexity difference across sequences - - Identifies worst-case sequence - -- **`mismatch_log_ppl_diff_min`**: Minimum log perplexity difference across sequences - -- **`mismatch_kl`**: KL divergence KL(π_rollout || π_training) - - **Formula**: `mean(log_prob_rollout - log_prob_training)` - - **Ideal value**: Close to 0.0 (policies match) - - **Warning**: > 0.1 indicates significant mismatch - - **Note**: Can be negative (rollout is less confident) - -- **`mismatch_k3_kl`**: K3 KL estimator - - **Formula**: `mean(exp(log_ratio) - log_ratio - 1)` - - More stable for small KL values - - Always non-negative - -#### **Example: Accessing Metrics in Code** - -```python -# Metrics are returned from compute_rollout_importance_weights -from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights - -# NEW: Returns 3 values (weights, modified_response_mask, metrics) -weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( - old_log_prob=training_log_probs, # from training policy - rollout_log_prob=rollout_log_probs, # from rollout policy - response_mask=response_mask, - rollout_is_level="token", - rollout_is_mode="mask", # Using mask mode for rejection sampling - rollout_is_threshold=2.0, - rollout_is_threshold_lower=0.5, - rollout_is_veto_threshold=1e-4, # Enable veto for catastrophic outliers -) - -# Extract IS weights (processed, zeroed at padding) -is_weights = weights_proto.batch["rollout_is_weights"] - -# IS weights processing (mask mode with token level): -# 1. Safety-bounded: exp(clamp(log_ratio, -20, 20)) per token -# 2. Mask mode: no threshold clamping (safety-bounded ratios preserved) -# 3. Zeroed at padding positions - -# modified_response_mask has rejection applied: -# 1. Outlier rejection: tokens outside [0.5, 2.0] masked to 0 (mask mode) -# 2. Veto rejection: sequences with catastrophic tokens (ratio < 1e-4) masked to 0 -# Note: Veto checks unclamped per-token ratios, not the safety-bounded weights - -# All metrics have 'mismatch/' prefix -print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}") -print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}") -print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}") -print(f"Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.3f}") -print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}") - -# Check IS weights for valid tokens (non-padding) -valid_weights = is_weights[response_mask.bool()] -print(f"\n✓ IS weights min (valid tokens): {valid_weights.min():.4f}") -print(f"✓ IS weights max (valid tokens): {valid_weights.max():.4f}") -print(f"✓ All valid IS weights > 0: {(valid_weights > 0).all()}") - -# Check rejection via response_mask -rejected_tokens = (response_mask == 1) & (modified_response_mask == 0) -print(f"\n✓ Rejected {rejected_tokens.sum()} tokens via response_mask") -print(f"✓ In mask mode: IS weights for rejected tokens are NON-ZERO (safety-bounded ratios)") -print(f"✓ In truncate mode: IS weights upper clamped to {rollout_is_threshold}") -print(f"✓ Both modes: IS weights safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]") - -# Check for warning conditions -if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0: - print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected") - -if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3: - print("⚠️ Warning: Low effective sample size, high variance in IS weights") - -if metrics['mismatch/rollout_is_veto_fraction'] > 0.1: - print("⚠️ Warning: High veto fraction, policies may be too different") -``` - -#### **Example: Monitoring Metrics During Training** - -```python -# In your training loop -for epoch in range(num_epochs): - for batch_idx, batch in enumerate(dataloader): - # ... rollout phase ... - - # Compute IS weights and get metrics (NEW: 3 return values) - weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( - old_log_prob=batch.old_log_prob, - rollout_log_prob=batch.rollout_log_prob, - response_mask=batch.response_mask, - rollout_is_level=config.rollout_is_level, - rollout_is_mode=config.rollout_is_mode, - rollout_is_threshold=config.rollout_is_threshold, - rollout_is_threshold_lower=config.rollout_is_threshold_lower, - rollout_is_veto_threshold=config.rollout_is_veto_threshold, - ) - - # Log to tensorboard/wandb - for metric_name, metric_value in metrics.items(): - logger.log_scalar(metric_name, metric_value, step=global_step) - - # IMPORTANT: Update batch response_mask with rejection applied - batch.response_mask = modified_response_mask - - # Use IS weights in training (processed based on mode) - # Truncate mode: upper clamped to min(weight, upper_threshold) - # Mask mode: safety-bounded ratios preserved (no threshold clamping) - # Both modes: safety bounded to [exp(-20), exp(20)], zeroed at padding - is_weights = weights_proto.batch["rollout_is_weights"] - # ... apply weights to policy gradient ... -``` - -#### **Example: Conditional Alerting Based on Metrics** - -```python -def check_rollout_is_health(metrics, config): - """Check if rollout IS metrics indicate healthy training.""" - warnings = [] - - # Check mean IS weight - mean_weight = metrics['mismatch/rollout_is_mean'] - if mean_weight < 0.5 or mean_weight > 2.0: - warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0") - - # Check effective sample size - ess = metrics['mismatch/rollout_is_eff_sample_size'] - if ess < 0.3: - warnings.append(f"Effective sample size {ess:.3f} is too low") - - # Check veto fraction - veto_frac = metrics['mismatch/rollout_is_veto_fraction'] - if veto_frac > 0.1: - warnings.append(f"Veto fraction {veto_frac:.3f} is too high") - - # Check variance - std = metrics['mismatch/rollout_is_std'] - if std > 1.0: - warnings.append(f"IS weight std {std:.3f} is too high") - - # Check KL divergence - kl = metrics['mismatch/mismatch_kl'] - if abs(kl) > 0.1: - warnings.append(f"KL divergence {kl:.3f} indicates significant mismatch") - - if warnings: - print("⚠️ Rollout IS Health Warnings:") - for warning in warnings: - print(f" - {warning}") - return False - else: - print("✅ Rollout IS metrics look healthy") - return True - -# Use in training (NEW: 3 return values) -_, _, metrics = compute_rollout_importance_weights(...) -is_healthy = check_rollout_is_health(metrics, config) - -if not is_healthy: - # Consider adjusting config or investigating issues - print("Consider:") - print(" - Tightening rollout_is_threshold") - print(" - Switching to geometric aggregation level") - print(" - Checking if rollout and training policies are too different") -``` - -### Running Examples - -Start with the basic token-level truncate configuration: -```bash -bash examples/rollout_importance_sampling/run_with_rollout_is.sh -``` - -Monitor metrics for 1-2 epochs before adjusting parameters. - -## Configuration Examples - -### Example 1: Full IS Correction -```yaml -algorithm: - rollout_is_threshold: 2.0 - rollout_is: true # Apply weights to loss - rollout_is_level: token - rollout_is_mode: truncate -``` - -### Example 2: Metrics Only (Monitoring Mode) -```yaml -algorithm: - rollout_is_threshold: 2.0 - rollout_is: false # Compute metrics, don't apply weights - rollout_is_level: token - rollout_is_mode: truncate -``` - -### Example 3: Geometric Mean with Mask -```yaml -algorithm: - rollout_is_threshold: 1.0002 - rollout_is: true - rollout_is_threshold_lower: 0.9998 - rollout_is_level: geometric - rollout_is_mode: mask -``` - -### Example 4: Asymmetric Thresholds -```yaml -algorithm: - rollout_is_threshold: 5.0 - rollout_is: true - rollout_is_threshold_lower: 0.8 - rollout_is_level: token - rollout_is_mode: mask -``` - -## Troubleshooting - -### Issue: High variance in IS weights -**Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3 - -**Solutions:** -1. Switch from `sequence` to `geometric` level -2. Tighten thresholds -3. Verify rollout and training aren't too different - -### Issue: Too many sequences vetoed -**Symptoms:** `rollout_is_veto_fraction` > 0.1 - -**Solutions:** -1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3` -2. Check for numerical issues in log prob computation -3. Verify policies aren't completely different - -### Issue: Mean IS weight far from 1.0 -**Symptoms:** `rollout_is_mean` < 0.5 or > 2.0 - -**Solutions:** -1. Verify `calculate_log_probs=True` is set -2. Check rollout_log_probs are correctly passed -3. Check for systematic bias - -### Debugging: Visualizing Metrics - -**Example: Plot IS weight distribution** - -```python -import matplotlib.pyplot as plt -import numpy as np - -def plot_is_metrics(metrics_history): - """Plot rollout IS metrics over training steps.""" - fig, axes = plt.subplots(2, 3, figsize=(15, 10)) - - # Plot 1: Mean IS weight over time - axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean']) - axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal') - axes[0, 0].set_title('Mean IS Weight') - axes[0, 0].set_xlabel('Step') - axes[0, 0].legend() - - # Plot 2: Effective sample size - axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size']) - axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good') - axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning') - axes[0, 1].set_title('Effective Sample Size') - axes[0, 1].set_xlabel('Step') - axes[0, 1].legend() - - # Plot 3: Veto fraction - axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction']) - axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning') - axes[0, 2].set_title('Veto Fraction') - axes[0, 2].set_xlabel('Step') - axes[0, 2].legend() - - # Plot 4: KL divergence over time - axes[1, 0].plot(metrics_history['mismatch/mismatch_kl'], label='KL') - axes[1, 0].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL') - axes[1, 0].axhline(y=0, color='g', linestyle='--', alpha=0.3) - axes[1, 0].set_title('KL Divergence') - axes[1, 0].set_xlabel('Step') - axes[1, 0].legend() - - # Plot 5: PPL ratio over time - axes[1, 1].plot(metrics_history['mismatch/mismatch_ppl_ratio']) - axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Ideal') - axes[1, 1].set_title('PPL Ratio (Training/Rollout)') - axes[1, 1].set_xlabel('Step') - axes[1, 1].legend() - - # Hide unused subplot - axes[1, 2].axis('off') - - plt.tight_layout() - plt.savefig('rollout_is_metrics.png', dpi=150) - print("Saved plot to rollout_is_metrics.png") -``` - -**Example: Metric collection during training** - -```python -# Collect metrics over time -metrics_history = { - 'mismatch/rollout_is_mean': [], - 'mismatch/rollout_is_eff_sample_size': [], - 'mismatch/rollout_is_veto_fraction': [], - 'mismatch/mismatch_kl': [], - 'mismatch/mismatch_k3_kl': [], - 'mismatch/mismatch_ppl_ratio': [], -} - -# In training loop -for step in range(num_steps): - # ... compute IS weights ... (NEW: 3 return values) - _, _, metrics = compute_rollout_importance_weights(...) - - # Store metrics - for key in metrics_history.keys(): - if key in metrics: - metrics_history[key].append(metrics[key]) - - # Plot every 100 steps - if step % 100 == 0: - plot_is_metrics(metrics_history) -``` - -## Performance Impact - -- **Memory overhead**: ~1% of model memory -- **Computational overhead**: 1-3% depending on level -- **Training stability**: Significantly improved when mismatch exists - - -## Testing - -Run the test suite to verify everything works: - -```bash -# Basic unit tests -python test_rollout_is.py - -# Integration tests (if pytest is available) -pytest tests/trainer/ppo/test_rollout_is_integration.py -v -``` - -Expected output: All tests pass ✓ - -## Additional Resources - -- **Implementation**: `verl/trainer/ppo/mismatch_helper.py` -- **Examples**: `examples/rollout_importance_sampling/` -- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh` - -## Summary - -Rollout Importance Sampling provides: -- ✅ Robust handling of distribution mismatch -- ✅ Numerical stability -- ✅ Comprehensive metrics for monitoring -- ✅ Flexibility for different scenarios -- ✅ Memory-efficient computation - -## References - -- [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda) -- [Your Efficient RL Framework Secretly Brings You Off-Policy RL Training](https://fengyao.notion.site/off-policy-rl) \ No newline at end of file diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 71c96f3080e..385acae422c 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -127,13 +127,14 @@ Actor/Rollout/Reference Policy clip_ratio: 0.2 entropy_coeff: 0.0 use_kl_loss: False # True for GRPO - # Rollout Importance Sampling (corrects distribution mismatch between rollout and training) - rollout_is: False # Enable IS correction - rollout_is_threshold: null # Upper threshold for IS weights (null to disable) - rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper) - rollout_is_level: token # Aggregation: token/sequence/geometric - rollout_is_mode: truncate # Bounding: truncate/mask - rollout_is_veto_threshold: null # Catastrophic outlier threshold (null to disable) + # Rollout Correction (corrects distribution mismatch between rollout and training) + rollout_correction: + rollout_is: token # IS weights: token/sequence/null + rollout_is_threshold: 2.0 # Upper threshold for IS weights + rollout_rs: null # Rejection sampling: token/sequence/geometric/null + rollout_rs_threshold: null # RS upper threshold + rollout_rs_threshold_lower: null # RS lower threshold + rollout_token_veto_threshold: null # Per-token veto (null to disable) use_torch_compile: True # False to disable torch compile kl_loss_coef: 0.001 # for grpo kl_loss_type: low_var_kl # for grpo @@ -515,13 +516,14 @@ Algorithm kl_coef: 0.005 horizon: 10000 target_kl: 0.1 - # Rollout Importance Sampling - rollout_is: False - rollout_is_threshold: null - rollout_is_threshold_lower: null - rollout_is_level: token - rollout_is_mode: truncate - rollout_is_veto_threshold: null # Disabled by default + # Rollout Correction + rollout_correction: + rollout_is: null # IS weights: token/sequence/null + rollout_is_threshold: 2.0 # Upper threshold for IS weights + rollout_rs: null # Rejection sampling: token/sequence/geometric/null + rollout_rs_threshold: null # RS upper threshold + rollout_rs_threshold_lower: null # RS lower threshold + rollout_token_veto_threshold: null # Per-token veto (null to disable) - ``gamma``: discount factor - ``lam``: Trade-off between bias and variance in the GAE estimator @@ -536,13 +538,17 @@ Algorithm - ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController. - ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details. -- ``rollout_is``: Whether to enable rollout importance sampling correction. Default is False. -- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely. -- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper). -- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental). -- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds). -- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is null (disabled). - Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``. +- ``rollout_correction``: Rollout Correction configuration (nested dict). Set to ``null`` to disable. + When enabled, contains: + + - ``rollout_is``: IS weights aggregation level: ``token``, ``sequence``, or ``null`` to disable IS weights. + - ``rollout_is_threshold``: Upper threshold for IS weights (e.g., 2.0). + - ``rollout_rs``: Rejection sampling mode: ``token``, ``sequence``, ``geometric``, or ``null`` to disable RS. + - ``rollout_rs_threshold``: RS upper threshold. + - ``rollout_rs_threshold_lower``: RS lower threshold (null = auto-reciprocal). + - ``rollout_token_veto_threshold``: Per-token veto threshold for catastrophic outliers (null = disabled). + + Note: Rollout Correction requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``. Trainer ~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index a984179dd0f..18b46044291 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -121,7 +121,8 @@ verl is fast with: examples/sandbox_fusion_example advance/rollout_trace.rst advance/rollout_skip.rst - advance/rollout_is.md + advance/rollout_corr.md + advance/rollout_corr_math.md advance/one_step_off advance/agent_loop advance/reward_loop diff --git a/examples/rollout_correction/README.md b/examples/rollout_correction/README.md new file mode 100644 index 00000000000..bf37fba8ebb --- /dev/null +++ b/examples/rollout_correction/README.md @@ -0,0 +1,253 @@ +# Rollout Correction Examples + +This directory contains examples and documentation for using Rollout Correction to address off-policy issues in RL training. + +**References:** +- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda +- Off-policy RL: https://fengyao.notion.site/off-policy-rl + +## Overview + +Rollout Correction addresses off-policy issues including: +1. **Policy mismatch**: Rollout policy (e.g., vLLM BFloat16) vs Training policy (e.g., FSDP FP32) +2. **Model staleness**: Training on trajectories from older policy checkpoints +3. **Distribution shifts**: Any distribution gap between data collection and training + +Rollout Correction uses importance sampling (IS) weights and rejection sampling (RS) to correct for these distribution shifts. + +## Quick Start + +### Basic Configuration + +```yaml +algorithm: + rollout_correction: + rollout_is: token # IS weights: token/sequence/null + rollout_is_threshold: 2.0 # Upper threshold + rollout_rs: null # Rejection sampling: token/sequence/geometric/null + rollout_rs_threshold: null + rollout_rs_threshold_lower: null + rollout_token_veto_threshold: null # Veto threshold + +# IMPORTANT: Must enable log prob calculation +actor_rollout_ref: + rollout: + calculate_log_probs: true +``` + +### Running the Example + +```bash +# Basic example with token-level truncate +bash examples/rollout_correction/run_with_rollout_corr.sh +``` + +## Configuration Options + +### IS Weights Aggregation Levels (`rollout_is`) + +| Level | Properties | Threshold Range | +|-------|-----------|-----------------| +| **token** | Per-token weighting | 1.5 - 5.0 | +| **sequence** | Per-sequence weighting | 2.0 - 10.0 | +| **null** | Disabled | N/A | + +### Rejection Sampling Modes (`rollout_rs`) + +| Mode | Behavior | Threshold Range | +|------|----------|-----------------| +| **token** | Per-token rejection | 1.5 - 5.0 | +| **sequence** | Per-sequence rejection | 2.0 - 10.0 | +| **geometric** | Geometric mean rejection | 1.0002 - 1.001 | +| **null** | Disabled | N/A | + +### Key Parameters + +- `rollout_is`: IS weights aggregation level (`token`, `sequence`, or `null`) +- `rollout_is_threshold`: Upper threshold for IS weights +- `rollout_rs`: Rejection sampling mode (`token`, `sequence`, `geometric`, or `null`) +- `rollout_rs_threshold`: RS upper threshold +- `rollout_rs_threshold_lower`: RS lower threshold (null = auto 1/upper) +- `rollout_token_veto_threshold`: Per-token catastrophic outlier veto threshold (null = disabled) + +## Configuration Examples + +### Example 1: IS Weights Only (Token-level) + +```yaml +algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: null # No rejection sampling +``` + +### Example 2: Rejection Sampling Only + +```yaml +algorithm: + rollout_correction: + rollout_is: null # No IS weights + rollout_rs: sequence + rollout_rs_threshold: 3.0 +``` + +### Example 3: Combined IS + RS + +```yaml +algorithm: + rollout_correction: + rollout_is: token + rollout_is_threshold: 2.0 + rollout_rs: token + rollout_rs_threshold: 2.0 +``` + +### Example 4: Geometric Mean RS with Veto + +```yaml +algorithm: + rollout_correction: + rollout_is: null + rollout_rs: geometric + rollout_rs_threshold: 1.0002 + rollout_rs_threshold_lower: 0.9998 + rollout_token_veto_threshold: 1e-4 +``` + +### Example 5: Full Configuration + +```yaml +algorithm: + rollout_correction: + rollout_is: sequence + rollout_is_threshold: 2.0 + rollout_rs: token + rollout_rs_threshold: 2.0 + rollout_rs_threshold_lower: 0.5 + rollout_token_veto_threshold: 1e-4 +``` + +## Monitoring Metrics + +Key metrics to watch (all prefixed with `rollout_corr/` in logs): + +### Health Indicators +- `rollout_is_mean`: Mean IS weight across sequences +- `rollout_is_eff_sample_size`: Effective sample size after weighting +- `rollout_is_veto_fraction`: Fraction of sequences vetoed + +### Distribution Metrics +- `rollout_is_max`, `rollout_is_min`: Weight extremes +- `rollout_is_std`: Standard deviation + +### Diagnostic Metrics +- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold +- `rollout_is_ratio_fraction_low`: Fraction below lower threshold +- `rollout_is_catastrophic_token_fraction`: Catastrophic tokens detected + +### Mismatch Metrics (Training vs Rollout Policy) + +These metrics help diagnose the distribution mismatch between rollout and training policies: + +**Perplexity Metrics:** +- `training_ppl`: Perplexity of training policy +- `rollout_ppl`: Perplexity of rollout policy +- `ppl_ratio`: Ratio of training PPL to rollout PPL +- `log_ppl_diff`: Log perplexity difference + +**KL Divergence Metrics:** +- `kl`: KL divergence KL(π_rollout || π_training) +- `k3_kl`: K3 KL estimator + +## Troubleshooting + +### Issue: High Variance in IS Weights + +**Symptoms**: `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3 + +**Solutions**: +1. Switch from `sequence` to `geometric` level +2. Tighten thresholds +3. Check if rollout and training are too different + +### Issue: Too Many Sequences Vetoed + +**Symptoms**: `rollout_is_veto_fraction` > 0.1 + +**Solutions**: +1. Relax veto threshold in config: + ```yaml + algorithm: + rollout_correction: + rollout_token_veto_threshold: 1e-3 + ``` +2. Check for numerical issues in log prob computation +3. Verify rollout and training policies aren't completely different + +### Issue: Mean IS Weight Far from 1.0 + +**Symptoms**: `rollout_is_mean` < 0.5 or > 2.0 + +**Solutions**: +1. Check that `calculate_log_probs=True` is set +2. Verify rollout_log_probs are correctly passed +3. Check for systematic bias in rollout vs training + +### Issue: Too Much Data Discarded (Mask Mode) + +**Symptoms**: `rollout_is_masked_fraction` > 0.5 + +**Solutions**: +1. Widen thresholds +2. Switch to `truncate` mode +3. Use `geometric` level for better stability + +## Performance Considerations + +### Memory Usage +- Rollout Correction adds minimal memory overhead (~1% of model memory) +- Log-space computation prevents numerical overflow + +### Computational Cost +- Token-level: ~1-2% overhead +- Sequence-level: ~2-3% overhead +- Geometric: ~2-3% overhead + +## Advanced Topics + +### Dual Thresholds + +Specify both upper and lower explicitly: + +```yaml +rollout_is_threshold: 2.0 # Upper +rollout_is_threshold_lower: 0.5 # Lower (not 1/2.0 = 0.5) +``` + +Or use auto-reciprocal: + +```yaml +rollout_is_threshold: 2.0 # Upper = 2.0, Lower = 0.5 (auto) +rollout_is_threshold_lower: null +``` + +### Veto Mechanism + +The veto mechanism zeros out entire sequences containing catastrophic outliers: + +- If any token has ratio < `rollout_token_veto_threshold`, the entire sequence is rejected +- This prevents extreme outliers from dominating training +- Default: `null` (disabled by default) +- Set to `1e-4` to enable (catches ratios 10,000x off) + +## Examples + +See the script in this directory: +- `run_with_rollout_corr.sh`: Basic example with token-level truncate mode + +## References + +- Implementation: `verl/trainer/ppo/rollout_corr_helper.py` +- Core algorithm: `verl/trainer/ppo/core_algos.py` +- Paper: "Your Efficient RL Framework Secretly Brings You Off-Policy RL Training" diff --git a/examples/rollout_importance_sampling/run_with_rollout_is.sh b/examples/rollout_correction/run_with_rollout_corr.sh similarity index 58% rename from examples/rollout_importance_sampling/run_with_rollout_is.sh rename to examples/rollout_correction/run_with_rollout_corr.sh index 42c4a2a5981..04e0ab08943 100755 --- a/examples/rollout_importance_sampling/run_with_rollout_is.sh +++ b/examples/rollout_correction/run_with_rollout_corr.sh @@ -1,31 +1,24 @@ #!/usr/bin/env bash -# Example: Basic PPO training with Rollout Importance Sampling +# Example: Basic PPO training with Rollout Correction # This demonstrates the standard setup for correcting distribution mismatch set -xeuo pipefail # ============================================================================== -# Rollout Importance Sampling Configuration +# Rollout Correction Configuration # ============================================================================== -# Main control: Upper threshold for IS weights (null = disabled, float = enabled) -rollout_is_threshold=2.0 +# Importance Sampling (IS) weights configuration +rollout_is="token" # "token", "sequence", or null to disable +rollout_is_threshold=2.0 # Upper threshold for IS weights -# Whether to apply IS weights to policy loss -# true = apply weights to loss, false = compute metrics only -rollout_is=true +# Rejection Sampling (RS) configuration +rollout_rs="null" # "token", "sequence", "geometric", or null to disable +rollout_rs_threshold="null" # RS upper threshold +rollout_rs_threshold_lower="null" # RS lower threshold -# Lower threshold (null = auto-reciprocal, i.e., 1/upper = 0.5) -rollout_is_threshold_lower=null - -# Aggregation level: token | sequence | geometric (experimental) -rollout_is_level=token - -# Bounding mode: truncate (cap upper) | mask (zero outside bounds) -rollout_is_mode=truncate - -# Catastrophic outlier veto threshold (set to null to disable, or e.g., 1e-4 to enable) -rollout_is_veto_threshold=null +# Veto mechanism (optional, independent of IS/RS) +rollout_token_veto_threshold="null" # Per-token veto threshold (null to disable) # ============================================================================== # Model and Data Configuration @@ -68,12 +61,12 @@ python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=${adv_estimator} \ algorithm.gamma=${gamma} \ algorithm.lam=${lam} \ - algorithm.rollout_is=${rollout_is} \ - algorithm.rollout_is_threshold=${rollout_is_threshold} \ - algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \ - algorithm.rollout_is_level=${rollout_is_level} \ - algorithm.rollout_is_mode=${rollout_is_mode} \ - algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=${learning_rate} \ actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \ @@ -81,19 +74,19 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.rollout.calculate_log_probs=True \ actor_rollout_ref.rollout.name=vllm \ trainer.logger='["console","wandb"]' \ - trainer.project_name="rollout_is_example" \ + trainer.project_name="rollout_corr_example" \ trainer.experiment_name="basic_token_truncate" \ trainer.total_epochs=10 echo "Training completed!" echo "" -echo "Rollout IS Configuration:" -echo " - Threshold: ${rollout_is_threshold}" -echo " - Apply to loss: ${rollout_is}" -echo " - Level: ${rollout_is_level}" -echo " - Mode: ${rollout_is_mode}" +echo "Rollout Correction Configuration:" +echo " - IS weights: ${rollout_is}" +echo " - IS threshold: ${rollout_is_threshold}" +echo " - RS mode: ${rollout_rs}" +echo " - Veto threshold: ${rollout_token_veto_threshold}" echo "" echo "Monitor these key metrics in wandb:" -echo " - mismatch/rollout_is_mean (should be ~1.0)" -echo " - mismatch/rollout_is_eff_sample_size (should be >0.5)" -echo " - mismatch/rollout_is_veto_fraction (should be <0.1)" +echo " - rollout_corr/rollout_is_mean (should be ~1.0)" +echo " - rollout_corr/rollout_is_eff_sample_size (should be >0.5)" +echo " - rollout_corr/rollout_is_veto_fraction (should be <0.1)" diff --git a/examples/rollout_importance_sampling/README.md b/examples/rollout_importance_sampling/README.md deleted file mode 100644 index 7baf55ebf2e..00000000000 --- a/examples/rollout_importance_sampling/README.md +++ /dev/null @@ -1,241 +0,0 @@ -# Rollout Importance Sampling (IS) Examples - -This directory contains examples and documentation for using Rollout Importance Sampling to correct distribution mismatch between rollout and training policies. - -**References:** -- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda -- Off-policy RL: https://fengyao.notion.site/off-policy-rl - -## Overview - -Rollout Importance Sampling corrects for distribution mismatch when: -1. **Rollout generation** uses one policy (e.g., vLLM with BFloat16) -2. **Training** uses another policy (e.g., FSDP with FP32) -3. This mismatch leads to biased gradient estimates - -## Quick Start - -### Basic Configuration - -```yaml -algorithm: - # Main control: set threshold to enable (null = disabled) - rollout_is_threshold: 2.0 - # Whether to apply weights to policy loss (true) or just compute metrics (false) - rollout_is: true - rollout_is_level: token - rollout_is_mode: truncate - -# IMPORTANT: Must enable log prob calculation -actor_rollout_ref: - rollout: - calculate_log_probs: true -``` - -### Running the Example - -```bash -# Basic example with token-level truncate -bash examples/rollout_importance_sampling/run_with_rollout_is.sh -``` - -## Configuration Options - -### Aggregation Levels (`rollout_is_level`) - -| Level | Properties | Threshold Range | -|-------|-----------|-----------------| -| **token** | Per-token | 1.5 - 5.0 | -| **sequence** | Per-sequence | 2.0 - 10.0 | -| **geometric** | Geometric mean | 1.0002 - 1.001 | - -### Bounding Modes (`rollout_is_mode`) - -| Mode | Behavior | -|------|----------| -| **truncate** | Cap weights at upper threshold only | -| **clip** | Zero out weights outside [lower, upper] | - -### Key Parameters - -- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.** -- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false. -- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper) -- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: null, disabled) - -## Configuration Examples - -### Example 1: Full IS Correction (Apply Weights) - -```yaml -algorithm: - rollout_is_threshold: 2.0 - rollout_is: true # Apply to loss - rollout_is_level: token - rollout_is_mode: truncate - rollout_is_veto_threshold: null # Disabled by default -``` - -### Example 2: Metrics Only (No Weight Application) - -```yaml -algorithm: - rollout_is_threshold: 2.0 - rollout_is: false # Compute metrics only, don't apply to loss - rollout_is_level: token - rollout_is_mode: truncate -``` - -### Example 3: Geometric Mean with Mask - -```yaml -algorithm: - rollout_is_threshold: 1.0002 - rollout_is: true - rollout_is_threshold_lower: 0.9998 - rollout_is_level: geometric - rollout_is_mode: mask - rollout_is_veto_threshold: 1e-4 # Enable veto for this example -``` - -### Example 4: Sequence-level with Truncate - -```yaml -algorithm: - rollout_is_threshold: 5.0 - rollout_is: true - rollout_is_threshold_lower: null # Auto-reciprocal: 0.2 - rollout_is_level: sequence - rollout_is_mode: truncate - rollout_is_veto_threshold: 1e-4 # Enable veto for this example -``` - -### Example 5: Asymmetric Thresholds - -```yaml -algorithm: - rollout_is_threshold: 5.0 - rollout_is: true - rollout_is_threshold_lower: 0.8 - rollout_is_level: token - rollout_is_mode: mask -``` - -## Monitoring Metrics - -Key metrics to watch (all prefixed with `mismatch/` in logs): - -### Health Indicators -- `rollout_is_mean`: Mean IS weight across sequences -- `rollout_is_eff_sample_size`: Effective sample size after weighting -- `rollout_is_veto_fraction`: Fraction of sequences vetoed - -### Distribution Metrics -- `rollout_is_max`, `rollout_is_min`: Weight extremes -- `rollout_is_std`: Standard deviation - -### Diagnostic Metrics -- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold -- `rollout_is_ratio_fraction_low`: Fraction below lower threshold -- `rollout_is_catastrophic_token_fraction`: Catastrophic tokens detected - -### Mismatch Metrics (Training vs Rollout Policy) - -These metrics help diagnose the distribution mismatch between rollout and training policies: - -**Perplexity Metrics:** -- `mismatch_training_ppl`: Perplexity of training policy -- `mismatch_rollout_ppl`: Perplexity of rollout policy -- `mismatch_ppl_ratio`: Ratio of training PPL to rollout PPL -- `mismatch_log_ppl_diff`: Log perplexity difference - -**KL Divergence Metrics:** -- `mismatch_kl`: KL divergence KL(π_rollout || π_training) -- `mismatch_k3_kl`: K3 KL estimator - -## Troubleshooting - -### Issue: High Variance in IS Weights - -**Symptoms**: `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3 - -**Solutions**: -1. Switch from `sequence` to `geometric` level -2. Tighten thresholds -3. Check if rollout and training are too different - -### Issue: Too Many Sequences Vetoed - -**Symptoms**: `rollout_is_veto_fraction` > 0.1 - -**Solutions**: -1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3` -2. Check for numerical issues in log prob computation -3. Verify rollout and training policies aren't completely different - -### Issue: Mean IS Weight Far from 1.0 - -**Symptoms**: `rollout_is_mean` < 0.5 or > 2.0 - -**Solutions**: -1. Check that `calculate_log_probs=True` is set -2. Verify rollout_log_probs are correctly passed -3. Check for systematic bias in rollout vs training - -### Issue: Too Much Data Discarded (Mask Mode) - -**Symptoms**: `rollout_is_masked_fraction` > 0.5 - -**Solutions**: -1. Widen thresholds -2. Switch to `truncate` mode -3. Use `geometric` level for better stability - -## Performance Considerations - -### Memory Usage -- Rollout IS adds minimal memory overhead (~1% of model memory) -- Log-space computation prevents numerical overflow - -### Computational Cost -- Token-level: ~1-2% overhead -- Sequence-level: ~2-3% overhead -- Geometric: ~2-3% overhead - -## Advanced Topics - -### Dual Thresholds - -Specify both upper and lower explicitly: - -```yaml -rollout_is_threshold: 2.0 # Upper -rollout_is_threshold_lower: 0.5 # Lower (not 1/2.0 = 0.5) -``` - -Or use auto-reciprocal: - -```yaml -rollout_is_threshold: 2.0 # Upper = 2.0, Lower = 0.5 (auto) -rollout_is_threshold_lower: null -``` - -### Veto Mechanism - -The veto mechanism zeros out entire sequences containing catastrophic outliers: - -- If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected -- This prevents extreme outliers from dominating training -- Default: `null` (disabled by default) -- Set to `1e-4` to enable (catches ratios 10,000x off) - -## Examples - -See the script in this directory: -- `run_with_rollout_is.sh`: Basic example with token-level truncate mode - -## References - -- Implementation: `verl/trainer/ppo/mismatch_helper.py` -- Core algorithm: `verl/trainer/ppo/core_algos.py` -- Paper: "Your Efficient RL Framework Secretly Brings You Off-Policy RL Training" diff --git a/recipe/dapo/dapo_ray_trainer.py b/recipe/dapo/dapo_ray_trainer.py index b1367e2b656..a6a06192374 100644 --- a/recipe/dapo/dapo_ray_trainer.py +++ b/recipe/dapo/dapo_ray_trainer.py @@ -316,10 +316,14 @@ def fit(self): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - # Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer) - batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch) - # IS and mismatch metrics already have mismatch/ prefix - metrics.update(is_metrics) + # Compute rollout correction weights and off-policy metrics (inherited from RayPPOTrainer) + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) with marked_timer("adv", timing_raw, "brown"): # compute advantages, executed on the driver process diff --git a/recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh b/recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh similarity index 84% rename from recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh rename to recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh index bec96da90ed..53d8467a062 100644 --- a/recipe/dapo/run_dapo_qwen2.5_32b_rollout_is.sh +++ b/recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh @@ -1,13 +1,13 @@ #!/usr/bin/env bash set -xeuo pipefail -# Rollout Importance Sampling Example +# Rollout Correction Example # References: # - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda # - Off-policy RL: https://fengyao.notion.site/off-policy-rl project_name='DAPO' -exp_name='DAPO-Qwen2.5-32B-RolloutIS' # Rollout Importance Sampling +exp_name='DAPO-Qwen2.5-32B-RolloutIS' # Rollout Correction adv_estimator=grpo @@ -16,13 +16,13 @@ kl_coef=0.0 use_kl_loss=False kl_loss_coef=0.0 -# Rollout Importance Sampling parameters -rollout_is=True +# Rollout Correction parameters +rollout_is=token rollout_is_threshold=2.0 -rollout_is_threshold_lower=null # No lower bound -rollout_is_level=token # token-level -rollout_is_mode=truncate # truncate mode -rollout_is_veto_threshold=null # No veto +rollout_rs=null +rollout_rs_threshold=null +rollout_rs_threshold_lower=null +rollout_token_veto_threshold=null clip_ratio_low=0.2 clip_ratio_high=0.28 @@ -70,16 +70,15 @@ offload=True gen_tp=4 -# Rollout Importance Sampling (corrects distribution mismatch between rollout and training) +# Rollout Correction (corrects distribution mismatch between rollout and training) # # Please note that server mode (agent loop) hasn't returned rollout_log_probs for now, -# so currently server mode is not supported for Rollout IS. +# so currently server mode is not supported for Rollout Correction. # -# Rollout IS parameters (configured at top of script): -# algorithm.rollout_is=True -# algorithm.rollout_is_threshold=2.0 # Upper threshold (can be tuned) -# algorithm.rollout_is_level=token # Aggregation level -# algorithm.rollout_is_mode=truncate # Bounding mode +# Rollout Correction parameters (configured at top of script): +# algorithm.rollout_correction.rollout_is=token +# algorithm.rollout_correction.rollout_is_threshold=2.0 +# algorithm.rollout_correction.rollout_rs=null # actor_rollout_ref.rollout.calculate_log_probs=True # Required! ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ @@ -124,12 +123,12 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \ actor_rollout_ref.actor.grad_clip=1.0 \ actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ - algorithm.rollout_is=${rollout_is} \ - algorithm.rollout_is_threshold=${rollout_is_threshold} \ - algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \ - algorithm.rollout_is_level=${rollout_is_level} \ - algorithm.rollout_is_mode=${rollout_is_mode} \ - algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \ + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} \ actor_rollout_ref.rollout.calculate_log_probs=True \ actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ diff --git a/recipe/fully_async_policy/ray_trainer.py b/recipe/fully_async_policy/ray_trainer.py index 9d216e05639..ba6c1448ae6 100644 --- a/recipe/fully_async_policy/ray_trainer.py +++ b/recipe/fully_async_policy/ray_trainer.py @@ -423,12 +423,16 @@ def compute_old_log_prob(batch): else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - # Compute rollout importance sampling weights centrally (once per batch) - # This corrects for mismatch between rollout policy and training policy - # Also computes mismatch metrics (KL, PPL, etc.) - batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch) - # IS and mismatch metrics already have mismatch/ prefix - metrics.update(is_metrics) + # Compute rollout correction weights centrally (once per batch) + # This corrects for off-policy issues (policy mismatch, model staleness, etc.) + # Also computes off-policy diagnostic metrics (KL, PPL, etc.) + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) # compute advantages, executed on the driver process norm_adv_by_std_in_grpo = self.config.algorithm.get( diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh index c72ea786c12..65079fdb2de 100644 --- a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64_mis.sh @@ -76,13 +76,13 @@ trigger_parameter_sync_step=4 require_batches=4 partial_rollout=True -# Rollout Importance Sampling +# Rollout Correction +rollout_is=geometric rollout_is_threshold=1.001 -rollout_is=True -rollout_is_threshold_lower=0.99 -rollout_is_level=geometric -rollout_is_mode=mask -rollout_is_veto_threshold=1e-4 +rollout_rs=geometric +rollout_rs_threshold=1.001 +rollout_rs_threshold_lower=0.99 +rollout_token_veto_threshold=1e-4 python -m recipe.fully_async_policy.fully_async_main \ data.train_files="${TRAIN_FILE}" \ @@ -169,10 +169,10 @@ python -m recipe.fully_async_policy.fully_async_main \ async_training.partial_rollout="${partial_rollout}" \ async_training.use_rollout_log_probs=True \ async_training.compute_prox_log_prob=True \ - algorithm.rollout_is=${rollout_is} \ - algorithm.rollout_is_threshold=${rollout_is_threshold} \ - algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \ - algorithm.rollout_is_level=${rollout_is_level} \ - algorithm.rollout_is_mode=${rollout_is_mode} \ - algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} + algorithm.rollout_correction.rollout_is=${rollout_is} \ + algorithm.rollout_correction.rollout_is_threshold=${rollout_is_threshold} \ + algorithm.rollout_correction.rollout_rs=${rollout_rs} \ + algorithm.rollout_correction.rollout_rs_threshold=${rollout_rs_threshold} \ + algorithm.rollout_correction.rollout_rs_threshold_lower=${rollout_rs_threshold_lower} \ + algorithm.rollout_correction.rollout_token_veto_threshold=${rollout_token_veto_threshold} diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py index 6152580c689..3f5ec30ac17 100644 --- a/recipe/one_step_off_policy/ray_trainer.py +++ b/recipe/one_step_off_policy/ray_trainer.py @@ -577,10 +577,14 @@ def fit(self): else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - # Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer) - batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch) - # IS and mismatch metrics already have mismatch/ prefix - metrics.update(is_metrics) + # Compute rollout correction weights and off-policy metrics (inherited from RayPPOTrainer) + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_add_to_batch + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) # compute advantages, executed on the driver process diff --git a/tests/trainer/ppo/test_rollout_is.py b/tests/trainer/ppo/test_rollout_corr.py similarity index 58% rename from tests/trainer/ppo/test_rollout_is.py rename to tests/trainer/ppo/test_rollout_corr.py index 9ae13f0eab0..3ac408e12b6 100644 --- a/tests/trainer/ppo/test_rollout_is.py +++ b/tests/trainer/ppo/test_rollout_corr.py @@ -13,30 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Quick Sanity Test for Rollout Importance Sampling +Quick Sanity Test for Rollout Correction This is a standalone test script that can be run without pytest to quickly verify -the rollout IS implementation is working correctly. For comprehensive integration -tests, see: tests/trainer/ppo/test_rollout_is_integration.py +the rollout correction implementation is working correctly. For comprehensive integration +tests, see: tests/trainer/ppo/test_rollout_corr_integration.py Usage: - python test_rollout_is.py + python test_rollout_corr.py This tests: -- Basic rollout IS functionality (3 levels, 2 modes) -- Metrics completeness (32 total: 21 IS + 11 mismatch metrics) +- Basic rollout correction functionality (IS weights + rejection sampling) +- Metrics completeness (IS metrics + rejection metrics + off-policy metrics) - Veto mechanism - Edge cases """ import torch -from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights +from verl.trainer.ppo.rollout_corr_helper import ( + compute_offpolicy_metrics, + compute_rollout_correction_and_rejection_mask, +) -def test_basic_rollout_is(): - """Test basic rollout IS functionality.""" - print("Testing basic rollout IS functionality...") +def test_basic_rollout_correction(): + """Test basic rollout correction functionality.""" + print("Testing basic rollout correction functionality...") # Create test data batch_size, seq_length = 4, 10 @@ -49,63 +52,63 @@ def test_basic_rollout_is(): # Test token-level truncate mode print("\n1. Testing token-level truncate mode...") - weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( + weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, - rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is="token", # Compute IS weights at token level rollout_is_threshold=2.0, - rollout_is_veto_threshold=1e-4, + rollout_rs=None, # No rejection sampling (truncate mode) + rollout_token_veto_threshold=1e-4, ) weights = weights_proto.batch["rollout_is_weights"] print(f" Weights shape: {weights.shape}") - print(f" Mean weight: {metrics['mismatch/rollout_is_mean']:.4f}") - print(f" Max weight: {metrics['mismatch/rollout_is_max']:.4f}") - print(f" Min weight: {metrics['mismatch/rollout_is_min']:.4f}") - print(f" Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.4f}") + print(f" Mean weight: {metrics['rollout_corr/rollout_is_mean']:.4f}") + print(f" Max weight: {metrics['rollout_corr/rollout_is_max']:.4f}") + print(f" Min weight: {metrics['rollout_corr/rollout_is_min']:.4f}") + print(f" Veto fraction: {metrics['rollout_corr/rollout_is_veto_fraction']:.4f}") assert weights.shape == old_log_prob.shape assert weights.max() <= 2.0, "Weights should be capped at threshold" print(" ✓ Token-level truncate mode passed") # Test sequence-level mode print("\n2. Testing sequence-level mode...") - weights_seq_proto, _, metrics_seq = compute_rollout_importance_weights( + weights_seq_proto, _, metrics_seq = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, - rollout_is_level="sequence", - rollout_is_mode="truncate", + rollout_is="sequence", # Compute IS weights at sequence level rollout_is_threshold=5.0, - rollout_is_veto_threshold=1e-4, + rollout_rs=None, # No rejection sampling (truncate mode) + rollout_token_veto_threshold=1e-4, ) weights_seq = weights_seq_proto.batch["rollout_is_weights"] - print(f" Mean weight: {metrics_seq['mismatch/rollout_is_mean']:.4f}") - print(f" Effective sample size: {metrics_seq['mismatch/rollout_is_eff_sample_size']:.4f}") + print(f" Mean weight: {metrics_seq['rollout_corr/rollout_is_mean']:.4f}") + print(f" Effective sample size: {metrics_seq['rollout_corr/rollout_is_eff_sample_size']:.4f}") # Check that all tokens in a sequence have the same weight for i in range(batch_size): seq_weights = weights_seq[i, eos_mask[i].bool()] assert torch.allclose(seq_weights, seq_weights[0]), "All tokens in sequence should have same weight" print(" ✓ Sequence-level mode passed") - # Test geometric mean mode - print("\n3. Testing geometric mean mode...") - weights_geo_proto, _, metrics_geo = compute_rollout_importance_weights( + # Test geometric mean rejection sampling (mask mode) + print("\n3. Testing geometric mean rejection sampling...") + weights_geo_proto, modified_mask_geo, metrics_geo = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, - rollout_is_level="geometric", - rollout_is_mode="mask", - rollout_is_threshold=1.5, - rollout_is_threshold_lower=0.5, - rollout_is_veto_threshold=1e-4, + rollout_is=None, # No IS weights (pure mask mode) + rollout_rs="geometric", # Rejection sampling with geometric mean + rollout_rs_threshold=1.5, + rollout_rs_threshold_lower=0.5, + rollout_token_veto_threshold=1e-4, ) - print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}") - print(f" Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}") - print(" ✓ Geometric mean mode passed") + print(f" Masked fraction: {metrics_geo['rollout_corr/rollout_rs_masked_fraction']:.4f}") + print(f" Veto fraction: {metrics_geo['rollout_corr/rollout_is_veto_fraction']:.4f}") + print(" ✓ Geometric mean rejection sampling passed") # Test veto mechanism print("\n4. Testing veto mechanism...") @@ -116,18 +119,18 @@ def test_basic_rollout_is(): rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7 eos_mask_veto = torch.ones(2, 5, device=device) - weights_veto_proto, modified_response_mask_veto, metrics_veto = compute_rollout_importance_weights( + weights_veto_proto, modified_response_mask_veto, metrics_veto = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob_veto, rollout_log_prob=rollout_log_prob_veto, response_mask=eos_mask_veto, - rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is="token", rollout_is_threshold=2.0, - rollout_is_veto_threshold=1e-4, + rollout_rs=None, + rollout_token_veto_threshold=1e-4, ) weights_veto = weights_veto_proto.batch["rollout_is_weights"] - print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}") + print(f" Veto fraction: {metrics_veto['rollout_corr/rollout_is_veto_fraction']:.4f}") # KEY FIX: Veto is applied via response_mask, not by zeroing weights # Check that weights are NON-ZERO (safety-bounded ratios preserved, not zeroed) assert weights_veto[0].sum() > 0, "Weights should be non-zero (not zeroed by veto)" @@ -136,18 +139,21 @@ def test_basic_rollout_is(): assert modified_response_mask_veto[1].sum() > 0, "Normal sequence should have response_mask unchanged" print(" ✓ Veto mechanism passed") - # Test disabled IS (threshold=None) + # Test disabled IS (rollout_is=None, rollout_rs=None) print("\n5. Testing disabled IS...") - weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_importance_weights( + weights_disabled, modified_response_mask_disabled, metrics_disabled = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, - rollout_is_threshold=None, + rollout_is=None, + rollout_rs=None, + rollout_token_veto_threshold=None, ) - assert weights_disabled is None, "Should return None when threshold is None" + assert weights_disabled is None, "Should return None when IS is disabled" assert torch.equal(modified_response_mask_disabled, eos_mask), "Should return original mask unchanged" - assert len(metrics_disabled) == 0, "Should return empty metrics when disabled" + # Note: off-policy metrics are still computed even when IS/RS are disabled + assert "rollout_corr/kl" in metrics_disabled, "Should still compute off-policy metrics" print(" ✓ Disabled IS passed") print("\n✓ All tests passed!") @@ -164,44 +170,46 @@ def test_metrics_completeness(): rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2 eos_mask = torch.ones(batch_size, seq_length, device=device) - _, _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=eos_mask, - rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is="token", rollout_is_threshold=2.5, + rollout_rs=None, ) # Expected IS metrics expected_is_metrics = [ - "mismatch/rollout_is_mean", - "mismatch/rollout_is_max", - "mismatch/rollout_is_min", - "mismatch/rollout_is_std", - "mismatch/rollout_is_eff_sample_size", - "mismatch/rollout_is_veto_fraction", - "mismatch/rollout_is_catastrophic_token_fraction", - "mismatch/rollout_is_ratio_fraction_high", - "mismatch/rollout_is_ratio_fraction_low", + "rollout_corr/rollout_is_mean", + "rollout_corr/rollout_is_max", + "rollout_corr/rollout_is_min", + "rollout_corr/rollout_is_std", + "rollout_corr/rollout_is_eff_sample_size", + "rollout_corr/rollout_is_veto_fraction", + "rollout_corr/rollout_is_catastrophic_token_fraction", + "rollout_corr/rollout_is_ratio_fraction_high", + "rollout_corr/rollout_is_ratio_fraction_low", ] - # Expected mismatch/diagnostic metrics (also included now) - expected_mismatch_metrics = [ - "mismatch/mismatch_training_ppl", - "mismatch/mismatch_training_log_ppl", - "mismatch/mismatch_kl", - "mismatch/mismatch_k3_kl", - "mismatch/mismatch_rollout_ppl", - "mismatch/mismatch_rollout_log_ppl", - "mismatch/mismatch_log_ppl_diff", - "mismatch/mismatch_log_ppl_abs_diff", - "mismatch/mismatch_log_ppl_diff_max", - "mismatch/mismatch_log_ppl_diff_min", - "mismatch/mismatch_ppl_ratio", + # Expected off-policy diagnostic metrics (also included now) + expected_offpolicy_metrics = [ + "rollout_corr/training_ppl", + "rollout_corr/training_log_ppl", + "rollout_corr/kl", + "rollout_corr/k3_kl", + "rollout_corr/rollout_ppl", + "rollout_corr/rollout_log_ppl", + "rollout_corr/log_ppl_diff", + "rollout_corr/log_ppl_abs_diff", + "rollout_corr/log_ppl_diff_max", + "rollout_corr/log_ppl_diff_min", + "rollout_corr/ppl_ratio", + "rollout_corr/chi2_token", + "rollout_corr/chi2_seq", ] - expected_metrics = expected_is_metrics + expected_mismatch_metrics + expected_metrics = expected_is_metrics + expected_offpolicy_metrics missing_metrics = [m for m in expected_metrics if m not in metrics] if missing_metrics: @@ -213,9 +221,9 @@ def test_metrics_completeness(): return True -def test_mismatch_metrics(): - """Test mismatch metrics computation.""" - print("\nTesting mismatch metrics computation...") +def test_offpolicy_metrics(): + """Test off-policy metrics computation.""" + print("\nTesting off-policy metrics computation...") batch_size, seq_length = 4, 12 device = "cuda" if torch.cuda.is_available() else "cpu" @@ -226,46 +234,48 @@ def test_mismatch_metrics(): response_mask = torch.ones(batch_size, seq_length, device=device) # Test with rollout log probs - metrics = compute_mismatch_metrics( + metrics = compute_offpolicy_metrics( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask, ) expected_metrics = [ - "mismatch_training_ppl", - "mismatch_training_log_ppl", - "mismatch_kl", - "mismatch_k3_kl", - "mismatch_rollout_ppl", - "mismatch_rollout_log_ppl", - "mismatch_log_ppl_diff", - "mismatch_log_ppl_abs_diff", - "mismatch_log_ppl_diff_max", - "mismatch_log_ppl_diff_min", - "mismatch_ppl_ratio", + "training_ppl", + "training_log_ppl", + "kl", + "k3_kl", + "rollout_ppl", + "rollout_log_ppl", + "log_ppl_diff", + "log_ppl_abs_diff", + "log_ppl_diff_max", + "log_ppl_diff_min", + "ppl_ratio", + "chi2_token", + "chi2_seq", ] for metric in expected_metrics: assert metric in metrics, f"Missing metric: {metric}" - print(f" Training PPL: {metrics['mismatch_training_ppl']:.4f}") - print(f" Rollout PPL: {metrics['mismatch_rollout_ppl']:.4f}") - print(f" KL divergence: {metrics['mismatch_kl']:.6f}") - print(f" K3 KL: {metrics['mismatch_k3_kl']:.6f}") - print(f" PPL ratio: {metrics['mismatch_ppl_ratio']:.4f}") - print(f" ✓ All {len(expected_metrics)} mismatch metrics present") + print(f" Training PPL: {metrics['training_ppl']:.4f}") + print(f" Rollout PPL: {metrics['rollout_ppl']:.4f}") + print(f" KL divergence: {metrics['kl']:.6f}") + print(f" K3 KL: {metrics['k3_kl']:.6f}") + print(f" PPL ratio: {metrics['ppl_ratio']:.4f}") + print(f" ✓ All {len(expected_metrics)} off-policy metrics present") # Test without rollout log probs - metrics_no_rollout = compute_mismatch_metrics( + metrics_no_rollout = compute_offpolicy_metrics( old_log_prob=old_log_prob, rollout_log_prob=None, response_mask=response_mask, ) - assert "mismatch_training_ppl" in metrics_no_rollout - assert "mismatch_rollout_ppl" not in metrics_no_rollout - print(" ✓ Mismatch metrics work without rollout log probs") + assert "training_ppl" in metrics_no_rollout + assert "rollout_ppl" not in metrics_no_rollout + print(" ✓ Off-policy metrics work without rollout log probs") def test_mask_mode(): @@ -288,15 +298,16 @@ def test_mask_mode(): ) response_mask = torch.ones(batch_size, seq_length, device=device) - weights_proto, modified_response_mask, metrics = compute_rollout_importance_weights( + weights_proto, modified_response_mask, metrics = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask, - rollout_is_level="token", - rollout_is_mode="mask", + rollout_is="token", # Compute IS weights rollout_is_threshold=2.0, - rollout_is_threshold_lower=0.5, - rollout_is_veto_threshold=None, + rollout_rs="token", # Also apply rejection sampling (mask mode) + rollout_rs_threshold=2.0, + rollout_rs_threshold_lower=0.5, + rollout_token_veto_threshold=None, ) weights = weights_proto.batch["rollout_is_weights"] @@ -314,27 +325,27 @@ def test_mask_mode(): assert torch.all(modified_response_mask[0, :] == 0), "First sequence should be rejected via mask" assert torch.all(modified_response_mask[1, :] == 1), "Second sequence should be accepted" - # Verify mask metrics exist - assert "mismatch/rollout_is_masked_fraction" in metrics - assert abs(metrics["mismatch/rollout_is_masked_fraction"] - 0.5) < 0.01, "Should reject 50% of tokens" + # Verify rejection sampling metrics exist + assert "rollout_corr/rollout_rs_masked_fraction" in metrics, "Should have rollout_rs_masked_fraction metric" + assert abs(metrics["rollout_corr/rollout_rs_masked_fraction"] - 0.5) < 0.01, "Should reject 50% of tokens" print(f" First seq IS weight: {weights[0, 0]:.4f} (expected ≈0.37)") print(f" Second seq IS weight: {weights[1, 0]:.4f} (expected ≈1.65)") print(f" First seq mask: {modified_response_mask[0, 0]:.0f} (expected 0 - rejected)") print(f" Second seq mask: {modified_response_mask[1, 0]:.0f} (expected 1 - accepted)") - print(f" Masked fraction: {metrics['mismatch/rollout_is_masked_fraction']:.2f}") + print(f" Masked fraction: {metrics['rollout_corr/rollout_rs_masked_fraction']:.2f}") print(" ✓ Mask mode correctly separates IS weights from rejection") if __name__ == "__main__": print("=" * 60) - print("Rollout Importance Sampling Test Suite") + print("Rollout Correction Test Suite") print("=" * 60) try: - test_basic_rollout_is() + test_basic_rollout_correction() test_metrics_completeness() - test_mismatch_metrics() + test_offpolicy_metrics() test_mask_mode() print("\n" + "=" * 60) print("ALL TESTS PASSED ✓") diff --git a/tests/trainer/ppo/test_rollout_is_integration.py b/tests/trainer/ppo/test_rollout_corr_integration.py similarity index 67% rename from tests/trainer/ppo/test_rollout_is_integration.py rename to tests/trainer/ppo/test_rollout_corr_integration.py index b96fb77523f..4df924872b2 100644 --- a/tests/trainer/ppo/test_rollout_is_integration.py +++ b/tests/trainer/ppo/test_rollout_corr_integration.py @@ -11,18 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Integration tests for Rollout Importance Sampling.""" +"""Integration tests for Rollout Correction.""" import pytest import torch from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla -from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights +from verl.trainer.ppo.rollout_corr_helper import ( + compute_offpolicy_metrics, + compute_rollout_correction_and_rejection_mask, +) from verl.workers.config.actor import ActorConfig class TestRolloutISIntegration: - """Integration tests for Rollout IS with PPO.""" + """Integration tests for Rollout Correction with PPO.""" @pytest.fixture def sample_data(self): @@ -54,21 +57,21 @@ def config_with_rollout_is(self): return config def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is): - """Test that policy loss computation works with rollout IS weights. + """Test that policy loss computation works with rollout correction weights. Note: In production, IS weights are computed centrally in the trainer (before advantage computation) and passed to policy loss. This test simulates that workflow. """ # First compute IS weights (as trainer would do centrally) - rollout_is_weights_proto, _, _ = compute_rollout_importance_weights( + rollout_is_weights_proto, _, _ = compute_rollout_correction_and_rejection_mask( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], - rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is="token", + rollout_rs=None, rollout_is_threshold=2.0, - rollout_is_veto_threshold=1e-4, + rollout_token_veto_threshold=1e-4, ) rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"] @@ -91,15 +94,15 @@ def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is): assert not torch.isinf(pg_loss) def test_rollout_is_weights_computation(self, sample_data): - """Test rollout IS weights and metrics computation.""" - weights_proto, _, metrics = compute_rollout_importance_weights( + """Test rollout correction weights and metrics computation.""" + weights_proto, _, metrics = compute_rollout_correction_and_rejection_mask( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], - rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is="token", + rollout_rs=None, rollout_is_threshold=2.0, - rollout_is_veto_threshold=1e-4, + rollout_token_veto_threshold=1e-4, ) # Check weights @@ -113,54 +116,74 @@ def test_rollout_is_weights_computation(self, sample_data): # Check metrics are returned assert isinstance(metrics, dict) assert len(metrics) > 0 - assert "mismatch/rollout_is_mean" in metrics + assert "rollout_corr/rollout_is_mean" in metrics def test_all_aggregation_levels(self, sample_data): - """Test all three aggregation levels.""" - levels = ["token", "sequence", "geometric"] - - for level in levels: - _, _, metrics = compute_rollout_importance_weights( + """Test all aggregation levels (token, sequence for IS; geometric for RS).""" + # Test IS weight levels + is_levels = ["token", "sequence"] + for level in is_levels: + _, _, metrics = compute_rollout_correction_and_rejection_mask( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], - rollout_is_level=level, - rollout_is_mode="truncate", + rollout_is=level, rollout_is_threshold=2.0, + rollout_rs=None, ) + assert "rollout_corr/rollout_is_mean" in metrics - assert "mismatch/rollout_is_mean" in metrics + # Test rejection sampling with geometric level + _, _, metrics_geo = compute_rollout_correction_and_rejection_mask( + old_log_prob=sample_data["old_log_prob"], + rollout_log_prob=sample_data["rollout_log_prob"], + response_mask=sample_data["response_mask"], + rollout_is=None, + rollout_rs="geometric", + rollout_rs_threshold=2.0, + ) + assert "rollout_corr/rollout_rs_mean" in metrics_geo def test_both_bounding_modes(self, sample_data): """Test both truncate and mask modes.""" - modes = ["truncate", "mask"] - - for mode in modes: - _, _, metrics = compute_rollout_importance_weights( - old_log_prob=sample_data["old_log_prob"], - rollout_log_prob=sample_data["rollout_log_prob"], - response_mask=sample_data["response_mask"], - rollout_is_level="token", - rollout_is_mode=mode, - rollout_is_threshold=2.0, - rollout_is_threshold_lower=0.5, - ) + # Test truncate mode (IS weights only) + _, _, metrics_truncate = compute_rollout_correction_and_rejection_mask( + old_log_prob=sample_data["old_log_prob"], + rollout_log_prob=sample_data["rollout_log_prob"], + response_mask=sample_data["response_mask"], + rollout_is="token", + rollout_is_threshold=2.0, + rollout_rs=None, + ) + assert "rollout_corr/rollout_is_mean" in metrics_truncate - assert "mismatch/rollout_is_mean" in metrics + # Test mask mode (rejection sampling) + _, _, metrics_mask = compute_rollout_correction_and_rejection_mask( + old_log_prob=sample_data["old_log_prob"], + rollout_log_prob=sample_data["rollout_log_prob"], + response_mask=sample_data["response_mask"], + rollout_is="token", # Can also compute IS weights in mask mode + rollout_is_threshold=2.0, + rollout_rs="token", # Enable rejection sampling + rollout_rs_threshold=2.0, + rollout_rs_threshold_lower=0.5, + ) + assert "rollout_corr/rollout_is_mean" in metrics_mask + assert "rollout_corr/rollout_rs_mean" in metrics_mask - def test_mismatch_metrics(self, sample_data): - """Test mismatch diagnostic metrics computation.""" - metrics = compute_mismatch_metrics( + def test_offpolicy_metrics(self, sample_data): + """Test off-policy diagnostic metrics computation.""" + metrics = compute_offpolicy_metrics( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], ) # Check key metrics are present - assert "mismatch_training_ppl" in metrics - assert "mismatch_rollout_ppl" in metrics - assert "mismatch_kl" in metrics - assert isinstance(metrics["mismatch_kl"], float) + assert "training_ppl" in metrics + assert "rollout_ppl" in metrics + assert "kl" in metrics + assert isinstance(metrics["kl"], float) def test_veto_mechanism(self): """Test veto mechanism with catastrophic outliers.""" @@ -175,19 +198,19 @@ def test_veto_mechanism(self): response_mask = torch.ones(batch_size, seq_length, device=device) - _, _, metrics = compute_rollout_importance_weights( + _, _, metrics = compute_rollout_correction_and_rejection_mask( old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask, - rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is="token", rollout_is_threshold=2.0, - rollout_is_veto_threshold=1e-4, + rollout_rs=None, + rollout_token_veto_threshold=1e-4, ) # Should have vetoed one sequence - assert metrics["mismatch/rollout_is_veto_fraction"] > 0 - assert metrics["mismatch/rollout_is_veto_fraction"] <= 1.0 + assert metrics["rollout_corr/rollout_is_veto_fraction"] > 0 + assert metrics["rollout_corr/rollout_is_veto_fraction"] <= 1.0 def test_metrics_only_mode(self, sample_data, config_with_rollout_is): """Test metrics-only mode: compute IS weights/metrics but don't apply to loss. @@ -196,18 +219,18 @@ def test_metrics_only_mode(self, sample_data, config_with_rollout_is): but rollout_is=False (disables weight application to policy loss). """ # Compute IS weights (as trainer would do) - rollout_is_weights_proto, _, is_metrics = compute_rollout_importance_weights( + rollout_is_weights_proto, _, is_metrics = compute_rollout_correction_and_rejection_mask( old_log_prob=sample_data["old_log_prob"], rollout_log_prob=sample_data["rollout_log_prob"], response_mask=sample_data["response_mask"], - rollout_is_level="token", - rollout_is_mode="truncate", + rollout_is="token", rollout_is_threshold=2.0, + rollout_rs=None, ) # Metrics should be computed assert len(is_metrics) > 0 - assert "mismatch/rollout_is_mean" in is_metrics + assert "rollout_corr/rollout_is_mean" in is_metrics # In metrics-only mode, we compute loss WITHOUT applying weights # (simulating rollout_is=False) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 43718c1f399..5881161f4a7 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -466,10 +466,15 @@ reward_model: override_transformer_config: ${oc.select:actor_rollout_ref.actor.megatron.override_transformer_config,{}} use_mbridge: ${oc.select:actor_rollout_ref.actor.megatron.use_mbridge,False} load_weight: true -custom_reward_function: - path: null - name: compute_score algorithm: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + rollout_rs_threshold_lower: null + rollout_token_veto_threshold: null + bypass_old_logprob_for_rollout: false + use_pure_rollout_correction: false _target_: verl.trainer.config.AlgoConfig gamma: 1.0 lam: 1.0 @@ -487,12 +492,9 @@ algorithm: pf_ppo: reweight_method: pow weight_pow: 2.0 - rollout_is_threshold: null - rollout_is_threshold_lower: null - rollout_is_level: token - rollout_is_mode: truncate - rollout_is_veto_threshold: null - rollout_is: false +custom_reward_function: + path: null + name: compute_score trainer: balance_batch: true total_epochs: 30 diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 8ca6d8f31b8..52318324cba 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -453,6 +453,16 @@ reward_model: save_path: ${oc.select:global_profiler.save_path,null} tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} ulysses_sequence_parallel_size: 1 +'@algorithm': + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + rollout_rs_threshold_lower: null + rollout_token_veto_threshold: null + bypass_old_logprob_for_rollout: false + use_pure_rollout_correction: false custom_reward_function: path: null name: compute_score @@ -474,12 +484,6 @@ algorithm: pf_ppo: reweight_method: pow weight_pow: 2.0 - rollout_is_threshold: null - rollout_is_threshold_lower: null - rollout_is_level: token - rollout_is_mode: truncate - rollout_is_veto_threshold: null - rollout_is: false trainer: balance_batch: true total_epochs: 30 diff --git a/verl/trainer/config/algorithm.py b/verl/trainer/config/algorithm.py index bd6763226f7..9b79a3f3089 100644 --- a/verl/trainer/config/algorithm.py +++ b/verl/trainer/config/algorithm.py @@ -17,7 +17,7 @@ from verl.base_config import BaseConfig -__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig"] +__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig", "RolloutCorrectionConfig"] @dataclass @@ -56,6 +56,264 @@ class FilterGroupsConfig(BaseConfig): max_num_gen_batches: int = 0 +@dataclass +class RolloutCorrectionConfig(BaseConfig): + """Configuration for Rollout Correction (addresses off-policy issues in RL training). + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Rollout Correction handles off-policiness from multiple sources: + 1. Policy mismatch: Rollout policy (e.g., vLLM BF16) vs Training policy (e.g., FSDP FP32) + 2. Model update staleness: Rollout data collected from older policy checkpoints + 3. General off-policy scenarios: Any distribution shift between data collection and training + + For more details, see: + "When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch" + https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda + + This typed config replaces the old dict-based approach and provides: + - Type safety and validation + - Clear documentation of all parameters + - Named factory methods for common presets (TIS, MIS, etc.) + - Sensible defaults + + Args: + rollout_is (Optional[str]): IS weight aggregation level. + - None: No IS weights (metrics only) + - "token": Per-token IS weights (low variance, biased) + - "sequence": Per-sequence IS weights (unbiased, high variance) + Default: "sequence" + + rollout_is_threshold (float): Upper threshold for IS weight truncation/rejection. + Typical range: 1.5-5.0 for token level, 2.0-10.0 for sequence level. + Default: 2.0 + + rollout_rs (Optional[str]): Rejection sampling aggregation level. + - None: No rejection sampling + - "token": Reject individual tokens with outlier ratios + - "sequence": Reject entire sequences with outlier ratios + - "geometric": Geometric mean aggregation (threshold: 1.0002-1.001) + Default: None (use IS weights without rejection) + + rollout_rs_threshold (Optional[float]): Upper threshold for rejection sampling. + - If None and rollout_rs is enabled, uses rollout_is_threshold + - Tokens/sequences with ratio > threshold are masked out + Default: None (uses rollout_is_threshold when rollout_rs is enabled) + + rollout_rs_threshold_lower (Optional[float]): Lower threshold for rejection sampling. + - If None, uses reciprocal of upper threshold (1/upper) + - Tokens/sequences with ratio < threshold are masked out + Default: None (auto-computed as reciprocal) + + rollout_token_veto_threshold (Optional[float]): Per-token veto for catastrophic outliers. + - Checks unclamped per-token ratios before safety bounds + - If ANY token has ratio < threshold, entire sequence is rejected + - Independent of rollout_is and rollout_rs settings + - Typical values: 1e-4 to 1e-6 when enabled + Default: None (disabled) + + bypass_old_logprob_for_rollout (bool): Skip old_log_prob computation. + - True: Reuse rollout_log_prob as old_log_prob (15-20% speedup) + - False: Compute old_log_prob via actor.compute_log_prob() (standard) + Default: False (standard mode) + + use_pure_rollout_correction (bool): Use pure policy gradient with IS (no PPO clipping). + - Requires bypass_old_logprob_for_rollout=True + - True: Pure IS loss without clipping (higher variance, unbiased) + - False: PPO loss with IS correction (standard) + Default: False (PPO mode) + + Example: + # Create with defaults + config = RolloutCorrectionConfig() + + # Use presets + config = RolloutCorrectionConfig.token_is() # Token-level IS + config = RolloutCorrectionConfig.seq_is_rs() # Sequence-level IS + rejection sampling + config = RolloutCorrectionConfig.seq_is() # Sequence-level IS + + Reference: + Liu, Li, Fu, Wang, Liu, Shen (2025) + "When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch" + https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda + """ + + rollout_is: Optional[str] = "sequence" + rollout_is_threshold: float = 2.0 + rollout_rs: Optional[str] = None + rollout_rs_threshold: Optional[float] = None + rollout_rs_threshold_lower: Optional[float] = None + rollout_token_veto_threshold: Optional[float] = None + bypass_old_logprob_for_rollout: bool = False + use_pure_rollout_correction: bool = False + + @classmethod + def token_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Token-level Truncated Importance Sampling. + + IS weight correction at token level. + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for token-level IS + """ + return cls(rollout_is="token", rollout_is_threshold=threshold, rollout_rs=None) + + @classmethod + def token_tis(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Alias for token_is().""" + return cls.token_is(threshold=threshold) + + @classmethod + def seq_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Sequence-level Truncated Importance Sampling. + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for sequence-level IS + """ + return cls(rollout_is="sequence", rollout_is_threshold=threshold, rollout_rs=None) + + @classmethod + def seq_is_rs( + cls, + is_threshold: float = 2.0, + rs_threshold: float = 2.0, + rs_threshold_lower: Optional[float] = None, + ) -> "RolloutCorrectionConfig": + """Sequence-level IS with Rejection Sampling (MIS). + + Sequence-level IS with sequence-level rejection sampling. + Rejects entire sequences based on sequence-level IS weight. + + Args: + is_threshold (float): Upper threshold for IS weights. Default: 2.0 + rs_threshold (float): Upper threshold for rejection sampling. Default: 2.0 + rs_threshold_lower (Optional[float]): Lower threshold for rejection sampling. + If None, auto-computed as reciprocal of rs_threshold. Default: None + + Returns: + RolloutCorrectionConfig configured for sequence IS + RS + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=is_threshold, + rollout_rs="sequence", + rollout_rs_threshold=rs_threshold, + rollout_rs_threshold_lower=rs_threshold_lower, + ) + + @classmethod + def seq_mis(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Alias for seq_is_rs().""" + return cls.seq_is_rs( + is_threshold=threshold, + rs_threshold=threshold, + rs_threshold_lower=0, + ) + + @classmethod + def geo_rs( + cls, + rs_threshold: float = 1.001, + rs_threshold_lower: Optional[float] = None, + veto_threshold: float = 1e-4, + ) -> "RolloutCorrectionConfig": + """Geometric Rejection Sampling with Veto. + + Uses geometric mean for rejection sampling at sequence level, + with additional veto mechanism. Geometric mean is extremely sensitive to outliers, + requiring very tight thresholds close to 1.0. + + Args: + rs_threshold (float): Geometric RS threshold (upper). Default: 1.001 (±0.1%) + rs_threshold_lower (Optional[float]): Geometric RS threshold (lower). + If None, auto-computed as reciprocal of rs_threshold. Default: None + veto_threshold (float): Per-token veto threshold. Default: 1e-4 + + Returns: + RolloutCorrectionConfig configured for geometric RS with veto + """ + return cls( + rollout_is=None, + rollout_rs="geometric", + rollout_rs_threshold=rs_threshold, + rollout_rs_threshold_lower=rs_threshold_lower, + rollout_token_veto_threshold=veto_threshold, + ) + + @classmethod + def geo_mis( + cls, + rs_threshold: float = 1.001, + rs_threshold_lower: float = 0.999, + veto_threshold: float = 1e-4, + ) -> "RolloutCorrectionConfig": + """Alias for geo_rs().""" + return cls.geo_rs( + rs_threshold=rs_threshold, + rs_threshold_lower=rs_threshold_lower, + veto_threshold=veto_threshold, + ) + + @classmethod + def ppo_is_bypass(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """PPO with IS Correction in Bypass Mode. + + Skips old_log_prob computation by reusing rollout_log_prob. + PPO clips against rollout policy instead of true old policy. + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for PPO_IS bypass mode + """ + return cls( + rollout_is="token", + rollout_is_threshold=threshold, + rollout_rs=None, + bypass_old_logprob_for_rollout=True, + use_pure_rollout_correction=False, + ) + + @classmethod + def pure_is(cls, threshold: float = 2.0) -> "RolloutCorrectionConfig": + """Pure Policy Gradient with IS Correction. + + Uses pure policy gradient loss with explicit IS correction. + No PPO clipping. + + Args: + threshold (float): Upper threshold for IS weights. Default: 2.0 + + Returns: + RolloutCorrectionConfig configured for pure IS mode + """ + return cls( + rollout_is="sequence", + rollout_is_threshold=threshold, + rollout_rs=None, + bypass_old_logprob_for_rollout=True, + use_pure_rollout_correction=True, + ) + + @classmethod + def disabled(cls) -> "RolloutCorrectionConfig": + """Disabled - Metrics Only Mode. + + Computes and logs off-policy metrics without applying correction. + + Returns: + RolloutCorrectionConfig with all correction disabled + """ + return cls(rollout_is=None, rollout_rs=None) + + @dataclass class AlgoConfig(BaseConfig): """Configuration for the algorithm. @@ -73,14 +331,17 @@ class AlgoConfig(BaseConfig): use_pf_ppo (bool): Whether to enable preference feedback PPO. pf_ppo (dict[str, Any]): Preference feedback PPO settings. filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy - rollout_is_threshold (Optional[float]): Upper threshold for IS weights. null = disabled, - float value = enabled (compute weights and metrics). This is the main on/off switch. - rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper. - rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric". - rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds). - rollout_is_veto_threshold (float or None): Per-token veto threshold for catastrophic outliers. None to disable. - rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights, - False = compute metrics only (useful for monitoring before enabling correction). Default: False. + rollout_correction (Optional[RolloutCorrectionConfig]): Rollout Correction configuration. + Addresses off-policy issues from policy mismatch, model staleness, and general distribution shifts. + + Set to None to disable entirely. Use factory methods for common presets: + - RolloutCorrectionConfig.token_is() - Token-level IS + - RolloutCorrectionConfig.seq_is_rs() - Sequence-level IS + rejection sampling + - RolloutCorrectionConfig.seq_is() - Sequence-level IS (unbiased estimator) + - RolloutCorrectionConfig.geo_rs() - Geometric RS with veto + + For backward compatibility, you can still pass a dict, which will be converted to + RolloutCorrectionConfig automatically. """ gamma: float = 1.0 @@ -93,13 +354,6 @@ class AlgoConfig(BaseConfig): use_pf_ppo: bool = False pf_ppo: dict[str, Any] = field(default_factory=dict) filter_groups: Optional[FilterGroupsConfig] = None - # Rollout Importance Sampling - # Controls computation of IS weights and mismatch metrics - rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled - rollout_is_threshold_lower: Optional[float] = None - rollout_is_level: str = "token" - rollout_is_mode: str = "truncate" - rollout_is_veto_threshold: Optional[float] = None - # Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set) - # True = apply weights to loss, False = compute metrics only (no weight application) - rollout_is: bool = False + # Rollout Correction: corrects off-policy issues (policy mismatch, model staleness, distribution shifts) + # Set to None to disable, use RolloutCorrectionConfig presets (e.g., .tis(), .mis()), or pass dict + rollout_correction: Optional[RolloutCorrectionConfig] = None diff --git a/verl/trainer/config/algorithm/rollout_correction.yaml b/verl/trainer/config/algorithm/rollout_correction.yaml new file mode 100644 index 00000000000..5c53cf5ad3f --- /dev/null +++ b/verl/trainer/config/algorithm/rollout_correction.yaml @@ -0,0 +1,10 @@ +# Rollout Correction: corrects distribution mismatch between rollout and training policies +# Override via CLI: algorithm.rollout_correction.rollout_is="token_is" +rollout_is: null +rollout_is_threshold: 2.0 +rollout_rs: null +rollout_rs_threshold: null +rollout_rs_threshold_lower: null +rollout_token_veto_threshold: null +bypass_old_logprob_for_rollout: false +use_pure_rollout_correction: false diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index fb5eef5d6fc..c126dc418ee 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -15,6 +15,8 @@ defaults: - critic@critic: megatron_critic # Reward model config. - reward_model@reward_model: megatron_reward_model + # Rollout correction config. + - algorithm@algorithm: rollout_correction - _self_ actor_rollout_ref: @@ -23,7 +25,6 @@ actor_rollout_ref: nccl_timeout: 600 # seconds, default is 10 minutes for torch, you can set it to a larger value if you have long-running operations like 32B or 72B model using megatron model: - path: ~/models/deepseek-llm-7b-chat custom_chat_template: null @@ -73,28 +74,6 @@ algorithm: reweight_method: pow # ["pow", "max_min", "max_random"] weight_pow: 2.0 - # Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies - # Main control: Upper threshold for IS weights (null = disabled, float = enabled) - # When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.) - rollout_is_threshold: null - - # Lower threshold for IS weights (null = auto-reciprocal of upper) - rollout_is_threshold_lower: null - - # Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental) - rollout_is_level: token - - # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds) - rollout_is_mode: truncate - - # Per-token veto threshold for catastrophic outliers (null to disable) - rollout_is_veto_threshold: null - - # Whether to apply IS weights to policy loss - # true = apply weights to loss, false = compute metrics only (no weight application) - # Useful for monitoring mismatch before enabling correction - rollout_is: false - trainer: balance_batch: True total_epochs: 30 @@ -191,7 +170,6 @@ global_profiler: # configs for TransferQueue transfer_queue: - # Whether to enable transfer queue enable: False diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index d9d250c9a48..baa55802410 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -30,6 +30,9 @@ defaults: # Reward model config. - reward_model@reward_model: dp_reward_model + # Rollout correction config. + - algorithm@@algorithm.rollout_correction: rollout_correction + # load the reference default config, then apply the fields in the current yaml # self config override anything above - _self_ @@ -113,28 +116,6 @@ algorithm: # Power used for weight scaling in "pow" method weight_pow: 2.0 - # Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies - # Main control: Upper threshold for IS weights (null = disabled, float = enabled) - # When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.) - rollout_is_threshold: null - - # Lower threshold for IS weights (null = auto-reciprocal of upper) - rollout_is_threshold_lower: null - - # Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental) - rollout_is_level: token - - # Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds) - rollout_is_mode: truncate - - # Per-token veto threshold for catastrophic outliers (null to disable) - rollout_is_veto_threshold: null - - # Whether to apply IS weights to policy loss - # true = apply weights to loss, false = compute metrics only (no weight application) - # Useful for monitoring mismatch before enabling correction - rollout_is: false - # config for the trainer trainer: diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7c2bfd53673..300d58c1e14 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -962,7 +962,7 @@ def compute_policy_loss_vanilla( pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1) - # Apply rollout importance sampling weights if provided + # Apply rollout correction weights if provided if rollout_is_weights is not None: pg_losses = pg_losses * rollout_is_weights @@ -1025,7 +1025,7 @@ def compute_policy_loss_gspo( pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high) pg_losses = torch.maximum(pg_losses1, pg_losses2) - # Apply rollout importance sampling weights if provided + # Apply rollout correction weights if provided if rollout_is_weights is not None: pg_losses = pg_losses * rollout_is_weights @@ -1066,7 +1066,7 @@ def compute_policy_loss_gpg( """ pg_losses = -log_prob * advantages - # Apply rollout importance sampling weights if provided + # Apply rollout correction weights if provided if rollout_is_weights is not None: pg_losses = pg_losses * rollout_is_weights @@ -1165,7 +1165,7 @@ def compute_policy_loss_clip_cov( pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr - # Apply rollout importance sampling weights if provided + # Apply rollout correction weights if provided if rollout_is_weights is not None: pg_losses = pg_losses * rollout_is_weights @@ -1241,7 +1241,7 @@ def compute_policy_loss_kl_cov( large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1] ] - # Apply rollout importance sampling weights if provided + # Apply rollout correction weights if provided if rollout_is_weights is not None: pg_losses = pg_losses * rollout_is_weights @@ -1312,11 +1312,11 @@ def compute_policy_loss_geo_mean( advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) pg_losses = -advantage * ratio - # Apply rollout importance sampling weights if provided + # Apply rollout correction weights if provided # For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level if rollout_is_weights is not None: # Aggregate token-level weights to sequence level using geometric mean for consistency - # Note: rollout_is_weights is always 2D regardless of rollout_is_level + # Note: rollout_is_weights is always 2D regardless of aggregation mode seq_is_weights = torch.exp( (torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8) ) @@ -1532,3 +1532,215 @@ def compute_weights(scores: torch.Tensor, reweight_method: str, weight_pow: floa resampled_data.meta_info = resampled_meta_info return resampled_data + + +def compute_policy_loss_with_rollout_correction( + rollout_log_prob, + log_prob, + advantages, + eos_mask, + loss_agg_mode="seq-mean-token-sum", + loss_scale_factor=1.0, + rollout_is: Optional[str] = None, + rollout_is_threshold: float = 2.0, + rollout_rs: Optional[str] = None, + rollout_rs_threshold: Optional[float] = None, + rollout_rs_threshold_lower: Optional[float] = None, + rollout_token_veto_threshold: Optional[float] = None, +): + """Compute policy loss with pure rollout correction (no PPO clipping). + + This function implements policy gradient with importance sampling correction + for rollout-training policy mismatch, without PPO's clipping mechanism. + + Mathematical formulation: + Without IS (rollout_is=None): + L = -E[log π(a|s) * A(s,a)] + Gradient: ∇_θ L = -E[∇log π(a|s) * A] (standard REINFORCE) + + With IS (rollout_is enabled): + L = -E_π_rollout[w * log π(a|s) * A(s,a)] + where w = π_current / π_rollout (truncated IS weight) + Gradient: ∇_θ L = -E[w * ∇log π(a|s) * A] (IS-corrected policy gradient) + + Args: + rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM BF16). + Shape: (batch_size, seq_length) + log_prob: Log probabilities from current training policy. + Shape: (batch_size, seq_length) + advantages: Advantage estimates for each token. + Shape: (batch_size, seq_length) + eos_mask: Mask indicating valid tokens (1 for valid, 0 for padding). + Shape: (batch_size, seq_length) + loss_agg_mode: Loss aggregation strategy (see agg_loss for details). + loss_scale_factor: Multiplicative scaling factor applied to final loss. + rollout_is: IS aggregation level ("token", "sequence", or None). + rollout_is_threshold: Upper threshold for truncating IS weights. + rollout_rs: Rejection sampling aggregation level (or None to disable). + rollout_rs_threshold: Upper threshold for rejection sampling. + rollout_rs_threshold_lower: Lower threshold for rejection sampling. + rollout_token_veto_threshold: Per-token veto threshold for catastrophic outliers. + + Returns: + Tuple of (loss, clip_fraction, kl_divergence, clip_fraction_lower): + - loss: Policy gradient loss with IS correction + - clip_fraction: Always 0.0 (no clipping in this mode) + - kl_divergence: KL between current and rollout policy + - clip_fraction_lower: Always 0.0 (no clipping in this mode) + Note: Rollout correction metrics are computed internally but not returned. + Caller should compute them separately if needed. + + Note: + Unlike compute_policy_loss (PPO), this function: + - Does NOT use PPO clipping (no old_log_prob needed) + - Directly applies IS correction computed from current vs rollout + - Computes IS/RS on-the-fly during training + + Usage: + This function is called by the actor when: + - bypass_old_logprob_for_rollout=True (trainer uses rollout_log_prob as old_log_prob) + - use_pure_rollout_correction=True (actor uses this function instead of compute_policy_loss) + + Example config: + algorithm: + rollout_correction: + bypass_old_logprob_for_rollout: true + use_pure_rollout_correction: true + rollout_is: "token" + rollout_is_threshold: 2.0 + rollout_rs: "token" + rollout_rs_threshold: 2.0 + rollout_rs_threshold_lower: 0.5 + + Performance: + - Memory: Saves ~1MB per batch (no old_log_prob storage) + - Speed: ~15-20% faster (skips actor.compute_log_prob()) + - Variance: Higher than PPO (no clipping safety net) + """ + # Import rollout correction helper + from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask + + # Compute IS weights and rejection mask on-the-fly + rollout_is_weights_proto, modified_response_mask, rollout_metrics = compute_rollout_correction_and_rejection_mask( + old_log_prob=log_prob, # Current policy + rollout_log_prob=rollout_log_prob, # Rollout policy + response_mask=eos_mask, + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + rollout_rs=rollout_rs, + rollout_rs_threshold=rollout_rs_threshold, + rollout_rs_threshold_lower=rollout_rs_threshold_lower, + rollout_token_veto_threshold=rollout_token_veto_threshold, + ) + + # Extract weights tensor from DataProto (or None if disabled) + rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"] if rollout_is_weights_proto else None + + # Apply rejection mask (if RS is enabled) + effective_mask = modified_response_mask if rollout_rs is not None else eos_mask + + # Compute pure policy gradient loss with IS correction + # Standard REINFORCE: L = -E[log π(a|s) * A] + # With IS: L = -E[w * log π(a|s) * A] where w = π_current / π_rollout + # + # Note: rollout_is_weights already contains w = π_current / π_rollout + # So we apply it to the standard log-prob trick formula + + if rollout_is_weights is not None: + # With IS correction: weight the log-prob trick by IS weight + # w = exp(log_prob - rollout_log_prob).clamp(max=threshold) + # L = -E[w * log π * A] + # Gradient: ∇L = -E[w * ∇log π * A] = -E[w * A] + pg_losses = -advantages * log_prob * rollout_is_weights + else: + # No IS correction: standard REINFORCE with log-prob trick + # L = -E[log π(a|s) * A] + # Gradient: ∇L = -E[∇log π * A] = -E[A] + pg_losses = -advantages * log_prob + + # Aggregate loss (apply scale factor manually) + pg_loss = ( + agg_loss( + loss_mat=pg_losses, + loss_mask=effective_mask, + loss_agg_mode=loss_agg_mode, + ) + * loss_scale_factor + ) + + # Compute KL divergence between current and rollout policy + negative_approx_kl = log_prob - rollout_log_prob + kl_divergence = verl_F.masked_mean(-negative_approx_kl, effective_mask) + + # No clipping in pure rollout correction mode + clip_fraction = torch.tensor(0.0) + + # Return tuple matching compute_policy_loss signature: (loss, clip_fraction, kl, clip_fraction_lower) + # Note: Algorithm metrics (rollout_metrics) should be handled separately by caller + return pg_loss, clip_fraction, kl_divergence, clip_fraction + + +@register_policy_loss("rollout_correction") +def compute_policy_loss_rollout_correction_wrapper( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + response_mask: torch.Tensor, + loss_agg_mode: str = "token-mean", + config: Optional[DictConfig | AlgoConfig] = None, + rollout_is_weights: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Wrapper for compute_policy_loss_with_rollout_correction to match PolicyLossFn interface. + + This function is used when algorithm.rollout_correction.use_pure_rollout_correction=True. + In this mode, the trainer has already set old_log_prob=rollout_log_prob (bypass mode). + + Args: + old_log_prob: In bypass mode, this is actually rollout_log_prob + log_prob: Current policy log probabilities + advantages: Advantage estimates + response_mask: Valid token mask + loss_agg_mode: Loss aggregation mode + config: Actor config containing rollout_correction settings + rollout_is_weights: Pre-computed IS weights (ignored, computed internally) + + Returns: + Tuple of (loss, clip_fraction, kl, clip_fraction_lower) + """ + assert config is not None, "config is required for rollout_correction loss mode" + + # Extract rollout_correction config + # In ray_trainer, when use_pure_rollout_correction=True, the rollout_correction config + # is embedded in actor config's policy_loss field + rollout_corr_config = config.policy_loss.get("rollout_correction", None) if hasattr(config, "policy_loss") else None + + if rollout_corr_config is None: + raise ValueError( + "rollout_correction config not found in policy_loss. " + "When using loss_mode='rollout_correction', ensure rollout_correction config is passed." + ) + + # Extract parameters + rollout_is = rollout_corr_config.get("rollout_is", None) + rollout_is_threshold = rollout_corr_config.get("rollout_is_threshold", 2.0) + rollout_rs = rollout_corr_config.get("rollout_rs", None) + rollout_rs_threshold = rollout_corr_config.get("rollout_rs_threshold", None) + rollout_rs_threshold_lower = rollout_corr_config.get("rollout_rs_threshold_lower", None) + rollout_token_veto_threshold = rollout_corr_config.get("rollout_token_veto_threshold", None) + + # Call the actual implementation + # In bypass mode, old_log_prob IS rollout_log_prob + return compute_policy_loss_with_rollout_correction( + rollout_log_prob=old_log_prob, # This is rollout_log_prob in bypass mode + log_prob=log_prob, + advantages=advantages, + eos_mask=response_mask, + loss_agg_mode=loss_agg_mode, + loss_scale_factor=1.0, + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + rollout_rs=rollout_rs, + rollout_rs_threshold=rollout_rs_threshold, + rollout_rs_threshold_lower=rollout_rs_threshold_lower, + rollout_token_veto_threshold=rollout_token_veto_threshold, + ) diff --git a/verl/trainer/ppo/mismatch_helper.py b/verl/trainer/ppo/mismatch_helper.py deleted file mode 100644 index 7b14383bd94..00000000000 --- a/verl/trainer/ppo/mismatch_helper.py +++ /dev/null @@ -1,488 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Rollout Importance Sampling (IS) Helper Module - -This module handles importance sampling weight computation for correcting -distribution mismatch between rollout policy (e.g., vLLM BFloat16) and -training policy (e.g., FSDP FP32). - -Key Features: -1. Three aggregation levels: token, sequence, geometric -2. Two handling modes: truncate, mask -3. Per-token veto mechanism for catastrophic outliers -4. Memory-efficient computation to prevent CUDA OOM -5. Comprehensive metrics tracking - -Usage Notes: -- compute_rollout_importance_weights() computes both IS weights and mismatch metrics -- Used in ray_trainer.py via compute_rollout_importance_weights_and_add_to_batch() -- Also used in dp_actor.py for distributed worker computations -- compute_mismatch_metrics() is called internally by compute_rollout_importance_weights() - -References: -- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda -- Off-policy RL: https://fengyao.notion.site/off-policy-rl -""" - -from typing import Any, Optional - -import torch - -import verl.utils.torch_functional as verl_F -from verl.protocol import DataProto - - -def compute_rollout_importance_weights( - old_log_prob: torch.Tensor, - rollout_log_prob: torch.Tensor, - response_mask: torch.Tensor, - rollout_is_level: str = "token", - rollout_is_mode: str = "truncate", - rollout_is_threshold: Optional[float] = None, - rollout_is_threshold_lower: Optional[float] = None, - rollout_is_veto_threshold: Optional[float] = None, -) -> tuple[Optional[DataProto], torch.Tensor, dict[str, Any]]: - """Compute importance sampling weights and rejection mask for rollout-training mismatch. - - This function computes IS weights to correct for distribution mismatch between rollout - and training policies, and applies rejection sampling for outliers. - - Key Design: Separation of IS Weights and Rejection Sampling - - IS weights (rollout_is_weights): Ratios π_train/π_rollout with processing applied: - * Safety-bounded to prevent overflow: - - Token level: exp(clamp(log_ratio, -20, 20)) per token - - Sequence level: exp(clamp(sum(log_ratio), -20, 20)) broadcast to all tokens - - Geometric level: exp(clamp(mean(log_ratio), -20, 20)) broadcast to all tokens - * Truncate mode: upper clamped via .clamp(max=upper_threshold) - * Mask mode: safety-bounded ratios preserved (no threshold clamping) - * All modes: zeroed at padding positions - Used for policy gradient calculations - - Response mask (modified_response_mask): Has rejection applied (mask mode + veto) - Used for loss aggregation to exclude rejected samples from training - - Reference: - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda - - Memory-efficient implementation: - - Log-space computation to prevent overflow - - Safety bounds (exp(±20)) on all exponentiations - - Metrics computed without large intermediate tensors - - Args: - old_log_prob: Log probs from training policy (FSDP FP32), shape (batch_size, seq_length) - rollout_log_prob: Log probs from rollout policy (vLLM BF16), shape (batch_size, seq_length) - response_mask: Valid token mask (1=valid, 0=padding), shape (batch_size, seq_length) - rollout_is_level: IS weight aggregation level - - "token": Per-token ratios ρ_t = π_train(t)/π_rollout(t) (biased but low variance) - - "sequence": Sequence product ρ_seq = ∏ρ_t (unbiased but high variance) - - "geometric": Geometric mean ρ_geo = (∏ρ_t)^(1/T) (experimental trade-off) - rollout_is_mode: Treatment of outlier IS weights - - "truncate": Clamp weights at upper threshold only. No rejection for outlier ratios, - but veto can still apply (TIS) - - "mask": Reject tokens/sequences outside [lower, upper] via response_mask (MIS/rejection sampling) - rollout_is_threshold: Upper threshold for IS weights (required, e.g., 2.0) - rollout_is_threshold_lower: Lower threshold for mask mode (if None, defaults to 1/upper) - rollout_is_veto_threshold: Catastrophic token threshold. If any token has ratio < this, - reject entire sequence. Applied independently of rollout_is_mode. If None, veto disabled. Default None. - - Returns: - Tuple of (weights_proto, modified_response_mask, metrics): - weights_proto: DataProto with processed IS weights, key "rollout_is_weights", - shape (batch_size, seq_length). Processing applied: - - Safety-bounded to [exp(-20), exp(20)] ≈ [2e-9, 5e8]: - * Token level: bounds per-token ratios - * Sequence/geometric level: bounds aggregated ratio (broadcast to all tokens) - - Truncate mode: upper clamped via .clamp(max=upper_threshold) - - Mask mode: safety-bounded ratios preserved (no threshold clamping) - - All modes: zeroed at padding positions (response_mask == 0) - None if rollout_is_threshold is None. - modified_response_mask: Response mask with rejection applied: - - truncate mode: unchanged for outlier ratios, but veto rejection still applied - - mask mode: tokens outside [lower, upper] masked to 0 - - veto: sequences with catastrophic tokens masked to 0 (applied in both modes) - Shape (batch_size, seq_length). - metrics: Dict of IS and mismatch metrics, all scalars with "mismatch/" prefix - """ - if rollout_is_threshold is None: - return None, response_mask, {} - - # Parse thresholds: if lower not specified, use 1/upper (reciprocal) - upper_threshold = rollout_is_threshold - if rollout_is_threshold_lower is not None: - lower_threshold = rollout_is_threshold_lower - else: - # Default: lower = 1/upper (reciprocal) - lower_threshold = 1.0 / upper_threshold - - # Step 1: Compute raw importance weights based on the specified level - log_ratio = old_log_prob - rollout_log_prob - - # Pre-compute log thresholds - device = old_log_prob.device - log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device)) - log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device)) - - # Safety bound to prevent numerical overflow (exp(20) ≈ 485M) - SAFETY_BOUND = 20.0 - - # Store unclamped values in log-space for accurate metrics - if rollout_is_level == "token": - # Token-level IS: π_train(a|s) / π_rollout(a|s) per token - log_ratio_for_metrics = log_ratio - - # Apply safety bound to prevent overflow - log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) - rollout_is_weights = torch.exp(log_ratio_safe) - - elif rollout_is_level == "sequence": - # Sequence-level IS: π_train(y|x) / π_rollout(y|x) for entire sequence - # Product of token ratios: exp(Σ log(π_train/π_rollout)) - log_ratio_sum = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_sum # Store for metrics - - # Apply safety bound to prevent overflow - log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND) - rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob) - - elif rollout_is_level == "geometric": - # Geometric mean IS: (∏ π_train/π_rollout)^(1/T) - # Equivalent to exp(mean(log(π_train/π_rollout))) - log_ratio_mean = verl_F.masked_mean(log_ratio, response_mask, axis=-1).unsqueeze(-1) - log_ratio_for_metrics = log_ratio_mean # Store for metrics - - # Geometric mean rarely explodes due to averaging, but apply safety bound anyway - log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND) - rollout_is_weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob) - - else: - raise ValueError(f"Invalid rollout_is_level: {rollout_is_level}. Must be 'token', 'sequence', or 'geometric'.") - - # Step 1.5: Apply per-token veto check in log space (memory efficient) - if rollout_is_veto_threshold is not None: - log_veto_threshold = torch.log(torch.tensor(rollout_is_veto_threshold, device=device)) - - # Check if any token ratio is below veto threshold (in log space) - # log(π_train/π_rollout) < log(veto_threshold) ⟺ π_train/π_rollout < veto_threshold - catastrophic_tokens = (log_ratio < log_veto_threshold) & response_mask.bool() - - # For each sequence, check if it has any catastrophic token - # Use broadcasting instead of expand_as to save memory - has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True) - - # Create veto mask: 0 if sequence has catastrophic token, 1 otherwise - veto_mask = (~has_catastrophic).float() - else: - # No veto mechanism - catastrophic_tokens = torch.zeros_like(response_mask, dtype=torch.bool) - has_catastrophic = torch.zeros((old_log_prob.size(0), 1), dtype=torch.bool, device=device) - veto_mask = torch.ones((old_log_prob.size(0), 1), dtype=torch.float32, device=device) - - # Step 2: Compute comprehensive metrics - metrics = compute_is_metrics( - rollout_is_weights=rollout_is_weights, - log_ratio_for_metrics=log_ratio_for_metrics, - response_mask=response_mask, - rollout_is_level=rollout_is_level, - rollout_is_threshold=upper_threshold, - rollout_is_threshold_lower=lower_threshold, - log_threshold_upper=log_threshold_upper, - log_threshold_lower=log_threshold_lower, - has_catastrophic=has_catastrophic, - catastrophic_tokens=catastrophic_tokens, - SAFETY_BOUND=SAFETY_BOUND, - ) - - # Step 3: Apply outlier handling and rejection sampling - # Key design principle: IS weights and rejection are separate mechanisms - # - rollout_is_weights: IS weight ratios with mode-specific processing - # * Truncate mode: upper clamped to prevent extreme values - # * Mask mode: safety-bounded ratios preserved (no threshold clamping, rejection via mask) - # Used for policy gradient calculations - # - modified_response_mask: Has rejection applied (excludes outliers from training) - # Used for loss denominator: ensures rejected samples don't dilute gradients - - if rollout_is_mode == "truncate": - # Truncated IS (TIS): clamp weights to prevent extreme importance ratios - # Weights are modified by clamping; no rejection via mask for outlier ratios - # Veto rejection (if enabled) will still be applied to modified_response_mask below - rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold) - modified_response_mask = response_mask # Unchanged for outlier ratios (veto applied later) - - elif rollout_is_mode == "mask": - # Masked IS (MIS): rejection sampling for outlier IS weights - # Reject tokens/sequences with IS ratios outside [lower, upper] via response_mask - # IS weights themselves are NOT threshold-clamped (remain safety-bounded only) - mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) - mask = mask.float() - - # Compute rejection rate metrics - metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask) - if rollout_is_level in ["sequence", "geometric"]: - # Sequence-level: all tokens have same weight, check first token - metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean() - else: - # Token-level: sequence rejected if ANY token is rejected - seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0 - metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean() - - # Apply rejection via response_mask (NOT by clamping IS weights) - modified_response_mask = response_mask * mask - # rollout_is_weights kept as safety-bounded ratios (no threshold clamping) - - else: - raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.") - - # Apply veto: reject entire sequences with catastrophic tokens (ratio < veto_threshold) - # Veto is independent of mode - it applies to modified_response_mask after mode-specific handling - modified_response_mask = modified_response_mask * veto_mask - # Note: rollout_is_weights unaffected by veto (already clamped in truncate mode, or kept as-is in mask mode) - - # Zero out padding positions in IS weights for correct aggregation - # This is different from rejection - padding must be zeroed regardless of mode - rollout_is_weights = rollout_is_weights * response_mask - - # Wrap in DataProto for consistency with worker methods - rollout_is_weights_proto = DataProto.from_dict(tensors={"rollout_is_weights": rollout_is_weights}) - - # Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics - mismatch_metrics = compute_mismatch_metrics( - old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask - ) - metrics.update(mismatch_metrics) - - # Convert all tensor metrics to scalars for logging - # Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad() - metrics_scalar = {} - for key, value in metrics.items(): - if isinstance(value, torch.Tensor): - metrics_scalar[f"mismatch/{key}"] = value.item() - else: - metrics_scalar[f"mismatch/{key}"] = value - - return rollout_is_weights_proto, modified_response_mask, metrics_scalar - - -def compute_is_metrics( - rollout_is_weights: torch.Tensor, - log_ratio_for_metrics: torch.Tensor, - response_mask: torch.Tensor, - rollout_is_level: str, - rollout_is_threshold: float, - rollout_is_threshold_lower: float, - log_threshold_upper: torch.Tensor, - log_threshold_lower: torch.Tensor, - has_catastrophic: torch.Tensor, - catastrophic_tokens: torch.Tensor, - SAFETY_BOUND: float, -) -> dict[str, Any]: - """Compute comprehensive metrics for importance sampling weights. - - Reference: - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda - - This function computes metrics using a mix of true unclamped values (for max/min/fractions - in sequence/geometric mode via log-space) and safety-clamped values (for mean/std/ESS) - to balance accuracy with numerical stability and avoid overflow. - """ - # Validate that we have at least one valid sample - assert response_mask.any(), "Expected at least one valid sample in response_mask" - - metrics = {} - device = rollout_is_weights.device - - # Track veto statistics - metrics["rollout_is_veto_fraction"] = has_catastrophic.float().mean() - metrics["rollout_is_catastrophic_token_fraction"] = verl_F.masked_mean(catastrophic_tokens.float(), response_mask) - - # Compute metrics based on IS level - if rollout_is_level in ["sequence", "geometric"]: - # For sequence/geometric, compute true statistics from log-space - # This reflects the actual distribution before clamping - - # True max/min in log space - log_max = log_ratio_for_metrics.max() - log_min = log_ratio_for_metrics.min() - - # Convert to regular space with safety bound - metrics["rollout_is_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND)) - metrics["rollout_is_min"] = torch.exp(log_min) - - # Mean uses clamped weights to avoid overflow - metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask) - - # Compute fraction exceeding threshold in log space (accurate) - exceeds_upper = log_ratio_for_metrics > log_threshold_upper - below_lower = log_ratio_for_metrics < log_threshold_lower - - if rollout_is_level == "sequence": - # For sequence level, all tokens in a sequence have the same weight - metrics["rollout_is_ratio_fraction_high"] = exceeds_upper.float().mean() - metrics["rollout_is_ratio_fraction_low"] = below_lower.float().mean() - else: # geometric - # Need to expand to match token dimensions - exceeds_upper_expanded = exceeds_upper.expand_as(response_mask) - below_lower_expanded = below_lower.expand_as(response_mask) - metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean( - exceeds_upper_expanded.float(), response_mask - ) - metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(below_lower_expanded.float(), response_mask) - - else: - # Token-level: compute directly from weights - metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask) - - # Fraction exceeding thresholds - rollout_is_above_threshold = rollout_is_weights > rollout_is_threshold - rollout_is_below_threshold = rollout_is_weights < rollout_is_threshold_lower - metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean( - rollout_is_above_threshold.float(), response_mask - ) - metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(rollout_is_below_threshold.float(), response_mask) - - # Max/min for token level - mask_bool = response_mask.bool() - metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max() - metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min() - - # Compute standard deviation using clamped weights to avoid overflow - mask_count = response_mask.sum() - if mask_count > 1: - # Use clamped weights for variance to avoid squaring huge values - weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold) - # Use mean from clamped weights for consistency - mean_clamped = verl_F.masked_mean(weights_for_std, response_mask) - rollout_is_var = verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square() - metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0)) - else: - metrics["rollout_is_std"] = torch.tensor(0.0, device=device) - - # Effective sample size (use clamped weights to avoid overflow) - weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold) - mean_for_ess = verl_F.masked_mean(weights_for_ess, response_mask) - is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8) - metrics["rollout_is_eff_sample_size"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask) - - # Per-sequence breakdown metrics - if rollout_is_weights.dim() > 1: - # Compute mean IS weight per sequence - seq_mean_weights = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1) - - # Per-sequence statistics - metrics["rollout_is_seq_mean"] = seq_mean_weights.mean() - metrics["rollout_is_seq_std"] = ( - seq_mean_weights.std() if seq_mean_weights.numel() > 1 else torch.tensor(0.0, device=device) - ) - metrics["rollout_is_seq_max"] = seq_mean_weights.max() - metrics["rollout_is_seq_min"] = seq_mean_weights.min() - - # Identify most problematic sequences - seq_deviation = (seq_mean_weights - 1.0).abs() - metrics["rollout_is_seq_max_deviation"] = seq_deviation.max() - - # Fraction of sequences with high IS weights - metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean() - metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean() - - return metrics - - -def compute_mismatch_metrics( - old_log_prob: torch.Tensor, - rollout_log_prob: Optional[torch.Tensor], - response_mask: torch.Tensor, -) -> dict[str, Any]: - """Compute training-inference mismatch metrics (helper function). - - This helper function operates on raw tensors and is used internally by: - - compute_rollout_importance_weights() in this module (automatically included) - - Tests (test_rollout_is.py, test_rollout_is_integration.py) - - These metrics help diagnose the mismatch between the rollout policy (e.g., vLLM) - and the training policy (e.g., FSDP), which can cause training instability. - - Key metrics: - - mismatch_kl: Direct KL divergence estimator KL(π_rollout || π_training) - - mismatch_k3_kl: K3 KL estimator for stability (more stable for small KL) - - training_ppl: Perplexity of training policy - - rollout_ppl: Perplexity of rollout policy - - log_ppl_diff: Difference in log perplexities - - ppl_ratio: Ratio of training PPL to rollout PPL - - Args: - old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length) - rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length) - response_mask: Mask for valid tokens, shape (batch_size, seq_length) - - Returns: - Dictionary of mismatch metrics (without prefix) - - Reference: - - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda - """ - # Validate that we have at least one valid token - assert response_mask.any(), "Expected at least one valid token in response_mask" - - metrics = {} - - # 1. Training policy perplexity (always available) - # Formula: exp(-1/|T| * Σ log π_training(y_t|y_ 0 and "rollout_log_probs" in batch.batch: - # Compute IS weights and get modified response_mask - rollout_is_weights, modified_response_mask, rollout_is_metrics = compute_rollout_importance_weights( - old_log_prob=batch.batch["old_log_probs"], - rollout_log_prob=batch.batch["rollout_log_probs"], - response_mask=batch.batch["response_mask"], - rollout_is_level=self.config.algorithm.rollout_is_level, - rollout_is_mode=self.config.algorithm.rollout_is_mode, - rollout_is_threshold=self.config.algorithm.rollout_is_threshold, - rollout_is_threshold_lower=self.config.algorithm.get("rollout_is_threshold_lower", None), - rollout_is_veto_threshold=self.config.algorithm.get("rollout_is_veto_threshold", None), - ) - - # ALWAYS update response_mask with rejection (even if rollout_is=False) - # - Mask mode: tokens with outlier IS ratios excluded - # - Veto: sequences with catastrophic tokens excluded - # This ensures correct loss normalization (rejected samples not in denominator) - batch.batch["response_mask"] = modified_response_mask - - # Conditionally add IS weights based on rollout_is config flag - # - rollout_is=True: Enable IS weight correction in policy loss - # - rollout_is=False: Metrics-only mode (rejection still applied via mask) - apply_weights = self.config.algorithm.get("rollout_is", False) - - if apply_weights: - # Add IS weights (safety-bounded, mode-processed) to enable weight correction - batch = batch.union(rollout_is_weights) - - return batch, rollout_is_metrics - - # Return unchanged batch and empty metrics if IS is disabled - return batch, {} - def fit(self): """ The training loop of PPO. @@ -1161,23 +1101,38 @@ def fit(self): else: reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) - # recompute old_log_probs - with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] - response_masks = batch.batch["response_mask"] - loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode - entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) - old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} - metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) + from verl.trainer.ppo.rollout_corr_helper import ( + compute_rollout_correction_and_add_to_batch, + maybe_apply_rollout_correction, + ) + + rollout_corr_config = self.config.algorithm.get("rollout_correction", None) + need_recomputation = maybe_apply_rollout_correction( + batch=batch, + rollout_corr_config=rollout_corr_config, + policy_loss_config=self.config.actor_rollout_ref.actor.policy_loss, + ) + if need_recomputation: + # LEGACY MODE: Compute old_log_probs from actor + with marked_timer("old_log_prob", timing_raw, color="blue"): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss( + loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode + ) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics - if "rollout_log_probs" in batch.batch.keys(): - # TODO: we may want to add diff of probs too. - from verl.utils.debug.metrics import calculate_debug_metrics + metrics.update(calculate_debug_metrics(batch)) - metrics.update(calculate_debug_metrics(batch)) + assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' if self.use_reference_policy: # compute reference log_prob @@ -1213,12 +1168,13 @@ def fit(self): else: batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] - # Compute rollout importance sampling weights centrally (once per batch) - # This corrects for mismatch between rollout policy and training policy - # Also computes mismatch metrics (KL, PPL, etc.) - batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch) - # IS and mismatch metrics already have mismatch/ prefix - metrics.update(is_metrics) + # Compute rollout correction weights centrally (once per batch) + # This corrects for off-policy issues (policy mismatch, model staleness, etc.) + # Also computes off-policy diagnostic metrics (KL, PPL, etc.) + if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: + batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) + # IS and off-policy metrics already have rollout_corr/ prefix + metrics.update(is_metrics) # compute advantages, executed on the driver process norm_adv_by_std_in_grpo = self.config.algorithm.get( diff --git a/verl/trainer/ppo/rollout_corr_helper.py b/verl/trainer/ppo/rollout_corr_helper.py new file mode 100644 index 00000000000..737246b6aef --- /dev/null +++ b/verl/trainer/ppo/rollout_corr_helper.py @@ -0,0 +1,879 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Rollout Correction Helper Module + +This module provides a complete pipeline to address **off-policy issues** in RL training, +including: +1. Policy mismatch between rollout and training implementations (e.g., vLLM BFloat16 vs FSDP FP32) +2. Model update staleness (training on trajectories from older checkpoints) +3. General distribution shifts between data collection and training + +Its core capabilities include computing importance sampling (IS) weights, +filtering outlier samples via rejection sampling (RS), and +tracking metrics to diagnose and correct off-policy issues. + +## Core Capabilities +1. **Multi-Granularity Aggregation**: + - Importance Sampling (IS): + Token-level + Sequence-level + - Rejection Sampling (RS): + Token-level + Sequence/geometric (sequence-level geometric mean) — supports flexible outlier filtering. +2. **Catastrophic Outlier Veto**: + Independent per-token veto mechanism — fully reject sequences containing tokens + with extremely low IS weights (prevents catastrophic updates). +3. **Memory-Efficient Design**: + - Log-space computations to avoid numerical overflow/underflow. + - Fixed safety bounds (exp(±20)) for stable exponentiation. + - Metrics calculated without large intermediate tensors (prevents CUDA OOM). +4. **Comprehensive Metrics Tracking**: + - IS/RS statistics (mean/max/min, effective sample size ESS, rejection rate). + - Off-policy diagnostics (KL divergence, perplexity PPL, log PPL difference, χ² divergence). + - Sequence-level breakdowns (deviation from ideal weights, outlier fraction). + + +## Key Interfaces & Usage +- compute_rollout_correction_and_rejection_mask(): compute IS weights + rejection mask + veto. +- compute_rollout_correction_weights(): only compute truncated IS weights (for variance + reduction, no outlier rejection). +- compute_rollout_rejection_mask(): only filter outliers (for sample cleaning, no IS weight + computation). +- compute_offpolicy_metrics(): called by core functions to calculate off-policy diagnostics + (KL/PPL/χ²) — no direct external calls needed. + +### Integration Notes +- Used in `ray_trainer.py` via `compute_rollout_correction_and_add_to_batch()` (batch training pipeline). +- Used in `dp_actor.py` for distributed worker computations (distributed training scenarios). +- All functions support batch inputs and valid token masking (via `response_mask`). + + +## References +- "When Speed Kills Stability" (LLM training stability analysis): https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda +- Off-policy RL (theoretical basis for IS): https://fengyao.notion.site/off-policy-rl +""" + +from typing import Any, Optional + +import torch + +import verl.utils.torch_functional as verl_F +from verl.protocol import DataProto +from verl.trainer.config.algorithm import RolloutCorrectionConfig +from verl.workers.config.actor import PolicyLossConfig + +# Safety bound to prevent numerical overflow/underflow when exponentiating +# exp(20) ≈ 485 million (upper limit for stable weights), exp(-20) ≈ 2e-9 (lower limit) +SAFETY_BOUND = 20.0 + + +def compute_rollout_rejection_mask( + log_ratio: torch.Tensor, + response_mask: torch.Tensor, + rollout_rs: str = "token", + rollout_rs_threshold: Optional[float] = None, + rollout_rs_threshold_lower: Optional[float] = None, +) -> tuple[torch.Tensor, dict[str, float]]: + """Compute rejection mask for outlier handling in off-policy RL training. + + This function identifies and masks outlier tokens/sequences using precomputed log ratios + (log(π_train / π_rollout)). It supports multiple aggregation levels and uses log-space + computations for numerical stability. + + Memory-efficient design: + - Log-space calculations to avoid overflow + - Fixed safety bounds on exponentiation + - Metrics computed without large intermediate tensors + + Args: + log_ratio: Log ratio of training policy probability to rollout policy probability, + shape (batch_size, seq_length). + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_rs: Rejection sampling aggregation level, must be one of: + - "token": Per-token outlier detection + - "sequence": Aggregate across entire sequence (product of token ratios) + - "geometric": Geometric mean across entire sequence + rollout_rs_threshold: Upper threshold for valid IS weights (required for outlier detection). + rollout_rs_threshold_lower: Lower threshold for valid IS weights. If None, defaults to 1/upper threshold. + + Returns: + Tuple containing: + modified_response_mask: Response mask with outliers masked (0=rejected), + shape (batch_size, seq_length). + metrics: Dictionary of rejection sampling metrics (all scalars), including: + - rollout_rs_mean/max/min: Statistic of IS weights + - rollout_rs_ratio_fraction_high/low: Fraction of weights exceeding thresholds + - rollout_rs_masked_fraction: Fraction of tokens rejected (unified for all modes) + - rollout_rs_seq_masked_fraction: Fraction of sequences rejected (mode-dependent) + """ + # Validate input parameters + valid_rs_levels = {"token", "sequence", "geometric"} + if rollout_rs not in valid_rs_levels: + raise ValueError(f"Invalid rollout_rs: {rollout_rs}. Must be one of {valid_rs_levels}.") + if rollout_rs_threshold is None: + raise ValueError("rollout_rs_threshold must be provided for rejection sampling.") + + # Set default lower threshold if not specified (reciprocal of upper threshold) + upper_threshold = rollout_rs_threshold + lower_threshold = rollout_rs_threshold_lower if rollout_rs_threshold_lower is not None else 1.0 / upper_threshold + + # Compute IS weights from log ratio (handles different aggregation levels) + if rollout_rs == "token": + # Per-token IS weight: exp(log(π_train/π_rollout)) with safety clamp + log_ratio_for_metrics: torch.Tensor = log_ratio + log_ratio_safe: torch.Tensor = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rollout_is_weights: torch.Tensor = torch.exp(log_ratio_safe) + + elif rollout_rs == "sequence": + # Sequence-level IS weight: product of token ratios (exp(sum(log ratios))) + log_ratio_sum: torch.Tensor = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze( + -1 + ) # Shape: (batch_size, 1) + log_ratio_for_metrics = log_ratio_sum + + log_ratio_sum_safe: torch.Tensor = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(log_ratio) # Broadcast to (batch_size, seq_length) + + elif rollout_rs == "geometric": + # Sequence-level geometric mean: exp(mean(log ratios)) + log_ratio_mean: torch.Tensor = verl_F.masked_mean(log_ratio, response_mask, axis=-1).unsqueeze( + -1 + ) # Shape: (batch_size, 1) + log_ratio_for_metrics = log_ratio_mean + + log_ratio_mean_safe: torch.Tensor = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rollout_is_weights = torch.exp(log_ratio_mean_safe).expand_as(log_ratio) + + else: + raise ValueError(f"Unsupported rollout_rs: {rollout_rs}") + + # Generate outlier mask: 1=valid (within [lower, upper] threshold), 0=outlier + mask: torch.Tensor = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold) + mask = mask.float() + + # Compute rejection sampling metrics + metrics: dict[str, float] = compute_rs_metrics( + rollout_is_weights=rollout_is_weights, + log_ratio_for_metrics=log_ratio_for_metrics, + response_mask=response_mask, + rollout_rs=rollout_rs, + rollout_rs_threshold=upper_threshold, + rollout_rs_threshold_lower=lower_threshold, + ) + + # Track token-level and sequence-level rejection rates + # rollout_rs_masked_fraction: fraction of tokens rejected (unified for all modes) + metrics["rollout_rs_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask).item() + + # rollout_rs_seq_masked_fraction: fraction of sequences rejected (mode-dependent) + if rollout_rs == "token": + # Token-level aggregation: sequence is rejected if any token is rejected + seq_has_masked: torch.Tensor = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0 + metrics["rollout_rs_seq_masked_fraction"] = seq_has_masked.float().mean().item() + else: + # Sequence-level aggregation: check first token's mask (all tokens in sequence have same mask) + metrics["rollout_rs_seq_masked_fraction"] = (1 - mask[:, 0]).mean().item() + + # Apply rejection mask to original response mask + modified_response_mask: torch.Tensor = response_mask * mask + + return modified_response_mask, metrics + + +def compute_rs_metrics( + rollout_is_weights: torch.Tensor, + log_ratio_for_metrics: torch.Tensor, + response_mask: torch.Tensor, + rollout_rs: str, + rollout_rs_threshold: float, + rollout_rs_threshold_lower: float, +) -> dict[str, float]: + """Compute comprehensive metrics for rejection sampling. + + This function calculates statistics for IS weights used in rejection sampling, + balancing numerical stability (using clamped weights) and accuracy (using log-space + for threshold checks). + + Args: + rollout_is_weights: Clamped IS weights (π_train / π_rollout), + shape (batch_size, seq_length). + log_ratio_for_metrics: Log ratio of training to rollout probabilities (unclamped), + shape varies by aggregation level. + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_rs: Rejection sampling aggregation level (matches compute_rollout_rejection_mask). + rollout_rs_threshold: Upper threshold for valid IS weights. + rollout_rs_threshold_lower: Lower threshold for valid IS weights. + + Returns: + Dictionary of rejection sampling metrics (all scalars). + """ + if not response_mask.any(): + raise ValueError("response_mask must contain at least one valid token (1).") + + metrics: dict[str, float] = {} + device: torch.device = rollout_is_weights.device + + # Precompute log thresholds for accurate threshold checks + log_threshold_upper: torch.Tensor = torch.log(torch.tensor(rollout_rs_threshold, device=device)) + log_threshold_lower: torch.Tensor = torch.log(torch.tensor(rollout_rs_threshold_lower, device=device)) + + # Compute metrics based on aggregation level + if rollout_rs in ["sequence", "geometric"]: + # Sequence-level aggregation: use log-space for accurate max/min/threshold checks + # True max/min (unclamped) converted with safety bounds + log_max: torch.Tensor = log_ratio_for_metrics.max() + log_min: torch.Tensor = log_ratio_for_metrics.min() + metrics["rollout_rs_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND)).item() + metrics["rollout_rs_min"] = torch.exp(log_min).item() + + # Mean uses clamped weights to avoid overflow + metrics["rollout_rs_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask).item() + + # Fraction of weights exceeding thresholds (log-space for accuracy) + # Both sequence and geometric modes operate at sequence level (batch_size, 1) + exceeds_upper: torch.Tensor = log_ratio_for_metrics > log_threshold_upper + below_lower: torch.Tensor = log_ratio_for_metrics < log_threshold_lower + metrics["rollout_rs_ratio_fraction_high"] = exceeds_upper.float().mean().item() + metrics["rollout_rs_ratio_fraction_low"] = below_lower.float().mean().item() + + else: # token-level + # Token-level aggregation: compute directly from clamped weights + metrics["rollout_rs_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask).item() + + # Fraction of tokens exceeding thresholds + rollout_is_above_threshold: torch.Tensor = rollout_is_weights > rollout_rs_threshold + rollout_is_below_threshold: torch.Tensor = rollout_is_weights < rollout_rs_threshold_lower + metrics["rollout_rs_ratio_fraction_high"] = verl_F.masked_mean( + rollout_is_above_threshold.float(), response_mask + ).item() + metrics["rollout_rs_ratio_fraction_low"] = verl_F.masked_mean( + rollout_is_below_threshold.float(), response_mask + ).item() + + # Max/min (mask out padding tokens first) + mask_bool: torch.Tensor = response_mask.bool() + metrics["rollout_rs_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max().item() + metrics["rollout_rs_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min().item() + + # Compute standard deviation (using clamped weights for stability) + mask_count: torch.Tensor = response_mask.sum() + if mask_count > 1: + # Clamp weights to threshold range to avoid squaring extreme values + weights_for_std: torch.Tensor = rollout_is_weights.clamp( + min=rollout_rs_threshold_lower, max=rollout_rs_threshold + ) + mean_clamped: torch.Tensor = verl_F.masked_mean(weights_for_std, response_mask) + # Variance = E[X²] - (E[X])² (masked to valid tokens) + rollout_is_var: torch.Tensor = ( + verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square() + ) + metrics["rollout_rs_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0)).item() + else: + metrics["rollout_rs_std"] = 0.0 + + # Compute Effective Sample Size (ESS) for IS weights + # ESS = 1 / E[(w_i / E[w_i])²] (using clamped weights for stability) + weights_for_ess: torch.Tensor = rollout_is_weights.clamp(min=rollout_rs_threshold_lower, max=rollout_rs_threshold) + mean_for_ess: torch.Tensor = verl_F.masked_mean(weights_for_ess, response_mask) + is_weights_normalized: torch.Tensor = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero + metrics["rollout_rs_eff_sample_size"] = ( + 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask).item() + ) + + # Add sequence-level metrics if weights have batch dimension + if rollout_is_weights.dim() > 1: + # Mean weight per sequence (masked to valid tokens) + seq_mean_weights: torch.Tensor = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1) + + metrics["rollout_rs_seq_mean"] = seq_mean_weights.mean().item() + metrics["rollout_rs_seq_std"] = seq_mean_weights.std().item() if seq_mean_weights.numel() > 1 else 0.0 + metrics["rollout_rs_seq_max"] = seq_mean_weights.max().item() + metrics["rollout_rs_seq_min"] = seq_mean_weights.min().item() + + # Sequence deviation from ideal weight (1.0) + seq_deviation: torch.Tensor = (seq_mean_weights - 1.0).abs() + metrics["rollout_rs_seq_max_deviation"] = seq_deviation.max().item() + + # Fraction of sequences with extreme weights + metrics["rollout_rs_seq_fraction_high"] = (seq_mean_weights > rollout_rs_threshold).float().mean().item() + metrics["rollout_rs_seq_fraction_low"] = (seq_mean_weights < rollout_rs_threshold_lower).float().mean().item() + + return metrics + + +def compute_rollout_correction_weights( + log_ratio: torch.Tensor, + response_mask: torch.Tensor, + rollout_is: str = "token", + rollout_is_threshold: float = 2.0, +) -> tuple[torch.Tensor, dict[str, float]]: + """Compute importance sampling weights to correct for off-policy distribution shifts. + + This function calculates IS weights (π_train / π_rollout) using log ratios for numerical stability. + It supports multiple aggregation levels and truncates extreme weights to prevent training instability. + + Key design: + - Log-space computations to avoid overflow + - Truncation of extreme weights (TIS: Truncated Importance Sampling) + - Metrics tracking for weight distribution analysis + + Args: + log_ratio: Log ratio of training policy probability to rollout policy probability, + shape (batch_size, seq_length). + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_is: IS weight aggregation level, must be one of: + - "token": Per-token weights (biased, low variance) + - "sequence": Per-sequence weight (product of tokens; unbiased, high variance) + rollout_is_threshold: Upper threshold for truncating extreme weights (e.g., 2.0), + default 2.0. + + Returns: + Tuple containing: + rollout_is_weights: Truncated IS weights (masked to zero for padding tokens), + shape (batch_size, seq_length). + metrics: Dictionary of IS weight metrics (all scalars), including: + - rollout_is_mean/max/min: Statistic of truncated weights + - rollout_is_eff_sample_size: Effective sample size (ESS) + - rollout_is_seq_*: Sequence-level weight statistics + """ + # Validate input parameters + valid_is_levels = {"token", "sequence"} + if rollout_is not in valid_is_levels: + raise ValueError(f"Invalid rollout_is: {rollout_is}. Must be one of {valid_is_levels}.") + if rollout_is_threshold <= 0: + raise ValueError(f"rollout_is_threshold must be positive, got {rollout_is_threshold}.") + + # Compute IS weights from log ratio (handles different aggregation levels) + if rollout_is == "token": + # Per-token IS weight: exp(log(π_train/π_rollout)) with safety clamp + log_ratio_for_metrics: torch.Tensor = log_ratio + log_ratio_safe: torch.Tensor = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rollout_is_weights: torch.Tensor = torch.exp(log_ratio_safe) + + elif rollout_is == "sequence": + # Sequence-level IS weight: product of token ratios (exp(sum(log ratios))) + log_ratio_sum: torch.Tensor = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze( + -1 + ) # Shape: (batch_size, 1) + log_ratio_for_metrics = log_ratio_sum + + log_ratio_sum_safe: torch.Tensor = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(log_ratio) # Broadcast to sequence length + + else: + raise ValueError(f"Unsupported rollout_is: {rollout_is}") + + # Zero out weights for padding tokens using response mask + rollout_is_weights = rollout_is_weights * response_mask + + # Compute IS weight metrics (BEFORE truncation to get accurate fraction_high/low) + metrics: dict[str, float] = compute_is_metrics( + rollout_is_weights=rollout_is_weights, + log_ratio_for_metrics=log_ratio_for_metrics, + response_mask=response_mask, + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + ) + + # Truncate extreme weights (TIS: Truncated Importance Sampling) + rollout_is_weights = rollout_is_weights.clamp(max=rollout_is_threshold) + + return rollout_is_weights, metrics + + +def compute_is_metrics( + rollout_is_weights: torch.Tensor, + log_ratio_for_metrics: torch.Tensor, + response_mask: torch.Tensor, + rollout_is: str, + rollout_is_threshold: float, +) -> dict[str, float]: + """Compute comprehensive metrics for truncated importance sampling weights. + + This function calculates statistics for truncated IS weights (TIS), using log-space + for accurate threshold checks and clamped weights for stable mean/std calculations. + + Args: + rollout_is_weights: Truncated IS weights (π_train / π_rollout), + shape (batch_size, seq_length). + log_ratio_for_metrics: Log ratio of training to rollout probabilities (unclamped), + shape varies by aggregation level. + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_is: IS weight aggregation level (matches compute_rollout_correction_weights). + rollout_is_threshold: Upper threshold for truncated IS weights. + + Returns: + Dictionary of IS weight metrics (all scalars). + """ + if not response_mask.any(): + raise ValueError("response_mask must contain at least one valid token (1).") + + metrics: dict[str, float] = {} + device: torch.device = rollout_is_weights.device + # Default lower threshold (reciprocal of upper threshold) + rollout_is_threshold_lower: float = 1.0 / rollout_is_threshold + + # Precompute log thresholds for accurate checks + log_threshold_upper: torch.Tensor = torch.log(torch.tensor(rollout_is_threshold, device=device)) + log_threshold_lower: torch.Tensor = torch.log(torch.tensor(rollout_is_threshold_lower, device=device)) + + # Compute metrics based on aggregation level + if rollout_is == "sequence": + # Sequence-level aggregation: use log-space for unclamped stats + log_max: torch.Tensor = log_ratio_for_metrics.max() + log_min: torch.Tensor = log_ratio_for_metrics.min() + metrics["rollout_is_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND)).item() + metrics["rollout_is_min"] = torch.exp(log_min).item() + + # Mean uses truncated weights to avoid overflow + metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask).item() + + # Fraction of weights exceeding thresholds (log-space for accuracy) + exceeds_upper: torch.Tensor = log_ratio_for_metrics > log_threshold_upper + below_lower: torch.Tensor = log_ratio_for_metrics < log_threshold_lower + metrics["rollout_is_ratio_fraction_high"] = exceeds_upper.float().mean().item() + metrics["rollout_is_ratio_fraction_low"] = below_lower.float().mean().item() + + else: # token-level + # Token-level aggregation: compute directly from truncated weights + metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask).item() + + # Fraction of tokens exceeding thresholds + rollout_is_above_threshold: torch.Tensor = rollout_is_weights > rollout_is_threshold + rollout_is_below_threshold: torch.Tensor = rollout_is_weights < rollout_is_threshold_lower + metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean( + rollout_is_above_threshold.float(), response_mask + ).item() + metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean( + rollout_is_below_threshold.float(), response_mask + ).item() + + # Max/min (mask out padding tokens) + mask_bool: torch.Tensor = response_mask.bool() + metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max().item() + metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min().item() + + # Compute standard deviation (using clamped weights for stability) + mask_count: torch.Tensor = response_mask.sum() + if mask_count > 1: + weights_for_std: torch.Tensor = rollout_is_weights.clamp( + min=rollout_is_threshold_lower, max=rollout_is_threshold + ) + mean_clamped: torch.Tensor = verl_F.masked_mean(weights_for_std, response_mask) + rollout_is_var: torch.Tensor = ( + verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square() + ) + metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0)).item() + else: + metrics["rollout_is_std"] = 0.0 + + # Compute Effective Sample Size (ESS) for truncated weights + weights_for_ess: torch.Tensor = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold) + mean_for_ess: torch.Tensor = verl_F.masked_mean(weights_for_ess, response_mask) + is_weights_normalized: torch.Tensor = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero + metrics["rollout_is_eff_sample_size"] = ( + 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask).item() + ) + + # Add sequence-level metrics if weights have batch dimension + if rollout_is_weights.dim() > 1: + seq_mean_weights: torch.Tensor = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1) + + metrics["rollout_is_seq_mean"] = seq_mean_weights.mean().item() + metrics["rollout_is_seq_std"] = seq_mean_weights.std().item() if seq_mean_weights.numel() > 1 else 0.0 + metrics["rollout_is_seq_max"] = seq_mean_weights.max().item() + metrics["rollout_is_seq_min"] = seq_mean_weights.min().item() + + # Sequence deviation from ideal weight (1.0) + seq_deviation: torch.Tensor = (seq_mean_weights - 1.0).abs() + metrics["rollout_is_seq_max_deviation"] = seq_deviation.max().item() + + # Fraction of sequences with extreme weights + metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean().item() + metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean().item() + + return metrics + + +def compute_rollout_correction_and_rejection_mask( + old_log_prob: torch.Tensor, + rollout_log_prob: torch.Tensor, + response_mask: torch.Tensor, + rollout_is: Optional[str] = None, + rollout_is_threshold: Optional[float] = 2.0, + rollout_rs: Optional[str] = None, + rollout_rs_threshold: Optional[float] = 2.0, + rollout_rs_threshold_lower: Optional[float] = None, + rollout_token_veto_threshold: Optional[float] = None, +) -> tuple[Optional[DataProto], torch.Tensor, dict[str, float]]: + """Unified interface for computing IS weights and rejection masks. + + This function combines IS weight calculation (truncated) and rejection sampling (masked) + into a single pipeline. It also applies a per-token veto for catastrophic outliers + (sequences with extremely low token ratios are fully rejected). + + Key design: + - Separation of IS weights (for variance reduction) and rejection masks (for sample filtering) + - Veto mechanism for catastrophic sequences (applied independently of other modes) + - Comprehensive metrics tracking for mismatch diagnosis + + Args: + old_log_prob: Log probabilities from the training policy (e.g., FSDP FP32), + shape (batch_size, seq_length). + rollout_log_prob: Log probabilities from the rollout policy (e.g., vLLM BF16), + shape (batch_size, seq_length). + response_mask: Binary mask for valid tokens (1=valid, 0=padding), + shape (batch_size, seq_length). + rollout_is: IS weight aggregation level (see compute_rollout_correction_weights for options). + Set to None to disable IS weight computation. + rollout_is_threshold: Upper threshold for truncated IS weights (used if rollout_is is set), + default 2.0. + rollout_rs: Rejection sampling aggregation level (see compute_rollout_rejection_mask for options). + Set to None to disable rejection sampling. + rollout_rs_threshold: Upper threshold for rejection sampling. Required if rollout_rs is enabled. + Default 2.0. + rollout_rs_threshold_lower: Lower threshold for rejection sampling (used if rollout_rs is set). + Defaults to 1/rollout_rs_threshold if None. + rollout_token_veto_threshold: Minimum allowed token-level IS weight. Sequences containing + any token below this threshold are fully rejected. Set to None to disable veto. + + Returns: + Tuple containing: + rollout_is_weights_proto: DataProto with IS weights (None if rollout_is is None), + key "rollout_is_weights", shape (batch_size, seq_length). + modified_response_mask: Response mask with rejection sampling and veto applied, + shape (batch_size, seq_length). + metrics: Dictionary of all metrics (prefixed with "rollout_corr/"), including: + - IS weight statistics + - Rejection sampling rates + - Veto statistics + - Policy mismatch metrics (KL, PPL, etc.) + """ + # Validate input masks + if not response_mask.any(): + raise ValueError("response_mask must contain at least one valid token (1).") + if old_log_prob.shape != rollout_log_prob.shape: + raise ValueError( + f"old_log_prob shape {old_log_prob.shape} does not match rollout_log_prob shape {rollout_log_prob.shape}." + ) + if old_log_prob.shape != response_mask.shape: + raise ValueError( + f"log_prob shape {old_log_prob.shape} does not match response_mask shape {response_mask.shape}." + ) + + # Step 1: Compute log ratio (log(π_train / π_rollout)) + log_ratio: torch.Tensor = old_log_prob - rollout_log_prob + device: torch.device = log_ratio.device + metrics: dict[str, float] = {} + + # Step 2: Compute IS weights (if enabled) + rollout_is_weights: Optional[torch.Tensor] = None + if rollout_is is not None and rollout_is_threshold is not None: + rollout_is_weights, is_metrics = compute_rollout_correction_weights( + log_ratio=log_ratio, + response_mask=response_mask, + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + ) + metrics.update(is_metrics) + + # Step 3: Compute rejection mask (if enabled) + modified_response_mask: torch.Tensor = response_mask.clone() + if rollout_rs is not None: + if rollout_rs_threshold is None: + raise ValueError( + "rollout_rs_threshold must be explicitly provided when rollout_rs is enabled. " + "Set rollout_rs_threshold to the desired threshold value." + ) + modified_response_mask, rs_metrics = compute_rollout_rejection_mask( + log_ratio=log_ratio, + response_mask=response_mask, + rollout_rs=rollout_rs, + rollout_rs_threshold=rollout_rs_threshold, + rollout_rs_threshold_lower=rollout_rs_threshold_lower, + ) + metrics.update(rs_metrics) + + # Step 4: Apply per-token veto (reject sequences with catastrophic tokens) + if rollout_token_veto_threshold is not None: + if rollout_token_veto_threshold <= 0: + raise ValueError(f"rollout_token_veto_threshold must be positive, got {rollout_token_veto_threshold}.") + + # Compute log threshold for numerical stability + log_veto_threshold: torch.Tensor = torch.log(torch.tensor(rollout_token_veto_threshold, device=device)) + # Identify catastrophic tokens (log ratio below threshold + valid mask) + catastrophic_tokens: torch.Tensor = (log_ratio < log_veto_threshold) & response_mask.bool() + # Check if sequence contains any catastrophic token + has_catastrophic: torch.Tensor = catastrophic_tokens.any(dim=-1, keepdim=True) + # Create veto mask (0=reject sequence, 1=keep) + veto_mask: torch.Tensor = (~has_catastrophic).float() + + # Track veto metrics + metrics["rollout_is_veto_fraction"] = has_catastrophic.float().mean().item() + metrics["rollout_is_catastrophic_token_fraction"] = verl_F.masked_mean( + catastrophic_tokens.float(), response_mask + ).item() + + # Apply veto to response mask (overrides previous rejection) + modified_response_mask = modified_response_mask * veto_mask + else: + # Add placeholder metrics if veto is disabled + metrics["rollout_is_veto_fraction"] = 0.0 + metrics["rollout_is_catastrophic_token_fraction"] = 0.0 + + # Step 5: Compute off-policy metrics (KL, PPL, χ², etc.) + offpolicy_metrics: dict[str, float] = compute_offpolicy_metrics( + old_log_prob=old_log_prob, + rollout_log_prob=rollout_log_prob, + response_mask=response_mask, + ) + metrics.update(offpolicy_metrics) + + # Step 6: Add "rollout_corr/" prefix to all metrics for logging consistency + metrics_scalar: dict[str, float] = {} + for key, value in metrics.items(): + if isinstance(value, torch.Tensor): + metrics_scalar[f"rollout_corr/{key}"] = value.item() + else: + metrics_scalar[f"rollout_corr/{key}"] = value + + # Step 7: Wrap IS weights in DataProto for consistency with API + rollout_is_weights_proto: Optional[DataProto] = None + if rollout_is_weights is not None: + rollout_is_weights_proto = DataProto.from_dict(tensors={"rollout_is_weights": rollout_is_weights}) + + return rollout_is_weights_proto, modified_response_mask, metrics_scalar + + +def compute_offpolicy_metrics( + old_log_prob: torch.Tensor, + rollout_log_prob: Optional[torch.Tensor], + response_mask: torch.Tensor, +) -> dict[str, Any]: + """Compute off-policy diagnostic metrics (helper function). + + This helper function operates on raw tensors and is used internally by: + - compute_rollout_correction_and_rejection_mask() in this module (automatically included) + - Tests (test_rollout_corr.py, test_rollout_corr_integration.py) + + These metrics help diagnose the off-policy gap between rollout and training policies, + which can arise from: + - Policy mismatch (e.g., vLLM BF16 vs FSDP FP32) + - Model staleness (training on trajectories from older checkpoints) + - General distribution shifts + + Key metrics: + - kl: Direct KL divergence estimator KL(π_rollout || π_training) + - k3_kl: K3 KL estimator for stability (more stable for small KL) + - training_ppl: Perplexity of training policy + - rollout_ppl: Perplexity of rollout policy + - log_ppl_diff: Difference in log perplexities + - ppl_ratio: Ratio of training PPL to rollout PPL + - chi2_token: Token-level χ² divergence E[ρ²] - 1 + - chi2_seq: Sequence-level χ² divergence E[(∏ρ_t)²] - 1 + + Args: + old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length) + rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length) + response_mask: Mask for valid tokens, shape (batch_size, seq_length) + + Returns: + Dictionary of off-policy metrics (without prefix) + """ + # Validate that we have at least one valid token + assert response_mask.any(), "Expected at least one valid token in response_mask" + + metrics = {} + + # 1. Training policy perplexity (always available) + # Formula: exp(-1/|T| * Σ log π_training(y_t|y_ tuple[DataProto, dict]: + """Compute rollout correction weights and apply rejection sampling. + + Computes importance sampling weights to correct for off-policy issues between + rollout and training policies. Applies rejection sampling by modifying response_mask. + Always updates response_mask; conditionally adds IS weights. + + Key behavior: + - response_mask: ALWAYS updated with rejection (veto + optional RS excluded from training) + - rollout_is_weights: Added to batch ONLY if rollout_is parameter is set + + This separation ensures: + - Rejection works independently of IS weight application + - Metrics can be monitored before enabling IS weight correction + + Args: + batch: DataProto with old_log_probs, rollout_log_probs, response_mask + + Returns: + Tuple of (updated_batch, metrics): + updated_batch: Batch with modified response_mask (always) and rollout_is_weights (if enabled) + metrics: Dict of IS and off-policy metrics, all with "rollout_corr/" prefix + + Note: + The implementation is copied from szrlee . + """ + # Get new API parameters directly from config + rollout_is = rollout_corr_config.get("rollout_is", None) + rollout_is_threshold = rollout_corr_config.get("rollout_is_threshold", 2.0) + rollout_rs = rollout_corr_config.get("rollout_rs", None) + rollout_rs_threshold = rollout_corr_config.get("rollout_rs_threshold", None) + rollout_rs_threshold_lower = rollout_corr_config.get("rollout_rs_threshold_lower", None) + rollout_token_veto_threshold = rollout_corr_config.get("rollout_token_veto_threshold", None) + + # Compute IS weights and get modified response_mask + rollout_is_weights, modified_response_mask, rollout_corr_metrics = compute_rollout_correction_and_rejection_mask( + old_log_prob=batch.batch["old_log_probs"], + rollout_log_prob=batch.batch["rollout_log_probs"], + response_mask=batch.batch["response_mask"], + rollout_is=rollout_is, + rollout_is_threshold=rollout_is_threshold, + rollout_rs=rollout_rs, + rollout_rs_threshold=rollout_rs_threshold, + rollout_rs_threshold_lower=rollout_rs_threshold_lower, + rollout_token_veto_threshold=rollout_token_veto_threshold, + ) + + # ALWAYS update response_mask with rejection applied + batch.batch["response_mask"] = modified_response_mask + + # Add IS weights to batch if computed + if rollout_is_weights is not None: + batch = batch.union(rollout_is_weights) + + return batch, rollout_corr_metrics + + +def maybe_apply_rollout_correction( + batch: DataProto, + rollout_corr_config: Optional[RolloutCorrectionConfig] = None, + policy_loss_config: PolicyLossConfig = None, +) -> bool: + """ + BYPASS MODE: Use rollout_log_probs as old_log_probs + Skips expensive actor forward pass for old_log_prob computation + + Two sub-modes (controlled by use_pure_rollout_correction in actor): + 1. PPO_IS mode (use_pure_rollout_correction=False, default): + - Actor uses standard PPO with old_log_prob=rollout_log_prob + - PPO clips ratio = π_current / π_rollout (not π_current / π_old) + + 2. Pure rollout correction mode (use_pure_rollout_correction=True): + - Actor uses compute_policy_loss_with_rollout_correction() + - Pure policy gradient with IS correction (no PPO clipping) + + Returns: + need_recomputation (bool): Whether recomputing logprobs is needed. + + Note: + The implementation is copied from szrlee . + """ + # Rollout correction mode selection + bypass_mode = rollout_corr_config.get("bypass_old_logprob_for_rollout", False) if rollout_corr_config else False + + if bypass_mode: + if "rollout_log_probs" not in batch.batch: + raise ValueError( + "bypass_old_logprob_for_rollout=True requires rollout_log_probs in batch. " + "Ensure rollout worker is configured to calculate_log_probs=true." + ) + + # Use rollout log probs as old log probs (zero-cost substitution) + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + # Check if pure rollout correction mode is enabled + use_pure_rollout_correction = rollout_corr_config.get("use_pure_rollout_correction", False) + + if use_pure_rollout_correction: + # Pure IS mode: Configure actor to use rollout_correction loss function + # This will use compute_policy_loss_with_rollout_correction (no PPO clipping) + policy_loss_config["loss_mode"] = "rollout_correction" + policy_loss_config["rollout_correction"] = rollout_corr_config + + return False + + return True diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 7dd531ad266..6e74de2b18b 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -439,7 +439,7 @@ def update_policy(self, data: DataProto): loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla - # Extract pre-computed rollout importance sampling weights if present + # Extract pre-computed rollout correction weights if present # Weights are computed centrally in trainer and added when algorithm.rollout_is=True rollout_is_weights = model_inputs.get("rollout_is_weights", None) diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index fbda5d0529c..fff696ba6b2 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -449,7 +449,7 @@ def loss_func(output, data, meta_info): policy_loss_fn = get_policy_loss_fn(loss_mode) - # Extract pre-computed rollout importance sampling weights if present + # Extract pre-computed rollout correction weights if present # Weights are computed centrally in trainer and added when algorithm.rollout_is=True rollout_is_weights = data.get("rollout_is_weights", None)