diff --git a/pyproject.toml b/pyproject.toml index f55dd9308bd5..5c87de018c10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -170,6 +170,7 @@ eles = "eles" datas = "datas" ser = "ser" ure = "ure" +VALU = "VALU" # Walsh-Hadamard Transform wht = "wht" WHT = "WHT" diff --git a/tests/kernels/moe/test_gemma4router.py b/tests/kernels/moe/test_gemma4router.py new file mode 100644 index 000000000000..ba69d6927495 --- /dev/null +++ b/tests/kernels/moe/test_gemma4router.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch + +from vllm.model_executor.models.gemma4 import ( + gemma4_fused_routing_kernel_triton, + gemma4_routing_function_torch, +) + + +def sort_by_id(w, ids): + order = ids.argsort(dim=-1) + return w.gather(1, order), ids.gather(1, order) + + +# Gemma4 MoE Model has context length of 250K +# the minus 1 is to ensure that edge cases are tested +@pytest.mark.parametrize("num_tokens", [1, 2, 2048, 250000]) +@pytest.mark.parametrize("num_experts", [128]) # gemma4 moe experts +@pytest.mark.parametrize("topk", [8]) # gemma4 topk +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +def test_gemma4_routing_kernel_triton( + num_tokens: int, + num_experts: int, + topk: int, + dtype: torch.dtype, +): + torch.manual_seed(0) + + gating = torch.randn(num_tokens, num_experts, dtype=dtype, device="cuda") + scales = torch.rand(num_experts, dtype=torch.float32, device="cuda") + + ref_w, ref_ids = gemma4_routing_function_torch(gating, topk, scales) + tri_w, tri_ids = gemma4_fused_routing_kernel_triton(gating, topk, scales) + + # Sort by expert id — to remove tie-breaking differences + ref_ws, ref_is = sort_by_id(ref_w, ref_ids) + tri_ws, tri_is = sort_by_id(tri_w, tri_ids) + + ids_match = (ref_is == tri_is).all().item() + weights_match = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2) + all_match = ids_match and weights_match + max_err = (ref_ws - tri_ws).abs().max().item() + print( + f"T={num_tokens:5d} E={num_experts:4d} K={topk} " + f"{str(dtype).split('.')[-1]:7s} ids={ids_match} max_Δweight={max_err:.2e}" + ) + if not all_match: + bad = (ref_is != tri_is).any(dim=-1).nonzero(as_tuple=True)[0] + if len(bad): + r = bad[0].item() + print( + f" first bad row {r}: ref_ids={ref_ids[r].tolist()} " + f"tri_ids={tri_ids[r].tolist()}" + ) + assert all_match diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py index 06189540090d..d166a9df38ac 100644 --- a/vllm/model_executor/models/gemma4.py +++ b/vllm/model_executor/models/gemma4.py @@ -57,7 +57,9 @@ default_weight_loader, maybe_remap_kv_scale_name, ) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import ( @@ -79,6 +81,120 @@ logger = init_logger(__name__) +@triton.jit +def _gemma4_routing_kernel( + gating_ptr, + per_expert_scale_ptr, + topk_weights_ptr, + topk_ids_ptr, + E: tl.constexpr, + K: tl.constexpr, + BLOCK_E: tl.constexpr, +): + pid = tl.program_id(0) + offs_e = tl.arange(0, BLOCK_E) + valid = offs_e < E + + logits = tl.load( + gating_ptr + pid * E + offs_e, + mask=valid, + other=-float("inf"), + ).to(tl.float32) + + max_l = tl.max(logits, axis=0) + + # Float32 → ascending-sortable bijection + MIN32 = -2147483648 + logit_bits = logits.to(tl.int32, bitcast=True) + sign_b = logit_bits >> 31 + key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32) + key = tl.where(valid, key, 0x7FFFFFFF) + sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF + packed = (sk64 << 32) | offs_e.to(tl.int64) + sorted_p = tl.sort(packed, descending=False) + + # Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions + all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) + all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) + + # Inverse bijection: recover original logit bits + sign_k = all_keys >> 31 + all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) + all_logits = all_bits.to(tl.float32, bitcast=True) + + # Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks + all_raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634) + + # Sum only top-K for renorm — ONE masked reduction + top_mask = offs_e < K + renorm_raw = tl.sum(tl.where(top_mask, all_raw_exp, 0.0), axis=0) + renorm_raw = tl.where(renorm_raw > 0.0, renorm_raw, 1.0) + inv_renorm = 1.0 / renorm_raw + + # Load scales for top-K only (masked gather; scale array is tiny → L1 cached) + all_scales = tl.load( + per_expert_scale_ptr + all_ids.to(tl.int64), + mask=top_mask, + other=1.0, + ).to(tl.float32) + + # Final weights: vectorized multiply (only top-K will be stored) + all_weights = (all_raw_exp * inv_renorm * all_scales).to(tl.float32) + + # Write results with TWO masked stores — replaces K × 2 serial scalar stores + base_off = pid * K + offs_e + tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) + tl.store(topk_weights_ptr + base_off, all_weights, mask=top_mask) + + +def gemma4_fused_routing_kernel_triton( + gating_output: torch.Tensor, + topk: int, + per_expert_scale: torch.Tensor, + num_warps: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + gating_output = gating_output.contiguous() + per_expert_scale = per_expert_scale.contiguous() + T, E = gating_output.shape + weights = torch.empty(T, topk, dtype=torch.float32, device=gating_output.device) + ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device) + BLOCK_E = triton.next_power_of_2(E) + _gemma4_routing_kernel[(T,)]( + gating_output, + per_expert_scale, + weights, + ids, + E, + topk, + BLOCK_E, + num_warps=num_warps, + ) + return weights, ids + + +def gemma4_routing_function_torch( + gating_output: torch.Tensor, + topk: int, + per_expert_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) + indicator = torch.nn.functional.one_hot( + topk_ids, num_classes=gating_output.size(-1) + ).sum(dim=-2) + gate_weights = indicator * router_probabilities + renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) + renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) + dispatch_weights = gate_weights / renorm_factor + + topk_weights = dispatch_weights.gather(1, topk_ids) + + # Fold per_expert_scale into routing weights + expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) + topk_weights = topk_weights * expert_scales + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + def _get_text_config(config): """Dereference text_config if config is a nested Gemma4Config. @@ -216,22 +332,12 @@ def routing_function( topk: int, renormalize: bool, ) -> tuple[torch.Tensor, torch.Tensor]: - _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) - router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) - indicator = torch.nn.functional.one_hot( - topk_ids, num_classes=gating_output.size(-1) - ).sum(dim=-2) - gate_weights = indicator * router_probabilities - renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) - renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) - dispatch_weights = gate_weights / renorm_factor - - topk_weights = dispatch_weights.gather(1, topk_ids) - - # Fold per_expert_scale into routing weights - expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) - topk_weights = topk_weights * expert_scales - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + if current_platform.is_cuda_alike() or current_platform.is_xpu(): + return gemma4_fused_routing_kernel_triton( + gating_output, topk, per_expert_scale + ) + + return gemma4_routing_function_torch(gating_output, topk, per_expert_scale) # FusedMoE experts with custom Gemma4 routing self.experts = FusedMoE(