[Bugfix][Async] Fix async spec decoding with hybrid models#38556
Merged
MatthewBonanni merged 11 commits intovllm-project:mainfrom Mar 31, 2026
Merged
[Bugfix][Async] Fix async spec decoding with hybrid models#38556MatthewBonanni merged 11 commits intovllm-project:mainfrom
MatthewBonanni merged 11 commits intovllm-project:mainfrom
Conversation
…re_next_token_ids_padded) When async scheduling is enabled (zero-bubble spec decoding, PR vllm-project#32951), optimistic_seq_lens_cpu = num_computed_tokens + num_scheduled_tokens is passed to prepare_next_token_ids_padded as seq_lens_cpu. This value is inflated relative to the actual committed output_token_ids because _prepare_inputs appends -1 placeholder slots optimistically. The backup token lookup calls: request.get_token_id(seq_lens_cpu[i]) where seq_lens_cpu[i] points one past the end of the committed tokens, causing get_token_id() to return -1 (placeholder). The drafter then receives -1 as its next input token, which corrupts its hidden state and degrades the draft acceptance rate — causing the Nemotron-3-Super-120B BF16 GSM8K score to drop from ~0.93 to ~0.74. Fix: use (num_tokens_no_spec[i] - 1) — the index of the last committed output token — for the backup token lookup in both EagleProposer (eagle.py) and ExtractHiddenStatesProposer (extract_hidden_states.py). num_tokens_no_spec is set to request.num_tokens before the optimistic extend, so it always points to a valid token slot. Fixes: vllm-project#38098 Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Per Gemini review, the original name was misleading — the buggy code was always off-by-one, not just when async inflation was present. Rename to test_buggy_code_was_always_off_by_one and update the docstring to clearly explain that seq_len (= num_tokens) is always out of range for get_token_id(). Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request addresses issues with async scheduling in speculative decoding. It updates eagle.py and extract_hidden_states.py to use num_tokens_no_spec - 1 instead of sequence lengths to correctly identify the last committed token, preventing errors caused by inflated sequence lengths from async-scheduling placeholders. Additionally, gpu_model_runner.py is updated to correctly map num_accepted_tokens using prev_positions when async scheduling is enabled, accounting for index reordering by condense(). I have no feedback to provide.
8 tasks
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
e1e41ab to
46fa724
Compare
NickLucche
reviewed
Mar 31, 2026
Collaborator
NickLucche
left a comment
There was a problem hiding this comment.
I think we want to test this @ZhanqiuHu
khluu
pushed a commit
that referenced
this pull request
Apr 1, 2026
Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com> (cherry picked from commit 757068d)
EricccYang
pushed a commit
to EricccYang/vllm
that referenced
this pull request
Apr 1, 2026
…ect#38556) Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: SandishKumarHN <sandishkumarhn@gmail.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
co-authored by @SandishKumarHN
FIX: #38098
Purpose
Incorporates 2 fixes:
Fix 1
Posted earlier as #38419, incorporated into this PR.
In async mode,
seq_lens_cpuis inflated by optimistic draft token placeholders. Whenprepare_next_token_ids_paddeduses this inflated value to callget_token_id(), it reads past the end of the committed tokens and returns -1. Usenum_tokens_no_spec - 1(the actual last committed token position) instead ofseq_lens_cpufor computing backup token indices.Fix 2
In async mode,
condense()copiesnum_accepted_tokens_cpuvalues while the GPU→CPU async copy from the previous batch is still in-flight. This results in stale values being propagated to reordered indices, corrupting Mamba hidden states.Test Plan
LM Eval Large Models (H200)
Test Result
main: Fails
PR: Passes
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.