Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ eles = "eles"
datas = "datas"
ser = "ser"
ure = "ure"
VALU = "VALU"
# Walsh-Hadamard Transform
wht = "wht"
WHT = "WHT"
Expand Down
57 changes: 57 additions & 0 deletions tests/kernels/moe/test_gemma4router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm.model_executor.models.gemma4 import (
gemma4_fused_routing_kernel_triton,
gemma4_routing_function_torch,
)


def sort_by_id(w, ids):
order = ids.argsort(dim=-1)
return w.gather(1, order), ids.gather(1, order)


# Gemma4 MoE Model has context length of 250K
# the minus 1 is to ensure that edge cases are tested
@pytest.mark.parametrize("num_tokens", [1, 2, 2048, 250000])
@pytest.mark.parametrize("num_experts", [128]) # gemma4 moe experts
@pytest.mark.parametrize("topk", [8]) # gemma4 topk
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32])
def test_gemma4_routing_kernel_triton(
num_tokens: int,
num_experts: int,
topk: int,
dtype: torch.dtype,
):
torch.manual_seed(0)

gating = torch.randn(num_tokens, num_experts, dtype=dtype, device="cuda")
scales = torch.rand(num_experts, dtype=torch.float32, device="cuda")

ref_w, ref_ids = gemma4_routing_function_torch(gating, topk, scales)
tri_w, tri_ids = gemma4_fused_routing_kernel_triton(gating, topk, scales)

# Sort by expert id — to remove tie-breaking differences
ref_ws, ref_is = sort_by_id(ref_w, ref_ids)
tri_ws, tri_is = sort_by_id(tri_w, tri_ids)

ids_match = (ref_is == tri_is).all().item()
weights_match = torch.allclose(ref_ws, tri_ws, atol=1e-2, rtol=1e-2)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the atol here too large?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the routing function especially in bf16 has larger numerical divergence.

I am currently following the threshold used in existing unit tests of the fused routing methods:

topk_weights_ref.to(torch.float32), topk_weights, atol=1e-2, rtol=1e-2

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bfloat16

T= 1 E= 128 K=8 bfloat16 ids=True max_Δweight=6.34e-04
T= 2 E= 128 K=8 bfloat16 ids=True max_Δweight=6.34e-04
T= 2048 E= 128 K=8 bfloat16 ids=True max_Δweight=3.40e-03
T=250000 E= 128 K=8 bfloat16 ids=True max_Δweight=3.89e-03

float16

T= 1 E= 128 K=8 float16 ids=True max_Δweight=6.76e-05
T= 2 E= 128 K=8 float16 ids=True max_Δweight=6.76e-05
T= 2048 E= 128 K=8 float16 ids=True max_Δweight=2.81e-04
T=250000 E= 128 K=8 float16 ids=True max_Δweight=7.93e-04

float32

T= 1 E= 128 K=8 float32 ids=True max_Δweight=7.45e-09
T= 2 E= 128 K=8 float32 ids=True max_Δweight=1.49e-08
T= 2048 E= 128 K=8 float32 ids=True max_Δweight=5.96e-08
T=250000 E= 128 K=8 float32 ids=True max_Δweight=1.19e-07

all_match = ids_match and weights_match
max_err = (ref_ws - tri_ws).abs().max().item()
print(
f"T={num_tokens:5d} E={num_experts:4d} K={topk} "
f"{str(dtype).split('.')[-1]:7s} ids={ids_match} max_Δweight={max_err:.2e}"
)
if not all_match:
bad = (ref_is != tri_is).any(dim=-1).nonzero(as_tuple=True)[0]
if len(bad):
r = bad[0].item()
print(
f" first bad row {r}: ref_ids={ref_ids[r].tolist()} "
f"tri_ids={tri_ids[r].tolist()}"
)
assert all_match
138 changes: 122 additions & 16 deletions vllm/model_executor/models/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata

from .interfaces import (
Expand All @@ -79,6 +81,120 @@
logger = init_logger(__name__)


@triton.jit
def _gemma4_routing_kernel(
gating_ptr,
per_expert_scale_ptr,
topk_weights_ptr,
topk_ids_ptr,
E: tl.constexpr,
K: tl.constexpr,
BLOCK_E: tl.constexpr,
):
pid = tl.program_id(0)
offs_e = tl.arange(0, BLOCK_E)
valid = offs_e < E

logits = tl.load(
gating_ptr + pid * E + offs_e,
mask=valid,
other=-float("inf"),
).to(tl.float32)

max_l = tl.max(logits, axis=0)

# Float32 → ascending-sortable bijection
MIN32 = -2147483648
logit_bits = logits.to(tl.int32, bitcast=True)
sign_b = logit_bits >> 31
key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32)
key = tl.where(valid, key, 0x7FFFFFFF)
sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF
packed = (sk64 << 32) | offs_e.to(tl.int64)
sorted_p = tl.sort(packed, descending=False)

# Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions
all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32)
all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32)

# Inverse bijection: recover original logit bits
sign_k = all_keys >> 31
all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32)
all_logits = all_bits.to(tl.float32, bitcast=True)
Comment thread
tjtanaa marked this conversation as resolved.

# Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks
all_raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634)

# Sum only top-K for renorm — ONE masked reduction
top_mask = offs_e < K
renorm_raw = tl.sum(tl.where(top_mask, all_raw_exp, 0.0), axis=0)
renorm_raw = tl.where(renorm_raw > 0.0, renorm_raw, 1.0)
inv_renorm = 1.0 / renorm_raw

# Load scales for top-K only (masked gather; scale array is tiny → L1 cached)
all_scales = tl.load(
per_expert_scale_ptr + all_ids.to(tl.int64),
mask=top_mask,
other=1.0,
).to(tl.float32)

# Final weights: vectorized multiply (only top-K will be stored)
all_weights = (all_raw_exp * inv_renorm * all_scales).to(tl.float32)

# Write results with TWO masked stores — replaces K × 2 serial scalar stores
base_off = pid * K + offs_e
tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask)
tl.store(topk_weights_ptr + base_off, all_weights, mask=top_mask)


def gemma4_fused_routing_kernel_triton(
gating_output: torch.Tensor,
Comment thread
tjtanaa marked this conversation as resolved.
topk: int,
per_expert_scale: torch.Tensor,
Comment thread
tjtanaa marked this conversation as resolved.
num_warps: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
gating_output = gating_output.contiguous()
per_expert_scale = per_expert_scale.contiguous()
T, E = gating_output.shape
weights = torch.empty(T, topk, dtype=torch.float32, device=gating_output.device)
ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device)
BLOCK_E = triton.next_power_of_2(E)
_gemma4_routing_kernel[(T,)](
gating_output,
per_expert_scale,
weights,
ids,
E,
topk,
BLOCK_E,
num_warps=num_warps,
)
return weights, ids


