Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,8 @@ def forward(
hidden_states,
*extra_inputs,
) = inputs
if kwargs.get("skip_routing", False):
routing_logits = None
num_tokens = hidden_states.shape[0]

extra_input_idx = 0
Expand Down Expand Up @@ -1836,7 +1838,12 @@ def trtllm_fp4_block_scale_moe_op(
)
inputs = [
output,
torch.empty(num_tokens, num_experts, dtype=routing_dtype, device="meta")
torch.empty(
num_tokens,
num_experts,
dtype=routing_dtype,
device=hidden_states.device,
)
if routing_logits is None
else routing_logits,
topk_ids,
Expand All @@ -1851,6 +1858,7 @@ def trtllm_fp4_block_scale_moe_op(
[moe_runner],
tunning_config,
inputs,
skip_routing=(routing_logits is None),
num_experts=num_experts,
routing_bias=routing_bias,
gemm1_weights=gemm1_weights,
Expand Down
94 changes: 94 additions & 0 deletions tests/autotuner/test_trtllm_fused_moe_autotuner_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,100 @@ def test_bf16_moe_all_supported_tile_n_inference_succeed(
)


@pytest.mark.parametrize("num_tokens", [1, 16])
@pytest.mark.parametrize("num_experts", [16])
@pytest.mark.parametrize("top_k", [4])
def test_fp4_routed_moe_autotune_no_crash(
num_tokens: int,
num_experts: int,
top_k: int,
):
"""Regression test: trtllm_fp4_block_scale_routed_moe must not crash during
autotuning. Before the fix, the autotuner received a meta-device placeholder
for routing_logits and passed it to the C++ kernel via TVM FFI, which raised
'Cannot pack tensors on meta'.
"""
_require_sm100()
reset_autotuner()
device = torch.device("cuda:0")

from flashinfer.fused_moe import trtllm_fp4_block_scale_routed_moe

hidden_size = 3072
intermediate_size = 3072

topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device)
topk_weights = torch.randn(num_tokens, top_k, dtype=torch.bfloat16, device=device)
packed_topk_ids = (topk_ids.to(torch.int32) << 16) | topk_weights.view(
torch.int16
).to(torch.int32)

hidden_states = torch.randn(
num_tokens, hidden_size, dtype=torch.bfloat16, device=device
)
gemm1_weights = torch.empty(
num_experts,
intermediate_size * 2,
hidden_size // 2,
dtype=torch.uint8,
device=device,
)
gemm1_weights_scale = torch.empty(
num_experts,
intermediate_size * 2,
hidden_size // 2 // 16,
dtype=torch.float8_e4m3fn,
device=device,
)
gemm2_weights = torch.empty(
num_experts,
hidden_size,
intermediate_size // 2,
dtype=torch.uint8,
device=device,
)
gemm2_weights_scale = torch.empty(
num_experts,
hidden_size,
intermediate_size // 2 // 16,
dtype=torch.float8_e4m3fn,
device=device,
)
output = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)

with autotune(tune_mode=True):
trtllm_fp4_block_scale_routed_moe(
topk_ids=packed_topk_ids,
routing_bias=None,
hidden_states=hidden_states,
hidden_states_scale=None,
gemm1_weights=gemm1_weights,
gemm1_weights_scale=gemm1_weights_scale,
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=gemm2_weights,
gemm2_weights_scale=gemm2_weights_scale,
gemm2_bias=None,
output1_scale_scalar=None,
output1_scale_gate_scalar=None,
output2_scale_scalar=None,
num_experts=num_experts,
top_k=top_k,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routing_method_type=1,
do_finalize=True,
output=output,
tune_max_num_tokens=1,
)


@pytest.mark.parametrize(
"invalid_tactic",
[
Expand Down
Loading