Skip to content

feat: implement deterministic topk#2661

Open
jiangyinzuo wants to merge 1 commit intoflashinfer-ai:mainfrom
jiangyinzuo:feat/deterministic-topk
Open

feat: implement deterministic topk#2661
jiangyinzuo wants to merge 1 commit intoflashinfer-ai:mainfrom
jiangyinzuo:feat/deterministic-topk

Conversation

@jiangyinzuo
Copy link
Copy Markdown
Contributor

@jiangyinzuo jiangyinzuo commented Mar 1, 2026

📌 Description

Part of the FilteredTopK implementation refers to or is adapted from @Linda-Stadter's work in #2759

Deterministic Mode for Top-K Kernels

FilteredTopK Kernel

FilteredTopKKernel implements deterministic mode as follows:

  1. Build a coarse histogram.
  • Build a coarse histogram on the top 8 bits to locate the coarse threshold bin that contains the k-th largest element.
  • Same as non-deterministic mode, elements with bin > threshold_bin are appended to s_indices via atomicAdd (see collect_gt_and_nondet_eq_threshold); their final order is determined by the post-sort kernel.
  1. Refine with 8-bit radix passes.
  • Run multiple 8-bit refine passes to find the exact pivot.
  • Deterministic == pivot selection is performed by collect_det_eq_pivot, which writes the selected tie elements into s_indices in deterministic thread-strided order.

Thread-strided order means, for example, if BLOCK_THREADS = 4, then the logical scan order is:

  • thread 0: 0, 4, 8, ...
  • thread 1: 1, 5, 9, ...
  • thread 2: 2, 6, 10, ...
  • thread 3: 3, 7, 11, ...

If the == pivot positions are:

  • thread 0: 0, 8
  • thread 1: 5
  • thread 2: none
  • thread 3: 3, 7

then the deterministic collection order is: [0, 8, 5, 3, 7].
That is, we order elements first by thread ID, and then by each thread's strided traversal order.

  1. Post-sort kernels.
  • After FilteredTopKKernel finishes, SortTopKByIndexKernel is applied to produce index-ascending output and make the final ordering deterministic (we use atomicAdd to collect > pivot at stage 1).
  • If the Python API is called with sorted=True, StableSortTopKByValueKernel is applied afterward to produce value-descending output.

RadixTopK Kernel

  1. RadixSelectFindPivot
  • Finds ordered_pivot, which Stage 2 uses to determine whether an element is >= ordered_pivot.
  • Computes cta_local_eq_count and cta_local_gt_count, which Stage 2 uses to determine how many elements the current CTA may emit and where each emitted element should be placed.
  1. collect_indices (RadixCollectIndicesDeterministic)

RadixCollectIndicesDeterministic: after the pivot is known, assigns each CTA a fixed output range, then writes all > pivot elements followed by the required == pivot elements in a deterministic order.

Order definition:

  • Emit > pivot elements first, then == pivot elements.
  • For each category, earlier CTAs write to earlier output positions.
  • Within each CTA, emit elements in thread-strided order.

Benchmarks

machine: NVIDIA A100-PCIE-40GB

command: (fp32/fp16/bf16)

python -u benchmarks/bench_topk.py \
  --op all \
  --dtype fp32 \
  --deterministic \
  --compare-torch-deterministic \
  --input-pattern random

raw results:

output.txt
Summary

dtype geomean det slowdown vs non-det geomean speedup vs torch.det
fp32 1.0992x 1.7660x
fp16 1.0777x 1.3381x
bf16 1.0745x 1.3055x

NOTE: FlashInfer deterministic underperforms PyTorch mainly on short-sequence workloads. Importantly, this is not unique to the deterministic path: FlashInfer non-deterministic top-k is also slower than PyTorch in the same short-sequence regime. This suggests the gap is primarily a short-sequence top-k issue rather than a deterministic-specific regression. Optimizing short-sequence top-k, for both non-deterministic and deterministic modes, is better treated as future work.

🔍 Related Issues

close: #2584

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).
unittest I ran:
test_topk.py
test_sampling.py
test_logits_processor.py

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Deterministic mode for top‑k and fused transforms (stable, repeatable tie ordering) with API flag to enable deterministic outputs and stable sorting behavior.
  • Benchmarks

    • Expanded benchmarking to compare deterministic vs nondeterministic runs, pre-generated input patterns, DSA workload cases, and richer CLI output.
  • Tests

    • Large suite of determinism and correctness tests (ties, multi‑CTA, streams, sorted behavior, cache transitions).
  • Bug Fixes

    • Improved runtime-error labeling and benchmark cache handling.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 1, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an opt-in deterministic mode across the top-k stack: Python APIs, FFI bindings, C++ dispatch, and CUDA kernels; implements deterministic multi-CTA collection and stable tie‑breaking, updates benchmarks/CLI for deterministic comparisons and DSA workloads, and adds deterministic-focused tests and helpers.

Changes

Cohort / File(s) Summary
Benchmarks & CLI
benchmarks/bench_topk.py
Refactor benchmark flow to pre-generate scores, add deterministic benchmarking infrastructure (deterministic vs nondeterministic timings, torch-deterministic comparison), DSA workload generation, and new CLI flags/options.
Python API surface
flashinfer/topk.py
Add deterministic: bool to top_k, top_k_page_table_transform, top_k_ragged_transform; forward sorted_output/deterministic into CUDA bindings; adjust kernel selection and stable-sort fallback behavior.
FFI binding layer
csrc/flashinfer_topk_binding.cu
Extend exported bindings radix_topk, radix_topk_page_table_transform, radix_topk_ragged_transform to accept new sorted_output/deterministic boolean parameters.
C++ dispatcher & glue
csrc/topk.cu
Thread new sorted_output and deterministic flags into TopKDispatch/fused dispatch calls and propagate to kernel launch paths.
CUDA kernels & headers
include/flashinfer/topk.cuh
Major deterministic additions: ordered SMEM sizing helper, deterministic multi‑CTA scratch/barrier primitives, deterministic collection and pivot eq-count tracking, deterministic-aware FilteredTopK and stable post-sort transforms, and updated dispatch/heuristics with deterministic guards.
Tests
tests/utils/test_topk.py
Add deterministic repeatability/tie/stability tests, cached radix row-states buffer inspection/eviction helpers, parameterize tests over deterministic mode, and expand transform/regression coverage to validate deterministic behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python API
    participant Bind as FFI Binding
    participant Dispatch as C++ Dispatcher
    participant Kernel as CUDA Kernel
    Py->>Bind: call top_k(..., deterministic=True)
    Bind->>Dispatch: radix_topk(..., sorted_output=..., deterministic=...)
    Dispatch->>Kernel: launch deterministic-aware kernel (det scratch, DETERMINISTIC)
    Kernel->>Kernel: deterministic collect / stable tie-breaking / optional stable sort
    Kernel-->>Dispatch: return indices & values
    Dispatch-->>Bind: propagate results
    Bind-->>Py: deliver deterministic outputs
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • jiahanc
  • kahyunnam
  • cyx-6
  • yzh119
  • IwakuraRein
  • jimmyzho
  • bkryu
  • nvmbreughe

Poem

🐰 I hopped through kernels, threads aligned,
I made each tie yield the smaller mind.
Runs now repeat, indices hold true,
Determinism carrot—crisp and new.
🥕

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.10% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed The PR implements all key acceptance criteria from issue #2584: opt-in deterministic mode, bitwise reproducibility via thread-strided collection and CTA-assigned ranges, well-defined tie-breaking by thread-strided order, and comprehensive benchmarking showing ~1.08-1.10x slowdown and 1.43-2.06x speedup vs PyTorch.
Out of Scope Changes check ✅ Passed All changes are scoped to deterministic top-k implementation: kernel modifications (radix and filtered paths), Python API extensions, benchmark infrastructure updates, and comprehensive test additions. No unrelated changes detected.
Title check ✅ Passed The title 'feat: implement deterministic topk' directly and clearly summarizes the primary change of the pull request—adding deterministic mode support for top-k kernels.
Description check ✅ Passed The PR description is comprehensive and well-structured, following the repository's template with complete Description, Related Issues, and Pull Request Checklist sections.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant feature by enabling deterministic behavior for all top-k related operations within FlashInfer. It provides users with fine-grained control over determinism, which is crucial for reproducibility in machine learning and scientific computing. The changes span Python APIs, CUDA kernels, and benchmarking tools, ensuring robust implementation and verification of this new capability.

