Skip to content

[bugfix] Fix FilteredTopK overflow correctness#2605

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
jiangyinzuo:fix/bf16-topk-filtered-overflow
Feb 25, 2026
Merged

[bugfix] Fix FilteredTopK overflow correctness#2605
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
jiangyinzuo:fix/bf16-topk-filtered-overflow

Conversation

@jiangyinzuo
Copy link
Copy Markdown
Contributor

@jiangyinzuo jiangyinzuo commented Feb 20, 2026

📌 Description

This PR fixes issue #2604: FilteredTopK could produce incorrect results when the refine candidate buffer
overflows in long-sequence, tie-heavy workloads.

Root Cause

FilteredTopK uses a fixed-size refine candidate buffer (SMEM_INPUT_SIZE=16K), which can overflow when the threshold bin is too large. The previous logic could continue with partially truncated state after overflow, causing incorrect top-k outputs.

Main Changes

1) Kernel correctness fixes (include/flashinfer/topk.cuh)

  • Added a refine-overflow flag s_refine_overflow, set atomically on candidate write overflow.
  • On refine-buffer overflow, we now use a full-scan correctness fallback: rebuild the needed histogram/threshold state from the full input, recompute pivot selection state, reset partial intermediate state, and recollect winners from the full
    input to guarantee correct top-k results.

2) Regression test coverage (tests/utils/test_topk.py)

  • Added BF16 long-sequence regressions across auto / multi_cta / filtered.
  • Added FP32 long-sequence refine-overflow regressions for both top_k and transform APIs (page_table / ragged).
  • Added FP32 pivot reconstruction regressions for both top_k and transform APIs.

🔍 Related Issues

fix: #2604

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).
  • Performance test regression.

Run the following benchmark 10 times. There is no noticeable performance degradation in this PR.

python benchmarks/bench_topk.py --compare-algorithms --dtype bf16/fp16/fp32

Environment

  • GPU: NVIDIA A100-PCIE-40GB (SM80)
  • PyTorch: 2.8.0+cu128
  • CUDA (torch): 12.8

bf16

batch seq_len multi-CTA(us) filtered(us)this PR filtered(us) base filtered this PR / base
1 4096 27.470 20.887 20.377 1.0250
1 16384 44.983 39.322 38.911 1.0106
1 65536 106.397 89.090 88.629 1.0052
1 131072 107.572 96.564 101.887 0.9478
1 262144 89.549 154.777 154.625 1.0010
1 524288 93.492 242.329 258.611 0.9370
16 4096 18.739 16.178 16.896 0.9575
16 16384 28.055 25.087 25.190 0.9959
16 65536 63.949 55.551 55.294 1.0046
16 131072 65.381 67.379 66.968 1.0061
16 262144 65.591 114.792 114.689 1.0009
16 524288 136.191 213.658 213.708 0.9998
64 4096 19.304 18.227 16.997 1.0724
64 16384 28.875 26.620 26.620 1.0000
64 65536 69.630 61.440 61.440 1.0000
64 131072 133.197 76.800 76.800 1.0000
64 262144 190.563 140.392 140.596 0.9985
64 524288 335.819 264.088 263.935 1.0006
256 4096 38.910 29.904 29.700 1.0069
256 16384 79.870 73.576 73.524 1.0007
256 65536 200.700 173.060 173.060 1.0000
256 131072 322.662 215.142 215.091 1.0002
256 262144 588.391 421.581 421.477 1.0002
256 524288 1141.863 811.111 811.367 0.9997

fp16

batch seq_len multi-CTA(us) filtered(us)this PR filtered(us) base filtered this PR / base
1 4096 27.599 20.324 19.252 1.0557
1 16384 44.545 36.860 36.707 1.0042
1 65536 103.780 70.145 70.863 0.9899
1 131072 103.270 97.638 93.080 1.0490
1 262144 89.141 154.778 147.199 1.0515
1 524288 84.274 201.217 202.087 0.9957
16 4096 19.738 15.820 16.793 0.9421
16 16384 29.695 25.498 25.447 1.0020
16 65536 65.562 46.283 46.847 0.9880
16 131072 65.537 67.072 67.174 0.9985
16 262144 65.024 114.484 114.175 1.0027
16 524288 133.426 176.232 176.027 1.0012
64 4096 18.485 15.667 17.815 0.8794
64 16384 28.670 24.682 24.580 1.0041
64 65536 68.147 51.200 51.200 1.0000
64 131072 131.714 76.800 76.800 1.0000
64 262144 188.930 137.322 137.322 1.0000
64 524288 328.855 208.900 208.900 1.0000
256 4096 38.885 27.547 27.341 1.0075
256 16384 78.901 68.610 68.610 1.0000
256 65536 195.580 142.237 142.134 1.0007
256 131072 318.975 214.938 214.836 1.0005
256 262144 582.863 410.672 410.777 0.9997
256 524288 1119.438 643.019 642.815 1.0003

fp32

batch seq_len multi-CTA(us) filtered(us)this PR filtered(us) base filtered this PR / base
1 4096 35.840 19.970 19.970 1.0000
1 16384 58.370 44.957 45.161 0.9955
1 65536 99.072 87.911 89.089 0.9868
1 131072 92.674 132.098 124.827 1.0582
1 262144 88.114 188.930 187.392 1.0082
1 524288 77.157 263.422 260.658 1.0106
16 4096 22.270 16.692 16.794 0.9939
16 16384 36.150 28.673 29.799 0.9622
16 65536 60.879 56.014 56.731 0.9874
16 131072 60.522 86.632 86.581 1.0006
16 262144 123.594 151.860 151.653 1.0014
16 524288 141.516 275.816 275.559 1.0009
64 4096 22.581 17.306 17.817 0.9713
64 16384 37.916 31.740 31.740 1.0000
64 65536 121.860 66.866 66.662 1.0031
64 131072 177.150 112.742 112.742 1.0000
64 262144 308.426 203.728 203.780 0.9997
64 524288 503.938 310.373 310.270 1.0003
256 4096 53.250 30.516 30.924 0.9868
256 16384 108.618 88.781 88.987 0.9977
256 65536 295.657 188.420 188.471 0.9997
256 131072 543.282 333.514 333.361 1.0005
256 262144 1002.573 654.493 654.646 0.9998
256 524288 1899.034 1035.158 1035.468 0.9997

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Improved Top-K correctness for long sequences with overflow-safe, staged refinement (fast and slow paths), multi-pass threshold handling, and coordination across parallel units to handle tie-heavy and buffer-overflow edge cases (BF16/FP32).
  • Tests

    • Added regression tests, data builders, and helpers for BF16 long-sequence transforms and FP32 long-sequence overflow/pivot cases, plus tie-insensitive index comparison utilities.
  • Documentation

    • Added note about potential threshold-bin overflow (bf16/tie-heavy) with suggestion to use a multi-CTA mode when needed.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 20, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds staged, overflow-aware multi-round refinement to the FilteredTopK kernel (shared-memory refine state, inter-CTA coordination, staged collection and slow-path fallback) and expands BF16/FP32 long-sequence tests to cover filtered/multi-CTA/auto behaviors, overflow regressions, and pivot reconstruction cases.

Changes

Cohort / File(s) Summary
TopK Kernel Refinement
include/flashinfer/topk.cuh
Reworks FilteredTopKUnifiedKernel: adds shared refine state (s_refine_overflow, s_last_remain), initializes/updates refine thresholds, breaks refinement into staged helper lambdas, supports 1-round fast-path and multi-round slow-path with overflow-safe collection (including full-threshold re-histogram), and adds single-CTA vs multi-CTA coordination and extra synchronization.
Test Coverage Expansion
tests/utils/test_topk.py
Adds BF16 and FP32 long-sequence regression tests across algorithms and transform modes, capability gating for BF16, helpers to build tie-heavy/overflow/pivot-mismatch inputs (_build_bf16_long_seq_bucket_inputs, _build_fp32_long_seq_overflow_inputs, _build_fp32_long_seq_pivot_mismatch_inputs), unordered-index comparison helper, and new test cases exercising filtered/multi-CTA/auto behaviors.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~55 minutes

Possibly related PRs

Suggested labels

v0.6.1

Suggested reviewers

  • jiahanc
  • kahyunnam
  • IwakuraRein
  • yzh119
  • nv-yunzheq

Poem

"I hopped through bins and thresholds bright,
I nudged the counts by shared-memory light,
When overflow loomed and buffers swayed,
I staged my rounds and found the way,
A rabbit cheers — the top‑k’s right tonight! 🐇"

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title '[bugfix] Fix FilteredTopK overflow correctness' is clear, specific, and directly related to the main changes addressing the BF16 long-sequence FilteredTopK correctness bug.
Linked Issues check ✅ Passed The PR changes directly address issue #2604 by implementing overflow handling and threshold refinement logic to fix FilteredTopK correctness for BF16 long sequences, and adds comprehensive regression tests.
Out of Scope Changes check ✅ Passed All changes are directly scoped to fixing the FilteredTopK overflow bug: kernel modifications for overflow handling, refinement logic, and comprehensive test coverage for BF16/FP32 scenarios.
Docstring Coverage ✅ Passed Docstring coverage is 91.67% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed The pull request description is comprehensive and well-structured, addressing all key template sections including problem description, root cause analysis, main changes, related issues, and checklist completion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @jiangyinzuo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical correctness bug within the FilteredTopK CUDA kernel, specifically impacting bf16 data types. The issue occurred when the fast path for threshold bin refinement experienced an overflow, leading to incorrect results. To mitigate this, a new 'slow path' has been implemented that re-histograms the entire threshold bin, guaranteeing accurate top-k computations under these specific conditions. Accompanying this fix are new regression tests designed to thoroughly validate the corrected behavior across various top-k algorithms and data transformation methods.

Highlights

  • BF16 FilteredTopK Correctness Fix: Implemented a bug fix for the bf16 FilteredTopK algorithm to address correctness issues arising from threshold bin overflow during refinement.
  • Slow Path for Overflow Handling: Introduced a 'slow path' mechanism in the FilteredTopK kernel that re-histograms the full threshold bin when an overflow is detected in the fast path, ensuring accurate results for bf16 data types.
  • New Regression Tests: Added comprehensive regression tests for bf16 long sequence top-k operations, covering different algorithms and transform modes, to validate the overflow fix and ensure robust behavior.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • include/flashinfer/topk.cuh
    • Included <type_traits> for advanced type manipulation.
    • Declared new shared memory variables s_refine_overflow and s_last_remain to manage overflow state and remaining elements.
    • Initialized s_refine_overflow to zero at the start of the kernel.
    • Introduced a new lambda update_refine_threshold to encapsulate the logic for updating the threshold bin and remaining elements.
    • Modified the collect_from_input lambda to set s_refine_overflow if the number of inputs exceeds SMEM_INPUT_SIZE.
    • Refactored the Stage 2 refinement loop into a conditional block that checks for NUM_ROUNDS == 1 (for bf16/fp16) and s_refine_overflow.
    • Implemented a 'slow path' within the NUM_ROUNDS == 1 block that re-histograms the full threshold bin when s_refine_overflow is true, using new helper lambdas for_each_score_overflow, build_full_threshold_hist, and collect_from_full_threshold_bin.
    • Consolidated the fast path and multi-round refinement logic into a new run_refine_round lambda for better modularity.
  • tests/utils/test_topk.py
    • Added test_bf16_long_seq_regression_across_algorithms to verify bf16 top-k correctness for long sequences across 'auto', 'multi_cta', and 'filtered' algorithms, specifically using a data pattern prone to threshold bin overflow.
    • Introduced _build_bf16_long_seq_bucket_inputs to generate a specific bf16 workload with many tied values.
    • Added _assert_unordered_indices_match helper for comparing index sets while ignoring order due to ties.
    • Implemented test_bf16_long_seq_transform_regression_filtered to test bf16 top-k with page table and ragged transforms under the 'filtered' algorithm, ensuring correctness in overflow scenarios.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a correctness bug in the FilteredTopK algorithm for bfloat16 data types, which could occur when there are many elements with tied values, causing an overflow in an intermediate shared memory buffer. The fix introduces a robust fallback path that re-scans the input to ensure correctness in these overflow scenarios. The implementation has also been significantly refactored into smaller, more manageable helper lambdas, improving code clarity. Additionally, comprehensive regression tests have been added to verify the fix and prevent future issues. My review includes one suggestion for a minor performance optimization.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
tests/utils/test_topk.py (1)

1276-1278: _assert_unordered_indices_match assumes output dtype matches expected dtype.

output from the transform tests is torch.int32 while expected is also torch.int32, so this works. However, for the plain top_k path, indices are torch.int64 — if this helper is ever reused there, the torch.equal will fail on dtype mismatch. A minor robustness nit, not blocking.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1276 - 1278, The helper
_assert_unordered_indices_match can fail when output and expected have different
integer dtypes (e.g., int32 vs int64); modify the function to normalize dtypes
before comparison by casting both sorted values to a common integer dtype (for
example expected.dtype or torch.long) prior to calling torch.equal so that dtype
mismatches do not cause false failures when only order differs; ensure you apply
the cast to the result of torch.sort(...).values for both output and expected.
include/flashinfer/topk.cuh (2)

2199-2216: collect_with_threshold_last_round calls reset_histogram() unnecessarily.

Line 2200 clears s_histogram, but the histogram is never read again after this function. This is a harmless no-op but adds an unnecessary __syncthreads() barrier. Minor nit — not worth changing unless performance-sensitive.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 2199 - 2216, The lambda
collect_with_threshold_last_round unnecessarily calls reset_histogram() and then
issues a __syncthreads(), but s_histogram is never referenced afterward; remove
the reset_histogram() invocation and the ensuing __syncthreads() inside
collect_with_threshold_last_round to eliminate the no-op clear and extra barrier
(keep the rest of the loop logic and atomic updates to s_counter/s_last_remain
and s_indices unchanged).

2217-2239: Silent overflow in multi-round (float32) collect_with_threshold_non_last_round.

Line 2230 checks pos < SMEM_INPUT_SIZE but silently drops elements on overflow without setting any flag (unlike the 1-round path at lines 2156–2158). For float32 (NUM_ROUNDS=4), an overflow here means some equal-to-threshold elements are lost, leading to incorrect results.

