diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 122aad1cc6..8bcbe36381 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -955,6 +955,8 @@ def forward( hidden_states, *extra_inputs, ) = inputs + if kwargs.get("skip_routing", False): + routing_logits = None num_tokens = hidden_states.shape[0] extra_input_idx = 0 @@ -1836,7 +1838,12 @@ def trtllm_fp4_block_scale_moe_op( ) inputs = [ output, - torch.empty(num_tokens, num_experts, dtype=routing_dtype, device="meta") + torch.empty( + num_tokens, + num_experts, + dtype=routing_dtype, + device=hidden_states.device, + ) if routing_logits is None else routing_logits, topk_ids, @@ -1851,6 +1858,7 @@ def trtllm_fp4_block_scale_moe_op( [moe_runner], tunning_config, inputs, + skip_routing=(routing_logits is None), num_experts=num_experts, routing_bias=routing_bias, gemm1_weights=gemm1_weights, diff --git a/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py b/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py index 07df7e2e45..d6678b4d76 100644 --- a/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py +++ b/tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py @@ -253,6 +253,100 @@ def test_bf16_moe_all_supported_tile_n_inference_succeed( ) +@pytest.mark.parametrize("num_tokens", [1, 16]) +@pytest.mark.parametrize("num_experts", [16]) +@pytest.mark.parametrize("top_k", [4]) +def test_fp4_routed_moe_autotune_no_crash( + num_tokens: int, + num_experts: int, + top_k: int, +): + """Regression test: trtllm_fp4_block_scale_routed_moe must not crash during + autotuning. Before the fix, the autotuner received a meta-device placeholder + for routing_logits and passed it to the C++ kernel via TVM FFI, which raised + 'Cannot pack tensors on meta'. + """ + _require_sm100() + reset_autotuner() + device = torch.device("cuda:0") + + from flashinfer.fused_moe import trtllm_fp4_block_scale_routed_moe + + hidden_size = 3072 + intermediate_size = 3072 + + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device) + topk_weights = torch.randn(num_tokens, top_k, dtype=torch.bfloat16, device=device) + packed_topk_ids = (topk_ids.to(torch.int32) << 16) | topk_weights.view( + torch.int16 + ).to(torch.int32) + + hidden_states = torch.randn( + num_tokens, hidden_size, dtype=torch.bfloat16, device=device + ) + gemm1_weights = torch.empty( + num_experts, + intermediate_size * 2, + hidden_size // 2, + dtype=torch.uint8, + device=device, + ) + gemm1_weights_scale = torch.empty( + num_experts, + intermediate_size * 2, + hidden_size // 2 // 16, + dtype=torch.float8_e4m3fn, + device=device, + ) + gemm2_weights = torch.empty( + num_experts, + hidden_size, + intermediate_size // 2, + dtype=torch.uint8, + device=device, + ) + gemm2_weights_scale = torch.empty( + num_experts, + hidden_size, + intermediate_size // 2 // 16, + dtype=torch.float8_e4m3fn, + device=device, + ) + output = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=device) + + with autotune(tune_mode=True): + trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=None, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + gemm2_bias=None, + output1_scale_scalar=None, + output1_scale_gate_scalar=None, + output2_scale_scalar=None, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=1, + do_finalize=True, + output=output, + tune_max_num_tokens=1, + ) + + @pytest.mark.parametrize( "invalid_tactic", [