Skip to content

[Spec] FrozenKVMTP fold assistant seed into captured draft graph#25539

Open
kpham-sgl wants to merge 3 commits into
mainfrom
kp/gemma4-mtp-remove-seed
Open

[Spec] FrozenKVMTP fold assistant seed into captured draft graph#25539
kpham-sgl wants to merge 3 commits into
mainfrom
kp/gemma4-mtp-remove-seed

Conversation

@kpham-sgl
Copy link
Copy Markdown
Collaborator

@kpham-sgl kpham-sgl commented May 17, 2026

Motivation

Frozen-KV MTP ran a one-token eager assistant "seed" forward before the
captured recurrent draft loop. Seed and recurrent iters share the same
seq_lens - 1 rope position against frozen target KV, so splitting them
just costs an extra launch + sync per decode.

Modifications

  • frozen_kv_mtp_worker.py: _run_assistant_seed_step only stashes
    bonus_tokens / hidden_states now; draft_forward runs the seed
    as iter 0 of the recurrent loop (topk>1 replicate/slice inline).
    Attn init no longer gated on num_steps > 1. Drops the stale
    topk == 1 shortcut in _init_draft_attn_backend.
  • frozen_kv_mtp_cuda_graph_runner.py: adds bonus_tokens to input
    buffers; stops copying topk_p/topk_index (now produced by the
    captured seed iter). Wraps _replay() in a record_function span.

No behavior change on verify / accept / KV-write paths.

Accuracy Tests

Test with #24552 for the 31B model

topk num_draft_tokens GSM8K score threshold avg_spec_accept_length result
1 6 0.8150 0.7750 4.4767 PASS
3 12 0.8000 0.7750 5.0259 PASS

Speed Tests and Profiling

Before
Screenshot 2026-05-17 at 11 13 04 AM

After
Screenshot 2026-05-17 at 11 13 22 AM

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

CI States

Latest PR Test (Base): ✅ Run #26061044934
Latest PR Test (Extra): ⚠️ Not enabled -- add run-ci-extra label to opt in.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Frozen-KV MTP (Multi-Token Prediction) implementation by moving the assistant seed forward pass into the captured CUDA graph within draft_forward. Key changes include adding bonus_tokens to the input buffers, updating the worker to stash seed inputs for the graph runner, and enforcing the use of the Triton attention backend. Feedback suggests improving the API by removing unused arguments from the _run_assistant_seed_step signature and optimizing performance by using torch.empty instead of torch.zeros for tensors that are immediately overwritten.

Comment on lines 352 to +357
mm_input_embeds: Optional[torch.Tensor] = None,
draft_input: Optional[FrozenKVMTPDraftInput] = None,
) -> None:
"""Run the one-token assistant seed step against frozen target KV."""
"""Stash seed inputs on ``batch.spec_info``; the forward runs inside
the captured draft graph (see ``draft_forward``'s seed iter)."""
del seq_lens_cpu, mm_input_embeds, draft_input
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The arguments seq_lens_cpu, mm_input_embeds, and draft_input are no longer used in _run_assistant_seed_step because the assistant forward pass has been moved into the captured draft graph (executed within draft_forward). While they are currently being explicitly deleted to avoid unused variable warnings, it would be cleaner to remove them from the function signature and update the callers (forward_draft_extend and forward_draft_extend_after_decode) accordingly to maintain a clean API.

Comment on lines +376 to +381
stashed.topk_p = torch.zeros(
(bs, self.topk), device=device, dtype=torch.float32
)
stashed.topk_index = torch.zeros(
(bs, self.topk), device=device, dtype=torch.int64
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since topk_p and topk_index are only initialized here for compatibility with filter_batch/merge_batch and are overwritten by the captured seed iteration in draft_forward, you can use torch.empty instead of torch.zeros to avoid unnecessary zero-initialization overhead on the GPU.

Suggested change
stashed.topk_p = torch.zeros(
(bs, self.topk), device=device, dtype=torch.float32
)
stashed.topk_index = torch.zeros(
(bs, self.topk), device=device, dtype=torch.int64
)
stashed.topk_p = torch.empty(
(bs, self.topk), device=device, dtype=torch.float32
)
stashed.topk_index = torch.empty(
(bs, self.topk), device=device, dtype=torch.int64
)

@kpham-sgl
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

kpham-sgl and others added 3 commits May 18, 2026 21:15
@kpham-sgl kpham-sgl force-pushed the kp/gemma4-mtp-remove-seed branch from 4d9203d to ee7c80c Compare May 18, 2026 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant