Skip to content

feat: add speculative decoding during post-training#1785

Merged
terrykong merged 9 commits intoNVIDIA-NeMo:mainfrom
isomap:speculative-decoding-support
Feb 6, 2026
Merged

feat: add speculative decoding during post-training#1785
terrykong merged 9 commits intoNVIDIA-NeMo:mainfrom
isomap:speculative-decoding-support

Conversation

@isomap
Copy link
Contributor

@isomap isomap commented Jan 15, 2026

What does this PR do ?

Enable speculative decoding support in NeMo-RL using the vLLM backend during post-training (GRPO).

This PR integrates vLLM's speculative decoding capabilities into NeMo-RL, allowing for faster generation during the post-training phase. It includes necessary patches for vLLM to ensure correct metric collection and provides utility functions to track and report speculative decoding performance (e.g., acceptance rates) during training.

Key changes:

  • vLLM Patching: Adds a monkey patch to vllm.v1.engine.core_client to properly call post_step, which is essential for speculative decoding to function correctly in the v1 engine when VLLM_ENABLE_V1_MULTIPROCESSING=0. This is fixed upstream in vllm-project/vllm#30319 but not yet in a released version.
  • Metric Collection: Updates VllmGenerationWorker and VllmGeneration to collect speculative decoding counters (draft tokens, accepted tokens, etc.) from the underlying vLLM engine.
  • Aggregation Utilities: Adds helpers in nemo_rl/algorithms/utils.py to aggregate these metrics across multiple workers and compute derived metrics like "acceptance rate" and "draft efficiency".
  • Configuration: Automatically sets load_format="auto" in VllmConfig when speculative_config is detected, ensuring the model weights are loaded correctly for speculative execution.

Issues

List issues that this PR closes:
N/A

Usage

To enable speculative decoding, include the speculative_model and related parameters in your vllm_kwargs configuration:

policy:
  generation:
    backend: "vllm"
    vllm_kwargs:
      speculative_config:
        model: "nvidia/gpt-oss-120b-Eagle3-long-context",  # Example draft model
        num_speculative_tokens: 3
    # ... other config ...

Warning

Limitation: When using speculative decoding with vLLM < 0.12.0, generation log probabilities will be returned as 0. This means use_importance_sampling cannot be used. This is fixed in vllm-project/vllm#29223 and will be available in vLLM v0.12.0+.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • Speculative decoding metrics will now appear in your training logs under the spec_decode/ prefix if enabled.
  • The vLLM patch can be removed once NeMo-RL upgrades to a vLLM version that includes vllm-project/vllm#30319.

Summary by CodeRabbit

  • New Features

    • Added speculative decoding metrics tracking and reporting during training to monitor draft generation, token acceptance rates, and decode efficiency.
    • Enhanced configuration handling to properly detect and support speculative decoding setups.
  • Improvements

    • Enabled metrics collection and aggregation across distributed workers for better performance visibility.

✏️ Tip: You can customize this high-level summary in your review settings.

@isomap isomap requested review from a team as code owners January 15, 2026 22:21
@terrykong terrykong requested review from gshennvm and yfw January 15, 2026 22:23
Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the contribution @isomap !

@yfw @gshennvm to review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 15, 2026

📝 Walkthrough

Walkthrough

This change introduces speculative decoding metrics instrumentation throughout the GRPO training pipeline and vLLM generation infrastructure. New utilities aggregate and compute speculative decoding metrics from worker groups, vLLM generation classes expose metric collection methods, and GRPO training captures counter snapshots before and after generation to track speculative decoding performance.

Changes

Cohort / File(s) Summary
Speculative Decoding Metrics Utilities
nemo_rl/algorithms/utils.py
Adds two new public functions: aggregate_spec_decode_counters (collects and sums spec_decode metrics from multiple workers) and compute_spec_decode_metrics (computes deltas and derived metrics between counter snapshots). Includes defaultdict import.
GRPO Training Instrumentation
nemo_rl/algorithms/grpo.py
Imports new metric utilities; captures spec counters before generation and after batch completion; computes and merges spec_metrics into training metrics at two instrumentation points.
vLLM Generation Metrics API
nemo_rl/models/generation/vllm/vllm_generation.py
Adds get_metrics() method that collects speculative decoding metrics from all vLLM workers via RPC across DP rank 0 workers.
vLLM Worker Metrics & Patching
nemo_rl/models/generation/vllm/vllm_worker.py
Adds get_metrics() methods to both BaseVllmGenerationWorker and VllmGenerationWorker classes; introduces post_step patching logic for speculative decoding in InprocessClient.get_output; changes disable_log_stats default from True to False.
Speculative Configuration Detection
nemo_rl/models/generation/__init__.py
Detects speculative_config in vllm_kwargs; extends load_format decision logic to set "auto" if either is_eval or is_spec is true.

Sequence Diagram

