Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@
if _is_cuda:
from sgl_kernel import moe_fused_gate

try:
from flashinfer.fused_moe import fused_topk_deepseek
except ImportError:
fused_topk_deepseek = None

try:
from sgl_kernel import kimi_k2_moe_fused_gate
except ImportError as e:
Expand Down Expand Up @@ -732,12 +737,68 @@ def biased_grouped_topk_gpu(
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
# TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.

num_tokens = gating_output.shape[0]
num_experts = gating_output.shape[1]
experts_per_group = (
num_experts // num_expert_group if num_expert_group else num_experts
)

if (
_is_cuda
and gating_output.shape[1] // num_expert_group
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and is_power_of_two(correction_bias.shape[0])
and fused_topk_deepseek is not None
and num_fused_shared_experts == 0
and is_power_of_two(num_experts)
# flashinfer constraints
and topk <= 8
and topk_group <= num_expert_group
and topk_group * num_expert_group >= topk
and (
(experts_per_group <= 32 and experts_per_group * topk_group <= 128)
if num_expert_group > 1
else num_experts <= 384
)
):
# Pre-allocate output tensors (flashinfer mutates them in-place)
topk_weights = torch.empty(
(num_tokens, topk), dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty(
(num_tokens, topk), dtype=torch.int32, device=gating_output.device
)

# flashinfer always applies the scaling_factor internally
scaling_factor = 1.0
if routed_scaling_factor is not None and apply_routed_scaling_factor_on_output:
scaling_factor = routed_scaling_factor

# flashinfer's fused_topk_deepseek
fused_topk_deepseek(
gating_output.to(dtype=torch.float32),
correction_bias,
num_expert_group,
topk_group,
topk,
scaling_factor,
topk_weights,
topk_ids,
True,
)

if (expert_location_dispatch_info is not None) or (
num_token_non_padded is not None
):
topk_ids = _biased_grouped_topk_postprocess(
topk_ids, expert_location_dispatch_info, num_token_non_padded
)
return topk_weights, topk_ids

elif (
_is_cuda
and num_fused_shared_experts == 0
# moe_fused_gate kernel ensures that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and experts_per_group <= 32
and is_power_of_two(num_experts)
):
topk_weights, topk_ids = moe_fused_gate(
gating_output.to(dtype=torch.float32),
Expand All @@ -757,6 +818,7 @@ def biased_grouped_topk_gpu(
topk_ids, expert_location_dispatch_info, num_token_non_padded
)
return topk_weights, topk_ids

elif _use_aiter:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
token = gating_output.shape[0]
Expand Down
97 changes: 97 additions & 0 deletions test/registered/kernels/test_fused_topk_deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import pytest
import torch

from sglang.srt.layers.moe.topk import biased_grouped_topk_gpu, biased_grouped_topk_impl
from sglang.test.ci.ci_register import register_cuda_ci

register_cuda_ci(est_time=2, suite="nightly-1-gpu", nightly=True)


@pytest.mark.parametrize(
"seq_length",
list(range(1, 10))
+ [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
)
@pytest.mark.parametrize(
"params",
[
(128, 4, 2, 4), # 128 experts configuration
(256, 8, 4, 8), # DeepSeek V3 config - most important to test
(64, 2, 2, 4), # Smaller configuration
],
)
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True])
def test_fused_topk_deepseek(seq_length, params, apply_routed_scaling_factor_on_output):
"""
Test the fused_topk_deepseek code path in biased_grouped_topk_gpu.
"""
num_experts, num_expert_group, topk_group, topk = params
dtype = torch.float32

torch.manual_seed(seq_length)
hidden_states = torch.randn(seq_length, 128, dtype=dtype, device="cuda")
gating_output = torch.randn(seq_length, num_experts, dtype=dtype, device="cuda")
correction_bias = torch.randn(num_experts, dtype=dtype, device="cuda")

routed_scaling_factor = 2.5 if apply_routed_scaling_factor_on_output else None

# Fused implementation (uses fused_topk_deepseek when conditions are met)
output, indices = biased_grouped_topk_gpu(
hidden_states,
gating_output,
correction_bias,
topk=topk,
renormalize=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
num_fused_shared_experts=0,
routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)

# Reference implementation (pure PyTorch)
ref_output, ref_indices = biased_grouped_topk_impl(
hidden_states,
gating_output,
correction_bias,
topk=topk,
renormalize=True,
num_expert_group=num_expert_group,
topk_group=topk_group,
num_fused_shared_experts=0,
routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)

# Check 1: Row-wise sums should match (invariant to tie-breaking)
output_sum = output.sum(dim=-1)
ref_output_sum = ref_output.sum(dim=-1)
sum_check = torch.allclose(output_sum, ref_output_sum, rtol=1e-03, atol=1e-04)

# Check 2: Scatter-based comparison with allowance for tie-breaking
res = torch.zeros(seq_length, num_experts, dtype=torch.float32, device="cuda")
ref = torch.zeros(seq_length, num_experts, dtype=torch.float32, device="cuda")

res.scatter_(1, indices.long(), output)
ref.scatter_(1, ref_indices.long(), ref_output)

diff = torch.abs(ref - res)
atol = (
5e-03
if (seq_length >= 4096 and apply_routed_scaling_factor_on_output)
else 1e-03
)
num_large_diffs = (diff > atol).sum().item()

# Allow a small number of differences for tie-breaking situations
max_allowed_diffs = max(16, seq_length // 500)
scatter_check = num_large_diffs <= max_allowed_diffs

assert sum_check and scatter_check, (
f"Output mismatch at seq_length {seq_length}, params {params}, "
f"apply_routed_scaling_factor_on_output {apply_routed_scaling_factor_on_output}"
)


if __name__ == "__main__":
pytest.main([__file__])
10 changes: 2 additions & 8 deletions test/srt/test_deepseek_v3_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ def test_a_gsm8k(
f"{avg_spec_accept_length=:.2f}\n"
)
self.assertGreater(metrics["accuracy"], 0.935)
if is_in_amd_ci():
self.assertGreater(avg_spec_accept_length, 2.8)
else:
self.assertGreater(avg_spec_accept_length, 2.9)
self.assertGreater(avg_spec_accept_length, 2.8)

def test_bs_1_speed(self):
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
Expand All @@ -99,10 +96,7 @@ def test_bs_1_speed(self):
f"{acc_length=:.2f}\n"
f"{speed=:.2f} token/s\n"
)
if is_in_amd_ci():
self.assertGreater(acc_length, 2.8)
else:
self.assertGreater(acc_length, 2.9)
self.assertGreater(acc_length, 2.8)
if is_in_amd_ci():
self.assertGreater(speed, 15)
else:
Expand Down
Loading