Skip to content

feat: support truncated importance sampling#1348

Merged
terrykong merged 7 commits intomainfrom
yukih/tis
Oct 21, 2025
Merged

feat: support truncated importance sampling#1348
terrykong merged 7 commits intomainfrom
yukih/tis

Conversation

@yuki-97
Copy link
Contributor

@yuki-97 yuki-97 commented Oct 13, 2025

As title.

Summary by CodeRabbit

  • New Features
    • Added an optional truncated_importance_sampling_ratio in GRPO loss settings to cap per-token importance weights when importance-sampling correction is enabled, offering finer control over training stability.
    • Enabled this option in example GRPO configurations (default: null/disabled).
    • Added safeguards: the ratio must be positive, and is only valid with importance-sampling correction. Violations produce clear errors to guide correct usage.

@yuki-97 yuki-97 requested review from a team as code owners October 13, 2025 03:30
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 13, 2025

📝 Walkthrough

Walkthrough

Introduces a new optional configuration field, truncated_importance_sampling_ratio, in GRPO YAML configs and wires it into ClippedPGLossFn. The loss now validates this field, restricts its use to token-level importance sampling with correction enabled, ensures positivity, and applies clipping to token-level importance weights when configured.

Changes

Cohort / File(s) Summary
GRPO YAML configs
examples/configs/grpo_math_1B.yaml, examples/configs/vlm_grpo_3B.yaml, examples/configs/vlm_grpo_3B_megatron.yaml
Add loss_fn.truncated_importance_sampling_ratio: null option; no other config changes.
Loss function logic
nemo_rl/algorithms/loss_functions.py
Extend ClippedPGLossConfig with optional truncated_importance_sampling_ratio. Validate usage (requires use_importance_sampling_correction=True, token-level only, positive value). In token-level IS path, clip actor_importance_weights_expanded by this ratio when provided.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant C as Config
  participant T as Trainer
  participant L as ClippedPGLossFn
  participant D as Data (rewards, logprobs)

  C->>T: Provide loss_fn config (incl. truncated_importance_sampling_ratio)
  T->>L: Initialize with cfg
  L->>L: Validate<br/>- if ratio set: require IS correction enabled<br/>- forbid sequence-level IS<br/>- ratio > 0

  T->>L: forward(old_logprobs, new_logprobs, rewards, flags)
  alt IS correction enabled
    alt Sequence-level IS
      L->>L: Use sequence-level ratios (no truncation)
    else Token-level IS
      L->>L: Compute token IS weights
      opt ratio provided
        Note over L: Clip token IS weights to truncated_importance_sampling_ratio
      end
      L->>L: Compute clipped PG loss
    end
  else
    L->>L: Compute loss without IS correction
  end
  L-->>T: loss
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning The PR introduces a new truncated importance sampling capability inside ClippedPGLossFn, which is a substantive numerical change to the reinforcement learning loss, yet the PR description contains only “As title” and provides no testing or validation details, violating the requirement that major or numerics-affecting changes document test results. Please update the PR description with the relevant test results or validation evidence demonstrating that the new truncated importance sampling logic behaves correctly and does not regress existing training performance or convergence.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly describes the primary feature added—support for truncated importance sampling—matching the code changes that introduce the truncated_importance_sampling_ratio configuration and associated logic.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yukih/tis

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
nemo_rl/algorithms/loss_functions.py (1)

297-302: Reference public documentation instead of private Notion page.

The comment on line 297 links to https://fengyao.notion.site/off-policy-rl, which is a private page inaccessible to most developers. Consider referencing a public paper, arXiv preprint, or internal documentation instead.

Additionally, consider moving this truncation inside the if self.use_importance_sampling_correction: block (line 305) to make the dependency explicit, even though validation already prevents the invalid configuration:

 actor_importance_weights = actor_importance_weights_expanded
 del actor_importance_weights_expanded
 if self.use_importance_sampling_correction:
+    # TIS: Truncate importance weights if configured
+    if self.truncated_importance_sampling_ratio is not None:
+        actor_importance_weights = torch.clamp(
+            actor_importance_weights,
+            max=self.truncated_importance_sampling_ratio,
+        )
     importance_weights_to_use = actor_importance_weights
 else:

This would require removing the truncation from lines 297-302.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6d1d711 and 9445661.

📒 Files selected for processing (4)
  • examples/configs/grpo_math_1B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B_megatron.yaml (1 hunks)
  • nemo_rl/algorithms/loss_functions.py (4 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/algorithms/loss_functions.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/loss_functions.py
examples/configs/*.yaml

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/
.yaml

Files:

  • examples/configs/grpo_math_1B.yaml
  • examples/configs/vlm_grpo_3B.yaml
  • examples/configs/vlm_grpo_3B_megatron.yaml
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post automodel integration comment / Comment on PR
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/loss_functions.py (1)

132-141: LGTM!

The validation logic is thorough and correctly enforces the constraints:

  • Only usable when importance sampling correction is enabled
  • Only for token-level (not sequence-level) importance sampling
  • Must be positive if provided

@yuki-97 yuki-97 added the CI:L1 Run doctests, unit tests, and functional tests label Oct 13, 2025
@yuki-97 yuki-97 requested a review from a team as a code owner October 13, 2025 06:49
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 13, 2025
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 13, 2025
@yuki-97 yuki-97 requested a review from terrykong October 13, 2025 10:12
@parthchadha
Copy link
Contributor

@yuki-97 did you see any benefits of truncated importance sampling? I tried this in a bunch of experiments and didn't see any benefits.

@yuki-97
Copy link
Contributor Author

yuki-97 commented Oct 14, 2025

@yuki-97 did you see any benefits of truncated importance sampling? I tried this in a bunch of experiments and didn't see any benefits.

@parthchadha yea, research team mentioned tis has a significant improvement in their setting.

this is their train reward curve, val curve has the same trend, grey is w/ tis and blue is w/o.
image

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 14, 2025
Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feature lgtm

@yuki-97 i understand there's some setting by the research team that shows benefit. OOC did you see any benefit with any of our existing configs

Signed-off-by: Yuki Huang <yukih@nvidia.com>

TIS only supports for token level

Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97
Copy link
Contributor Author

yuki-97 commented Oct 18, 2025

feature lgtm

@yuki-97 i understand there's some setting by the research team that shows benefit. OOC did you see any benefit with any of our existing configs

@terrykong I tried some our existing configs, but unfortunately also didn't see any benefits.
Maybe just like what mentioned in https://fengyao.notion.site/off-policy-rl that "We also observed that, in cases where the probability difference is relatively small, introducing the additional Truncated Importance Sampling term cannot bring performance gain."

Also FYI, the setting of the research team uses DAPO, multiple training datasets, and some other tricks, not sure which is the key point to gain the benefit of using TIS.

@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 18, 2025
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 18, 2025
@terrykong terrykong added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 21, 2025
@terrykong terrykong enabled auto-merge (squash) October 21, 2025 05:04
@terrykong terrykong merged commit f2de476 into main Oct 21, 2025
72 of 74 checks passed
@terrykong terrykong deleted the yukih/tis branch October 21, 2025 09:49
chtruong814 pushed a commit that referenced this pull request Oct 21, 2025
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
lbliii pushed a commit that referenced this pull request Nov 3, 2025
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Lawrence Lane <llane@nvidia.com>
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: Yuki Huang <yukih@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests r0.4.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants