[Model Runner V2] Rebuild attn metadata between draft decode steps#41162
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the EagleSpeculator to support multi-step decoding with CUDA Graphs, introducing a multi_step_decode method and splitting draft generation logic. It also includes optimizations for metadata building and fixes typos in kernel naming. Feedback focuses on two critical areas: first, the reassignment of attn_metadata within the decode loop is incompatible with FULL CUDA graph execution because it fails to update the captured static inputs. Second, moving hidden state updates from the fused Triton kernel to a separate torch.copy_ call introduces additional kernel launch overhead, which may negatively impact latency in speculative decoding.
bbe2722 to
3cd2adf
Compare
| # Update the inputs for the next step. | ||
| update_eagle_draft_inputs( | ||
| num_reqs, | ||
| self.decode_draft_tokens[:num_reqs], | ||
| self.input_buffers, | ||
| self.max_model_len, | ||
| ) |
There was a problem hiding this comment.
Shouldn't we skip this for the very last decode step?
There was a problem hiding this comment.
The issue is that we can't pass the current draft step into generate_draft during FULL cudagraph, so it can't know when to skip or not.
While it results in an extra update_eagle_draft_inputs call for the last draft token, we at least don't pay for the kernel launch overhead. If we moved it outside of generate_draft into the loop and skipped it on the last step, we'd pay for the kernel launch overhead on every step except the last.
3cd2adf to
d3fc140
Compare
f960d1b to
af1131f
Compare
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
af1131f to
4827469
Compare
|
Was able to put the current draft step in a scalar tensor, and use it during the CG-captured
No more temporary buffers are needed. cc: @WoosukKwon |
…llm-project#41162) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…llm-project#41162) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
…llm-project#41162) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
Context
Investigating DSV4 "invalid memory access" crash that happens when MTP > 2. I believe that crash is caused by not rebuilding the attention metadata in between draft decode steps. DSV4's attention metadata builders have position-dependent state that must be updated whenever the position is advanced. For example:
vllm/vllm/v1/attention/backends/mla/sparse_swa.py
Lines 293 to 307 in e9f8f31
During draft decoding, we only update the sequence lengths, positions, and slot mappings, but not the metadata properties dependent on those:
vllm/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py
Lines 316 to 327 in e9f8f31
Any attention backend with builders that have position-dependent state can potentially be affected by this bug.
This PR
Fixes the above issue by rebuilding attention metadata in between draft generation steps during the multi-step decode loop. Because attention metadata building is not guaranateed to be a cudagraph-compatible operation, we must capture the single-step draft generation, and replay in each iteration of the loop. The captured draft generation routine consists of the following operations, in order:
run_model- Draft model forward pass to get hidden states.compute_logits- Computes logits from hidden states._sample_draft- Samples a draft token using either gumbel or greedy sampling (depending on what was configured)update_eagle_draft_inputs- Updates the output draft tokens (self.draft_tokens) for the current step, and updates the inputs for the next draft step, such as input ids, hidden states, positions, etc.Because cudagraph prevents me from passing an integer step idx to determine which draft tokens/logits column to write the outputs to, I had to introduce a scalar tensor called
self.current_draft_step. That single-element tensor isused by
update_eagle_draft_inputsto write the output draft token ids to the output buffer (self.draft_tokens).Updating the correct column (step) of
self.draft_logitswas a bit tricker. I couldn't just pass it into the sample method like i did before:because step is now a scalar tensor instead of an int. Attempting to do so would trigger advanced indexing, returning a new copy of the tensor, not a slice. So I had to add support to
gumbel_samplefor accepting a column tensor (output_processed_logits_col), indicating which column in the logits tensor to write to.Test Plan
Before, the following server and bench commands:
Would result in a crash with an IMA error.
Now with the changes in this PR, it serves to completion.
Benchmarks
I measured the effect that this PR has on performance, to make sure it doesn't result in any significant performance regressions. I measured all of them with 3 speculative tokens. The following benchmark command was used:
NOTE: GLM 4.7 Flash with 3 speculative tokens failed on main, likely due to the issue that this PR fixed.
GLM 4.7 Flash
Base: zai-org/GLM-4.7-Flash
Draft: MTP
Temperature = 0.0
Temperature = 1.0
GPT-OSS 20B
Base: openai/gpt-oss-20b
Draft: RedHatAI/gpt-oss-20b-speculator.eagle3
Temperature = 0.0
Temperature = 1.0
Llama3 8B
Base: meta-llama/Meta-Llama-3-8B-Instruct
Draft: yuhuili/EAGLE-LLaMA3-Instruct-8B
Temperature = 0.0
Temperature = 1.0
Qwen3 8B
Base: Qwen/Qwen3-8B
Draft: RedHatAI/Qwen3-8B-speculator.eagle3
Temperature = 0.0
Temperature = 1.0
Future Work
It's unfortunate that we now lose the full cudagraph captured loop for the last N-1 draft tokens, but these changes were necessary for correctness. It may be possible to preserve the full cudagraph captured loop by introducing a new method to the
AttentionMetadataBuilderAPI similar toupdate_block_tablethat allows updating the attention metadata in-place assuming single-token-decodes and that only the sequence position has changed since the last rebuild. With those assumptions, one could implement cudagraph-safe updates on a per-backend basis.