Skip to content

[BugFix] guard against uint32 underflow in multi-CTA TopK chunk calculation#2592

Merged
bkryu merged 4 commits intoflashinfer-ai:mainfrom
LopezCastroRoberto:bug/topKraggedTransform
Mar 13, 2026
Merged

[BugFix] guard against uint32 underflow in multi-CTA TopK chunk calculation#2592
bkryu merged 4 commits intoflashinfer-ai:mainfrom
LopezCastroRoberto:bug/topKraggedTransform

Conversation

@LopezCastroRoberto
Copy link
Copy Markdown
Contributor

@LopezCastroRoberto LopezCastroRoberto commented Feb 19, 2026

Summary

  • Fix unsigned integer underflow in RadixTopKKernel when chunk_start >= length in multi-CTA mode with variable-length rows.
  • Add regression test for ragged transform mode.

Problem

In multi-CTA mode, chunk_size and ctas_per_group are derived from max_len (the input tensor stride). In ragged/page-table modes, each row has its own length which can be much shorter than max_len. When a CTA's chunk_start = cta_in_group * chunk_size exceeds a row's actual length:

const uint32_t chunk_end = min(chunk_start + chunk_size, length); // = length
const uint32_t actual_chunk_size = chunk_end - chunk_start;       // unsigned underflow

chunk_end resolves to length (since length < chunk_start), and the subtraction underflows, causing out-of-bounds memory access and a segfault.

Fix

  const uint32_t actual_chunk_size = (chunk_start < length) ? (chunk_end - chunk_start) : 0;

CTAs whose chunk falls beyond a row's length get actual_chunk_size = 0. They still participate in multi-CTA barriers (required for correctness) but process no data.

Test plan

  - pytest tests/utils/test_topk.py::test_ragged_transform_multi_cta_short_rows

Summary by CodeRabbit

  • Tests
    • Added a regression test covering top-k ragged transform with mixed short and long sequences to improve edge-case coverage, verify shapes/dtypes/accuracy, ensure returned indices are within valid per-row ranges, and confirm environment variable handling is restored after test execution.

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @LopezCastroRoberto, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical uint32 underflow bug within the RadixTopKKernel when processing variable-length rows in multi-CTA mode. The fix ensures correct chunk_size calculation, preventing memory access violations. A new regression test has been added to validate the fix and improve the robustness of the top_k_ragged_transform functionality.

Highlights

  • Bug Fix: Resolved an unsigned integer underflow issue in the RadixTopKKernel's multi-CTA chunk calculation, which occurred when chunk_start exceeded length in variable-length rows, preventing potential out-of-bounds memory access and segfaults.
  • Testing: Introduced a new regression test, test_ragged_transform_multi_cta_short_rows, to specifically cover the scenario that previously led to the uint32 underflow, ensuring the fix's effectiveness and preventing future regressions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/utils/test_topk.py
    • Added test_ragged_transform_multi_cta_short_rows to test the uint32 underflow fix.
Activity
  • No specific activity (comments, reviews, progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fix for a uint32 underflow bug in the multi-CTA TopK chunk calculation and adds a corresponding regression test. The test effectively simulates the scenario that triggers the bug by using a mix of long and short rows with the multi_cta algorithm forced. My review of the new test code includes suggestions to improve code quality, consistency, and robustness, such as using an existing fixture for environment variable management, leveraging a list comprehension for more concise code, and adding a guard to make an assertion more robust against different test parameters.

Comment on lines +1242 to +1283
old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None)
os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta"

try:
scores = torch.randn(num_rows, max_len, device=device, dtype=dtype)
offsets = torch.zeros(num_rows, device=device, dtype=torch.int32)

# Mix short and long rows. Short rows (4K-8K) are well below chunk_size
# on any GPU, so CTAs beyond the first will have chunk_start > length.
lengths_list = []
for i in range(num_rows):
if i % 2 == 0:
lengths_list.append(max_len)
else:
lengths_list.append(
torch.randint(4000, 8000, (1,)).item()
)
lengths = torch.tensor(lengths_list, device=device, dtype=torch.int32)

output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, top_k)
ref_output = reference_ragged_transform(scores, offsets, lengths, top_k)

assert output.shape == (num_rows, top_k)
assert output.dtype == torch.int32

accuracy = compute_transform_accuracy(output, ref_output, num_rows, top_k)
min_accuracy = 0.90
assert accuracy >= min_accuracy, f"Accuracy {accuracy:.4f} < {min_accuracy}"

# Verify indices stay within [offset, offset + length) for each row
for i in range(num_rows):
length = lengths[i].item()
row_out = output[i]
valid = row_out[row_out >= 0]
assert torch.all(valid < length), (
f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
)
finally:
if old_algo is None:
os.environ.pop("FLASHINFER_TOPK_ALGO", None)
else:
os.environ["FLASHINFER_TOPK_ALGO"] = old_algo
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other tests in this file and to simplify the code, consider using the set_topk_algo fixture to manage the FLASHINFER_TOPK_ALGO environment variable. This removes the need for a manual try...finally block.

To apply this change:

  1. Add set_topk_algo to the test function's parameters on line 1233: def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype, set_topk_algo):
  2. Replace lines 1242-1245 with a single call to the fixture: set_topk_algo("multi_cta").
  3. Remove the finally block on lines 1279-1283.
  4. Un-indent the code that is currently inside the try block (lines 1246-1278).

Comment on lines +1251 to +1258
lengths_list = []
for i in range(num_rows):
if i % 2 == 0:
lengths_list.append(max_len)
else:
lengths_list.append(
torch.randint(4000, 8000, (1,)).item()
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop can be expressed more concisely using a list comprehension, which is more idiomatic in Python.

        lengths_list = [
            max_len if i % 2 == 0 else torch.randint(4000, 8000, (1,)).item()
            for i in range(num_rows)
        ]

Comment on lines +1276 to +1278
assert torch.all(valid < length), (
f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .max() on an empty tensor will raise a RuntimeError. While the current test parameters ensure valid is never empty, it's good practice to make the test more robust by adding a guard. This will prevent the test from failing unexpectedly if its parameters are changed in the future.

            if valid.numel() > 0:
                assert torch.all(valid < length), (
                    f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
                )

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 19, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8d2b45c and 6bf84cc.

📒 Files selected for processing (1)
  • tests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/utils/test_topk.py

📝 Walkthrough

Walkthrough

A new regression test test_ragged_transform_multi_cta_short_rows was added to tests/utils/test_topk.py to force the multi-CTA top-k path and validate behavior on batches with mixed short and long sequence lengths, including index bounds and accuracy checks.

Changes

Cohort / File(s) Summary
Top-K Regression Test
tests/utils/test_topk.py
Added regression test test_ragged_transform_multi_cta_short_rows that forces FLASHINFER_TOPK_ALGO="multi_cta", generates mixed-length ragged input, compares outputs to a reference implementation, validates shapes/dtypes/index bounds/accuracy, and restores env state.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • IwakuraRein
  • kahyunnam
  • jiahanc
  • cyx-6

Poem

🐰 I hopped through rows both short and long,
Poked at chunks where math felt wrong,
Multi-CTA, I gave you a test,
Indices checked — now let bugs rest. 🥕

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: guarding against uint32 underflow in multi-CTA TopK chunk calculation.
Description check ✅ Passed The description provides a comprehensive explanation of the problem, the fix, and the test plan. However, the PR template's pre-commit and test checklists are not completed.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/utils/test_topk.py (2)

1230-1232: Consider adding torch.bfloat16 to the dtype parameterization.

The underflow is in uint32_t arithmetic and is dtype-independent, so float32 + float16 fully cover the regression. However, all analogous ragged-transform tests in this file (test_top_k_ragged_transform, test_top_k_ragged_transform_out_of_length) include bfloat16, and omitting it here creates a minor gap in consistency.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1230 - 1232, Add torch.bfloat16 to the
dtype parameterization for the test in tests/utils/test_topk.py so it matches
the other ragged-transform tests; update the `@pytest.mark.parametrize`("dtype",
...) that currently lists torch.float32 and torch.float16 to also include
torch.bfloat16 (affecting the test function(s) around
test_top_k_ragged_transform, test_top_k_ragged_transform_out_of_length and the
current top-k test) so the test matrix includes bfloat16 for consistency.

1233-1282: Use the existing set_topk_algo fixture instead of manual env-var management.

The file already has a set_topk_algo fixture (lines 27–43) that handles exactly this old_algo/try/finally pattern. Using it removes ~10 lines of boilerplate and keeps teardown consistent with the rest of the suite.

♻️ Proposed refactor
-def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype):
+def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype, set_topk_algo):
     """Regression test for uint32 underflow in multi-CTA chunk_size calculation."""
     torch.manual_seed(42)
     device = "cuda"
 
     max_len = 131072
 
-    # Force multi_cta path so the test exercises the vulnerable code path
-    # regardless of the heuristic.
-    old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None)
-    os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta"
-
-    try:
-        scores = torch.randn(num_rows, max_len, device=device, dtype=dtype)
-        ...
-        assert torch.all(valid < length), (
-            f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
-        )
-    finally:
-        if old_algo is None:
-            os.environ.pop("FLASHINFER_TOPK_ALGO", None)
-        else:
-            os.environ["FLASHINFER_TOPK_ALGO"] = old_algo
+    # Force multi_cta path so the test exercises the vulnerable code path
+    # regardless of the heuristic.
+    set_topk_algo("multi_cta")
+
+    scores = torch.randn(num_rows, max_len, device=device, dtype=dtype)
+    ...
+    assert torch.all(valid < length), (
+        f"Row {i}: index out of bounds (max={valid.max().item()}, length={length})"
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1233 - 1282, Replace the manual
env-var management in test_ragged_transform_multi_cta_short_rows with the
existing set_topk_algo fixture: remove the old_algo/os.environ try/finally block
and instead accept/set_topk_algo (or call set_topk_algo("multi_cta") per project
convention) in the test signature or setup so the FLASHINFER_TOPK_ALGO is set to
"multi_cta" for the test and restored automatically; update the test function
definition (test_ragged_transform_multi_cta_short_rows) to use the fixture and
delete the explicit os.environ manipulation and finally cleanup.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/utils/test_topk.py`:
- Around line 1230-1232: Add torch.bfloat16 to the dtype parameterization for
the test in tests/utils/test_topk.py so it matches the other ragged-transform
tests; update the `@pytest.mark.parametrize`("dtype", ...) that currently lists
torch.float32 and torch.float16 to also include torch.bfloat16 (affecting the
test function(s) around test_top_k_ragged_transform,
test_top_k_ragged_transform_out_of_length and the current top-k test) so the
test matrix includes bfloat16 for consistency.
- Around line 1233-1282: Replace the manual env-var management in
test_ragged_transform_multi_cta_short_rows with the existing set_topk_algo
fixture: remove the old_algo/os.environ try/finally block and instead
accept/set_topk_algo (or call set_topk_algo("multi_cta") per project convention)
in the test signature or setup so the FLASHINFER_TOPK_ALGO is set to "multi_cta"
for the test and restored automatically; update the test function definition
(test_ragged_transform_multi_cta_short_rows) to use the fixture and delete the
explicit os.environ manipulation and finally cleanup.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 19, 2026

Hi @LopezCastroRoberto this looks similar to #2489? #2489 has more guard than this PR, would you mind double checking?

Thanks for working on the unittest btw.

@LopezCastroRoberto
Copy link
Copy Markdown
Contributor Author

Hey @yzh119.

Oh yeah, didn't see PR #2489.

In my case I didn't need RadixTopKMode::Basic, so I missed that guard. Perfectly fine to land #2489 and just delete this PR. Thanks!

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 20, 2026

Hi @LopezCastroRoberto would you mind rebasing your PR to main (#2489 was already merged) and keep your unittest? We can still merge it (the test is helpful).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/utils/test_topk.py (1)

1242-1281: Use the existing set_topk_algo fixture instead of duplicating env-var management.

The manual old_algo / os.environ / try/finally block (lines 1242–1281) replicates the set_topk_algo fixture verbatim. Accept it as a parameter and call set_topk_algo("multi_cta") instead.

♻️ Proposed refactor
-def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype):
+def test_ragged_transform_multi_cta_short_rows(num_rows, top_k, dtype, set_topk_algo):
     """Regression test for uint32 underflow in multi-CTA chunk_size calculation."""
     torch.manual_seed(42)
     device = "cuda"

     max_len = 131072

-    # Force multi_cta path so the test exercises the vulnerable code path
-    # regardless of the heuristic.
-    old_algo = os.environ.get("FLASHINFER_TOPK_ALGO", None)
-    os.environ["FLASHINFER_TOPK_ALGO"] = "multi_cta"
+    set_topk_algo("multi_cta")

-    try:
-        scores = torch.randn(num_rows, max_len, device=device, dtype=dtype)
-        ...
-        # Verify indices stay within [offset, offset + length) for each row
-        for i in range(num_rows):
-            ...
-    finally:
-        if old_algo is None:
-            os.environ.pop("FLASHINFER_TOPK_ALGO", None)
-        else:
-            os.environ["FLASHINFER_TOPK_ALGO"] = old_algo
+    scores = torch.randn(num_rows, max_len, device=device, dtype=dtype)
+    ...
+    # Verify indices stay within [offset, offset + length) for each row
+    for i in range(num_rows):
+        ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1242 - 1281, Replace the manual
env-var save/restore block (old_algo / os.environ / try/finally) in the test
with the existing set_topk_algo fixture: accept set_topk_algo as a test
parameter and call set_topk_algo("multi_cta") at the start of the test instead
of manipulating os.environ directly; leave the rest of the test (scores,
lengths, flashinfer.top_k_ragged_transform, assertions, and bounds checks)
unchanged so the fixture manages FLASHINFER_TOPK_ALGO setup/teardown.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/utils/test_topk.py`:
- Around line 1230-1233: Add a GPU architecture guard at the start of
test_ragged_transform_multi_cta_short_rows so the test is skipped on unsupported
GPUs; use flashinfer.utils helpers (for example call is_sm90a_supported() or
get_compute_capability() from flashinfer.utils) and then call pytest.skip(...)
or apply pytest.mark.skipif(...) if the capability isn't present. Ensure the
check is placed in the test function test_ragged_transform_multi_cta_short_rows
and references flashinfer.utils.get_compute_capability or
flashinfer.utils.is_sm90a_supported so the test follows the project coding
guideline for architecture-gated tests.

---

Nitpick comments:
In `@tests/utils/test_topk.py`:
- Around line 1242-1281: Replace the manual env-var save/restore block (old_algo
/ os.environ / try/finally) in the test with the existing set_topk_algo fixture:
accept set_topk_algo as a test parameter and call set_topk_algo("multi_cta") at
the start of the test instead of manipulating os.environ directly; leave the
rest of the test (scores, lengths, flashinfer.top_k_ragged_transform,
assertions, and bounds checks) unchanged so the fixture manages
FLASHINFER_TOPK_ALGO setup/teardown.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bf6ec24 and 8d2b45c.

📒 Files selected for processing (1)
  • tests/utils/test_topk.py

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !361 has been created, and the CI pipeline #45159141 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45159141: 1/20 passed

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !361 has been updated with latest changes, and the CI pipeline #45169718 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #45169718: canceled

@bkryu bkryu merged commit f487726 into flashinfer-ai:main Mar 13, 2026
19 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…lation (flashinfer-ai#2592)

## Summary

- Fix unsigned integer underflow in `RadixTopKKernel` when `chunk_start
>= length` in multi-CTA mode with variable-length rows.
  - Add regression test for ragged transform mode.

## Problem

In multi-CTA mode, `chunk_size` and `ctas_per_group` are derived from
`max_len` (the input tensor stride). In ragged/page-table modes, each
row has its own `length` which can be much shorter than `max_len`. When
a CTA's `chunk_start = cta_in_group *
  chunk_size` exceeds a row's actual `length`:

  ```cpp
const uint32_t chunk_end = min(chunk_start + chunk_size, length); // =
length
const uint32_t actual_chunk_size = chunk_end - chunk_start; // unsigned
underflow
```

```chunk_end``` resolves to ```length``` (since ```length <
chunk_start```), and the subtraction underflows, causing out-of-bounds
memory access and a segfault.

###  Fix
```cpp
  const uint32_t actual_chunk_size = (chunk_start < length) ? (chunk_end - chunk_start) : 0;
```
CTAs whose chunk falls beyond a row's length get ```actual_chunk_size =
0```. They still participate in multi-CTA barriers (required for
correctness) but process no data.

###  Test plan
```
  - pytest tests/utils/test_topk.py::test_ragged_transform_multi_cta_short_rows
 ```

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **Tests**
  * Added a regression test covering top-k ragged transform with mixed short and long sequences to improve edge-case coverage, verify shapes/dtypes/accuracy, ensure returned indices are within valid per-row ranges, and confirm environment variable handling is restored after test execution.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-authored-by: Brian Ryu <bryu@nvidia.com>
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…lation (flashinfer-ai#2592)

## Summary

- Fix unsigned integer underflow in `RadixTopKKernel` when `chunk_start
>= length` in multi-CTA mode with variable-length rows.
  - Add regression test for ragged transform mode.

## Problem

In multi-CTA mode, `chunk_size` and `ctas_per_group` are derived from
`max_len` (the input tensor stride). In ragged/page-table modes, each
row has its own `length` which can be much shorter than `max_len`. When
a CTA's `chunk_start = cta_in_group *
  chunk_size` exceeds a row's actual `length`:

  ```cpp
const uint32_t chunk_end = min(chunk_start + chunk_size, length); // =
length
const uint32_t actual_chunk_size = chunk_end - chunk_start; // unsigned
underflow
```

```chunk_end``` resolves to ```length``` (since ```length <
chunk_start```), and the subtraction underflows, causing out-of-bounds
memory access and a segfault.

###  Fix
```cpp
  const uint32_t actual_chunk_size = (chunk_start < length) ? (chunk_end - chunk_start) : 0;
```
CTAs whose chunk falls beyond a row's length get ```actual_chunk_size =
0```. They still participate in multi-CTA barriers (required for
correctness) but process no data.

###  Test plan
```
  - pytest tests/utils/test_topk.py::test_ragged_transform_multi_cta_short_rows
 ```

<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **Tests**
  * Added a regression test covering top-k ragged transform with mixed short and long sequences to improve edge-case coverage, verify shapes/dtypes/accuracy, ensure returned indices are within valid per-row ranges, and confirm environment variable handling is restored after test execution.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-authored-by: Brian Ryu <bryu@nvidia.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants