Skip to content

[Spec Decode] Unified Parallel Drafting#32887

Merged
benchislett merged 24 commits intovllm-project:mainfrom
CentML:bchislett/unified-parallel-drafting
Feb 5, 2026
Merged

[Spec Decode] Unified Parallel Drafting#32887
benchislett merged 24 commits intovllm-project:mainfrom
CentML:bchislett/unified-parallel-drafting

Conversation

@benchislett
Copy link
Collaborator

@benchislett benchislett commented Jan 22, 2026

Purpose

This PR implements a single input preparation kernel for draft model support, and parallel drafting both with and without hidden states from the target model. As such we now have support for AMD's PARD, which proposes parallel drafting for fine-tuned external draft models, and AWS' P-EAGLE which implements parallel prediction for EAGLE3. Both of these are benchmarked as part of this PR effort.

Testing

E2E tests for parallel drafting and unit tests for the input preparation logic are all passing locally. Confirmed that E2E tests for draft models and EAGLE3 are also still passing locally.

Benchmarks

Benchmarks were conducted for AWS's 2-layer P-EAGLE on GPT-OSS 120B using Acceptance Lengths calculated by averaging AL over each of the MTBench categories, with 2048 max output tokens. The baseline is NVIDIA's EAGLE3 short-context. I also compare AMD's PARD Llama 3.2 1B for Llama 3.3 70B NVFP4, with the autoregressive drafter as a baseline. All benchmarks on 1xB200.

Best config for GPT-OSS at BS=1 is P-EAGLE with K=7, with ~560 output TPS, a speedup of 1.52x over baseline and 1.12x over EAGLE3 best config. At BS=8, P-EAGLE is optimal with K=3, a speedup of 1.34x over baseline and 1.07x over best EAGLE3.

Best config for Llama 3.3 70B-NVFP4 at BS=1 is PARD Llama-1B with K=11, with ~254 output TPS, a speedup of 3.10x over baseline and 1.61x over vanilla draft-model. At BS=8, PARD is optimal with K=7, a speedup of 2.87x over baseline and 1.61x over vanilla draft-model.

All data, with best-at-concurrency bolded for each model.
Checkpoint K BS AL Median Iter Time Est. TPS
GPT-OSS 120B EAGLE3 0 1 1 0.0027 370
GPT-OSS 120B EAGLE3 3 1 2.4015427 0.0048 500.3213958
GPT-OSS 120B EAGLE3 4 1 2.5880982 0.0055 470.5633091
GPT-OSS 120B EAGLE3 5 1 2.685875283 0.0061 440.3074234
GPT-OSS 120B EAGLE3 7 1 2.776950087 0.007 396.7071553
GPT-OSS 120B P-EAGLE 0 1 1 0.0027 370
GPT-OSS 120B P-EAGLE 2 1 2.13727217 0.0042 508.8743262
GPT-OSS 120B P-EAGLE 3 1 2.3888 0.0043 555.5348837
GPT-OSS 120B P-EAGLE 4 1 2.545543 0.0046 553.378913
GPT-OSS 120B P-EAGLE 5 1 2.621487 0.0048 546.143125
GPT-OSS 120B P-EAGLE 7 1 2.6937 0.0048 561.1875
Llama 70B-NVFP4 Draft 1B 0 1 1 0.0122 81.96721311
Llama 70B-NVFP4 Draft 1B 3 1 2.964271208 0.0203 146.0232122
Llama 70B-NVFP4 Draft 1B 5 1 3.827371847 0.0243 157.5050143
Llama 70B-NVFP4 Draft 1B 7 1 4.408360614 0.0289 152.5384295
Llama 70B-NVFP4 PARD 1B 0 1 1 0.0122 81.96721311
Llama 70B-NVFP4 PARD 1B 3 1 2.72759473 0.0142 192.0841359
Llama 70B-NVFP4 PARD 1B 5 1 3.251206494 0.0141 230.5820208
Llama 70B-NVFP4 PARD 1B 7 1 3.568283773 0.0141 253.0697711
Llama 70B-NVFP4 PARD 1B 11 1 3.68784371 0.0145 254.334049
Llama 70B-NVFP4 PARD 1B 15 1 3.678693636 0.0145 253.7030094
GPT-OSS 120B EAGLE3 0 8 1 0.0045 1777.777778
GPT-OSS 120B EAGLE3 2 8 2.146052648 0.008 2146.052648
GPT-OSS 120B EAGLE3 3 8 2.4015427 0.0086 2233.993209
GPT-OSS 120B EAGLE3 4 8 2.5880982 0.01 2070.47856
GPT-OSS 120B EAGLE3 5 8 2.685875283 0.0108 1989.537247
GPT-OSS 120B EAGLE3 7 8 2.776950087 0.0117 1898.769291
GPT-OSS 120B P-EAGLE 0 8 1 0.0045 1777.777778
GPT-OSS 120B P-EAGLE 2 8 2.13727217 0.0080 2137.27217
GPT-OSS 120B P-EAGLE 3 8 2.3888 0.0080 2388.8
GPT-OSS 120B P-EAGLE 4 8 2.545543 0.009 2262.704889
GPT-OSS 120B P-EAGLE 5 8 2.621487 0.0094 2231.052766
GPT-OSS 120B P-EAGLE 7 8 2.6937 0.0095 2268.378947
Llama 70B-NVFP4 Draft 1B 0 8 1 0.0118 677.9661017
Llama 70B-NVFP4 Draft 1B 3 8 2.964271208 0.0208 1140.104311
Llama 70B-NVFP4 Draft 1B 5 8 3.827371847 0.0258 1186.781968
Llama 70B-NVFP4 Draft 1B 7 8 4.408360614 0.0299 1179.494479
Llama 70B-NVFP4 PARD 1B 0 8 1 0.012 666.6666667
Llama 70B-NVFP4 PARD 1B 3 8 2.72759473 0.0143 1525.927122
Llama 70B-NVFP4 PARD 1B 5 8 3.251206494 0.0147 1769.364078
Llama 70B-NVFP4 PARD 1B 7 8 3.568283773 0.0149 1915.857059
Llama 70B-NVFP4 PARD 1B 11 8 3.68784371 0.0156 1891.201903
Llama 70B-NVFP4 PARD 1B 15 8 3.678693636 0.0158 1862.629689

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify
Copy link