Highlights

  • Deterministic Top-K Operations: Implemented deterministic behavior for all top-k operations (top_k, top_k_page_table_transform, top_k_ragged_transform) to ensure repeatable results, especially when ties are present in input scores.
  • Flexible Deterministic Mode Control: Introduced a DeterministicMode enum (NON_DETERMINISTIC, REPRODUCIBLE) and corresponding parameters (deterministic, deterministic_mode) in Python APIs, allowing users to explicitly control the level of determinism. The deterministic=True flag is maintained for backward compatibility, mapping to REPRODUCIBLE mode.
  • Enhanced Benchmarking and Testing: Updated the benchmarking script (bench_topk.py) with new arguments for deterministic mode and various input patterns (random, tie_heavy, pivot_tie). Comprehensive unit tests (test_topk.py) were added to verify the repeatability and correctness of deterministic top-k operations across different scenarios and algorithms.
  • CUDA Kernel Modifications: Modified underlying CUDA kernels (topk.cu, topk.cuh) to incorporate deterministic logic, including changes to radix selection, index collection (with new RadixBlockExclusivePrefix and RadixCollectIndicesReproducible functions), and filtered top-k sorting (FilteredTopKBitonicSortIndices). Heuristics for algorithm selection (ShouldUseFilteredTopKDeterministicAware) were also updated to consider deterministic requirements.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_topk.py
    • Added contextlib.contextmanager for temporary PyTorch deterministic algorithm mode.
    • Implemented run_torch_topk to wrap torch.topk with optional deterministic mode.
    • Introduced generate_scores to create benchmark input with various tie patterns.
    • Extended bench_top_k, bench_page_table_transform, and bench_ragged_transform functions with input_pattern, deterministic, and deterministic_mode parameters.
    • Added command-line arguments --deterministic, --deterministic-mode, --compare-torch-deterministic, and --input-pattern to main function.
    • Updated benchmark output formatting to reflect new deterministic options.
  • csrc/flashinfer_topk_binding.cu
    • Modified C++ function signatures for radix_topk, radix_topk_page_table_transform, and radix_topk_ragged_transform to include a deterministic_mode parameter.
  • csrc/topk.cu
    • Added ParseDeterministicMode helper to convert integer mode to sampling::DeterministicMode enum.
    • Passed deterministic_mode to sampling::TopKDispatch functions.
    • Updated RadixSelectFromSharedMemory to optionally track equal counts for deterministic tie-breaking.
    • Implemented RadixBlockExclusivePrefix and RadixCollectIndicesReproducible for deterministic index collection.
    • Introduced RadixCollectIndicesDispatch to select between deterministic and non-deterministic collection paths.
    • Modified kernel launch logic to use the DETERMINISTIC template parameter.
  • flashinfer/init.py
    • Exported DeterministicMode enum from flashinfer.topk.
  • flashinfer/topk.py
    • Defined DeterministicMode as an IntEnum for NON_DETERMINISTIC and REPRODUCIBLE.
    • Added _DETERMINISTIC_MODE_ALIASES for string-based mode selection.
    • Implemented _resolve_deterministic_mode to parse and validate deterministic mode parameters.
    • Added deterministic and deterministic_mode parameters to top_k, top_k_page_table_transform, and top_k_ragged_transform Python APIs.
    • Modified top_k to use stable=True for torch.sort when sorted=True and in reproducible mode.
    • Passed the resolved deterministic mode to the underlying C++ kernel calls.
  • include/flashinfer/topk.cuh
    • Defined DeterministicMode enum and IsDeterministicMode helper.
    • Added GetReproducibleTargetCTAsPerGroup and MaybeBoostReproducibleCTAsPerGroup for dynamic CTA adjustment in reproducible mode.
    • Updated RadixSelectFromSharedMemory to optionally track eq_count.
    • Introduced RadixBlockExclusivePrefix for block-level exclusive prefix sums.
    • Implemented RadixCollectIndicesReproducible for deterministic index collection with tie-breaking.
    • Added RadixCollectIndicesDispatch to conditionally use deterministic collection.
    • Modified RadixTopKKernel_Unified to accept a DETERMINISTIC template parameter.
    • Added FilteredTopKBitonicSortIndices for sorting indices in deterministic filtered top-k.
    • Updated FilteredTopKUnifiedKernel to accept DETERMINISTIC template parameter and use bitonic sort.
    • Introduced SelectFilteredTopKBlockThreads for dynamic block size selection in filtered top-k.
    • Added LaunchFilteredTopKUnified to centralize filtered top-k kernel launches.
    • Modified ShouldUseFilteredTopKDeterministicAware to include deterministic mode heuristics for algorithm selection.
    • Updated TopKPageTableTransformDispatch, TopKRaggedTransformDispatch, and TopKDispatch to pass deterministic_mode and use the new deterministic-aware heuristics.
  • tests/utils/test_topk.py
    • Imported DeterministicMode enum.
    • Added test_top_k_deterministic_mode_bool_compatibility to verify backward compatibility.
    • Included test_top_k_reproducible_mode_repeatability and test_top_k_reproducible_mode_repeatability_multi_cta for repeatable results.
    • Added test_top_k_invalid_deterministic_mode to check error handling for invalid modes.
    • Implemented test_top_k_deterministic_bitwise_repeatability for strict bitwise repeatability.
    • Added repeatability tests for top_k_page_table_transform and top_k_ragged_transform in both deterministic=True and DeterministicMode.REPRODUCIBLE modes.
Activity
  • The pull request is currently a Work In Progress (WIP).
  • Pre-commit checks have been completed.
  • Tests are not yet marked as added/updated or passing, indicating ongoing development and verification.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: deterministic top-k selection. The changes are extensive, adding new execution paths to both the radix and filtered top-k algorithms to ensure reproducible results, which is particularly important for handling ties. The implementation includes backward compatibility for existing APIs by adding new optional parameters. The benchmarks and tests have been updated comprehensively to cover the new deterministic modes. The overall implementation is well-designed and robust. I have one suggestion to improve code clarity and remove a minor redundancy in the CUDA kernel.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 51-56: The benchmark currently enters the
torch_deterministic_algorithms context inside run_torch_topk on every iteration,
adding overhead; instead enable deterministic mode once before the timing loop
and restore the prior state afterwards, removing the per-iteration context from
run_torch_topk (and the analogous per-iteration context in the other benchmark
at lines 136-145); specifically, call the global deterministic enable API once
(save the previous value), run the repeated torch.topk calls normally inside
run_torch_topk, then restore the saved deterministic setting after the loop so
the timing measures only torch.topk cost.

In `@tests/utils/test_topk.py`:
- Around line 1492-1514: The BF16 reproducibility test
(test_top_k_reproducible_mode_repeatability_multi_cta) runs unconditionally but
must be skipped on GPUs with compute capability < SM80; add a guard at the start
of the test that calls flashinfer.utils.get_compute_capability() (or the project
helper like flashinfer.utils.is_sm90a_supported/is_sm80_supported) and use
pytest.skip(...) when the capability is below 80 to avoid running BF16 on
unsupported hardware; also add an import for pytest if it's not present.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f521fe1 and a595ced.

📒 Files selected for processing (7)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/__init__.py
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

@jiangyinzuo jiangyinzuo marked this pull request as draft March 1, 2026 09:31
@jiangyinzuo jiangyinzuo force-pushed the feat/deterministic-topk branch 12 times, most recently from 4358ff1 to 9e88bc8 Compare March 5, 2026 16:48
@jiangyinzuo jiangyinzuo force-pushed the feat/deterministic-topk branch from 9e88bc8 to 42a86f9 Compare March 8, 2026 12:00
@jiangyinzuo jiangyinzuo marked this pull request as ready for review March 8, 2026 12:01
@jiangyinzuo jiangyinzuo force-pushed the feat/deterministic-topk branch from 42a86f9 to 7679c40 Compare March 8, 2026 12:02
@jiangyinzuo jiangyinzuo force-pushed the feat/deterministic-topk branch 2 times, most recently from c196697 to 66fe404 Compare March 24, 2026 08:41
@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

jiangyinzuo commented Mar 26, 2026

@coderabbitai review

/gemini review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 26, 2026

@jiangyinzuo: I'll kick off a full review of the PR now, including the latest changes with the CUB sort integration and overflow fix cherry-picked from @Linda-Stadter's fork.

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces deterministic top-k operations across FlashInfer's core functions, including top_k, top_k_page_table_transform, and top_k_ragged_transform. The changes involve adding new helper functions for score generation and benchmarking, extending command-line arguments in benchmarks/bench_topk.py to support deterministic mode and various input patterns, and modifying the C++ kernels to propagate and utilize a deterministic_mode flag. New test cases have been added to validate the repeatability and correctness of these deterministic operations, especially under tie-heavy conditions. Feedback includes removing a debug print statement from the benchmarks and redundant int() casts in the Python API for clarity.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/topk.py (1)

266-283: ⚠️ Potential issue | 🟠 Major

Invalidate the cached topk JIT module for this ABI change.

This PR adds new parameters (sorted_output, deterministic_mode) to the radix_topk* function signatures. The Python wrappers now pass these parameters to the CUDA exports (e.g., module.radix_topk(..., sorted_cuda, mode, ...)), but the cache key is purely name-based (topk.so). If an older cached .so exists from a prior version, it will be reused without recompilation, causing the new Python code to call stale CUDA exports with mismatched arity. Bump the JIT module name or add a version/hash suffix to invalidate existing caches alongside this binding change.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 266 - 283, Cached JIT module name is
unchanged so old topk.so can be reused even though radix_topk signatures
changed; update the cache key used by get_topk_module (or wherever the extension
is built/loaded) to include a version or hash/suffix (e.g., bump the "topk"
module name or append a ABI version string) so that new builds produce a new .so
and old caches are not loaded; ensure the new key is used when calling
get_topk_module() so the Python wrappers (radix_topk, radix_topk_cuda exports)
are always matched to the correct compiled ABI.
🧹 Nitpick comments (3)
benchmarks/bench_topk.py (3)

38-38: Remove debug print statement.

This print("message: ", message) appears to be a debug artifact that will clutter benchmark output when an unclassified error occurs. Consider removing it or gating behind a verbose/debug flag.

🔧 Proposed fix
     if "invalid argument" in message or "operation not supported" in message:
         return "UNSUPPORTED"
 
-    print("message: ", message)
     return None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` at line 38, Remove the stray debug print that
prints the local variable message (the line print("message: ", message) in
bench_topk.py); either delete this statement or wrap it behind an explicit
verbose/debug flag so benchmark output remains clean (refer to the
print("message: ", message) call and the local variable message when making the
change).

64-71: Consider enabling CUPTI timing for more accurate measurements.

The benchmark uses enable_cupti=False, but CUPTI typically provides more accurate kernel timing. Based on learnings, flashinfer.testing.bench_gpu_time() is recommended with CUPTI timing enabled (with auto-fallback to CUDA events).

🔧 Proposed change
 def bench_median_ms(fn) -> float:
     measurements = bench_gpu_time(
         fn,
-        enable_cupti=False,
+        enable_cupti=True,
         dry_run_iters=10,
         repeat_iters=100,
     )
     return float(np.median(measurements))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` around lines 64 - 71, The benchmark helper
bench_median_ms currently calls bench_gpu_time(fn, enable_cupti=False, ...)
which disables CUPTI; change it to enable CUPTI timing by calling bench_gpu_time
with enable_cupti=True so CUPTI is used (falling back to CUDA events
automatically as implemented in bench_gpu_time), keeping dry_run_iters and
repeat_iters unchanged; update the call site in function bench_median_ms to set
enable_cupti=True.

633-654: Consider extracting header formatting logic.

