feat: implement deterministic topk#2661
feat: implement deterministic topk#2661jiangyinzuo wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an opt-in deterministic mode across the top-k stack: Python APIs, FFI bindings, C++ dispatch, and CUDA kernels; implements deterministic multi-CTA collection and stable tie‑breaking, updates benchmarks/CLI for deterministic comparisons and DSA workloads, and adds deterministic-focused tests and helpers. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python API
participant Bind as FFI Binding
participant Dispatch as C++ Dispatcher
participant Kernel as CUDA Kernel
Py->>Bind: call top_k(..., deterministic=True)
Bind->>Dispatch: radix_topk(..., sorted_output=..., deterministic=...)
Dispatch->>Kernel: launch deterministic-aware kernel (det scratch, DETERMINISTIC)
Kernel->>Kernel: deterministic collect / stable tie-breaking / optional stable sort
Kernel-->>Dispatch: return indices & values
Dispatch-->>Bind: propagate results
Bind-->>Py: deliver deterministic outputs
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant feature by enabling deterministic behavior for all top-k related operations within FlashInfer. It provides users with fine-grained control over determinism, which is crucial for reproducibility in machine learning and scientific computing. The changes span Python APIs, CUDA kernels, and benchmarking tools, ensuring robust implementation and verification of this new capability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature: deterministic top-k selection. The changes are extensive, adding new execution paths to both the radix and filtered top-k algorithms to ensure reproducible results, which is particularly important for handling ties. The implementation includes backward compatibility for existing APIs by adding new optional parameters. The benchmarks and tests have been updated comprehensively to cover the new deterministic modes. The overall implementation is well-designed and robust. I have one suggestion to improve code clarity and remove a minor redundancy in the CUDA kernel.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 51-56: The benchmark currently enters the
torch_deterministic_algorithms context inside run_torch_topk on every iteration,
adding overhead; instead enable deterministic mode once before the timing loop
and restore the prior state afterwards, removing the per-iteration context from
run_torch_topk (and the analogous per-iteration context in the other benchmark
at lines 136-145); specifically, call the global deterministic enable API once
(save the previous value), run the repeated torch.topk calls normally inside
run_torch_topk, then restore the saved deterministic setting after the loop so
the timing measures only torch.topk cost.
In `@tests/utils/test_topk.py`:
- Around line 1492-1514: The BF16 reproducibility test
(test_top_k_reproducible_mode_repeatability_multi_cta) runs unconditionally but
must be skipped on GPUs with compute capability < SM80; add a guard at the start
of the test that calls flashinfer.utils.get_compute_capability() (or the project
helper like flashinfer.utils.is_sm90a_supported/is_sm80_supported) and use
pytest.skip(...) when the capability is below 80 to avoid running BF16 on
unsupported hardware; also add an import for pytest if it's not present.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/__init__.pyflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
4358ff1 to
9e88bc8
Compare
9e88bc8 to
42a86f9
Compare
42a86f9 to
7679c40
Compare
c196697 to
66fe404
Compare
|
@coderabbitai review /gemini review |
|
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Code Review
This pull request introduces deterministic top-k operations across FlashInfer's core functions, including top_k, top_k_page_table_transform, and top_k_ragged_transform. The changes involve adding new helper functions for score generation and benchmarking, extending command-line arguments in benchmarks/bench_topk.py to support deterministic mode and various input patterns, and modifying the C++ kernels to propagate and utilize a deterministic_mode flag. New test cases have been added to validate the repeatability and correctness of these deterministic operations, especially under tie-heavy conditions. Feedback includes removing a debug print statement from the benchmarks and redundant int() casts in the Python API for clarity.
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/topk.py (1)
266-283:⚠️ Potential issue | 🟠 MajorInvalidate the cached
topkJIT module for this ABI change.This PR adds new parameters (
sorted_output,deterministic_mode) to theradix_topk*function signatures. The Python wrappers now pass these parameters to the CUDA exports (e.g.,module.radix_topk(..., sorted_cuda, mode, ...)), but the cache key is purely name-based (topk.so). If an older cached.soexists from a prior version, it will be reused without recompilation, causing the new Python code to call stale CUDA exports with mismatched arity. Bump the JIT module name or add a version/hash suffix to invalidate existing caches alongside this binding change.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 266 - 283, Cached JIT module name is unchanged so old topk.so can be reused even though radix_topk signatures changed; update the cache key used by get_topk_module (or wherever the extension is built/loaded) to include a version or hash/suffix (e.g., bump the "topk" module name or append a ABI version string) so that new builds produce a new .so and old caches are not loaded; ensure the new key is used when calling get_topk_module() so the Python wrappers (radix_topk, radix_topk_cuda exports) are always matched to the correct compiled ABI.
🧹 Nitpick comments (3)
benchmarks/bench_topk.py (3)
38-38: Remove debug print statement.This
print("message: ", message)appears to be a debug artifact that will clutter benchmark output when an unclassified error occurs. Consider removing it or gating behind a verbose/debug flag.🔧 Proposed fix
if "invalid argument" in message or "operation not supported" in message: return "UNSUPPORTED" - print("message: ", message) return None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` at line 38, Remove the stray debug print that prints the local variable message (the line print("message: ", message) in bench_topk.py); either delete this statement or wrap it behind an explicit verbose/debug flag so benchmark output remains clean (refer to the print("message: ", message) call and the local variable message when making the change).
64-71: Consider enabling CUPTI timing for more accurate measurements.The benchmark uses
enable_cupti=False, but CUPTI typically provides more accurate kernel timing. Based on learnings,flashinfer.testing.bench_gpu_time()is recommended with CUPTI timing enabled (with auto-fallback to CUDA events).🔧 Proposed change
def bench_median_ms(fn) -> float: measurements = bench_gpu_time( fn, - enable_cupti=False, + enable_cupti=True, dry_run_iters=10, repeat_iters=100, ) return float(np.median(measurements))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 64 - 71, The benchmark helper bench_median_ms currently calls bench_gpu_time(fn, enable_cupti=False, ...) which disables CUPTI; change it to enable CUPTI timing by calling bench_gpu_time with enable_cupti=True so CUPTI is used (falling back to CUDA events automatically as implemented in bench_gpu_time), keeping dry_run_iters and repeat_iters unchanged; update the call site in function bench_median_ms to set enable_cupti=True.
633-654: Consider extracting header formatting logic.The conditional header and divider length calculations are getting complex with multiple flag combinations. For maintainability, consider extracting this into a helper function. However, this is acceptable for a benchmark script.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 633 - 654, Extract the header/divider construction into a small helper (e.g., format_header_and_divider) that takes the flags args.deterministic, args.compare_torch_deterministic, and args.compare_sglang and returns the header string and divider_len; move the current conditional logic that builds header and computes divider_len into that helper and replace the block that sets header, divider_len and prints with a single call to format_header_and_divider to improve readability and maintainability while keeping the existing output formatting and width calculations unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3259-3261: The failing pre-commit is due to an unformatted call to
RadixTopKRaggedTransformMultiCTA<DType, IdType>; update the formatting of this
return statement to match clang-format (wrap parameters/line breaks as
clang-format expects for the call using the same function name and parameters:
input, output_indices, offsets, lengths, num_rows, top_k_val, max_len,
row_states_buffer, deterministic, stream) and commit that rewritten hunk so
pre-commit/CI no longer reformats it.
- Around line 3269-3293: The radix path needs the same index-based
canonicalization as the filtered path when deterministic is true: after calling
RadixTopKMultiCTA<DType, IdType>(...) and before any
StableSortTopKByValue<DType, IdType>(...) call, invoke SortTopKByIndex<DType,
IdType>(output_indices, output_values, num_rows, top_k_val, max_len, stream) to
reorder ties by index; then, if sorted_output is true, follow that with the
existing StableSortTopKByValue<DType, IdType>(...) call so ties are preserved
according to the canonical index order.
---
Outside diff comments:
In `@flashinfer/topk.py`:
- Around line 266-283: Cached JIT module name is unchanged so old topk.so can be
reused even though radix_topk signatures changed; update the cache key used by
get_topk_module (or wherever the extension is built/loaded) to include a version
or hash/suffix (e.g., bump the "topk" module name or append a ABI version
string) so that new builds produce a new .so and old caches are not loaded;
ensure the new key is used when calling get_topk_module() so the Python wrappers
(radix_topk, radix_topk_cuda exports) are always matched to the correct compiled
ABI.
---
Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Line 38: Remove the stray debug print that prints the local variable message
(the line print("message: ", message) in bench_topk.py); either delete this
statement or wrap it behind an explicit verbose/debug flag so benchmark output
remains clean (refer to the print("message: ", message) call and the local
variable message when making the change).
- Around line 64-71: The benchmark helper bench_median_ms currently calls
bench_gpu_time(fn, enable_cupti=False, ...) which disables CUPTI; change it to
enable CUPTI timing by calling bench_gpu_time with enable_cupti=True so CUPTI is
used (falling back to CUDA events automatically as implemented in
bench_gpu_time), keeping dry_run_iters and repeat_iters unchanged; update the
call site in function bench_median_ms to set enable_cupti=True.
- Around line 633-654: Extract the header/divider construction into a small
helper (e.g., format_header_and_divider) that takes the flags
args.deterministic, args.compare_torch_deterministic, and args.compare_sglang
and returns the header string and divider_len; move the current conditional
logic that builds header and computes divider_len into that helper and replace
the block that sets header, divider_len and prints with a single call to
format_header_and_divider to improve readability and maintainability while
keeping the existing output formatting and width calculations unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b335df74-24d1-444c-ba84-c593cd054fa5
📒 Files selected for processing (6)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
|
@coderabbitai review /gemini review |
|
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Code Review
This pull request introduces a new deterministic mode for FlashInfer's top-k operations, including top_k, top_k_page_table_transform, and top_k_ragged_transform. This mode ensures bitwise-reproducible output for the selected top-k elements, especially in cases with ties. Key changes include extending the API with a deterministic parameter, modifying CUDA kernels to implement deterministic collection logic, adjusting shared memory management for CUB's BlockScan in deterministic mode, and updating heuristics for algorithm selection. The benchmarking script was enhanced to support deterministic mode and various input patterns, and extensive new test cases were added to verify correctness and repeatability. Additionally, new sorting kernels were introduced for post-processing deterministic filtered top-k results when sorted output is requested.
|
@coderabbitai review /gemini review |
|
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Code Review
This pull request implements deterministic mode for FlashInfer's top-k operations, covering basic top-k, page table transform, and ragged transform. The changes involve significant CUDA kernel updates to ensure bitwise reproducibility, particularly for multi-CTA and tie-heavy scenarios, along with new post-processing kernels for stable sorting. The PR also expands the benchmarking suite to include DeepSeek DSA-like workloads and adds rigorous tests for repeatability across streams. Feedback indicates that the IsRadixTopKUnsupportedConfig helper function is currently unused and should be removed to maintain code quality.
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3010-3038: StableSortTopKByValue currently chooses kernel variants
whose capacity (threads * ITEMS_PER_THREAD determined by the kernel template
parameters used in StableSortTopKByValueKernel) tops out at 2048, so when
callers (e.g. the radix deterministic path) pass top_k_val > 2048 the function
leaves higher-index entries untouched; fix by adding logic in
StableSortTopKByValue to ensure the chosen kernel can handle at least top_k_val
items per row: either select or add larger kernel template instantiations
(increase threads and/or ITEMS_PER_THREAD parameters for
StableSortTopKByValueKernel) or implement a chunking/fallback path that
processes rows in multiple launches until all top_k_val items are covered;
update the launch selection (the if/else chain that calls launch_sort with
StableSortTopKByValueKernel<...>) so the product of the kernel params >=
top_k_val for all possible inputs.
- Around line 206-209: RadixDeterministicCollectScratch's fixed arrays
gt_count/eq_count only allow 256 CTAs but ctas_per_group (used when indexing
cta_in_group) can be larger and overflow; fix by either bounding ctas_per_group
to a safe MAX_CTAS (e.g., 256) before any indexing into gt_count/eq_count or by
changing the design to allocate gt_count/eq_count dynamically (shared or heap)
sized to ctas_per_group and pass that pointer into the kernels that use
RadixDeterministicCollectScratch; update all call sites and the kernel logic
that compute/use ctas_per_group and cta_in_group so they respect the new cap or
use the dynamically-sized storage to avoid out-of-bounds writes.
In `@tests/utils/test_topk.py`:
- Around line 1515-1527: The helper function _assert_top_k_matches_torch uses a
parameter named sorted which shadows Python's builtin and triggers a Ruff lint;
rename the parameter (for example to sorted_output) in the function signature
and update all references inside the function (the calls to flashinfer.top_k and
torch.topk should pass sorted=sorted_output or use the renamed keyword) as well
as any external callers if present so the behavior remains identical.
- Around line 1884-1913: The test currently only checks values and repeatability
(values_a/values_b and indices_a/indices_b) but doesn't verify that returned
indices are valid and include the strictly greater winners; update
test_top_k_deterministic_sorted_tie_break_oracle to also assert that indices_a
(and indices_b) have length k, contain no duplicates, all entries are within [0,
vocab_size), and include the index range of the gt winners (vocab_size -
gt_count ... vocab_size - 1) when pattern != "all_equal"; use the existing
variables indices_a/indices_b, logits, vocab_size, gt_count and k to implement
these checks so any kernel repeating an index, dropping a winner, or returning
out-of-bounds indices will fail the test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d3df0407-8612-4a73-895a-efac3f2f7800
📒 Files selected for processing (6)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
|
@coderabbitai review /gemini review |
|
🧠 Learnings used✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Code Review
This pull request introduces a deterministic mode for FlashInfer's top-k kernels, ensuring bitwise-reproducible output ordering. The changes include updates to the CUDA kernels to support deterministic collection paths, modifications to the Python bindings and benchmark utilities to expose this functionality, and the addition of comprehensive tests to verify repeatability and correctness across different routing modes and input patterns.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/topk.py (1)
270-283:⚠️ Potential issue | 🟠 MajorPass
stable=Truetotorch.sort()whendeterministic=True.When
deterministic=Trueandk > 2048, the fallback sort path is executed but usestorch.sort()with the defaultstable=False. This allows equal values to be reordered arbitrarily, violating the documented guarantee that "deterministic mode guarantees repeatable FlashInfer output ordering."The fix passes
stable=deterministicto ensure stable sorting only in deterministic mode, preserving performance for non-deterministic cases.🔧 Suggested fix
if sorted and not sorted_cuda: # Sort within each row by value (descending) - sorted_values, sort_indices = torch.sort(output_values, dim=-1, descending=True) + sorted_values, sort_indices = torch.sort( + output_values, + dim=-1, + descending=True, + stable=deterministic, + ) sorted_indices = torch.gather(indices, dim=-1, index=sort_indices) return sorted_values, sorted_indices🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 270 - 283, In the fallback path (when sorted and not sorted_cuda) ensure torch.sort is invoked with stable=deterministic so equal values remain in a repeatable order when deterministic=True: update the call to torch.sort(output_values, dim=-1, descending=True) inside the sorted/not-sorted_cuda block to pass stable=deterministic (affecting sorted_values, sort_indices used with indices via torch.gather to produce sorted_indices) so deterministic mode preserves stable ordering while leaving performance unchanged when deterministic is False.
♻️ Duplicate comments (2)
include/flashinfer/topk.cuh (1)
3243-3260:⚠️ Potential issue | 🟠 MajorCanonicalize radix ties before the deterministic value sort.
The filtered branch already normalizes deterministic output with
LaunchSortTopKByIndex(...), but the radix branch still goes straight toStableSortTopKByValue(...). Because that sort is stable, equal values preserve the collect order fromRadixCollectIndicesDeterministic(...), sodeterministic=True, sorted=Truestill changes tie order based on which algorithm was selected.🔧 Suggested fix
} else { FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>( input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len, row_states_buffer, deterministic, stream))); + if (deterministic && sorted_output) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>( + output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, + stream))); + } } if (sorted_output) { FLASHINFER_CUDA_CALL((StableSortTopKByValue<DType, IdType>( output_indices, output_values, num_rows, top_k_val, max_len, stream))); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3243 - 3260, The radix path doesn't canonicalize ties before the deterministic value sort, so when deterministic and sorted both true ties can differ by algorithm; after calling RadixTopKMultiCTA<DType, IdType>(...) inside the else branch, add the same canonicalization step used in the filtered branch: if (deterministic) call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>(output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so radix-collected ties are normalized prior to the subsequent StableSortTopKByValue<DType, IdType>(...).tests/utils/test_topk.py (1)
1851-1881:⚠️ Potential issue | 🟠 MajorDon’t derive the sorted oracle from FlashInfer’s unsorted order.
This test currently asserts that
sorted=Trueis exactly a stable sort ofsorted=False. The public API only promises deterministic sorted output, not that the sorted path must preserve the unsorted emission order, so this will reject valid implementations.🔧 Safer oracle
- unsorted_values, unsorted_indices = flashinfer.top_k( - logits, k, deterministic=True, sorted=False - ) sorted_values_a, sorted_indices_a = flashinfer.top_k( logits, k, deterministic=True, sorted=True ) sorted_values_b, sorted_indices_b = flashinfer.top_k( logits, k, deterministic=True, sorted=True ) - expected_values, sort_order = torch.sort( - unsorted_values, dim=-1, descending=True, stable=True - ) - expected_indices = torch.gather(unsorted_indices, dim=-1, index=sort_order) + expected_row = torch.argsort(pattern, descending=True, stable=True)[:k] + expected_indices = expected_row.unsqueeze(0).expand(batch_size, -1) + expected_values = torch.gather(logits, dim=-1, index=expected_indices)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1851 - 1881, The test currently derives the "expected" sorted results from flashinfer.top_k(..., sorted=False) which wrongly assumes the sorted path must be a stable reordering of the unsorted path; change the oracle so it computes expected results directly from the input logits instead. In test_top_k_deterministic_sorted_matches_stable_sort, replace the block that computes expected_values and expected_indices from unsorted_values/unsorted_indices with a direct sort/top-k on logits (e.g., torch.topk(logits, k, dim=-1, largest=True, sorted=True) or torch.sort on logits then gather indices) and assert flashinfer.top_k(..., sorted=True) matches that expected output and remains deterministic across calls. Ensure references remain to flashinfer.top_k, logits, and the test function name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 30-35: classify_benchmark_runtime_error currently only matches
"invalid argument" or "operation not supported" but misses CUDA's "invalid
configuration" text; update the function (classify_benchmark_runtime_error) to
also detect "invalid configuration" (or "invalid configuration argument") in the
lowercased exception message and return "UNSUPPORTED" for that case so
deterministic multi-CTA invalid-configuration failures from
include/flashinfer/topk.cuh are classified correctly.
- Around line 124-137: The code currently omits setting result["sglang_us"] when
the SGLang comparison is skipped or fails, and callers treat its absence as an
error; change the SGLang comparison block (the branch using compare_sglang,
HAS_SGL_KERNEL, k==2048, scores.dtype==torch.float32 and calling
sgl_kernel.fast_topk_v2 via bench_median_ms) to always set a status-aware entry
in result (e.g., set result["sglang_us"] to None or a sentinel and add
result["sglang_status"]="skipped" or "error" on skips/failures), mirroring the
handling used in the page-table and ragged comparison blocks so that skipped
comparisons do not print “(SGLang error)” and real exceptions are caught,
logged, and reflected in the result instead of aborting before emitting the
FlashInfer result; ensure you also compute result["speedup_vs_sglang"] only when
sg_ms is valid.
In `@include/flashinfer/topk.cuh`:
- Around line 232-240: AdvanceRadixGroupBarrier currently lets thread 0 publish
the group arrival before the CTA has synchronized, allowing other CTAs to
observe partially written state; change the order so the block first
synchronizes and flushes memory, then thread 0 performs red_release. Concretely,
in AdvanceRadixGroupBarrier add a __syncthreads() (and a __threadfence() or
__threadfence_system() as appropriate for visibility to other CTAs) before the
if (tx == 0) red_release(&state->arrival_counter, 1) so all threads finish
histogram/output writes and those writes are visible before arrival_counter is
incremented; leave the rest of the function (wait_ge, barrier_phase++,
__syncthreads()) unchanged.
---
Outside diff comments:
In `@flashinfer/topk.py`:
- Around line 270-283: In the fallback path (when sorted and not sorted_cuda)
ensure torch.sort is invoked with stable=deterministic so equal values remain in
a repeatable order when deterministic=True: update the call to
torch.sort(output_values, dim=-1, descending=True) inside the
sorted/not-sorted_cuda block to pass stable=deterministic (affecting
sorted_values, sort_indices used with indices via torch.gather to produce
sorted_indices) so deterministic mode preserves stable ordering while leaving
performance unchanged when deterministic is False.
---
Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3243-3260: The radix path doesn't canonicalize ties before the
deterministic value sort, so when deterministic and sorted both true ties can
differ by algorithm; after calling RadixTopKMultiCTA<DType, IdType>(...) inside
the else branch, add the same canonicalization step used in the filtered branch:
if (deterministic) call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType,
IdType>(output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val,
max_len, stream) so radix-collected ties are normalized prior to the subsequent
StableSortTopKByValue<DType, IdType>(...).
In `@tests/utils/test_topk.py`:
- Around line 1851-1881: The test currently derives the "expected" sorted
results from flashinfer.top_k(..., sorted=False) which wrongly assumes the
sorted path must be a stable reordering of the unsorted path; change the oracle
so it computes expected results directly from the input logits instead. In
test_top_k_deterministic_sorted_matches_stable_sort, replace the block that
computes expected_values and expected_indices from
unsorted_values/unsorted_indices with a direct sort/top-k on logits (e.g.,
torch.topk(logits, k, dim=-1, largest=True, sorted=True) or torch.sort on logits
then gather indices) and assert flashinfer.top_k(..., sorted=True) matches that
expected output and remains deterministic across calls. Ensure references remain
to flashinfer.top_k, logits, and the test function name.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8d61e949-8848-4bff-b118-c9851e4f86ea
📒 Files selected for processing (6)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
|
|
||
|
|
||
| @pytest.mark.parametrize("algo", ["auto", "multi_cta", "filtered"]) | ||
| def test_fp32_long_seq_refine_overflow_regression_across_algorithms( |
There was a problem hiding this comment.
test_fp32_long_seq_refine_overflow_regression_across_algorithms
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_overflow_inputs, algo in {auto, multi_cta, filtered},
api_kind="top_k")
test_fp32_long_seq_refine_overflow_transform_regression_across_algorithms
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_overflow_inputs, algo in {auto, multi_cta, filtered},
api_kind in {"page_table", "ragged"})
test_fp32_long_seq_pivot_rebuild_regression_filtered
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered",
api_kind="top_k")
test_fp32_long_seq_pivot_rebuild_transform_regression_filtered
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered", api_kind in
{"page_table", "ragged"})
There was a problem hiding this comment.
The provided code changes implement a deterministic mode for Top-K operations in FlashInfer. The implementation introduces a deterministic flag across the public API and internal CUDA kernels, ensuring bitwise-reproducible output ordering for the selected top-k set on a fixed input and system.
I have reviewed the changes, and they appear to correctly handle the necessary synchronization and deterministic collection logic required for this feature. The addition of RadixDeterministicCollectScratch and the use of DeterministicThreadStridedCollect are appropriate for maintaining determinism across multiple CTAs.
One minor suggestion for the test_fp32_long_seq_regression_matrix test: ensure that the builder functions are robust against potential OOM errors when running on smaller GPUs, as the stress cases added (stress_b2048_l131072_k1024 and stress_b4096_l200000_k1024) are quite memory-intensive. The current classify_benchmark_runtime_error helper is a good addition for handling these cases gracefully in the benchmark suite.
There was a problem hiding this comment.
@jiangyinzuo, the refactoring is clean and the coverage mapping is accurate. Here's a detailed breakdown:
Confirmed equivalence (old → new):
| Old function | New: test_fp32_long_seq_regression_matrix params |
|---|---|
test_fp32_long_seq_refine_overflow_regression_across_algorithms |
builder=_build_fp32_long_seq_overflow_inputs, algo ∈ {auto, multi_cta, filtered}, api_kind="top_k" |
test_fp32_long_seq_refine_overflow_transform_regression_across_algorithms |
builder=_build_fp32_long_seq_overflow_inputs, algo ∈ {auto, multi_cta, filtered}, api_kind ∈ {page_table, ragged} |
test_fp32_long_seq_pivot_rebuild_regression_filtered |
builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered", api_kind="top_k" |
test_fp32_long_seq_pivot_rebuild_transform_regression_filtered |
builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered", api_kind ∈ {page_table, ragged} |
Observations:
-
Additive coverage bonus:
_build_fp32_long_seq_pivot_mismatch_inputswithalgo="filtered"is now also exercised againstapi_kind ∈ {page_table, ragged}— this is new coverage beyond what the original four functions provided, which is a net positive. -
Correct assertion strategy:
_assert_top_k_matches_torch(used forapi_kind="top_k") asserts exact sorted order viatorch.testing.assert_close+torch.equal._assert_transform_identity_matches_torch(used for transform api_kinds) correctly uses_assert_unordered_indices_match, since transform APIs don't expose asortedparameter. -
ids=[...]usage: The explicit IDs (refine_overflow-auto,refine_overflow-multi_cta, etc.) make test reports readable — good practice. -
set_topk_algofixture: Correctly scoped viayieldwith teardown restoring the env var, so algo overrides are isolated per test case. -
Minor nit:
logits, _, _, k = builder()silently discardsbatch_sizeandvocab_size. A comment like# batch_size, vocab_size unusedwould aid readability, but this is very minor.
The refactoring looks correct and the consolidation into a single parameterized matrix test is an improvement. ✅
🧠 Learnings used
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.
Learnt from: DomBrown
Repo: flashinfer-ai/flashinfer PR: 2770
File: flashinfer/decode.py:2231-2235
Timestamp: 2026-03-19T20:24:35.442Z
Learning: In `flashinfer/decode.py` (and related files `flashinfer/prefill.py`, `flashinfer/mla.py`), the `uses_shared_paged_kv_idx=False` mode is intended for direct TRT-LLM integration. When this flag is False, the `kv_cache` and `kv_block_scales` are expected to already be in TRT-LLM's native paged layout (separate K/V page indices, 3D block_tables `[batch_size, 2, max_num_pages_per_seq]`). The test code interleaves/reshapes tensors only to simulate TRT-LLM layout from a FlashInfer/vLLM-layout fixture — this is a test artifact, not a requirement imposed on real callers.
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh:343-366
Timestamp: 2026-03-23T21:04:23.630Z
Learning: FlashInfer Mamba SSU MTP kernels: When applying the z gate in epilogues, index z using its own strides (z_stride_batch/z_stride_mtp), not the output strides. Verified in include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh (role_epilogue) and include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh (inline epilogue).
Learnt from: yzh119
Repo: flashinfer-ai/flashinfer PR: 2370
File: tests/gdn/conftest.py:25-34
Timestamp: 2026-01-21T21:26:00.701Z
Learning: Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures. Ensure test files under tests/ follow this convention and avoid adding CPU-only guards in fixtures unless explicitly handling a non-CUDA environment.
|
@coderabbitai review /gemini review |
|
🧠 Learnings used✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Code Review
This pull request introduces deterministic mode for FlashInfer's top-k operations, ensuring bitwise-reproducible results across basic, page table, and ragged transform variants. The implementation includes new CUDA kernels for deterministic collection and stable sorting, along with an expanded benchmarking suite for DeepSeek DSA workloads. The review feedback highlights opportunities to improve code maintainability by refactoring duplicated printing logic in the benchmark script and replacing macro-based shared memory aliasing in the CUDA kernels with a more readable struct-based approach.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (3)
tests/utils/test_topk.py (1)
1487-1499:⚠️ Potential issue | 🟡 MinorRename
sorted; Ruff still flags this helper.Line 1488 shadows the Python builtin, so this helper keeps tripping A002.
sorted_outputavoids the lint with no behavior change.🧹 Minimal rename
def _assert_top_k_matches_torch( - logits: torch.Tensor, k: int, *, deterministic: bool = False, sorted: bool = True + logits: torch.Tensor, + k: int, + *, + deterministic: bool = False, + sorted_output: bool = True, ): """Assert FlashInfer top_k matches torch.topk for exact-order cases.""" values, indices = flashinfer.top_k( - logits, k, deterministic=deterministic, sorted=sorted + logits, k, deterministic=deterministic, sorted=sorted_output ) - ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=sorted) + ref_values, ref_indices = torch.topk( + logits, k, dim=-1, sorted=sorted_output + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/utils/test_topk.py` around lines 1487 - 1499, Rename the parameter named sorted in the helper function _assert_top_k_matches_torch to avoid shadowing the built-in; change the parameter name to sorted_output and update all uses inside the function (the flashinfer.top_k call and torch.topk call) to pass sorted=sorted_output (and any internal references if present), leaving the behavior and variable names values, indices, ref_values, ref_indices unchanged.include/flashinfer/topk.cuh (2)
232-240:⚠️ Potential issue | 🔴 CriticalSynchronize the CTA before publishing the radix-group arrival.
AdvanceRadixGroupBarrier()still lets Line 235 advancearrival_counterbefore the rest of the block is forced to finish its preceding histogram/output writes. The current callers at Line 468, Line 648, and Line 851 hit it immediately after per-thread atomics/stores, so another CTA can observe partially updated state and break correctness/determinism again.🔧 Minimal fix
__device__ __forceinline__ void AdvanceRadixGroupBarrier(RadixRowState* state, int& barrier_phase, uint32_t ctas_per_group, uint32_t tx) { + __syncthreads(); if (tx == 0) { red_release(&state->arrival_counter, 1); } int target = (barrier_phase + 1) * ctas_per_group; wait_ge(&state->arrival_counter, target, tx);Expected result: either the helper owns the CTA sync, or every releasing call site shows an immediate
__syncthreads()before it.#!/bin/bash set -euo pipefail sed -n '232,240p' include/flashinfer/topk.cuh sed -n '452,470p' include/flashinfer/topk.cuh sed -n '635,650p' include/flashinfer/topk.cuh sed -n '835,855p' include/flashinfer/topk.cuh sed -n '1256,1263p' include/flashinfer/topk.cuh🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 232 - 240, AdvanceRadixGroupBarrier currently releases the radix-group arrival (red_release(&state->arrival_counter, 1)) before the CTA is synchronized, allowing other CTAs to observe partially written per-thread state; fix it by owning the CTA sync inside AdvanceRadixGroupBarrier: add a __syncthreads() immediately before the tx==0 release path so the block finishes all histogram/output stores before calling red_release, leaving the existing wait_ge(&state->arrival_counter, target, tx), barrier_phase++, and trailing __syncthreads() intact.
3241-3258:⚠️ Potential issue | 🟠 MajorCanonicalize radix ties before the stable value sort.
Line 3246 index-sorts only the filtered deterministic path. When Line 3251 routes deterministic work through radix,
StableSortTopKByValue()on Line 3256 preserves the deterministic collection order fromRadixCollectIndicesDeterministic, sosorted=True, deterministic=Truestill returns a different tie order depending on which algorithm was selected.🔧 Suggested fix
} else { FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>( input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len, row_states_buffer, deterministic, stream))); + if (deterministic && sorted_output) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>( + output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, + stream))); + } } if (sorted_output) { FLASHINFER_CUDA_CALL((StableSortTopKByValue<DType, IdType>( output_indices, output_values, num_rows, top_k_val, max_len, stream)));🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3241 - 3258, The deterministic canonicalization (index-sort via LaunchSortTopKByIndex) is only applied in the filtered path; ensure radix-based deterministic results are canonicalized the same way before the stable value sort. After calling RadixTopKMultiCTA in the else branch, if deterministic is true call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the same arguments used in the filtered branch (output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so that StableSortTopKByValue sees a canonical tie order regardless of which algorithm ran; keep the existing filtered-path LaunchSortTopKByIndex and the final StableSortTopKByValue intact.
🧹 Nitpick comments (2)
benchmarks/bench_topk.py (1)
209-223: Consider using-float('inf')consistently forneg_inffallback.For
fp16/bf16, usingtorch.finfo(dtype).mininstead of-infmeans values at the minimum representable float could still be selected over the masked positions. If the intent is to fully exclude masked positions from top-k selection,-inf(which is representable in fp16/bf16) would be more robust.🔧 Suggested fix
- neg_inf = -torch.inf if dtype == torch.float32 else torch.finfo(dtype).min + neg_inf = float('-inf')🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_topk.py` around lines 209 - 223, The masking uses torch.finfo(dtype).min for neg_inf when dtype is fp16/bf16 which can still be chosen; change the neg_inf computation in the causal_chunk block (where start_pos, lengths, q_len, dtype are used) to use a true negative infinity constant (e.g. -float('inf')) for the masked_fill value so masked positions are fully excluded when you call scores = scores.masked_fill(invalid, neg_inf).flashinfer/topk.py (1)
176-182: Docstring could clarify tie-breaking strategy for deterministic mode.The PR objectives and issue
#2584mention that deterministic mode uses "lower element index wins" for tie-breaking. Consider adding this detail to the docstring so users understand the expected behavior when values are equal.📝 Suggested docstring enhancement
deterministic : bool, optional If True, uses deterministic mode. Default is False (non-deterministic, which is faster). Deterministic mode guarantees repeatable FlashInfer output ordering for - the selected top-k set on a fixed input and system. + the selected top-k set on a fixed input and system. When values are equal, + elements with lower indices are selected first (stable tie-breaking).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 176 - 182, Update the docstring for the top_k function to explicitly state the tie-breaking rule used when deterministic=True: when values are equal the element with the lower index is chosen ("lower element index wins"). Mention this behavior near the deterministic parameter description in the top_k docstring so callers know how ties are resolved and that ordering may differ from non-deterministic behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 213-218: The cached row-state workspace currently doesn't include
space for the deterministic scratch tail used by
MaybeGetRadixDeterministicCollectScratchBuffer, so when deterministic &&
!single_cta the pointer (row_states_buffer + num_groups) can walk past the
allocated buffer; update all allocation sites that create
radix_topk_row_states_* (in flashinfer/topk.py and any C++/CUDA allocs) to
reserve room for both RadixRowState[num_groups] and
RadixDeterministicCollectScratch[num_groups] (i.e. allocate num_groups of
RadixRowState plus num_groups of RadixDeterministicCollectScratch, or
equivalently adjust byte-size to
num_groups*(sizeof(RadixRowState)+sizeof(RadixDeterministicCollectScratch))),
and ensure any cached size calculations and related comments reflect this change
so deterministic multi-CTA no longer overruns the buffer.
In `@tests/utils/test_topk.py`:
- Around line 1896-1937: The tests
test_top_k_deterministic_sorted_large_k_matches_torch_by_algo and
test_top_k_deterministic_trivial_k_equals_length_by_algo currently parametrize
over "filtered" but use k values (4096 and vocab_size) larger than
FILTERED_TOPK_MAX_K (defined in include/flashinfer/topk.cuh as 2048), so they
never exercise FilteredTopK; update the parametrization to only use ["auto",
"multi_cta"] for these two tests, or alternatively add a separate test case that
explicitly uses set_topk_algo("filtered") with k <= FILTERED_TOPK_MAX_K (e.g.,
k=2048) to validate the filtered path.
---
Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 232-240: AdvanceRadixGroupBarrier currently releases the
radix-group arrival (red_release(&state->arrival_counter, 1)) before the CTA is
synchronized, allowing other CTAs to observe partially written per-thread state;
fix it by owning the CTA sync inside AdvanceRadixGroupBarrier: add a
__syncthreads() immediately before the tx==0 release path so the block finishes
all histogram/output stores before calling red_release, leaving the existing
wait_ge(&state->arrival_counter, target, tx), barrier_phase++, and trailing
__syncthreads() intact.
- Around line 3241-3258: The deterministic canonicalization (index-sort via
LaunchSortTopKByIndex) is only applied in the filtered path; ensure radix-based
deterministic results are canonicalized the same way before the stable value
sort. After calling RadixTopKMultiCTA in the else branch, if deterministic is
true call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the
same arguments used in the filtered branch (output_indices, output_values,
nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so that
StableSortTopKByValue sees a canonical tie order regardless of which algorithm
ran; keep the existing filtered-path LaunchSortTopKByIndex and the final
StableSortTopKByValue intact.
In `@tests/utils/test_topk.py`:
- Around line 1487-1499: Rename the parameter named sorted in the helper
function _assert_top_k_matches_torch to avoid shadowing the built-in; change the
parameter name to sorted_output and update all uses inside the function (the
flashinfer.top_k call and torch.topk call) to pass sorted=sorted_output (and any
internal references if present), leaving the behavior and variable names values,
indices, ref_values, ref_indices unchanged.
---
Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 209-223: The masking uses torch.finfo(dtype).min for neg_inf when
dtype is fp16/bf16 which can still be chosen; change the neg_inf computation in
the causal_chunk block (where start_pos, lengths, q_len, dtype are used) to use
a true negative infinity constant (e.g. -float('inf')) for the masked_fill value
so masked positions are fully excluded when you call scores =
scores.masked_fill(invalid, neg_inf).
In `@flashinfer/topk.py`:
- Around line 176-182: Update the docstring for the top_k function to explicitly
state the tie-breaking rule used when deterministic=True: when values are equal
the element with the lower index is chosen ("lower element index wins"). Mention
this behavior near the deterministic parameter description in the top_k
docstring so callers know how ties are resolved and that ordering may differ
from non-deterministic behavior.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ae026e77-9523-469a-8f61-393293ddf092
📒 Files selected for processing (6)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/flashinfer_topk_binding.cu
- csrc/topk.cu
|
@coderabbitai review /gemini review |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
🧠 Learnings used✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
include/flashinfer/topk.cuh (1)
3259-3267:⚠️ Potential issue | 🟠 MajorRadix deterministic path omits index canonicalization; tie order differs from filtered path.
The filtered path calls
SortTopKByIndexwhendeterministic=true(line 3255-3257), but the radix path (line 3260-3262) does not. This means:
- Filtered + deterministic: ties ordered by ascending index
- Radix + deterministic: ties ordered by thread-strided emission
When
sorted_output=true,StableSortTopKByValuepreserves prior ordering for ties, so the final tie-breaking strategy depends on which algorithm was selected—breaking the contract that deterministic mode produces consistent, index-ordered ties.🔧 Suggested fix
} else { FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>( input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len, row_states_buffer, deterministic, stream))); + if (deterministic) { + FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>( + output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, + stream))); + } }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3259 - 3267, The radix deterministic path omits the canonical index tie-break used by the filtered path; after RadixTopKMultiCTA completes and when deterministic==true, call SortTopKByIndex with the same arguments used by the filtered path (e.g., output_indices, output_values, num_rows, top_k_val, max_len, stream) before any StableSortTopKByValue call so ties are canonicalized by ascending index; ensure this conditional mirrors the filtered path's deterministic branch around SortTopKByIndex so both algorithms produce identical tie order.
🧹 Nitpick comments (2)
flashinfer/topk.py (1)
63-73: Ruff flagsinputshadowing Python builtin.The static analysis tool flags line 65 for shadowing Python's built-in
input. However, this pattern is consistent with the existing codebase conventions for tensor parameter naming in this file. Given the "Chill" review mode and that this is a widespread pattern, this can be addressed in a separate cleanup if desired.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/topk.py` around lines 63 - 73, The parameter name input in the _fake_radix_topk function shadows Python's builtin and should be renamed to avoid the Ruff warning; update the function signature of _fake_radix_topk (registered as "flashinfer::radix_topk") to use a non-builtins name (e.g., tensor, inp, or src_tensor) and replace all uses inside the function (input.size and input.device) accordingly so behavior is unchanged.include/flashinfer/topk.cuh (1)
3174-3198: Consider documenting heuristic rationale.The deterministic-mode algorithm selection heuristics (lines 3174-3184) differ significantly from non-deterministic heuristics (lines 3186-3197). Consider adding a brief comment explaining the trade-off (e.g., filtered deterministic overhead vs. radix multi-CTA coordination cost).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/topk.cuh` around lines 3174 - 3198, Add a short explanatory comment immediately above the block that branches on deterministic and DType size (referencing variables/conditions: deterministic, sizeof(DType), max_len, num_rows, and batch_threshold) that explains why deterministic-mode thresholds differ from non-deterministic ones — e.g., deterministic implementation favors simpler per-row filtered scans to avoid non-deterministic cross-CTA radix coordination (hence lower thresholds like 16384 and the special 256 divisor), while non-deterministic heuristics accept radix/multi-CTA strategies for larger max_len (notice thresholds 16384/32768 and the use of max_len/4096 or /16384 to compute batch_threshold); keep the comment concise (1–3 lines) describing the trade-off and pointing to the key constants so future readers understand the rationale.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3259-3267: The radix deterministic path omits the canonical index
tie-break used by the filtered path; after RadixTopKMultiCTA completes and when
deterministic==true, call SortTopKByIndex with the same arguments used by the
filtered path (e.g., output_indices, output_values, num_rows, top_k_val,
max_len, stream) before any StableSortTopKByValue call so ties are canonicalized
by ascending index; ensure this conditional mirrors the filtered path's
deterministic branch around SortTopKByIndex so both algorithms produce identical
tie order.
---
Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 63-73: The parameter name input in the _fake_radix_topk function
shadows Python's builtin and should be renamed to avoid the Ruff warning; update
the function signature of _fake_radix_topk (registered as
"flashinfer::radix_topk") to use a non-builtins name (e.g., tensor, inp, or
src_tensor) and replace all uses inside the function (input.size and
input.device) accordingly so behavior is unchanged.
In `@include/flashinfer/topk.cuh`:
- Around line 3174-3198: Add a short explanatory comment immediately above the
block that branches on deterministic and DType size (referencing
variables/conditions: deterministic, sizeof(DType), max_len, num_rows, and
batch_threshold) that explains why deterministic-mode thresholds differ from
non-deterministic ones — e.g., deterministic implementation favors simpler
per-row filtered scans to avoid non-deterministic cross-CTA radix coordination
(hence lower thresholds like 16384 and the special 256 divisor), while
non-deterministic heuristics accept radix/multi-CTA strategies for larger
max_len (notice thresholds 16384/32768 and the use of max_len/4096 or /16384 to
compute batch_threshold); keep the comment concise (1–3 lines) describing the
trade-off and pointing to the key constants so future readers understand the
rationale.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4b20947c-0a7d-4e41-ace0-d46414a25216
📒 Files selected for processing (6)
benchmarks/bench_topk.pycsrc/flashinfer_topk_binding.cucsrc/topk.cuflashinfer/topk.pyinclude/flashinfer/topk.cuhtests/utils/test_topk.py
|
@yzh119 @Linda-Stadter This PR is ready for review now. |
also add cub stable radix sort and overflow handling Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com> Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
| } | ||
| } | ||
| if (tx == 0) { | ||
| st_release(&state->output_counter, 0); |
There was a problem hiding this comment.
This is not safe for the nondeterministic path because there is no final __syncthreads() in RadixCollectIndices. But also, this is redundant anyway? I would then just remove it
| uint32_t cta_local_eq_count = 0; | ||
| OrderedType ordered_pivot = | ||
| RadixSelectFindPivot<BLOCK_THREADS, VEC_SIZE, SINGLE_CTA, DETERMINISTIC, DType>( | ||
| input + row_idx * stride, shared_ordered, local_histogram, suffix_sum, shared_scalars, |
There was a problem hiding this comment.
This doesn't contain my overflow fix by casting to size_t. I will create another commit on top of this :)
📌 Description
Deterministic Mode for Top-K Kernels
FilteredTopK Kernel
FilteredTopKKernel implements deterministic mode as follows:
collect_gt_and_nondet_eq_threshold); their final order is determined by the post-sort kernel.collect_det_eq_pivot, which writes the selected tie elements intos_indicesin deterministic thread-strided order.SortTopKByIndexKernelis applied to produce index-ascending output and make the final ordering deterministic (we use atomicAdd to collect > pivot at stage 1).StableSortTopKByValueKernelis applied afterward to produce value-descending output.RadixTopK Kernel
ordered_pivot, which Stage 2 uses to determine whether an element is >=ordered_pivot.cta_local_eq_countandcta_local_gt_count, which Stage 2 uses to determine how many elements the current CTA may emit and where each emitted element should be placed.RadixCollectIndicesDeterministic)RadixCollectIndicesDeterministic: after the pivot is known, assigns each CTA a fixed output range, then writes all > pivot elements followed by the required == pivot elements in a deterministic order.
Order definition:
Benchmarks
machine: NVIDIA A100-PCIE-40GB
command: (fp32/fp16/bf16)
raw results:
output.txt
Summary
NOTE: FlashInfer deterministic underperforms PyTorch mainly on short-sequence workloads. Importantly, this is not unique to the deterministic path: FlashInfer non-deterministic top-k is also slower than PyTorch in the same short-sequence regime. This suggests the gap is primarily a short-sequence top-k issue rather than a deterministic-specific regression. Optimizing short-sequence top-k, for both non-deterministic and deterministic modes, is better treated as future work.
🔍 Related Issues
close: #2584
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Benchmarks
Tests
Bug Fixes