[bugfix] Fix FilteredTopK overflow correctness#2605
[bugfix] Fix FilteredTopK overflow correctness#2605yzh119 merged 1 commit intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds staged, overflow-aware multi-round refinement to the FilteredTopK kernel (shared-memory refine state, inter-CTA coordination, staged collection and slow-path fallback) and expands BF16/FP32 long-sequence tests to cover filtered/multi-CTA/auto behaviors, overflow regressions, and pivot reconstruction cases. Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~55 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jiangyinzuo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical correctness bug within the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a correctness bug in the FilteredTopK algorithm for bfloat16 data types, which could occur when there are many elements with tied values, causing an overflow in an intermediate shared memory buffer. The fix introduces a robust fallback path that re-scans the input to ensure correctness in these overflow scenarios. The implementation has also been significantly refactored into smaller, more manageable helper lambdas, improving code clarity. Additionally, comprehensive regression tests have been added to verify the fix and prevent future issues. My review includes one suggestion for a minor performance optimization.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (3)
tests/utils/test_topk.py (1)
1276-1278:_assert_unordered_indices_matchassumes output dtype matches expected dtype.
outputfrom the transform tests istorch.int32whileexpectedis alsotorch.int32, so this works. However, for the plaintop_kpath,indicesaretorch.int64— if this helper is ever reused there, thetorch.equalwill fail on dtype mismatch. A minor robustness nit, not blocking.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1276 - 1278, The helper _assert_unordered_indices_match can fail when output and expected have different integer dtypes (e.g., int32 vs int64); modify the function to normalize dtypes before comparison by casting both sorted values to a common integer dtype (for example expected.dtype or torch.long) prior to calling torch.equal so that dtype mismatches do not cause false failures when only order differs; ensure you apply the cast to the result of torch.sort(...).values for both output and expected.include/flashinfer/topk.cuh (2)
2199-2216:collect_with_threshold_last_roundcallsreset_histogram()unnecessarily.Line 2200 clears
s_histogram, but the histogram is never read again after this function. This is a harmless no-op but adds an unnecessary__syncthreads()barrier. Minor nit — not worth changing unless performance-sensitive.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 2199 - 2216, The lambda collect_with_threshold_last_round unnecessarily calls reset_histogram() and then issues a __syncthreads(), but s_histogram is never referenced afterward; remove the reset_histogram() invocation and the ensuing __syncthreads() inside collect_with_threshold_last_round to eliminate the no-op clear and extra barrier (keep the rest of the loop logic and atomic updates to s_counter/s_last_remain and s_indices unchanged).
2217-2239: Silent overflow in multi-round (float32)collect_with_threshold_non_last_round.Line 2230 checks
pos < SMEM_INPUT_SIZEbut silently drops elements on overflow without setting any flag (unlike the 1-round path at lines 2156–2158). For float32 (NUM_ROUNDS=4), an overflow here means some equal-to-threshold elements are lost, leading to incorrect results.In practice, this is extremely unlikely for float32 (would require >16K elements sharing identical upper 16+ bits), but for completeness the same
s_refine_overflow-style fallback could apply. Worth noting for future hardening.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 2217 - 2239, collect_with_threshold_non_last_round currently drops equal-to-threshold elements when pos >= SMEM_INPUT_SIZE without signaling overflow; update the else branch that handles static_cast<int>(bin) == threshold inside collect_with_threshold_non_last_round to set the same overflow/fallback flag used by the 1-round path (e.g., s_refine_overflow or s_refine_overflow[next_r_idx]) when pos is out of bounds instead of silently dropping items, mirroring the behavior in the single-round code path and keeping existing uses of s_input_idx, s_num_input, s_histogram, s_indices and s_counter intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/utils/test_topk.py`:
- Around line 1230-1254: test_bf16_long_seq_regression_across_algorithms tries
to run the "filtered" TopK kernel unguarded, which will fail on GPUs without the
128KB shared memory requirement; add a guard at the start of the test that skips
when algo == "filtered" and can_implement_filtered_topk() is false (mirror other
tests), i.e., call can_implement_filtered_topk() and pytest.skip when it returns
False before calling set_topk_algo or invoking flashinfer.top_k.
- Around line 1281-1306: The test
test_bf16_long_seq_transform_regression_filtered forces the "filtered" path via
set_topk_algo("filtered") but lacks the can_implement_filtered_topk() guard used
elsewhere; add a guard after the BF16 support check that calls
can_implement_filtered_topk() and calls pytest.skip("Filtered top-k not
supported on this device") if it returns False so the test only runs when the
filtered kernel can be launched (keep set_topk_algo("filtered") and the rest of
the test unchanged).
---
Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2199-2216: The lambda collect_with_threshold_last_round
unnecessarily calls reset_histogram() and then issues a __syncthreads(), but
s_histogram is never referenced afterward; remove the reset_histogram()
invocation and the ensuing __syncthreads() inside
collect_with_threshold_last_round to eliminate the no-op clear and extra barrier
(keep the rest of the loop logic and atomic updates to s_counter/s_last_remain
and s_indices unchanged).
- Around line 2217-2239: collect_with_threshold_non_last_round currently drops
equal-to-threshold elements when pos >= SMEM_INPUT_SIZE without signaling
overflow; update the else branch that handles static_cast<int>(bin) == threshold
inside collect_with_threshold_non_last_round to set the same overflow/fallback
flag used by the 1-round path (e.g., s_refine_overflow or
s_refine_overflow[next_r_idx]) when pos is out of bounds instead of silently
dropping items, mirroring the behavior in the single-round code path and keeping
existing uses of s_input_idx, s_num_input, s_histogram, s_indices and s_counter
intact.
In `@tests/utils/test_topk.py`:
- Around line 1276-1278: The helper _assert_unordered_indices_match can fail
when output and expected have different integer dtypes (e.g., int32 vs int64);
modify the function to normalize dtypes before comparison by casting both sorted
values to a common integer dtype (for example expected.dtype or torch.long)
prior to calling torch.equal so that dtype mismatches do not cause false
failures when only order differs; ensure you apply the cast to the result of
torch.sort(...).values for both output and expected.
|
/bot run |
|
@flashinfer-bot run |
|
[FAILED] Pipeline #44473794: 14/20 passed |
f6a28e6 to
f72961c
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2228-2238: The multi-round float32 refine path doesn't handle
s_refine_overflow set by collect_with_threshold_non_last_round, so add a check
of s_refine_overflow before proceeding with the multi-round loop (the same place
get_num_input is used) and apply the same fallback used for the 1-round overflow
case: detect overflow, set a safe clamp/exit path (or force single-round
processing) and log/propagate the overflow state so elements aren't silently
lost; update the logic around get_num_input and the multi-round loop to
early-handle s_refine_overflow and mirror the bf16/1-round overflow handling.
f72961c to
3d32e6d
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces overflow detection and fallback mechanisms for the FilteredTopK algorithm to ensure correctness when processing long sequences or workloads with heavy ties. The implementation adds a slow-path fallback that performs full scans of the threshold bin when the shared memory candidate buffer overflows. While the overall logic is sound and the added regression tests are comprehensive, I identified a few issues in the fallback paths, including an incorrect pivot construction in the multi-round case and a missing counter reset in the single-round case, which could lead to incorrect results or out-of-bounds memory access.
There was a problem hiding this comment.
🧹 Nitpick comments (4)
tests/utils/test_topk.py (2)
1246-1259: DRY: use_build_bf16_long_seq_bucket_inputs()instead of duplicating logit construction.The logit tensor construction at lines 1246–1251 is byte-for-byte identical to
_build_bf16_long_seq_bucket_inputsdefined at lines 1269–1274. Call the helper instead:♻️ Proposed refactor
- batch_size = 4 - vocab_size = 65536 - k = 1024 - device = "cuda" - - # Repeated-value buckets trigger large threshold-bin occupancy in bf16. - logits = ( - ((torch.arange(vocab_size, device=device, dtype=torch.float32) % 64) / 64.0) - .unsqueeze(0) - .repeat(batch_size, 1) - .to(torch.bfloat16) - ) - - values, indices = flashinfer.top_k(logits, k, sorted=True) - ref_values, _ = torch.topk(logits, k, dim=-1, sorted=True) + logits, _, _, batch_size, _, k, _ = _build_bf16_long_seq_bucket_inputs() + values, indices = flashinfer.top_k(logits, k, sorted=True) + ref_values, _ = torch.topk(logits, k, dim=-1, sorted=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1246 - 1259, Replace the duplicated logit tensor construction with a call to the existing helper _build_bf16_long_seq_bucket_inputs: remove the manual creation of logits (the block that builds the bfloat16 repeated arange using vocab_size, batch_size, device) and instead call _build_bf16_long_seq_bucket_inputs(batch_size, vocab_size, device) to produce logits before calling flashinfer.top_k; ensure the variable name logits is preserved so subsequent uses (flashinfer.top_k, torch.topk, torch.gather) remain unchanged.
1233-1234: Consider replacingtorch.cuda.is_bf16_supported()with aflashinfer.utilscapability check.
torch.cuda.is_bf16_supported()is functional, but the coding guidelines require architecture-gating viaflashinfer.utilsfunctions (e.g.get_compute_capability). BF16 native support requires SM80+, so:from flashinfer.utils import get_compute_capability if get_compute_capability() < (8, 0): pytest.skip("BF16 requires SM80+")If
flashinfer.utilsdoes not currently exposeget_compute_capability, this can be deferred, but it should be tracked for consistency with the guidelines. The same applies to line 1315.Based on coding guidelines: "Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1233 - 1234, Replace the direct torch.cuda.is_bf16_supported() check with the flashinfer.utils architecture-capability check: import and call get_compute_capability() and skip the test when the returned compute capability is less than (8, 0) (BF16 requires SM80+); update the two locations that use torch.cuda.is_bf16_supported() in tests/utils/test_topk.py to use get_compute_capability() instead (or defer and create/get_compute_capability in flashinfer.utils if missing) so gating follows the project's utility functions.include/flashinfer/topk.cuh (2)
2157-2159: Optional: preferatomicOroveratomicExchfor flag-setting semantics.Both instructions are correct, but
atomicOr(&s_refine_overflow, 1)communicates intent more clearly (set a flag, ignoring the prior value) thanatomicExch, and avoids a redundant full-word write when the flag is already1.🔧 Suggested change
- atomicExch(&s_refine_overflow, 1); + atomicOr(&s_refine_overflow, 1);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 2157 - 2159, Replace the use of atomicExch for setting the overflow flag with atomicOr to express flag semantics and avoid redundant writes: in the block that currently calls atomicExch(&s_refine_overflow, 1) (the shared/global flag variable s_refine_overflow inside the refine/overflow handling code), change it to atomicOr(&s_refine_overflow, 1) so the flag is set atomically without overwriting prior value; ensure the variable type matches atomicOr's expected integer type and no other logic depends on the returned old value from atomicExch.
2275-2321: Add a comment documenting thes_countercontinuity invariant.
s_counteris intentionally not reset beforecollect_from_full_threshold_bin. It already holds the count ofcoarse_bin > threshold_binelements (written duringfilter_and_add_to_histogram).collect_from_full_threshold_binappends thesub_bin > refine_thresholdelements on top, maintaining the invariant:s_counter_final + s_last_remain_initial == top_kwithout which the output positions would overlap or leave gaps. A brief comment at line 2302 would save future readers from reverse-engineering this.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 2275 - 2321, Add a short in-place comment where s_counter is left untouched before defining/using collect_from_full_threshold_bin (near the block that calls update_refine_threshold and then defines collect_from_full_threshold_bin) explaining the continuity invariant: s_counter already contains the count of elements with coarse_bin > threshold_bin (from the earlier filter_and_add_to_histogram phase), collect_from_full_threshold_bin appends sub_bin > refine_threshold items on top of that, and together with s_last_remain they satisfy s_counter_final + s_last_remain_initial == top_k; mention that s_counter is intentionally not reset to prevent overwriting positions in s_indices.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2157-2159: Replace the use of atomicExch for setting the overflow
flag with atomicOr to express flag semantics and avoid redundant writes: in the
block that currently calls atomicExch(&s_refine_overflow, 1) (the shared/global
flag variable s_refine_overflow inside the refine/overflow handling code),
change it to atomicOr(&s_refine_overflow, 1) so the flag is set atomically
without overwriting prior value; ensure the variable type matches atomicOr's
expected integer type and no other logic depends on the returned old value from
atomicExch.
- Around line 2275-2321: Add a short in-place comment where s_counter is left
untouched before defining/using collect_from_full_threshold_bin (near the block
that calls update_refine_threshold and then defines
collect_from_full_threshold_bin) explaining the continuity invariant: s_counter
already contains the count of elements with coarse_bin > threshold_bin (from the
earlier filter_and_add_to_histogram phase), collect_from_full_threshold_bin
appends sub_bin > refine_threshold items on top of that, and together with
s_last_remain they satisfy s_counter_final + s_last_remain_initial == top_k;
mention that s_counter is intentionally not reset to prevent overwriting
positions in s_indices.
In `@tests/utils/test_topk.py`:
- Around line 1246-1259: Replace the duplicated logit tensor construction with a
call to the existing helper _build_bf16_long_seq_bucket_inputs: remove the
manual creation of logits (the block that builds the bfloat16 repeated arange
using vocab_size, batch_size, device) and instead call
_build_bf16_long_seq_bucket_inputs(batch_size, vocab_size, device) to produce
logits before calling flashinfer.top_k; ensure the variable name logits is
preserved so subsequent uses (flashinfer.top_k, torch.topk, torch.gather) remain
unchanged.
- Around line 1233-1234: Replace the direct torch.cuda.is_bf16_supported() check
with the flashinfer.utils architecture-capability check: import and call
get_compute_capability() and skip the test when the returned compute capability
is less than (8, 0) (BF16 requires SM80+); update the two locations that use
torch.cuda.is_bf16_supported() in tests/utils/test_topk.py to use
get_compute_capability() instead (or defer and create/get_compute_capability in
flashinfer.utils if missing) so gating follows the project's utility functions.
3d32e6d to
da39748
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a bugfix for FilteredTopK to handle overflow scenarios correctly, especially for long sequences with many ties. The changes involve adding a fallback path that performs a full scan when the intermediate buffer for refinement overflows. This ensures correctness at the cost of performance in those edge cases. The changes are well-structured, refactoring complex logic into helper lambdas. Additionally, comprehensive regression tests have been added to verify the correctness of the new overflow handling logic for both bf16 and fp32 data types.
I've found one critical correctness issue in the multi-round overflow fallback path where a counter was being incorrectly reset. Please see the detailed comment.
da39748 to
58f0927
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)
2334-2390: Consider adding an invariant comment fors_threshold_bin_idusage in the multi-round fallback.In
run_full_scan_overflow_fallback_multi_round,s_threshold_bin_idis overwritten each round by the histogram threshold-finding block (line 2379). This is safe because the overflow invariant guaranteesC > SMEM_INPUT_SIZE ≫ topk_remain, so the suffix-sum conditions_histogram[tx] > topk_remainis always satisfiable. However, there is no reset ofs_threshold_bin_idto a sentinel and no comment documenting this invariant, making the code fragile if the invariant is later weakened.Suggested documentation comment
+ // Invariant: overflow was triggered because C (threshold-bin count) > + // SMEM_INPUT_SIZE = 16384 >> top_k >= topk_remain, so the histogram + // threshold condition s_histogram[tx] > topk_remain is always satisfiable + // and s_threshold_bin_id is always written before being read below. for (int round = 0; round < NUM_ROUNDS; ++round) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 2334 - 2390, The loop in run_full_scan_overflow_fallback_multi_round writes s_threshold_bin_id each round without a sentinel or documentation; add a short invariant comment explaining why s_threshold_bin_id will always be assigned (overflow invariant: C > SMEM_INPUT_SIZE ≫ topk_remain) and make the code defensive by initializing/resetting s_threshold_bin_id to a sentinel (e.g. -1) before the per-round histogram/cumsum block and asserting (or checking) after run_cumsum() that s_threshold_bin_id was set when selecting threshold_bytes[round] and updating topk_remain; reference s_threshold_bin_id, topk_remain, threshold_bytes, run_cumsum, and the run_full_scan_overflow_fallback_multi_round lambda when applying the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/utils/test_topk.py`:
- Line 1246: The test unpacks batch_size from
_build_bf16_long_seq_bucket_inputs() but never uses it, triggering RUF059;
change the unpacking to replace the unused variable name batch_size with an
underscore (_) so the call becomes e.g. logits, _, _, _, _, k, _ =
_build_bf16_long_seq_bucket_inputs() (keep the original surrounding underscores
intact) to silence the warning while preserving the other returned values used
by the test.
---
Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2334-2390: The loop in run_full_scan_overflow_fallback_multi_round
writes s_threshold_bin_id each round without a sentinel or documentation; add a
short invariant comment explaining why s_threshold_bin_id will always be
assigned (overflow invariant: C > SMEM_INPUT_SIZE ≫ topk_remain) and make the
code defensive by initializing/resetting s_threshold_bin_id to a sentinel (e.g.
-1) before the per-round histogram/cumsum block and asserting (or checking)
after run_cumsum() that s_threshold_bin_id was set when selecting
threshold_bytes[round] and updating topk_remain; reference s_threshold_bin_id,
topk_remain, threshold_bytes, run_cumsum, and the
run_full_scan_overflow_fallback_multi_round lambda when applying the change.
db9e011 to
a7f5908
Compare
|
/gemini review |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
include/flashinfer/topk.cuh (1)
2088-2099:s_threshold_bin_idnot defensively reset in the multi-round overflow fallback, unlike the 1-round slow path.Lines 2292–2293 reset
s_threshold_bin_id = 0before callingupdate_refine_thresholdin the 1-round slow path. The multi-round overflow fallback at lines 2396–2402 does not do the same reset before each round's cumsum + threshold search. The IVT argument guarantees the condition always finds a unique bucket for valid inputs, so this is not an active bug. However, defensive parity with the 1-round path would guard against stale data under any edge case.Suggested addition inside the per-round block of the multi-round fallback loop
for (int round = 0; round < NUM_ROUNDS; ++round) { const int offset = FIRST_SHIFT - round * 8; if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); + + if (tx == 0) s_threshold_bin_id = 0; + __syncthreads(); auto build_hist = [&](auto raw_input, int /*index*/) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 2088 - 2099, The multi-round overflow fallback loop doesn't defensively reset the shared variable s_threshold_bin_id before each round's cumulative-sum + threshold search like the 1-round slow path does; to fix, add a reset of s_threshold_bin_id = 0 immediately before calling update_refine_threshold inside the per-round block of the multi-round fallback loop so each invocation of update_refine_threshold (the lambda that runs run_cumsum() and checks s_histogram to set s_threshold_bin_id) starts from a known state and cannot rely on stale values across rounds.tests/utils/test_topk.py (1)
1279-1294: Inconsistent integer width vs. the sibling_build_fp32_long_seq_pivot_mismatch_inputshelper.
idxandbitsusedtype=torch.int32, whereas_build_fp32_long_seq_pivot_mismatch_inputs(lines 1311-1312) usestorch.int64+ explicit& 0xFFFFFFFFmasking for the same pattern. The current values don't overflow (0x3F800000 + 65535 = 0x3F80FFFF), but aligning to the same defensive idiom prevents silent breakage if the base or range changes.Suggested fix
- idx = torch.arange(vocab_size, device=device, dtype=torch.int32) - bits = torch.full((vocab_size,), 0x3F800000, device=device, dtype=torch.int32) + idx - logits = bits.view(torch.float32).unsqueeze(0).contiguous() + idx = torch.arange(vocab_size, device=device, dtype=torch.int64) + bits = ( + (torch.tensor(0x3F800000, device=device, dtype=torch.int64) + idx) & 0xFFFFFFFF + ).to(torch.int32) + logits = bits.view(torch.float32).unsqueeze(0).contiguous()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1279 - 1294, The helper _build_fp32_long_seq_overflow_inputs currently uses torch.int32 for idx and bits; update it to match the defensive pattern used in _build_fp32_long_seq_pivot_mismatch_inputs by creating idx with dtype=torch.int64, computing bits as torch.int64 (e.g., torch.full(..., dtype=torch.int64) + idx), applying an explicit & 0xFFFFFFFF mask to bits, and then cast/convert as needed to the integer width expected by the subsequent view() operation so the constructed bit pattern cannot silently change if base/range change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2334-2351: There is a duplicate/ redundant check for
s_refine_overflow inside the NUM_ROUNDS loop; remove the inner "if
(s_refine_overflow) { break; }" at the end of the loop body because
s_refine_overflow is already guarded before entering the loop and
run_refine_round(r_idx, offset, ...) should drive early exits via its return
value; update the loop in topk.cuh (the NUM_ROUNDS for-loop that calls
run_refine_round and references s_refine_overflow) to rely on the
run_refine_round return path and the outer condition only.
In `@tests/utils/test_topk.py`:
- Line 1246: The unpacking assigns an unused variable batch_size from
_build_bf16_long_seq_bucket_inputs(), triggering Ruff RUF059; change the unpack
to either discard that element (replace batch_size with an underscore) or rename
it to a prefixed-underscore (e.g. _batch_size) so the linter recognizes it as
intentionally unused—update the unpack expression where logits, _, _,
batch_size, _, k, _ = _build_bf16_long_seq_bucket_inputs() is used.
---
Nitpick comments:
In `@include/flashinfer/topk.cuh`:
- Around line 2088-2099: The multi-round overflow fallback loop doesn't
defensively reset the shared variable s_threshold_bin_id before each round's
cumulative-sum + threshold search like the 1-round slow path does; to fix, add a
reset of s_threshold_bin_id = 0 immediately before calling
update_refine_threshold inside the per-round block of the multi-round fallback
loop so each invocation of update_refine_threshold (the lambda that runs
run_cumsum() and checks s_histogram to set s_threshold_bin_id) starts from a
known state and cannot rely on stale values across rounds.
In `@tests/utils/test_topk.py`:
- Around line 1279-1294: The helper _build_fp32_long_seq_overflow_inputs
currently uses torch.int32 for idx and bits; update it to match the defensive
pattern used in _build_fp32_long_seq_pivot_mismatch_inputs by creating idx with
dtype=torch.int64, computing bits as torch.int64 (e.g., torch.full(...,
dtype=torch.int64) + idx), applying an explicit & 0xFFFFFFFF mask to bits, and
then cast/convert as needed to the integer width expected by the subsequent
view() operation so the constructed bit pattern cannot silently change if
base/range change.
There was a problem hiding this comment.
Code Review
The pull request significantly improves the correctness of the FilteredTopK algorithm for long sequences and tie-heavy workloads by introducing a robust overflow handling mechanism. When the shared memory candidate buffer overflows during refinement, the kernel now correctly falls back to a "slow path" that re-histograms the full threshold bin, ensuring that no winners are missed. This change affects both 16-bit (FP16/BF16) and 32-bit (FP32) data types, with a multi-round fallback implemented for the latter. The addition of comprehensive regression tests for these edge cases further ensures the reliability of the fix. The implementation is well-structured, leveraging modern CUDA C++ features and maintaining performance through a fast-path/slow-path design.
8e495ce to
3e3231f
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request provides a critical correctness fix for FilteredTopK by adding a fallback mechanism to handle buffer overflows during the refinement stage. The changes are well-implemented, refactoring complex logic into more manageable helper functions, which improves code clarity. The addition of comprehensive regression tests targeting the specific overflow scenarios is excellent and ensures the robustness of the fix. I have one minor suggestion to improve code consistency.
Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
|
@yzh119 I am ready now. Can you please retrigger the CI? |
|
@flashinfer-bot run |
|
/bot run |
|
[FAILED] Pipeline #44540308: 11/20 passed |
|
I can't see the gitlab CI log. Do the failures matters? @yzh119 |
|
@flashinfer-bot run |
|
Thanks for the thorough fix. One issue I noticed:
Suggest adding Also worth adding a test with |
Thanks for the review. This PR focuses on The Your test suggestion is still very useful for those persistent kernels. I plan to prepare those tests and send them in a separate PR to keep this FilteredTopK bugfix PR scoped. |
|
@yzh119 can this PR be merged? 👀 |
yzh119
left a comment
There was a problem hiding this comment.
@jiangyinzuo thanks for clarification, sorry it's my bad, I think you are right.
Failed gitlab tests are infrastructure issues and not relevant. Let's merge it now, thanks for your contribution!
<!-- .github/pull_request_template.md --> ## 📌 Description This PR fixes issue flashinfer-ai#2604: FilteredTopK could produce incorrect results when the refine candidate buffer overflows in long-sequence, tie-heavy workloads. ### Root Cause FilteredTopK uses a fixed-size refine candidate buffer (SMEM_INPUT_SIZE=16K), which can overflow when the threshold bin is too large. The previous logic could continue with partially truncated state after overflow, causing incorrect top-k outputs. ### Main Changes #### 1) Kernel correctness fixes (include/flashinfer/topk.cuh) - Added a refine-overflow flag s_refine_overflow, set atomically on candidate write overflow. - On refine-buffer overflow, we now use a full-scan correctness fallback: rebuild the needed histogram/threshold state from the full input, recompute pivot selection state, reset partial intermediate state, and recollect winners from the full input to guarantee correct top-k results. #### 2) Regression test coverage (tests/utils/test_topk.py) - Added BF16 long-sequence regressions across auto / multi_cta / filtered. - Added FP32 long-sequence refine-overflow regressions for both top_k and transform APIs (page_table / ragged). - Added FP32 pivot reconstruction regressions for both top_k and transform APIs. ## 🔍 Related Issues fix: flashinfer-ai#2604 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). - [x] Performance test regression. Run the following benchmark 10 times. There is no noticeable performance degradation in this PR. ```python python benchmarks/bench_topk.py --compare-algorithms --dtype bf16/fp16/fp32 ``` **Environment** - GPU: NVIDIA A100-PCIE-40GB (SM80) - PyTorch: 2.8.0+cu128 - CUDA (torch): 12.8 ### bf16 | batch | seq_len | multi-CTA(us) | filtered(us)this PR | filtered(us) base | filtered this PR / base | |---:|---:|---:|---:|---:|---:| | 1 | 4096 | 27.470 | 20.887 | 20.377 | 1.0250 | | 1 | 16384 | 44.983 | 39.322 | 38.911 | 1.0106 | | 1 | 65536 | 106.397 | 89.090 | 88.629 | 1.0052 | | 1 | 131072 | 107.572 | 96.564 | 101.887 | 0.9478 | | 1 | 262144 | 89.549 | 154.777 | 154.625 | 1.0010 | | 1 | 524288 | 93.492 | 242.329 | 258.611 | 0.9370 | | 16 | 4096 | 18.739 | 16.178 | 16.896 | 0.9575 | | 16 | 16384 | 28.055 | 25.087 | 25.190 | 0.9959 | | 16 | 65536 | 63.949 | 55.551 | 55.294 | 1.0046 | | 16 | 131072 | 65.381 | 67.379 | 66.968 | 1.0061 | | 16 | 262144 | 65.591 | 114.792 | 114.689 | 1.0009 | | 16 | 524288 | 136.191 | 213.658 | 213.708 | 0.9998 | | 64 | 4096 | 19.304 | 18.227 | 16.997 | 1.0724 | | 64 | 16384 | 28.875 | 26.620 | 26.620 | 1.0000 | | 64 | 65536 | 69.630 | 61.440 | 61.440 | 1.0000 | | 64 | 131072 | 133.197 | 76.800 | 76.800 | 1.0000 | | 64 | 262144 | 190.563 | 140.392 | 140.596 | 0.9985 | | 64 | 524288 | 335.819 | 264.088 | 263.935 | 1.0006 | | 256 | 4096 | 38.910 | 29.904 | 29.700 | 1.0069 | | 256 | 16384 | 79.870 | 73.576 | 73.524 | 1.0007 | | 256 | 65536 | 200.700 | 173.060 | 173.060 | 1.0000 | | 256 | 131072 | 322.662 | 215.142 | 215.091 | 1.0002 | | 256 | 262144 | 588.391 | 421.581 | 421.477 | 1.0002 | | 256 | 524288 | 1141.863 | 811.111 | 811.367 | 0.9997 | ### fp16 | batch | seq_len | multi-CTA(us) | filtered(us)this PR | filtered(us) base | filtered this PR / base | |---:|---:|---:|---:|---:|---:| | 1 | 4096 | 27.599 | 20.324 | 19.252 | 1.0557 | | 1 | 16384 | 44.545 | 36.860 | 36.707 | 1.0042 | | 1 | 65536 | 103.780 | 70.145 | 70.863 | 0.9899 | | 1 | 131072 | 103.270 | 97.638 | 93.080 | 1.0490 | | 1 | 262144 | 89.141 | 154.778 | 147.199 | 1.0515 | | 1 | 524288 | 84.274 | 201.217 | 202.087 | 0.9957 | | 16 | 4096 | 19.738 | 15.820 | 16.793 | 0.9421 | | 16 | 16384 | 29.695 | 25.498 | 25.447 | 1.0020 | | 16 | 65536 | 65.562 | 46.283 | 46.847 | 0.9880 | | 16 | 131072 | 65.537 | 67.072 | 67.174 | 0.9985 | | 16 | 262144 | 65.024 | 114.484 | 114.175 | 1.0027 | | 16 | 524288 | 133.426 | 176.232 | 176.027 | 1.0012 | | 64 | 4096 | 18.485 | 15.667 | 17.815 | 0.8794 | | 64 | 16384 | 28.670 | 24.682 | 24.580 | 1.0041 | | 64 | 65536 | 68.147 | 51.200 | 51.200 | 1.0000 | | 64 | 131072 | 131.714 | 76.800 | 76.800 | 1.0000 | | 64 | 262144 | 188.930 | 137.322 | 137.322 | 1.0000 | | 64 | 524288 | 328.855 | 208.900 | 208.900 | 1.0000 | | 256 | 4096 | 38.885 | 27.547 | 27.341 | 1.0075 | | 256 | 16384 | 78.901 | 68.610 | 68.610 | 1.0000 | | 256 | 65536 | 195.580 | 142.237 | 142.134 | 1.0007 | | 256 | 131072 | 318.975 | 214.938 | 214.836 | 1.0005 | | 256 | 262144 | 582.863 | 410.672 | 410.777 | 0.9997 | | 256 | 524288 | 1119.438 | 643.019 | 642.815 | 1.0003 | ### fp32 | batch | seq_len | multi-CTA(us) | filtered(us)this PR | filtered(us) base | filtered this PR / base | |---:|---:|---:|---:|---:|---:| | 1 | 4096 | 35.840 | 19.970 | 19.970 | 1.0000 | | 1 | 16384 | 58.370 | 44.957 | 45.161 | 0.9955 | | 1 | 65536 | 99.072 | 87.911 | 89.089 | 0.9868 | | 1 | 131072 | 92.674 | 132.098 | 124.827 | 1.0582 | | 1 | 262144 | 88.114 | 188.930 | 187.392 | 1.0082 | | 1 | 524288 | 77.157 | 263.422 | 260.658 | 1.0106 | | 16 | 4096 | 22.270 | 16.692 | 16.794 | 0.9939 | | 16 | 16384 | 36.150 | 28.673 | 29.799 | 0.9622 | | 16 | 65536 | 60.879 | 56.014 | 56.731 | 0.9874 | | 16 | 131072 | 60.522 | 86.632 | 86.581 | 1.0006 | | 16 | 262144 | 123.594 | 151.860 | 151.653 | 1.0014 | | 16 | 524288 | 141.516 | 275.816 | 275.559 | 1.0009 | | 64 | 4096 | 22.581 | 17.306 | 17.817 | 0.9713 | | 64 | 16384 | 37.916 | 31.740 | 31.740 | 1.0000 | | 64 | 65536 | 121.860 | 66.866 | 66.662 | 1.0031 | | 64 | 131072 | 177.150 | 112.742 | 112.742 | 1.0000 | | 64 | 262144 | 308.426 | 203.728 | 203.780 | 0.9997 | | 64 | 524288 | 503.938 | 310.373 | 310.270 | 1.0003 | | 256 | 4096 | 53.250 | 30.516 | 30.924 | 0.9868 | | 256 | 16384 | 108.618 | 88.781 | 88.987 | 0.9977 | | 256 | 65536 | 295.657 | 188.420 | 188.471 | 0.9997 | | 256 | 131072 | 543.282 | 333.514 | 333.361 | 1.0005 | | 256 | 262144 | 1002.573 | 654.493 | 654.646 | 0.9998 | | 256 | 524288 | 1899.034 | 1035.158 | 1035.468 | 0.9997 | ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved Top-K correctness for long sequences with overflow-safe, staged refinement (fast and slow paths), multi-pass threshold handling, and coordination across parallel units to handle tie-heavy and buffer-overflow edge cases (BF16/FP32). * **Tests** * Added regression tests, data builders, and helpers for BF16 long-sequence transforms and FP32 long-sequence overflow/pivot cases, plus tie-insensitive index comparison utilities. * **Documentation** * Added note about potential threshold-bin overflow (bf16/tie-heavy) with suggestion to use a multi-CTA mode when needed. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
This PR fixes issue #2604: FilteredTopK could produce incorrect results when the refine candidate buffer
overflows in long-sequence, tie-heavy workloads.
Root Cause
FilteredTopK uses a fixed-size refine candidate buffer (SMEM_INPUT_SIZE=16K), which can overflow when the threshold bin is too large. The previous logic could continue with partially truncated state after overflow, causing incorrect top-k outputs.
Main Changes
1) Kernel correctness fixes (include/flashinfer/topk.cuh)
input to guarantee correct top-k results.
2) Regression test coverage (tests/utils/test_topk.py)
🔍 Related Issues
fix: #2604
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Run the following benchmark 10 times. There is no noticeable performance degradation in this PR.
Environment
bf16
fp16
fp32
Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests
Documentation