mergify bot commented Jan 22, 2026

Documentation preview: https://vllm--32887.org.readthedocs.build/en/32887/

@mergify mergify bot added documentation Improvements or additions to documentation nvidia speculative-decoding labels Jan 22, 2026
@mergify mergify bot added the v1 label Jan 22, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a unified parallel drafting mechanism for speculative decoding, combining logic for EAGLE and other draft models. The changes are extensive, primarily refactoring the speculative decoding logic into a base proposer class and adding a new, complex Triton kernel for preparing inputs. While the overall refactoring appears sound, I've identified a potential critical issue in the new Triton kernel where a safeguard against out-of-bounds memory access is not being used, which could lead to memory corruption.

@mergify
Copy link

mergify bot commented Jan 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 23, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot added the llama Related to Llama models label Jan 24, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot removed the needs-rebase label Jan 26, 2026
@mergify
Copy link

mergify bot commented Jan 27, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 27, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@zihaoanllm
Copy link
Contributor

zihaoanllm commented Jan 29, 2026

Hi @benchislett ,Thanks a lot for your great work!! I tested the PARD integration in vLLM and compared its performance with the PARD repo example. Under the same configuration, the acceptance length is well aligned between the two runs (3.56 vs 3.50).

Below are the full benchmark results and the exact script I used for vLLM testing.

Results

framework target draft method bmk device bs k baseline tps PARD tps speedup accept length
pard repo with transformers+ L3.1 8B PARD mt_bench A100-40GB 1 8 76.55 197.77 2.58 3.50
vllm L3.1 8B PARD mt_bench A100-40GB 1 8 77.69 202.29 2.60 3.56

vLLM test script

# Start server
k=8
target=unsloth/Meta-Llama-3.1-8B-Instruct
draft=amd/PARD-Llama-3.2-1B

vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
  --speculative-config '{"model": "'"$draft"'", "method": "draft_model", "num_speculative_tokens": '"$k"', "parallel_drafting": true}' \

# Benchmark
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
    --temperature 0 \
    --backend openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts ${NUM_PROMPTS} \
    --max-concurrency ${MAX_CONCURRENCY}

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@benchislett benchislett merged commit af3162d into vllm-project:main Feb 5, 2026
62 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 5, 2026
@benchislett
Copy link
Collaborator Author

For future readers, here is a link to AWS's P-EAGLE arxiv paper:
https://arxiv.org/pdf/2602.01469

LucasWilkinson added a commit that referenced this pull request Feb 12, 2026
When using MTP speculative decoding, the compile range is extended by
(multiplier * max_num_seqs), but the assertion in _dummy_run only checked
against max_num_batched_tokens, causing warmup to fail.

