cp: feat: support truncated importance sampling (1348) into r0.4.0#1400
cp: feat: support truncated importance sampling (1348) into r0.4.0#1400
feat: support truncated importance sampling (1348) into r0.4.0#1400Conversation
Signed-off-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
📝 WalkthroughWalkthroughThis pull request adds support for truncated importance sampling ratio configuration across the GRPO loss function. A new optional field Changes
Sequence Diagram(s)sequenceDiagram
participant Config as Configuration
participant Loss as ClippedPGLossFn
participant Compute as Loss Computation
Config->>Loss: __init__(cfg with truncated_importance_sampling_ratio)
Loss->>Loss: Validate: if ratio set, then use_importance_sampling_correction must be True
Loss->>Loss: Validate: if ratio set, must be positive
Loss->>Loss: Store truncated_importance_sampling_ratio
Compute->>Loss: forward(actor_importance_weights, ...)
Loss->>Compute: Calculate raw importance weights
alt truncated_importance_sampling_ratio is set
Compute->>Compute: Clamp weights to ratio (truncate)
else ratio not set
Compute->>Compute: Use raw weights
end
Compute->>Loss: Return truncated/raw weights
Loss->>Loss: Compute clipped loss with processed weights
Loss->>Compute: Return loss
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes The changes span multiple heterogeneous areas: configuration schema extensions, loss function validation logic, test refactoring with shared base configs, and new test coverage. While individual changes are straightforward, the breadth across config files, implementation logic, and multiple test suites—combined with the need to verify validation constraints and test coverage—requires moderate review attention. Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/unit/algorithms/test_loss_functions.py (1)
447-447: Consider using deepcopy for consistency.While these tests don't modify
cfgand thus direct assignment is currently safe, usingdeepcopy(basic_pg_loss_test_config)would align with the pattern used throughout the file and prevent potential issues if these tests are modified in the future to mutate the config.Apply this pattern for consistency:
- cfg = basic_pg_loss_test_config + cfg = deepcopy(basic_pg_loss_test_config)Also applies to: 1155-1155
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/configs/grpo_math_1B.yaml(1 hunks)examples/configs/vlm_grpo_3B.yaml(1 hunks)examples/configs/vlm_grpo_3B_megatron.yaml(1 hunks)nemo_rl/algorithms/loss_functions.py(4 hunks)tests/unit/algorithms/test_grpo.py(1 hunks)tests/unit/algorithms/test_loss_functions.py(23 hunks)tests/unit/algorithms/test_sequence_packing_gradients.py(1 hunks)tests/unit/models/policy/test_dtensor_worker.py(1 hunks)tests/unit/models/policy/test_megatron_worker.py(4 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/vlm_grpo_3B.yamlexamples/configs/vlm_grpo_3B_megatron.yamlexamples/configs/grpo_math_1B.yaml
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
tests/unit/models/policy/test_dtensor_worker.pynemo_rl/algorithms/loss_functions.pytests/unit/algorithms/test_sequence_packing_gradients.pytests/unit/algorithms/test_grpo.pytests/unit/models/policy/test_megatron_worker.pytests/unit/algorithms/test_loss_functions.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (2)
tests/unit/models/policy/test_megatron_worker.py (1)
nemo_rl/algorithms/loss_functions.py (2)
ClippedPGLossConfig(38-51)NLLLoss(379-455)
tests/unit/algorithms/test_loss_functions.py (1)
nemo_rl/algorithms/loss_functions.py (2)
ClippedPGLossConfig(38-51)ClippedPGLossFn(67-376)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (14)
tests/unit/models/policy/test_dtensor_worker.py (1)
677-678: LGTM!The test configuration correctly includes the new fields
truncated_importance_sampling_ratioandsequence_level_importance_ratioswith appropriate default values.tests/unit/algorithms/test_sequence_packing_gradients.py (1)
136-137: LGTM!Test configuration properly updated to include the new loss function configuration fields.
tests/unit/algorithms/test_grpo.py (1)
618-619: LGTM!Test configuration consistently updated with the new configuration fields.
nemo_rl/algorithms/loss_functions.py (4)
45-45: LGTM!The new configuration field is properly typed and integrated into the
ClippedPGLossConfigTypedDict.
117-119: LGTM!The configuration value is correctly stored as an instance variable.
132-138: LGTM!The validation logic correctly enforces:
- TIS requires importance sampling correction to be enabled
- The ratio must be positive when provided
294-299: LGTM!The truncated importance sampling implementation correctly clamps the importance weights to the specified maximum ratio. The one-sided clamping (max only, no min) appears intentional for this TIS formulation.
tests/unit/models/policy/test_megatron_worker.py (4)
23-28: LGTM!Appropriate imports added to support the typed configuration and refactored test setup.
37-48: LGTM!Excellent refactoring to use a shared, typed configuration object. This reduces code duplication and ensures consistency across tests. The configuration includes all required fields with appropriate defaults, including the new
truncated_importance_sampling_ratioandsequence_level_importance_ratiosfields.
808-808: LGTM!Test now uses the shared configuration object, improving maintainability.
1704-1704: LGTM!Consistent use of the shared configuration object across all test cases.
tests/unit/algorithms/test_loss_functions.py (3)
15-15: LGTM: Imports support refactoring and type safety.The addition of
deepcopyenables safe test isolation by creating independent config copies, andClippedPGLossConfigprovides proper type hints for the test configuration.Also applies to: 21-21
30-41: LGTM: Excellent refactoring with base config.Introducing a typed base configuration reduces duplication and improves maintainability across all tests. The inclusion of
truncated_importance_sampling_ratio: Noneproperly disables TIS by default, and individual tests can enable it by creating a deepcopy and modifying the field.
893-1022: LGTM: Comprehensive test coverage for truncated importance sampling.This parameterized test thoroughly validates the TIS feature for both token-level and sequence-level importance ratios. The detailed hand calculations with intermediate assertions ensure correctness:
- Token-level mode: importance weights are correctly truncated from
[0.6065, 1.6487, 0.8187]to[0.6065, 0.8, 0.8]- Sequence-level mode: importance weight is correctly truncated from
0.8187to0.8- Expected losses are accurately computed for both modes
The test properly enables
use_importance_sampling_correctionas required by the TIS feature and includes clear comments explaining the expected behavior.
#1400) Signed-off-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com> Co-authored-by: Yuki Huang <yukih@nvidia.com>
beep boop [🤖]: Hi @yuki-97 👋,
Summary by CodeRabbit
New Features
Tests