sequenceDiagram
    participant GRPO as GRPO Training Loop
    participant PolicyGen as PolicyGeneration
    participant Worker as vLLM Worker
    participant Metrics as Metric Aggregation
    
    GRPO->>PolicyGen: policy_generation.get_metrics()
    PolicyGen->>Worker: RPC get_metrics() to rank 0 workers
    Worker-->>PolicyGen: return spec_counters dict
    PolicyGen-->>GRPO: aggregated worker_metrics list
    GRPO->>Metrics: spec_counters_start = aggregate_spec_decode_counters()
    
    Note over GRPO,Worker: Generation Phase
    GRPO->>PolicyGen: run generation step
    PolicyGen->>Worker: generate tokens with spec decode
    
    GRPO->>PolicyGen: policy_generation.get_metrics()
    PolicyGen->>Worker: RPC get_metrics() to rank 0 workers
    Worker-->>PolicyGen: return updated spec_counters dict
    PolicyGen-->>GRPO: aggregated worker_metrics list
    GRPO->>Metrics: spec_counters_end = aggregate_spec_decode_counters()
    
    GRPO->>Metrics: compute_spec_decode_metrics(start, end)
    Metrics-->>GRPO: spec_metrics (deltas, derived metrics)
    GRPO->>GRPO: merge spec_metrics into training metrics
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

enhancement, Performance

Suggested reviewers

  • parthchadha
  • terrykong
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major feature (speculative decoding) with incomplete tests, missing performance benchmarks, and undocumented impact of known vLLM compatibility issues on training correctness. Complete tests and include results. Provide performance benchmarks. Document vLLM version requirements and add guards for incompatible configurations.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding speculative decoding support during post-training in the GRPO algorithm, which is reflected in the extensive modifications across generation, utility, and algorithm files.
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 1152-1154: Calls to policy_generation.get_metrics() can raise for
backends (e.g., megatron) that don’t implement get_metrics; guard those calls by
checking capability and defaulting spec_metrics to {}. Update the places using
aggregate_spec_decode_counters(policy_generation.get_metrics()) (e.g., where
spec_counters_start is assigned and the other occurrences around lines
referenced) to first check if hasattr(policy_generation, "get_metrics") or
callable(getattr(policy_generation, "get_metrics", None)); if present call it
and pass the result to aggregate_spec_decode_counters, otherwise pass an empty
dict so spec_metrics/spec_counters_start defaults to {} and training won’t
break.
🧹 Nitpick comments (2)
nemo_rl/models/generation/__init__.py (1)

44-45: Guard against missing vllm_kwargs for custom configs.

If a user config omits vllm_kwargs (or sets it to None), this line will raise. A small defensive guard keeps behavior identical while avoiding a hard failure in edge configs.

♻️ Suggested guard
-        is_spec = "speculative_config" in config["vllm_kwargs"]
+        vllm_kwargs = config.get("vllm_kwargs") or {}
+        is_spec = "speculative_config" in vllm_kwargs
         config["vllm_cfg"]["load_format"] = "auto" if is_eval or is_spec else "dummy"
nemo_rl/models/generation/vllm/vllm_worker.py (1)

284-318: Make patch logging reflect whether it actually applied.

Right now the log says “Successfully patched…” even if the snippet wasn’t found (newer vLLM or already patched). Returning a boolean and logging accordingly avoids confusion during upgrades.

♻️ Suggested change
-        def _patch_vllm_speculative_decoding_post_step():
+        def _patch_vllm_speculative_decoding_post_step() -> bool:
@@
-            if new_snippet in content or old_snippet not in content:
-                return
+            if new_snippet in content or old_snippet not in content:
+                return False
@@
-            with open(file_to_patch, "w") as f:
-                f.write(content)
+            with open(file_to_patch, "w") as f:
+                f.write(content)
+            return True
@@
-        _patch_vllm_speculative_decoding_post_step()
-        logger.info("Successfully patched vllm speculative decoding post_step.")
+        if _patch_vllm_speculative_decoding_post_step():
+            logger.info("Successfully patched vllm speculative decoding post_step.")
+        else:
+            logger.info(
+                "Skipped vllm speculative decoding post_step patch (already patched or incompatible version)."
+            )

Also applies to: 325-326

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0b1a91 and 7452999.

📒 Files selected for processing (5)
  • nemo_rl/algorithms/grpo.py
  • nemo_rl/algorithms/utils.py
  • nemo_rl/models/generation/__init__.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/generation/__init__.py
  • nemo_rl/algorithms/utils.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/algorithms/grpo.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/generation/__init__.py
  • nemo_rl/algorithms/utils.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/algorithms/grpo.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/models/generation/__init__.py
  • nemo_rl/algorithms/utils.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/algorithms/grpo.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/generation/__init__.py
  • nemo_rl/algorithms/utils.py
  • nemo_rl/models/generation/vllm/vllm_worker.py
  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/algorithms/grpo.py
🧬 Code graph analysis (2)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • get_metrics (384-397)
nemo_rl/algorithms/grpo.py (4)
nemo_rl/algorithms/utils.py (3)
  • aggregate_spec_decode_counters (775-810)
  • calculate_baseline_and_std_per_prompt (80-157)
  • compute_spec_decode_metrics (813-879)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • get_metrics (384-397)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
  • get_metrics (526-546)
nemo_rl/data/packing/metrics.py (1)
  • update (52-91)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (6)
nemo_rl/models/generation/vllm/vllm_generation.py (1)

384-398: Looks good.

nemo_rl/algorithms/grpo.py (1)

41-44: Imports look fine.

nemo_rl/models/generation/vllm/vllm_worker.py (2)

456-456: LGTM.


526-546: Nice addition for metrics visibility.

nemo_rl/algorithms/utils.py (2)

18-18: No issues here.


775-879: Spec‑decode aggregation utilities look solid.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @isomap for the contribution! can you write a simple unit test to validate that the metrics are coming back as expected?

@isomap isomap requested a review from a team as a code owner January 30, 2026 01:09
Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very small comments, but ty for addressing the last review. i think after this we should be good to merge

@isomap isomap requested a review from terrykong January 31, 2026 05:33
@terrykong terrykong added the CI:L1 Run doctests, unit tests, and functional tests label Jan 31, 2026
@terrykong
Copy link
Collaborator

lgtm @isomap

two final things:

  1. could you resolve the DCO by signing your commits
  2. could you merge in main or rebase your PR to satisfy the Check if PR branch is up to date. It fails when the branch is not rooted on a recent enough main commit.

@isomap isomap force-pushed the speculative-decoding-support branch 2 times, most recently from 62b47ee to b972938 Compare February 1, 2026 00:34
@isomap isomap requested a review from terrykong February 1, 2026 00:45
@terrykong terrykong added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 1, 2026
@terrykong terrykong enabled auto-merge (squash) February 1, 2026 04:49
terrykong
terrykong previously approved these changes Feb 1, 2026
@terrykong terrykong removed the CI:L1 Run doctests, unit tests, and functional tests label Feb 2, 2026
@terrykong
Copy link
Collaborator

@isomap was infra issue with CI, restarted CI now

@terrykong
Copy link
Collaborator

@isomap looks like a CI failure from the new test, does it run on your end?

2026-02-02T19:21:38.1841799Z =================================== FAILURES ===================================
2026-02-02T19:21:38.1842257Z ______________ test_vllm_speculative_decoding_patch_still_needed _______________
2026-02-02T19:21:38.1842580Z 
2026-02-02T19:21:38.1842740Z     def test_vllm_speculative_decoding_patch_still_needed():
2026-02-02T19:21:38.1843189Z         # This test reminds to remove the vLLM patch when no longer needed.
2026-02-02T19:21:38.1843738Z         # The patch was fixed upstream: https://github.com/vllm-project/vllm/pull/30319
2026-02-02T19:21:38.1844306Z         # When this test fails, remove _patch_vllm_speculative_decoding_post_step()
2026-02-02T19:21:38.1844773Z         # from nemo_rl/models/generation/vllm/vllm_worker.py
2026-02-02T19:21:38.1845088Z         import os
2026-02-02T19:21:38.1845316Z         from importlib.util import find_spec
2026-02-02T19:21:38.1845583Z     
2026-02-02T19:21:38.1845772Z         spec = find_spec("vllm")
2026-02-02T19:21:38.1846103Z >       base_dir = next(iter(spec.submodule_search_locations))
2026-02-02T19:21:38.1846609Z                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2026-02-02T19:21:38.1847043Z E       AttributeError: 'NoneType' object has no attribute 'submodule_search_locations'

https://github.com/NVIDIA-NeMo/RL/actions/runs/21556877547/job/62252156482?pr=1785

auto-merge was automatically disabled February 3, 2026 07:51

Head branch was pushed to by a user without write access

auto-merge was automatically disabled February 5, 2026 01:06

Head branch was pushed to by a user without write access

@isomap isomap requested a review from terrykong February 5, 2026 01:07
isomap and others added 9 commits February 4, 2026 17:08
Signed-off-by: hiso <hiso@nvidia.com>
Signed-off-by: hiso <hiso@nvidia.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: Hayate Iso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
…rics

Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Signed-off-by: hiso <hiso@nvidia.com>
@isomap isomap force-pushed the speculative-decoding-support branch from a436b65 to d3e5e05 Compare February 5, 2026 01:08
@terrykong terrykong enabled auto-merge (squash) February 6, 2026 08:09
@terrykong terrykong removed the CI:L1 Run doctests, unit tests, and functional tests label Feb 6, 2026
@yuki-97 yuki-97 added the CI:L1 Run doctests, unit tests, and functional tests label Feb 6, 2026
@terrykong terrykong merged commit f6585a6 into NVIDIA-NeMo:main Feb 6, 2026
41 of 42 checks passed
@coderabbitai coderabbitai bot mentioned this pull request Feb 10, 2026
4 tasks
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 12, 2026
Signed-off-by: hiso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
@terrykong terrykong linked an issue Feb 13, 2026 that may be closed by this pull request
@coderabbitai coderabbitai bot mentioned this pull request Feb 20, 2026
4 tasks
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: hiso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: hiso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: hiso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
seonjinn pushed a commit that referenced this pull request Mar 9, 2026
Signed-off-by: hiso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hiso@nvidia.com>
Signed-off-by: Hayate Iso <hyate.iso@gmail.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[super-pr] disable_log_stats=false

3 participants