[BREAKING][algo] feat: Rollout Correction for General Off-Policy Problems#3984
[BREAKING][algo] feat: Rollout Correction for General Off-Policy Problems#3984ISEEKYAN merged 35 commits intoverl-project:mainfrom
Conversation
…d compatibility
Breaking change: Replace flat `algorithm.rollout_is_*` parameters with nested
`algorithm.rollout_correction` dict structure. Rename "Rollout Importance Sampling"
to "Rollout Correction" throughout codebase.
Changes:
- Replace algorithm.py flat parameters with rollout_correction: Optional[dict]
- Update ray_trainer to use nested config, remove backward compatibility
- Rename parameters: rollout_is_level→rollout_is, rollout_is_mode→rollout_rs,
rollout_is_veto_threshold→rollout_token_veto_threshold
- Rename files: rollout_is.md→rollout_corr.md, test_rollout_is.py→test_rollout_corr.py,
examples/rollout_importance_sampling/→examples/rollout_correction/
- Update all docs, examples, and recipes to use new nested structure
- Fix docs code examples to use config.algorithm.get("rollout_correction")
Migration required:
Old: algorithm.rollout_is_threshold: 2.0
New: algorithm.rollout_correction.rollout_is_threshold: 2.0
Files changed: 17 edited, 6 renamed (23 total)
Tests: 11/11 passing
There was a problem hiding this comment.
Code Review
This is a comprehensive refactoring that modernizes the configuration and terminology for rollout correction, moving from a flat structure to a more organized nested configuration under algorithm.rollout_correction. The breaking change is well-documented with a clear migration guide, and the separation of Importance Sampling (IS) and Rejection Sampling (RS) into distinct parameters (rollout_is and rollout_rs) greatly improves clarity and control. The code is cleaner, and the documentation updates across all 23 files are thorough. I found one critical issue where the implementation for the rollout_rs_threshold fallback is missing, which contradicts the documentation and will cause a runtime error for users relying on this feature. Once this is addressed, this will be an excellent pull request.
df0b980 to
ef7f8f5
Compare
|
Regarding the critical issue you identified: we've addressed it, but intentionally took a different approach than the suggested fallback. Here's our reasoning: Our Solution: Explicit Configuration (No Fallback)We removed the fallback behavior entirely and updated the documentation to match. This follows Python's "explicit is better than implicit" principle. Why no fallback to
Current Behavior: # ✅ Works - uses default 2.0
compute_rollout_correction_and_rejection_mask(
rollout_rs="token" # rollout_rs_threshold defaults to 2.0
)
# ✅ Works - uses explicit value
compute_rollout_correction_and_rejection_mask(
rollout_rs="token",
rollout_rs_threshold=5.0
)
# ❌ Raises clear error - catches misconfiguration
compute_rollout_correction_and_rejection_mask(
rollout_rs="token",
rollout_rs_threshold=None # ValueError with clear message
)What Changed (commit ef7f8f57):
The key distinction: we only raise an error when someone explicitly passes Does this approach work for your use case, or do you see a scenario where the fallback would be necessary? |
f4298d6 to
b27306c
Compare
|
@szrlee The new APIs are very powerful and general, but I would say it a little complicated for a new user to get started. We cannot expect our users use the APIs after reading and understand all the related papers, so it will be very crucial to provide more verified combinations of the parameters and their alias, such as So I think it will be very helpful for users to add docs about:
Thanks again for the incredible contributions! |
Yes, we should add a typed config class, e.g |
@ISEEKYAN @wuxibin89 this refactor (commit f8bd558) replaces dict-based config |
c1f6226 to
e0f84c4
Compare
3eecc51 to
828639b
Compare
828639b to
5426daa
Compare
| if rollout_rs == "token": | ||
| # Token-level aggregation: sequence is rejected if any token is rejected | ||
| seq_has_masked: torch.Tensor = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0 | ||
| metrics["rollout_rs_seq_masked_fraction"] = seq_has_masked.float().mean().item() |
There was a problem hiding this comment.
rollout_rs == "token", so it should be rollout_rs_token_masked_fraction instread of rollout_rs_seq_masked_fraction? maybe we can use the same metric name rollout_rs_masked_fraction for both modes?
| if rollout_rs == "sequence": | ||
| # Sequence-level: all tokens in a sequence have the same weight | ||
| metrics["rollout_rs_ratio_fraction_high"] = exceeds_upper.float().mean().item() | ||
| metrics["rollout_rs_ratio_fraction_low"] = below_lower.float().mean().item() | ||
| else: # geometric | ||
| # Broadcast threshold checks to match token dimensions | ||
| exceeds_upper_expanded: torch.Tensor = exceeds_upper.expand_as(response_mask) | ||
| below_lower_expanded: torch.Tensor = below_lower.expand_as(response_mask) | ||
| metrics["rollout_rs_ratio_fraction_high"] = verl_F.masked_mean( | ||
| exceeds_upper_expanded.float(), response_mask | ||
| ).item() | ||
| metrics["rollout_rs_ratio_fraction_low"] = verl_F.masked_mean( | ||
| below_lower_expanded.float(), response_mask | ||
| ).item() |
There was a problem hiding this comment.
why do geo and seq modes have difference methods of calculating rollout_rs_ratio_fraction? I thought they are all sequence level, with only difference on how to calculate mismatch(mean vs sum on logp).
f3cb76a to
a264c85
Compare
… and Add Batch Normalization (#4070) ## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR #3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ 0ef0e05](https://github.com/volcengine/verl/blob/0ef0e05b/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
…lems (verl-project#3984) ## Summary This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations. If you find the PR useful, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = september, } ``` **⚠️ BREAKING CHANGE**: Removes backward compatibility. Users must migrate to typed config. --- ## What's New ### 1. Typed Configuration with Presets **Before (deprecated):** ```yaml algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_level: token ``` **After (Python - Recommended):** ```python from verl.trainer.config.algorithm import RolloutCorrectionConfig # Use validated presets config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is() config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs() ``` **After (YAML):** ```yaml algorithm: rollout_correction: rollout_is: token rollout_is_threshold: 2.0 ``` **10 validated presets:** - `token_is()` / `token_tis()` - Per-token IS - `seq_is()` - Sequence-level IS - `seq_is_rs()` / `seq_mis()` - Sequence IS + rejection sampling - `geo_rs()` / `geo_mis()` - Geometric RS + veto - `ppo_is_bypass()` - Bypass mode (performance) - `pure_is()` - Pure policy gradient (no PPO clipping) - `disabled()` - Metrics only ### 2. Mathematical Documentation New comprehensive document: `docs/advance/rollout_corr_math.md` (585 lines) **Theoretical foundation:** - REINFORCE → PPO → Decoupled PPO progression - Batch size invariance: Decoupling proximal policy from behavior policy - Three-policy framework: π_rollout, π_old, π_θ **Complete formulations for:** - Off-policy REINFORCE (`pure_is`) - Standard PPO and bypass mode - Decoupled PPO (`token_is`, `seq_is`, `seq_is_rs`) - Rejection sampling (`geo_rs`) **Diagnostic metrics:** - KL divergence (direct and K3 estimators) - Perplexity and perplexity ratio - χ² divergence (token and sequence level) **Quality:** - Objective technical descriptions - All formulas mathematically verified - Cross-document consistency validated ### 3. Training Modes | Mode | Config | Policies | Speed | Description | |------|--------|----------|-------|-------------| | **Standard** | `bypass=false, pure=false` | 3 | Standard | Full decoupled PPO with batch size invariance | | **Bypass** | `bypass=true, pure=false` | 2 | **Fast** | PPO clips against rollout (faster) | | **Pure IS** | `bypass=true, pure=true` | 2 | **Fast** | Off-policy REINFORCE without clipping | **Example:** ```python # Bypass mode for performance config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) # Pure IS for research config = RolloutCorrectionConfig.pure_is(threshold=2.0) ``` ### 4. Chi-Squared Divergence Metrics Quantify off-policy severity: ```python rollout_corr/chi2_token # E[ρ²] - 1 rollout_corr/chi2_seq # E[(∏ρ)²] - 1 ``` **Interpretation:** - χ² = 0: Perfect on-policy - χ² < 1: Low off-policiness, stable - χ² ≥ 10: High off-policiness, need correction **Cleanup:** - Removed `mismatch_` prefix - All metrics under `rollout_corr/` namespace ### 5. Bug Fix **Critical fix:** - `rollout_rs="token"` with `rollout_rs_threshold=None` silently failed - Now raises `ValueError` with clear error message --- ## Migration Guide ### Example 1: Basic Token-level IS ```python # Old (no longer works) config.algorithm.rollout_is = True config.algorithm.rollout_is_threshold = 2.0 config.algorithm.rollout_is_level = "token" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is(threshold=2.0) ``` ### Example 2: Sequence IS + Rejection Sampling ```python # Old (no longer works) config.algorithm.rollout_is_level = "sequence" config.algorithm.rollout_is_mode = "mask" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs( is_threshold=2.0, rs_threshold=2.0 ) ``` ### Example 3: Disable ```yaml # Old rollout_is: false # New rollout_correction: null ``` --- ## References Liu, Li, Fu, Wang, Liu, Shen (2025). *When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch*. [Blog](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
… and Add Batch Normalization (verl-project#4070) ## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR verl-project#3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ e8ad3cd](https://github.com/volcengine/verl/blob/e8ad3cdb/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
…lems (verl-project#3984) ## Summary This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations. If you find the PR useful, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = september, } ``` **⚠️ BREAKING CHANGE**: Removes backward compatibility. Users must migrate to typed config. --- ## What's New ### 1. Typed Configuration with Presets **Before (deprecated):** ```yaml algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_level: token ``` **After (Python - Recommended):** ```python from verl.trainer.config.algorithm import RolloutCorrectionConfig # Use validated presets config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is() config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs() ``` **After (YAML):** ```yaml algorithm: rollout_correction: rollout_is: token rollout_is_threshold: 2.0 ``` **10 validated presets:** - `token_is()` / `token_tis()` - Per-token IS - `seq_is()` - Sequence-level IS - `seq_is_rs()` / `seq_mis()` - Sequence IS + rejection sampling - `geo_rs()` / `geo_mis()` - Geometric RS + veto - `ppo_is_bypass()` - Bypass mode (performance) - `pure_is()` - Pure policy gradient (no PPO clipping) - `disabled()` - Metrics only ### 2. Mathematical Documentation New comprehensive document: `docs/advance/rollout_corr_math.md` (585 lines) **Theoretical foundation:** - REINFORCE → PPO → Decoupled PPO progression - Batch size invariance: Decoupling proximal policy from behavior policy - Three-policy framework: π_rollout, π_old, π_θ **Complete formulations for:** - Off-policy REINFORCE (`pure_is`) - Standard PPO and bypass mode - Decoupled PPO (`token_is`, `seq_is`, `seq_is_rs`) - Rejection sampling (`geo_rs`) **Diagnostic metrics:** - KL divergence (direct and K3 estimators) - Perplexity and perplexity ratio - χ² divergence (token and sequence level) **Quality:** - Objective technical descriptions - All formulas mathematically verified - Cross-document consistency validated ### 3. Training Modes | Mode | Config | Policies | Speed | Description | |------|--------|----------|-------|-------------| | **Standard** | `bypass=false, pure=false` | 3 | Standard | Full decoupled PPO with batch size invariance | | **Bypass** | `bypass=true, pure=false` | 2 | **Fast** | PPO clips against rollout (faster) | | **Pure IS** | `bypass=true, pure=true` | 2 | **Fast** | Off-policy REINFORCE without clipping | **Example:** ```python # Bypass mode for performance config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) # Pure IS for research config = RolloutCorrectionConfig.pure_is(threshold=2.0) ``` ### 4. Chi-Squared Divergence Metrics Quantify off-policy severity: ```python rollout_corr/chi2_token # E[ρ²] - 1 rollout_corr/chi2_seq # E[(∏ρ)²] - 1 ``` **Interpretation:** - χ² = 0: Perfect on-policy - χ² < 1: Low off-policiness, stable - χ² ≥ 10: High off-policiness, need correction **Cleanup:** - Removed `mismatch_` prefix - All metrics under `rollout_corr/` namespace ### 5. Bug Fix **Critical fix:** - `rollout_rs="token"` with `rollout_rs_threshold=None` silently failed - Now raises `ValueError` with clear error message --- ## Migration Guide ### Example 1: Basic Token-level IS ```python # Old (no longer works) config.algorithm.rollout_is = True config.algorithm.rollout_is_threshold = 2.0 config.algorithm.rollout_is_level = "token" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is(threshold=2.0) ``` ### Example 2: Sequence IS + Rejection Sampling ```python # Old (no longer works) config.algorithm.rollout_is_level = "sequence" config.algorithm.rollout_is_mode = "mask" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs( is_threshold=2.0, rs_threshold=2.0 ) ``` ### Example 3: Disable ```yaml # Old rollout_is: false # New rollout_correction: null ``` --- ## References Liu, Li, Fu, Wang, Liu, Shen (2025). *When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch*. [Blog](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
… and Add Batch Normalization (verl-project#4070) ## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR verl-project#3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ 0ef0e05](https://github.com/volcengine/verl/blob/0ef0e05b/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
…lems (verl-project#3984) ## Summary This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations. If you find the PR useful, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = september, } ``` **⚠️ BREAKING CHANGE**: Removes backward compatibility. Users must migrate to typed config. --- ## What's New ### 1. Typed Configuration with Presets **Before (deprecated):** ```yaml algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_level: token ``` **After (Python - Recommended):** ```python from verl.trainer.config.algorithm import RolloutCorrectionConfig # Use validated presets config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is() config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs() ``` **After (YAML):** ```yaml algorithm: rollout_correction: rollout_is: token rollout_is_threshold: 2.0 ``` **10 validated presets:** - `token_is()` / `token_tis()` - Per-token IS - `seq_is()` - Sequence-level IS - `seq_is_rs()` / `seq_mis()` - Sequence IS + rejection sampling - `geo_rs()` / `geo_mis()` - Geometric RS + veto - `ppo_is_bypass()` - Bypass mode (performance) - `pure_is()` - Pure policy gradient (no PPO clipping) - `disabled()` - Metrics only ### 2. Mathematical Documentation New comprehensive document: `docs/advance/rollout_corr_math.md` (585 lines) **Theoretical foundation:** - REINFORCE → PPO → Decoupled PPO progression - Batch size invariance: Decoupling proximal policy from behavior policy - Three-policy framework: π_rollout, π_old, π_θ **Complete formulations for:** - Off-policy REINFORCE (`pure_is`) - Standard PPO and bypass mode - Decoupled PPO (`token_is`, `seq_is`, `seq_is_rs`) - Rejection sampling (`geo_rs`) **Diagnostic metrics:** - KL divergence (direct and K3 estimators) - Perplexity and perplexity ratio - χ² divergence (token and sequence level) **Quality:** - Objective technical descriptions - All formulas mathematically verified - Cross-document consistency validated ### 3. Training Modes | Mode | Config | Policies | Speed | Description | |------|--------|----------|-------|-------------| | **Standard** | `bypass=false, pure=false` | 3 | Standard | Full decoupled PPO with batch size invariance | | **Bypass** | `bypass=true, pure=false` | 2 | **Fast** | PPO clips against rollout (faster) | | **Pure IS** | `bypass=true, pure=true` | 2 | **Fast** | Off-policy REINFORCE without clipping | **Example:** ```python # Bypass mode for performance config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) # Pure IS for research config = RolloutCorrectionConfig.pure_is(threshold=2.0) ``` ### 4. Chi-Squared Divergence Metrics Quantify off-policy severity: ```python rollout_corr/chi2_token # E[ρ²] - 1 rollout_corr/chi2_seq # E[(∏ρ)²] - 1 ``` **Interpretation:** - χ² = 0: Perfect on-policy - χ² < 1: Low off-policiness, stable - χ² ≥ 10: High off-policiness, need correction **Cleanup:** - Removed `mismatch_` prefix - All metrics under `rollout_corr/` namespace ### 5. Bug Fix **Critical fix:** - `rollout_rs="token"` with `rollout_rs_threshold=None` silently failed - Now raises `ValueError` with clear error message --- ## Migration Guide ### Example 1: Basic Token-level IS ```python # Old (no longer works) config.algorithm.rollout_is = True config.algorithm.rollout_is_threshold = 2.0 config.algorithm.rollout_is_level = "token" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is(threshold=2.0) ``` ### Example 2: Sequence IS + Rejection Sampling ```python # Old (no longer works) config.algorithm.rollout_is_level = "sequence" config.algorithm.rollout_is_mode = "mask" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs( is_threshold=2.0, rs_threshold=2.0 ) ``` ### Example 3: Disable ```yaml # Old rollout_is: false # New rollout_correction: null ``` --- ## References Liu, Li, Fu, Wang, Liu, Shen (2025). *When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch*. [Blog](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
… and Add Batch Normalization (verl-project#4070) ## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR verl-project#3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ 9cf6cf9](https://github.com/volcengine/verl/blob/9cf6cf93/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
…lems (verl-project#3984) ## Summary This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations. If you find the PR useful, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = september, } ``` **⚠️ BREAKING CHANGE**: Removes backward compatibility. Users must migrate to typed config. --- ## What's New ### 1. Typed Configuration with Presets **Before (deprecated):** ```yaml algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_level: token ``` **After (Python - Recommended):** ```python from verl.trainer.config.algorithm import RolloutCorrectionConfig # Use validated presets config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is() config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs() ``` **After (YAML):** ```yaml algorithm: rollout_correction: rollout_is: token rollout_is_threshold: 2.0 ``` **10 validated presets:** - `token_is()` / `token_tis()` - Per-token IS - `seq_is()` - Sequence-level IS - `seq_is_rs()` / `seq_mis()` - Sequence IS + rejection sampling - `geo_rs()` / `geo_mis()` - Geometric RS + veto - `ppo_is_bypass()` - Bypass mode (performance) - `pure_is()` - Pure policy gradient (no PPO clipping) - `disabled()` - Metrics only ### 2. Mathematical Documentation New comprehensive document: `docs/advance/rollout_corr_math.md` (585 lines) **Theoretical foundation:** - REINFORCE → PPO → Decoupled PPO progression - Batch size invariance: Decoupling proximal policy from behavior policy - Three-policy framework: π_rollout, π_old, π_θ **Complete formulations for:** - Off-policy REINFORCE (`pure_is`) - Standard PPO and bypass mode - Decoupled PPO (`token_is`, `seq_is`, `seq_is_rs`) - Rejection sampling (`geo_rs`) **Diagnostic metrics:** - KL divergence (direct and K3 estimators) - Perplexity and perplexity ratio - χ² divergence (token and sequence level) **Quality:** - Objective technical descriptions - All formulas mathematically verified - Cross-document consistency validated ### 3. Training Modes | Mode | Config | Policies | Speed | Description | |------|--------|----------|-------|-------------| | **Standard** | `bypass=false, pure=false` | 3 | Standard | Full decoupled PPO with batch size invariance | | **Bypass** | `bypass=true, pure=false` | 2 | **Fast** | PPO clips against rollout (faster) | | **Pure IS** | `bypass=true, pure=true` | 2 | **Fast** | Off-policy REINFORCE without clipping | **Example:** ```python # Bypass mode for performance config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) # Pure IS for research config = RolloutCorrectionConfig.pure_is(threshold=2.0) ``` ### 4. Chi-Squared Divergence Metrics Quantify off-policy severity: ```python rollout_corr/chi2_token # E[ρ²] - 1 rollout_corr/chi2_seq # E[(∏ρ)²] - 1 ``` **Interpretation:** - χ² = 0: Perfect on-policy - χ² < 1: Low off-policiness, stable - χ² ≥ 10: High off-policiness, need correction **Cleanup:** - Removed `mismatch_` prefix - All metrics under `rollout_corr/` namespace ### 5. Bug Fix **Critical fix:** - `rollout_rs="token"` with `rollout_rs_threshold=None` silently failed - Now raises `ValueError` with clear error message --- ## Migration Guide ### Example 1: Basic Token-level IS ```python # Old (no longer works) config.algorithm.rollout_is = True config.algorithm.rollout_is_threshold = 2.0 config.algorithm.rollout_is_level = "token" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is(threshold=2.0) ``` ### Example 2: Sequence IS + Rejection Sampling ```python # Old (no longer works) config.algorithm.rollout_is_level = "sequence" config.algorithm.rollout_is_mode = "mask" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs( is_threshold=2.0, rs_threshold=2.0 ) ``` ### Example 3: Disable ```yaml # Old rollout_is: false # New rollout_correction: null ``` --- ## References Liu, Li, Fu, Wang, Liu, Shen (2025). *When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch*. [Blog](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
… and Add Batch Normalization (verl-project#4070) ## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR verl-project#3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ 0ef0e05](https://github.com/volcengine/verl/blob/0ef0e05b/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
|
Hello, I want to know why I get an error when using the following settings: The error message is: After testing, I found that the all-zero mask is the |
…lems (verl-project#3984) ## Summary This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations. If you find the PR useful, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = september, } ``` **⚠️ BREAKING CHANGE**: Removes backward compatibility. Users must migrate to typed config. --- ## What's New ### 1. Typed Configuration with Presets **Before (deprecated):** ```yaml algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_level: token ``` **After (Python - Recommended):** ```python from verl.trainer.config.algorithm import RolloutCorrectionConfig # Use validated presets config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is() config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs() ``` **After (YAML):** ```yaml algorithm: rollout_correction: rollout_is: token rollout_is_threshold: 2.0 ``` **10 validated presets:** - `token_is()` / `token_tis()` - Per-token IS - `seq_is()` - Sequence-level IS - `seq_is_rs()` / `seq_mis()` - Sequence IS + rejection sampling - `geo_rs()` / `geo_mis()` - Geometric RS + veto - `ppo_is_bypass()` - Bypass mode (performance) - `pure_is()` - Pure policy gradient (no PPO clipping) - `disabled()` - Metrics only ### 2. Mathematical Documentation New comprehensive document: `docs/advance/rollout_corr_math.md` (585 lines) **Theoretical foundation:** - REINFORCE → PPO → Decoupled PPO progression - Batch size invariance: Decoupling proximal policy from behavior policy - Three-policy framework: π_rollout, π_old, π_θ **Complete formulations for:** - Off-policy REINFORCE (`pure_is`) - Standard PPO and bypass mode - Decoupled PPO (`token_is`, `seq_is`, `seq_is_rs`) - Rejection sampling (`geo_rs`) **Diagnostic metrics:** - KL divergence (direct and K3 estimators) - Perplexity and perplexity ratio - χ² divergence (token and sequence level) **Quality:** - Objective technical descriptions - All formulas mathematically verified - Cross-document consistency validated ### 3. Training Modes | Mode | Config | Policies | Speed | Description | |------|--------|----------|-------|-------------| | **Standard** | `bypass=false, pure=false` | 3 | Standard | Full decoupled PPO with batch size invariance | | **Bypass** | `bypass=true, pure=false` | 2 | **Fast** | PPO clips against rollout (faster) | | **Pure IS** | `bypass=true, pure=true` | 2 | **Fast** | Off-policy REINFORCE without clipping | **Example:** ```python # Bypass mode for performance config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) # Pure IS for research config = RolloutCorrectionConfig.pure_is(threshold=2.0) ``` ### 4. Chi-Squared Divergence Metrics Quantify off-policy severity: ```python rollout_corr/chi2_token # E[ρ²] - 1 rollout_corr/chi2_seq # E[(∏ρ)²] - 1 ``` **Interpretation:** - χ² = 0: Perfect on-policy - χ² < 1: Low off-policiness, stable - χ² ≥ 10: High off-policiness, need correction **Cleanup:** - Removed `mismatch_` prefix - All metrics under `rollout_corr/` namespace ### 5. Bug Fix **Critical fix:** - `rollout_rs="token"` with `rollout_rs_threshold=None` silently failed - Now raises `ValueError` with clear error message --- ## Migration Guide ### Example 1: Basic Token-level IS ```python # Old (no longer works) config.algorithm.rollout_is = True config.algorithm.rollout_is_threshold = 2.0 config.algorithm.rollout_is_level = "token" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is(threshold=2.0) ``` ### Example 2: Sequence IS + Rejection Sampling ```python # Old (no longer works) config.algorithm.rollout_is_level = "sequence" config.algorithm.rollout_is_mode = "mask" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs( is_threshold=2.0, rs_threshold=2.0 ) ``` ### Example 3: Disable ```yaml # Old rollout_is: false # New rollout_correction: null ``` --- ## References Liu, Li, Fu, Wang, Liu, Shen (2025). *When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch*. [Blog](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
… and Add Batch Normalization (verl-project#4070) ## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR verl-project#3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ 86dd84b](https://github.com/volcengine/verl/blob/86dd84b2/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
…lems (verl-project#3984) ## Summary This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations. If you find the PR useful, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = september, } ``` **⚠️ BREAKING CHANGE**: Removes backward compatibility. Users must migrate to typed config. --- ## What's New ### 1. Typed Configuration with Presets **Before (deprecated):** ```yaml algorithm: rollout_is: true rollout_is_threshold: 2.0 rollout_is_level: token ``` **After (Python - Recommended):** ```python from verl.trainer.config.algorithm import RolloutCorrectionConfig # Use validated presets config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is() config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs() ``` **After (YAML):** ```yaml algorithm: rollout_correction: rollout_is: token rollout_is_threshold: 2.0 ``` **10 validated presets:** - `token_is()` / `token_tis()` - Per-token IS - `seq_is()` - Sequence-level IS - `seq_is_rs()` / `seq_mis()` - Sequence IS + rejection sampling - `geo_rs()` / `geo_mis()` - Geometric RS + veto - `ppo_is_bypass()` - Bypass mode (performance) - `pure_is()` - Pure policy gradient (no PPO clipping) - `disabled()` - Metrics only ### 2. Mathematical Documentation New comprehensive document: `docs/advance/rollout_corr_math.md` (585 lines) **Theoretical foundation:** - REINFORCE → PPO → Decoupled PPO progression - Batch size invariance: Decoupling proximal policy from behavior policy - Three-policy framework: π_rollout, π_old, π_θ **Complete formulations for:** - Off-policy REINFORCE (`pure_is`) - Standard PPO and bypass mode - Decoupled PPO (`token_is`, `seq_is`, `seq_is_rs`) - Rejection sampling (`geo_rs`) **Diagnostic metrics:** - KL divergence (direct and K3 estimators) - Perplexity and perplexity ratio - χ² divergence (token and sequence level) **Quality:** - Objective technical descriptions - All formulas mathematically verified - Cross-document consistency validated ### 3. Training Modes | Mode | Config | Policies | Speed | Description | |------|--------|----------|-------|-------------| | **Standard** | `bypass=false, pure=false` | 3 | Standard | Full decoupled PPO with batch size invariance | | **Bypass** | `bypass=true, pure=false` | 2 | **Fast** | PPO clips against rollout (faster) | | **Pure IS** | `bypass=true, pure=true` | 2 | **Fast** | Off-policy REINFORCE without clipping | **Example:** ```python # Bypass mode for performance config = RolloutCorrectionConfig.ppo_is_bypass(threshold=2.0) # Pure IS for research config = RolloutCorrectionConfig.pure_is(threshold=2.0) ``` ### 4. Chi-Squared Divergence Metrics Quantify off-policy severity: ```python rollout_corr/chi2_token # E[ρ²] - 1 rollout_corr/chi2_seq # E[(∏ρ)²] - 1 ``` **Interpretation:** - χ² = 0: Perfect on-policy - χ² < 1: Low off-policiness, stable - χ² ≥ 10: High off-policiness, need correction **Cleanup:** - Removed `mismatch_` prefix - All metrics under `rollout_corr/` namespace ### 5. Bug Fix **Critical fix:** - `rollout_rs="token"` with `rollout_rs_threshold=None` silently failed - Now raises `ValueError` with clear error message --- ## Migration Guide ### Example 1: Basic Token-level IS ```python # Old (no longer works) config.algorithm.rollout_is = True config.algorithm.rollout_is_threshold = 2.0 config.algorithm.rollout_is_level = "token" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.token_is(threshold=2.0) ``` ### Example 2: Sequence IS + Rejection Sampling ```python # Old (no longer works) config.algorithm.rollout_is_level = "sequence" config.algorithm.rollout_is_mode = "mask" # New config.algorithm.rollout_correction = RolloutCorrectionConfig.seq_is_rs( is_threshold=2.0, rs_threshold=2.0 ) ``` ### Example 3: Disable ```yaml # Old rollout_is: false # New rollout_correction: null ``` --- ## References Liu, Li, Fu, Wang, Liu, Shen (2025). *When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch*. [Blog](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
… and Add Batch Normalization (verl-project#4070) ## Overview This PR fixes bugs, refactors configuration for semantic clarity, and adds batch normalization support to the rollout correction implementation introduced in PR verl-project#3984. --- ## Bug Fixes ### 1. Metrics Computation Running in Wrong Mode⚠️ **Problem**: Rollout correction metrics were computed in **bypass mode** instead of **decoupled mode**, making them meaningless. **Root Cause**: Incorrect condition at [ray_trainer.py:1177-1180](verl/trainer/ppo/ray_trainer.py#L1177) ```python # BEFORE (incorrect - runs in bypass mode) if rollout_corr_config is not None and "rollout_log_probs" in batch.batch: batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch) ``` ```python # AFTER (correct - runs in decoupled mode only) if (rollout_corr_config is not None and "rollout_log_probs" in batch.batch and not bypass_recomputing_logprobs): # Only in decoupled mode batch, is_metrics = compute_rollout_correction_and_add_to_batch(batch, rollout_corr_config) ``` **Impact**: - IS weights and rejection sampling metrics are now computed only when meaningful (decoupled mode with 3 policies) - In bypass mode (2 policies), actor now correctly computes metrics from evolving π_θ vs π_rollout **Related Changes**: - Added clarifying comments in [ray_trainer.py:1104-1107](verl/trainer/ppo/ray_trainer.py#L1104) (operating mode selection) - Added clarifying comments in [ray_trainer.py:1175-1177](verl/trainer/ppo/ray_trainer.py#L1175) (metrics behavior) - Fixed actor metrics computation in [dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py) --- ## Configuration Refactor (Semantic Clarity) ### 2. Variable Renaming Renamed config variables to accurately reflect their semantics: | Old Name | New Name | Rationale | |----------|----------|-----------| | `bypass_old_logprob_for_rollout` | `bypass_mode` | Directly describes the operating mode (2-policy vs 3-policy) | | `use_pure_rollout_correction` | `use_policy_gradient` | Reflects actual choice: policy gradient loss vs Q-function loss | **Before** ([algorithm.py @ 9879bfb](https://github.com/volcengine/verl/blob/9879bfb6/verl/trainer/config/algorithm.py)): ```python bypass_old_logprob_for_rollout: bool = False # Unclear what "bypass" means use_pure_rollout_correction: bool = False # "Pure" is vague ``` **After** ([algorithm.py @ HEAD](verl/trainer/config/algorithm.py#L157)): ```python bypass_mode: bool = False # Clear: bypass or decoupled mode use_policy_gradient: bool = False # Clear: PG or Q-function loss ``` **Files Updated**: - Core config: [algorithm.py](verl/trainer/config/algorithm.py), [rollout_correction.yaml](verl/trainer/config/algorithm/rollout_correction.yaml) - Implementation: [ray_trainer.py](verl/trainer/ppo/ray_trainer.py), [rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py), [core_algos.py](verl/trainer/ppo/core_algos.py) - Examples: [run_with_rollout_corr.sh](examples/rollout_correction/run_with_rollout_corr.sh), [run_dapo_qwen2.5_32b_rollout_corr.sh](recipe/dapo/run_dapo_qwen2.5_32b_rollout_corr.sh) - Generated configs: [_generated_ppo_trainer.yaml](verl/trainer/config/_generated_ppo_trainer.yaml), [_generated_ppo_megatron_trainer.yaml](verl/trainer/config/_generated_ppo_megatron_trainer.yaml) --- ## New Feature: Batch Normalization ### 3. IS Weight Batch Normalization **Added**: `rollout_is_batch_normalize` config parameter ([algorithm.py:159](verl/trainer/config/algorithm.py#L159)) ```python rollout_is_batch_normalize: bool = False ``` **Purpose**: - Normalizes importance sampling weights to have mean=1.0 within each batch - Aligns normalization scope with IS aggregation level (token/sequence/geometric) - Helps stabilize training when policy drift is large **Behavior**: - `True`: IS weights normalized so mean=1.0 per batch (reduces variance) - `False`: Raw truncated IS weights used (standard behavior, default) **Documentation**: - Mathematical formulation: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Usage guide: [rollout_corr.md](docs/algo/rollout_corr.md) --- ## Documentation Overhaul ### 4. File Reorganization **Moved documentation to `docs/algo/`**: - `docs/advance/rollout_corr.md` → `docs/algo/rollout_corr.md` (+439 additions) - `docs/advance/rollout_corr_math.md` → `docs/algo/rollout_corr_math.md` (+459 additions) **Deleted redundant file**: - `examples/rollout_correction/README.md` (-253 lines) **Updated references**: - [docs/index.rst](docs/index.rst): Updated paths - [docs/advance/fully_async.md](docs/advance/fully_async.md): Updated cross-references ### 5. Preset Renaming for Clarity Renamed presets to clearly indicate operating mode: | Old Name | New Name | Operating Mode | Description | |----------|----------|----------------|-------------| | `token_is` | `decoupled_token_is` | Decoupled (3-policy) | Token-level IS weighting | | `seq_is` | `decoupled_seq_is` | Decoupled (3-policy) | Sequence-level IS weighting | | `geo_rs` | `decoupled_geo_rs` | Decoupled (3-policy) | Geometric rejection sampling | | `ppo_is_bypass` | `ppo_is_bypass` | Bypass (2-policy) | PPO with IS (unchanged) | | `pure_is` | `pg_is` | Bypass (2-policy) | Policy gradient + sequence IS | | N/A | `pg_rs` | Bypass (2-policy) | Policy gradient + geometric RS (new) | **Naming Convention**: - **Decoupled mode** presets: `decoupled_*` (requires old_log_prob computation) - **Bypass mode** presets: `pg_*` or `ppo_*` (skips old_log_prob computation) ### 6. Content Improvements **Cross-References**: - Added prominent links between [rollout_corr.md](docs/algo/rollout_corr.md) (usage guide) and [rollout_corr_math.md](docs/algo/rollout_corr_math.md) (mathematical foundations) **Clarified Loss Formulations**: - Changed examples from PPO to REINFORCE in [rollout_corr_math.md §3.3](docs/algo/rollout_corr_math.md) - **Rationale**: Separates IS weight mechanics from PPO clipping for clarity - Added note that REINFORCE examples can be combined with PPO clipping **New Sections**: - Dedicated batch normalization section: [rollout_corr_math.md §3.4](docs/algo/rollout_corr_math.md) - Improved operating mode explanations throughout --- ## Code Quality Improvements ### 7. Enhanced Comments and Documentation **Trainer Logic** ([ray_trainer.py](verl/trainer/ppo/ray_trainer.py)): - Lines 1104-1107: Operating mode selection logic - Lines 1175-1177: Metrics computation behavior explanation **Policy Loss** ([core_algos.py](verl/trainer/ppo/core_algos.py)): - Enhanced docstrings for `compute_policy_loss_with_rollout_correction` - Clarified when to use policy gradient vs Q-function loss **Actor Workers** ([dp_actor.py](verl/workers/actor/dp_actor.py), [megatron_actor.py](verl/workers/actor/megatron_actor.py)): - Added comments explaining bypass mode metrics computation ### 8. Code Simplification **Removed Unused Logic** ([rollout_corr_helper.py](verl/trainer/ppo/rollout_corr_helper.py)): - Removed unnecessary config parameters from metrics computation - Removed unused IS weight processing logic - Simplified metrics calculation flow **Improved Variable Reuse**: - Reused `need_recomputation` variable instead of redundant bypass mode checks - Reduced code duplication --- ## Commit History <details> <summary>18 commits (click to expand)</summary> 1. `7c9e41da` - fix(rollout_corr): compute metrics in actor for bypass mode and fix trainer bugs 2. `96ae2be1` - docs(rollout_corr): move to algo/ and add pure_rs preset 3. `c0ea9bdc` - feat(rollout_corr): add batch normalization option for IS weights 4. `7de6c5f9` - docs(rollout_corr_math): use REINFORCE in aggregation loss examples for clarity 5. `2b34cfee` - refactor(rollout_corr): simplify metrics computation by removing unused config and IS weight logic 6. `0c42f85a` - docs(rollout_corr): add prominent cross-references between usage and math docs 7. `fef8a48f` - docs(rollout_corr_math): add dedicated section for batch normalization 8. `08cc9c7d` - fix: docstring of compute_policy_loss_with_rollout_correction 9. `437a4aba` - feat: reuse need_recomputation instead of bypass_mode 10. `5f9a53bf` - feat: improve comments 11. `b2f63709` - feat: improve comments 12. `79cdbf2f` - feat: refactor bypass_recomputing_logprobs 13. `62e32701` - feat(rollout_corr): align batch normalization with IS aggregation level 14. `b5c19ff7` - docs(rollout_corr): rename decoupled mode presets for clarity and update examples 15. `11f9aa05` - fix(rollout_corr): correct metrics computation to run in decoupled mode only 16. `58565cb0` - docs(rollout_corr): rename presets for clarity and consistency 17. `8bb1a0e0` - refactor(rollout_corr): rename config vars for semantic clarity 18. `6002c00c` - refactor(rollout_corr): update implementation to use renamed config variables </details> --- ## Summary This PR systematically improves the rollout correction implementation through three key areas: 1. **Bug Fixes**: Corrected metrics computation to run in the appropriate mode 2. **Semantic Clarity**: Renamed variables to accurately reflect their purpose (`bypass_mode`, `use_policy_gradient`) 3. **Feature Addition**: Added batch normalization option for IS weights with comprehensive documentation All changes maintain backward compatibility while significantly improving code clarity, correctness, and maintainability. --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
|
File "verl/verl/trainer/ppo/rollout_corr_helper.py", line 754, in compute_offpolicy_metrics the same error,any one can help me ????? |
I think this situation is rare and data related, maybe you can increase the batch size to lower the probability of all invalid. Or you should add some function to skip this training step and redo the rollout until you have at least one valid data for training. |
Summary
This PR introduces a comprehensive overhaul of the rollout correction system with typed configuration, mathematical documentation, and performance optimizations.
If you find the PR useful, please consider citing:
What's New
1. Typed Configuration with Presets
Before (deprecated):
After (Python - Recommended):
After (YAML):
10 validated presets:
token_is()/token_tis()- Per-token ISseq_is()- Sequence-level ISseq_is_rs()/seq_mis()- Sequence IS + rejection samplinggeo_rs()/geo_mis()- Geometric RS + vetoppo_is_bypass()- Bypass mode (performance)pure_is()- Pure policy gradient (no PPO clipping)disabled()- Metrics onlyComplete document can be find on https://verl.readthedocs.io/en/latest/advance/rollout_corr.html.
YAML Configuration (Advanced)
For advanced customization or YAML-based configs:
2. Mathematical Documentation
New comprehensive document:
docs/advance/rollout_corr_math.md(585 lines)Theoretical foundation:
Complete formulations for:
pure_is)token_is,seq_is,seq_is_rs)geo_rs)Diagnostic metrics:
Quality:
3. Training Modes
bypass=false, pure=falsebypass=true, pure=falsebypass=true, pure=trueExample:
4. Chi-Squared Divergence Metrics
Quantify off-policy severity:
Interpretation:
Cleanup:
mismatch_prefixrollout_corr/namespace5. Bug Fix
Critical fix:
rollout_rs="token"withrollout_rs_threshold=Nonesilently failedValueErrorwith clear error messageMigration Guide
Example 1: Basic Token-level IS
Example 2: Sequence IS + Rejection Sampling
Example 3: Disable
References
Liu, Li, Fu, Wang, Liu, Shen (2025). When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch. Blog