Skip to content

feat: further optimize top-k and add fused top-k page construction kernels for DSA#2215

Merged
yzh119 merged 16 commits intoflashinfer-ai:mainfrom
yzh119:fused-topk-page
Dec 19, 2025
Merged

feat: further optimize top-k and add fused top-k page construction kernels for DSA#2215
yzh119 merged 16 commits intoflashinfer-ai:mainfrom
yzh119:fused-topk-page

Conversation

@yzh119
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 commented Dec 13, 2025

📌 Description

Follow up of #2119 , this PR implements top_k_page_table_transform and top_k_ragged_transform function required for dsa in sglang. (fp16/bf16/fp32 are supported), as requested in #2221 .

This PR also adds the more top-k algorithm to choose from:

  1. filtered top-k (used in tilelang and sgl-kernel): store top-k candidates to shared memory, can be used for small-k
  2. multi-cta top-k (implemented in perf: bunch of features and optimizations for top-k (sampling + sparse attention) #2119): cache input at shared memory, works for any K.

and add heuristics to choose between algorthms.

Besides new features, this PR also moves the Top-K related functions outside of sampling.cuh (because they are not only used for sampling).

Benchmark on B200 for fp32 input:

====================================================================================================
top_k_page_table_transform: Fused top-k + page table gather (dtype=FP32)
NOTE: SGLang only supports k=2048 and float32
====================================================================================================
 batch    seq_len      k |   FlashInfer       SGLang    Speedup
------------------------------------------------------------------------------------------
     1       4096    256 |       8.13us
     1       4096    512 |       8.32us
     1       4096   1024 |       8.35us
     1       4096   2048 |       8.06us       8.03us      1.00x
     1       4096   4096 |       4.03us
     1      16384    256 |      11.78us
     1      16384    512 |      12.51us
     1      16384   1024 |      13.63us
     1      16384   2048 |      14.67us      14.85us      1.01x
     1      16384   4096 |      20.48us
     1      65536    256 |      31.14us
     1      65536    512 |      31.68us
     1      65536   1024 |      32.32us
     1      65536   2048 |      33.63us      40.03us      1.19x
     1      65536   4096 |      35.04us
     1     131072    256 |      35.90us
     1     131072    512 |      36.37us
     1     131072   1024 |      37.09us
     1     131072   2048 |      38.11us      68.02us      1.78x
     1     131072   4096 |      40.67us
     1     262144    256 |      40.56us
     1     262144    512 |      40.77us
     1     262144   1024 |      41.33us
     1     262144   2048 |      42.67us     123.57us      2.90x
     1     262144   4096 |      44.99us
     1     524288    256 |      40.58us
     1     524288    512 |      40.86us
     1     524288   1024 |      41.04us
     1     524288   2048 |      41.66us     208.86us      5.01x
     1     524288   4096 |      43.17us
    16       4096    256 |       8.51us
    16       4096    512 |       8.67us
    16       4096   1024 |       8.70us
    16       4096   2048 |       8.32us       8.16us      0.98x
    16       4096   4096 |       4.38us
    16      16384    256 |      12.13us
    16      16384    512 |      12.77us
    16      16384   1024 |      13.86us
    16      16384   2048 |      15.84us      15.90us      1.00x
    16      16384   4096 |      20.93us
    16      65536    256 |      24.19us
    16      65536    512 |      29.63us
    16      65536   1024 |      31.18us
    16      65536   2048 |      33.54us      41.60us      1.24x
    16      65536   4096 |      39.62us
    16     131072    256 |      40.62us
    16     131072    512 |      40.96us
    16     131072   1024 |      50.69us
    16     131072   2048 |      52.48us      69.63us      1.33x
    16     131072   4096 |      45.50us
    16     262144    256 |      45.28us
    16     262144    512 |      45.57us
    16     262144   1024 |      46.06us
    16     262144   2048 |      47.30us     125.23us      2.65x
    16     262144   4096 |      49.63us
    16     524288    256 |      87.14us
    16     524288    512 |      87.30us
    16     524288   1024 |      87.97us
    16     524288   2048 |      89.41us     211.57us      2.37x
    16     524288   4096 |      91.90us
    64       4096    256 |       8.74us
    64       4096    512 |       9.60us
    64       4096   1024 |       9.70us
    64       4096   2048 |       8.26us       8.26us      1.00x
    64       4096   4096 |       4.67us
    64      16384    256 |      13.15us
    64      16384    512 |      13.25us
    64      16384   1024 |      14.62us
    64      16384   2048 |      15.97us      16.35us      1.02x
    64      16384   4096 |      21.57us
    64      65536    256 |      24.96us
    64      65536    512 |      30.86us
    64      65536   1024 |      31.07us
    64      65536   2048 |      34.66us      43.15us      1.25x
    64      65536   4096 |      41.34us
    64     131072    256 |      41.54us
    64     131072    512 |      42.26us
    64     131072   1024 |      51.71us
    64     131072   2048 |      53.38us      71.25us      1.33x
    64     131072   4096 |      90.32us
    64     262144    256 |      72.61us
    64     262144    512 |      78.94us
    64     262144   1024 |      80.22us
    64     262144   2048 |      96.06us     147.15us      1.53x
    64     262144   4096 |     146.43us
    64     524288    256 |     157.69us
    64     524288    512 |     158.27us
    64     524288   1024 |     170.81us
    64     524288   2048 |     171.42us     353.32us      2.06x
    64     524288   4096 |     229.28us
   256       4096    256 |      16.53us
   256       4096    512 |      17.15us
   256       4096   1024 |      17.82us
   256       4096   2048 |      15.87us      15.68us      0.99x
   256       4096   4096 |       7.36us
   256      16384    256 |      25.34us
   256      16384    512 |      26.98us
   256      16384   1024 |      29.22us
   256      16384   2048 |      32.32us      33.34us      1.03x
   256      16384   4096 |      41.86us
   256      65536    256 |      51.73us
   256      65536    512 |      62.02us
   256      65536   1024 |      64.50us
   256      65536   2048 |      72.64us      89.60us      1.23x
   256      65536   4096 |     153.09us
   256     131072    256 |      88.58us
   256     131072    512 |      89.98us
   256     131072   1024 |     108.46us
   256     131072   2048 |     112.22us     152.14us      1.36x
   256     131072   4096 |     260.67us
   256     262144    256 |     173.65us
   256     262144    512 |     186.22us
   256     262144   1024 |     188.06us
   256     262144   2048 |     226.05us     396.01us      1.75x
   256     262144   4096 |     420.38us
   256     524288    256 |     322.33us
   256     524288    512 |     323.42us
   256     524288   1024 |     349.85us
   256     524288   2048 |     353.49us     709.18us      2.01x
   256     524288   4096 |     837.13us

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Fused Top‑K transforms (page‑table and ragged) exposed in the public API, plus a public capability check for filtered Top‑K and new benchmarking CLI with algorithm and optional SGLang comparisons.
  • Tests

    • Extensive fused Top‑K test suite covering dtypes, lengths, mappings, ragged/prefill modes, algorithm modes, and SGLang comparisons.
  • Documentation

    • New Top‑K API docs and cross‑references added.
  • Chores

    • Removed legacy multi‑CTA Radix Top‑K implementation.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 13, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds fused Top‑K transforms (page‑table and ragged) across Python, FFI, CUDA host, and device layers; removes multi‑CTA RadixTopK device machinery; introduces benchmarks and extensive tests for the new fused transforms and exposes capability inquiry for filtered TopK.

Changes

Cohort / File(s) Summary
Benchmark Suite
benchmarks/bench_topk.py
New CLI benchmarking module for top_k, page_table, and ragged benchmarks; supports SGLang comparisons, GPU timing, OOM handling, dtype parsing, and top-k algorithm toggles.
FFI / Bindings
csrc/flashinfer_topk_binding.cu
Added TVM/FFI export wrappers for radix_topk_page_table_transform, radix_topk_ragged_transform, and can_implement_filtered_topk.
CUDA Host Entrypoints
csrc/topk.cu
Added radix_topk_page_table_transform, radix_topk_ragged_transform, and can_implement_filtered_topk host entrypoints with validation, dtype dispatch, optional auxiliary-buffer handling, and unified TopK dispatch.
Device / Sampling Headers
include/flashinfer/sampling.cuh
Removed multi‑CTA RadixTopK machinery and RadixTopKTraits specializations; added topk.cuh include and collapsed prior multi‑CTA implementations.
Python API
flashinfer/topk.py
Added FFI wrappers/fake ops and high‑level fused ops top_k_page_table_transform / top_k_ragged_transform, allocation/row_states handling, aliases, and extended public module exports; added can_implement_filtered_topk.
Package Exports
flashinfer/__init__.py
Exported top_k_page_table_transform and top_k_ragged_transform at package level.
Test Suite
tests/utils/test_topk.py
Large new test suite covering page‑table and ragged fused transforms, SGLang-style references, many dtype/shape/padding/row_to_batch/offset cases, prefill/decode modes, and algorithm permutations.
Documentation
docs/api/topk.rst, docs/api/sampling.rst, docs/index.rst
Added Top‑K API docs and cross‑references; updated sampling docs to reference fused top-k utilities.
Build Manifest
CMakeLists.txt
Updated to reflect added topk host/device entry points and bindings.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant PyAPI as Python API\n(flashinfer/topk.py)
    participant FFI as TVM FFI\n(csrc/flashinfer_topk_binding.cu)
    participant Host as CUDA Host\n(csrc/topk.cu)
    participant Kernel as Device Kernel\n(include/.../topk.cuh)

    User->>PyAPI: call top_k_page_table_transform(...) / top_k_ragged_transform(...)
    activate PyAPI
    PyAPI->>PyAPI: validate inputs\nallocate row_states & outputs
    PyAPI->>FFI: call radix_topk_*(tensors...)
    deactivate PyAPI

    activate FFI
    FFI->>Host: forward tensors / invoke host entrypoint
    deactivate FFI

    activate Host
    Host->>Host: validate & dispatch dtype\nprepare CUDA stream / optional buffers
    Host->>Kernel: launch device kernel (page_table / ragged)
    deactivate Host

    activate Kernel
    Kernel->>Kernel: perform selection & transform\nwrite outputs
    Kernel-->>Host: outputs ready
    deactivate Kernel

    Host-->>FFI: return status
    FFI-->>PyAPI: return to caller
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • Attention points:
    • include/flashinfer/sampling.cuh — ensure removed multi‑CTA code has no remaining call sites and declarations moved to topk.cuh.
    • csrc/topk.cu & csrc/flashinfer_topk_binding.cu — verify dtype dispatch, optional buffer handling, stream synchronization, and FFI signatures.
    • flashinfer/topk.py — validate shape/stride contracts, allocations, and backwards compatibility.
    • tests/utils/test_topk.py — confirm reference implementations and edge-case coverage.

Possibly related PRs

Suggested reviewers

  • djmmoss
  • cyx-6
  • wenscarl
  • IwakuraRein
  • nvmbreughe
  • kahyunnam
  • jimmyzho
  • aleozlx

Poem

🐇 I hopped through kernels with indices bright,

Page tables folded neat in CUDA light,
Ragged offsets danced in tidy rows,
Benchmarks hummed where every fast path goes,
A rabbit cheers the top‑k's speedy flight.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 76.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: optimizing top-k and adding fused top-k page construction kernels for DSA, which aligns with the PR's primary objectives.
Description check ✅ Passed The PR description comprehensively covers the changes, references related issues, includes benchmark results, and completes all template checklist items with implementation details and algorithm choices explained.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's capabilities for sparse attention by introducing highly optimized, fused top-k kernels. These new kernels, top_k_page_table_transform and top_k_ragged_transform, integrate the top-k selection process with subsequent index transformations directly on the GPU. This fusion reduces kernel launch overhead and improves overall efficiency, particularly beneficial for complex sparse attention patterns. The changes involve a major refactoring of the CUDA top-k implementation into a unified kernel architecture, accompanied by new benchmarks and thorough testing to ensure correctness and performance gains.

Highlights

  • New Fused Top-K Kernels: Introduced top_k_page_table_transform and top_k_ragged_transform for sparse attention, combining top-k selection with index transformations.
  • Unified CUDA Kernel Architecture: Refactored the underlying CUDA implementation into a single, unified kernel (RadixTopKKernel_Unified) that efficiently handles basic top-k, page table transforms, and ragged transforms.
  • Performance Benchmarks: Added a new benchmark script (benchmarks/bench_topk.py) to evaluate the performance of these new fused kernels against torch.topk and SGLang's sgl_kernel.
  • Comprehensive Testing: Included extensive unit tests for correctness, edge cases, and large-scale scenarios for the new fused top-k transform operations, including comparisons with SGLang-style references.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two new fused Top-K operations, top_k_page_table_transform and top_k_ragged_transform, to FlashInfer, specifically designed for sparse attention. The changes include adding C++ CUDA kernel implementations for these operations, which utilize a unified multi-CTA radix select algorithm for efficiency, along with their Python bindings and high-level API functions. Comprehensive unit tests have been added to verify correctness across various scenarios, including trivial cases, variable lengths, row_to_batch mapping, and comparisons against SGLang-style reference implementations. A new benchmark script bench_topk.py was added to measure the performance of these new operations against torch.topk and optionally SGLang's sgl_kernel. Review comments suggest refactoring the main function in the benchmark script to reduce code duplication and removing redundant aliases in flashinfer/topk.py for cleaner code.

Comment on lines +192 to +294
if args.op in ["all", "top_k"]:
print("=" * 100)
print("top_k: Basic radix-based top-k selection")
print("=" * 100)
print(
f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}"
)
print("-" * 70)

for batch_size in batch_sizes:
for seq_len in seq_lens:
for k in k_values:
if k > seq_len:
continue
try:
result = bench_top_k(batch_size, seq_len, k)
print(
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
f"{result['flashinfer_us']:>10.2f}us {result['torch_us']:>10.2f}us "
f"{result['speedup_vs_torch']:>9.2f}x"
)
except RuntimeError as e:
if "out of memory" in str(e):
print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM")
torch.cuda.empty_cache()
else:
raise

if args.op in ["all", "page_table"]:
print("\n" + "=" * 100)
print("top_k_page_table_transform: Fused top-k + page table gather")
if args.compare_sglang:
print("NOTE: SGLang only supports k=2048")
print("=" * 100)

header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}"
if args.compare_sglang:
header += f" {'SGLang':>12} {'Speedup':>10}"
print(header)
print("-" * (70 if not args.compare_sglang else 90))

for batch_size in batch_sizes:
for seq_len in seq_lens:
for k in k_values:
if k > seq_len:
continue
try:
result = bench_page_table_transform(
batch_size, seq_len, k, compare_sglang=args.compare_sglang
)
line = (
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
f"{result['flashinfer_us']:>10.2f}us"
)
if "sglang_us" in result:
line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x"
elif args.compare_sglang and k == 2048:
line += " (SGLang error)"
print(line)
except RuntimeError as e:
if "out of memory" in str(e):
print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM")
torch.cuda.empty_cache()
else:
raise

if args.op in ["all", "ragged"]:
print("\n" + "=" * 100)
print("top_k_ragged_transform: Fused top-k + ragged index transform")
if args.compare_sglang:
print("NOTE: SGLang only supports k=2048")
print("=" * 100)

header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}"
if args.compare_sglang:
header += f" {'SGLang':>12} {'Speedup':>10}"
print(header)
print("-" * (70 if not args.compare_sglang else 90))

for batch_size in batch_sizes:
for seq_len in seq_lens:
for k in k_values:
if k > seq_len:
continue
try:
result = bench_ragged_transform(
batch_size, seq_len, k, compare_sglang=args.compare_sglang
)
line = (
f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | "
f"{result['flashinfer_us']:>10.2f}us"
)
if "sglang_us" in result:
line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x"
elif args.compare_sglang and k == 2048:
line += " (SGLang error)"
print(line)
except RuntimeError as e:
if "out of memory" in str(e):
print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM")
torch.cuda.empty_cache()
else:
raise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The main function contains significant code duplication for benchmarking each operation (top_k, page_table, ragged). The loop structure, OOM handling, and result printing are very similar across the three blocks. This could be refactored into a helper function to improve maintainability and reduce code size.

A helper function could take parameters like the operation name, the benchmark function to call, and configuration for printing headers and results. This would make the main function much cleaner and easier to extend with new benchmark operations in the future.

Comment on lines +407 to +408
topk_page_table_transform = top_k_page_table_transform
topk_ragged_transform = top_k_ragged_transform
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These aliases are redundant as the functions are already defined with these names. They can be removed for cleaner code.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems the aliases can be removed here

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
tests/utils/test_topk.py (3)

296-314: Minor: Unused parameter k in compute_transform_accuracy.

The k parameter is unused as the function derives valid entries from the -1 mask. This is actually correct behavior since some rows may have fewer than k valid entries.

Consider removing the unused k parameter to avoid confusion:

-def compute_transform_accuracy(test_output, ref_output, num_rows, k):
+def compute_transform_accuracy(test_output, ref_output, num_rows):

And update the call sites accordingly, or add a docstring explaining why k is present but unused (for API consistency).


1056-1065: Loop variable capture issue in benchmark closures.

The inner functions run_page_table and run_ragged capture loop variables (src_page_table, lengths, offsets) by reference. In Python, closures capture variables by reference, not by value. However, since these closures are immediately executed within the same loop iteration before the variables change, this is safe in practice.

For defensive coding, you could bind the variables explicitly:

def run_page_table(scores, k, _src_page_table=src_page_table, _lengths=lengths, **kw):
    return flashinfer.top_k_page_table_transform(
        scores, _src_page_table, _lengths, k
    )

However, since the closure is called immediately within the same iteration, the current code works correctly.


987-1013: Minor: Unused impl_name parameter.

The impl_name parameter is declared but never used in the function body.

Either remove the unused parameter or use it for logging/output:

def benchmark_topk_transform(
-   impl_name: str,
    impl_func,
    scores: torch.Tensor,
    k: int,
    warmup: int = 10,
    repeat: int = 100,
    **kwargs,
):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ac4e1d and 30d01e3.

📒 Files selected for processing (7)
  • benchmarks/bench_topk.py (1 hunks)
  • csrc/flashinfer_topk_binding.cu (1 hunks)
  • csrc/topk.cu (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/topk.py (2 hunks)
  • include/flashinfer/sampling.cuh (10 hunks)
  • tests/utils/test_topk.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
🧬 Code graph analysis (5)
csrc/topk.cu (1)
csrc/tvm_ffi_utils.h (1)
  • get_stream (294-296)
csrc/flashinfer_topk_binding.cu (2)
flashinfer/topk.py (4)
  • top_k (139-234)
  • radix_topk (34-52)
  • radix_topk_page_table_transform (68-88)
  • radix_topk_ragged_transform (106-119)
csrc/topk.cu (6)
  • radix_topk (24-60)
  • radix_topk (24-25)
  • radix_topk_page_table_transform (62-107)
  • radix_topk_page_table_transform (62-65)
  • radix_topk_ragged_transform (109-146)
  • radix_topk_ragged_transform (109-111)
tests/utils/test_topk.py (1)
flashinfer/topk.py (2)
  • top_k_page_table_transform (241-327)
  • top_k_ragged_transform (330-403)
benchmarks/bench_topk.py (2)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (998-1059)
flashinfer/topk.py (3)
  • top_k (139-234)
  • top_k_page_table_transform (241-327)
  • top_k_ragged_transform (330-403)
flashinfer/__init__.py (1)
flashinfer/topk.py (2)
  • top_k_page_table_transform (241-327)
  • top_k_ragged_transform (330-403)
🪛 Ruff (0.14.8)
tests/utils/test_topk.py

296-296: Unused function argument: k

(ARG001)


988-988: Unused function argument: impl_name

(ARG001)


1056-1056: Unused function argument: kw

(ARG001)


1058-1058: Function definition does not bind loop variable src_page_table

(B023)


1058-1058: Function definition does not bind loop variable lengths

(B023)


1064-1064: Unused function argument: kw

(ARG001)


1065-1065: Function definition does not bind loop variable offsets

(B023)


1065-1065: Function definition does not bind loop variable lengths

(B023)

benchmarks/bench_topk.py

32-32: Unused function argument: compare_sglang

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (25)
include/flashinfer/sampling.cuh (8)

2102-2141: LGTM - Clean abstraction for Stage 1 loading.

The LoadToSharedOrdered function provides a clean, reusable abstraction for vectorized loading and ordered representation conversion. The tail handling is correct.


2169-2315: LGTM - Well-structured radix select with proper synchronization.

The RadixSelectFromSharedMemory function correctly implements multi-round radix selection with proper barrier synchronization for multi-CTA mode. The use of macro aliases for shared scalar indices improves readability, and the #undef cleanup is correct.


2396-2477: LGTM - Correct two-pass index collection with tie handling.

The RadixCollectIndices function correctly implements the two-pass approach: first collecting elements strictly greater than pivot, then filling with equal elements until k is reached. The barrier synchronization for multi-CTA mode is properly placed.


2484-2488: LGTM - Clear enum for epilogue mode selection.

The RadixTopKMode enum clearly defines the three epilogue variants, enabling compile-time specialization in the unified kernel.


3131-3193: LGTM - Host entry point follows established pattern.

The RadixTopKRenormProbMultiCTA function correctly mirrors the structure of other multi-CTA entry points with proper device property queries, shared memory sizing, and kernel dispatch.


3214-3286: LGTM - Page table transform entry point is well-structured.

The RadixTopKPageTableTransformMultiCTA function correctly sets up the unified kernel with RadixTopKMode::PageTableTransform, properly handles optional row_to_batch mapping, and follows the established multi-CTA dispatch pattern.


3304-3377: LGTM - Ragged transform entry point correctly implemented.

The RadixTopKRaggedTransformMultiCTA function properly sets up the unified kernel with RadixTopKMode::RaggedTransform and follows the same pattern as the page table variant.


2505-2660: Trivial case handling is consistent and safe across all modes.

The unified kernel correctly handles trivial cases for each mode:

  • Basic mode (k >= length): The loop is bounded by chunk_end = min(chunk_start + chunk_size, length), ensuring no read beyond valid input even though k >= length. The condition chunk_start + i < k is checked before writing, though it's redundant since the loop bounds already prevent exceeding length.
  • PageTableTransform (length <= top_k_val): Properly guards access with i < length before reading from src_page_entry.
  • RaggedTransform (length <= top_k_val): Properly guards access with i < length before computing the offset index.

All three modes allocate output arrays as [num_rows, top_k_val], matching the write indices in their respective trivial case paths. No out-of-bounds access occurs in any mode.

flashinfer/__init__.py (1)

146-147: LGTM - New API exports follow established pattern.

The new top_k_page_table_transform and top_k_ragged_transform functions are correctly exported from the .topk module, maintaining consistency with the existing top_k export on line 145.

csrc/flashinfer_topk_binding.cu (1)

23-39: LGTM - FFI bindings correctly declare and export new transforms.

The function declarations match the implementations in csrc/topk.cu, and the TVM FFI export macros follow the established pattern for radix_topk.

tests/utils/test_topk.py (7)

234-263: LGTM - Reference implementation is correct.

The reference_page_table_transform function correctly implements the expected behavior: trivial case handling for length <= k and proper torch.topk-based selection otherwise. The -1 padding for remaining positions is handled via the initial torch.full allocation.


265-293: LGTM - Ragged reference implementation is correct.

The reference_ragged_transform function properly handles the offset addition for both trivial and non-trivial cases.


317-396: LGTM - Parametrized tests provide good coverage.

The test_top_k_page_table_transform and test_top_k_ragged_transform tests cover multiple configurations (num_rows, max_len, k, dtype) with appropriate accuracy thresholds.


402-473: LGTM - Trivial case tests verify padding behavior.

The trivial case tests properly verify both the copied values and the -1 padding behavior when length <= k.


709-760: LGTM - Correctness tests validate actual output values.

The test_page_table_transform_correctness_exact and test_ragged_transform_offset_correctness tests provide strong validation by checking that output values are valid (exist in page table or within offset range).


766-854: LGTM - SGLang-style references enable compatibility testing.

The SGLang-style reference implementations provide valuable cross-validation with an external implementation style, including prefill mode with cu_seqlens_q mapping.


1082-1114: LGTM - Main block provides quick smoke tests.

The if __name__ == "__main__" block provides a convenient way to run a subset of tests and benchmarks locally without pytest.

csrc/topk.cu (2)

62-107: LGTM - Page table transform entry point is well-implemented.

The radix_topk_page_table_transform function correctly:

  • Validates all required inputs with appropriate dimension checks
  • Extracts src_stride for potentially non-contiguous page tables
  • Handles optional row_to_batch and row_states_buffer parameters
  • Uses the type dispatch macro consistently
  • Includes proper error checking

109-146: LGTM - Ragged transform entry point follows established pattern.

The radix_topk_ragged_transform function correctly validates inputs, handles the optional row_states_buffer, and dispatches to the kernel with proper error checking.

benchmarks/bench_topk.py (2)

121-164: LGTM!

The bench_ragged_transform function is well-structured with proper parameter usage and consistent SGLang comparison logic.


167-318: LGTM with note on performance summary.

The main function provides comprehensive benchmark coverage with good error handling. The performance summary (lines 300-314) makes specific claims that should be verified by users on their hardware, as actual speedups may vary.

flashinfer/topk.py (4)

64-130: LGTM!

The custom op registrations for radix_topk_page_table_transform and radix_topk_ragged_transform follow the established pattern with proper dtype validation, mutation declarations, and fake op implementations for torch.compile compatibility.


132-136: LGTM!

The module namespace correctly exports the new fused operations, and the aliases provide API flexibility for users who prefer either naming convention.

Also applies to: 406-408


241-327: Well-designed API with comprehensive documentation and thorough test coverage.

The top_k_page_table_transform function provides a clean interface with excellent documentation and examples. The implementation correctly allocates buffers and delegates to the kernel. The edge case behavior described in the docstring (lines 286-287: "If lengths[i] <= k, the output simply contains src_page_table[batch_idx, 0:lengths[i]]") is thoroughly tested in test_page_table_transform_trivial_case (tests/utils/test_topk.py, lines 402-435), which explicitly validates this scenario with lengths uniformly sampled in the range [1, k].


330-403: Well-designed API with comprehensive documentation and test coverage.

The top_k_ragged_transform function is properly implemented with thorough documentation. The edge case behavior described in the docstring ("If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, ..., offsets[i]+lengths[i]-1]") is explicitly covered by test_ragged_transform_trivial_case, which validates both the valid entries and -1 padding for lengths ≤ k. Additional tests including variable lengths, large-scale scenarios, and correctness checks provide robust coverage.

Comment on lines +27 to +33
def bench_top_k(
batch_size: int,
seq_len: int,
k: int,
dtype: torch.dtype = torch.float32,
compare_sglang: bool = False,
) -> dict:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused compare_sglang parameter.

The compare_sglang parameter is declared but never used in this function, unlike bench_page_table_transform and bench_ragged_transform which implement SGLang comparisons. This creates API inconsistency and may confuse users.

Apply this diff to remove the unused parameter:

 def bench_top_k(
     batch_size: int,
     seq_len: int,
     k: int,
     dtype: torch.dtype = torch.float32,
-    compare_sglang: bool = False,
 ) -> dict:
🧰 Tools
🪛 Ruff (0.14.8)

32-32: Unused function argument: compare_sglang

(ARG001)

🤖 Prompt for AI Agents
In benchmarks/bench_topk.py around lines 27 to 33, the function bench_top_k
declares an unused parameter compare_sglang which creates an API inconsistency;
remove the compare_sglang parameter from the function signature and any
corresponding callers/tests (update imports/usage sites) so the function only
accepts batch_size, seq_len, k, and dtype, and update any docstrings/type hints
to reflect the removed parameter.

Comment on lines +2686 to +2689

// Fixed shared memory (at the beginning)
constexpr size_t fixed_smem_size =
sizeof(uint32_t) * (RADIX + RADIX + 4); // histogram + suffix + 4 scalars
// histogram[256] + suffix[256] + 5 scalars (for RadixSelectFromSharedMemory)
constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 5);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Verify shared_scalars allocation size consistency.

The fixed shared memory size calculation uses num_scalars = SINGLE_CTA ? 5 : 4 (line 2530), but the comment at line 2688-2689 mentions "5 scalars" for RadixSelectFromSharedMemory. Ensure the scalar count is consistent:

  • RadixSelectFromSharedMemory uses indices 0-4 (5 scalars)
  • For multi-CTA mode, the kernel sets num_scalars = 4, but the callee needs 5

This could cause a buffer overflow in multi-CTA mode.

The num_scalars calculation at line 2530 sets it to 4 for multi-CTA mode, but RadixSelectFromSharedMemory uses shared_scalars[4] for shared_output_counter in single-CTA mode. Verify that multi-CTA mode doesn't access shared_scalars[4]:

- constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4;
+ constexpr size_t num_scalars = 5;  // Always allocate 5 for consistency

Or confirm that RadixSelectFromSharedMemory only uses shared_scalars[4] when SINGLE_CTA is true.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 2686-2689 (and refer to
num_scalars at line 2530), the fixed shared-memory size comment and allocation
assume 5 scalar slots but num_scalars is computed as SINGLE_CTA ? 5 : 4, which
can under-allocate in multi-CTA mode while RadixSelectFromSharedMemory expects
index 0-4; fix by making the shared memory sizing consistent: either ensure
RadixSelectFromSharedMemory only accesses shared_scalars[4] when SINGLE_CTA is
true (add appropriate #if/SINGLE_CTA guards around that access), or change the
fixed_smem_size calculation to allocate for the maximum scalar count (use max
num_scalars = 5) and document it, and update any runtime num_scalars assignment
to match the allocation so no buffer overflow can occur.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
include/flashinfer/sampling.cuh (1)

2533-2542: Fix the out-of-bounds pointer formation in multi-CTA mode.

The allocation uses num_scalars = SINGLE_CTA ? 5 : 4;, but shared_output_counter is defined as shared_scalars[4] unconditionally. In multi-CTA mode, this forms a pointer to index 4 when only indices 0-3 are allocated.

While shared_output_counter is only dereferenced within if constexpr (SINGLE_CTA) blocks (lines 2447, 2468), forming the out-of-bounds pointer &shared_scalars[4] in multi-CTA mode violates C++ memory safety rules and could trigger warnings from strict sanitizers like AddressSanitizer. This inconsistency also conflicts with other kernels in the same file (e.g., RadixTopKMaskLogitsKernel_MultiCTA always allocates 5 scalars).

Consider one of these approaches:

  1. Always allocate 5 scalars: constexpr size_t num_scalars = 5;
  2. Use conditional pointer: Pass SINGLE_CTA ? &shared_scalars[4] : nullptr to functions
  3. Add a compile-time assertion to document the constraint
🧹 Nitpick comments (2)
include/flashinfer/sampling.cuh (2)

2721-2725: Consider aligning scalar allocation with RadixTopKKernel_Unified.

This kernel always allocates 5 scalars regardless of the SINGLE_CTA template parameter, whereas RadixTopKKernel_Unified (line 2533) conditionally allocates 4 or 5.

In multi-CTA mode, the 5th scalar is unused, wasting 4 bytes per block. While the memory overhead is minor, aligning the allocation strategy across kernels would improve consistency and maintainability.

Apply this diff if consistency is desired:

-  // Fixed shared memory (at the beginning)
-  // histogram[256] + suffix[256] + 5 scalars (for RadixSelectFromSharedMemory)
-  constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 5);
+  // Fixed shared memory (at the beginning)
+  // histogram[256] + suffix[256] + scalars (5 for single-CTA, 4 for multi-CTA)
+  constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4;
+  constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + num_scalars);

2930-2934: Fragile shared memory layout relies on shared_sum for overflow space.

The allocation provides 4 scalars (shared_scalars[0-3]) followed by shared_sum, but RadixSelectFromSharedMemory (called via RadixSelectFindPivot) expects shared_scalars[0-4]. In SINGLE_CTA mode, accessing shared_scalars[4] overlaps with the shared_sum memory.

This is currently safe because:

  1. shared_output_counter (shared_scalars[4]) is only accessed in SINGLE_CTA mode
  2. The memory at that location (shared_sum) provides valid backing storage
  3. No actual reads/writes cause type confusion

However, this implicit dependency makes the code fragile. If RadixSelectFromSharedMemory changes to read shared_output_counter in multi-CTA mode, or if the memory layout shifts, silent bugs could occur.

Consider making the layout explicit:

-  constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 4) + sizeof(float);
+  constexpr size_t num_scalars = 5;  // RadixSelectFromSharedMemory needs indices 0-4
+  constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + num_scalars) + sizeof(float);
   uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem);
   uint32_t* suffix_sum = local_histogram + RADIX;
   uint32_t* shared_scalars = suffix_sum + RADIX;
-  float* shared_sum = reinterpret_cast<float*>(shared_scalars + 4);
+  float* shared_sum = reinterpret_cast<float*>(shared_scalars + num_scalars);

This documents that 5 scalars are intentionally allocated, even if only 4 are used in multi-CTA mode, and clarifies the memory layout.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 30d01e3 and 76d3e27.

📒 Files selected for processing (1)
  • include/flashinfer/sampling.cuh (10 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
benchmarks/bench_topk.py (2)

37-73: Remove unused compare_sglang parameter.

This parameter is unused in bench_top_k, unlike the other benchmark functions which implement SGLang comparisons. This creates API inconsistency.


98-100: Remove commented-out code.

These lines appear to be leftover from development and should be removed.

🧹 Nitpick comments (4)
tests/utils/test_topk.py (3)

243-243: Remove discarded expression.

scores.size(1) is called but its return value is not used. This appears to be dead code, possibly a leftover from debugging or an incomplete validation.

     num_rows = scores.size(0)
-    scores.size(1)
     device = scores.device

296-314: Unused parameter k in function signature.

The k parameter is declared but never used. The function computes accuracy based on ref_set size from the actual data rather than the expected k. This is actually correct behavior (handles padding properly), but the parameter should be removed to avoid confusion.

-def compute_transform_accuracy(test_output, ref_output, num_rows, k):
+def compute_transform_accuracy(test_output, ref_output, num_rows):
     """Compute accuracy for transform outputs, handling -1 padding correctly."""

Note: You'll need to update all call sites (lines 353, 394, 518, 555, 592, 627, 662, 685, 705, 891, 928, 977) to remove the k argument.


785-785: Remove discarded expression.

Same issue as line 243 - scores.size(1) result is unused.

     num_rows = scores.size(0)
-    scores.size(1)
     device = scores.device
benchmarks/bench_topk.py (1)

20-25: Environment variable manipulation may have side effects.

The set_topk_algo function modifies os.environ which persists across benchmark runs. While this is reset to "auto" after algorithm comparison (line 262), if an exception occurs between setting the algorithm and the reset, the environment will be left in an inconsistent state.

Consider using a context manager pattern for safer environment variable handling:

from contextlib import contextmanager

@contextmanager
def topk_algo(algo: str):
    """Context manager to temporarily set topk algorithm."""
    old_value = os.environ.get("FLASHINFER_TOPK_ALGO")
    try:
        if algo == "auto":
            os.environ.pop("FLASHINFER_TOPK_ALGO", None)
        else:
            os.environ["FLASHINFER_TOPK_ALGO"] = algo
        yield
    finally:
        if old_value is None:
            os.environ.pop("FLASHINFER_TOPK_ALGO", None)
        else:
            os.environ["FLASHINFER_TOPK_ALGO"] = old_value
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 76d3e27 and 8b04191.

📒 Files selected for processing (3)
  • benchmarks/bench_topk.py (1 hunks)
  • csrc/topk.cu (1 hunks)
  • tests/utils/test_topk.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_topk.py (1)
flashinfer/topk.py (2)
  • top_k_page_table_transform (241-327)
  • top_k_ragged_transform (330-403)
benchmarks/bench_topk.py (2)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (998-1059)
flashinfer/topk.py (3)
  • top_k (139-234)
  • top_k_page_table_transform (241-327)
  • top_k_ragged_transform (330-403)
🪛 Ruff (0.14.8)
tests/utils/test_topk.py

296-296: Unused function argument: k

(ARG001)

benchmarks/bench_topk.py

42-42: Unused function argument: compare_sglang

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
csrc/topk.cu (2)

62-108: LGTM! Well-structured page table transform entry point.

The function follows the established pattern from radix_topk, with proper input validation, device/stream setup, optional buffer handling, and error checking. The src_stride extraction for non-contiguous page tables is a good detail.


110-148: LGTM! Ragged transform entry point is consistent with existing code.

The implementation mirrors the page table transform with appropriate adjustments for the ragged use case (offsets instead of page table, no row_to_batch mapping). Error handling and type dispatch are consistent.

tests/utils/test_topk.py (3)

317-396: Good test coverage for the new transform operations.

The parameterized tests cover a solid matrix of configurations (num_rows, max_len, k, dtype) for both page table and ragged transforms. The 0.95 accuracy threshold appropriately accounts for tie-breaking differences in radix-based vs heap-based selection.


399-474: Trivial case tests validate edge behavior correctly.

Tests properly verify that when length <= k, the output contains the expected sequential entries (page table values or offset-based indices) with -1 padding for remaining positions.


709-761: Exact correctness tests provide strong guarantees.

These tests use unique page table values and verify that all outputs exist in the correct range, providing stronger correctness guarantees beyond accuracy metrics.

benchmarks/bench_topk.py (3)

76-128: Page table transform benchmark implementation looks good.

The function properly sets up input tensors with contiguous page tables and handles the optional SGLang comparison with appropriate constraints (k=2048, float32).


131-174: Ragged transform benchmark is consistent with page table benchmark.

The implementation follows the same pattern with appropriate adjustments for ragged layout (offsets instead of page table).


191-277: Algorithm comparison mode is useful for performance analysis.

The comparison between Multi-CTA and Filtered algorithms provides valuable insights for tuning the automatic algorithm selection heuristics.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8b04191 and 32a55a4.

📒 Files selected for processing (1)
  • include/flashinfer/sampling.cuh (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/sampling.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/sampling.cuh (1)

36-36: Good refactoring: Top-K logic externalized.

The inclusion of topk.cuh appropriately replaces the inline multi-CTA Radix Top-K implementation that was removed from this file. This separation of concerns improves modularity and enables reuse of Top-K functionality.

Comment on lines +24 to +25
#include <cstdlib>
#include <cstring>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Check if cstdlib or cstring functions are used in sampling.cuh

# Common cstdlib functions
echo "=== Checking for cstdlib usage ==="
rg -nP '\b(malloc|calloc|realloc|free|atoi|atof|strtol|strtod|abs|rand|srand|exit|getenv)\s*\(' include/flashinfer/sampling.cuh

# Common cstring functions  
echo "=== Checking for cstring usage ==="
rg -nP '\b(memcpy|memset|memmove|memcmp|strlen|strcpy|strncpy|strcmp|strncmp|strcat|strchr|strstr)\s*\(' include/flashinfer/sampling.cuh

# Check if they're needed for other headers or forward declarations
echo "=== Checking topk.cuh dependencies ==="
rg -nP '\b(malloc|calloc|realloc|free|memcpy|memset|strlen|strcmp)\s*\(' include/flashinfer/topk.cuh

# Also check the includes at lines 24-25 to confirm they're there
echo "=== Current includes at lines 24-25 ==="
sed -n '24,25p' include/flashinfer/sampling.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 418


🏁 Script executed:

#!/bin/bash
# Check what includes topk.cuh has
echo "=== Includes in topk.cuh (first 50 lines) ==="
head -50 include/flashinfer/topk.cuh | grep -E "^#include"

echo ""
echo "=== Full include section of topk.cuh ==="
sed -n '1,/^[^#]/p' include/flashinfer/topk.cuh | grep "^#include"

Repository: flashinfer-ai/flashinfer

Length of output: 285


🏁 Script executed:

#!/bin/bash
# Check the include section of sampling.cuh to see include order
echo "=== First 50 lines of sampling.cuh ==="
head -50 include/flashinfer/sampling.cuh | cat -n

echo ""
echo "=== Search for topk.cuh include ==="
rg -n "topk.cuh" include/flashinfer/sampling.cuh

echo ""
echo "=== Check if cstring is included via other headers in sampling.cuh ==="
head -50 include/flashinfer/sampling.cuh | grep "#include"

Repository: flashinfer-ai/flashinfer

Length of output: 2405


Move <cstring> include to topk.cuh where it's actually used.

The includes <cstdlib> and <cstring> at lines 24-25 are not used directly in sampling.cuh. However, topk.cuh (included at line 36) uses std::strcmp at lines 2469-2470, which requires <cstring>. Currently, topk.cuh doesn't include <cstring> itself, creating an improper transitive dependency where sampling.cuh provides it. Move <cstring> to topk.cuh and remove both unused includes from sampling.cuh.

🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 24-25, remove the unused
includes <cstdlib> and <cstring>; then open include/flashinfer/topk.cuh and add
#include <cstring> near the other standard headers so topk.cuh directly provides
std::strcmp, eliminating the transitive dependency on sampling.cuh.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
benchmarks/bench_topk.py (3)

37-69: Unused compare_sglang parameter.

The compare_sglang parameter is declared but never used, unlike the other benchmark functions. This was flagged in a previous review.


94-96: Remove commented-out code.

This was flagged in a previous review. The commented-out benchmark alternative should be removed.


267-381: Code duplication in benchmark loops.

The three benchmark loops (top_k, page_table, ragged) share similar structure for iteration, OOM handling, and result printing. This was flagged in a previous review as a refactoring opportunity.

🧹 Nitpick comments (1)
benchmarks/bench_topk.py (1)

165-176: Consider adding error handling for invalid dtype strings.

Invalid dtype strings will raise a KeyError. A more descriptive error message would improve usability.

 def parse_dtype(dtype_str: str) -> torch.dtype:
     """Parse dtype string to torch.dtype."""
     dtype_map = {
         "fp32": torch.float32,
         "float32": torch.float32,
         "fp16": torch.float16,
         "float16": torch.float16,
         "half": torch.float16,
         "bf16": torch.bfloat16,
         "bfloat16": torch.bfloat16,
     }
-    return dtype_map[dtype_str.lower()]
+    key = dtype_str.lower()
+    if key not in dtype_map:
+        raise ValueError(f"Unknown dtype '{dtype_str}'. Valid options: {list(dtype_map.keys())}")
+    return dtype_map[key]
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32a55a4 and 65ec004.

📒 Files selected for processing (1)
  • benchmarks/bench_topk.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_topk.py (2)
flashinfer/testing/utils.py (1)
  • bench_gpu_time_with_cupti (646-878)
flashinfer/topk.py (3)
  • top_k (139-234)
  • top_k_page_table_transform (241-327)
  • top_k_ragged_transform (330-403)
🪛 Ruff (0.14.8)
benchmarks/bench_topk.py

42-42: Unused function argument: compare_sglang

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
benchmarks/bench_topk.py (6)

1-18: LGTM!

Clear module docstring and clean imports. All imported modules are used appropriately.


20-26: LGTM!

Simple and correct environment variable management for algorithm selection.


28-34: LGTM!

Standard pattern for optional dependency handling with a clear flag.


123-162: LGTM!

Correct construction of offsets for ragged indexing and consistent SGLang comparison logic.


179-265: LGTM!

Well-structured CLI with clear argument descriptions and algorithm comparison mode. The OOM handling correctly distinguishes memory errors from other runtime exceptions.


384-385: LGTM!

Standard entry point pattern.

Comment on lines +110 to +114
cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda")
measurements = bench_gpu_time_with_cupti(
lambda: sgl_kernel.fast_topk_transform_fused(
scores, lengths, src_page_table, cu_seqlens_q, k
),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix cu_seqlens_q construction to use actual sequence lengths.

The current construction torch.arange(0, batch_size + 1) produces sequential indices, but cumulative sequence lengths should reflect actual token positions. Compute it as torch.cumsum(lengths, dim=0, dtype=torch.int32) with a 0 prepended, or use torch.cat([torch.tensor([0], dtype=torch.int32, device='cuda'), torch.cumsum(lengths, dim=0, dtype=torch.int32)]).

🤖 Prompt for AI Agents
In benchmarks/bench_topk.py around lines 110 to 114, cu_seqlens_q is incorrectly
built using torch.arange(0, batch_size + 1) which creates sequential indices
instead of cumulative token positions; replace it with the cumulative sum of the
per-sequence lengths with a leading zero, e.g. construct cu_seqlens_q by
prepending 0 to torch.cumsum(lengths, dim=0, dtype=torch.int32) and ensure the
tensor is on device='cuda' (or use torch.cat([torch.tensor([0],
dtype=torch.int32, device='cuda'), torch.cumsum(lengths, dim=0,
dtype=torch.int32, device='cuda')])) so cu_seqlens_q correctly represents
cumulative sequence lengths.

@yzh119 yzh119 changed the title feat: add fused top-k page construction kernels for DSA feat: further optimize top-k and add fused top-k page construction kernels for DSA Dec 16, 2025
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 88a412e and 659ad09.

📒 Files selected for processing (1)
  • csrc/topk.cu (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/topk.cu (2)
flashinfer/topk.py (1)
  • top_k (139-234)
csrc/tvm_ffi_utils.h (1)
  • get_stream (294-296)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/topk.cu (2)

49-60: LGTM! Clean refactor to unified dispatch.

The switch from direct RadixTopKMultiCTA to TopKDispatch enables algorithm selection via heuristics, which aligns with the PR objectives. The error message update is appropriate.


50-50: No issue found: bf16 is supported by the dispatch macro.

The DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16 macro includes _DISPATCH_CASE_BF16 at line 105 of csrc/tvm_ffi_utils.h, which handles the bfloat16 data type by mapping it to the nv_bfloat16 C type. The macro's comment explicitly documents support for "FP32/FP16/BF16 data types." While the macro name is misleading (omitting BF16 from the identifier), the implementation correctly supports all three types as claimed in the PR objectives.

Comment on lines +62 to +108
void radix_topk_page_table_transform(TensorView input, TensorView output_page_table,
TensorView src_page_table,
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_page_table);
CHECK_INPUT(src_page_table);
CHECK_INPUT(lengths);
CHECK_DIM(2, input); // input: (num_rows, max_len)
CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k)
CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len)
CHECK_DIM(1, lengths); // lengths: (num_rows,)

unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);
int64_t src_stride = src_page_table.stride(0);

cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());

cudaError_t status;
auto dtype = input.dtype();

sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}

int32_t* row_to_batch_ptr = nullptr;
if (maybe_row_to_batch.has_value()) {
row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr());
}

// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()),
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr,
static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len,
row_states_ptr, stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKPageTableTransform failed with error code " << cudaGetErrorString(status);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Consider adding size validations for early error detection.

The function validates tensor ranks but not the actual dimension sizes. Given the complex parameter relationships in this page-table transform, adding explicit size checks would improve debuggability and user experience.

Consider adding these validations after line 73:

 CHECK_DIM(1, lengths);            // lengths: (num_rows,)
+
+unsigned int num_rows = input.size(0);
+unsigned int max_len = input.size(1);
+
+TVM_FFI_ICHECK(output_page_table.size(0) == num_rows)
+    << "output_page_table batch size mismatch: expected " << num_rows 
+    << ", got " << output_page_table.size(0);
+TVM_FFI_ICHECK(output_page_table.size(1) == top_k)
+    << "output_page_table second dimension mismatch: expected " << top_k 
+    << ", got " << output_page_table.size(1);
+TVM_FFI_ICHECK(lengths.size(0) == num_rows)
+    << "lengths size mismatch: expected " << num_rows 
+    << ", got " << lengths.size(0);
+TVM_FFI_ICHECK(src_page_table.size(1) == max_len)
+    << "src_page_table second dimension mismatch: expected " << max_len 
+    << ", got " << src_page_table.size(1);
-
-unsigned int num_rows = input.size(0);
-unsigned int max_len = input.size(1);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void radix_topk_page_table_transform(TensorView input, TensorView output_page_table,
TensorView src_page_table,
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_page_table);
CHECK_INPUT(src_page_table);
CHECK_INPUT(lengths);
CHECK_DIM(2, input); // input: (num_rows, max_len)
CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k)
CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len)
CHECK_DIM(1, lengths); // lengths: (num_rows,)
unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);
int64_t src_stride = src_page_table.stride(0);
cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());
cudaError_t status;
auto dtype = input.dtype();
sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}
int32_t* row_to_batch_ptr = nullptr;
if (maybe_row_to_batch.has_value()) {
row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr());
}
// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()),
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr,
static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len,
row_states_ptr, stream);
return true;
});
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKPageTableTransform failed with error code " << cudaGetErrorString(status);
}
void radix_topk_page_table_transform(TensorView input, TensorView output_page_table,
TensorView src_page_table,
Optional<TensorView> maybe_row_to_batch, TensorView lengths,
Optional<TensorView> maybe_row_states_buffer, int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_page_table);
CHECK_INPUT(src_page_table);
CHECK_INPUT(lengths);
CHECK_DIM(2, input); // input: (num_rows, max_len)
CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k)
CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len)
CHECK_DIM(1, lengths); // lengths: (num_rows,)
unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);
TVM_FFI_ICHECK(output_page_table.size(0) == num_rows)
<< "output_page_table batch size mismatch: expected " << num_rows
<< ", got " << output_page_table.size(0);
TVM_FFI_ICHECK(output_page_table.size(1) == top_k)
<< "output_page_table second dimension mismatch: expected " << top_k
<< ", got " << output_page_table.size(1);
TVM_FFI_ICHECK(lengths.size(0) == num_rows)
<< "lengths size mismatch: expected " << num_rows
<< ", got " << lengths.size(0);
TVM_FFI_ICHECK(src_page_table.size(1) == max_len)
<< "src_page_table second dimension mismatch: expected " << max_len
<< ", got " << src_page_table.size(1);
int64_t src_stride = src_page_table.stride(0);
cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());
cudaError_t status;
auto dtype = input.dtype();
sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}
int32_t* row_to_batch_ptr = nullptr;
if (maybe_row_to_batch.has_value()) {
row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr());
}
// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()),
static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr,
static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len,
row_states_ptr, stream);
return true;
});
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKPageTableTransform failed with error code " << cudaGetErrorString(status);
}
🤖 Prompt for AI Agents
In csrc/topk.cu around lines 62 to 108 (add checks after line 73), add explicit
size validations: ensure output_page_table.size(0) == input.size(0) and
lengths.size(0) == input.size(0) (num_rows), ensure input.size(1) ==
src_page_table.size(1) (max_len), ensure output_page_table.size(1) >=
static_cast<size_t>(top_k) (or == if contract requires exact top_k), if
maybe_row_to_batch is present check its numel == num_rows and that all entries
are within [0, src_page_table.size(0)-1], if maybe_row_states_buffer is present
validate its size is sufficient for num_rows (per RadixRowState element size),
and check that all values in lengths are <= max_len; on any mismatch return a
clear error via TVM_FFI_ICHECK or equivalent with descriptive message.

Comment on lines +110 to 148
void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(offsets);
CHECK_INPUT(lengths);
CHECK_DIM(2, input); // input: (num_rows, max_len)
CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
CHECK_DIM(1, offsets); // offsets: (num_rows,)
CHECK_DIM(1, lengths); // lengths: (num_rows,)

unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);

cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());

cudaError_t status;
auto dtype = input.dtype();

sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}

// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()),
num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, stream);
return true;
});

TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKRaggedTransform failed with error code " << cudaGetErrorString(status);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Consider adding size validations for early error detection.

Similar to radix_topk_page_table_transform, this function would benefit from explicit size checks to catch parameter mismatches early with clear error messages.

Consider adding these validations after line 120:

 CHECK_DIM(1, lengths);         // lengths: (num_rows,)
+
+unsigned int num_rows = input.size(0);
+unsigned int max_len = input.size(1);
+
+TVM_FFI_ICHECK(output_indices.size(0) == num_rows)
+    << "output_indices batch size mismatch: expected " << num_rows 
+    << ", got " << output_indices.size(0);
+TVM_FFI_ICHECK(output_indices.size(1) == top_k)
+    << "output_indices second dimension mismatch: expected " << top_k 
+    << ", got " << output_indices.size(1);
+TVM_FFI_ICHECK(offsets.size(0) == num_rows)
+    << "offsets size mismatch: expected " << num_rows 
+    << ", got " << offsets.size(0);
+TVM_FFI_ICHECK(lengths.size(0) == num_rows)
+    << "lengths size mismatch: expected " << num_rows 
+    << ", got " << lengths.size(0);
-
-unsigned int num_rows = input.size(0);
-unsigned int max_len = input.size(1);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(offsets);
CHECK_INPUT(lengths);
CHECK_DIM(2, input); // input: (num_rows, max_len)
CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
CHECK_DIM(1, offsets); // offsets: (num_rows,)
CHECK_DIM(1, lengths); // lengths: (num_rows,)
unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);
cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());
cudaError_t status;
auto dtype = input.dtype();
sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}
// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()),
num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, stream);
return true;
});
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKRaggedTransform failed with error code " << cudaGetErrorString(status);
}
void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets,
TensorView lengths, Optional<TensorView> maybe_row_states_buffer,
int64_t top_k) {
CHECK_INPUT(input);
CHECK_INPUT(output_indices);
CHECK_INPUT(offsets);
CHECK_INPUT(lengths);
CHECK_DIM(2, input); // input: (num_rows, max_len)
CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
CHECK_DIM(1, offsets); // offsets: (num_rows,)
CHECK_DIM(1, lengths); // lengths: (num_rows,)
unsigned int num_rows = input.size(0);
unsigned int max_len = input.size(1);
TVM_FFI_ICHECK(output_indices.size(0) == num_rows)
<< "output_indices batch size mismatch: expected " << num_rows
<< ", got " << output_indices.size(0);
TVM_FFI_ICHECK(output_indices.size(1) == top_k)
<< "output_indices second dimension mismatch: expected " << top_k
<< ", got " << output_indices.size(1);
TVM_FFI_ICHECK(offsets.size(0) == num_rows)
<< "offsets size mismatch: expected " << num_rows
<< ", got " << offsets.size(0);
TVM_FFI_ICHECK(lengths.size(0) == num_rows)
<< "lengths size mismatch: expected " << num_rows
<< ", got " << lengths.size(0);
cudaSetDevice(input.device().device_id);
auto stream = get_stream(input.device());
cudaError_t status;
auto dtype = input.dtype();
sampling::RadixRowState* row_states_ptr = nullptr;
if (maybe_row_states_buffer.has_value()) {
row_states_ptr =
static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
}
// Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] {
status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>(
static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()),
static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()),
num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, stream);
return true;
});
TVM_FFI_ICHECK(status == cudaSuccess)
<< "TopKRaggedTransform failed with error code " << cudaGetErrorString(status);
}
🤖 Prompt for AI Agents
In csrc/topk.cu around lines 120-126, add explicit size validations after
computing num_rows and max_len: verify offsets.size(0) == num_rows and
lengths.size(0) == num_rows, verify output_indices.size(0) == num_rows and
output_indices.size(1) == static_cast<size_t>(top_k), check that top_k > 0 and
top_k <= static_cast<int64_t>(max_len), and if
maybe_row_states_buffer.has_value() validate its byte size/element count matches
expected sampling::RadixRowState array length for num_rows; on failure, emit
clear TVM_FFI_ICHECK messages identifying the mismatched parameter.

@yzh119
Copy link
Copy Markdown
Collaborator Author

yzh119 commented Dec 17, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !198 has been created, and the CI pipeline #40387618 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aef0d0a and 2efb96b.

📒 Files selected for processing (3)
  • docs/api/sampling.rst (1 hunks)
  • docs/api/topk.rst (1 hunks)
  • docs/index.rst (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • docs/api/topk.rst
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
docs/index.rst (1)

39-39: The api/topk toctree entry is correct and the referenced documentation file is properly formatted.

The verification confirms that docs/api/topk.rst exists and contains complete, properly formatted documentation for the top-k functions (top_k, top_k_page_table_transform, top_k_ragged_transform, and can_implement_filtered_topk). The toctree entry placement is logical and follows the existing naming convention.

Comment on lines +8 to +12
.. seealso::

For efficient Top-K selection (without sampling), see :ref:`apitopk` which provides
:func:`~flashinfer.top_k`, :func:`~flashinfer.top_k_page_table_transform`, and
:func:`~flashinfer.top_k_ragged_transform`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Verify that the referenced functions are exposed in the public API and documentation exists

# Check 1: Verify the three top-k functions are exported in flashinfer/__init__.py
echo "=== Checking if top-k functions are exported in __init__.py ==="
rg -n "top_k_page_table_transform|top_k_ragged_transform|^top_k" flashinfer/__init__.py | head -20

# Check 2: Verify docs/api/topk.rst exists and contains the functions
echo -e "\n=== Checking if docs/api/topk.rst exists and documents these functions ==="
if [ -f "docs/api/topk.rst" ]; then
  echo "✓ docs/api/topk.rst exists"
  rg -n "top_k_page_table_transform|top_k_ragged_transform|.. _apitopk" docs/api/topk.rst
else
  echo "✗ docs/api/topk.rst not found"
fi

# Check 3: Verify the :ref:`apitopk` reference label is defined in docs/api/topk.rst
echo -e "\n=== Checking for the apitopk reference label ==="
rg -n ".. _apitopk" docs/api/topk.rst 2>/dev/null || echo "Not found or file doesn't exist"

# Check 4: Verify flashinfer/topk.py implements these functions
echo -e "\n=== Checking if topk.py exists and implements these functions ==="
if [ -f "flashinfer/topk.py" ]; then
  echo "✓ flashinfer/topk.py exists"
  rg -n "^def.*top_k" flashinfer/topk.py | head -10
else
  echo "✗ flashinfer/topk.py not found"
fi

# Check 5: Look for the functions anywhere in the codebase
echo -e "\n=== Searching for top_k function definitions anywhere ==="
rg -n "^def (top_k|top_k_page_table_transform|top_k_ragged_transform)" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 1192


Remove reference to top_k from the seealso section.

The seealso section references :func:~flashinfer.top_k``, but this function is not included in the apitopk documentation. Only :func:~flashinfer.top_k_page_table_transform` and `:func:`~flashinfer.top_k_ragged_transform` are documented in that page.

🤖 Prompt for AI Agents
In docs/api/sampling.rst around lines 8 to 12, the seealso currently references
:func:`~flashinfer.top_k` which does not exist in the apitopk page; remove the
:func:`~flashinfer.top_k` entry so the seealso lists only
:func:`~flashinfer.top_k_page_table_transform` and
:func:`~flashinfer.top_k_ragged_transform`, keeping the surrounding text and
Sphinx cross-reference syntax intact.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #40387618: 7/20 passed

Copy link
Copy Markdown
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work on the Top-K optimizations and the new fused kernels for DSA. The performance gains are impressive

Comment on lines +407 to +408
topk_page_table_transform = top_k_page_table_transform
topk_ragged_transform = top_k_ragged_transform
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems the aliases can be removed here

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Dec 18, 2025

everything else lgtm

Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for the tagging the functions with @flashinfer_api for logging purposes. Can we get these added?

The achieved speedups are impressive 👀

@yzh119 yzh119 enabled auto-merge (squash) December 18, 2025 22:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
flashinfer/topk.py (2)

258-344: Consider adding validation for k parameter.

The implementation is well-documented and follows existing patterns correctly. However, adding a defensive check for k > 0 would catch invalid inputs early with a clear error message rather than relying on kernel-side validation.

🔎 Optional validation check
 def top_k_page_table_transform(
     input: torch.Tensor,
     src_page_table: torch.Tensor,
     lengths: torch.Tensor,
     k: int,
     row_to_batch: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:
     r"""Fused Top-K selection + Page Table Transform for sparse attention.
+    if k <= 0:
+        raise ValueError(f"k must be positive, got {k}")
     device = input.device
     num_rows = input.size(0)

348-421: Consider adding validation for k parameter.

The implementation is well-documented and correctly structured. Similar to top_k_page_table_transform, adding a defensive check for k > 0 would improve early error detection.

🔎 Optional validation check
 def top_k_ragged_transform(
     input: torch.Tensor,
     offsets: torch.Tensor,
     lengths: torch.Tensor,
     k: int,
 ) -> torch.Tensor:
     r"""Fused Top-K selection + Ragged Index Transform for sparse attention.
+    if k <= 0:
+        raise ValueError(f"k must be positive, got {k}")
     device = input.device
     num_rows = input.size(0)
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2efb96b and e7ea229.

📒 Files selected for processing (1)
  • flashinfer/topk.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/topk.py (4)
flashinfer/api_logging.py (1)
  • flashinfer_api (464-565)
flashinfer/utils.py (4)
  • register_custom_op (319-328)
  • register_custom_op (338-357)
  • register_fake_op (330-334)
  • register_fake_op (359-364)
csrc/flashinfer_topk_binding.cu (4)
  • radix_topk_page_table_transform (23-26)
  • radix_topk_ragged_transform (28-30)
  • radix_topk (20-21)
  • can_implement_filtered_topk (32-32)
csrc/topk.cu (8)
  • radix_topk_page_table_transform (63-109)
  • radix_topk_page_table_transform (63-66)
  • radix_topk_ragged_transform (111-149)
  • radix_topk_ragged_transform (111-113)
  • radix_topk (25-61)
  • radix_topk (25-26)
  • can_implement_filtered_topk (151-151)
  • can_implement_filtered_topk (151-151)
🔇 Additional comments (4)
flashinfer/topk.py (4)

65-131: LGTM: Custom ops correctly implement fused Top-K transforms.

The custom ops and their fake counterparts follow the correct pattern:

  • Proper dtype validation (fp32/fp16/bf16)
  • Correct mutates_args specification
  • In-place mutation design matches PyTorch custom op best practices
  • Fake ops appropriately stubbed for shape inference

141-152: LGTM: Clean capability query with clear documentation.

The helper function provides a clean interface to check GPU support for FilteredTopK, with good documentation about the 128KB shared memory requirement.


257-258: Good: Decorator added per previous review feedback.

The @flashinfer_api decorator has been added as requested by aleozlx in previous reviews.


347-348: Good: Decorator added per previous review feedback.

The @flashinfer_api decorator has been added as requested by bkryu in previous reviews.

@yzh119 yzh119 disabled auto-merge December 19, 2025 05:45
@yzh119 yzh119 merged commit 3a301a1 into flashinfer-ai:main Dec 19, 2025
4 checks passed
vincentzed added a commit to bzhng-development/sglang that referenced this pull request Dec 31, 2025
Signed-off-by: vincentzed
<207368749+vincentzed@users.noreply.github.com>
vincentzed added a commit to bzhng-development/sglang that referenced this pull request Jan 16, 2026
Signed-off-by: vincentzed
<207368749+vincentzed@users.noreply.github.com>
vincentzed added a commit to bzhng-development/sglang that referenced this pull request Jan 16, 2026
Signed-off-by: vincentzed
<207368749+vincentzed@users.noreply.github.com>
@coderabbitai coderabbitai bot mentioned this pull request Mar 1, 2026
5 tasks
hammersam added a commit to hammersam/sglang that referenced this pull request Mar 8, 2026
Port the multi-CTA radix-based top-k kernel from flashinfer PR sgl-project#2215
(flashinfer-ai/flashinfer#2215) into sglang as
a JIT-compiled kernel. This replaces the existing AOT single-CTA top-k
implementation for NSA attention, providing better performance on long
sequences (32K+) where the multi-CTA path activates.

Key changes:

- Add `python/sglang/jit_kernel/topk.py`: Python API exposing three
  JIT top-k variants (basic, page-table transform, ragged transform)
  with workspace management and lazy compilation via `cache_once`.

- Add `python/sglang/jit_kernel/csrc/elementwise/topk.cuh`: CUDA wrapper
  providing TVM FFI entry points that dispatch to the flashinfer adaptive
  top-k kernels (TopKDispatch, TopKPageTableTransformDispatch,
  TopKRaggedTransformDispatch).

- Add `python/sglang/jit_kernel/include/sgl_kernel/topk_fi.cuh`: Core
  CUDA implementation adapted from flashinfer, featuring:
  - 8-bit radix selection algorithm with multi-CTA support for large
    sequences (threshold configurable, default 32K)
  - Support for float32, float16, and bfloat16 input types
  - row_starts parameter for ragged input score layouts (sglang-specific)
  - Three output modes: indices-only, page-table lookup, and ragged
    offset addition

- Update `python/sglang/srt/layers/attention/nsa_backend.py`: Switch
  NSA indexer to import from JIT kernel instead of AOT sgl_kernel.

- Update `sgl-kernel/python/sgl_kernel/top_k.py`: Add JIT fallback path
  controlled by SGLANG_USE_JIT_TOPK env var (default enabled). When JIT
  is available, fast_topk_v2 / fast_topk_transform_fused /
  fast_topk_transform_ragged_fused transparently delegate to JIT kernels.

- Add `sgl-kernel/tests/test_topk_jit.py`: Correctness tests covering
  basic, page-table, ragged, and trivial (length <= topk) cases across
  various batch sizes and sequence lengths up to 131K.

- Add `sgl-kernel/benchmarks/bench_topk_jit.py`: Latency benchmark
  comparing JIT multi-CTA vs AOT single-CTA kernels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
hammersam added a commit to hammersam/sglang that referenced this pull request Mar 8, 2026
Port the multi-CTA radix-based top-k kernel from
flashinfer-ai/flashinfer#2215 into sglang as
a JIT-compiled kernel. This replaces the existing AOT single-CTA top-k
implementation for NSA attention, providing better performance on long
sequences (32K+) where the multi-CTA path activates.

Key changes:

- Add `python/sglang/jit_kernel/topk.py`: Python API exposing three
  JIT top-k variants (basic, page-table transform, ragged transform)
  with workspace management and lazy compilation via `cache_once`.

- Add `python/sglang/jit_kernel/csrc/elementwise/topk.cuh`: CUDA wrapper
  providing TVM FFI entry points that dispatch to the flashinfer adaptive
  top-k kernels (TopKDispatch, TopKPageTableTransformDispatch,
  TopKRaggedTransformDispatch).

- Add `python/sglang/jit_kernel/include/sgl_kernel/topk_fi.cuh`: Core
  CUDA implementation adapted from flashinfer, featuring:
  - 8-bit radix selection algorithm with multi-CTA support for large
    sequences (threshold configurable, default 32K)
  - Support for float32, float16, and bfloat16 input types
  - row_starts parameter for ragged input score layouts (sglang-specific)
  - Three output modes: indices-only, page-table lookup, and ragged
    offset addition

- Update `python/sglang/srt/layers/attention/nsa_backend.py`: Switch
  NSA indexer to import from JIT kernel instead of AOT sgl_kernel.

- Update `sgl-kernel/python/sgl_kernel/top_k.py`: Add JIT fallback path
  controlled by SGLANG_USE_JIT_TOPK env var (default enabled). When JIT
  is available, fast_topk_v2 / fast_topk_transform_fused /
  fast_topk_transform_ragged_fused transparently delegate to JIT kernels.

- Add `sgl-kernel/tests/test_topk_jit.py`: Correctness tests covering
  basic, page-table, ragged, and trivial (length <= topk) cases across
  various batch sizes and sequence lengths up to 131K.(TODO(yifan): 1M)

- Add `sgl-kernel/benchmarks/bench_topk_jit.py`: Latency benchmark
  comparing JIT multi-CTA vs AOT single-CTA kernels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants