Skip to content

Enable trtllm_mha as gemma4 default attn backend.#25006

Merged
kpham-sgl merged 2 commits into
sgl-project:mainfrom
wenscarl:Gemma4_trtllm_attn
May 17, 2026
Merged

Enable trtllm_mha as gemma4 default attn backend.#25006
kpham-sgl merged 2 commits into
sgl-project:mainfrom
wenscarl:Gemma4_trtllm_attn

Conversation

@wenscarl
Copy link
Copy Markdown
Collaborator

@wenscarl wenscarl commented May 11, 2026

Summary

Enable trtllm_mha as the default attention backend for Gemma4 on SM100.

When --attention-backend is not specified for Gemma4ForConditionalGeneration, SGLang now selects:

  • trtllm_mha on SM100
  • triton otherwise

This keeps the existing non-SM100 behavior unchanged while enabling the Blackwell-optimized MHA backend for Gemma4 by default.

Benchmark

Same server flags otherwise, comparing triton vs trtllm_mha.

Server

sglang serve --model-path google/gemma-4-31B-it \
    --reasoning-parser gemma4 \
    --tool-call-parser gemma4 \
    --mem-fraction-static 0.9 \
    --host 0.0.0.0 --port 30000 --tp-size 4

Benchmark Commands

Latency, text:

python3 -m sglang.bench_serving --backend sglang \
  --host 0.0.0.0 --port 30000 \
  --dataset-name random --num-prompts 10 --max-concurrency 1

Latency, image:

python3 -m sglang.bench_serving --backend sglang-oai-chat \
  --host 0.0.0.0 --port 30000 \
  --dataset-name image --image-count 2 --image-resolution 720p \
  --random-input-len 128 --random-output-len 1024 \
  --num-prompts 10 --max-concurrency 1

Throughput, text:

python3 -m sglang.bench_serving --backend sglang \
  --host 0.0.0.0 --port 30000 \
  --dataset-name random --num-prompts 1000 --max-concurrency 100

Throughput, image:

python3 -m sglang.bench_serving --backend sglang-oai-chat \
  --host 0.0.0.0 --port 30000 \
  --dataset-name image --image-count 2 --image-resolution 720p \
  --random-input-len 128 --random-output-len 1024 \
  --num-prompts 1000 --max-concurrency 100

Note: the throughput-image benchmark had 999/1000 successful requests with trtllm_mha; one request was silently dropped with no client-side error logged, possibly due to a request abort.

Latency

concurrency=1, 10 prompts

Metric triton trtllm_mha Delta
Text Duration (s) 37.29 32.75 -12.2%
Text Output tok/s 113.2 128.9 +13.9%
Text Mean TTFT (ms) 72.05 67.04 -7.0%
Text Mean TPOT (ms) 8.55 7.60 -11.1%
Text Median ITL (ms) 8.84 7.63 -13.7%
Image Duration (s) 38.45 33.76 -12.2%
Image Output tok/s 109.8 125.0 +13.9%
Image Mean TTFT (ms) 182.55 179.90 -1.5%
Image Mean TPOT (ms) 8.62 7.57 -12.2%

Throughput

concurrency=100, 1000 prompts

Metric triton trtllm_mha Delta
Text Duration (s) 144.45 117.92 -18.4%
Text Req/s 6.92 8.48 +22.5%
Text Output tok/s 3536.6 4332.3 +22.5%
Text Total tok/s 7086.9 8681.5 +22.5%
Text Mean E2E (ms) 13794 11221 -18.7%
Text Mean TPOT (ms) 27.02 21.88 -19.0%
Text P99 TPOT (ms) 37.47 29.89 -20.2%
Image Successful 1000 999 -1 req
Image Duration (s) 249.45 228.88 -8.2%
Image Req/s 4.01 4.36 +8.7%
Image Output tok/s 2048.0 2231.6 +9.0%
Image Mean E2E (ms) 24286 22351 -8.0%
Image Mean TPOT (ms) 46.02 42.10 -8.5%
Image Mean TTFT (ms) 1270.5 1317.7 +3.7%

CI States

Latest PR Test (Base): Run #25998857394
Latest PR Test (Extra): ⚠️ Not enabled — add run-ci-extra label to opt in.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@wenscarl wenscarl marked this pull request as ready for review May 11, 2026 19:45
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@wenscarl
Copy link
Copy Markdown
Collaborator Author

cc. @nvpohanh

@kpham-sgl kpham-sgl self-assigned this May 11, 2026
Copy link
Copy Markdown
Collaborator

@kpham-sgl kpham-sgl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! Glad to know flashinfer support headdim=512 now.

if k is None and v is None:
pool = forward_batch.token_to_kv_pool
cache_loc = forward_batch.out_cache_loc
if isinstance(pool, SWAKVPool) and pool.layers_mapping[layer.layer_id][1]:
cache_loc = pool.translate_loc_from_full_to_swa(cache_loc)
k_buffer, v_buffer = pool.get_kv_buffer(layer.layer_id)
k = k_buffer[cache_loc]
v = v_buffer[cache_loc]

For E2B and E4B variants with KV cache reuse we need this extra KV cache retrieval path. If you have time can you figure out how to add similar path to the flashinfer backend? If not can you guard this change to the bigger model (31B and 26B-A4B) only?

@pyc96
Copy link
Copy Markdown
Collaborator

pyc96 commented May 11, 2026

Curious does current SGL flahsinfer vesion support it? And does it work with NVFP4 ckpt?

@kpham-sgl
Copy link
Copy Markdown
Collaborator

Curious does current SGL flahsinfer vesion support it? And does it work with NVFP4 ckpt?

@wenscarl what flashinfer version was the trtllm_mha support for headdim=512 added in?

@wenscarl
Copy link
Copy Markdown
Collaborator Author

@wenscarl what flashinfer version was the trtllm_mha support for headdim=512 added in?

v0.6.10.post1

@wenscarl
Copy link
Copy Markdown
Collaborator Author

For E2B and E4B variants with KV cache reuse we need this extra KV cache retrieval path. If you have time can you figure out how to add similar path to the flashinfer backend? If not can you guard this change to the bigger model (31B and 26B-A4B) only?

trtllm_mha actually works for E2B/E4B as-is — verified with a run on E2B and it passes.
trtllm_mha is structurally different — the trtllm-gen kernel reads K/V directly from the paged KV cache via
page_table + get_kv_buffer(layer.layer_id), never as a separate per-token window. And the KV-share redirection is
already done in the model at gemma4_causal.py:325-327:

self.attn = RadixAttention(
    ...
    layer_id=(self.kv_shared_layer_index if self.is_kv_shared_layer else self.layer_id),
    ...
)

@wenscarl wenscarl requested a review from kpham-sgl May 11, 2026 23:02
@kpham-sgl
Copy link
Copy Markdown
Collaborator

For E2B and E4B variants with KV cache reuse we need this extra KV cache retrieval path. If you have time can you figure out how to add similar path to the flashinfer backend? If not can you guard this change to the bigger model (31B and 26B-A4B) only?

trtllm_mha actually works for E2B/E4B as-is — verified with a run on E2B and it passes. trtllm_mha is structurally different — the trtllm-gen kernel reads K/V directly from the paged KV cache via page_table + get_kv_buffer(layer.layer_id), never as a separate per-token window. And the KV-share redirection is already done in the model at gemma4_causal.py:325-327:

self.attn = RadixAttention(
    ...
    layer_id=(self.kv_shared_layer_index if self.is_kv_shared_layer else self.layer_id),
    ...
)

ohh right. Thanks for pointing this out!

@kpham-sgl
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@kpham-sgl
Copy link
Copy Markdown
Collaborator

Will merge once we upgrade flashinfer to v0.6.10.post1

@kpham-sgl
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@kpham-sgl
Copy link
Copy Markdown
Collaborator

Accuracy checks in #25461 (comment)

@kpham-sgl kpham-sgl merged commit c67b287 into sgl-project:main May 17, 2026
113 of 124 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants