[BugFix] guard against uint32 underflow in multi-CTA TopK chunk calculation#2592
Conversation
Summary of ChangesHello @LopezCastroRoberto, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical uint32 underflow bug within the RadixTopKKernel when processing variable-length rows in multi-CTA mode. The fix ensures correct chunk_size calculation, preventing memory access violations. A new regression test has been added to validate the fix and improve the robustness of the top_k_ragged_transform functionality. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a fix for a uint32 underflow bug in the multi-CTA TopK chunk calculation and adds a corresponding regression test. The test effectively simulates the scenario that triggers the bug by using a mix of long and short rows with the multi_cta algorithm forced. My review of the new test code includes suggestions to improve code quality, consistency, and robustness, such as using an existing fixture for environment variable management, leveraging a list comprehension for more concise code, and adding a guard to make an assertion more robust against different test parameters.
| old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None) | ||
| os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta" | ||
|
|
||
| try: | ||
| scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) | ||
| offsets = torch.zeros(num_rows, device=device, dtype=torch.int32) | ||
|
|
||
| # Mix short and long rows. Short rows (4K-8K) are well below chunk_size | ||
| # on any GPU, so CTAs beyond the first will have chunk_start > length. | ||
| lengths_list = [] | ||
| for i in range(num_rows): | ||
| if i % 2 == 0: | ||
| lengths_list.append(max_len) | ||
| else: | ||
| lengths_list.append( | ||
| torch.randint(4000, 8000, (1,)).item() | ||
| ) | ||
| lengths = torch.tensor(lengths_list, device=device, dtype=torch.int32) | ||
|
|
||
| output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, top_k) | ||
| ref_output = reference_ragged_transform(scores, offsets, lengths, top_k) | ||
|
|
||
| assert output.shape == (num_rows, top_k) | ||
| assert output.dtype == torch.int32 | ||
|
|
||
| accuracy = compute_transform_accuracy(output, ref_output, num_rows, top_k) | ||
| min_accuracy = 0.90 | ||
| assert accuracy >= min_accuracy, f"Accuracy {accuracy:.4f} < {min_accuracy}" | ||
|
|
||
| # Verify indices stay within [offset, offset + length) for each row | ||
| for i in range(num_rows): | ||
| length = lengths[i].item() | ||
| row_out = output[i] | ||
| valid = row_out[row_out >= 0] | ||
| assert torch.all(valid < length), ( | ||
| f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})" | ||
| ) | ||
| finally: | ||
| if old_algo is None: | ||
| os.environ.pop("FLASHINFER_TOPK_ALGO", None) | ||
| else: | ||
| os.environ["FLASHINFER_TOPK_ALGO"] = old_algo |
There was a problem hiding this comment.
For consistency with other tests in this file and to simplify the code, consider using the set_topk_algo fixture to manage the FLASHINFER_TOPK_ALGO environment variable. This removes the need for a manual try...finally block.
To apply this change:
- Add
set_topk_algoto the test function's parameters on line 1233:def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype, set_topk_algo): - Replace lines 1242-1245 with a single call to the fixture:
set_topk_algo("multi_cta"). - Remove the
finallyblock on lines 1279-1283. - Un-indent the code that is currently inside the
tryblock (lines 1246-1278).
tests/utils/test_topk.py
Outdated
| lengths_list = [] | ||
| for i in range(num_rows): | ||
| if i % 2 == 0: | ||
| lengths_list.append(max_len) | ||
| else: | ||
| lengths_list.append( | ||
| torch.randint(4000, 8000, (1,)).item() | ||
| ) |
| assert torch.all(valid < length), ( | ||
| f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})" | ||
| ) |
There was a problem hiding this comment.
Calling .max() on an empty tensor will raise a RuntimeError. While the current test parameters ensure valid is never empty, it's good practice to make the test more robust by adding a guard. This will prevent the test from failing unexpectedly if its parameters are changed in the future.
if valid.numel() > 0:
assert torch.all(valid < length), (
f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
)Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughA new regression test Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/utils/test_topk.py (2)
1230-1232: Consider addingtorch.bfloat16to the dtype parameterization.The underflow is in
uint32_tarithmetic and is dtype-independent, sofloat32+float16fully cover the regression. However, all analogous ragged-transform tests in this file (test_top_k_ragged_transform,test_top_k_ragged_transform_out_of_length) includebfloat16, and omitting it here creates a minor gap in consistency.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1230 - 1232, Add torch.bfloat16 to the dtype parameterization for the test in tests/utils/test_topk.py so it matches the other ragged-transform tests; update the `@pytest.mark.parametrize`("dtype", ...) that currently lists torch.float32 and torch.float16 to also include torch.bfloat16 (affecting the test function(s) around test_top_k_ragged_transform, test_top_k_ragged_transform_out_of_length and the current top-k test) so the test matrix includes bfloat16 for consistency.
1233-1282: Use the existingset_topk_algofixture instead of manual env-var management.The file already has a
set_topk_algofixture (lines 27–43) that handles exactly thisold_algo/try/finallypattern. Using it removes ~10 lines of boilerplate and keeps teardown consistent with the rest of the suite.♻️ Proposed refactor
-def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype): +def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype, set_topk_algo): """Regression test for uint32 underflow in multi-CTA chunk_size calculation.""" torch.manual_seed(42) device = "cuda" max_len = 131072 - # Force multi_cta path so the test exercises the vulnerable code path - # regardless of the heuristic. - old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None) - os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta" - - try: - scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) - ... - assert torch.all(valid < length), ( - f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})" - ) - finally: - if old_algo is None: - os.environ.pop("FLASHINFER_TOPK_ALGO", None) - else: - os.environ["FLASHINFER_TOPK_ALGO"] = old_algo + # Force multi_cta path so the test exercises the vulnerable code path + # regardless of the heuristic. + set_topk_algo("multi_cta") + + scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) + ... + assert torch.all(valid < length), ( + f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1233 - 1282, Replace the manual env-var management in test_ragged_transform_multi_cta_short_rows with the existing set_topk_algo fixture: remove the old_algo/os.environ try/finally block and instead accept/set_topk_algo (or call set_topk_algo("multi_cta") per project convention) in the test signature or setup so the FLASHINFER_TOPK_ALGO is set to "multi_cta" for the test and restored automatically; update the test function definition (test_ragged_transform_multi_cta_short_rows) to use the fixture and delete the explicit os.environ manipulation and finally cleanup.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/utils/test_topk.py`:
- Around line 1230-1232: Add torch.bfloat16 to the dtype parameterization for
the test in tests/utils/test_topk.py so it matches the other ragged-transform
tests; update the `@pytest.mark.parametrize`("dtype", ...) that currently lists
torch.float32 and torch.float16 to also include torch.bfloat16 (affecting the
test function(s) around test_top_k_ragged_transform,
test_top_k_ragged_transform_out_of_length and the current top-k test) so the
test matrix includes bfloat16 for consistency.
- Around line 1233-1282: Replace the manual env-var management in
test_ragged_transform_multi_cta_short_rows with the existing set_topk_algo
fixture: remove the old_algo/os.environ try/finally block and instead
accept/set_topk_algo (or call set_topk_algo("multi_cta") per project convention)
in the test signature or setup so the FLASHINFER_TOPK_ALGO is set to "multi_cta"
for the test and restored automatically; update the test function definition
(test_ragged_transform_multi_cta_short_rows) to use the fixture and delete the
explicit os.environ manipulation and finally cleanup.
|
Hi @LopezCastroRoberto this looks similar to #2489? #2489 has more guard than this PR, would you mind double checking? Thanks for working on the unittest btw. |
|
Hi @LopezCastroRoberto would you mind rebasing your PR to main (#2489 was already merged) and keep your unittest? We can still merge it (the test is helpful). |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/utils/test_topk.py (1)
1242-1281: Use the existingset_topk_algofixture instead of duplicating env-var management.The manual
old_algo / os.environ / try/finallyblock (lines 1242–1281) replicates theset_topk_algofixture verbatim. Accept it as a parameter and callset_topk_algo("multi_cta")instead.♻️ Proposed refactor
-def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype): +def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype, set_topk_algo): """Regression test for uint32 underflow in multi-CTA chunk_size calculation.""" torch.manual_seed(42) device = "cuda" max_len = 131072 - # Force multi_cta path so the test exercises the vulnerable code path - # regardless of the heuristic. - old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None) - os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta" + set_topk_algo("multi_cta") - try: - scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) - ... - # Verify indices stay within [offset, offset + length) for each row - for i in range(num_rows): - ... - finally: - if old_algo is None: - os.environ.pop("FLASHINFER_TOPK_ALGO", None) - else: - os.environ["FLASHINFER_TOPK_ALGO"] = old_algo + scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) + ... + # Verify indices stay within [offset, offset + length) for each row + for i in range(num_rows): + ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1242 - 1281, Replace the manual env-var save/restore block (old_algo / os.environ / try/finally) in the test with the existing set_topk_algo fixture: accept set_topk_algo as a test parameter and call set_topk_algo("multi_cta") at the start of the test instead of manipulating os.environ directly; leave the rest of the test (scores, lengths, flashinfer.top_k_ragged_transform, assertions, and bounds checks) unchanged so the fixture manages FLASHINFER_TOPK_ALGO setup/teardown.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/utils/test_topk.py`:
- Around line 1230-1233: Add a GPU architecture guard at the start of
test_ragged_transform_multi_cta_short_rows so the test is skipped on unsupported
GPUs; use flashinfer.utils helpers (for example call is_sm90a_supported() or
get_compute_capability() from flashinfer.utils) and then call pytest.skip(...)
or apply pytest.mark.skipif(...) if the capability isn't present. Ensure the
check is placed in the test function test_ragged_transform_multi_cta_short_rows
and references flashinfer.utils.get_compute_capability or
flashinfer.utils.is_sm90a_supported so the test follows the project coding
guideline for architecture-gated tests.
---
Nitpick comments:
In `@tests/utils/test_topk.py`:
- Around line 1242-1281: Replace the manual env-var save/restore block (old_algo
/ os.environ / try/finally) in the test with the existing set_topk_algo fixture:
accept set_topk_algo as a test parameter and call set_topk_algo("multi_cta") at
the start of the test instead of manipulating os.environ directly; leave the
rest of the test (scores, lengths, flashinfer.top_k_ragged_transform,
assertions, and bounds checks) unchanged so the fixture manages
FLASHINFER_TOPK_ALGO setup/teardown.
|
/bot run |
|
[FAILED] Pipeline #45159141: 1/20 passed |
|
/bot run |
|
[CANCELING] Pipeline #45169718: canceled |
…lation (flashinfer-ai#2592) ## Summary - Fix unsigned integer underflow in `RadixTopKKernel` when `chunk_start >= length` in multi-CTA mode with variable-length rows. - Add regression test for ragged transform mode. ## Problem In multi-CTA mode, `chunk_size` and `ctas_per_group` are derived from `max_len` (the input tensor stride). In ragged/page-table modes, each row has its own `length` which can be much shorter than `max_len`. When a CTA's `chunk_start = cta_in_group * chunk_size` exceeds a row's actual `length`: ```cpp const uint32_t chunk_end = min(chunk_start + chunk_size, length); // = length const uint32_t actual_chunk_size = chunk_end - chunk_start; // unsigned underflow ``` ```chunk_end``` resolves to ```length``` (since ```length < chunk_start```), and the subtraction underflows, causing out-of-bounds memory access and a segfault. ### Fix ```cpp const uint32_t actual_chunk_size = (chunk_start < length) ? (chunk_end - chunk_start) : 0; ``` CTAs whose chunk falls beyond a row's length get ```actual_chunk_size = 0```. They still participate in multi-CTA barriers (required for correctness) but process no data. ### Test plan ``` - pytest tests/utils/test_topk.py::test_ragged_transform_multi_cta_short_rows ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Added a regression test covering top-k ragged transform with mixed short and long sequences to improve edge-case coverage, verify shapes/dtypes/accuracy, ensure returned indices are within valid per-row ranges, and confirm environment variable handling is restored after test execution. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Co-authored-by: Brian Ryu <bryu@nvidia.com>
…lation (flashinfer-ai#2592) ## Summary - Fix unsigned integer underflow in `RadixTopKKernel` when `chunk_start >= length` in multi-CTA mode with variable-length rows. - Add regression test for ragged transform mode. ## Problem In multi-CTA mode, `chunk_size` and `ctas_per_group` are derived from `max_len` (the input tensor stride). In ragged/page-table modes, each row has its own `length` which can be much shorter than `max_len`. When a CTA's `chunk_start = cta_in_group * chunk_size` exceeds a row's actual `length`: ```cpp const uint32_t chunk_end = min(chunk_start + chunk_size, length); // = length const uint32_t actual_chunk_size = chunk_end - chunk_start; // unsigned underflow ``` ```chunk_end``` resolves to ```length``` (since ```length < chunk_start```), and the subtraction underflows, causing out-of-bounds memory access and a segfault. ### Fix ```cpp const uint32_t actual_chunk_size = (chunk_start < length) ? (chunk_end - chunk_start) : 0; ``` CTAs whose chunk falls beyond a row's length get ```actual_chunk_size = 0```. They still participate in multi-CTA barriers (required for correctness) but process no data. ### Test plan ``` - pytest tests/utils/test_topk.py::test_ragged_transform_multi_cta_short_rows ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Added a regression test covering top-k ragged transform with mixed short and long sequences to improve edge-case coverage, verify shapes/dtypes/accuracy, ensure returned indices are within valid per-row ranges, and confirm environment variable handling is restored after test execution. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Co-authored-by: Brian Ryu <bryu@nvidia.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Summary
RadixTopKKernelwhenchunk_start >= lengthin multi-CTA mode with variable-length rows.Problem
In multi-CTA mode,
chunk_sizeandctas_per_groupare derived frommax_len(the input tensor stride). In ragged/page-table modes, each row has its ownlengthwhich can be much shorter thanmax_len. When a CTA'schunk_start = cta_in_group * chunk_sizeexceeds a row's actuallength:chunk_endresolves tolength(sincelength < chunk_start), and the subtraction underflows, causing out-of-bounds memory access and a segfault.Fix
CTAs whose chunk falls beyond a row's length get
actual_chunk_size = 0. They still participate in multi-CTA barriers (required for correctness) but process no data.Test plan
Summary by CodeRabbit