Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void run(Data& data, void* stream) {
int const numBlocksCoop = smCount - 8;

int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
if (data.mPtrTopKIds == nullptr) {
if (data.mPtrScores != nullptr) {
FLASHINFER_CHECK(data.mNumExperts >= MaxSupportedTopExperts,
"Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts,
data.mNumExperts);
Expand Down
104 changes: 81 additions & 23 deletions tests/moe/test_trtllm_gen_routed_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .test_trtllm_gen_fused_moe import (
FP8BlockScaleMoe,
QuantMode,
routing_reference_no_aux,
routing_reference_renormalize,
routing_reference_renormalize_naive,
routing_reference_topk,
Expand Down Expand Up @@ -405,37 +406,38 @@ def test_trtllm_gen_fp8_routed_fused_moe(
assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%"


@pytest.mark.parametrize("num_tokens", [8, 64])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("top_k", [2, 4])
@pytest.mark.parametrize(
"routing_method_type",
[
RoutingMethodType.Renormalize,
],
)
def test_trtllm_gen_bf16_routed_fused_moe(
def run_bf16_routed_equivalence_test(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
num_experts: int,
routing_method_type: RoutingMethodType,
routed_scaling_factor: float | None = None,
routing_bias: torch.Tensor | None = None,
n_group: int | None = None,
topk_group: int | None = None,
seed: int = 42,
):
"""Test Bf16 scale routed MoE matches standard routing."""
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] not in [10]:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
torch.manual_seed(42)
torch.manual_seed(seed)
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)

# Generate random routing logits for reference
routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
torch.bfloat16
)
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = torch.randn(num_tokens, num_experts, device=device).to(
torch.float
)
if routing_bias is None:
routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16)
else:
routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
torch.bfloat16
)

# Generate random hidden states in FP8
hidden_states = (
Expand Down Expand Up @@ -464,18 +466,18 @@ def test_trtllm_gen_bf16_routed_fused_moe(
# Run reference with routing_logits
reference_output = trtllm_bf16_moe(
routing_logits=routing_logits,
routing_bias=None,
routing_bias=routing_bias,
hidden_states=hidden_states,
gemm1_weights=gemm1_weights,
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=None,
topk_group=None,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type.value,
use_shuffled_weight=True,
weight_layout=WeightLayout.BlockMajorK,
Expand All @@ -496,6 +498,18 @@ def test_trtllm_gen_bf16_routed_fused_moe(
permute_info, expert_weights_ref = routing_reference_topk(
routing_logits, top_k, num_experts, 8
)
elif routing_method_type == RoutingMethodType.DeepSeekV3:
permute_info, expert_weights_ref = routing_reference_no_aux(
routing_logits,
routing_bias,
top_k,
n_group,
topk_group,
routed_scaling_factor,
8,
)
else:
raise NotImplementedError(f"Unsupported routing method: {routing_method_type}")
topk_ids = permute_info["topKIndices"].to(torch.int32)
expert_weights = expert_weights_ref.view(num_tokens, num_experts)[
torch.arange(num_tokens, device=device).unsqueeze(1), topk_ids
Expand All @@ -515,12 +529,12 @@ def test_trtllm_gen_bf16_routed_fused_moe(
gemm2_weights=gemm2_weights,
num_experts=num_experts,
top_k=top_k,
n_group=None,
topk_group=None,
n_group=n_group,
topk_group=topk_group,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routed_scaling_factor=routed_scaling_factor,
routing_method_type=routing_method_type.value,
use_shuffled_weight=True,
weight_layout=WeightLayout.BlockMajorK,
Expand Down Expand Up @@ -700,3 +714,47 @@ def test_trtllm_gen_fp8_mxfp8_routed_activation_parity(activation_type: int):
close = torch.isclose(output_ref, output_routed, atol=1e-2, rtol=1e-2)
mismatch_pct = (~close).float().mean().item() * 100
assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%"


@pytest.mark.parametrize("num_tokens", [8, 64])
@pytest.mark.parametrize("hidden_size", [1024, 2048])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("top_k", [2, 4])
@pytest.mark.parametrize(
"routing_method_type",
[
RoutingMethodType.Renormalize,
],
)
def test_trtllm_gen_bf16_routed_fused_moe(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
num_experts: int,
routing_method_type: RoutingMethodType,
):
run_bf16_routed_equivalence_test(
num_tokens=num_tokens,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
top_k=top_k,
num_experts=num_experts,
routing_method_type=routing_method_type,
)


def test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing():
run_bf16_routed_equivalence_test(
num_tokens=1,
hidden_size=128,
intermediate_size=128,
top_k=1,
num_experts=32,
routing_method_type=RoutingMethodType.DeepSeekV3,
routed_scaling_factor=2.0,
n_group=1,
topk_group=1,
seed=0,
)
Comment on lines +748 to +760
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test for DeepSeek routing uses a single set of hardcoded parameters. To improve test coverage and ensure the fix is robust across different scenarios, consider parameterizing this test to cover a wider range of inputs. This is especially important for n_group and topk_group to test different grouping strategies.

@pytest.mark.parametrize(
    "num_tokens, hidden_size, intermediate_size, top_k, num_experts, n_group, topk_group",
    [
        (1, 128, 128, 1, 32, 1, 1),  # Original case
        (64, 256, 256, 4, 64, 1, 1),
        (8, 1024, 512, 8, 256, 8, 4),  # DSv3-like config
    ],
)
def test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing(
    num_tokens: int,
    hidden_size: int,
    intermediate_size: int,
    top_k: int,
    num_experts: int,
    n_group: int,
    topk_group: int,
):
    run_bf16_routed_equivalence_test(
        num_tokens=num_tokens,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        top_k=top_k,
        num_experts=num_experts,
        routing_method_type=RoutingMethodType.DeepSeekV3,
        routed_scaling_factor=2.0,
        n_group=n_group,
        topk_group=topk_group,
        seed=0,
    )

Loading