Skip to content

Refactor and optimize Gemma4 attention QKV proj and norm#25461

Open
pyc96 wants to merge 4 commits into
sgl-project:mainfrom
pyc96:gemma-op
Open

Refactor and optimize Gemma4 attention QKV proj and norm#25461
pyc96 wants to merge 4 commits into
sgl-project:mainfrom
pyc96:gemma-op

Conversation

@pyc96
Copy link
Copy Markdown
Collaborator

@pyc96 pyc96 commented May 16, 2026

Motivation

Modifications

  • Fuse kernel for layers when k=v.
  • The largest gain come from trtllm_mha kernel. This PR also try to validate trtllm_mha
  • Note trtllm_mha kernel only supports 314. Larger spec setup will cause a missing cubin error.

Accuracy Tests

python -m sglang.launch_server     --host 0.0.0.0     --port 18000     --model-path google/gemma-4-31B-it     --tp 1     --enable-metrics     --decode-log-interval 1     --enable-cache-report     --model-loader-extra-config '{"enable_multithread_load":true,"num_threads":64}'     --revision main     --served-model-name google/gemma-4-31B-it     --context-length 65536     --mem-fraction-static 0.85     --chunked-prefill-size 8192     --max-prefill-tokens 8192     --cuda-graph-max-bs 32     --max-running-requests 32     --speculative-algorithm NEXTN     --speculative-draft-model-path google/gemma-4-31B-it-assistant     --speculative-num-steps 3     --speculative-num-draft-tokens 4     --speculative-eagle-topk 1  [--attention-backend=trtllm_mha]

MMLU - 31B

  • Baseline before this PR: Average accuracy: 0.709
  • This PR with triton attn backend: Average accuracy: 0.709
  • This PR with trtllm_mha: Average accuracy: 0.711

MMLU - E4B

  • Before: Average accuracy: 0.588
  • After: Average accuracy: 0.587

Speed Tests and Profiling

============ Serving Benchmark Result ============                                                                                                                                                                            
Backend:                                 vllm-chat                                                                                                                                                                            
Traffic request rate:                    1.0                                                                                                                                                                                  
Max request concurrency:                 256                                                                   
Successful requests:                     64                                                                    
Benchmark duration (s):                  62.15                                                                                                                                                                                
Total input tokens:                      752156                                                                                                                                                                               
Total input text tokens:                 752156                                                                                                                                                                               
Total generated tokens:                  51200                                                                                                                                                                                
Total generated tokens (retokenized):    51197                                                                                                                                                                                
Request throughput (req/s):              1.03                                                                                                                                                                                 
Input token throughput (tok/s):          12101.35                                                                                                                                                                             
Output token throughput (tok/s):         823.75                                                                                                                                                                               
Peak output token throughput (tok/s):    656.00                                                                                                                                                                               
Peak concurrent requests:                20                                                                                                                                                                                   
Total token throughput (tok/s):          12925.10                                                                                                                                                                             
Concurrency:                             9.42                                                                  
----------------End-to-End Latency----------------                                                             
Mean E2E Latency (ms):                   9148.54                                                                                                                                                                              
Median E2E Latency (ms):                 9541.21                                                                                                                                                                              
P90 E2E Latency (ms):                    11752.22                                                                                                                                                                             
P99 E2E Latency (ms):                    12821.44                                                              
---------------Time to First Token----------------                                                                                                                                                                            
Mean TTFT (ms):                          835.47                                                                                                                                                                               
Median TTFT (ms):                        498.64                                                                
P99 TTFT (ms):                           3355.32                                                               
-----Time per Output Token (excl. 1st token)------                                                             
Mean TPOT (ms):                          10.40                                                                 
Median TPOT (ms):                        10.53                                                                 
P99 TPOT (ms):                           15.11     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           36.81                                                                 
Median ITL (ms):                         22.28     
P95 ITL (ms):                            27.47     
P99 ITL (ms):                            651.13                                                                
Max ITL (ms):                            3059.00                                                               
==================================================

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

CI States

Latest PR Test (Base): ❌ Run #26273335941
Latest PR Test (Extra): ❌ Run #26273335837

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Comment thread python/sglang/srt/models/gemma4_causal.py
Comment thread python/sglang/srt/layers/gemma4_fused_ops.py Outdated
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@pyc96
Copy link
Copy Markdown
Collaborator Author

pyc96 commented May 18, 2026

/tag-and-rerun-ci