The conditional header and divider length calculations are getting complex with multiple flag combinations. For maintainability, consider extracting this into a helper function. However, this is acceptable for a benchmark script.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` around lines 633 - 654, Extract the header/divider
construction into a small helper (e.g., format_header_and_divider) that takes
the flags args.deterministic, args.compare_torch_deterministic, and
args.compare_sglang and returns the header string and divider_len; move the
current conditional logic that builds header and computes divider_len into that
helper and replace the block that sets header, divider_len and prints with a
single call to format_header_and_divider to improve readability and
maintainability while keeping the existing output formatting and width
calculations unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3259-3261: The failing pre-commit is due to an unformatted call to
RadixTopKRaggedTransformMultiCTA<DType, IdType>; update the formatting of this
return statement to match clang-format (wrap parameters/line breaks as
clang-format expects for the call using the same function name and parameters:
input, output_indices, offsets, lengths, num_rows, top_k_val, max_len,
row_states_buffer, deterministic, stream) and commit that rewritten hunk so
pre-commit/CI no longer reformats it.
- Around line 3269-3293: The radix path needs the same index-based
canonicalization as the filtered path when deterministic is true: after calling
RadixTopKMultiCTA<DType, IdType>(...) and before any
StableSortTopKByValue<DType, IdType>(...) call, invoke SortTopKByIndex<DType,
IdType>(output_indices, output_values, num_rows, top_k_val, max_len, stream) to
reorder ties by index; then, if sorted_output is true, follow that with the
existing StableSortTopKByValue<DType, IdType>(...) call so ties are preserved
according to the canonical index order.

---

Outside diff comments:
In `@flashinfer/topk.py`:
- Around line 266-283: Cached JIT module name is unchanged so old topk.so can be
reused even though radix_topk signatures changed; update the cache key used by
get_topk_module (or wherever the extension is built/loaded) to include a version
or hash/suffix (e.g., bump the "topk" module name or append a ABI version
string) so that new builds produce a new .so and old caches are not loaded;
ensure the new key is used when calling get_topk_module() so the Python wrappers
(radix_topk, radix_topk_cuda exports) are always matched to the correct compiled
ABI.

---

Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Line 38: Remove the stray debug print that prints the local variable message
(the line print("message: ", message) in bench_topk.py); either delete this
statement or wrap it behind an explicit verbose/debug flag so benchmark output
remains clean (refer to the print("message: ", message) call and the local
variable message when making the change).
- Around line 64-71: The benchmark helper bench_median_ms currently calls
bench_gpu_time(fn, enable_cupti=False, ...) which disables CUPTI; change it to
enable CUPTI timing by calling bench_gpu_time with enable_cupti=True so CUPTI is
used (falling back to CUDA events automatically as implemented in
bench_gpu_time), keeping dry_run_iters and repeat_iters unchanged; update the
call site in function bench_median_ms to set enable_cupti=True.
- Around line 633-654: Extract the header/divider construction into a small
helper (e.g., format_header_and_divider) that takes the flags
args.deterministic, args.compare_torch_deterministic, and args.compare_sglang
and returns the header string and divider_len; move the current conditional
logic that builds header and computes divider_len into that helper and replace
the block that sets header, divider_len and prints with a single call to
format_header_and_divider to improve readability and maintainability while
keeping the existing output formatting and width calculations unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b335df74-24d1-444c-ba84-c593cd054fa5

📥 Commits

Reviewing files that changed from the base of the PR and between da308c7 and 1b99b0e.

📒 Files selected for processing (6)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

/gemini review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 26, 2026

@jiangyinzuo: I'll kick off a full review of the PR now, covering the deterministic top-k implementation, the CUB sort integration, the overflow fix from @Linda-Stadter's fork, and all associated benchmark and test changes.

✅ Actions performed

Full review triggered.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new deterministic mode for FlashInfer's top-k operations, including top_k, top_k_page_table_transform, and top_k_ragged_transform. This mode ensures bitwise-reproducible output for the selected top-k elements, especially in cases with ties. Key changes include extending the API with a deterministic parameter, modifying CUDA kernels to implement deterministic collection logic, adjusting shared memory management for CUB's BlockScan in deterministic mode, and updating heuristics for algorithm selection. The benchmarking script was enhanced to support deterministic mode and various input patterns, and extensive new test cases were added to verify correctness and repeatability. Additionally, new sorting kernels were introduced for post-processing deterministic filtered top-k results when sorted output is requested.

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

/gemini review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 28, 2026

@jiangyinzuo: I'll kick off a full review of the PR now, covering the deterministic top-k implementation across all modified files (include/flashinfer/topk.cuh, csrc/topk.cu, csrc/flashinfer_topk_binding.cu, flashinfer/topk.py, benchmarks/bench_topk.py, and tests/utils/test_topk.py), including the CUB sort integration and overflow fix from @Linda-Stadter's fork.

✅ Actions performed

Full review triggered.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements deterministic mode for FlashInfer's top-k operations, covering basic top-k, page table transform, and ragged transform. The changes involve significant CUDA kernel updates to ensure bitwise reproducibility, particularly for multi-CTA and tie-heavy scenarios, along with new post-processing kernels for stable sorting. The PR also expands the benchmarking suite to include DeepSeek DSA-like workloads and adds rigorous tests for repeatability across streams. Feedback indicates that the IsRadixTopKUnsupportedConfig helper function is currently unused and should be removed to maintain code quality.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3010-3038: StableSortTopKByValue currently chooses kernel variants
whose capacity (threads * ITEMS_PER_THREAD determined by the kernel template
parameters used in StableSortTopKByValueKernel) tops out at 2048, so when
callers (e.g. the radix deterministic path) pass top_k_val > 2048 the function
leaves higher-index entries untouched; fix by adding logic in
StableSortTopKByValue to ensure the chosen kernel can handle at least top_k_val
items per row: either select or add larger kernel template instantiations
(increase threads and/or ITEMS_PER_THREAD parameters for
StableSortTopKByValueKernel) or implement a chunking/fallback path that
processes rows in multiple launches until all top_k_val items are covered;
update the launch selection (the if/else chain that calls launch_sort with
StableSortTopKByValueKernel<...>) so the product of the kernel params >=
top_k_val for all possible inputs.
- Around line 206-209: RadixDeterministicCollectScratch's fixed arrays
gt_count/eq_count only allow 256 CTAs but ctas_per_group (used when indexing
cta_in_group) can be larger and overflow; fix by either bounding ctas_per_group
to a safe MAX_CTAS (e.g., 256) before any indexing into gt_count/eq_count or by
changing the design to allocate gt_count/eq_count dynamically (shared or heap)
sized to ctas_per_group and pass that pointer into the kernels that use
RadixDeterministicCollectScratch; update all call sites and the kernel logic
that compute/use ctas_per_group and cta_in_group so they respect the new cap or
use the dynamically-sized storage to avoid out-of-bounds writes.

In `@tests/utils/test_topk.py`:
- Around line 1515-1527: The helper function _assert_top_k_matches_torch uses a
parameter named sorted which shadows Python's builtin and triggers a Ruff lint;
rename the parameter (for example to sorted_output) in the function signature
and update all references inside the function (the calls to flashinfer.top_k and
torch.topk should pass sorted=sorted_output or use the renamed keyword) as well
as any external callers if present so the behavior remains identical.
- Around line 1884-1913: The test currently only checks values and repeatability
(values_a/values_b and indices_a/indices_b) but doesn't verify that returned
indices are valid and include the strictly greater winners; update
test_top_k_deterministic_sorted_tie_break_oracle to also assert that indices_a
(and indices_b) have length k, contain no duplicates, all entries are within [0,
vocab_size), and include the index range of the gt winners (vocab_size -
gt_count ... vocab_size - 1) when pattern != "all_equal"; use the existing
variables indices_a/indices_b, logits, vocab_size, gt_count and k to implement
these checks so any kernel repeating an index, dropping a winner, or returning
out-of-bounds indices will fail the test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d3df0407-8612-4a73-895a-efac3f2f7800

📥 Commits

Reviewing files that changed from the base of the PR and between 31b63bc and 8293ae0.

📒 Files selected for processing (6)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

/gemini review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 28, 2026

@jiangyinzuo: I'll kick off a full review of the PR now, covering the deterministic top-k implementation across all modified files (include/flashinfer/topk.cuh, csrc/topk.cu, csrc/flashinfer_topk_binding.cu, flashinfer/topk.py, benchmarks/bench_topk.py, and tests/utils/test_topk.py), including the CUB sort integration and overflow fix cherry-picked from @Linda-Stadter's fork.


🧠 Learnings used
Learnt from: Aalanli
Repo: flashinfer-ai/flashinfer PR: 2814
File: csrc/fast_topk_clusters.cu:135-141
Timestamp: 2026-03-20T17:55:09.735Z
Learning: In `csrc/fast_topk_clusters.cu`, the `fast_topk_cuda_v4` device function intentionally drops elements that fall in the threshold bin when the shared cache is full (`cached_offset >= num_cached`). This is a deliberate performance trade-off in the "fast" (inexact) TopK variant. The exact/correctness-preserving alternative is `fast_topk_clusters_exact` in `csrc/fast_topk_clusters_exact.cu`. Do not flag this overflow-drop behavior as a bug.
✅ Actions performed

Full review triggered.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a deterministic mode for FlashInfer's top-k kernels, ensuring bitwise-reproducible output ordering. The changes include updates to the CUDA kernels to support deterministic collection paths, modifications to the Python bindings and benchmark utilities to expose this functionality, and the addition of comprehensive tests to verify repeatability and correctness across different routing modes and input patterns.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/topk.py (1)

270-283: ⚠️ Potential issue | 🟠 Major

Pass stable=True to torch.sort() when deterministic=True.

When deterministic=True and k > 2048, the fallback sort path is executed but uses torch.sort() with the default stable=False. This allows equal values to be reordered arbitrarily, violating the documented guarantee that "deterministic mode guarantees repeatable FlashInfer output ordering."

The fix passes stable=deterministic to ensure stable sorting only in deterministic mode, preserving performance for non-deterministic cases.

🔧 Suggested fix
    if sorted and not sorted_cuda:
        # Sort within each row by value (descending)
-       sorted_values, sort_indices = torch.sort(output_values, dim=-1, descending=True)
+       sorted_values, sort_indices = torch.sort(
+           output_values,
+           dim=-1,
+           descending=True,
+           stable=deterministic,
+       )
        sorted_indices = torch.gather(indices, dim=-1, index=sort_indices)
        return sorted_values, sorted_indices
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 270 - 283, In the fallback path (when sorted
and not sorted_cuda) ensure torch.sort is invoked with stable=deterministic so
equal values remain in a repeatable order when deterministic=True: update the
call to torch.sort(output_values, dim=-1, descending=True) inside the
sorted/not-sorted_cuda block to pass stable=deterministic (affecting
sorted_values, sort_indices used with indices via torch.gather to produce
sorted_indices) so deterministic mode preserves stable ordering while leaving
performance unchanged when deterministic is False.
♻️ Duplicate comments (2)
include/flashinfer/topk.cuh (1)

3243-3260: ⚠️ Potential issue | 🟠 Major

Canonicalize radix ties before the deterministic value sort.

The filtered branch already normalizes deterministic output with LaunchSortTopKByIndex(...), but the radix branch still goes straight to StableSortTopKByValue(...). Because that sort is stable, equal values preserve the collect order from RadixCollectIndicesDeterministic(...), so deterministic=True, sorted=True still changes tie order based on which algorithm was selected.

🔧 Suggested fix
   } else {
     FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>(
         input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len,
         row_states_buffer, deterministic, stream)));
+    if (deterministic && sorted_output) {
+      FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>(
+          output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len,
+          stream)));
+    }
   }
   if (sorted_output) {
     FLASHINFER_CUDA_CALL((StableSortTopKByValue<DType, IdType>(
         output_indices, output_values, num_rows, top_k_val, max_len, stream)));
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3243 - 3260, The radix path doesn't
canonicalize ties before the deterministic value sort, so when deterministic and
sorted both true ties can differ by algorithm; after calling
RadixTopKMultiCTA<DType, IdType>(...) inside the else branch, add the same
canonicalization step used in the filtered branch: if (deterministic) call
LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>(output_indices,
output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so
radix-collected ties are normalized prior to the subsequent
StableSortTopKByValue<DType, IdType>(...).
tests/utils/test_topk.py (1)

1851-1881: ⚠️ Potential issue | 🟠 Major

Don’t derive the sorted oracle from FlashInfer’s unsorted order.

This test currently asserts that sorted=True is exactly a stable sort of sorted=False. The public API only promises deterministic sorted output, not that the sorted path must preserve the unsorted emission order, so this will reject valid implementations.

🔧 Safer oracle
-    unsorted_values, unsorted_indices = flashinfer.top_k(
-        logits, k, deterministic=True, sorted=False
-    )
     sorted_values_a, sorted_indices_a = flashinfer.top_k(
         logits, k, deterministic=True, sorted=True
     )
     sorted_values_b, sorted_indices_b = flashinfer.top_k(
         logits, k, deterministic=True, sorted=True
     )
 
-    expected_values, sort_order = torch.sort(
-        unsorted_values, dim=-1, descending=True, stable=True
-    )
-    expected_indices = torch.gather(unsorted_indices, dim=-1, index=sort_order)
+    expected_row = torch.argsort(pattern, descending=True, stable=True)[:k]
+    expected_indices = expected_row.unsqueeze(0).expand(batch_size, -1)
+    expected_values = torch.gather(logits, dim=-1, index=expected_indices)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1851 - 1881, The test currently
derives the "expected" sorted results from flashinfer.top_k(..., sorted=False)
which wrongly assumes the sorted path must be a stable reordering of the
unsorted path; change the oracle so it computes expected results directly from
the input logits instead. In
test_top_k_deterministic_sorted_matches_stable_sort, replace the block that
computes expected_values and expected_indices from
unsorted_values/unsorted_indices with a direct sort/top-k on logits (e.g.,
torch.topk(logits, k, dim=-1, largest=True, sorted=True) or torch.sort on logits
then gather indices) and assert flashinfer.top_k(..., sorted=True) matches that
expected output and remains deterministic across calls. Ensure references remain
to flashinfer.top_k, logits, and the test function name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_topk.py`:
- Around line 30-35: classify_benchmark_runtime_error currently only matches
"invalid argument" or "operation not supported" but misses CUDA's "invalid
configuration" text; update the function (classify_benchmark_runtime_error) to
also detect "invalid configuration" (or "invalid configuration argument") in the
lowercased exception message and return "UNSUPPORTED" for that case so
deterministic multi-CTA invalid-configuration failures from
include/flashinfer/topk.cuh are classified correctly.
- Around line 124-137: The code currently omits setting result["sglang_us"] when
the SGLang comparison is skipped or fails, and callers treat its absence as an
error; change the SGLang comparison block (the branch using compare_sglang,
HAS_SGL_KERNEL, k==2048, scores.dtype==torch.float32 and calling
sgl_kernel.fast_topk_v2 via bench_median_ms) to always set a status-aware entry
in result (e.g., set result["sglang_us"] to None or a sentinel and add
result["sglang_status"]="skipped" or "error" on skips/failures), mirroring the
handling used in the page-table and ragged comparison blocks so that skipped
comparisons do not print “(SGLang error)” and real exceptions are caught,
logged, and reflected in the result instead of aborting before emitting the
FlashInfer result; ensure you also compute result["speedup_vs_sglang"] only when
sg_ms is valid.

In `@include/flashinfer/topk.cuh`:
- Around line 232-240: AdvanceRadixGroupBarrier currently lets thread 0 publish
the group arrival before the CTA has synchronized, allowing other CTAs to
observe partially written state; change the order so the block first
synchronizes and flushes memory, then thread 0 performs red_release. Concretely,
in AdvanceRadixGroupBarrier add a __syncthreads() (and a __threadfence() or
__threadfence_system() as appropriate for visibility to other CTAs) before the
if (tx == 0) red_release(&state->arrival_counter, 1) so all threads finish
histogram/output writes and those writes are visible before arrival_counter is
incremented; leave the rest of the function (wait_ge, barrier_phase++,
__syncthreads()) unchanged.

---

Outside diff comments:
In `@flashinfer/topk.py`:
- Around line 270-283: In the fallback path (when sorted and not sorted_cuda)
ensure torch.sort is invoked with stable=deterministic so equal values remain in
a repeatable order when deterministic=True: update the call to
torch.sort(output_values, dim=-1, descending=True) inside the
sorted/not-sorted_cuda block to pass stable=deterministic (affecting
sorted_values, sort_indices used with indices via torch.gather to produce
sorted_indices) so deterministic mode preserves stable ordering while leaving
performance unchanged when deterministic is False.

---

Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3243-3260: The radix path doesn't canonicalize ties before the
deterministic value sort, so when deterministic and sorted both true ties can
differ by algorithm; after calling RadixTopKMultiCTA<DType, IdType>(...) inside
the else branch, add the same canonicalization step used in the filtered branch:
if (deterministic) call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType,
IdType>(output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val,
max_len, stream) so radix-collected ties are normalized prior to the subsequent
StableSortTopKByValue<DType, IdType>(...).

In `@tests/utils/test_topk.py`:
- Around line 1851-1881: The test currently derives the "expected" sorted
results from flashinfer.top_k(..., sorted=False) which wrongly assumes the
sorted path must be a stable reordering of the unsorted path; change the oracle
so it computes expected results directly from the input logits instead. In
test_top_k_deterministic_sorted_matches_stable_sort, replace the block that
computes expected_values and expected_indices from
unsorted_values/unsorted_indices with a direct sort/top-k on logits (e.g.,
torch.topk(logits, k, dim=-1, largest=True, sorted=True) or torch.sort on logits
then gather indices) and assert flashinfer.top_k(..., sorted=True) matches that
expected output and remains deterministic across calls. Ensure references remain
to flashinfer.top_k, logits, and the test function name.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8d61e949-8848-4bff-b118-c9851e4f86ea

📥 Commits

Reviewing files that changed from the base of the PR and between 904fa8c and fe9e258.

📒 Files selected for processing (6)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py



@pytest.mark.parametrize("algo", ["auto", "multi_cta", "filtered"])
def test_fp32_long_seq_refine_overflow_regression_across_algorithms(
Copy link
Copy Markdown
Contributor Author

@jiangyinzuo jiangyinzuo Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_fp32_long_seq_refine_overflow_regression_across_algorithms
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_overflow_inputs, algo in {auto, multi_cta, filtered},
api_kind="top_k")

test_fp32_long_seq_refine_overflow_transform_regression_across_algorithms
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_overflow_inputs, algo in {auto, multi_cta, filtered},
api_kind in {"page_table", "ragged"})

test_fp32_long_seq_pivot_rebuild_regression_filtered
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered",
api_kind="top_k")

test_fp32_long_seq_pivot_rebuild_transform_regression_filtered
-> test_fp32_long_seq_regression_matrix(builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered", api_kind in
{"page_table", "ragged"})

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai review

/gemini review

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code changes implement a deterministic mode for Top-K operations in FlashInfer. The implementation introduces a deterministic flag across the public API and internal CUDA kernels, ensuring bitwise-reproducible output ordering for the selected top-k set on a fixed input and system.

I have reviewed the changes, and they appear to correctly handle the necessary synchronization and deterministic collection logic required for this feature. The addition of RadixDeterministicCollectScratch and the use of DeterministicThreadStridedCollect are appropriate for maintaining determinism across multiple CTAs.

One minor suggestion for the test_fp32_long_seq_regression_matrix test: ensure that the builder functions are robust against potential OOM errors when running on smaller GPUs, as the stress cases added (stress_b2048_l131072_k1024 and stress_b4096_l200000_k1024) are quite memory-intensive. The current classify_benchmark_runtime_error helper is a good addition for handling these cases gracefully in the benchmark suite.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jiangyinzuo, the refactoring is clean and the coverage mapping is accurate. Here's a detailed breakdown:

Confirmed equivalence (old → new):

Old function New: test_fp32_long_seq_regression_matrix params
test_fp32_long_seq_refine_overflow_regression_across_algorithms builder=_build_fp32_long_seq_overflow_inputs, algo ∈ {auto, multi_cta, filtered}, api_kind="top_k"
test_fp32_long_seq_refine_overflow_transform_regression_across_algorithms builder=_build_fp32_long_seq_overflow_inputs, algo ∈ {auto, multi_cta, filtered}, api_kind ∈ {page_table, ragged}
test_fp32_long_seq_pivot_rebuild_regression_filtered builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered", api_kind="top_k"
test_fp32_long_seq_pivot_rebuild_transform_regression_filtered builder=_build_fp32_long_seq_pivot_mismatch_inputs, algo="filtered", api_kind ∈ {page_table, ragged}

Observations:

  1. Additive coverage bonus: _build_fp32_long_seq_pivot_mismatch_inputs with algo="filtered" is now also exercised against api_kind ∈ {page_table, ragged} — this is new coverage beyond what the original four functions provided, which is a net positive.

  2. Correct assertion strategy: _assert_top_k_matches_torch (used for api_kind="top_k") asserts exact sorted order via torch.testing.assert_close + torch.equal. _assert_transform_identity_matches_torch (used for transform api_kinds) correctly uses _assert_unordered_indices_match, since transform APIs don't expose a sorted parameter.

  3. ids=[...] usage: The explicit IDs (refine_overflow-auto, refine_overflow-multi_cta, etc.) make test reports readable — good practice.

  4. set_topk_algo fixture: Correctly scoped via yield with teardown restoring the env var, so algo overrides are isolated per test case.

  5. Minor nit: logits, _, _, k = builder() silently discards batch_size and vocab_size. A comment like # batch_size, vocab_size unused would aid readability, but this is very minor.

The refactoring looks correct and the consolidation into a single parameterized matrix test is an improvement. ✅


🧠 Learnings used
Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2773
File: include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh:27-32
Timestamp: 2026-03-12T21:29:16.342Z
Learning: In `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` (flashinfer-ai/flashinfer), the `static_assert` inside the `PHILOX_ROUNDS > 0` block that restricts stochastic rounding to fp16 state (`std::is_same_v<state_t, half>`) is intentionally kept in the CUDA header close to the implementation rather than being guarded by a pre-JIT Python-side runtime check. The maintainer prefers this colocation for easier auditability. Do not suggest moving or duplicating this constraint to the Python layer.

Learnt from: DomBrown
Repo: flashinfer-ai/flashinfer PR: 2770
File: flashinfer/decode.py:2231-2235
Timestamp: 2026-03-19T20:24:35.442Z
Learning: In `flashinfer/decode.py` (and related files `flashinfer/prefill.py`, `flashinfer/mla.py`), the `uses_shared_paged_kv_idx=False` mode is intended for direct TRT-LLM integration. When this flag is False, the `kv_cache` and `kv_block_scales` are expected to already be in TRT-LLM's native paged layout (separate K/V page indices, 3D block_tables `[batch_size, 2, max_num_pages_per_seq]`). The test code interleaves/reshapes tensors only to simulate TRT-LLM layout from a FlashInfer/vLLM-layout fixture — this is a test artifact, not a requirement imposed on real callers.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2865
File: include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh:343-366
Timestamp: 2026-03-23T21:04:23.630Z
Learning: FlashInfer Mamba SSU MTP kernels: When applying the z gate in epilogues, index z using its own strides (z_stride_batch/z_stride_mtp), not the output strides. Verified in include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh (role_epilogue) and include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh (inline epilogue).

Learnt from: yzh119
Repo: flashinfer-ai/flashinfer PR: 2370
File: tests/gdn/conftest.py:25-34
Timestamp: 2026-01-21T21:26:00.701Z
Learning: Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures. Ensure test files under tests/ follow this convention and avoid adding CPU-only guards in fixtures unless explicitly handling a non-CUDA environment.

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

/gemini review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 29, 2026

@jiangyinzuo: I'll kick off a full review of the PR now, covering the deterministic top-k implementation across all modified files (include/flashinfer/topk.cuh, csrc/topk.cu, csrc/flashinfer_topk_binding.cu, flashinfer/topk.py, benchmarks/bench_topk.py, and tests/utils/test_topk.py), including the CUB sort integration and overflow fix cherry-picked from @Linda-Stadter's fork.


🧠 Learnings used
Learnt from: Aalanli
Repo: flashinfer-ai/flashinfer PR: 2814
File: csrc/fast_topk_clusters.cu:135-141
Timestamp: 2026-03-20T17:55:09.735Z
Learning: In `csrc/fast_topk_clusters.cu`, the `fast_topk_cuda_v4` device function intentionally drops elements that fall in the threshold bin when the shared cache is full (`cached_offset >= num_cached`). This is a deliberate performance trade-off in the "fast" (inexact) TopK variant. The exact/correctness-preserving alternative is `fast_topk_clusters_exact` in `csrc/fast_topk_clusters_exact.cu`. Do not flag this overflow-drop behavior as a bug.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2709
File: include/flashinfer/mamba/seq_chunk_cumsum.cuh:0-0
Timestamp: 2026-03-06T20:52:57.849Z
Learning: In `include/flashinfer/mamba/seq_chunk_cumsum.cuh` and `csrc/seq_chunk_cumsum.cu`, the maintainer explicitly does not want runtime validation of metadata (chunk_indices, chunk_offsets, seq_idx bounds, monotonicity) in the kernel launcher or device code because this is a high-throughput kernel. Do not suggest adding such checks. Debug-mode assertions may be acceptable but should not be pushed.

Learnt from: Aalanli
Repo: flashinfer-ai/flashinfer PR: 2814
File: csrc/fast_topk_clusters_exact.cu:406-410
Timestamp: 2026-03-20T17:56:48.864Z
Learning: In `csrc/fast_topk_clusters.cu` and `csrc/fast_topk_clusters_exact.cu`, the launcher functions (`launch_fast_topk_clusters` and `launch_fast_topk_clusters_exact`) intentionally and silently fall back to `num_clusters=1, pdl_enabled=false` when an unsupported `num_clusters` value is passed. No warning log is emitted. Do not flag this silent fallback as a bug or suggest adding a warning.
✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces deterministic mode for FlashInfer's top-k operations, ensuring bitwise-reproducible results across basic, page table, and ragged transform variants. The implementation includes new CUDA kernels for deterministic collection and stable sorting, along with an expanded benchmarking suite for DeepSeek DSA workloads. The review feedback highlights opportunities to improve code maintainability by refactoring duplicated printing logic in the benchmark script and replacing macro-based shared memory aliasing in the CUDA kernels with a more readable struct-based approach.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
tests/utils/test_topk.py (1)

1487-1499: ⚠️ Potential issue | 🟡 Minor

Rename sorted; Ruff still flags this helper.

Line 1488 shadows the Python builtin, so this helper keeps tripping A002. sorted_output avoids the lint with no behavior change.

🧹 Minimal rename
 def _assert_top_k_matches_torch(
-    logits: torch.Tensor, k: int, *, deterministic: bool = False, sorted: bool = True
+    logits: torch.Tensor,
+    k: int,
+    *,
+    deterministic: bool = False,
+    sorted_output: bool = True,
 ):
     """Assert FlashInfer top_k matches torch.topk for exact-order cases."""
     values, indices = flashinfer.top_k(
-        logits, k, deterministic=deterministic, sorted=sorted
+        logits, k, deterministic=deterministic, sorted=sorted_output
     )
-    ref_values, ref_indices = torch.topk(logits, k, dim=-1, sorted=sorted)
+    ref_values, ref_indices = torch.topk(
+        logits, k, dim=-1, sorted=sorted_output
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_topk.py` around lines 1487 - 1499, Rename the parameter
named sorted in the helper function _assert_top_k_matches_torch to avoid
shadowing the built-in; change the parameter name to sorted_output and update
all uses inside the function (the flashinfer.top_k call and torch.topk call) to
pass sorted=sorted_output (and any internal references if present), leaving the
behavior and variable names values, indices, ref_values, ref_indices unchanged.
include/flashinfer/topk.cuh (2)

232-240: ⚠️ Potential issue | 🔴 Critical

Synchronize the CTA before publishing the radix-group arrival.

AdvanceRadixGroupBarrier() still lets Line 235 advance arrival_counter before the rest of the block is forced to finish its preceding histogram/output writes. The current callers at Line 468, Line 648, and Line 851 hit it immediately after per-thread atomics/stores, so another CTA can observe partially updated state and break correctness/determinism again.

🔧 Minimal fix
 __device__ __forceinline__ void AdvanceRadixGroupBarrier(RadixRowState* state, int& barrier_phase,
                                                          uint32_t ctas_per_group, uint32_t tx) {
+  __syncthreads();
   if (tx == 0) {
     red_release(&state->arrival_counter, 1);
   }
   int target = (barrier_phase + 1) * ctas_per_group;
   wait_ge(&state->arrival_counter, target, tx);

Expected result: either the helper owns the CTA sync, or every releasing call site shows an immediate __syncthreads() before it.

#!/bin/bash
set -euo pipefail
sed -n '232,240p' include/flashinfer/topk.cuh
sed -n '452,470p' include/flashinfer/topk.cuh
sed -n '635,650p' include/flashinfer/topk.cuh
sed -n '835,855p' include/flashinfer/topk.cuh
sed -n '1256,1263p' include/flashinfer/topk.cuh
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 232 - 240, AdvanceRadixGroupBarrier
currently releases the radix-group arrival (red_release(&state->arrival_counter,
1)) before the CTA is synchronized, allowing other CTAs to observe partially
written per-thread state; fix it by owning the CTA sync inside
AdvanceRadixGroupBarrier: add a __syncthreads() immediately before the tx==0
release path so the block finishes all histogram/output stores before calling
red_release, leaving the existing wait_ge(&state->arrival_counter, target, tx),
barrier_phase++, and trailing __syncthreads() intact.

3241-3258: ⚠️ Potential issue | 🟠 Major

Canonicalize radix ties before the stable value sort.

Line 3246 index-sorts only the filtered deterministic path. When Line 3251 routes deterministic work through radix, StableSortTopKByValue() on Line 3256 preserves the deterministic collection order from RadixCollectIndicesDeterministic, so sorted=True, deterministic=True still returns a different tie order depending on which algorithm was selected.

🔧 Suggested fix
   } else {
     FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>(
         input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len,
         row_states_buffer, deterministic, stream)));
+    if (deterministic && sorted_output) {
+      FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>(
+          output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len,
+          stream)));
+    }
   }
   if (sorted_output) {
     FLASHINFER_CUDA_CALL((StableSortTopKByValue<DType, IdType>(
         output_indices, output_values, num_rows, top_k_val, max_len, stream)));
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3241 - 3258, The deterministic
canonicalization (index-sort via LaunchSortTopKByIndex) is only applied in the
filtered path; ensure radix-based deterministic results are canonicalized the
same way before the stable value sort. After calling RadixTopKMultiCTA in the
else branch, if deterministic is true call
LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the same
arguments used in the filtered branch (output_indices, output_values, nullptr,
0, nullptr, num_rows, top_k_val, max_len, stream) so that StableSortTopKByValue
sees a canonical tie order regardless of which algorithm ran; keep the existing
filtered-path LaunchSortTopKByIndex and the final StableSortTopKByValue intact.
🧹 Nitpick comments (2)
benchmarks/bench_topk.py (1)

209-223: Consider using -float('inf') consistently for neg_inf fallback.

For fp16/bf16, using torch.finfo(dtype).min instead of -inf means values at the minimum representable float could still be selected over the masked positions. If the intent is to fully exclude masked positions from top-k selection, -inf (which is representable in fp16/bf16) would be more robust.

🔧 Suggested fix
-        neg_inf = -torch.inf if dtype == torch.float32 else torch.finfo(dtype).min
+        neg_inf = float('-inf')
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_topk.py` around lines 209 - 223, The masking uses
torch.finfo(dtype).min for neg_inf when dtype is fp16/bf16 which can still be
chosen; change the neg_inf computation in the causal_chunk block (where
start_pos, lengths, q_len, dtype are used) to use a true negative infinity
constant (e.g. -float('inf')) for the masked_fill value so masked positions are
fully excluded when you call scores = scores.masked_fill(invalid, neg_inf).
flashinfer/topk.py (1)

176-182: Docstring could clarify tie-breaking strategy for deterministic mode.

The PR objectives and issue #2584 mention that deterministic mode uses "lower element index wins" for tie-breaking. Consider adding this detail to the docstring so users understand the expected behavior when values are equal.

📝 Suggested docstring enhancement
     deterministic : bool, optional
         If True, uses deterministic mode.
         Default is False (non-deterministic, which is faster).
 
         Deterministic mode guarantees repeatable FlashInfer output ordering for
-        the selected top-k set on a fixed input and system.
+        the selected top-k set on a fixed input and system. When values are equal,
+        elements with lower indices are selected first (stable tie-breaking).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 176 - 182, Update the docstring for the
top_k function to explicitly state the tie-breaking rule used when
deterministic=True: when values are equal the element with the lower index is
chosen ("lower element index wins"). Mention this behavior near the
deterministic parameter description in the top_k docstring so callers know how
ties are resolved and that ordering may differ from non-deterministic behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/topk.cuh`:
- Around line 213-218: The cached row-state workspace currently doesn't include
space for the deterministic scratch tail used by
MaybeGetRadixDeterministicCollectScratchBuffer, so when deterministic &&
!single_cta the pointer (row_states_buffer + num_groups) can walk past the
allocated buffer; update all allocation sites that create
radix_topk_row_states_* (in flashinfer/topk.py and any C++/CUDA allocs) to
reserve room for both RadixRowState[num_groups] and
RadixDeterministicCollectScratch[num_groups] (i.e. allocate num_groups of
RadixRowState plus num_groups of RadixDeterministicCollectScratch, or
equivalently adjust byte-size to
num_groups*(sizeof(RadixRowState)+sizeof(RadixDeterministicCollectScratch))),
and ensure any cached size calculations and related comments reflect this change
so deterministic multi-CTA no longer overruns the buffer.

In `@tests/utils/test_topk.py`:
- Around line 1896-1937: The tests
test_top_k_deterministic_sorted_large_k_matches_torch_by_algo and
test_top_k_deterministic_trivial_k_equals_length_by_algo currently parametrize
over "filtered" but use k values (4096 and vocab_size) larger than
FILTERED_TOPK_MAX_K (defined in include/flashinfer/topk.cuh as 2048), so they
never exercise FilteredTopK; update the parametrization to only use ["auto",
"multi_cta"] for these two tests, or alternatively add a separate test case that
explicitly uses set_topk_algo("filtered") with k <= FILTERED_TOPK_MAX_K (e.g.,
k=2048) to validate the filtered path.

---

Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 232-240: AdvanceRadixGroupBarrier currently releases the
radix-group arrival (red_release(&state->arrival_counter, 1)) before the CTA is
synchronized, allowing other CTAs to observe partially written per-thread state;
fix it by owning the CTA sync inside AdvanceRadixGroupBarrier: add a
__syncthreads() immediately before the tx==0 release path so the block finishes
all histogram/output stores before calling red_release, leaving the existing
wait_ge(&state->arrival_counter, target, tx), barrier_phase++, and trailing
__syncthreads() intact.
- Around line 3241-3258: The deterministic canonicalization (index-sort via
LaunchSortTopKByIndex) is only applied in the filtered path; ensure radix-based
deterministic results are canonicalized the same way before the stable value
sort. After calling RadixTopKMultiCTA in the else branch, if deterministic is
true call LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType> with the
same arguments used in the filtered branch (output_indices, output_values,
nullptr, 0, nullptr, num_rows, top_k_val, max_len, stream) so that
StableSortTopKByValue sees a canonical tie order regardless of which algorithm
ran; keep the existing filtered-path LaunchSortTopKByIndex and the final
StableSortTopKByValue intact.

In `@tests/utils/test_topk.py`:
- Around line 1487-1499: Rename the parameter named sorted in the helper
function _assert_top_k_matches_torch to avoid shadowing the built-in; change the
parameter name to sorted_output and update all uses inside the function (the
flashinfer.top_k call and torch.topk call) to pass sorted=sorted_output (and any
internal references if present), leaving the behavior and variable names values,
indices, ref_values, ref_indices unchanged.

---

Nitpick comments:
In `@benchmarks/bench_topk.py`:
- Around line 209-223: The masking uses torch.finfo(dtype).min for neg_inf when
dtype is fp16/bf16 which can still be chosen; change the neg_inf computation in
the causal_chunk block (where start_pos, lengths, q_len, dtype are used) to use
a true negative infinity constant (e.g. -float('inf')) for the masked_fill value
so masked positions are fully excluded when you call scores =
scores.masked_fill(invalid, neg_inf).

In `@flashinfer/topk.py`:
- Around line 176-182: Update the docstring for the top_k function to explicitly
state the tie-breaking rule used when deterministic=True: when values are equal
the element with the lower index is chosen ("lower element index wins"). Mention
this behavior near the deterministic parameter description in the top_k
docstring so callers know how ties are resolved and that ordering may differ
from non-deterministic behavior.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ae026e77-9523-469a-8f61-393293ddf092

📥 Commits

Reviewing files that changed from the base of the PR and between fe9e258 and da7ad5e.

📒 Files selected for processing (6)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

/gemini review

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 30, 2026

@jiangyinzuo: I'll kick off a full review of the PR now, covering the deterministic top-k implementation across all modified files (include/flashinfer/topk.cuh, csrc/topk.cu, csrc/flashinfer_topk_binding.cu, flashinfer/topk.py, benchmarks/bench_topk.py, and tests/utils/test_topk.py), including the CUB sort integration and overflow fix cherry-picked from @Linda-Stadter's fork.


🧠 Learnings used
Learnt from: Aalanli
Repo: flashinfer-ai/flashinfer PR: 2814
File: csrc/fast_topk_clusters.cu:135-141
Timestamp: 2026-03-20T17:55:09.735Z
Learning: In `csrc/fast_topk_clusters.cu`, the `fast_topk_cuda_v4` device function intentionally drops elements that fall in the threshold bin when the shared cache is full (`cached_offset >= num_cached`). This is a deliberate performance trade-off in the "fast" (inexact) TopK variant. The exact/correctness-preserving alternative is `fast_topk_clusters_exact` in `csrc/fast_topk_clusters_exact.cu`. Do not flag this overflow-drop behavior as a bug.

Learnt from: Aalanli
Repo: flashinfer-ai/flashinfer PR: 2814
File: csrc/fast_topk_clusters_exact.cu:406-410
Timestamp: 2026-03-20T17:56:48.864Z
Learning: In `csrc/fast_topk_clusters.cu` and `csrc/fast_topk_clusters_exact.cu`, the launcher functions (`launch_fast_topk_clusters` and `launch_fast_topk_clusters_exact`) intentionally and silently fall back to `num_clusters=1, pdl_enabled=false` when an unsupported `num_clusters` value is passed. No warning log is emitted. Do not flag this silent fallback as a bug or suggest adding a warning.

Learnt from: ishovkun
Repo: flashinfer-ai/flashinfer PR: 2709
File: include/flashinfer/mamba/seq_chunk_cumsum.cuh:0-0
Timestamp: 2026-03-06T20:52:57.849Z
Learning: In `include/flashinfer/mamba/seq_chunk_cumsum.cuh` and `csrc/seq_chunk_cumsum.cu`, the maintainer explicitly does not want runtime validation of metadata (chunk_indices, chunk_offsets, seq_idx bounds, monotonicity) in the kernel launcher or device code because this is a high-throughput kernel. Do not suggest adding such checks. Debug-mode assertions may be acceptable but should not be pushed.
✅ Actions performed

Full review triggered.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
include/flashinfer/topk.cuh (1)

3259-3267: ⚠️ Potential issue | 🟠 Major

Radix deterministic path omits index canonicalization; tie order differs from filtered path.

The filtered path calls SortTopKByIndex when deterministic=true (line 3255-3257), but the radix path (line 3260-3262) does not. This means:

  • Filtered + deterministic: ties ordered by ascending index
  • Radix + deterministic: ties ordered by thread-strided emission

When sorted_output=true, StableSortTopKByValue preserves prior ordering for ties, so the final tie-breaking strategy depends on which algorithm was selected—breaking the contract that deterministic mode produces consistent, index-ordered ties.

🔧 Suggested fix
   } else {
     FLASHINFER_CUDA_CALL((RadixTopKMultiCTA<DType, IdType>(
         input, output_indices, output_values, nullptr, num_rows, top_k_val, max_len,
         row_states_buffer, deterministic, stream)));
+    if (deterministic) {
+      FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex<FilteredTopKMode::Plain, DType, IdType>(
+          output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len,
+          stream)));
+    }
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3259 - 3267, The radix
deterministic path omits the canonical index tie-break used by the filtered
path; after RadixTopKMultiCTA completes and when deterministic==true, call
SortTopKByIndex with the same arguments used by the filtered path (e.g.,
output_indices, output_values, num_rows, top_k_val, max_len, stream) before any
StableSortTopKByValue call so ties are canonicalized by ascending index; ensure
this conditional mirrors the filtered path's deterministic branch around
SortTopKByIndex so both algorithms produce identical tie order.
🧹 Nitpick comments (2)
flashinfer/topk.py (1)

63-73: Ruff flags input shadowing Python builtin.

The static analysis tool flags line 65 for shadowing Python's built-in input. However, this pattern is consistent with the existing codebase conventions for tensor parameter naming in this file. Given the "Chill" review mode and that this is a widespread pattern, this can be addressed in a separate cleanup if desired.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/topk.py` around lines 63 - 73, The parameter name input in the
_fake_radix_topk function shadows Python's builtin and should be renamed to
avoid the Ruff warning; update the function signature of _fake_radix_topk
(registered as "flashinfer::radix_topk") to use a non-builtins name (e.g.,
tensor, inp, or src_tensor) and replace all uses inside the function (input.size
and input.device) accordingly so behavior is unchanged.
include/flashinfer/topk.cuh (1)

3174-3198: Consider documenting heuristic rationale.

The deterministic-mode algorithm selection heuristics (lines 3174-3184) differ significantly from non-deterministic heuristics (lines 3186-3197). Consider adding a brief comment explaining the trade-off (e.g., filtered deterministic overhead vs. radix multi-CTA coordination cost).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/topk.cuh` around lines 3174 - 3198, Add a short
explanatory comment immediately above the block that branches on deterministic
and DType size (referencing variables/conditions: deterministic, sizeof(DType),
max_len, num_rows, and batch_threshold) that explains why deterministic-mode
thresholds differ from non-deterministic ones — e.g., deterministic
implementation favors simpler per-row filtered scans to avoid non-deterministic
cross-CTA radix coordination (hence lower thresholds like 16384 and the special
256 divisor), while non-deterministic heuristics accept radix/multi-CTA
strategies for larger max_len (notice thresholds 16384/32768 and the use of
max_len/4096 or /16384 to compute batch_threshold); keep the comment concise
(1–3 lines) describing the trade-off and pointing to the key constants so future
readers understand the rationale.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@include/flashinfer/topk.cuh`:
- Around line 3259-3267: The radix deterministic path omits the canonical index
tie-break used by the filtered path; after RadixTopKMultiCTA completes and when
deterministic==true, call SortTopKByIndex with the same arguments used by the
filtered path (e.g., output_indices, output_values, num_rows, top_k_val,
max_len, stream) before any StableSortTopKByValue call so ties are canonicalized
by ascending index; ensure this conditional mirrors the filtered path's
deterministic branch around SortTopKByIndex so both algorithms produce identical
tie order.

---

Nitpick comments:
In `@flashinfer/topk.py`:
- Around line 63-73: The parameter name input in the _fake_radix_topk function
shadows Python's builtin and should be renamed to avoid the Ruff warning; update
the function signature of _fake_radix_topk (registered as
"flashinfer::radix_topk") to use a non-builtins name (e.g., tensor, inp, or
src_tensor) and replace all uses inside the function (input.size and
input.device) accordingly so behavior is unchanged.

In `@include/flashinfer/topk.cuh`:
- Around line 3174-3198: Add a short explanatory comment immediately above the
block that branches on deterministic and DType size (referencing
variables/conditions: deterministic, sizeof(DType), max_len, num_rows, and
batch_threshold) that explains why deterministic-mode thresholds differ from
non-deterministic ones — e.g., deterministic implementation favors simpler
per-row filtered scans to avoid non-deterministic cross-CTA radix coordination
(hence lower thresholds like 16384 and the special 256 divisor), while
non-deterministic heuristics accept radix/multi-CTA strategies for larger
max_len (notice thresholds 16384/32768 and the use of max_len/4096 or /16384 to
compute batch_threshold); keep the comment concise (1–3 lines) describing the
trade-off and pointing to the key constants so future readers understand the
rationale.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4b20947c-0a7d-4e41-ace0-d46414a25216

📥 Commits

Reviewing files that changed from the base of the PR and between 779c24d and e8eaf33.

📒 Files selected for processing (6)
  • benchmarks/bench_topk.py
  • csrc/flashinfer_topk_binding.cu
  • csrc/topk.cu
  • flashinfer/topk.py
  • include/flashinfer/topk.cuh
  • tests/utils/test_topk.py

@jiangyinzuo
Copy link
Copy Markdown
Contributor Author

@yzh119 @Linda-Stadter This PR is ready for review now.

also add cub stable radix sort and overflow handling

Co-authored-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
}
}
if (tx == 0) {
st_release(&state->output_counter, 0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not safe for the nondeterministic path because there is no final __syncthreads() in RadixCollectIndices. But also, this is redundant anyway? I would then just remove it

uint32_t cta_local_eq_count = 0;
OrderedType ordered_pivot =
RadixSelectFindPivot<BLOCK_THREADS, VEC_SIZE, SINGLE_CTA, DETERMINISTIC, DType>(
input + row_idx * stride, shared_ordered, local_histogram, suffix_sum, shared_scalars,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't contain my overflow fix by casting to size_t. I will create another commit on top of this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Deterministic top-k kernels for sparse attention

2 participants