Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/utils/test_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,60 @@ def test_algorithms_with_large_k(algo, set_topk_algo):
assert accuracy >= 0.98, f"Algorithm {algo}: Accuracy {accuracy:.4f} < 0.98"


@pytest.mark.parametrize("num_rows", [4, 8])
@pytest.mark.parametrize("top_k", [256, 2048])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype):
"""Regression test for uint32 underflow in multi-CTA chunk_size calculation."""
torch.manual_seed(42)
device = "cuda"

max_len = 131072

# Force multi_cta path so the test exercises the vulnerable code path
# regardless of the heuristic.
old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None)
os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta"

try:
scores = torch.randn(num_rows, max_len, device=device, dtype=dtype)
offsets = torch.zeros(num_rows, device=device, dtype=torch.int32)

# Mix short and long rows. Short rows (4K-8K) are well below chunk_size
# on any GPU, so CTAs beyond the first will have chunk_start > length.
lengths_list = []
for i in range(num_rows):
if i % 2 == 0:
lengths_list.append(max_len)
else:
lengths_list.append(torch.randint(4000, 8000, (1,)).item())
lengths = torch.tensor(lengths_list, device=device, dtype=torch.int32)

output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, top_k)
ref_output = reference_ragged_transform(scores, offsets, lengths, top_k)

assert output.shape == (num_rows, top_k)
assert output.dtype == torch.int32

accuracy = compute_transform_accuracy(output, ref_output, num_rows, top_k)
min_accuracy = 0.90
assert accuracy >= min_accuracy, f"Accuracy {accuracy:.4f} < {min_accuracy}"

# Verify indices stay within [offset, offset + length) for each row
for i in range(num_rows):
length = lengths[i].item()
row_out = output[i]
valid = row_out[row_out >= 0]
assert torch.all(valid < length), (
f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
)
Comment on lines +1281 to +1283
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .max() on an empty tensor will raise a RuntimeError. While the current test parameters ensure valid is never empty, it's good practice to make the test more robust by adding a guard. This will prevent the test from failing unexpectedly if its parameters are changed in the future.

            if valid.numel() > 0:
                assert torch.all(valid < length), (
                    f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
                )

finally:
if old_algo is None:
os.environ.pop("FLASHINFER_TOPK_ALGO", None)
else:
os.environ["FLASHINFER_TOPK_ALGO"] = old_algo


@pytest.mark.parametrize("algo", ["auto", "multi_cta", "filtered"])
def test_bf16_long_seq_regression_across_algorithms(algo, set_topk_algo):
"""Regression for bf16 long-seq topk across algorithm overrides."""
Expand Down
Loading