def gemma4_routing_function_torch(
gating_output: torch.Tensor,
topk: int,
per_expert_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = torch.nn.functional.one_hot(
topk_ids, num_classes=gating_output.size(-1)
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor

topk_weights = dispatch_weights.gather(1, topk_ids)

# Fold per_expert_scale into routing weights
expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
topk_weights = topk_weights * expert_scales
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


def _get_text_config(config):
"""Dereference text_config if config is a nested Gemma4Config.

Expand Down Expand Up @@ -216,22 +332,12 @@ def routing_function(
topk: int,
renormalize: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
_, topk_ids = torch.topk(gating_output, k=topk, dim=-1)
router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1)
indicator = torch.nn.functional.one_hot(
topk_ids, num_classes=gating_output.size(-1)
).sum(dim=-2)
gate_weights = indicator * router_probabilities
renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True)
renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0)
dispatch_weights = gate_weights / renorm_factor

topk_weights = dispatch_weights.gather(1, topk_ids)

# Fold per_expert_scale into routing weights
expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype)
topk_weights = topk_weights * expert_scales
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
if current_platform.is_cuda_alike() or current_platform.is_xpu():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use something like custom op to dispatch?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Op itself is a very model specific op, and all in-tree models are benefited from the same triton kernel. So that's why this PR is shaped it this way.

I will follow up with a vLLM IR Op as CustomOp is going to be replaced by vLLM IR Op.

return gemma4_fused_routing_kernel_triton(
gating_output, topk, per_expert_scale
)

return gemma4_routing_function_torch(gating_output, topk, per_expert_scale)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you just add torch.compile to this function?

Copy link
Copy Markdown
Member Author

@tjtanaa tjtanaa Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try. However, in the microbenchmark, I found that the triton kernel is still faster than the torch compile implementations.

GPU vs Eager vs compile-default vs compile-reduce-overhead vs compile-max-autotune
A100 12.68× 4.50× 4.97× 4.45×
H100 10.58× 4.51× 5.02× 4.41×
MI300X 14.37× 6.00× 7.24× 5.81×
B60 6.33× 1.08× 1.07× 1.05×

Copy link
Copy Markdown
Member Author

@tjtanaa tjtanaa Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run the test with two different modes on MI300X

compiled_routing_function_torch = torch.compile(routing_function_torch)

and

compiled_routing_function_torch = torch.compile(routing_function_torch, mode='max-autotune-no-cudagraphs')

Metric Torch Torch Compile (default) Torch Compile (max-autotune) Triton Kernel
Benchmark Duration (s) 77.99 70.49 69.46 64.18
Request Throughput (req/s) 2.05 2.27 2.30 2.49
Output Token Throughput (tok/s) 1,050 1,162 1,179 1,276
Total Token Throughput (tok/s) 5,252 5,811 5,897 6,382
Mean TTFT (ms) 603 589 594 583
Mean TPOT (ms) 14.07 12.63 12.43 11.41
Mean ITL (ms) 14.07 12.63 12.43 11.41
GSM8K Accuracy (flexible) 43.06% 43.37% 45.26% 42.91%

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It matches the microbenchmark speed up trend

mi300xbenchmark.log

Additional information, the following are the microbenchmark on a100 and h100

a100benchmarkkernel.log

H100benchmarkkernel.log

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg This is the benchmark perf for H100

compiled_routing_function_torch = torch.compile(routing_function_torch, mode='max-autotune-no-cudagraphs')

vs Triton Kernel

Summary Statistics (Averaged across 3 runs)

Metric Torch Compile Triton Winner
Throughput
Request throughput (req/s) 2.43 2.47 ✅ Triton (+1.6%)
Output token throughput (tok/s) 1244.44 1266.30 ✅ Triton (+1.8%)
Total token throughput (tok/s) 6222.20 6331.49 ✅ Triton (+1.8%)
Peak output token throughput (tok/s) 1552.00 1584.00 ✅ Triton (+2.1%)
Latency
Mean TTFT (ms) 760.64 741.30 ✅ Triton (-2.5%)
Median TTFT (ms) 592.70 584.92 ✅ Triton (-1.3%)
P99 TTFT (ms) 3230.87 3174.35 ✅ Triton (-1.7%)
Mean TPOT (ms) 11.41 11.24 ✅ Triton (-1.5%)
Median TPOT (ms) 11.34 11.17 ✅ Triton (-1.5%)
P99 TPOT (ms) 12.57 12.64 ✅ Torch Compile (+0.6%)
Mean ITL (ms) 11.42 11.26 ✅ Triton (-1.4%)
Median ITL (ms) 10.82 10.63 ✅ Triton (-1.8%)
Duration
Benchmark duration (s) 65.98 64.90 ✅ Triton (-1.6%)


# FusedMoE experts with custom Gemma4 routing
self.experts = FusedMoE(
Expand Down
Loading