diff --git a/tests/kernels/moe/test_fused_topk.py b/tests/kernels/moe/test_fused_topk.py index a0e3580ee5a0..825cd20263d7 100644 --- a/tests/kernels/moe/test_fused_topk.py +++ b/tests/kernels/moe/test_fused_topk.py @@ -202,3 +202,72 @@ def test_fused_topk_nan_inf_clamp( f"Row {row} has non-finite weights {topk_weights[row].tolist()} " f"(bad_value={bad_value}, scoring_func={scoring_func})" ) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." +) +@pytest.mark.parametrize("num_experts", [6, 8, 16]) +@pytest.mark.parametrize("topk", [3, 4]) +@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("bad_value", [float("nan"), float("inf")]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half, torch.float32]) +def test_fused_topk_bias_nan_inf_clamp( + num_experts: int, + topk: int, + scoring_func: str, + bad_value: float, + dtype: torch.dtype, +): + """Regression test: NaN/Inf in gating logits must not produce duplicate + expert IDs or non-finite weights when e_score_correction_bias is present. + + Same scenario as test_fused_topk_nan_inf_clamp but exercising the bias + path (fused_topk_bias) so the fix in topk_softmax_kernels.cu is covered + for that entry point as well. + """ + torch.manual_seed(0) + num_tokens = 4 + hidden_size = 1024 + hidden_states = torch.randn((num_tokens, hidden_size), dtype=dtype, device="cuda") + e_score_correction_bias = torch.randn( + (num_experts,), dtype=torch.float32, device="cuda" + ) + + gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda") + gating_output[1:, :] = bad_value + + topk_weights, topk_ids = fused_topk_bias( + hidden_states=hidden_states, + gating_output=gating_output, + e_score_correction_bias=e_score_correction_bias, + topk=topk, + renormalize=False, + scoring_func=scoring_func, + ) + + # Normal row must still match the torch reference. + ref_weights, ref_ids = torch_topk( + gating_output=gating_output[:1], + topk=topk, + renormalize=False, + e_score_correction_bias=e_score_correction_bias, + scoring_func=scoring_func, + ) + torch.testing.assert_close( + ref_weights.to(torch.float32), topk_weights[:1], atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(ref_ids.to(torch.int32), topk_ids[:1], atol=0, rtol=0) + + # Poisoned rows: IDs must be unique (no duplicates) and weights must be + # finite (no NaN/Inf propagation into downstream MoE kernels). + for row in range(1, num_tokens): + row_ids = topk_ids[row] + assert row_ids.unique().numel() == topk, ( + f"Row {row} has duplicate expert IDs {row_ids.tolist()} " + f"(bad_value={bad_value}, scoring_func={scoring_func})" + ) + assert torch.isfinite(topk_weights[row]).all(), ( + f"Row {row} has non-finite weights {topk_weights[row].tolist()} " + f"(bad_value={bad_value}, scoring_func={scoring_func})" + )