@pyc96 pyc96 changed the title [WIP] Optimize Gemma 4 Prefill Refactor and optimize Gemma4 attention QKV proj and norm May 18, 2026
Copy link
Copy Markdown
Collaborator

@kpham-sgl kpham-sgl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests to run locally

  • Nightly CI for Gemma 4 w/o MTP in test_vlms_mmmu_eval.py
  • #24552 MTP CI and nightly tests (merging soon)

Comment on lines +273 to +274
use_k_eq_v: bool = False,
kv_shared: bool = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's represent this with a single enum, PROJ_AND_NORM_MODE or something similar

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tentatively 3 modes Q-only, QK-only, and QKV-full and then add comment that Q-only corresponds to KV sharing layers, QK-only to k_eq_v layers etc

Comment thread python/sglang/srt/layers/gemma4_fused_ops.py Outdated
Comment thread python/sglang/srt/models/gemma4_causal.py Outdated
Comment thread python/sglang/srt/models/gemma4_causal.py Outdated
pyc96 added 4 commits May 22, 2026 06:26
For attention_k_eq_v full-attention layers (10/60 in Gemma-4-31B), the
forward used to run two independent per-token RMSNorm launches:

    q = self.q_norm(q)                      # one Triton launch
    k, v = fused_kv_norm(k_raw, k_weight)   # second Triton launch

Both kernels are bandwidth-bound and visit the same set of tokens, so the
launch overhead and memory round-trip are pure waste at small decode batch
sizes (decode batch <= 32 here, x10 layers x 5 spec steps per group =
~50 saved launches per group).

This change adds a single combined kernel gemma_q_keqv_rmsnorm that:
  * normalises Q in-place head-by-head (same math as gemma_qkv_rmsnorm),
  * computes one shared rrms per (token, head) for the K=V input and
    writes both K (= norm * k_weight) and V (= norm) outputs.

The Gemma4Attention.use_k_eq_v branch now calls this fused op on the CUDA
fast path; the original q_norm + fused_kv_norm sequence is kept as a
fallback for non-CUDA / non-canonical (scale_shift != 0 or v_norm.with_scale)
configurations and for the is_kv_shared_layer path which only normalises Q.

Measured impact (gemma-4-31B-it + NEXTN spec, gsp 5k+5k+800 @ qps=1, 1xB200):

  Metric              | baseline (main+fix) | fused_kv only | + q_keqv
  -------------------+---------------------+---------------+---------
  Total tok/s         |             5160.6  |       5278.4  |  5383.0
  Output tok/s        |              328.9  |        336.4  |   343.1
  Median TPOT (ms)    |              44.87  |        42.67  |   42.26
  Mean TPOT (ms)      |              45.95  |        46.18  |   43.65

Numerical equivalence verified vs. the reference q_norm + fused_kv_norm
sequence across head_dim in {128, 256, 512}, num_q_heads/kv_heads
{32/16, 16/8}, and M up to 1024 (max bf16 abs diff <= 2e-2, expected
bf16 rounding noise; relative error well within 1%).
@pyc96
Copy link
Copy Markdown
Collaborator Author

pyc96 commented May 22, 2026

This PR now needs #26026 merged first

Test 1 — MMMU eval (Gemma 4 only)

test/registered/eval/test_vlms_mmmu_eval.py filtered to the 3 Gemma 4 entries, run on CUDA_VISIBLE_DEVICES=2,3.

Model Score (thr) Latency (thr) MoE runner Status
google/gemma-4-E4B-it 0.2802 (≥ 0.26) 6.77 s (≤ 15.0) n/a (dense)
google/gemma-4-26B-A4B-it (TP=2) 0.2802 (≥ 0.27) 7.53 s (≤ 22.3) triton (auto, BF16 fix verified)
google/gemma-4-31B-it (TP=2) 0.2901 (≥ 0.28) 7.45 s (≤ 25.5) n/a (dense)

Result: Ran 1 test in 329.649s — OK

Test 2 — Frozen-KV E4B MTP smoke test (from PR #24552)

test/registered/spec/test_frozen_kv_mtp.py run on CUDA_VISIBLE_DEVICES=0.
Model: google/gemma-4-E4B-it + google/gemma-4-E4B-it-assistant, GSM8K-200.

Variant Score (thr) avg_spec_accept_length Status
topk=1 0.7250 (≥ 0.65) 3.28
topk=3 0.7350 (≥ 0.65) 3.80

Result: Ran 1 test in 113.238s — OK

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants