feat: Add guided decoding passthrough to vLLM#827
feat: Add guided decoding passthrough to vLLM#827ybgao-nvidia wants to merge 19 commits intomainfrom
Conversation
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com>
wangshangsam
left a comment
There was a problem hiding this comment.
A few nits, but overall LGTM!
@SahilJain314 wanna take another look (in case I missed anything)?
Co-authored-by: Shang Wang <samshang.wang@mail.utoronto.ca> Signed-off-by: Yubo Gao <yubog@nvidia.com>
Co-authored-by: Shang Wang <samshang.wang@mail.utoronto.ca> Signed-off-by: Yubo Gao <yubog@nvidia.com>
|
@parthchadha can you take a quick look as well before merge? |
|
@ybgao-nvidia I dont need to review code :) you can remove me from the list of reviewers. Thank you! |
Signed-off-by: Yubo Gao <yubog@nvidia.com>
📝 WalkthroughWalkthroughThis PR adds guided decoding support to NeMo-RL's vLLM generation pipeline by introducing an optional Changes
Sequence DiagramsequenceDiagram
participant Rollout as Rollout Layer
participant GenInterface as Generation Interface
participant VllmGen as VllmGeneration
participant VllmWorker as VllmWorker
participant vLLM as vLLM Library
Rollout->>GenInterface: generate_responses(data, guided_decoding_config)
GenInterface->>VllmGen: generate(data, guided_decoding_config)
VllmGen->>VllmWorker: generate(data, guided_decoding_config)
activate VllmWorker
VllmWorker->>VllmWorker: _get_vllm_guided_decoding_params(guided_decoding_config)
VllmWorker->>VllmWorker: _build_sampling_params(..., guided_decoding_params)
deactivate VllmWorker
VllmWorker->>vLLM: generate_completion(sampling_params with guided_decoding)
vLLM-->>VllmWorker: structured output (matches constraints)
VllmWorker-->>VllmGen: BatchedDataDict
VllmGen-->>GenInterface: BatchedDataDict
GenInterface-->>Rollout: BatchedDataDict
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Areas requiring extra attention:
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (4 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/models/generation/interfaces.py (1)
143-157: Document the newguided_decodingconfig key.Per NeMo-RL config guidelines, every new
TypedDictkey must document its purpose, valid values, and recommended default.GenerationConfignow exposesguided_decoding, but the class docstring still omits it, so downstream users won’t know how to populate it. Please describe the field (e.g., that it accepts aGuidedDecodingConfigand defaults toNone) alongside the other keys. Based on learningsnemo_rl/experience/rollouts.py (1)
599-608: Critical: Missing parameter forwarding breaks guided decoding.The function accepts
guided_decoding_configbut doesn't forward it togenerate_responses_async, breaking guided decoding for async single-sample rollouts.Apply this diff:
updated_batch, generated_ids, gen_metrics = await generate_responses_async( policy_generation, generation_input_data, dummy_batch, tokenizer, input_lengths=input_lengths, include_logprobs=True, greedy=greedy, + guided_decoding_config=guided_decoding_config, )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
nemo_rl/experience/rollouts.py(17 hunks)nemo_rl/models/generation/interfaces.py(4 hunks)nemo_rl/models/generation/vllm/vllm_generation.py(11 hunks)nemo_rl/models/generation/vllm/vllm_worker.py(9 hunks)nemo_rl/models/generation/vllm/vllm_worker_async.py(6 hunks)nemo_rl/models/policy/lm_policy.py(2 hunks)tests/unit/models/generation/test_vllm_generation.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/policy/lm_policy.pytests/unit/models/generation/test_vllm_generation.pynemo_rl/experience/rollouts.pynemo_rl/models/generation/interfaces.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/policy/lm_policy.pynemo_rl/experience/rollouts.pynemo_rl/models/generation/interfaces.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.py
🧠 Learnings (3)
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
PR: NVIDIA-NeMo/RL#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Applied to files:
nemo_rl/models/policy/lm_policy.py
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
PR: NVIDIA-NeMo/RL#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
Applied to files:
nemo_rl/models/generation/interfaces.py
📚 Learning: 2025-09-20T14:58:45.492Z
Learnt from: CR
PR: NVIDIA-NeMo/RL#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-09-20T14:58:45.492Z
Learning: Applies to nemo_rl/**/*.py : Express configuration optionality via TypedDict using typing.NotRequired
Applied to files:
nemo_rl/models/generation/interfaces.py
🧬 Code graph analysis (7)
nemo_rl/models/policy/lm_policy.py (2)
nemo_rl/models/generation/interfaces.py (3)
GuidedDecodingConfig(118-139)GenerationDatumSpec(159-190)GenerationOutputSpec(193-237)nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-860)
tests/unit/models/generation/test_vllm_generation.py (4)
nemo_rl/models/generation/interfaces.py (2)
GuidedDecodingConfig(118-139)generate(251-257)tests/unit/environments/test_retriever.py (2)
cluster(97-114)tokenizer(84-93)nemo_rl/models/generation/vllm/vllm_generation.py (2)
generate(428-480)shutdown(775-782)nemo_rl/models/generation/vllm/vllm_worker.py (2)
generate(457-588)shutdown(792-812)
nemo_rl/experience/rollouts.py (1)
nemo_rl/models/generation/interfaces.py (1)
GuidedDecodingConfig(118-139)
nemo_rl/models/generation/interfaces.py (1)
nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-860)
nemo_rl/models/generation/vllm/vllm_worker_async.py (2)
nemo_rl/models/generation/interfaces.py (1)
GuidedDecodingConfig(118-139)nemo_rl/models/generation/vllm/vllm_worker.py (1)
_get_vllm_guided_decoding_params(345-368)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
nemo_rl/models/generation/interfaces.py (1)
GuidedDecodingConfig(118-139)
nemo_rl/models/generation/vllm/vllm_generation.py (2)
nemo_rl/models/generation/interfaces.py (2)
GuidedDecodingConfig(118-139)GenerationDatumSpec(159-190)nemo_rl/models/generation/vllm/vllm_worker_async.py (1)
generate_async(509-732)
🪛 Ruff (0.14.2)
nemo_rl/experience/rollouts.py
554-554: Unused function argument: guided_decoding_config
(ARG001)
nemo_rl/models/generation/vllm/vllm_worker.py
366-368: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Check if PR branch is up to date
- GitHub Check: Lint check
- GitHub Check: Check submodule fast-forward / Check submodule fast-forward
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (11)
nemo_rl/experience/rollouts.py (5)
58-73: LGTM: Parameter correctly threaded through to generation interface.The
guided_decoding_configparameter is properly forwarded topolicy_generation.generate().
125-155: LGTM: Parameter correctly threaded through async generation.The
guided_decoding_configparameter is properly forwarded topolicy_generation.generate_async().
340-430: LGTM: Docstring updated and parameter correctly forwarded.The docstring now documents the
guided_decoding_configparameter (line 352), and the parameter is correctly forwarded togenerate_responsesat line 429.
625-688: LGTM: Docstring updated and parameter correctly forwarded.The docstring documents the
guided_decoding_configparameter (line 641), and the parameter is correctly forwarded toasync_generate_response_for_sample_turnat line 687.
796-849: LGTM: Docstring updated and parameter correctly forwarded.The docstring documents the
guided_decoding_configparameter (line 811), and the parameter is correctly forwarded torun_sample_multi_turn_rolloutat line 848.nemo_rl/models/generation/vllm/vllm_generation.py (6)
19-41: LGTM: Proper use of TYPE_CHECKING for conditional imports.The TYPE_CHECKING import pattern correctly avoids runtime dependency on vLLM's
GuidedDecodingParamswhile enabling type hints.
428-457: LGTM: Parameter correctly threaded to workers.The
guided_decoding_configparameter is properly forwarded to worker methods viacommon_kwargs.
482-514: LGTM: Parameter correctly threaded to workers.The
guided_decoding_paramsparameter is properly forwarded to worker methods viacommon_kwargs.
534-578: LGTM: Flexible parameter passing via kwargs.Using
**kwargsin the base method appropriately supports different parameter names (guided_decoding_configvsguided_decoding_params) required by different callers.
664-692: LGTM: Parameter correctly forwarded to base method.The
guided_decoding_paramsparameter is properly forwarded to_async_generate_base.
694-722: LGTM: Parameter correctly forwarded to base method.The
guided_decoding_configparameter is properly forwarded to_async_generate_base.
terrykong
left a comment
There was a problem hiding this comment.
small comment
@parthchadha to review
| vllm_config["max_new_tokens"] = 16 | ||
| vllm_config["vllm_cfg"]["async_engine"] = False | ||
| vllm_config = configure_generation_config(vllm_config, tokenizer) | ||
| vllm_policy = VllmGeneration(cluster, vllm_config) |
There was a problem hiding this comment.
should we also test that the generation log probs also match our expectations: logprob=0 (1 in the linear domain) for the guided tokens?
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: root <root@pool0-01584.cm.cluster>
What does this PR do ?
This PR adds options passthrough to vLLM generation policy to enable guided decoding.
Issues
This PR resolves #603.
Usage
This PR adds a backend agnostic (i.e. does not depend on vLLM should new generation backend is added in the future) guided decoding config class (
nemo_rl.models.generation.interfaces.GuidedDecodingConfig).where
policyis any subclass ofGenerationInterfacewhich includesVllmGeneration.Before your PR is "Ready for review"
Pre checks:
Summary by CodeRabbit
New Features
Tests