This was introduced in commit af3162d (Unified Parallel Drafting #32887)
which added MTP/Eagle to the compile range extension logic.

Fix: Use the maximum compile range split point as the upper bound when
available, instead of max_num_batched_tokens.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 12, 2026
When using MTP speculative decoding, the compile range is extended by
(multiplier * max_num_seqs), but the assertion in _dummy_run only checked
against max_num_batched_tokens, causing warmup to fail.

This was introduced in commit af3162d (Unified Parallel Drafting #32887)
which added MTP/Eagle to the compile range extension logic.

Fix: Extend the assertion bound for speculative decoding configs only,
mirroring the compile range extension logic in _set_compile_ranges.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 12, 2026
When using MTP speculative decoding, the compile range is extended by
(multiplier * max_num_seqs), but the assertion in _dummy_run only checked
against max_num_batched_tokens, causing warmup to fail.

This was introduced in commit af3162d (Unified Parallel Drafting #32887)
which added MTP/Eagle to the compile range extension logic.

Fix: Extend the assertion bound for speculative decoding configs only,
mirroring the compile range extension logic in _set_compile_ranges.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 12, 2026
When using parallel speculative decoding, the compile range is extended
by (num_speculative_tokens * max_num_seqs), but the assertion in
_dummy_run only checked against max_num_batched_tokens, causing warmup
to fail.

This was introduced in commit af3162d (Unified Parallel Drafting #32887)
which added MTP/Eagle to the compile range extension logic.

Fix: Extend the assertion bound for parallel speculative decoding only.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 13, 2026
When using speculative decoding (MTP/Eagle/draft model), the compile
range is extended by (multiplier * max_num_seqs), but the assertions
in _dummy_run only checked against max_num_batched_tokens, causing
warmup to fail.

This was introduced in commit af3162d (Unified Parallel Drafting #32887)
which added MTP/Eagle to the compile range extension logic.

Fix: Extend both assertion bounds in _dummy_run for speculative decoding
configs, mirroring the compile range extension logic in _set_compile_ranges.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 13, 2026
When using parallel drafting (MTP/Eagle with parallel_drafting=True),
the compile range is extended by (num_speculative_tokens * max_num_seqs)
to accommodate drafter batches, but the assertions in _dummy_run only
checked against max_num_batched_tokens, causing warmup to fail.

This was introduced in commit af3162d (Unified Parallel Drafting #32887)
which added the compile range extension for parallel drafting.

Fix: Extend both assertion bounds in _dummy_run for parallel drafting
only, matching the compile range extension logic in _set_compile_ranges.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 13, 2026
Fix an assertion error during model warmup when using MTP speculative
decoding with parallel drafting. The issue occurred because the compile
range was extended for the drafter, but the warmup sizes included this
extended range, causing the target model's _dummy_run to be called with
sizes exceeding max_num_batched_tokens.

Root cause: PR #32887 added compile range extension for speculative
decoding to warm up the drafter, but this caused the target model's
_dummy_run assertion to fail.

Fix approach: Instead of extending the compile range (which affects the
target model), we now:
1. Keep the target model's compile range at max_num_batched_tokens
2. Warm up the drafter separately with its extended size in gpu_worker.py

This properly separates the warmup concerns - the target model never
sees batches larger than max_num_batched_tokens (the scheduler ensures
this), while the drafter is warmed up with its extended batch sizes.

Co-Authored-By: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 13, 2026
Fix an assertion error during model warmup when using MTP speculative
decoding with parallel drafting. The issue occurred because the compile
range is extended for speculative decoding to accommodate drafter
batches, but the assertion in _dummy_run wasn't updated to match.

Root cause: PR #32887 added compile range extension in _set_compile_ranges
for speculative decoding. This causes warmup sizes to exceed
max_num_batched_tokens, triggering the assertion in _dummy_run.

Fix: Extend the assertion bound in _dummy_run to match the extended
compile range when parallel drafting is enabled.

Co-Authored-By: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson added a commit that referenced this pull request Feb 13, 2026
Fix an assertion error during model warmup when using MTP speculative
decoding with parallel drafting. The issue occurred because the compile
range is extended for speculative decoding to accommodate drafter
batches, but the assertion in _dummy_run wasn't updated to match.

Root cause: PR #32887 added compile range extension in _set_compile_ranges
for speculative decoding. This causes warmup sizes to exceed
max_num_batched_tokens, triggering the assertion in _dummy_run.

Fix: Extend the assertion bound in _dummy_run to match the extended
compile range when parallel drafting is enabled.

Co-Authored-By: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
shaharmor98 pushed a commit to CentML/vllm that referenced this pull request Feb 16, 2026
Fix an assertion error during model warmup when using MTP speculative
decoding with parallel drafting. The issue occurred because the compile
range is extended for speculative decoding to accommodate drafter
batches, but the assertion in _dummy_run wasn't updated to match.

Root cause: PR vllm-project#32887 added compile range extension in _set_compile_ranges
for speculative decoding. This causes warmup sizes to exceed
max_num_batched_tokens, triggering the assertion in _dummy_run.

Fix: Extend the assertion bound in _dummy_run to match the extended
compile range when parallel drafting is enabled.

Co-Authored-By: Claude <noreply@anthropic.com>

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@hmellor hmellor mentioned this pull request Mar 4, 2026
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
):
max_num_queries_for_spec = (
1
+ (2 if speculative_config.parallel_drafting else 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 1+1+ speculative_config.num_speculative_tokens without parallel_drafting here? Big issue here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are misunderstanding. It's 1 + 1 * num_spec_tokens. So it will be the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation llama Related to Llama models nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants