diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index 540e015621..094459cea1 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -1234,6 +1234,60 @@ def test_algorithms_with_large_k(algo, set_topk_algo): assert accuracy >= 0.98, f"Algorithm {algo}: Accuracy {accuracy:.4f} < 0.98" +@pytest.mark.parametrize("num_rows", [4, 8]) +@pytest.mark.parametrize("top_k", [256, 2048]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype): + """Regression test for uint32 underflow in multi-CTA chunk_size calculation.""" + torch.manual_seed(42) + device = "cuda" + + max_len = 131072 + + # Force multi_cta path so the test exercises the vulnerable code path + # regardless of the heuristic. + old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None) + os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta" + + try: + scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) + offsets = torch.zeros(num_rows, device=device, dtype=torch.int32) + + # Mix short and long rows. Short rows (4K-8K) are well below chunk_size + # on any GPU, so CTAs beyond the first will have chunk_start > length. + lengths_list = [] + for i in range(num_rows): + if i % 2 == 0: + lengths_list.append(max_len) + else: + lengths_list.append(torch.randint(4000, 8000, (1,)).item()) + lengths = torch.tensor(lengths_list, device=device, dtype=torch.int32) + + output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, top_k) + ref_output = reference_ragged_transform(scores, offsets, lengths, top_k) + + assert output.shape == (num_rows, top_k) + assert output.dtype == torch.int32 + + accuracy = compute_transform_accuracy(output, ref_output, num_rows, top_k) + min_accuracy = 0.90 + assert accuracy >= min_accuracy, f"Accuracy {accuracy:.4f} < {min_accuracy}" + + # Verify indices stay within [offset, offset + length) for each row + for i in range(num_rows): + length = lengths[i].item() + row_out = output[i] + valid = row_out[row_out >= 0] + assert torch.all(valid < length), ( + f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})" + ) + finally: + if old_algo is None: + os.environ.pop("FLASHINFER_TOPK_ALGO", None) + else: + os.environ["FLASHINFER_TOPK_ALGO"] = old_algo + + @pytest.mark.parametrize("algo", ["auto", "multi_cta", "filtered"]) def test_bf16_long_seq_regression_across_algorithms(algo, set_topk_algo): """Regression for bf16 long-seq topk across algorithm overrides."""