In practice, this is extremely unlikely for float32 (would require >16K elements sharing identical upper 16+ bits), but for completeness the same s_refine_overflow-style fallback could apply. Worth noting for future hardening.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 2217 - 2239,
collect_with_threshold_non_last_round currently drops equal-to-threshold
elements when pos >= SMEM_INPUT_SIZE without signaling overflow; update the else
branch that handles static_cast<int>(bin) == threshold inside
collect_with_threshold_non_last_round to set the same overflow/fallback flag
used by the 1-round path (e.g., s_refine_overflow or
s_refine_overflow[next_r_idx]) when pos is out of bounds instead of silently
dropping items, mirroring the behavior in the single-round code path and keeping
existing uses of s_input_idx, s_num_input, s_histogram, s_indices and s_counter
intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/utils/test_topk.py`:
- Around line 1230-1254: test_bf16_long_seq_regression_across_algorithms tries
to run the "filtered" TopK kernel unguarded, which will fail on GPUs without the
128KB shared memory requirement; add a guard at the start of the test that skips
when algo == "filtered" and can_implement_filtered_topk() is false (mirror other
tests), i.e., call can_implement_filtered_topk() and pytest.skip when it returns
False before calling set_topk_algo or invoking flashinfer.top_k.
- Around line 1281-1306: The test
test_bf16_long_seq_transform_regression_filtered forces the "filtered" path via
set_topk_algo("filtered") but lacks the can_implement_filtered_topk() guard used
elsewhere; add a guard after the BF16 support check that calls
can_implement_filtered_topk() and calls pytest.skip("Filtered top-k not
supported on this device") if it returns False so the test only runs when the
filtered kernel can be launched (keep set_topk_algo("filtered") and the rest of
the test unchanged).

---

Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2199-2216: The lambda collect_with_threshold_last_round
unnecessarily calls reset_histogram() and then issues a __syncthreads(), but
s_histogram is never referenced afterward; remove the reset_histogram()
invocation and the ensuing __syncthreads() inside
collect_with_threshold_last_round to eliminate the no-op clear and extra barrier
(keep the rest of the loop logic and atomic updates to s_counter/s_last_remain
and s_indices unchanged).
- Around line 2217-2239: collect_with_threshold_non_last_round currently drops
equal-to-threshold elements when pos >= SMEM_INPUT_SIZE without signaling
overflow; update the else branch that handles static_cast<int>(bin) == threshold
inside collect_with_threshold_non_last_round to set the same overflow/fallback
flag used by the 1-round path (e.g., s_refine_overflow or
s_refine_overflow[next_r_idx]) when pos is out of bounds instead of silently
dropping items, mirroring the behavior in the single-round code path and keeping
existing uses of s_input_idx, s_num_input, s_histogram, s_indices and s_counter
intact.

In `@tests/utils/test_topk.py`:
- Around line 1276-1278: The helper _assert_unordered_indices_match can fail
when output and expected have different integer dtypes (e.g., int32 vs int64);
modify the function to normalize dtypes before comparison by casting both sorted
values to a common integer dtype (for example expected.dtype or torch.long)
prior to calling torch.equal so that dtype mismatches do not cause false
failures when only order differs; ensure you apply the cast to the result of
torch.sort(...).values for both output and expected.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 20, 2026

/bot run

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 20, 2026

@flashinfer-bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !335 has been created, and the CI pipeline #44473794 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44473794: 14/20 passed

@jiangyinzuo jiangyinzuo force-pushed the fix/bf16-topk-filtered-overflow branch from f6a28e6 to f72961c Compare February 21, 2026 03:49
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2228-2238: The multi-round float32 refine path doesn't handle
s_refine_overflow set by collect_with_threshold_non_last_round, so add a check
of s_refine_overflow before proceeding with the multi-round loop (the same place
get_num_input is used) and apply the same fallback used for the 1-round overflow
case: detect overflow, set a safe clamp/exit path (or force single-round
processing) and log/propagate the overflow state so elements aren't silently
lost; update the logic around get_num_input and the multi-round loop to
early-handle s_refine_overflow and mirror the bf16/1-round overflow handling.

@jiangyinzuo jiangyinzuo force-pushed the fix/bf16-topk-filtered-overflow branch from f72961c to 3d32e6d Compare February 21, 2026 06:44
@jiangyinzuo jiangyinzuo changed the title [bugfix] Fix bf16 FilteredTopK overflow correctness [bugfix] Fix FilteredTopK overflow correctness Feb 21, 2026
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces overflow detection and fallback mechanisms for the FilteredTopK algorithm to ensure correctness when processing long sequences or workloads with heavy ties. The implementation adds a slow-path fallback that performs full scans of the threshold bin when the shared memory candidate buffer overflows. While the overall logic is sound and the added regression tests are comprehensive, I identified a few issues in the fallback paths, including an incorrect pivot construction in the multi-round case and a missing counter reset in the single-round case, which could lead to incorrect results or out-of-bounds memory access.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
tests/utils/test_topk.py (2)

1246-1259: DRY: use _build_bf16_long_seq_bucket_inputs() instead of duplicating logit construction.

The logit tensor construction at lines 1246–1251 is byte-for-byte identical to _build_bf16_long_seq_bucket_inputs defined at lines 1269–1274. Call the helper instead:

♻️ Proposed refactor
-    batch_size = 4
-    vocab_size = 65536
-    k = 1024
-    device = "cuda"
-
-    # Repeated-value buckets trigger large threshold-bin occupancy in bf16.
-    logits = (
-        ((torch.arange(vocab_size, device=device, dtype=torch.float32) % 64) / 64.0)
-        .unsqueeze(0)
-        .repeat(batch_size, 1)
-        .to(torch.bfloat16)
-    )
-
-    values, indices = flashinfer.top_k(logits, k, sorted=True)
-    ref_values, _ = torch.topk(logits, k, dim=-1, sorted=True)
+    logits, _, _, batch_size, _, k, _ = _build_bf16_long_seq_bucket_inputs()
+    values, indices = flashinfer.top_k(logits, k, sorted=True)
+    ref_values, _ = torch.topk(logits, k, dim=-1, sorted=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1246 - 1259, Replace the duplicated
logit tensor construction with a call to the existing helper
_build_bf16_long_seq_bucket_inputs: remove the manual creation of logits (the
block that builds the bfloat16 repeated arange using vocab_size, batch_size,
device) and instead call _build_bf16_long_seq_bucket_inputs(batch_size,
vocab_size, device) to produce logits before calling flashinfer.top_k; ensure
the variable name logits is preserved so subsequent uses (flashinfer.top_k,
torch.topk, torch.gather) remain unchanged.

1233-1234: Consider replacing torch.cuda.is_bf16_supported() with a flashinfer.utils capability check.

torch.cuda.is_bf16_supported() is functional, but the coding guidelines require architecture-gating via flashinfer.utils functions (e.g. get_compute_capability). BF16 native support requires SM80+, so:

from flashinfer.utils import get_compute_capability
if get_compute_capability() < (8, 0):
    pytest.skip("BF16 requires SM80+")

If flashinfer.utils does not currently expose get_compute_capability, this can be deferred, but it should be tracked for consistency with the guidelines. The same applies to line 1315.

Based on coding guidelines: "Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1233 - 1234, Replace the direct
torch.cuda.is_bf16_supported() check with the flashinfer.utils
architecture-capability check: import and call get_compute_capability() and skip
the test when the returned compute capability is less than (8, 0) (BF16 requires
SM80+); update the two locations that use torch.cuda.is_bf16_supported() in
tests/utils/test_topk.py to use get_compute_capability() instead (or defer and
create/get_compute_capability in flashinfer.utils if missing) so gating follows
the project's utility functions.
include/flashinfer/topk.cuh (2)

2157-2159: Optional: prefer atomicOr over atomicExch for flag-setting semantics.

Both instructions are correct, but atomicOr(&s_refine_overflow, 1) communicates intent more clearly (set a flag, ignoring the prior value) than atomicExch, and avoids a redundant full-word write when the flag is already 1.

🔧 Suggested change
-        atomicExch(&s_refine_overflow, 1);
+        atomicOr(&s_refine_overflow, 1);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 2157 - 2159, Replace the use of
atomicExch for setting the overflow flag with atomicOr to express flag semantics
and avoid redundant writes: in the block that currently calls
atomicExch(&s_refine_overflow, 1) (the shared/global flag variable
s_refine_overflow inside the refine/overflow handling code), change it to
atomicOr(&s_refine_overflow, 1) so the flag is set atomically without
overwriting prior value; ensure the variable type matches atomicOr's expected
integer type and no other logic depends on the returned old value from
atomicExch.

2275-2321: Add a comment documenting the s_counter continuity invariant.

s_counter is intentionally not reset before collect_from_full_threshold_bin. It already holds the count of coarse_bin > threshold_bin elements (written during filter_and_add_to_histogram). collect_from_full_threshold_bin appends the sub_bin > refine_threshold elements on top, maintaining the invariant:

s_counter_final + s_last_remain_initial == top_k

without which the output positions would overlap or leave gaps. A brief comment at line 2302 would save future readers from reverse-engineering this.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 2275 - 2321, Add a short in-place
comment where s_counter is left untouched before defining/using
collect_from_full_threshold_bin (near the block that calls
update_refine_threshold and then defines collect_from_full_threshold_bin)
explaining the continuity invariant: s_counter already contains the count of
elements with coarse_bin > threshold_bin (from the earlier
filter_and_add_to_histogram phase), collect_from_full_threshold_bin appends
sub_bin > refine_threshold items on top of that, and together with s_last_remain
they satisfy s_counter_final + s_last_remain_initial == top_k; mention that
s_counter is intentionally not reset to prevent overwriting positions in
s_indices.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2157-2159: Replace the use of atomicExch for setting the overflow
flag with atomicOr to express flag semantics and avoid redundant writes: in the
block that currently calls atomicExch(&s_refine_overflow, 1) (the shared/global
flag variable s_refine_overflow inside the refine/overflow handling code),
change it to atomicOr(&s_refine_overflow, 1) so the flag is set atomically
without overwriting prior value; ensure the variable type matches atomicOr's
expected integer type and no other logic depends on the returned old value from
atomicExch.
- Around line 2275-2321: Add a short in-place comment where s_counter is left
untouched before defining/using collect_from_full_threshold_bin (near the block
that calls update_refine_threshold and then defines
collect_from_full_threshold_bin) explaining the continuity invariant: s_counter
already contains the count of elements with coarse_bin > threshold_bin (from the
earlier filter_and_add_to_histogram phase), collect_from_full_threshold_bin
appends sub_bin > refine_threshold items on top of that, and together with
s_last_remain they satisfy s_counter_final + s_last_remain_initial == top_k;
mention that s_counter is intentionally not reset to prevent overwriting
positions in s_indices.

In `@tests/utils/test_topk.py`:
- Around line 1246-1259: Replace the duplicated logit tensor construction with a
call to the existing helper _build_bf16_long_seq_bucket_inputs: remove the
manual creation of logits (the block that builds the bfloat16 repeated arange
using vocab_size, batch_size, device) and instead call
_build_bf16_long_seq_bucket_inputs(batch_size, vocab_size, device) to produce
logits before calling flashinfer.top_k; ensure the variable name logits is
preserved so subsequent uses (flashinfer.top_k, torch.topk, torch.gather) remain
unchanged.
- Around line 1233-1234: Replace the direct torch.cuda.is_bf16_supported() check
with the flashinfer.utils architecture-capability check: import and call
get_compute_capability() and skip the test when the returned compute capability
is less than (8, 0) (BF16 requires SM80+); update the two locations that use
torch.cuda.is_bf16_supported() in tests/utils/test_topk.py to use
get_compute_capability() instead (or defer and create/get_compute_capability in
flashinfer.utils if missing) so gating follows the project's utility functions.

@jiangyinzuo jiangyinzuo force-pushed the fix/bf16-topk-filtered-overflow branch from 3d32e6d to da39748 Compare February 21, 2026 08:07
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a bugfix for FilteredTopK to handle overflow scenarios correctly, especially for long sequences with many ties. The changes involve adding a fallback path that performs a full scan when the intermediate buffer for refinement overflows. This ensures correctness at the cost of performance in those edge cases. The changes are well-structured, refactoring complex logic into helper lambdas. Additionally, comprehensive regression tests have been added to verify the correctness of the new overflow handling logic for both bf16 and fp32 data types.

I've found one critical correctness issue in the multi-round overflow fallback path where a counter was being incorrectly reset. Please see the detailed comment.

@jiangyinzuo jiangyinzuo force-pushed the fix/bf16-topk-filtered-overflow branch from da39748 to 58f0927 Compare February 21, 2026 08:26
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)

2334-2390: Consider adding an invariant comment for s_threshold_bin_id usage in the multi-round fallback.

In run_full_scan_overflow_fallback_multi_round, s_threshold_bin_id is overwritten each round by the histogram threshold-finding block (line 2379). This is safe because the overflow invariant guarantees C > SMEM_INPUT_SIZE ≫ topk_remain, so the suffix-sum condition s_histogram[tx] > topk_remain is always satisfiable. However, there is no reset of s_threshold_bin_id to a sentinel and no comment documenting this invariant, making the code fragile if the invariant is later weakened.

Suggested documentation comment
+        // Invariant: overflow was triggered because C (threshold-bin count) >
+        // SMEM_INPUT_SIZE = 16384 >> top_k >= topk_remain, so the histogram
+        // threshold condition s_histogram[tx] > topk_remain is always satisfiable
+        // and s_threshold_bin_id is always written before being read below.
         for (int round = 0; round < NUM_ROUNDS; ++round) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 2334 - 2390, The loop in
run_full_scan_overflow_fallback_multi_round writes s_threshold_bin_id each round
without a sentinel or documentation; add a short invariant comment explaining
why s_threshold_bin_id will always be assigned (overflow invariant: C >
SMEM_INPUT_SIZE ≫ topk_remain) and make the code defensive by
initializing/resetting s_threshold_bin_id to a sentinel (e.g. -1) before the
per-round histogram/cumsum block and asserting (or checking) after run_cumsum()
that s_threshold_bin_id was set when selecting threshold_bytes[round] and
updating topk_remain; reference s_threshold_bin_id, topk_remain,
threshold_bytes, run_cumsum, and the run_full_scan_overflow_fallback_multi_round
lambda when applying the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/utils/test_topk.py`:
- Line 1246: The test unpacks batch_size from
_build_bf16_long_seq_bucket_inputs() but never uses it, triggering RUF059;
change the unpacking to replace the unused variable name batch_size with an
underscore (_) so the call becomes e.g. logits, _, _, _, _, k, _ =
_build_bf16_long_seq_bucket_inputs() (keep the original surrounding underscores
intact) to silence the warning while preserving the other returned values used
by the test.

---

Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2334-2390: The loop in run_full_scan_overflow_fallback_multi_round
writes s_threshold_bin_id each round without a sentinel or documentation; add a
short invariant comment explaining why s_threshold_bin_id will always be
assigned (overflow invariant: C > SMEM_INPUT_SIZE ≫ topk_remain) and make the
code defensive by initializing/resetting s_threshold_bin_id to a sentinel (e.g.
-1) before the per-round histogram/cumsum block and asserting (or checking)
after run_cumsum() that s_threshold_bin_id was set when selecting
threshold_bytes[round] and updating topk_remain; reference s_threshold_bin_id,
topk_remain, threshold_bytes, run_cumsum, and the
run_full_scan_overflow_fallback_multi_round lambda when applying the change.

@jiangyinzuo jiangyinzuo force-pushed the fix/bf16-topk-filtered-overflow branch 2 times, most recently from db9e011 to a7f5908 Compare February 21, 2026 08:43
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
include/flashinfer/topk.cuh (1)

2088-2099: s_threshold_bin_id not defensively reset in the multi-round overflow fallback, unlike the 1-round slow path.

Lines 2292–2293 reset s_threshold_bin_id = 0 before calling update_refine_threshold in the 1-round slow path. The multi-round overflow fallback at lines 2396–2402 does not do the same reset before each round's cumsum + threshold search. The IVT argument guarantees the condition always finds a unique bucket for valid inputs, so this is not an active bug. However, defensive parity with the 1-round path would guard against stale data under any edge case.

Suggested addition inside the per-round block of the multi-round fallback loop
           for (int round = 0; round < NUM_ROUNDS; ++round) {
             const int offset = FIRST_SHIFT - round * 8;
 
             if (tx < RADIX + 1) s_histogram[tx] = 0;
             __syncthreads();
+
+            if (tx == 0) s_threshold_bin_id = 0;
+            __syncthreads();
 
             auto build_hist = [&](auto raw_input, int /*index*/) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 2088 - 2099, The multi-round
overflow fallback loop doesn't defensively reset the shared variable
s_threshold_bin_id before each round's cumulative-sum + threshold search like
the 1-round slow path does; to fix, add a reset of s_threshold_bin_id = 0
immediately before calling update_refine_threshold inside the per-round block of
the multi-round fallback loop so each invocation of update_refine_threshold (the
lambda that runs run_cumsum() and checks s_histogram to set s_threshold_bin_id)
starts from a known state and cannot rely on stale values across rounds.
tests/utils/test_topk.py (1)

1279-1294: Inconsistent integer width vs. the sibling _build_fp32_long_seq_pivot_mismatch_inputs helper.

idx and bits use dtype=torch.int32, whereas _build_fp32_long_seq_pivot_mismatch_inputs (lines 1311-1312) uses torch.int64 + explicit & 0xFFFFFFFF masking for the same pattern. The current values don't overflow (0x3F800000 + 65535 = 0x3F80FFFF), but aligning to the same defensive idiom prevents silent breakage if the base or range changes.

Suggested fix
-    idx = torch.arange(vocab_size, device=device, dtype=torch.int32)
-    bits = torch.full((vocab_size,), 0x3F800000, device=device, dtype=torch.int32) + idx
-    logits = bits.view(torch.float32).unsqueeze(0).contiguous()
+    idx = torch.arange(vocab_size, device=device, dtype=torch.int64)
+    bits = (
+        (torch.tensor(0x3F800000, device=device, dtype=torch.int64) + idx) & 0xFFFFFFFF
+    ).to(torch.int32)
+    logits = bits.view(torch.float32).unsqueeze(0).contiguous()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1279 - 1294, The helper
_build_fp32_long_seq_overflow_inputs currently uses torch.int32 for idx and
bits; update it to match the defensive pattern used in
_build_fp32_long_seq_pivot_mismatch_inputs by creating idx with
dtype=torch.int64, computing bits as torch.int64 (e.g., torch.full(...,
dtype=torch.int64) + idx), applying an explicit & 0xFFFFFFFF mask to bits, and
then cast/convert as needed to the integer width expected by the subsequent
view() operation so the constructed bit pattern cannot silently change if
base/range change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2334-2351: There is a duplicate/ redundant check for
s_refine_overflow inside the NUM_ROUNDS loop; remove the inner "if
(s_refine_overflow) { break; }" at the end of the loop body because
s_refine_overflow is already guarded before entering the loop and
run_refine_round(r_idx, offset, ...) should drive early exits via its return
value; update the loop in topk.cuh (the NUM_ROUNDS for-loop that calls
run_refine_round and references s_refine_overflow) to rely on the
run_refine_round return path and the outer condition only.

In `@tests/utils/test_topk.py`:
- Line 1246: The unpacking assigns an unused variable batch_size from
_build_bf16_long_seq_bucket_inputs(), triggering Ruff RUF059; change the unpack
to either discard that element (replace batch_size with an underscore) or rename
it to a prefixed-underscore (e.g. _batch_size) so the linter recognizes it as
intentionally unused—update the unpack expression where logits, _, _,
batch_size, _, k, _ = _build_bf16_long_seq_bucket_inputs() is used.

---

Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2088-2099: The multi-round overflow fallback loop doesn't
defensively reset the shared variable s_threshold_bin_id before each round's
cumulative-sum + threshold search like the 1-round slow path does; to fix, add a
reset of s_threshold_bin_id = 0 immediately before calling
update_refine_threshold inside the per-round block of the multi-round fallback
loop so each invocation of update_refine_threshold (the lambda that runs
run_cumsum() and checks s_histogram to set s_threshold_bin_id) starts from a
known state and cannot rely on stale values across rounds.

In `@tests/utils/test_topk.py`:
- Around line 1279-1294: The helper _build_fp32_long_seq_overflow_inputs
currently uses torch.int32 for idx and bits; update it to match the defensive
pattern used in _build_fp32_long_seq_pivot_mismatch_inputs by creating idx with
dtype=torch.int64, computing bits as torch.int64 (e.g., torch.full(...,
dtype=torch.int64) + idx), applying an explicit & 0xFFFFFFFF mask to bits, and
then cast/convert as needed to the integer width expected by the subsequent
view() operation so the constructed bit pattern cannot silently change if
base/range change.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request significantly improves the correctness of the FilteredTopK algorithm for long sequences and tie-heavy workloads by introducing a robust overflow handling mechanism. When the shared memory candidate buffer overflows during refinement, the kernel now correctly falls back to a "slow path" that re-histograms the full threshold bin, ensuring that no winners are missed. This change affects both 16-bit (FP16/BF16) and 32-bit (FP32) data types, with a multi-round fallback implemented for the latter. The addition of comprehensive regression tests for these edge cases further ensures the reliability of the fix. The implementation is well-structured, leveraging modern CUDA C++ features and maintaining performance through a fast-path/slow-path design.

@jiangyinzuo jiangyinzuo force-pushed the fix/bf16-topk-filtered-overflow branch 3 times, most recently from 8e495ce to 3e3231f Compare February 21, 2026 11:59
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a critical correctness fix for FilteredTopK by adding a fallback mechanism to handle buffer overflows during the refinement stage. The changes are well-implemented, refactoring complex logic into more manageable helper functions, which improves code clarity. The addition of comprehensive regression tests targeting the specific overflow scenarios is excellent and ensures the robustness of the fix. I have one minor suggestion to improve code consistency.

Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@yzh119 I am ready now. Can you please retrigger the CI?

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 22, 2026

@flashinfer-bot run

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 22, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !335 has been updated with latest changes, and the CI pipeline #44540308 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #44540308: 11/20 passed

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

jiangyinzuo commented Feb 22, 2026

I can't see the gitlab CI log. Do the failures matters? @yzh119

@yongwww
Copy link
Copy Markdown
Member

yongwww commented Feb 22, 2026

@flashinfer-bot run

@yongwww yongwww added the run-ci label Feb 22, 2026
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 22, 2026

Thanks for the thorough fix. One issue I noticed:

s_refine_overflow is declared at function scope and set via atomicOr when the buffer overflows, but it's never reset between iterations of the persistent loop (the for (uint32_t iter = 0; iter < total_iterations; iter++) loop over rows). Once a row triggers overflow, all subsequent rows processed by the same CTA group will also take the slow fallback path, even if their threshold bin fits in the buffer.

Suggest adding if (tx == 0) s_refine_overflow = 0; alongside the existing s_counter = 0 reset at the top of each iteration.

Also worth adding a test with num_rows > 1 where only some rows have tie-heavy distributions that overflow, to verify per-row correctness in the persistent loop.

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough fix. One issue I noticed:

s_refine_overflow is declared at function scope and set via atomicOr when the buffer overflows, but it's never reset between iterations of the persistent loop (the for (uint32_t iter = 0; iter < total_iterations; iter++) loop over rows). Once a row triggers overflow, all subsequent rows processed by the same CTA group will also take the slow fallback path, even if their threshold bin fits in the buffer.

Suggest adding if (tx == 0) s_refine_overflow = 0; alongside the existing s_counter = 0 reset at the top of each iteration.

Also worth adding a test with num_rows > 1 where only some rows have tie-heavy distributions that overflow, to verify per-row correctness in the persistent loop.

Thanks for the review.

This PR focuses on FilteredTopKUnifiedKernel, which has no persistent row loop for (uint32_t iter = 0; iter < total_iterations; iter++), each
CTA handles exactly one row (Single-CTA mode), so s_refine_overflow cannot carry over to another row within
the same launch.

The for (uint32_t iter = 0; iter < total_iterations; ++iter) persistent-loop pattern exists in the Multi-CTA radix
kernels (RadixTopKKernel_Unified, RadixTopKMaskLogitsKernel_MultiCTA, RadixTopKRenormProbKernel_MultiCTA) and there is no s_refine_overflow variable.

Your test suggestion is still very useful for those persistent kernels. I plan to prepare those tests and send them in a separate PR to keep this FilteredTopK bugfix PR scoped.

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@yzh119 can this PR be merged? 👀

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiangyinzuo thanks for clarification, sorry it's my bad, I think you are right.

Failed gitlab tests are infrastructure issues and not relevant. Let's merge it now, thanks for your contribution!

@yzh119 yzh119 merged commit 82ff6a9 into flashinfer-ai:main Feb 25, 2026
45 checks passed
@coderabbitai coderabbitai bot mentioned this pull request Mar 1, 2026
5 tasks
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

This PR fixes issue flashinfer-ai#2604: FilteredTopK could produce incorrect results
when the refine candidate buffer
  overflows in long-sequence, tie-heavy workloads.

### Root Cause

FilteredTopK uses a fixed-size refine candidate buffer
(SMEM_INPUT_SIZE=16K), which can overflow when the threshold bin is too
large. The previous logic could continue with partially truncated state
after overflow, causing incorrect top-k outputs.

### Main Changes

#### 1) Kernel correctness fixes (include/flashinfer/topk.cuh)

- Added a refine-overflow flag s_refine_overflow, set atomically on
candidate write overflow.
- On refine-buffer overflow, we now use a full-scan correctness
fallback: rebuild the needed histogram/threshold state from the full
input, recompute pivot selection state, reset partial intermediate
state, and recollect winners from the full
  input to guarantee correct top-k results.

#### 2) Regression test coverage (tests/utils/test_topk.py)

- Added BF16 long-sequence regressions across auto / multi_cta /
filtered.
- Added FP32 long-sequence refine-overflow regressions for both top_k
and transform APIs (page_table / ragged).
- Added FP32 pivot reconstruction regressions for both top_k and
transform APIs.

## 🔍 Related Issues

fix: flashinfer-ai#2604

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
- [x] Performance test regression.

Run the following benchmark 10 times. There is no noticeable performance
degradation in this PR.
```python
python benchmarks/bench_topk.py --compare-algorithms --dtype bf16/fp16/fp32
```

**Environment**
- GPU: NVIDIA A100-PCIE-40GB (SM80)
- PyTorch: 2.8.0+cu128
- CUDA (torch): 12.8

### bf16

| batch | seq_len | multi-CTA(us) | filtered(us)this PR | filtered(us)
base | filtered this PR / base |
|---:|---:|---:|---:|---:|---:|
| 1 | 4096 | 27.470 | 20.887 | 20.377 | 1.0250 |
| 1 | 16384 | 44.983 | 39.322 | 38.911 | 1.0106 |
| 1 | 65536 | 106.397 | 89.090 | 88.629 | 1.0052 |
| 1 | 131072 | 107.572 | 96.564 | 101.887 | 0.9478 |
| 1 | 262144 | 89.549 | 154.777 | 154.625 | 1.0010 |
| 1 | 524288 | 93.492 | 242.329 | 258.611 | 0.9370 |
| 16 | 4096 | 18.739 | 16.178 | 16.896 | 0.9575 |
| 16 | 16384 | 28.055 | 25.087 | 25.190 | 0.9959 |
| 16 | 65536 | 63.949 | 55.551 | 55.294 | 1.0046 |
| 16 | 131072 | 65.381 | 67.379 | 66.968 | 1.0061 |
| 16 | 262144 | 65.591 | 114.792 | 114.689 | 1.0009 |
| 16 | 524288 | 136.191 | 213.658 | 213.708 | 0.9998 |
| 64 | 4096 | 19.304 | 18.227 | 16.997 | 1.0724 |
| 64 | 16384 | 28.875 | 26.620 | 26.620 | 1.0000 |
| 64 | 65536 | 69.630 | 61.440 | 61.440 | 1.0000 |
| 64 | 131072 | 133.197 | 76.800 | 76.800 | 1.0000 |
| 64 | 262144 | 190.563 | 140.392 | 140.596 | 0.9985 |
| 64 | 524288 | 335.819 | 264.088 | 263.935 | 1.0006 |
| 256 | 4096 | 38.910 | 29.904 | 29.700 | 1.0069 |
| 256 | 16384 | 79.870 | 73.576 | 73.524 | 1.0007 |
| 256 | 65536 | 200.700 | 173.060 | 173.060 | 1.0000 |
| 256 | 131072 | 322.662 | 215.142 | 215.091 | 1.0002 |
| 256 | 262144 | 588.391 | 421.581 | 421.477 | 1.0002 |
| 256 | 524288 | 1141.863 | 811.111 | 811.367 | 0.9997 |
### fp16

| batch | seq_len | multi-CTA(us) | filtered(us)this PR | filtered(us)
base | filtered this PR / base |
|---:|---:|---:|---:|---:|---:|
| 1 | 4096 | 27.599 | 20.324 | 19.252 | 1.0557 |
| 1 | 16384 | 44.545 | 36.860 | 36.707 | 1.0042 |
| 1 | 65536 | 103.780 | 70.145 | 70.863 | 0.9899 |
| 1 | 131072 | 103.270 | 97.638 | 93.080 | 1.0490 |
| 1 | 262144 | 89.141 | 154.778 | 147.199 | 1.0515 |
| 1 | 524288 | 84.274 | 201.217 | 202.087 | 0.9957 |
| 16 | 4096 | 19.738 | 15.820 | 16.793 | 0.9421 |
| 16 | 16384 | 29.695 | 25.498 | 25.447 | 1.0020 |
| 16 | 65536 | 65.562 | 46.283 | 46.847 | 0.9880 |
| 16 | 131072 | 65.537 | 67.072 | 67.174 | 0.9985 |
| 16 | 262144 | 65.024 | 114.484 | 114.175 | 1.0027 |
| 16 | 524288 | 133.426 | 176.232 | 176.027 | 1.0012 |
| 64 | 4096 | 18.485 | 15.667 | 17.815 | 0.8794 |
| 64 | 16384 | 28.670 | 24.682 | 24.580 | 1.0041 |
| 64 | 65536 | 68.147 | 51.200 | 51.200 | 1.0000 |
| 64 | 131072 | 131.714 | 76.800 | 76.800 | 1.0000 |
| 64 | 262144 | 188.930 | 137.322 | 137.322 | 1.0000 |
| 64 | 524288 | 328.855 | 208.900 | 208.900 | 1.0000 |
| 256 | 4096 | 38.885 | 27.547 | 27.341 | 1.0075 |
| 256 | 16384 | 78.901 | 68.610 | 68.610 | 1.0000 |
| 256 | 65536 | 195.580 | 142.237 | 142.134 | 1.0007 |
| 256 | 131072 | 318.975 | 214.938 | 214.836 | 1.0005 |
| 256 | 262144 | 582.863 | 410.672 | 410.777 | 0.9997 |
| 256 | 524288 | 1119.438 | 643.019 | 642.815 | 1.0003 |
### fp32

| batch | seq_len | multi-CTA(us) | filtered(us)this PR | filtered(us)
base | filtered this PR / base |
|---:|---:|---:|---:|---:|---:|
| 1 | 4096 | 35.840 | 19.970 | 19.970 | 1.0000 |
| 1 | 16384 | 58.370 | 44.957 | 45.161 | 0.9955 |
| 1 | 65536 | 99.072 | 87.911 | 89.089 | 0.9868 |
| 1 | 131072 | 92.674 | 132.098 | 124.827 | 1.0582 |
| 1 | 262144 | 88.114 | 188.930 | 187.392 | 1.0082 |
| 1 | 524288 | 77.157 | 263.422 | 260.658 | 1.0106 |
| 16 | 4096 | 22.270 | 16.692 | 16.794 | 0.9939 |
| 16 | 16384 | 36.150 | 28.673 | 29.799 | 0.9622 |
| 16 | 65536 | 60.879 | 56.014 | 56.731 | 0.9874 |
| 16 | 131072 | 60.522 | 86.632 | 86.581 | 1.0006 |
| 16 | 262144 | 123.594 | 151.860 | 151.653 | 1.0014 |
| 16 | 524288 | 141.516 | 275.816 | 275.559 | 1.0009 |
| 64 | 4096 | 22.581 | 17.306 | 17.817 | 0.9713 |
| 64 | 16384 | 37.916 | 31.740 | 31.740 | 1.0000 |
| 64 | 65536 | 121.860 | 66.866 | 66.662 | 1.0031 |
| 64 | 131072 | 177.150 | 112.742 | 112.742 | 1.0000 |
| 64 | 262144 | 308.426 | 203.728 | 203.780 | 0.9997 |
| 64 | 524288 | 503.938 | 310.373 | 310.270 | 1.0003 |
| 256 | 4096 | 53.250 | 30.516 | 30.924 | 0.9868 |
| 256 | 16384 | 108.618 | 88.781 | 88.987 | 0.9977 |
| 256 | 65536 | 295.657 | 188.420 | 188.471 | 0.9997 |
| 256 | 131072 | 543.282 | 333.514 | 333.361 | 1.0005 |
| 256 | 262144 | 1002.573 | 654.493 | 654.646 | 0.9998 |
| 256 | 524288 | 1899.034 | 1035.158 | 1035.468 | 0.9997 |

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Bug Fixes**
* Improved Top-K correctness for long sequences with overflow-safe,
staged refinement (fast and slow paths), multi-pass threshold handling,
and coordination across parallel units to handle tie-heavy and
buffer-overflow edge cases (BF16/FP32).

* **Tests**
* Added regression tests, data builders, and helpers for BF16
long-sequence transforms and FP32 long-sequence overflow/pivot cases,
plus tie-insensitive index comparison utilities.

* **Documentation**
* Added note about potential threshold-bin overflow (bf16/tie-heavy)
with suggestion to use a multi-CTA mode when needed.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] BF16 top_k returns incorrect values on long sequences when forcing filtered algorithm

4 participants