Skip to content

[TRTLLM-10382][feat] Support rejection sampling in one-model spec dec#11001

Open
ziyixiong-nv wants to merge 1 commit intoNVIDIA:mainfrom
ziyixiong-nv:dev-fxiong-rejection-sampling
Open

[TRTLLM-10382][feat] Support rejection sampling in one-model spec dec#11001
ziyixiong-nv wants to merge 1 commit intoNVIDIA:mainfrom
ziyixiong-nv:dev-fxiong-rejection-sampling

Conversation

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator

@ziyixiong-nv ziyixiong-nv commented Jan 26, 2026

Summary by CodeRabbit

  • New Features

    • Added rejection sampling support for draft token selection during speculative decoding with Eagle3 and MTP models.
  • Configuration

    • New use_rejection_sampling option enables rejection sampling-based draft token acceptance in speculative decoding (requires allow_advanced_sampling).

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@ziyixiong-nv ziyixiong-nv requested review from a team as code owners January 26, 2026 09:57
@ziyixiong-nv ziyixiong-nv requested a review from hchings January 26, 2026 09:57
@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-rejection-sampling branch from e19ea8b to 41a04a5 Compare January 26, 2026 10:00
@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 26, 2026

📝 Walkthrough

Walkthrough

This PR introduces rejection sampling support for draft token selection in speculative decoding. It adds infrastructure for capturing draft logits, computing draft probabilities, and routing sampling through a rejection-sampling-aware path when enabled across Eagle3 and MTP models.

Changes

Cohort / File(s) Summary
Core rejection sampling interface
tensorrt_llm/_torch/speculative/interface.py
Expands SpecMetadata with new fields (use_rejection_sampling, draft_probs, draft_probs_vocab_size). Introduces helper methods: _can_use_rejection_sampling, _sample_and_accept_draft_tokens_rejection, and _compute_and_store_draft_probs to enable rejection sampling workflow.
Model-specific implementations
tensorrt_llm/_torch/speculative/eagle3.py, tensorrt_llm/_torch/speculative/mtp.py
Adds logit capture during draft generation (draft_logits_list). Conditionally computes and stores draft probabilities when rejection sampling is enabled. Routes sampling through rejection-sampling code path via _sample_and_accept_draft_tokens_rejection when applicable.
Sampling utilities
tensorrt_llm/_torch/speculative/one_model_sampler.py
Introduces compute_probs_from_logits to apply temperature and top-k/top-p masking. Adds rejection_sampling_one_model wrapper for CUDA-graph-compatible rejection sampling using chain_speculative_sampling.
Configuration propagation
tensorrt_llm/_torch/speculative/utils.py, tensorrt_llm/llmapi/llm_args.py
Propagates use_rejection_sampling and allow_advanced_sampling flags through MTPSpecMetadata and Eagle3OneModelSpecMetadata construction. Adds use_rejection_sampling field to DecodingBaseConfig.
Test coverage
tests/integration/defs/accuracy/test_llm_api_pytorch.py, tests/unittest/_torch/speculative/test_eagle3.py
Adds use_rejection_sampling parameterization to Eagle3 integration tests. Updates CDL sampling tests with use_cuda_graph parameterization and enables advanced sampling flags in test configurations.

Sequence Diagram(s)

sequenceDiagram
    actor Forward as Forward Pass (Eagle3/MTP)
    participant Draft as Draft Model
    participant LogitCapture as Logit Capture
    participant ProbCompute as Probability Computation
    participant Sampling as Rejection Sampling
    participant Output as Output Tokens
    
    Forward->>Draft: Generate draft tokens per step
    Draft-->>LogitCapture: Emit draft logits
    LogitCapture->>LogitCapture: Accumulate in draft_logits_list<br/>(when use_rejection_sampling=True)
    
    Forward->>ProbCompute: Call _compute_and_store_draft_probs<br/>(after draft loop)
    ProbCompute->>ProbCompute: Apply temperature & top-k/top-p<br/>Compute softmax → draft_probs
    ProbCompute->>ProbCompute: Store in spec_metadata.draft_probs
    
    Forward->>Sampling: sample_and_accept_draft_tokens
    alt use_rejection_sampling enabled
        Sampling->>Sampling: Reshape draft_probs to<br/>[batch_size, max_draft_len, vocab_size]
        Sampling->>Sampling: _sample_and_accept_draft_tokens_rejection<br/>(target_probs vs draft_probs)
        Sampling->>Sampling: Perform probabilistic acceptance
    else fallback to base path
        Sampling->>Sampling: Existing sampling logic
    end
    
    Sampling-->>Output: Accepted tokens & counts
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • mikeiovine
  • yweng0828
  • achartier
🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The pull request description is incomplete. The author left all required template sections empty, providing no description of the changes, test coverage information, or confirmation of the PR checklist items beyond a single checkbox mark. Complete the PR description by filling in: (1) Description section explaining what rejection sampling changes were made and why, (2) Test Coverage section listing relevant tests, and (3) explicit confirmation of PR checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely summarizes the main feature: adding rejection sampling support to one-model speculative decoding, with proper ticket reference and type annotation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/llmapi/llm_args.py (1)

