diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py index 2e83a78a6d..b787b27c9a 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/moe_dispatch.py @@ -186,7 +186,9 @@ def allocate_sm120_static_workspace( active_expert_count=torch.zeros(1, dtype=torch.int32, device=device), weight_expert_ids=torch.arange(state_E, dtype=torch.int32, device=device), global_to_local_expert=torch.empty(weight_E, dtype=torch.int32, device=device), - compact_topk_ids=torch.empty(state_E, dtype=torch.int32, device=device), + compact_topk_ids=torch.empty( + max(state_E, max_rows), dtype=torch.int32, device=device + ), ) # Finalize views diff --git a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py index 4c8af5fbfa..deecb29d36 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/triton_compact.py @@ -68,8 +68,8 @@ def compact_topk_ids( return if compact_topk_ids.numel() < total_pairs: raise ValueError("compact_topk_ids must have at least total_pairs elements") - if weight_expert_ids.numel() < total_pairs: - raise ValueError("weight_expert_ids must have at least total_pairs elements") + # weight_expert_ids writes at indices 0..active_expert_count-1 (bounded by + # the number of local experts, not total_pairs), so no size check is needed here. if active_expert_count.numel() != 1: raise ValueError("active_expert_count must have shape [1]") diff --git a/tests/moe/test_b12x_fused_moe.py b/tests/moe/test_b12x_fused_moe.py index 087426f25e..8453395037 100644 --- a/tests/moe/test_b12x_fused_moe.py +++ b/tests/moe/test_b12x_fused_moe.py @@ -1067,6 +1067,78 @@ def test_micro_single_token_unique_path(self): f"(atol={atol:.4f})" ) + @pytest.mark.parametrize( + "num_tokens,top_k,num_experts", + [ + (2, 8, 8), # total_pairs=16 > num_local_experts=8 + (4, 8, 16), # total_pairs=32 > num_local_experts=16 + (4, 4, 8), # total_pairs=16 > num_local_experts=8 + ], + ) + def test_micro_pairs_exceed_local_experts( + self, num_tokens: int, top_k: int, num_experts: int + ): + """Regression test: micro kernel when num_tokens * top_k > num_local_experts. + + The workspace compact_topk_ids buffer was previously sized state_E + (num_local_experts), but the micro kernel fills it with total_pairs = + num_tokens * top_k. When total_pairs > num_local_experts the assertion + 'flat_ids.numel() <= workspace.compact_topk_ids.numel()' fired. + + Fixed by sizing compact_topk_ids as max(state_E, max_rows). + """ + from flashinfer import b12x_fused_moe + + hidden_size, intermediate_size = 256, 512 + + tensors = create_moe_tensors( + num_tokens=num_tokens, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts=num_experts, + num_local_experts=num_experts, + top_k=top_k, + ) + + result = b12x_fused_moe( + x=tensors["x_bf16"], + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + num_experts=num_experts, + top_k=top_k, + ) + + assert result.shape == (num_tokens, hidden_size) + assert not torch.isnan(result).any() + assert not torch.isinf(result).any() + + ref_output = compute_reference_moe_fp4( + hidden_states=tensors["x_bf16"].float().cuda(), + gemm1_weights=tensors["w1_weight_bf16"].float().cuda(), + gemm2_weights=tensors["w2_weight_bf16"].float().cuda(), + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + num_tokens=num_tokens, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + fc2_input_scale=tensors["fc2_input_scale"], + ) + + passed, percent_within, atol = check_accuracy(result, ref_output) + assert passed, ( + f"Micro pairs>experts: {percent_within * 100:.2f}% within tolerance " + f"(atol={atol:.4f}, tokens={num_tokens}, top_k={top_k}, experts={num_experts})" + ) + # ============================================================================= # Test Class: ReLU2 Activation (SM120-only, non-gated)