Skip to content

Fix filtered topk unified kernel#2319

Closed
murphymatt wants to merge 7 commits intoflashinfer-ai:mainfrom
murphymatt:fix-filtered-topk-unified-kernel
Closed

Fix filtered topk unified kernel#2319
murphymatt wants to merge 7 commits intoflashinfer-ai:mainfrom
murphymatt:fix-filtered-topk-unified-kernel

Conversation

@murphymatt
Copy link
Copy Markdown

@murphymatt murphymatt commented Jan 9, 2026

📌 Description

There's currently a minor bug in the filtered topk unified kernel affecting some edge cases (missing bounds checking in vectorized loops).

We currently loop over the length axis and perform vectorized processing over chunks. However, when length % VEC_SIZE != 0, the vectorized load loops can process elements beyond the valid range, and it is possible for out-of-bound indices to appear in the output.

Simple fix is to apply bounds checking in the affected locations.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

pytest tests/utils/test_topk.py
  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Fixed boundary checks to prevent reads/writes beyond input length, avoiding out-of-bounds memory access.
    • Made tail-element handling explicit so only valid elements update histograms and indices, improving stability for edge cases.
  • Tests

    • Added ragged-length bounds test (parametrized by length offset and dtype) to validate indices remain within each row's bounds.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 9, 2026

Warning

Rate limit exceeded

@murphymatt has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 2 minutes and 23 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 1dd0022 and 03497b7.

📒 Files selected for processing (1)
  • tests/utils/test_topk.py
📝 Walkthrough

Walkthrough

Adds explicit per-element length checks to a top-k CUDA kernel to prevent out-of-bounds reads/writes for tail elements and adds a ragged-row test exercising unaligned per-row lengths to validate index bounds.

Changes

Cohort / File(s) Summary
Kernel boundary guards
include/flashinfer/topk.cuh
Adds base + j < length guards in three code paths to prevent out-of-bounds histogram increments, bin/threshold calculations, and raw_input/s_indices accesses for tail elements.
Ragged-row test updates
tests/utils/test_topk.py
Adds test_ragged_transform_unaligned_length_bounds parametrized over length_offset and dtype, constructing per-row length = max_len - length_offset and asserting computed indices lie within offset <= index < offset + length.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • IwakuraRein
  • jiahanc
  • kahyunnam

Poem

🐰 I hop through kernels, careful and spry,

I count each tail where stray bytes may lie,
I guard the bins and fence each stride,
No out-of-bounds where indices hide,
A carrot for safety, nibble and sigh.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Fix filtered topk unified kernel' accurately describes the main change - a bug fix addressing bounds checking issues in the filtered topk kernel.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed The pull request description adequately explains the bug (missing bounds checking in vectorized loops), describes the fix (adding bounds checking), and includes completed pre-commit and testing checklist items with test command.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @murphymatt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug in the filtered topk unified kernel that caused out-of-bounds memory access during vectorized processing. The issue arose when the input data length was not perfectly divisible by the vector size, leading to incorrect indexing and atomicAdd operations on invalid memory locations. The solution involves adding precise bounds checking within the affected vectorized loops to ensure all memory accesses and data manipulations are confined to the valid input range, thereby enhancing the kernel's robustness and correctness.

Highlights

  • Bug Fix: Filtered TopK Unified Kernel: Addressed a minor bug in the filtered topk unified kernel where vectorized loops could lead to out-of-bounds memory access in edge cases, specifically when the input length is not a multiple of the vector size.
  • Bounds Checking Implementation: Introduced explicit bounds checking (if (base + j < length)) within the vectorized processing loops to ensure that operations like atomicAdd and index assignments (s_indices[pos] = base + j) only occur within the valid range of the input data.
  • Prevention of Incorrect Indexing: The fix prevents potential issues where the vectorized load loops might process elements beyond the valid range, which could result in out-of-bound indices appearing in the output.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a bug in the filtered_topk_unified kernel by adding bounds checking to vectorized loops. This is an important fix for edge cases where the input length is not a multiple of the vector size. However, the current implementation introduces two critical issues. First, a logic error in the histogram update will cause double-counting of elements. Second, a typo in a conditional check uses a variable before it's declared, which will lead to a compilation error. I've provided suggestions to correct these issues. The other bounds check added in the PR appears to be correct.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @include/flashinfer/topk.cuh:
- Around line 2044-2051: The loop in topk.cuh uses an undefined variable `bin`
in the bounds check; change the condition to check the computed index using
`base + j` before computing `bin`. Specifically, in the for-loop that iterates
VEC_SIZE, replace `if (bin + j < length)` with a bounds check like `if (base + j
< length)` so you only compute `const auto bin =
static_cast<int>(Traits::ToCoarseKey(score_vec[j]));` afterwards, and then
proceed to compare `bin` to `threshold_bin` and use `atomicAdd(&s_counter, 1)` /
`s_indices[pos] = base + j` as before.
- Around line 1996-2002: The loop currently increments s_histogram twice for
in-bounds elements because there is an unconditional atomicAdd followed by a
guarded atomicAdd; remove the unconditional increment (the first call to
Traits::ToCoarseKey(score_vec[j]) and atomicAdd(&s_histogram[bin], 1)) and keep
only the bounds-checked version so each element contributes at most one
histogram increment; ensure you compute bin once per j using
Traits::ToCoarseKey(score_vec[j]) and call atomicAdd(&s_histogram[bin], 1) only
inside the if (base + j < length) branch to avoid double-counting.
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd2b033 and ace332f.

📒 Files selected for processing (1)
  • include/flashinfer/topk.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/topk.cuh
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/topk.cuh
🔇 Additional comments (1)
include/flashinfer/topk.cuh (1)

2066-2081: LGTM: Correct bounds checking implementation.

This segment properly guards all operations on vector elements with if (base + j < length), preventing out-of-bounds reads from score_vec[j] and invalid writes to shared memory arrays when the vector chunk extends beyond the input length.

This is the correct pattern for vectorized boundary handling in CUDA kernels.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/utils/test_topk.py (1)

382-382: Good addition to test non-aligned lengths.

The length_offset parameter effectively creates test cases where actual_length = max_len - length_offset may not be aligned to VEC_SIZE, which exercises the bounds checking fix described in the PR.

📝 Optional: Consider adding a comment explaining the purpose

Adding a brief comment would help future maintainers understand why these specific offset values are tested:

+# Test both VEC_SIZE-aligned (0) and non-aligned (1) lengths to ensure bounds checking
 @pytest.mark.parametrize("length_offset", [0, 1])

Alternatively, you could test additional offset values (2, 3, etc.) to cover different VEC_SIZE alignments (e.g., VEC_SIZE=2, 4, 8), though [0, 1] is sufficient for catching the out-of-bounds issue.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9d55fd5 and b40569d.

📒 Files selected for processing (1)
  • tests/utils/test_topk.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/utils/test_topk.py
🔇 Additional comments (4)
tests/utils/test_topk.py (4)

386-391: LGTM! Correctly implements non-aligned length testing.

The function signature update and actual_length computation properly introduce variable row lengths into the test. The skip condition correctly uses actual_length to prevent invalid test cases where k > actual_length.


405-406: Correctly applies actual_length to test data.

The lengths tensor now uses actual_length instead of max_len, ensuring the kernel processes non-VEC_SIZE-aligned row lengths. This setup properly exercises the bounds checking fix.


420-430: Excellent bounds validation to verify the kernel fix.

This explicit validation loop ensures that all output indices fall within the valid range [offset, offset + length) for each row. This is the critical check that would catch the bug described in the PR, where vectorized loads without bounds checking could produce out-of-range indices when length % VEC_SIZE != 0.

The validation logic is sound:

  • Filters padding values (>= 0)
  • Converts to relative indices for range checking
  • Provides clear error messages with actual values

1209-1211: Good coverage of aligned and non-aligned test cases.

The test invocations now exercise both length_offset=0 (VEC_SIZE-aligned) and length_offset=1 (potentially non-aligned) scenarios. This ensures the kernel fix handles both common and edge cases correctly.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/utils/test_topk.py (2)

786-788: Consider adding torch.bfloat16 to dtype parameters.

Other tests in this file include bfloat16 (e.g., lines 343, 384), and the function supports it per the relevant code snippets. Since bfloat16 would have the same VEC_SIZE as float16 (8), it would provide additional validation of the bounds-checking fix for unaligned lengths.

♻️ Suggested enhancement
 @pytest.mark.parametrize("length_offset", [1, 2, 3, 5, 6, 7])
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
 def test_ragged_transform_unaligned_length_bounds(length_offset, dtype):

788-819: Consider testing top_k_page_table_transform with unaligned lengths.

The enriched summary indicates that bounds checks were added to multiple locations in topk.cuh for both histogram increments and bin computations. While this test covers top_k_ragged_transform, it may be worth adding a similar test for top_k_page_table_transform to ensure the bounds-checking fix applies to both code paths.

Additionally, since the PR specifically mentions the "filtered topk unified kernel," you might consider explicitly testing with the filtered algorithm:

def test_ragged_transform_unaligned_length_bounds_filtered(length_offset, dtype):
    """Test with filtered algorithm explicitly."""
    from flashinfer.topk import can_implement_filtered_topk
    if not can_implement_filtered_topk():
        pytest.skip("GPU does not support filtered topk")
    
    os.environ["FLASHINFER_TOPK_ALGO"] = "filtered"
    try:
        # ... existing test logic ...
    finally:
        os.environ.pop("FLASHINFER_TOPK_ALGO", None)

However, the existing test should catch the bug regardless of algorithm selection, so this is optional.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b40569d and 1dd0022.

📒 Files selected for processing (1)
  • tests/utils/test_topk.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/utils/test_topk.py
🧬 Code graph analysis (1)
tests/utils/test_topk.py (1)
flashinfer/topk.py (1)
  • top_k_ragged_transform (348-421)
🔇 Additional comments (1)
tests/utils/test_topk.py (1)

794-819: Well-designed test for the bounds-checking bug.

The test implementation effectively validates the fix:

  • Creates unaligned lengths (max_len - length_offset) to trigger the vectorized boundary condition
  • Properly validates that all returned indices fall within [offset, offset + length)
  • Clear assertions with helpful error messages

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 9, 2026

Hi @murphymatt thanks for working on this, looking like a duplicate of #2308 ?

@murphymatt
Copy link
Copy Markdown
Author

Hi @murphymatt thanks for working on this, looking like a duplicate of #2308 ?

awesome, didn't see this before

@murphymatt murphymatt closed this Jan 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants