[Spec] Add trtllm_mha support for Gemma 4 MTP draft attention backend#25545
[Spec] Add trtllm_mha support for Gemma 4 MTP draft attention backend#25545kpham-sgl wants to merge 1 commit into
trtllm_mha support for Gemma 4 MTP draft attention backend#25545Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the Frozen-KV MTP implementation by moving the assistant seed step's forward pass into the captured draft graph, necessitating updates to the CUDA graph runner and draft_forward logic. It also adds support for the trtllm_mha attention backend and includes NVTX profiling spans for the draft loop. Feedback recommends refactoring the _run_assistant_seed_step signature to remove parameters that became redundant after the logic move.
I am having trouble creating individual review comments. Click here to see my feedback.
python/sglang/srt/speculative/frozen_kv_mtp_worker.py (374)
The arguments seq_lens_cpu, mm_input_embeds, and draft_input are now unused in _run_assistant_seed_step because the forward pass logic has been moved into the captured graph in draft_forward. While using del prevents unused variable warnings, it would be cleaner to eventually refactor the function signature and its callers to remove these redundant parameters.
903a193 to
4d9203d
Compare
e38b916 to
ac4a6b1
Compare
4d9203d to
ee7c80c
Compare
ac4a6b1 to
9b97724
Compare
Motivation
Faster draft backend. Note: there are certain issues being tracked here flashinfer-ai/flashinfer#3343
Modifications
Add
trtllm_mhaas a viable backend for Gemma 4 MTP following #25006. Stacking on top of #25539Accuracy Tests
Verified with Gemma-4-31B-it on GSM8K, 5-shot, 200 examples. Server config is
target=trtllm_mha+draft=trtllm_mhaand--speculative-num-steps=3,--speculative-eagle-topk=1, and--speculative-num-draft-tokens=4Speed Tests and Profiling
Same as above
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ❌ Missing
run-cilabel -- add it to run CI tests.Latest PR Test (Extra): ❌ Blocked --
run-ciis required first.