Skip to content

[bug fix] Fix 3 issues when using Gemma4 MTP#26026

Merged
kpham-sgl merged 4 commits into
sgl-project:mainfrom
pyc96:pyc/fix/gemma4-assistant-mtp-regressions
May 23, 2026
Merged

[bug fix] Fix 3 issues when using Gemma4 MTP#26026
kpham-sgl merged 4 commits into
sgl-project:mainfrom
pyc96:pyc/fix/gemma4-assistant-mtp-regressions

Conversation

@pyc96
Copy link
Copy Markdown
Collaborator

@pyc96 pyc96 commented May 22, 2026

Modifications

  • pp_group is missing in MTP class due to it skipped Gemma4CausalLM.init
     File "transformers/modeling_utils.py", line 1395, in post_init                                                                                 
         self.init_weights()                                                                                                                        
     File "transformers/modeling_utils.py", line 3152, in init_weights                                                                              
         self.tie_weights(recompute_mapping=False)                                                                                                  
     File "gemma4_causal.py", line 1223, in tie_weights                                                                                             
         if self.pp_group.world_size > 1:                                                                                                           
     File "torch/nn/modules/module.py", line 1968, in __getattr__                                                                                   
         raise AttributeError(...)                                                                                                                  
     AttributeError: 'Gemma4AssistantForCausalLM' object has no attribute 'pp_group'  
  • num_experts can be None or not found for Dense subclass like MTP
    File "gemma4_mtp.py", line 380, in load_weights                                                                                                         result = super().load_weights(remap_assistant_weights())                                                                                   
     File "gemma4_causal.py", line 1323, in load_weights                                                                                            
         per_expert_params_mapping = FusedMoE.make_expert_params_mapping(                                                                           
     File "fused_moe_triton/layer.py", line 1140, in make_expert_params_mapping                                                                     
         for expert_id in range(num_experts)                                                                                                        
                          ^^^^^^^^^^^^^^^^^^                                                                                                        
     TypeError: 'NoneType' object cannot be interpreted as an integer   
  • current version of flashinfer doesn't support bf16 ckpt. Our auto selection of moe runner backend shouldn't use flashinfer when using bf16.

Tests

Verified by running #24552 on top of this PR. Before this PR, the server crashed during init.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

CI States

Latest PR Test (Base): ✅ Run #26266130108
Latest PR Test (Extra): ❌ Run #26266130011

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@pyc96 pyc96 marked this pull request as ready for review May 22, 2026 00:33
@pyc96 pyc96 requested a review from kpham-sgl as a code owner May 22, 2026 00:33
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@pyc96
Copy link
Copy Markdown
Collaborator Author

pyc96 commented May 22, 2026

/tag-and-rerun-ci

@pyc96 pyc96 changed the title [bug fix] Fix 2 issues when using Gemma4 MTP [bug fix] Fix 3 issues when using Gemma4 MTP May 22, 2026
@kpham-sgl kpham-sgl self-assigned this May 22, 2026
Copy link
Copy Markdown
Collaborator

@kpham-sgl kpham-sgl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the bug fix

@kpham-sgl kpham-sgl merged commit 89ff2bc into sgl-project:main May 23, 2026
118 of 130 checks passed
Shunkangz pushed a commit to Shunkangz/sglang that referenced this pull request May 27, 2026
pyc96 added a commit to pyc96/sglang that referenced this pull request May 27, 2026
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
pyc96 added a commit to pyc96/sglang that referenced this pull request May 27, 2026
Gemma4MoE.routing_function previously emitted four per-layer GPU kernels:

  torch.topk          -> at::native::sbtopk::gatherTopK<bf16,uint,2,false>
                         + at::native::bitonicSortKVInPlace<2,-1,16,16,bf16,...>
  softmax             -> at::native::cunn_SoftMaxForward<4,float,...>
  per_expert_scale[]  -> at::native::index_elementwise_kernel<bf16,...>
  topk_weights * ...  -> at::native::elementwise_kernel<MulFunctor<bf16>>
  cast to fp32        -> at::native::elementwise_kernel<copy>

torch.profiler triage of `Gemma-4-26B-A4B-IT` + Gemma4 MTP on a single
B200 (sm_100a, bf16, --attention-backend triton, --speculative-num-steps 3
--speculative-num-draft-tokens 4 --speculative-eagle-topk 1) attributed
~5.8% of decode GPU time to these split kernels.  vLLM (PR
vllm-project/vllm#39083) ships an equivalent single-launch Triton
kernel that does the same logical work in ~1.1% of its decode GPU time.

This commit ports the algorithm to SGLang:

* New `_gemma4_routing_kernel` + `gemma4_fused_routing` in
  python/sglang/srt/layers/gemma4_fused_ops.py.  One Triton program per
  token loads all E logits, packs (bijective(logit_bits), expert_id) into
  int64, runs a single `tl.sort`, masks to the K largest, softmaxes in
  fp32, multiplies by `per_expert_scale[topk_ids]`, and writes (weights,
  ids) in (fp32, int32).  num_warps=1 because Gemma4 E=128 fits in a warp.
* `Gemma4MoE.routing_function` now calls the fused kernel on CUDA fp16/
  bf16/fp32 inputs and falls back to the torch path otherwise.  Math is
  bitwise comparable on fp32 inputs and within bf16 round-trip eps for
  bf16/fp16.

Real-model results on 1x B200 (host venv SGLang, baseline = PR sgl-project#26026
head + the 3 launch-blocking fixes):

  workload                       baseline       this patch     delta
  chat        random 1000/1000   2729.30 tok/s  2880.94 tok/s  +5.6%
  summariz.   random 8000/1000   1060.98 tok/s  1108.42 tok/s  +4.5%
  chat        median TPOT (ms)   21.11          20.70          -1.9%
  chat        accept length      2.75           2.80           +1.8%

MMLU @ 500 random questions (seed 0, temp 0): 0.708 vs vLLM 0.710 -- no
quality regression.

Tests: test/srt/layers/test_gemma4_fused_routing.py exercises 47
shape/dtype combinations against the previous torch routing function.

Provenance: algorithm follows vLLM `_gemma4_routing_kernel` (apache-2.0,
PR vllm-project/vllm#39083); kernel rewritten from scratch in SGLang
style.

Co-authored-by: Claude
@pyc96 pyc96 mentioned this pull request May 28, 2026
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants