Skip to content

cp: feat: support truncated importance sampling (1348) into r0.4.0#1400

Merged
terrykong merged 1 commit intor0.4.0from
cherry-pick-1348-r0.4.0
Oct 21, 2025
Merged

cp: feat: support truncated importance sampling (1348) into r0.4.0#1400
terrykong merged 1 commit intor0.4.0from
cherry-pick-1348-r0.4.0

Conversation

@chtruong814
Copy link
Contributor

@chtruong814 chtruong814 commented Oct 21, 2025

beep boop [🤖]: Hi @yuki-97 👋,

we've cherry picked #1348 into  for you! 🚀

Please review and approve this cherry pick by your convenience!

Summary by CodeRabbit

  • New Features

    • Added truncated importance sampling ratio configuration option to loss function settings, enabling refined control over importance weight clamping during policy gradient optimization.
  • Tests

    • Added comprehensive test coverage for truncated importance sampling functionality and refactored test configuration for improved maintainability.

Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 21, 2025

📝 Walkthrough

Walkthrough

This pull request adds support for truncated importance sampling ratio configuration across the GRPO loss function. A new optional field truncated_importance_sampling_ratio is introduced to ClippedPGLossConfig, wired into ClippedPGLossFn with validation, and propagated through configuration files and tests.

Changes

Cohort / File(s) Summary
Configuration Files
examples/configs/grpo_math_1B.yaml, examples/configs/vlm_grpo_3B.yaml, examples/configs/vlm_grpo_3B_megatron.yaml
Added truncated_importance_sampling_ratio: null field under loss_fn block in each config file.
Loss Function Implementation
nemo_rl/algorithms/loss_functions.py
Added truncated_importance_sampling_ratio field to ClippedPGLossConfig. Implemented validation ensuring the field is used only with use_importance_sampling_correction=True and must be positive when set. Applied truncated importance sampling clamping to actor_importance_weights_expanded during loss computation.
Core Unit Tests
tests/unit/algorithms/test_loss_functions.py
Introduced basic_pg_loss_test_config: ClippedPGLossConfig base configuration. Refactored multiple test cases to reuse and modify the base config via deepcopy(). Added new test test_clipped_pg_loss_on_policy_truncated_importance_sampling parameterized by sequence_level_importance_ratios.
Test Configuration Updates
tests/unit/algorithms/test_grpo.py, tests/unit/algorithms/test_sequence_packing_gradients.py, tests/unit/models/policy/test_dtensor_worker.py
Added truncated_importance_sampling_ratio: None and sequence_level_importance_ratios: False to loss function configuration dictionaries in test setups.
Policy Worker Tests
tests/unit/models/policy/test_megatron_worker.py
Imported ClippedPGLossConfig and created basic_pg_loss_test_config variable. Replaced inline loss configuration dictionaries with reusable test config.

Sequence Diagram(s)

sequenceDiagram
    participant Config as Configuration
    participant Loss as ClippedPGLossFn
    participant Compute as Loss Computation
    
    Config->>Loss: __init__(cfg with truncated_importance_sampling_ratio)
    Loss->>Loss: Validate: if ratio set, then use_importance_sampling_correction must be True
    Loss->>Loss: Validate: if ratio set, must be positive
    Loss->>Loss: Store truncated_importance_sampling_ratio
    
    Compute->>Loss: forward(actor_importance_weights, ...)
    Loss->>Compute: Calculate raw importance weights
    
    alt truncated_importance_sampling_ratio is set
        Compute->>Compute: Clamp weights to ratio (truncate)
    else ratio not set
        Compute->>Compute: Use raw weights
    end
    
    Compute->>Loss: Return truncated/raw weights
    Loss->>Loss: Compute clipped loss with processed weights
    Loss->>Compute: Return loss
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

The changes span multiple heterogeneous areas: configuration schema extensions, loss function validation logic, test refactoring with shared base configs, and new test coverage. While individual changes are straightforward, the breadth across config files, implementation logic, and multiple test suites—combined with the need to verify validation constraints and test coverage—requires moderate review attention.

Possibly related PRs

Suggested labels

r0.4.0

Suggested reviewers

  • yuki-97
  • terrykong

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning This PR introduces truncated importance sampling as a significant feature affecting loss computation and convergence behavior in the GRPO algorithm. While the underlying feature was previously implemented in PR #1348 (as confirmed by git history), the current cherry-pick PR description contains only an automated message ("beep boop 🤖: we've cherry picked #1348...") without any documentation of test results, performance verification, or regression analysis. The instruction requires that PRs with major changes affecting numerics or convergence should include test results or information demonstrating no regression in the PR description itself. Although unit tests have been added to the codebase, the PR description does not document that these tests pass or provide any evidence of testing on the r0.4.0 branch. Update the PR description to include test results documenting that the feature passes all existing and new tests without regression. Since this is a cherry-pick of PR #1348, the description could reference the original PR's test results or include a summary stating that all tests pass on the r0.4.0 branch. Additionally, address the outstanding review comments requiring documentation of the truncated_importance_sampling_ratio configuration field in the example YAML files with inline comments explaining its purpose and constraints.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "cp: feat: support truncated importance sampling (1348) into r0.4.0" clearly communicates the primary change in the changeset. The core feature described—"support truncated importance sampling"—directly matches the main modifications across the pull request: adding a new truncated_importance_sampling_ratio configuration field to multiple YAML configs and implementing validation and post-processing logic in the loss functions module, along with corresponding test updates. While the title includes cherry-pick metadata ("cp:" and "into r0.4.0"), the central message remains specific and unambiguous about what feature is being added.
Docstring Coverage ✅ Passed Docstring coverage is 81.82% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch cherry-pick-1348-r0.4.0

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/unit/algorithms/test_loss_functions.py (1)

447-447: Consider using deepcopy for consistency.

While these tests don't modify cfg and thus direct assignment is currently safe, using deepcopy(basic_pg_loss_test_config) would align with the pattern used throughout the file and prevent potential issues if these tests are modified in the future to mutate the config.

Apply this pattern for consistency:

-    cfg = basic_pg_loss_test_config
+    cfg = deepcopy(basic_pg_loss_test_config)

Also applies to: 1155-1155

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a06941c and f7f2ace.

📒 Files selected for processing (9)
  • examples/configs/grpo_math_1B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B_megatron.yaml (1 hunks)
  • nemo_rl/algorithms/loss_functions.py (4 hunks)
  • tests/unit/algorithms/test_grpo.py (1 hunks)
  • tests/unit/algorithms/test_loss_functions.py (23 hunks)
  • tests/unit/algorithms/test_sequence_packing_gradients.py (1 hunks)
  • tests/unit/models/policy/test_dtensor_worker.py (1 hunks)
  • tests/unit/models/policy/test_megatron_worker.py (4 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/
.yaml

Files:

  • examples/configs/vlm_grpo_3B.yaml
  • examples/configs/vlm_grpo_3B_megatron.yaml
  • examples/configs/grpo_math_1B.yaml
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • tests/unit/models/policy/test_dtensor_worker.py
  • nemo_rl/algorithms/loss_functions.py
  • tests/unit/algorithms/test_sequence_packing_gradients.py
  • tests/unit/algorithms/test_grpo.py
  • tests/unit/models/policy/test_megatron_worker.py
  • tests/unit/algorithms/test_loss_functions.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (2)
tests/unit/models/policy/test_megatron_worker.py (1)
nemo_rl/algorithms/loss_functions.py (2)
  • ClippedPGLossConfig (38-51)
  • NLLLoss (379-455)
tests/unit/algorithms/test_loss_functions.py (1)
nemo_rl/algorithms/loss_functions.py (2)
  • ClippedPGLossConfig (38-51)
  • ClippedPGLossFn (67-376)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (14)
tests/unit/models/policy/test_dtensor_worker.py (1)

677-678: LGTM!

The test configuration correctly includes the new fields truncated_importance_sampling_ratio and sequence_level_importance_ratios with appropriate default values.

tests/unit/algorithms/test_sequence_packing_gradients.py (1)

136-137: LGTM!

Test configuration properly updated to include the new loss function configuration fields.

tests/unit/algorithms/test_grpo.py (1)

618-619: LGTM!

Test configuration consistently updated with the new configuration fields.

nemo_rl/algorithms/loss_functions.py (4)

45-45: LGTM!

The new configuration field is properly typed and integrated into the ClippedPGLossConfig TypedDict.


117-119: LGTM!

The configuration value is correctly stored as an instance variable.


132-138: LGTM!

The validation logic correctly enforces:

  1. TIS requires importance sampling correction to be enabled
  2. The ratio must be positive when provided

294-299: LGTM!

The truncated importance sampling implementation correctly clamps the importance weights to the specified maximum ratio. The one-sided clamping (max only, no min) appears intentional for this TIS formulation.

tests/unit/models/policy/test_megatron_worker.py (4)

23-28: LGTM!

Appropriate imports added to support the typed configuration and refactored test setup.


37-48: LGTM!

Excellent refactoring to use a shared, typed configuration object. This reduces code duplication and ensures consistency across tests. The configuration includes all required fields with appropriate defaults, including the new truncated_importance_sampling_ratio and sequence_level_importance_ratios fields.


808-808: LGTM!

Test now uses the shared configuration object, improving maintainability.


1704-1704: LGTM!

Consistent use of the shared configuration object across all test cases.

tests/unit/algorithms/test_loss_functions.py (3)

15-15: LGTM: Imports support refactoring and type safety.

The addition of deepcopy enables safe test isolation by creating independent config copies, and ClippedPGLossConfig provides proper type hints for the test configuration.

Also applies to: 21-21


30-41: LGTM: Excellent refactoring with base config.

Introducing a typed base configuration reduces duplication and improves maintainability across all tests. The inclusion of truncated_importance_sampling_ratio: None properly disables TIS by default, and individual tests can enable it by creating a deepcopy and modifying the field.


893-1022: LGTM: Comprehensive test coverage for truncated importance sampling.

This parameterized test thoroughly validates the TIS feature for both token-level and sequence-level importance ratios. The detailed hand calculations with intermediate assertions ensure correctness:

  • Token-level mode: importance weights are correctly truncated from [0.6065, 1.6487, 0.8187] to [0.6065, 0.8, 0.8]
  • Sequence-level mode: importance weight is correctly truncated from 0.8187 to 0.8
  • Expected losses are accurately computed for both modes

The test properly enables use_importance_sampling_correction as required by the TIS feature and includes clear comments explaining the expected behavior.

@terrykong terrykong added the CI:L1 Run doctests, unit tests, and functional tests label Oct 21, 2025
@terrykong terrykong enabled auto-merge (squash) October 21, 2025 16:26
@terrykong terrykong merged commit 5514d1e into r0.4.0 Oct 21, 2025
68 of 71 checks passed
@terrykong terrykong deleted the cherry-pick-1348-r0.4.0 branch October 21, 2025 20:59
terrykong pushed a commit that referenced this pull request Nov 19, 2025
#1400)

Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Co-authored-by: Yuki Huang <yukih@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick CI:L1 Run doctests, unit tests, and functional tests Run CICD

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants