[Spec] FrozenKVMTP fold assistant seed into captured draft graph#25539
[Spec] FrozenKVMTP fold assistant seed into captured draft graph#25539kpham-sgl wants to merge 3 commits into
FrozenKVMTP fold assistant seed into captured draft graph#25539Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Frozen-KV MTP (Multi-Token Prediction) implementation by moving the assistant seed forward pass into the captured CUDA graph within draft_forward. Key changes include adding bonus_tokens to the input buffers, updating the worker to stash seed inputs for the graph runner, and enforcing the use of the Triton attention backend. Feedback suggests improving the API by removing unused arguments from the _run_assistant_seed_step signature and optimizing performance by using torch.empty instead of torch.zeros for tensors that are immediately overwritten.
| mm_input_embeds: Optional[torch.Tensor] = None, | ||
| draft_input: Optional[FrozenKVMTPDraftInput] = None, | ||
| ) -> None: | ||
| """Run the one-token assistant seed step against frozen target KV.""" | ||
| """Stash seed inputs on ``batch.spec_info``; the forward runs inside | ||
| the captured draft graph (see ``draft_forward``'s seed iter).""" | ||
| del seq_lens_cpu, mm_input_embeds, draft_input |
There was a problem hiding this comment.
The arguments seq_lens_cpu, mm_input_embeds, and draft_input are no longer used in _run_assistant_seed_step because the assistant forward pass has been moved into the captured draft graph (executed within draft_forward). While they are currently being explicitly deleted to avoid unused variable warnings, it would be cleaner to remove them from the function signature and update the callers (forward_draft_extend and forward_draft_extend_after_decode) accordingly to maintain a clean API.
| stashed.topk_p = torch.zeros( | ||
| (bs, self.topk), device=device, dtype=torch.float32 | ||
| ) | ||
| stashed.topk_index = torch.zeros( | ||
| (bs, self.topk), device=device, dtype=torch.int64 | ||
| ) |
There was a problem hiding this comment.
Since topk_p and topk_index are only initialized here for compatibility with filter_batch/merge_batch and are overwritten by the captured seed iteration in draft_forward, you can use torch.empty instead of torch.zeros to avoid unnecessary zero-initialization overhead on the GPU.
| stashed.topk_p = torch.zeros( | |
| (bs, self.topk), device=device, dtype=torch.float32 | |
| ) | |
| stashed.topk_index = torch.zeros( | |
| (bs, self.topk), device=device, dtype=torch.int64 | |
| ) | |
| stashed.topk_p = torch.empty( | |
| (bs, self.topk), device=device, dtype=torch.float32 | |
| ) | |
| stashed.topk_index = torch.empty( | |
| (bs, self.topk), device=device, dtype=torch.int64 | |
| ) |
|
/tag-and-rerun-ci |
903a193 to
4d9203d
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
4d9203d to
ee7c80c
Compare
Motivation
Frozen-KV MTP ran a one-token eager assistant "seed" forward before the
captured recurrent draft loop. Seed and recurrent iters share the same
seq_lens - 1rope position against frozen target KV, so splitting themjust costs an extra launch + sync per decode.
Modifications
frozen_kv_mtp_worker.py:_run_assistant_seed_steponly stashesbonus_tokens/hidden_statesnow;draft_forwardruns the seedas iter 0 of the recurrent loop (topk>1 replicate/slice inline).
Attn init no longer gated on
num_steps > 1. Drops the staletopk == 1shortcut in_init_draft_attn_backend.frozen_kv_mtp_cuda_graph_runner.py: addsbonus_tokensto inputbuffers; stops copying
topk_p/topk_index(now produced by thecaptured seed iter). Wraps
_replay()in arecord_functionspan.No behavior change on verify / accept / KV-write paths.
Accuracy Tests
Test with #24552 for the 31B model
Speed Tests and Profiling
Before

After

Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ✅ Run #26061044934⚠️ Not enabled -- add
Latest PR Test (Extra):
run-ci-extralabel to opt in.