Optimize Gemma4 H200 MoE and extend attention#26588
Open
BBuf wants to merge 9 commits into
Open
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Collaborator
Author
Collaborator
Author
|
/tag-and-rerun-ci |
Collaborator
Author
|
/rerun-failed-ci |
3 similar comments
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/tag-and-rerun-ci extra |
Collaborator
Author
|
/rerun-failed-ci |
1 similar comment
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
29 similar comments
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
Collaborator
Author
|
/rerun-failed-ci |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
E=128,N=704normal and down projections.Lq=129..256to reduce TTFT/TPOT on Gemma4 prefill-heavy serving.Add a Gemma4-specific Triton routing kernel that fuses top-k selection, top-k softmax, and per-expert scale intoReverted (commitgemma4_topk_softmax_scale.27cb94c45) due to BF16 precision regression on MoE routing — see "Reverted optimizations" below.ReplaceReverted (commitGemma4Router.norm(x) * fused_scalewith a single fusedrmsnorm(x, fused_scale, eps)call on the CUDA fast path.0d3c9c694) for the same reason.Performance
Re-measured on
ion8-h200,sglang_bbuf, single NVIDIA H200, TP=1,google/gemma-4-26B-A4B-it, BF16,--attention-backend triton, random1024/256, request rate 8, max concurrency 64, 80 prompts. Baseline = latestorigin/main(376635c1e); patched = PR head (27cb94c45, with c2/c3 reverted).1024/256, 80 prompts376635c1e)1024/256, 80 prompts27cb94c45)Full benchmark details:
376635c1e)27cb94c45)The remaining speedup comes from the H200 MoE Triton configs + Hopper extend-attention block-size tuning + small-batch QKV RMSNorm; the heavier routing-fusion optimizations were dropped to keep accuracy on the MTP path.
Accuracy
Validated on
ion8-h200via the registered Gemma4 MTP CI test (test_gemma4_mtp_26b_a4b_extra::test_gsm8k_mtp, TP=2,--enable-deterministic-inference, NEXTN spec decode, 200 GSM8K examples, 5-shot). With the two reverted commits, GSM8K MTP score is back to baseline:376635c1e, all reverts = main)27cb94c45, c1+c4 kept, c2+c3 reverted)PR head before revert (10ab189e3, all 4 commits)0.3604.475Standalone GSM8K (TP=1, no MTP, 8-shot, 200 questions, no deterministic mode) shows the typical run-to-run variance (~5pp swing) but no systematic regression — patched and baseline land within noise of each other (patched: 0.41 / 0.495 across two runs; baseline: 0.465 / 0.445).
Reverted optimizations
Bisecting
test_gemma4_mtp_26b_a4b_extraonion8-h200(TP=2, deterministic):fabb5b5ee(QKV RMSNorm small batch)d72d246a3(Router RMSNorm fuse)03826cdd9(MoE topk fuse)3398cb7af(extend-attn + MoE configs)d72d246a3+03826cdd9togetherBoth reverted commits modify the Gemma4 MoE routing path in
gemma4_causal.py:d72d246a3:rmsnorm(x, fused_scale, eps)fast path inGemma4Router.forward. The math is equivalent to(self.norm(x) with weight=1) * fused_scale, but BF16 accumulation order in the fused kernel diverges from the two-step path whenfused_scale ≈ hidden_size**-0.5 ≈ 0.022.03826cdd9:_gemma4_topk_softmax_scale_kerneldoing topk + stable-softmax + per-expert scale in one pass. Equivalent tosoftmax(topk_logits) * scale[topk_ids]fortopk ≤ 8, but its FP32-internalexp(top_logit - top1) / sum_top_expordering doesn't matchtorch.nn.functional.softmaxbit-for-bit.Either alone causes a ~0.02-0.035 GSM8K drop; together the routing logits drift enough to flip enough expert selections that the MTP draft/target distribution gap widens (MTP
avg_accept_lengthstill looks healthy at 4.48, masking the underlying quality regression).If you want to recover these speedups, the kernels need to be made bit-equivalent to the eager paths — e.g. accept
weight=onesin the fused router RMSNorm and applyfused_scaleas a separate elementwise op, and have_gemma4_topk_softmax_scale_kernelmirror PyTorch's exact softmax floating-point order (compute max, subtract, exp, sum, divide — all in fp32 with the same reduction tree).Validation
ion8-h200, containersglang_bbuf.E=128,N=704) parse successfully.Lq=128,192,256,288.CI States
Latest PR Test (Base): ❌ Run #26700275243
Latest PR Test (Extra): ✅ Run #26700275191