689-696: Enforce allow_advanced_sampling when rejection sampling is enabled.

The rejection-sampling path expects per-token temperatures/top-k/top-p tensors; without allow_advanced_sampling=True, those stay unset and later slicing will fail. Add a validator to fail fast when the config is invalid.

✅ Proposed fix
 class DecodingBaseConfig(StrictBaseModel):
@@
     # Prototype. If true, allows non-greedy sampling when speculation is used. Only applicable
     # to 1-model code paths; non-greedy sampling is always enabled on 2-model paths.
     allow_advanced_sampling: bool = False
 
     # If true, uses rejection sampling for draft token acceptance instead of strict token equality.
     # Rejection sampling provides lossless acceleration that exactly matches the target model's
     # distribution. Requires allow_advanced_sampling=True. Only applicable to 1-model code paths.
     use_rejection_sampling: bool = False
+
+    `@model_validator`(mode="after")
+    def _validate_rejection_sampling(self):
+        if self.use_rejection_sampling and not self.allow_advanced_sampling:
+            raise ValueError(
+                "use_rejection_sampling requires allow_advanced_sampling=True")
+        return self

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33574 [ run ] triggered by Bot. Commit: 41a04a5

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33574 [ run ] completed with state FAILURE. Commit: 41a04a5
/LLM/main/L0_MergeRequest_PR pipeline #25898 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33584 [ run ] triggered by Bot. Commit: 86e7b19

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33584 [ run ] completed with state FAILURE. Commit: 86e7b19
/LLM/main/L0_MergeRequest_PR pipeline #25905 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33608 [ run ] triggered by Bot. Commit: 86e7b19

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33608 [ run ] completed with state FAILURE. Commit: 86e7b19
/LLM/main/L0_MergeRequest_PR pipeline #25926 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-rejection-sampling branch from 86e7b19 to 5264e8a Compare January 27, 2026 01:44
@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33646 [ run ] triggered by Bot. Commit: 5264e8a

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33647 [ run ] triggered by Bot. Commit: 5264e8a

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33647 [ run ] completed with state SUCCESS. Commit: 5264e8a
/LLM/main/L0_MergeRequest_PR pipeline #25959 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-rejection-sampling branch from 5264e8a to e351950 Compare January 27, 2026 05:24
@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33679 [ run ] triggered by Bot. Commit: e351950

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33679 [ run ] completed with state SUCCESS. Commit: e351950
/LLM/main/L0_MergeRequest_PR pipeline #25985 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-rejection-sampling branch from e351950 to 37bd7cb Compare January 27, 2026 08:14
@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #33703 [ run ] triggered by Bot. Commit: 37bd7cb

@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-rejection-sampling branch 2 times, most recently from b590ec3 to 03682c2 Compare March 6, 2026 01:03
@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #37931 [ run ] triggered by Bot. Commit: 03682c2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #37931 [ run ] completed with state SUCCESS. Commit: 03682c2
/LLM/main/L0_MergeRequest_PR pipeline #29376 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #37973 [ run ] triggered by Bot. Commit: 03682c2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #37973 [ run ] completed with state FAILURE. Commit: 03682c2
/LLM/main/L0_MergeRequest_PR pipeline #29409 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #37983 [ run ] triggered by Bot. Commit: 03682c2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #37983 [ run ] completed with state SUCCESS. Commit: 03682c2
/LLM/main/L0_MergeRequest_PR pipeline #29417 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-rejection-sampling branch from 03682c2 to ca9b7f0 Compare March 7, 2026 23:57
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38113 [ run ] triggered by Bot. Commit: ca9b7f0 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38113 [ run ] completed with state FAILURE. Commit: ca9b7f0
/LLM/main/L0_MergeRequest_PR pipeline #29524 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@ziyixiong-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@ziyixiong-nv ziyixiong-nv force-pushed the dev-fxiong-rejection-sampling branch from ca9b7f0 to fa740bf Compare March 9, 2026 19:31
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38320 [ run ] triggered by Bot. Commit: fa740bf Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #38320 [ run ] completed with state SUCCESS. Commit: fa740bf
/LLM/main/L0_MergeRequest_PR pipeline #29696 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@mikeiovine mikeiovine force-pushed the dev-fxiong-rejection-sampling branch from fa740bf to 9aad2c7 Compare March 17, 2026 18:27
@mikeiovine
Copy link
Copy Markdown
Collaborator

/bot run

Copy link
Copy Markdown
Collaborator

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39305 [ run ] triggered by Bot. Commit: 9aad2c7 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39305 [ run ] completed with state SUCCESS. Commit: 9aad2c7
/LLM/main/L0_MergeRequest_PR pipeline #30554 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
@mikeiovine mikeiovine force-pushed the dev-fxiong-rejection-sampling branch from 9aad2c7 to 0a55552 Compare March 18, 2026 15:15
@mikeiovine
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39473 [ run ] triggered by Bot. Commit: 0a55552 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #39473 [ run ] completed with state SUCCESS. Commit: 0a55552
/LLM/main/L0_MergeRequest_PR pipeline #30698 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants