diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu index 4ad38feb26..efd81af375 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu @@ -60,7 +60,7 @@ void run(Data& data, void* stream) { int const numBlocksCoop = smCount - 8; int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; - if (data.mPtrTopKIds == nullptr) { + if (data.mPtrScores != nullptr) { FLASHINFER_CHECK(data.mNumExperts >= MaxSupportedTopExperts, "Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts); diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index 49372e2173..fe33ed4137 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -43,6 +43,7 @@ from .test_trtllm_gen_fused_moe import ( FP8BlockScaleMoe, QuantMode, + routing_reference_no_aux, routing_reference_renormalize, routing_reference_renormalize_naive, routing_reference_topk, @@ -405,37 +406,38 @@ def test_trtllm_gen_fp8_routed_fused_moe( assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%" -@pytest.mark.parametrize("num_tokens", [8, 64]) -@pytest.mark.parametrize("hidden_size", [1024, 2048]) -@pytest.mark.parametrize("intermediate_size", [1024, 2048]) -@pytest.mark.parametrize("num_experts", [8, 16]) -@pytest.mark.parametrize("top_k", [2, 4]) -@pytest.mark.parametrize( - "routing_method_type", - [ - RoutingMethodType.Renormalize, - ], -) -def test_trtllm_gen_bf16_routed_fused_moe( +def run_bf16_routed_equivalence_test( num_tokens: int, hidden_size: int, intermediate_size: int, top_k: int, num_experts: int, routing_method_type: RoutingMethodType, + routed_scaling_factor: float | None = None, + routing_bias: torch.Tensor | None = None, + n_group: int | None = None, + topk_group: int | None = None, + seed: int = 42, ): """Test Bf16 scale routed MoE matches standard routing.""" compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] not in [10]: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") - torch.manual_seed(42) + torch.manual_seed(seed) device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) # Generate random routing logits for reference - routing_logits = torch.rand(num_tokens, num_experts, device=device).to( - torch.bfloat16 - ) + if routing_method_type == RoutingMethodType.DeepSeekV3: + routing_logits = torch.randn(num_tokens, num_experts, device=device).to( + torch.float + ) + if routing_bias is None: + routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16) + else: + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) # Generate random hidden states in FP8 hidden_states = ( @@ -464,18 +466,18 @@ def test_trtllm_gen_bf16_routed_fused_moe( # Run reference with routing_logits reference_output = trtllm_bf16_moe( routing_logits=routing_logits, - routing_bias=None, + routing_bias=routing_bias, hidden_states=hidden_states, gemm1_weights=gemm1_weights, gemm2_weights=gemm2_weights, num_experts=num_experts, top_k=top_k, - n_group=None, - topk_group=None, + n_group=n_group, + topk_group=topk_group, intermediate_size=intermediate_size, local_expert_offset=0, local_num_experts=num_experts, - routed_scaling_factor=None, + routed_scaling_factor=routed_scaling_factor, routing_method_type=routing_method_type.value, use_shuffled_weight=True, weight_layout=WeightLayout.BlockMajorK, @@ -496,6 +498,18 @@ def test_trtllm_gen_bf16_routed_fused_moe( permute_info, expert_weights_ref = routing_reference_topk( routing_logits, top_k, num_experts, 8 ) + elif routing_method_type == RoutingMethodType.DeepSeekV3: + permute_info, expert_weights_ref = routing_reference_no_aux( + routing_logits, + routing_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + 8, + ) + else: + raise NotImplementedError(f"Unsupported routing method: {routing_method_type}") topk_ids = permute_info["topKIndices"].to(torch.int32) expert_weights = expert_weights_ref.view(num_tokens, num_experts)[ torch.arange(num_tokens, device=device).unsqueeze(1), topk_ids @@ -515,12 +529,12 @@ def test_trtllm_gen_bf16_routed_fused_moe( gemm2_weights=gemm2_weights, num_experts=num_experts, top_k=top_k, - n_group=None, - topk_group=None, + n_group=n_group, + topk_group=topk_group, intermediate_size=intermediate_size, local_expert_offset=0, local_num_experts=num_experts, - routed_scaling_factor=None, + routed_scaling_factor=routed_scaling_factor, routing_method_type=routing_method_type.value, use_shuffled_weight=True, weight_layout=WeightLayout.BlockMajorK, @@ -700,3 +714,47 @@ def test_trtllm_gen_fp8_mxfp8_routed_activation_parity(activation_type: int): close = torch.isclose(output_ref, output_routed, atol=1e-2, rtol=1e-2) mismatch_pct = (~close).float().mean().item() * 100 assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%" + + +@pytest.mark.parametrize("num_tokens", [8, 64]) +@pytest.mark.parametrize("hidden_size", [1024, 2048]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [8, 16]) +@pytest.mark.parametrize("top_k", [2, 4]) +@pytest.mark.parametrize( + "routing_method_type", + [ + RoutingMethodType.Renormalize, + ], +) +def test_trtllm_gen_bf16_routed_fused_moe( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + num_experts: int, + routing_method_type: RoutingMethodType, +): + run_bf16_routed_equivalence_test( + num_tokens=num_tokens, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + top_k=top_k, + num_experts=num_experts, + routing_method_type=routing_method_type, + ) + + +def test_trtllm_gen_bf16_routed_fused_moe_deepseek_routing(): + run_bf16_routed_equivalence_test( + num_tokens=1, + hidden_size=128, + intermediate_size=128, + top_k=1, + num_experts=32, + routing_method_type=RoutingMethodType.DeepSeekV3, + routed_scaling_factor=2.0, + n_group=1, + topk_group=1, + seed=0, + )