From 4385d6aa133acbf0cb3b50f84259ee9c3748dfcb Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 17:05:07 +0800 Subject: [PATCH 01/11] add new gemma4 routing kernel Signed-off-by: tjtanaa --- vllm/model_executor/models/gemma4.py | 106 +++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index edb533134995..7eae4635e86f 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -65,10 +65,112 @@ make_layers, maybe_prefix, ) +import triton +import triton.language as tl + +from vllm.platforms import current_platform logger = init_logger(__name__) +@triton.jit +def _gemma4_routing_kernel( + gating_ptr, + per_expert_scale_ptr, + topk_weights_ptr, + topk_ids_ptr, + E: tl.constexpr, + K: tl.constexpr, + BLOCK_E: tl.constexpr, +): + """ + Sort-v3: eliminates the K serial masked-sum reductions after sort. + + Vs sort_v2: instead of K independent tl.sum(tl.where(offs_e==i, sorted_p, 0)) + reductions (K LDS ops), extract all BLOCK_E elements vectorized in one pass: + • Compute exp for ALL BLOCK_E sorted elements (2 VALU clocks, vectorized) + • Sum only top-K for renorm with ONE masked tl.sum (1 LDS op) + • Load scales for top-K with ONE masked gather (→ all in L1 after 1st token) + • Write ids + weights with TWO masked tl.store (no K-loop) + + ISA savings: ~5-6 total LDS ops (vs 12 in sort_v2), trading 1 extra VALU clock + for the additional BLOCK_E-K exp ops. + """ + pid = tl.program_id(0) + offs_e = tl.arange(0, BLOCK_E) + valid = offs_e < E + + logits = tl.load( + gating_ptr + pid * E + offs_e, + mask=valid, + other=-float("inf"), + ).to(tl.float32) + + max_l = tl.max(logits, axis=0) + + # Float32 → ascending-sortable bijection (same as sort_v2) + MIN32 = -2147483648 + logit_bits = logits.to(tl.int32, bitcast=True) + sign_b = logit_bits >> 31 + key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32) + key = tl.where(valid, key, 0x7FFFFFFF) + sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF + packed = (sk64 << 32) | offs_e.to(tl.int64) + sorted_p = tl.sort(packed, descending=False) + + # Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions + all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) + all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) + + # Inverse bijection: recover original logit bits + sign_k = all_keys >> 31 + all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) + all_logits = all_bits.to(tl.float32, bitcast=True) + + # Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks + all_raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634) + + # Sum only top-K for renorm — ONE masked reduction (vs K serial reductions in sort_v2) + top_mask = offs_e < K + renorm_raw = tl.sum(tl.where(top_mask, all_raw_exp, 0.0), axis=0) + renorm_raw = tl.where(renorm_raw > 0.0, renorm_raw, 1.0) + inv_renorm = 1.0 / renorm_raw + + # Load scales for top-K only (masked gather; scale array is tiny → L1 cached) + all_scales = tl.load( + per_expert_scale_ptr + all_ids.to(tl.int64), + mask=top_mask, + other=1.0, + ).to(tl.float32) + + # Final weights: vectorized multiply (only top-K will be stored) + all_weights = (all_raw_exp * inv_renorm * all_scales).to(tl.float32) + + # Write results with TWO masked stores — replaces K × 2 serial scalar stores + base_off = pid * K + offs_e + tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) + tl.store(topk_weights_ptr + base_off, all_weights, mask=top_mask) + + +def gemma4_routing_kernel( + gating_output: torch.Tensor, + topk: int, + per_expert_scale: torch.Tensor, + num_warps: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """Sort-v3: sort + vectorized extraction. See _routing_kernel_sort_v3.""" + gating_output = gating_output.contiguous() + per_expert_scale = per_expert_scale.contiguous() + T, E = gating_output.shape + weights = torch.empty(T, topk, dtype=torch.float32, device=gating_output.device) + ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device) + BLOCK_E = triton.next_power_of_2(E) + _gemma4_routing_kernel[(T,)]( + gating_output, per_expert_scale, weights, ids, + E, topk, BLOCK_E, num_warps=num_warps, + ) + return weights, ids + def _get_text_config(config): """Dereference text_config if config is a nested Gemma4Config. @@ -206,6 +308,10 @@ def routing_function( topk: int, renormalize: bool, ) -> tuple[torch.Tensor, torch.Tensor]: + + if (current_platform.is_cuda_alike() and current_platform.is_xpu()): + return gemma4_routing_kernel(gating_output, topk, per_expert_scale) + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) indicator = torch.nn.functional.one_hot( From 850e10c383a3e594b18733236e6607ebe2d2b4b1 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 17:06:10 +0800 Subject: [PATCH 02/11] remove comments Signed-off-by: tjtanaa --- vllm/model_executor/models/gemma4.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 7eae4635e86f..753b41041bc6 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -83,19 +83,6 @@ def _gemma4_routing_kernel( K: tl.constexpr, BLOCK_E: tl.constexpr, ): - """ - Sort-v3: eliminates the K serial masked-sum reductions after sort. - - Vs sort_v2: instead of K independent tl.sum(tl.where(offs_e==i, sorted_p, 0)) - reductions (K LDS ops), extract all BLOCK_E elements vectorized in one pass: - • Compute exp for ALL BLOCK_E sorted elements (2 VALU clocks, vectorized) - • Sum only top-K for renorm with ONE masked tl.sum (1 LDS op) - • Load scales for top-K with ONE masked gather (→ all in L1 after 1st token) - • Write ids + weights with TWO masked tl.store (no K-loop) - - ISA savings: ~5-6 total LDS ops (vs 12 in sort_v2), trading 1 extra VALU clock - for the additional BLOCK_E-K exp ops. - """ pid = tl.program_id(0) offs_e = tl.arange(0, BLOCK_E) valid = offs_e < E From 02fe095665416ac998af6ec3998758b55a9aad35 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 17:52:46 +0800 Subject: [PATCH 03/11] add gemma4 router test Signed-off-by: tjtanaa --- pyproject.toml | 1 + tests/kernels/moe/test_gemma4router.py | 63 ++++++++++++++ vllm/model_executor/models/gemma4.py | 110 ++++++++++++++----------- 3 files changed, 126 insertions(+), 48 deletions(-) create mode 100644 tests/kernels/moe/test_gemma4router.py diff --git a/pyproject.toml b/pyproject.toml index fad8c8c687a1..c2f3dd33aebf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,6 +169,7 @@ eles = "eles" datas = "datas" ser = "ser" ure = "ure" +VALU = "VALU" [tool.uv] no-build-isolation-package = ["torch"] diff --git a/tests/kernels/moe/test_gemma4router.py b/tests/kernels/moe/test_gemma4router.py new file mode 100644 index 000000000000..4fde720b4aae --- /dev/null +++ b/tests/kernels/moe/test_gemma4router.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm.model_executor.models.gemma4 import ( + gemma4_fused_routing_kernel_triton, + gemma4_routing_function_torch, +) + +NUM_TOKENS = [] + +# Gemma4 Moe Model has context length of 250K +# the minus 1 is to ensure that edge cases are tested +for t in range(1, 19): + tlen = 2**t + tlen_minus1 = tlen - 1 + NUM_TOKENS.extend([tlen_minus1, tlen]) + + +def sort_by_id(w, ids): + order = ids.argsort(dim=-1) + return w.gather(1, order), ids.gather(1, order) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_experts", [128]) # gemma4 moe experts +@pytest.mark.parametrize("topk", [8]) # gemma4 topk +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +def test_gemma4_routing_kernel_triton( + num_tokens: int, + num_experts: int, + topk: int, + dtype: torch.dtype, +): + torch.manual_seed(0) + + gating = torch.randn(num_tokens, num_experts, dtype=dtype, device="cuda") + scales = torch.rand(num_experts, dtype=torch.float32, device="cuda") + + ref_w, ref_ids = gemma4_routing_function_torch(gating, topk, scales) + tri_w, tri_ids = gemma4_fused_routing_kernel_triton(gating, topk, scales) + + # Sort by expert id — to remove tie-breaking differences + ref_ws, ref_is = sort_by_id(ref_w, ref_ids) + tri_ws, tri_is = sort_by_id(tri_w, tri_ids) + + ids_ok = (ref_is == tri_is).all().item() + weights_ok = torch.allclose(ref_ws, tri_ws, atol=1e-4, rtol=1e-4) + ok = ids_ok and weights_ok + max_err = (ref_ws - tri_ws).abs().max().item() + print( + f"T={num_tokens:5d} E={num_experts:4d} K={topk} " + f"{str(dtype).split('.')[-1]:7s} ids={ids_ok} max_Δweight={max_err:.2e}" + ) + if not ok: + bad = (ref_is != tri_is).any(dim=-1).nonzero(as_tuple=True)[0] + if len(bad): + r = bad[0].item() + print( + f" first bad row {r}: ref_ids={ref_ids[r].tolist()} " + f"tri_ids={tri_ids[r].tolist()}" + ) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 753b41041bc6..35113c8ab869 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -55,7 +55,9 @@ default_weight_loader, maybe_remap_kv_scale_name, ) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import ( @@ -65,10 +67,6 @@ make_layers, maybe_prefix, ) -import triton -import triton.language as tl - -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -79,13 +77,13 @@ def _gemma4_routing_kernel( per_expert_scale_ptr, topk_weights_ptr, topk_ids_ptr, - E: tl.constexpr, - K: tl.constexpr, + E: tl.constexpr, + K: tl.constexpr, BLOCK_E: tl.constexpr, ): - pid = tl.program_id(0) + pid = tl.program_id(0) offs_e = tl.arange(0, BLOCK_E) - valid = offs_e < E + valid = offs_e < E logits = tl.load( gating_ptr + pid * E + offs_e, @@ -95,30 +93,30 @@ def _gemma4_routing_kernel( max_l = tl.max(logits, axis=0) - # Float32 → ascending-sortable bijection (same as sort_v2) - MIN32 = -2147483648 + # Float32 → ascending-sortable bijection + MIN32 = -2147483648 logit_bits = logits.to(tl.int32, bitcast=True) - sign_b = logit_bits >> 31 - key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32) - key = tl.where(valid, key, 0x7FFFFFFF) - sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF - packed = (sk64 << 32) | offs_e.to(tl.int64) - sorted_p = tl.sort(packed, descending=False) + sign_b = logit_bits >> 31 + key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32) + key = tl.where(valid, key, 0x7FFFFFFF) + sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF + packed = (sk64 << 32) | offs_e.to(tl.int64) + sorted_p = tl.sort(packed, descending=False) # Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions - all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) - all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) + all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) + all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) # Inverse bijection: recover original logit bits - sign_k = all_keys >> 31 - all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) + sign_k = all_keys >> 31 + all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) all_logits = all_bits.to(tl.float32, bitcast=True) # Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks all_raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634) - # Sum only top-K for renorm — ONE masked reduction (vs K serial reductions in sort_v2) - top_mask = offs_e < K + # Sum only top-K for renorm — ONE masked reduction + top_mask = offs_e < K renorm_raw = tl.sum(tl.where(top_mask, all_raw_exp, 0.0), axis=0) renorm_raw = tl.where(renorm_raw > 0.0, renorm_raw, 1.0) inv_renorm = 1.0 / renorm_raw @@ -135,29 +133,59 @@ def _gemma4_routing_kernel( # Write results with TWO masked stores — replaces K × 2 serial scalar stores base_off = pid * K + offs_e - tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) + tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) tl.store(topk_weights_ptr + base_off, all_weights, mask=top_mask) -def gemma4_routing_kernel( +def gemma4_fused_routing_kernel_triton( gating_output: torch.Tensor, topk: int, per_expert_scale: torch.Tensor, num_warps: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: """Sort-v3: sort + vectorized extraction. See _routing_kernel_sort_v3.""" - gating_output = gating_output.contiguous() + gating_output = gating_output.contiguous() per_expert_scale = per_expert_scale.contiguous() - T, E = gating_output.shape + T, E = gating_output.shape weights = torch.empty(T, topk, dtype=torch.float32, device=gating_output.device) - ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device) + ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device) BLOCK_E = triton.next_power_of_2(E) _gemma4_routing_kernel[(T,)]( - gating_output, per_expert_scale, weights, ids, - E, topk, BLOCK_E, num_warps=num_warps, + gating_output, + per_expert_scale, + weights, + ids, + E, + topk, + BLOCK_E, + num_warps=num_warps, ) return weights, ids + +def gemma4_routing_function_torch( + gating_output: torch.Tensor, + topk: int, + per_expert_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) + indicator = torch.nn.functional.one_hot( + topk_ids, num_classes=gating_output.size(-1) + ).sum(dim=-2) + gate_weights = indicator * router_probabilities + renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) + renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) + dispatch_weights = gate_weights / renorm_factor + + topk_weights = dispatch_weights.gather(1, topk_ids) + + # Fold per_expert_scale into routing weights + expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) + topk_weights = topk_weights * expert_scales + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + def _get_text_config(config): """Dereference text_config if config is a nested Gemma4Config. @@ -295,26 +323,12 @@ def routing_function( topk: int, renormalize: bool, ) -> tuple[torch.Tensor, torch.Tensor]: + if current_platform.is_cuda_alike() and current_platform.is_xpu(): + return gemma4_fused_routing_kernel_triton( + gating_output, topk, per_expert_scale + ) - if (current_platform.is_cuda_alike() and current_platform.is_xpu()): - return gemma4_routing_kernel(gating_output, topk, per_expert_scale) - - _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) - router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) - indicator = torch.nn.functional.one_hot( - topk_ids, num_classes=gating_output.size(-1) - ).sum(dim=-2) - gate_weights = indicator * router_probabilities - renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) - renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) - dispatch_weights = gate_weights / renorm_factor - - topk_weights = dispatch_weights.gather(1, topk_ids) - - # Fold per_expert_scale into routing weights - expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) - topk_weights = topk_weights * expert_scales - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + return gemma4_routing_function_torch(gating_output, topk, per_expert_scale) # FusedMoE experts with custom Gemma4 routing self.experts = FusedMoE( From beb1b56545ab2be6aa9d29bd5577a60ed7f720db Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 21:16:43 +0800 Subject: [PATCH 04/11] change to or Signed-off-by: tjtanaa --- vllm/model_executor/models/gemma4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 35113c8ab869..65a15a668817 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -323,7 +323,7 @@ def routing_function( topk: int, renormalize: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - if current_platform.is_cuda_alike() and current_platform.is_xpu(): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): return gemma4_fused_routing_kernel_triton( gating_output, topk, per_expert_scale ) From d14e399ebc1ab640baaf7864daecae4317269667 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 23:14:23 +0800 Subject: [PATCH 05/11] remove xpu support for now Signed-off-by: tjtanaa --- vllm/model_executor/models/gemma4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 65a15a668817..9b806bbc14f0 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -323,7 +323,7 @@ def routing_function( topk: int, renormalize: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - if current_platform.is_cuda_alike() or current_platform.is_xpu(): + if current_platform.is_cuda_alike(): return gemma4_fused_routing_kernel_triton( gating_output, topk, per_expert_scale ) From 671871e2c01c58fc6a844a70787d3ef626d94ea4 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 23:34:16 +0800 Subject: [PATCH 06/11] remove legacy comment Signed-off-by: tjtanaa --- vllm/model_executor/models/gemma4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 4b2181582183..77c68475dc13 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -146,7 +146,6 @@ def gemma4_fused_routing_kernel_triton( per_expert_scale: torch.Tensor, num_warps: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: - """Sort-v3: sort + vectorized extraction. See _routing_kernel_sort_v3.""" gating_output = gating_output.contiguous() per_expert_scale = per_expert_scale.contiguous() T, E = gating_output.shape From 7e862f5ffa4b45d5894d4698d6d8730137c33698 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 23:46:01 +0800 Subject: [PATCH 07/11] update atol rtol to match other router Signed-off-by: tjtanaa --- tests/kernels/moe/test_gemma4router.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_gemma4router.py b/tests/kernels/moe/test_gemma4router.py index 4fde720b4aae..2ce048c26f03 100644 --- a/tests/kernels/moe/test_gemma4router.py +++ b/tests/kernels/moe/test_gemma4router.py @@ -46,7 +46,7 @@ def test_gemma4_routing_kernel_triton( tri_ws, tri_is = sort_by_id(tri_w, tri_ids) ids_ok = (ref_is == tri_is).all().item() - weights_ok = torch.allclose(ref_ws, tri_ws, atol=1e-4, rtol=1e-4) + weights_ok = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2) ok = ids_ok and weights_ok max_err = (ref_ws - tri_ws).abs().max().item() print( @@ -61,3 +61,4 @@ def test_gemma4_routing_kernel_triton( f" first bad row {r}: ref_ids={ref_ids[r].tolist()} " f"tri_ids={tri_ids[r].tolist()}" ) + assert ok From 4b528c6e014f6096e35ec73e0705f49d8af0a768 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 6 Apr 2026 23:47:15 +0800 Subject: [PATCH 08/11] clean up Signed-off-by: tjtanaa --- tests/kernels/moe/test_gemma4router.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kernels/moe/test_gemma4router.py b/tests/kernels/moe/test_gemma4router.py index 2ce048c26f03..4306d3c3d1cd 100644 --- a/tests/kernels/moe/test_gemma4router.py +++ b/tests/kernels/moe/test_gemma4router.py @@ -45,15 +45,15 @@ def test_gemma4_routing_kernel_triton( ref_ws, ref_is = sort_by_id(ref_w, ref_ids) tri_ws, tri_is = sort_by_id(tri_w, tri_ids) - ids_ok = (ref_is == tri_is).all().item() - weights_ok = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2) - ok = ids_ok and weights_ok + ids_match = (ref_is == tri_is).all().item() + weights_match = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2) + all_match = ids_match and weights_match max_err = (ref_ws - tri_ws).abs().max().item() print( f"T={num_tokens:5d} E={num_experts:4d} K={topk} " - f"{str(dtype).split('.')[-1]:7s} ids={ids_ok} max_Δweight={max_err:.2e}" + f"{str(dtype).split('.')[-1]:7s} ids={ids_match} max_Δweight={max_err:.2e}" ) - if not ok: + if not all_match: bad = (ref_is != tri_is).any(dim=-1).nonzero(as_tuple=True)[0] if len(bad): r = bad[0].item() @@ -61,4 +61,4 @@ def test_gemma4_routing_kernel_triton( f" first bad row {r}: ref_ids={ref_ids[r].tolist()} " f"tri_ids={tri_ids[r].tolist()}" ) - assert ok + assert all_match From b5fe2179ba1d3778328389293aae1dff9085a5ab Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 7 Apr 2026 00:15:38 +0800 Subject: [PATCH 09/11] reduce test cases Signed-off-by: tjtanaa --- tests/kernels/moe/test_gemma4router.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/kernels/moe/test_gemma4router.py b/tests/kernels/moe/test_gemma4router.py index 4306d3c3d1cd..d4cbcc2be729 100644 --- a/tests/kernels/moe/test_gemma4router.py +++ b/tests/kernels/moe/test_gemma4router.py @@ -8,22 +8,15 @@ gemma4_routing_function_torch, ) -NUM_TOKENS = [] - -# Gemma4 Moe Model has context length of 250K -# the minus 1 is to ensure that edge cases are tested -for t in range(1, 19): - tlen = 2**t - tlen_minus1 = tlen - 1 - NUM_TOKENS.extend([tlen_minus1, tlen]) - def sort_by_id(w, ids): order = ids.argsort(dim=-1) return w.gather(1, order), ids.gather(1, order) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +# Gemma4 Moe Model has context length of 250K +# the minus 1 is to ensure that edge cases are tested +@pytest.mark.parametrize("num_tokens", [1, 2, 2048, 250000]) @pytest.mark.parametrize("num_experts", [128]) # gemma4 moe experts @pytest.mark.parametrize("topk", [8]) # gemma4 topk @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) From 48f6a7105532dd3036b85fe380ddd062de6cfda3 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Tue, 7 Apr 2026 22:01:47 +0800 Subject: [PATCH 10/11] enable triton kernel on xpu Signed-off-by: tjtanaa --- vllm/model_executor/models/gemma4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 77c68475dc13..2d636c10a526 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -325,7 +325,7 @@ def routing_function( topk: int, renormalize: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - if current_platform.is_cuda_alike(): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): return gemma4_fused_routing_kernel_triton( gating_output, topk, per_expert_scale ) From 59312b2dad9efc889a35660260fe5b410d832c2c Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sun, 19 Apr 2026 08:12:48 +0000 Subject: [PATCH 11/11] fix nits from reviewer Signed-off-by: tjtanaa --- tests/kernels/moe/test_gemma4router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_gemma4router.py b/tests/kernels/moe/test_gemma4router.py index d4cbcc2be729..ba69d6927495 100644 --- a/tests/kernels/moe/test_gemma4router.py +++ b/tests/kernels/moe/test_gemma4router.py @@ -14,7 +14,7 @@ def sort_by_id(w, ids): return w.gather(1, order), ids.gather(1, order) -# Gemma4 Moe Model has context length of 250K +# Gemma4 MoE Model has context length of 250K # the minus 1 is to ensure that edge cases are tested @pytest.mark.parametrize("num_tokens", [1, 2, 2048, 250000]) @pytest.mark.parametrize("num_experts", [128]) # gemma4 moe experts