Refactor and optimize Gemma4 attention QKV proj and norm#25461
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
/tag-and-rerun-ci |
| use_k_eq_v: bool = False, | ||
| kv_shared: bool = False, |
There was a problem hiding this comment.
Let's represent this with a single enum, PROJ_AND_NORM_MODE or something similar
There was a problem hiding this comment.
Tentatively 3 modes Q-only, QK-only, and QKV-full and then add comment that Q-only corresponds to KV sharing layers, QK-only to k_eq_v layers etc
For attention_k_eq_v full-attention layers (10/60 in Gemma-4-31B), the
forward used to run two independent per-token RMSNorm launches:
q = self.q_norm(q) # one Triton launch
k, v = fused_kv_norm(k_raw, k_weight) # second Triton launch
Both kernels are bandwidth-bound and visit the same set of tokens, so the
launch overhead and memory round-trip are pure waste at small decode batch
sizes (decode batch <= 32 here, x10 layers x 5 spec steps per group =
~50 saved launches per group).
This change adds a single combined kernel gemma_q_keqv_rmsnorm that:
* normalises Q in-place head-by-head (same math as gemma_qkv_rmsnorm),
* computes one shared rrms per (token, head) for the K=V input and
writes both K (= norm * k_weight) and V (= norm) outputs.
The Gemma4Attention.use_k_eq_v branch now calls this fused op on the CUDA
fast path; the original q_norm + fused_kv_norm sequence is kept as a
fallback for non-CUDA / non-canonical (scale_shift != 0 or v_norm.with_scale)
configurations and for the is_kv_shared_layer path which only normalises Q.
Measured impact (gemma-4-31B-it + NEXTN spec, gsp 5k+5k+800 @ qps=1, 1xB200):
Metric | baseline (main+fix) | fused_kv only | + q_keqv
-------------------+---------------------+---------------+---------
Total tok/s | 5160.6 | 5278.4 | 5383.0
Output tok/s | 328.9 | 336.4 | 343.1
Median TPOT (ms) | 44.87 | 42.67 | 42.26
Mean TPOT (ms) | 45.95 | 46.18 | 43.65
Numerical equivalence verified vs. the reference q_norm + fused_kv_norm
sequence across head_dim in {128, 256, 512}, num_q_heads/kv_heads
{32/16, 16/8}, and M up to 1024 (max bf16 abs diff <= 2e-2, expected
bf16 rounding noise; relative error well within 1%).
|
This PR now needs #26026 merged first Test 1 — MMMU eval (Gemma 4 only)
Result: Test 2 — Frozen-KV E4B MTP smoke test (from PR #24552)
Result: |
Motivation
Modifications
Accuracy Tests
MMLU - 31B
Average accuracy: 0.709Average accuracy: 0.709Average accuracy: 0.711MMLU - E4B
Speed Tests and Profiling
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ciCI States
Latest PR Test (Base): ❌ Run #26273335941
Latest PR Test (Extra): ❌ Run #26273335837