Skip to content

[Spec] Add trtllm_mha support for Gemma 4 MTP draft attention backend#25545

Open
kpham-sgl wants to merge 1 commit into
kp/gemma4-mtp-remove-seedfrom
kp/gemma4-mtp-trtllm_mha-backend
Open

[Spec] Add trtllm_mha support for Gemma 4 MTP draft attention backend#25545
kpham-sgl wants to merge 1 commit into
kp/gemma4-mtp-remove-seedfrom
kp/gemma4-mtp-trtllm_mha-backend

Conversation

@kpham-sgl
Copy link
Copy Markdown
Collaborator

@kpham-sgl kpham-sgl commented May 17, 2026

Motivation

Faster draft backend. Note: there are certain issues being tracked here flashinfer-ai/flashinfer#3343

Modifications

Add trtllm_mha as a viable backend for Gemma 4 MTP following #25006. Stacking on top of #25539

Accuracy Tests

Verified with Gemma-4-31B-it on GSM8K, 5-shot, 200 examples. Server config is target=trtllm_mha + draft=trtllm_mha and --speculative-num-steps=3, --speculative-eagle-topk=1, and --speculative-num-draft-tokens=4

metric value
GSM8K score 0.800
avg_spec_accept_length 2.536 / 4 (≈ 63% of max)
latency 18.66 s
output throughput 1294 tok/s

Speed Tests and Profiling

Same as above

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

CI States

Latest PR Test (Base): ❌ Missing run-ci label -- add it to run CI tests.
Latest PR Test (Extra): ❌ Blocked -- run-ci is required first.

@kpham-sgl kpham-sgl changed the base branch from main to kp/gemma4-mtp-remove-seed May 17, 2026 21:57
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Frozen-KV MTP implementation by moving the assistant seed step's forward pass into the captured draft graph, necessitating updates to the CUDA graph runner and draft_forward logic. It also adds support for the trtllm_mha attention backend and includes NVTX profiling spans for the draft loop. Feedback recommends refactoring the _run_assistant_seed_step signature to remove parameters that became redundant after the logic move.

I am having trouble creating individual review comments. Click here to see my feedback.

python/sglang/srt/speculative/frozen_kv_mtp_worker.py (374)

medium

The arguments seq_lens_cpu, mm_input_embeds, and draft_input are now unused in _run_assistant_seed_step because the forward pass logic has been moved into the captured graph in draft_forward. While using del prevents unused variable warnings, it would be cleaner to eventually refactor the function signature and its callers to remove these redundant parameters.

@kpham-sgl kpham-sgl force-pushed the kp/gemma4-mtp-remove-seed branch from 903a193 to 4d9203d Compare May 17, 2026 22:14
@kpham-sgl kpham-sgl force-pushed the kp/gemma4-mtp-trtllm_mha-backend branch from e38b916 to ac4a6b1 Compare May 17, 2026 22:14
@kpham-sgl kpham-sgl force-pushed the kp/gemma4-mtp-remove-seed branch from 4d9203d to ee7c80c Compare May 18, 2026 21:18
@kpham-sgl kpham-sgl force-pushed the kp/gemma4-mtp-trtllm_mha-backend branch from ac4a6b1 to 9b97724 Compare May 18, 2026 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant