Skip to content

Fix: FilteredTopKUnifiedKernel read value out of length#2308

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
HarryWu99:topk_fix
Jan 9, 2026
Merged

Fix: FilteredTopKUnifiedKernel read value out of length#2308
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
HarryWu99:topk_fix

Conversation

@HarryWu99
Copy link
Copy Markdown
Contributor

@HarryWu99 HarryWu99 commented Jan 8, 2026

📌 Description

I found an issue in the implementation of top_k_ragged_transform where elements beyond the specified lengths range can be read. This leads to out-of-bounds indices in the result.

Minimal repro:

torch.set_default_device(torch.cuda.current_device())
torch.manual_seed(22)
topk = 4
block_logits = torch.randn(1, 8, dtype=torch.float32)
print(block_logits)

idxs = flashinfer.top_k_ragged_transform(
    block_logits,
    offsets=torch.zeros(1, dtype=torch.int32),
    lengths=torch.tensor([5], dtype=torch.int32),
    k=topk
)
print(idxs)
print(block_logits[:, :5].topk(k=4).indices)

Output:

tensor([[ 0.4519,  1.0099, -0.3167, -2.1224,  1.0826,  1.9583,  0.2751,  0.0463]],
       device='cuda:0')
tensor([[0, 4, 1, 5]], device='cuda:0', dtype=torch.int32)
tensor([[4, 1, 0, 2]], device='cuda:0')

In this example, lengths = [5], so indices in idxs should be strictly less than 5. The returned index 5 is clearly out of bounds.

The root cause is that the vec_load logic does not properly guard against reading past the valid range near the boundary. After fixing the boundary handling in vec_load, this case produces the expected result.

Affected APIs:

  • top_k_ragged_transform
  • top_k_page_table_transform

I also checked the implementation of RadixTopKMultiCTA, and it does not appear to suffer from this bug.

🔍 Related Issues

None

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • All tests are passing (unittest, etc.).
>> python tests/utils/test_topk.py
Testing page table transform...
Testing ragged transform...
Testing trivial cases...
Testing variable lengths...
Testing large scale...

Testing SGLang-style comparisons...

All tests passed!

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Improved boundary-safe processing for vectorized and tail elements to prevent out-of-range reads and ensure consistent histogram/top-k results on non-multiple lengths.
  • Tests
    • Added a test verifying top-k behavior with ragged/out-of-length inputs, ensuring outputs remain within valid index ranges and match a reference implementation.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 8, 2026

📝 Walkthrough

Walkthrough

Aligned-length bounds added to the top-k kernel to limit vectorized loops, explicit tail-handling loops introduced, per-element histogram logic refactored into a reusable lambda, and a test added validating out-of-length behavior for ragged top-k transforms.

Changes

Cohort / File(s) Summary
Top-k kernel updates
include/flashinfer/topk.cuh
Compute aligned_length to bound vectorized loops, iterate only to that boundary, add explicit tail-handling loops, introduce filter_and_add_to_histogram lambda to centralize per-element histogram/filter logic, and replace inline histogram updates with lambda calls across paths.
Tests — ragged/top-k
tests/utils/test_topk.py
Add test_top_k_ragged_transform_out_of_length validating ragged top-k outputs remain within [offset, offset+length), compare outputs to a reference implementation, check shape/dtype, and require ≥95% accuracy; skip when k > max_len.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • IwakuraRein
  • kahyunnam
  • jiahanc

Poem

🐰
I bound the vectors, tidy and light,
Tail crumbs handled, kept just right.
A lambda hops, collects each bit,
Histograms hum—no overreaded skit.
Joyful rabbit, code compiled bright.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main fix: correcting out-of-bounds read behavior in the FilteredTopKUnifiedKernel, which is the core issue addressed in this PR.
Description check ✅ Passed The description includes all required template sections: a clear problem statement with minimal repro, related issues, and completed pre-commit and test checklists with evidence of passing tests.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 10ab3d4 and 0a01d1c.

📒 Files selected for processing (1)
  • include/flashinfer/topk.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/topk.cuh

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @HarryWu99, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a memory safety vulnerability in the top-k transformation functions by preventing out-of-bounds memory access during vector loading operations. The fix ensures that data is read strictly within the defined boundaries, thereby enhancing the robustness and correctness of these performance-critical kernels.

Highlights

  • Out-of-bounds Read Fix: Addressed a critical bug in top_k_ragged_transform and top_k_page_table_transform where the vec_load operation could read beyond the specified lengths range, leading to out-of-bounds indices in the result.
  • Boundary Condition Guard: Implemented a precise boundary check (&& base + j < length) within the for loops of the FilteredTopKUnifiedKernel in topk.cuh to ensure all vector loads respect the valid data length.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a critical out-of-bounds read bug in FilteredTopKUnifiedKernel. The issue occurs when vectorized loads read past the specified length of the input array, leading to incorrect indices in the top-k results. The fix correctly adds a boundary check within the inner loops to ensure that only elements within the valid range are processed. While this prevents the use of out-of-bounds data, it doesn't prevent the out-of-bounds read itself. I've added a comment suggesting a more robust approach that avoids out-of-bounds memory access altogether by handling tail elements separately, which is a common and safer pattern. Overall, this is a good and necessary fix.

Comment on lines 1993 to 1999
for (int base = tx * VEC_SIZE; base < length; base += BLOCK_SIZE * VEC_SIZE) {
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
for (int j = 0; j < VEC_SIZE && base + j < length; ++j) {
const auto bin = Traits::ToCoarseKey(score_vec[j]);
atomicAdd(&s_histogram[bin], 1);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change correctly prevents processing of data read from out-of-bounds memory. However, the vectorized load score_vec.cast_load(&score[base]) on line 1994 can still read past the length boundary when base is close to length.

While this pattern of over-reading and then masking is common in performance-critical GPU code, it relies on behavior that isn't guaranteed to be safe across all hardware and memory allocators. A more robust and safer approach would be to handle the tail elements separately with scalar loads, avoiding any out-of-bounds access.

Consider refactoring this loop to separate the vectorized main part from the scalar tail part. This is a common pattern seen elsewhere in this file (e.g., LoadToSharedOrdered).

Example of a safer implementation:

  const int aligned_length = (length / VEC_SIZE) * VEC_SIZE;
#pragma unroll 2
  for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) {
    score_vec.cast_load(&score[base]);
#pragma unroll
    for (int j = 0; j < VEC_SIZE; ++j) {
      const auto bin = Traits::ToCoarseKey(score_vec[j]);
      atomicAdd(&s_histogram[bin], 1);
    }
  }

  // Scalar tail processing
  for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) {
    const auto bin = Traits::ToCoarseKey(score[i]);
    atomicAdd(&s_histogram[bin], 1);
  }

Since this is a critical bug fix, the current change is acceptable if the performance impact of a safer approach is a concern. However, I recommend considering the more robust implementation for long-term maintainability and safety.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/topk.cuh (1)

1993-2000: Critical: Vectorized load still reads out-of-bounds despite inner loop guard.

While the inner loop check && base + j < length prevents processing out-of-bounds elements, the score_vec.cast_load(&score[base]) on line 1994 still performs an out-of-bounds read when base < length but base + VEC_SIZE > length.

Example: If length = 5 and VEC_SIZE = 4, when base = 4:

  • Outer loop condition base < length (4 < 5) passes
  • cast_load(&score[4]) reads 4 elements: indices 4, 5, 6, 7
  • Only index 4 is valid; indices 5, 6, 7 are out-of-bounds
  • Inner loop only processes j=0 due to the check, but the invalid read already occurred

This same issue exists at lines 2037 and 2056. The correct pattern (used in LoadToSharedOrdered at lines 423-434) is:

const int aligned_length = (length / VEC_SIZE) * VEC_SIZE;
#pragma unroll 2
for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) {
  score_vec.cast_load(&score[base]);
  #pragma unroll
  for (int j = 0; j < VEC_SIZE; ++j) {
    // process score_vec[j]
  }
}
// Handle tail elements individually
for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) {
  const auto raw_input = score[i];
  // process raw_input
}
🔧 Proposed fix for out-of-bounds reads

Apply this pattern to all three locations (lines 1993-2000, 2036-2047, 2056-2075):

+  const int aligned_length = (length / VEC_SIZE) * VEC_SIZE;
+
 #pragma unroll 2
-  for (int base = tx * VEC_SIZE; base < length; base += BLOCK_SIZE * VEC_SIZE) {
+  for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) {
     score_vec.cast_load(&score[base]);
 #pragma unroll
-    for (int j = 0; j < VEC_SIZE && base + j < length; ++j) {
+    for (int j = 0; j < VEC_SIZE; ++j) {
       const auto bin = Traits::ToCoarseKey(score_vec[j]);
       atomicAdd(&s_histogram[bin], 1);
     }
   }
+  // Handle tail elements
+  for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) {
+    const auto bin = Traits::ToCoarseKey(score[i]);
+    atomicAdd(&s_histogram[bin], 1);
+  }

Apply analogous changes to lines 2036-2047 and 2056-2075.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4230a48 and bd3f16b.

📒 Files selected for processing (1)
  • include/flashinfer/topk.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/topk.cuh
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/topk.cuh

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the timely fix, would you mind creating corresponding unittests?

@flashinfer-ai flashinfer-ai deleted a comment from claude bot Jan 8, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/utils/test_topk.py (2)

422-443: Update docstring and clarify comment.

  1. Line 427: The docstring is copy-pasted from test_top_k_ragged_transform and should be updated to reflect this test's focus on out-of-length handling:

    """Test top_k_ragged_transform with randomized lengths to verify correct boundary handling."""
  2. Line 437: The comment "Generate naive offsets" is unclear. Consider:

    # Generate zero offsets (testing boundary handling, not offset logic)

461-468: Consider more direct bounds checking.

The current approach of clamping -1 padding values to 0 before checking bounds works correctly since offsets are all 0 in this test. However, it's conceptually indirect—it checks that padding values (after clamping) are in bounds, rather than checking only valid (non-padding) indices.

For improved clarity, consider filtering out padding values first:

# Check out of length - only validate non-padding indices
for i in range(num_rows):
    valid_mask = output[i] != -1
    if valid_mask.any():
        valid_indices = output[i][valid_mask]
        offset = offsets[i].item()
        length = lengths[i].item()
        assert torch.all((valid_indices >= offset) & (valid_indices < offset + length)), (
            f"Row {i}: indices out of range [{offset}, {offset + length})"
        )

The current implementation is correct and simpler, so this is just a nice-to-have improvement for clarity.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd3f16b and 10ab3d4.

📒 Files selected for processing (2)
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py
🧰 Additional context used
📓 Path-based instructions (2)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/topk.cuh
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/utils/test_topk.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/topk.cuh
🧬 Code graph analysis (1)
tests/utils/test_topk.py (1)
flashinfer/topk.py (1)
  • top_k_ragged_transform (348-421)
🔇 Additional comments (4)
include/flashinfer/topk.cuh (3)

1992-2006: LGTM! Correct fix for out-of-bounds read.

The aligned-length boundary approach correctly prevents vectorized loads from reading past the valid range. The bounds check at line 1997 (base + j < length) is technically redundant given that base < aligned_length already ensures base + VEC_SIZE <= aligned_length <= length, but it serves as good defensive programming.

Optional improvement: Consider adding a brief comment before line 1992 explaining the aligned-length approach for future maintainers:

// Restrict vectorized loads to aligned boundaries to prevent reading beyond valid range
const int aligned_length = (length / VEC_SIZE) * VEC_SIZE;

Based on coding guidelines: For performance-critical hot paths, leave comments explaining special algorithmic choices.


2043-2061: Consistent application of the boundary fix.

The aligned-length approach is correctly applied to the topk == 0 case, maintaining consistency with the coarse histogram path. The tail handling at lines 2055-2061 correctly processes remaining elements with scalar loads.


2069-2095: Good refactoring with the lambda.

The filter_and_add_to_histogram lambda effectively consolidates per-element histogram logic, improving code reuse across the vectorized loop (lines 2085-2091) and tail loop (lines 2093-2095). The absence of an explicit bounds check in the vectorized loop at line 2088 is correct, as the loop condition base < aligned_length already ensures that base + j < length for all j in [0, VEC_SIZE).

tests/utils/test_topk.py (1)

444-460: LGTM!

The test correctly validates both the accuracy of the top-k selection and the shape/dtype of outputs. The 0.95 accuracy threshold is appropriately consistent with other transform tests.

@HarryWu99
Copy link
Copy Markdown
Contributor Author

I updated the logic for detecting the out of length condition following the guidance from Gemini, and added a corresponding unit test. The new test fails on the previous implementation and passes after this fix.

I also ran the performance benchmark benchmarks/bench_topk.py, and did not observe any noticeable performance differences compared to the previous version.

after fix

====================================================================================================
top_k_ragged_transform: Fused top-k + ragged index transform (dtype=FP32)
====================================================================================================
 batch    seq_len      k |   FlashInfer
----------------------------------------------------------------------
     1      32768   2048 |      23.95us
     4      32768   2048 |      24.45us
     8      32768   2048 |      24.37us
    16      32768   2048 |      25.54us

before fix

====================================================================================================
top_k_ragged_transform: Fused top-k + ragged index transform (dtype=FP32)
====================================================================================================
 batch    seq_len      k |   FlashInfer
----------------------------------------------------------------------
     1      32768   2048 |      24.26us
     4      32768   2048 |      23.60us
     8      32768   2048 |      24.58us
    16      32768   2048 |      24.85us

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 9, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !232 has been created, and the CI pipeline #41404178 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 mentioned this pull request Jan 9, 2026
5 tasks
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #41404178: 13/20 passed

@yzh119 yzh119 merged commit c6dcedf into flashinfer-ai:main Jan 9, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants