feat: further optimize top-k and add fused top-k page construction kernels for DSA#2215
feat: further optimize top-k and add fused top-k page construction kernels for DSA#2215yzh119 merged 16 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds fused Top‑K transforms (page‑table and ragged) across Python, FFI, CUDA host, and device layers; removes multi‑CTA RadixTopK device machinery; introduces benchmarks and extensive tests for the new fused transforms and exposes capability inquiry for filtered TopK. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant PyAPI as Python API\n(flashinfer/topk.py)
participant FFI as TVM FFI\n(csrc/flashinfer_topk_binding.cu)
participant Host as CUDA Host\n(csrc/topk.cu)
participant Kernel as Device Kernel\n(include/.../topk.cuh)
User->>PyAPI: call top_k_page_table_transform(...) / top_k_ragged_transform(...)
activate PyAPI
PyAPI->>PyAPI: validate inputs\nallocate row_states & outputs
PyAPI->>FFI: call radix_topk_*(tensors...)
deactivate PyAPI
activate FFI
FFI->>Host: forward tensors / invoke host entrypoint
deactivate FFI
activate Host
Host->>Host: validate & dispatch dtype\nprepare CUDA stream / optional buffers
Host->>Kernel: launch device kernel (page_table / ragged)
deactivate Host
activate Kernel
Kernel->>Kernel: perform selection & transform\nwrite outputs
Kernel-->>Host: outputs ready
deactivate Kernel
Host-->>FFI: return status
FFI-->>PyAPI: return to caller
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities for sparse attention by introducing highly optimized, fused top-k kernels. These new kernels, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces two new fused Top-K operations, top_k_page_table_transform and top_k_ragged_transform, to FlashInfer, specifically designed for sparse attention. The changes include adding C++ CUDA kernel implementations for these operations, which utilize a unified multi-CTA radix select algorithm for efficiency, along with their Python bindings and high-level API functions. Comprehensive unit tests have been added to verify correctness across various scenarios, including trivial cases, variable lengths, row_to_batch mapping, and comparisons against SGLang-style reference implementations. A new benchmark script bench_topk.py was added to measure the performance of these new operations against torch.topk and optionally SGLang's sgl_kernel. Review comments suggest refactoring the main function in the benchmark script to reduce code duplication and removing redundant aliases in flashinfer/topk.py for cleaner code.
| if args.op in ["all", "top_k"]: | ||
| print("=" * 100) | ||
| print("top_k: Basic radix-based top-k selection") | ||
| print("=" * 100) | ||
| print( | ||
| f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12} {'torch.topk':>12} {'Speedup':>10}" | ||
| ) | ||
| print("-" * 70) | ||
|
|
||
| for batch_size in batch_sizes: | ||
| for seq_len in seq_lens: | ||
| for k in k_values: | ||
| if k > seq_len: | ||
| continue | ||
| try: | ||
| result = bench_top_k(batch_size, seq_len, k) | ||
| print( | ||
| f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " | ||
| f"{result['flashinfer_us']:>10.2f}us {result['torch_us']:>10.2f}us " | ||
| f"{result['speedup_vs_torch']:>9.2f}x" | ||
| ) | ||
| except RuntimeError as e: | ||
| if "out of memory" in str(e): | ||
| print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM") | ||
| torch.cuda.empty_cache() | ||
| else: | ||
| raise | ||
|
|
||
| if args.op in ["all", "page_table"]: | ||
| print("\n" + "=" * 100) | ||
| print("top_k_page_table_transform: Fused top-k + page table gather") | ||
| if args.compare_sglang: | ||
| print("NOTE: SGLang only supports k=2048") | ||
| print("=" * 100) | ||
|
|
||
| header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}" | ||
| if args.compare_sglang: | ||
| header += f" {'SGLang':>12} {'Speedup':>10}" | ||
| print(header) | ||
| print("-" * (70 if not args.compare_sglang else 90)) | ||
|
|
||
| for batch_size in batch_sizes: | ||
| for seq_len in seq_lens: | ||
| for k in k_values: | ||
| if k > seq_len: | ||
| continue | ||
| try: | ||
| result = bench_page_table_transform( | ||
| batch_size, seq_len, k, compare_sglang=args.compare_sglang | ||
| ) | ||
| line = ( | ||
| f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " | ||
| f"{result['flashinfer_us']:>10.2f}us" | ||
| ) | ||
| if "sglang_us" in result: | ||
| line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x" | ||
| elif args.compare_sglang and k == 2048: | ||
| line += " (SGLang error)" | ||
| print(line) | ||
| except RuntimeError as e: | ||
| if "out of memory" in str(e): | ||
| print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM") | ||
| torch.cuda.empty_cache() | ||
| else: | ||
| raise | ||
|
|
||
| if args.op in ["all", "ragged"]: | ||
| print("\n" + "=" * 100) | ||
| print("top_k_ragged_transform: Fused top-k + ragged index transform") | ||
| if args.compare_sglang: | ||
| print("NOTE: SGLang only supports k=2048") | ||
| print("=" * 100) | ||
|
|
||
| header = f"{'batch':>6} {'seq_len':>10} {'k':>6} | {'FlashInfer':>12}" | ||
| if args.compare_sglang: | ||
| header += f" {'SGLang':>12} {'Speedup':>10}" | ||
| print(header) | ||
| print("-" * (70 if not args.compare_sglang else 90)) | ||
|
|
||
| for batch_size in batch_sizes: | ||
| for seq_len in seq_lens: | ||
| for k in k_values: | ||
| if k > seq_len: | ||
| continue | ||
| try: | ||
| result = bench_ragged_transform( | ||
| batch_size, seq_len, k, compare_sglang=args.compare_sglang | ||
| ) | ||
| line = ( | ||
| f"{result['batch_size']:>6} {result['seq_len']:>10} {result['k']:>6} | " | ||
| f"{result['flashinfer_us']:>10.2f}us" | ||
| ) | ||
| if "sglang_us" in result: | ||
| line += f" {result['sglang_us']:>10.2f}us {result['speedup_vs_sglang']:>9.2f}x" | ||
| elif args.compare_sglang and k == 2048: | ||
| line += " (SGLang error)" | ||
| print(line) | ||
| except RuntimeError as e: | ||
| if "out of memory" in str(e): | ||
| print(f"{batch_size:>6} {seq_len:>10} {k:>6} | OOM") | ||
| torch.cuda.empty_cache() | ||
| else: | ||
| raise |
There was a problem hiding this comment.
The main function contains significant code duplication for benchmarking each operation (top_k, page_table, ragged). The loop structure, OOM handling, and result printing are very similar across the three blocks. This could be refactored into a helper function to improve maintainability and reduce code size.
A helper function could take parameters like the operation name, the benchmark function to call, and configuration for printing headers and results. This would make the main function much cleaner and easier to extend with new benchmark operations in the future.
flashinfer/topk.py
Outdated
| topk_page_table_transform = top_k_page_table_transform | ||
| topk_ragged_transform = top_k_ragged_transform |
There was a problem hiding this comment.
seems the aliases can be removed here
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
tests/utils/test_topk.py (3)
296-314: Minor: Unused parameterkincompute_transform_accuracy.The
kparameter is unused as the function derives valid entries from the-1mask. This is actually correct behavior since some rows may have fewer thankvalid entries.Consider removing the unused
kparameter to avoid confusion:-def compute_transform_accuracy(test_output, ref_output, num_rows, k): +def compute_transform_accuracy(test_output, ref_output, num_rows):And update the call sites accordingly, or add a docstring explaining why
kis present but unused (for API consistency).
1056-1065: Loop variable capture issue in benchmark closures.The inner functions
run_page_tableandrun_raggedcapture loop variables (src_page_table,lengths,offsets) by reference. In Python, closures capture variables by reference, not by value. However, since these closures are immediately executed within the same loop iteration before the variables change, this is safe in practice.For defensive coding, you could bind the variables explicitly:
def run_page_table(scores, k, _src_page_table=src_page_table, _lengths=lengths, **kw): return flashinfer.top_k_page_table_transform( scores, _src_page_table, _lengths, k )However, since the closure is called immediately within the same iteration, the current code works correctly.
987-1013: Minor: Unusedimpl_nameparameter.The
impl_nameparameter is declared but never used in the function body.Either remove the unused parameter or use it for logging/output:
def benchmark_topk_transform( - impl_name: str, impl_func, scores: torch.Tensor, k: int, warmup: int = 10, repeat: int = 100, **kwargs, ):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
benchmarks/bench_topk.py(1 hunks)csrc/flashinfer_topk_binding.cu(1 hunks)csrc/topk.cu(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/topk.py(2 hunks)include/flashinfer/sampling.cuh(10 hunks)tests/utils/test_topk.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
🧬 Code graph analysis (5)
csrc/topk.cu (1)
csrc/tvm_ffi_utils.h (1)
get_stream(294-296)
csrc/flashinfer_topk_binding.cu (2)
flashinfer/topk.py (4)
top_k(139-234)radix_topk(34-52)radix_topk_page_table_transform(68-88)radix_topk_ragged_transform(106-119)csrc/topk.cu (6)
radix_topk(24-60)radix_topk(24-25)radix_topk_page_table_transform(62-107)radix_topk_page_table_transform(62-65)radix_topk_ragged_transform(109-146)radix_topk_ragged_transform(109-111)
tests/utils/test_topk.py (1)
flashinfer/topk.py (2)
top_k_page_table_transform(241-327)top_k_ragged_transform(330-403)
benchmarks/bench_topk.py (2)
flashinfer/testing/utils.py (1)
bench_gpu_time(998-1059)flashinfer/topk.py (3)
top_k(139-234)top_k_page_table_transform(241-327)top_k_ragged_transform(330-403)
flashinfer/__init__.py (1)
flashinfer/topk.py (2)
top_k_page_table_transform(241-327)top_k_ragged_transform(330-403)
🪛 Ruff (0.14.8)
tests/utils/test_topk.py
296-296: Unused function argument: k
(ARG001)
988-988: Unused function argument: impl_name
(ARG001)
1056-1056: Unused function argument: kw
(ARG001)
1058-1058: Function definition does not bind loop variable src_page_table
(B023)
1058-1058: Function definition does not bind loop variable lengths
(B023)
1064-1064: Unused function argument: kw
(ARG001)
1065-1065: Function definition does not bind loop variable offsets
(B023)
1065-1065: Function definition does not bind loop variable lengths
(B023)
benchmarks/bench_topk.py
32-32: Unused function argument: compare_sglang
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (25)
include/flashinfer/sampling.cuh (8)
2102-2141: LGTM - Clean abstraction for Stage 1 loading.The
LoadToSharedOrderedfunction provides a clean, reusable abstraction for vectorized loading and ordered representation conversion. The tail handling is correct.
2169-2315: LGTM - Well-structured radix select with proper synchronization.The
RadixSelectFromSharedMemoryfunction correctly implements multi-round radix selection with proper barrier synchronization for multi-CTA mode. The use of macro aliases for shared scalar indices improves readability, and the#undefcleanup is correct.
2396-2477: LGTM - Correct two-pass index collection with tie handling.The
RadixCollectIndicesfunction correctly implements the two-pass approach: first collecting elements strictly greater than pivot, then filling with equal elements until k is reached. The barrier synchronization for multi-CTA mode is properly placed.
2484-2488: LGTM - Clear enum for epilogue mode selection.The
RadixTopKModeenum clearly defines the three epilogue variants, enabling compile-time specialization in the unified kernel.
3131-3193: LGTM - Host entry point follows established pattern.The
RadixTopKRenormProbMultiCTAfunction correctly mirrors the structure of other multi-CTA entry points with proper device property queries, shared memory sizing, and kernel dispatch.
3214-3286: LGTM - Page table transform entry point is well-structured.The
RadixTopKPageTableTransformMultiCTAfunction correctly sets up the unified kernel withRadixTopKMode::PageTableTransform, properly handles optionalrow_to_batchmapping, and follows the established multi-CTA dispatch pattern.
3304-3377: LGTM - Ragged transform entry point correctly implemented.The
RadixTopKRaggedTransformMultiCTAfunction properly sets up the unified kernel withRadixTopKMode::RaggedTransformand follows the same pattern as the page table variant.
2505-2660: Trivial case handling is consistent and safe across all modes.The unified kernel correctly handles trivial cases for each mode:
- Basic mode (k >= length): The loop is bounded by
chunk_end = min(chunk_start + chunk_size, length), ensuring no read beyond valid input even thoughk >= length. The conditionchunk_start + i < kis checked before writing, though it's redundant since the loop bounds already prevent exceeding length.- PageTableTransform (length <= top_k_val): Properly guards access with
i < lengthbefore reading fromsrc_page_entry.- RaggedTransform (length <= top_k_val): Properly guards access with
i < lengthbefore computing the offset index.All three modes allocate output arrays as
[num_rows, top_k_val], matching the write indices in their respective trivial case paths. No out-of-bounds access occurs in any mode.flashinfer/__init__.py (1)
146-147: LGTM - New API exports follow established pattern.The new
top_k_page_table_transformandtop_k_ragged_transformfunctions are correctly exported from the.topkmodule, maintaining consistency with the existingtop_kexport on line 145.csrc/flashinfer_topk_binding.cu (1)
23-39: LGTM - FFI bindings correctly declare and export new transforms.The function declarations match the implementations in
csrc/topk.cu, and the TVM FFI export macros follow the established pattern forradix_topk.tests/utils/test_topk.py (7)
234-263: LGTM - Reference implementation is correct.The
reference_page_table_transformfunction correctly implements the expected behavior: trivial case handling forlength <= kand proper torch.topk-based selection otherwise. The -1 padding for remaining positions is handled via the initialtorch.fullallocation.
265-293: LGTM - Ragged reference implementation is correct.The
reference_ragged_transformfunction properly handles the offset addition for both trivial and non-trivial cases.
317-396: LGTM - Parametrized tests provide good coverage.The
test_top_k_page_table_transformandtest_top_k_ragged_transformtests cover multiple configurations (num_rows, max_len, k, dtype) with appropriate accuracy thresholds.
402-473: LGTM - Trivial case tests verify padding behavior.The trivial case tests properly verify both the copied values and the -1 padding behavior when
length <= k.
709-760: LGTM - Correctness tests validate actual output values.The
test_page_table_transform_correctness_exactandtest_ragged_transform_offset_correctnesstests provide strong validation by checking that output values are valid (exist in page table or within offset range).
766-854: LGTM - SGLang-style references enable compatibility testing.The SGLang-style reference implementations provide valuable cross-validation with an external implementation style, including prefill mode with
cu_seqlens_qmapping.
1082-1114: LGTM - Main block provides quick smoke tests.The
if __name__ == "__main__"block provides a convenient way to run a subset of tests and benchmarks locally without pytest.csrc/topk.cu (2)
62-107: LGTM - Page table transform entry point is well-implemented.The
radix_topk_page_table_transformfunction correctly:
- Validates all required inputs with appropriate dimension checks
- Extracts
src_stridefor potentially non-contiguous page tables- Handles optional
row_to_batchandrow_states_bufferparameters- Uses the type dispatch macro consistently
- Includes proper error checking
109-146: LGTM - Ragged transform entry point follows established pattern.The
radix_topk_ragged_transformfunction correctly validates inputs, handles the optionalrow_states_buffer, and dispatches to the kernel with proper error checking.benchmarks/bench_topk.py (2)
121-164: LGTM!The
bench_ragged_transformfunction is well-structured with proper parameter usage and consistent SGLang comparison logic.
167-318: LGTM with note on performance summary.The main function provides comprehensive benchmark coverage with good error handling. The performance summary (lines 300-314) makes specific claims that should be verified by users on their hardware, as actual speedups may vary.
flashinfer/topk.py (4)
64-130: LGTM!The custom op registrations for
radix_topk_page_table_transformandradix_topk_ragged_transformfollow the established pattern with proper dtype validation, mutation declarations, and fake op implementations for torch.compile compatibility.
132-136: LGTM!The module namespace correctly exports the new fused operations, and the aliases provide API flexibility for users who prefer either naming convention.
Also applies to: 406-408
241-327: Well-designed API with comprehensive documentation and thorough test coverage.The
top_k_page_table_transformfunction provides a clean interface with excellent documentation and examples. The implementation correctly allocates buffers and delegates to the kernel. The edge case behavior described in the docstring (lines 286-287: "If lengths[i] <= k, the output simply contains src_page_table[batch_idx, 0:lengths[i]]") is thoroughly tested intest_page_table_transform_trivial_case(tests/utils/test_topk.py, lines 402-435), which explicitly validates this scenario with lengths uniformly sampled in the range [1, k].
330-403: Well-designed API with comprehensive documentation and test coverage.The
top_k_ragged_transformfunction is properly implemented with thorough documentation. The edge case behavior described in the docstring ("If lengths[i] <= k, the output contains [offsets[i], offsets[i]+1, ..., offsets[i]+lengths[i]-1]") is explicitly covered bytest_ragged_transform_trivial_case, which validates both the valid entries and -1 padding for lengths ≤ k. Additional tests including variable lengths, large-scale scenarios, and correctness checks provide robust coverage.
| def bench_top_k( | ||
| batch_size: int, | ||
| seq_len: int, | ||
| k: int, | ||
| dtype: torch.dtype = torch.float32, | ||
| compare_sglang: bool = False, | ||
| ) -> dict: |
There was a problem hiding this comment.
Remove unused compare_sglang parameter.
The compare_sglang parameter is declared but never used in this function, unlike bench_page_table_transform and bench_ragged_transform which implement SGLang comparisons. This creates API inconsistency and may confuse users.
Apply this diff to remove the unused parameter:
def bench_top_k(
batch_size: int,
seq_len: int,
k: int,
dtype: torch.dtype = torch.float32,
- compare_sglang: bool = False,
) -> dict:🧰 Tools
🪛 Ruff (0.14.8)
32-32: Unused function argument: compare_sglang
(ARG001)
🤖 Prompt for AI Agents
In benchmarks/bench_topk.py around lines 27 to 33, the function bench_top_k
declares an unused parameter compare_sglang which creates an API inconsistency;
remove the compare_sglang parameter from the function signature and any
corresponding callers/tests (update imports/usage sites) so the function only
accepts batch_size, seq_len, k, and dtype, and update any docstrings/type hints
to reflect the removed parameter.
include/flashinfer/sampling.cuh
Outdated
|
|
||
| // Fixed shared memory (at the beginning) | ||
| constexpr size_t fixed_smem_size = | ||
| sizeof(uint32_t) * (RADIX + RADIX + 4); // histogram + suffix + 4 scalars | ||
| // histogram[256] + suffix[256] + 5 scalars (for RadixSelectFromSharedMemory) | ||
| constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 5); |
There was a problem hiding this comment.
Verify shared_scalars allocation size consistency.
The fixed shared memory size calculation uses num_scalars = SINGLE_CTA ? 5 : 4 (line 2530), but the comment at line 2688-2689 mentions "5 scalars" for RadixSelectFromSharedMemory. Ensure the scalar count is consistent:
RadixSelectFromSharedMemoryuses indices 0-4 (5 scalars)- For multi-CTA mode, the kernel sets
num_scalars = 4, but the callee needs 5
This could cause a buffer overflow in multi-CTA mode.
The num_scalars calculation at line 2530 sets it to 4 for multi-CTA mode, but RadixSelectFromSharedMemory uses shared_scalars[4] for shared_output_counter in single-CTA mode. Verify that multi-CTA mode doesn't access shared_scalars[4]:
- constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4;
+ constexpr size_t num_scalars = 5; // Always allocate 5 for consistencyOr confirm that RadixSelectFromSharedMemory only uses shared_scalars[4] when SINGLE_CTA is true.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 2686-2689 (and refer to
num_scalars at line 2530), the fixed shared-memory size comment and allocation
assume 5 scalar slots but num_scalars is computed as SINGLE_CTA ? 5 : 4, which
can under-allocate in multi-CTA mode while RadixSelectFromSharedMemory expects
index 0-4; fix by making the shared memory sizing consistent: either ensure
RadixSelectFromSharedMemory only accesses shared_scalars[4] when SINGLE_CTA is
true (add appropriate #if/SINGLE_CTA guards around that access), or change the
fixed_smem_size calculation to allocate for the maximum scalar count (use max
num_scalars = 5) and document it, and update any runtime num_scalars assignment
to match the allocation so no buffer overflow can occur.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
include/flashinfer/sampling.cuh (1)
2533-2542: Fix the out-of-bounds pointer formation in multi-CTA mode.The allocation uses
num_scalars = SINGLE_CTA ? 5 : 4;, butshared_output_counteris defined asshared_scalars[4]unconditionally. In multi-CTA mode, this forms a pointer to index 4 when only indices 0-3 are allocated.While
shared_output_counteris only dereferenced withinif constexpr (SINGLE_CTA)blocks (lines 2447, 2468), forming the out-of-bounds pointer&shared_scalars[4]in multi-CTA mode violates C++ memory safety rules and could trigger warnings from strict sanitizers like AddressSanitizer. This inconsistency also conflicts with other kernels in the same file (e.g.,RadixTopKMaskLogitsKernel_MultiCTAalways allocates 5 scalars).Consider one of these approaches:
- Always allocate 5 scalars:
constexpr size_t num_scalars = 5;- Use conditional pointer: Pass
SINGLE_CTA ? &shared_scalars[4] : nullptrto functions- Add a compile-time assertion to document the constraint
🧹 Nitpick comments (2)
include/flashinfer/sampling.cuh (2)
2721-2725: Consider aligning scalar allocation with RadixTopKKernel_Unified.This kernel always allocates 5 scalars regardless of the
SINGLE_CTAtemplate parameter, whereasRadixTopKKernel_Unified(line 2533) conditionally allocates 4 or 5.In multi-CTA mode, the 5th scalar is unused, wasting 4 bytes per block. While the memory overhead is minor, aligning the allocation strategy across kernels would improve consistency and maintainability.
Apply this diff if consistency is desired:
- // Fixed shared memory (at the beginning) - // histogram[256] + suffix[256] + 5 scalars (for RadixSelectFromSharedMemory) - constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 5); + // Fixed shared memory (at the beginning) + // histogram[256] + suffix[256] + scalars (5 for single-CTA, 4 for multi-CTA) + constexpr size_t num_scalars = SINGLE_CTA ? 5 : 4; + constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + num_scalars);
2930-2934: Fragile shared memory layout relies on shared_sum for overflow space.The allocation provides 4 scalars (
shared_scalars[0-3]) followed byshared_sum, butRadixSelectFromSharedMemory(called viaRadixSelectFindPivot) expectsshared_scalars[0-4]. In SINGLE_CTA mode, accessingshared_scalars[4]overlaps with theshared_summemory.This is currently safe because:
shared_output_counter(shared_scalars[4]) is only accessed in SINGLE_CTA mode- The memory at that location (
shared_sum) provides valid backing storage- No actual reads/writes cause type confusion
However, this implicit dependency makes the code fragile. If
RadixSelectFromSharedMemorychanges to readshared_output_counterin multi-CTA mode, or if the memory layout shifts, silent bugs could occur.Consider making the layout explicit:
- constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + 4) + sizeof(float); + constexpr size_t num_scalars = 5; // RadixSelectFromSharedMemory needs indices 0-4 + constexpr size_t fixed_smem_size = sizeof(uint32_t) * (RADIX + RADIX + num_scalars) + sizeof(float); uint32_t* local_histogram = reinterpret_cast<uint32_t*>(smem); uint32_t* suffix_sum = local_histogram + RADIX; uint32_t* shared_scalars = suffix_sum + RADIX; - float* shared_sum = reinterpret_cast<float*>(shared_scalars + 4); + float* shared_sum = reinterpret_cast<float*>(shared_scalars + num_scalars);This documents that 5 scalars are intentionally allocated, even if only 4 are used in multi-CTA mode, and clarifies the memory layout.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/sampling.cuh(10 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
benchmarks/bench_topk.py (2)
37-73: Remove unusedcompare_sglangparameter.This parameter is unused in
bench_top_k, unlike the other benchmark functions which implement SGLang comparisons. This creates API inconsistency.
98-100: Remove commented-out code.These lines appear to be leftover from development and should be removed.
🧹 Nitpick comments (4)
tests/utils/test_topk.py (3)
243-243: Remove discarded expression.
scores.size(1)is called but its return value is not used. This appears to be dead code, possibly a leftover from debugging or an incomplete validation.num_rows = scores.size(0) - scores.size(1) device = scores.device
296-314: Unused parameterkin function signature.The
kparameter is declared but never used. The function computes accuracy based onref_setsize from the actual data rather than the expectedk. This is actually correct behavior (handles padding properly), but the parameter should be removed to avoid confusion.-def compute_transform_accuracy(test_output, ref_output, num_rows, k): +def compute_transform_accuracy(test_output, ref_output, num_rows): """Compute accuracy for transform outputs, handling -1 padding correctly."""Note: You'll need to update all call sites (lines 353, 394, 518, 555, 592, 627, 662, 685, 705, 891, 928, 977) to remove the
kargument.
785-785: Remove discarded expression.Same issue as line 243 -
scores.size(1)result is unused.num_rows = scores.size(0) - scores.size(1) device = scores.devicebenchmarks/bench_topk.py (1)
20-25: Environment variable manipulation may have side effects.The
set_topk_algofunction modifiesos.environwhich persists across benchmark runs. While this is reset to "auto" after algorithm comparison (line 262), if an exception occurs between setting the algorithm and the reset, the environment will be left in an inconsistent state.Consider using a context manager pattern for safer environment variable handling:
from contextlib import contextmanager @contextmanager def topk_algo(algo: str): """Context manager to temporarily set topk algorithm.""" old_value = os.environ.get("FLASHINFER_TOPK_ALGO") try: if algo == "auto": os.environ.pop("FLASHINFER_TOPK_ALGO", None) else: os.environ["FLASHINFER_TOPK_ALGO"] = algo yield finally: if old_value is None: os.environ.pop("FLASHINFER_TOPK_ALGO", None) else: os.environ["FLASHINFER_TOPK_ALGO"] = old_value
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/bench_topk.py(1 hunks)csrc/topk.cu(1 hunks)tests/utils/test_topk.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/utils/test_topk.py (1)
flashinfer/topk.py (2)
top_k_page_table_transform(241-327)top_k_ragged_transform(330-403)
benchmarks/bench_topk.py (2)
flashinfer/testing/utils.py (1)
bench_gpu_time(998-1059)flashinfer/topk.py (3)
top_k(139-234)top_k_page_table_transform(241-327)top_k_ragged_transform(330-403)
🪛 Ruff (0.14.8)
tests/utils/test_topk.py
296-296: Unused function argument: k
(ARG001)
benchmarks/bench_topk.py
42-42: Unused function argument: compare_sglang
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
csrc/topk.cu (2)
62-108: LGTM! Well-structured page table transform entry point.The function follows the established pattern from
radix_topk, with proper input validation, device/stream setup, optional buffer handling, and error checking. Thesrc_strideextraction for non-contiguous page tables is a good detail.
110-148: LGTM! Ragged transform entry point is consistent with existing code.The implementation mirrors the page table transform with appropriate adjustments for the ragged use case (offsets instead of page table, no row_to_batch mapping). Error handling and type dispatch are consistent.
tests/utils/test_topk.py (3)
317-396: Good test coverage for the new transform operations.The parameterized tests cover a solid matrix of configurations (num_rows, max_len, k, dtype) for both page table and ragged transforms. The 0.95 accuracy threshold appropriately accounts for tie-breaking differences in radix-based vs heap-based selection.
399-474: Trivial case tests validate edge behavior correctly.Tests properly verify that when
length <= k, the output contains the expected sequential entries (page table values or offset-based indices) with-1padding for remaining positions.
709-761: Exact correctness tests provide strong guarantees.These tests use unique page table values and verify that all outputs exist in the correct range, providing stronger correctness guarantees beyond accuracy metrics.
benchmarks/bench_topk.py (3)
76-128: Page table transform benchmark implementation looks good.The function properly sets up input tensors with contiguous page tables and handles the optional SGLang comparison with appropriate constraints (k=2048, float32).
131-174: Ragged transform benchmark is consistent with page table benchmark.The implementation follows the same pattern with appropriate adjustments for ragged layout (offsets instead of page table).
191-277: Algorithm comparison mode is useful for performance analysis.The comparison between Multi-CTA and Filtered algorithms provides valuable insights for tuning the automatic algorithm selection heuristics.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/sampling.cuh(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/sampling.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/sampling.cuh (1)
36-36: Good refactoring: Top-K logic externalized.The inclusion of
topk.cuhappropriately replaces the inline multi-CTA Radix Top-K implementation that was removed from this file. This separation of concerns improves modularity and enables reuse of Top-K functionality.
| #include <cstdlib> | ||
| #include <cstring> |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Check if cstdlib or cstring functions are used in sampling.cuh
# Common cstdlib functions
echo "=== Checking for cstdlib usage ==="
rg -nP '\b(malloc|calloc|realloc|free|atoi|atof|strtol|strtod|abs|rand|srand|exit|getenv)\s*\(' include/flashinfer/sampling.cuh
# Common cstring functions
echo "=== Checking for cstring usage ==="
rg -nP '\b(memcpy|memset|memmove|memcmp|strlen|strcpy|strncpy|strcmp|strncmp|strcat|strchr|strstr)\s*\(' include/flashinfer/sampling.cuh
# Check if they're needed for other headers or forward declarations
echo "=== Checking topk.cuh dependencies ==="
rg -nP '\b(malloc|calloc|realloc|free|memcpy|memset|strlen|strcmp)\s*\(' include/flashinfer/topk.cuh
# Also check the includes at lines 24-25 to confirm they're there
echo "=== Current includes at lines 24-25 ==="
sed -n '24,25p' include/flashinfer/sampling.cuhRepository: flashinfer-ai/flashinfer
Length of output: 418
🏁 Script executed:
#!/bin/bash
# Check what includes topk.cuh has
echo "=== Includes in topk.cuh (first 50 lines) ==="
head -50 include/flashinfer/topk.cuh | grep -E "^#include"
echo ""
echo "=== Full include section of topk.cuh ==="
sed -n '1,/^[^#]/p' include/flashinfer/topk.cuh | grep "^#include"Repository: flashinfer-ai/flashinfer
Length of output: 285
🏁 Script executed:
#!/bin/bash
# Check the include section of sampling.cuh to see include order
echo "=== First 50 lines of sampling.cuh ==="
head -50 include/flashinfer/sampling.cuh | cat -n
echo ""
echo "=== Search for topk.cuh include ==="
rg -n "topk.cuh" include/flashinfer/sampling.cuh
echo ""
echo "=== Check if cstring is included via other headers in sampling.cuh ==="
head -50 include/flashinfer/sampling.cuh | grep "#include"Repository: flashinfer-ai/flashinfer
Length of output: 2405
Move <cstring> include to topk.cuh where it's actually used.
The includes <cstdlib> and <cstring> at lines 24-25 are not used directly in sampling.cuh. However, topk.cuh (included at line 36) uses std::strcmp at lines 2469-2470, which requires <cstring>. Currently, topk.cuh doesn't include <cstring> itself, creating an improper transitive dependency where sampling.cuh provides it. Move <cstring> to topk.cuh and remove both unused includes from sampling.cuh.
🤖 Prompt for AI Agents
In include/flashinfer/sampling.cuh around lines 24-25, remove the unused
includes <cstdlib> and <cstring>; then open include/flashinfer/topk.cuh and add
#include <cstring> near the other standard headers so topk.cuh directly provides
std::strcmp, eliminating the transitive dependency on sampling.cuh.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
benchmarks/bench_topk.py (3)
37-69: Unusedcompare_sglangparameter.The
compare_sglangparameter is declared but never used, unlike the other benchmark functions. This was flagged in a previous review.
94-96: Remove commented-out code.This was flagged in a previous review. The commented-out benchmark alternative should be removed.
267-381: Code duplication in benchmark loops.The three benchmark loops (top_k, page_table, ragged) share similar structure for iteration, OOM handling, and result printing. This was flagged in a previous review as a refactoring opportunity.
🧹 Nitpick comments (1)
benchmarks/bench_topk.py (1)
165-176: Consider adding error handling for invalid dtype strings.Invalid dtype strings will raise a
KeyError. A more descriptive error message would improve usability.def parse_dtype(dtype_str: str) -> torch.dtype: """Parse dtype string to torch.dtype.""" dtype_map = { "fp32": torch.float32, "float32": torch.float32, "fp16": torch.float16, "float16": torch.float16, "half": torch.float16, "bf16": torch.bfloat16, "bfloat16": torch.bfloat16, } - return dtype_map[dtype_str.lower()] + key = dtype_str.lower() + if key not in dtype_map: + raise ValueError(f"Unknown dtype '{dtype_str}'. Valid options: {list(dtype_map.keys())}") + return dtype_map[key]
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/bench_topk.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_topk.py (2)
flashinfer/testing/utils.py (1)
bench_gpu_time_with_cupti(646-878)flashinfer/topk.py (3)
top_k(139-234)top_k_page_table_transform(241-327)top_k_ragged_transform(330-403)
🪛 Ruff (0.14.8)
benchmarks/bench_topk.py
42-42: Unused function argument: compare_sglang
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
benchmarks/bench_topk.py (6)
1-18: LGTM!Clear module docstring and clean imports. All imported modules are used appropriately.
20-26: LGTM!Simple and correct environment variable management for algorithm selection.
28-34: LGTM!Standard pattern for optional dependency handling with a clear flag.
123-162: LGTM!Correct construction of offsets for ragged indexing and consistent SGLang comparison logic.
179-265: LGTM!Well-structured CLI with clear argument descriptions and algorithm comparison mode. The OOM handling correctly distinguishes memory errors from other runtime exceptions.
384-385: LGTM!Standard entry point pattern.
| cu_seqlens_q = torch.arange(0, batch_size + 1, dtype=torch.int32, device="cuda") | ||
| measurements = bench_gpu_time_with_cupti( | ||
| lambda: sgl_kernel.fast_topk_transform_fused( | ||
| scores, lengths, src_page_table, cu_seqlens_q, k | ||
| ), |
There was a problem hiding this comment.
Fix cu_seqlens_q construction to use actual sequence lengths.
The current construction torch.arange(0, batch_size + 1) produces sequential indices, but cumulative sequence lengths should reflect actual token positions. Compute it as torch.cumsum(lengths, dim=0, dtype=torch.int32) with a 0 prepended, or use torch.cat([torch.tensor([0], dtype=torch.int32, device='cuda'), torch.cumsum(lengths, dim=0, dtype=torch.int32)]).
🤖 Prompt for AI Agents
In benchmarks/bench_topk.py around lines 110 to 114, cu_seqlens_q is incorrectly
built using torch.arange(0, batch_size + 1) which creates sequential indices
instead of cumulative token positions; replace it with the cumulative sum of the
per-sequence lengths with a leading zero, e.g. construct cu_seqlens_q by
prepending 0 to torch.cumsum(lengths, dim=0, dtype=torch.int32) and ensure the
tensor is on device='cuda' (or use torch.cat([torch.tensor([0],
dtype=torch.int32, device='cuda'), torch.cumsum(lengths, dim=0,
dtype=torch.int32, device='cuda')])) so cu_seqlens_q correctly represents
cumulative sequence lengths.
There was a problem hiding this comment.
Actionable comments posted: 2
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/topk.cu(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/topk.cu (2)
flashinfer/topk.py (1)
top_k(139-234)csrc/tvm_ffi_utils.h (1)
get_stream(294-296)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
csrc/topk.cu (2)
49-60: LGTM! Clean refactor to unified dispatch.The switch from direct
RadixTopKMultiCTAtoTopKDispatchenables algorithm selection via heuristics, which aligns with the PR objectives. The error message update is appropriate.
50-50: No issue found: bf16 is supported by the dispatch macro.The
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16macro includes_DISPATCH_CASE_BF16at line 105 ofcsrc/tvm_ffi_utils.h, which handles the bfloat16 data type by mapping it to thenv_bfloat16C type. The macro's comment explicitly documents support for "FP32/FP16/BF16 data types." While the macro name is misleading (omitting BF16 from the identifier), the implementation correctly supports all three types as claimed in the PR objectives.
| void radix_topk_page_table_transform(TensorView input, TensorView output_page_table, | ||
| TensorView src_page_table, | ||
| Optional<TensorView> maybe_row_to_batch, TensorView lengths, | ||
| Optional<TensorView> maybe_row_states_buffer, int64_t top_k) { | ||
| CHECK_INPUT(input); | ||
| CHECK_INPUT(output_page_table); | ||
| CHECK_INPUT(src_page_table); | ||
| CHECK_INPUT(lengths); | ||
| CHECK_DIM(2, input); // input: (num_rows, max_len) | ||
| CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k) | ||
| CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len) | ||
| CHECK_DIM(1, lengths); // lengths: (num_rows,) | ||
|
|
||
| unsigned int num_rows = input.size(0); | ||
| unsigned int max_len = input.size(1); | ||
| int64_t src_stride = src_page_table.stride(0); | ||
|
|
||
| cudaSetDevice(input.device().device_id); | ||
| auto stream = get_stream(input.device()); | ||
|
|
||
| cudaError_t status; | ||
| auto dtype = input.dtype(); | ||
|
|
||
| sampling::RadixRowState* row_states_ptr = nullptr; | ||
| if (maybe_row_states_buffer.has_value()) { | ||
| row_states_ptr = | ||
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | ||
| } | ||
|
|
||
| int32_t* row_to_batch_ptr = nullptr; | ||
| if (maybe_row_to_batch.has_value()) { | ||
| row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr()); | ||
| } | ||
|
|
||
| // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK | ||
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | ||
| status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>( | ||
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()), | ||
| static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr, | ||
| static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len, | ||
| row_states_ptr, stream); | ||
| return true; | ||
| }); | ||
|
|
||
| TVM_FFI_ICHECK(status == cudaSuccess) | ||
| << "TopKPageTableTransform failed with error code " << cudaGetErrorString(status); | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Consider adding size validations for early error detection.
The function validates tensor ranks but not the actual dimension sizes. Given the complex parameter relationships in this page-table transform, adding explicit size checks would improve debuggability and user experience.
Consider adding these validations after line 73:
CHECK_DIM(1, lengths); // lengths: (num_rows,)
+
+unsigned int num_rows = input.size(0);
+unsigned int max_len = input.size(1);
+
+TVM_FFI_ICHECK(output_page_table.size(0) == num_rows)
+ << "output_page_table batch size mismatch: expected " << num_rows
+ << ", got " << output_page_table.size(0);
+TVM_FFI_ICHECK(output_page_table.size(1) == top_k)
+ << "output_page_table second dimension mismatch: expected " << top_k
+ << ", got " << output_page_table.size(1);
+TVM_FFI_ICHECK(lengths.size(0) == num_rows)
+ << "lengths size mismatch: expected " << num_rows
+ << ", got " << lengths.size(0);
+TVM_FFI_ICHECK(src_page_table.size(1) == max_len)
+ << "src_page_table second dimension mismatch: expected " << max_len
+ << ", got " << src_page_table.size(1);
-
-unsigned int num_rows = input.size(0);
-unsigned int max_len = input.size(1);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void radix_topk_page_table_transform(TensorView input, TensorView output_page_table, | |
| TensorView src_page_table, | |
| Optional<TensorView> maybe_row_to_batch, TensorView lengths, | |
| Optional<TensorView> maybe_row_states_buffer, int64_t top_k) { | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(output_page_table); | |
| CHECK_INPUT(src_page_table); | |
| CHECK_INPUT(lengths); | |
| CHECK_DIM(2, input); // input: (num_rows, max_len) | |
| CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k) | |
| CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len) | |
| CHECK_DIM(1, lengths); // lengths: (num_rows,) | |
| unsigned int num_rows = input.size(0); | |
| unsigned int max_len = input.size(1); | |
| int64_t src_stride = src_page_table.stride(0); | |
| cudaSetDevice(input.device().device_id); | |
| auto stream = get_stream(input.device()); | |
| cudaError_t status; | |
| auto dtype = input.dtype(); | |
| sampling::RadixRowState* row_states_ptr = nullptr; | |
| if (maybe_row_states_buffer.has_value()) { | |
| row_states_ptr = | |
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | |
| } | |
| int32_t* row_to_batch_ptr = nullptr; | |
| if (maybe_row_to_batch.has_value()) { | |
| row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr()); | |
| } | |
| // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK | |
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | |
| status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>( | |
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()), | |
| static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr, | |
| static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len, | |
| row_states_ptr, stream); | |
| return true; | |
| }); | |
| TVM_FFI_ICHECK(status == cudaSuccess) | |
| << "TopKPageTableTransform failed with error code " << cudaGetErrorString(status); | |
| } | |
| void radix_topk_page_table_transform(TensorView input, TensorView output_page_table, | |
| TensorView src_page_table, | |
| Optional<TensorView> maybe_row_to_batch, TensorView lengths, | |
| Optional<TensorView> maybe_row_states_buffer, int64_t top_k) { | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(output_page_table); | |
| CHECK_INPUT(src_page_table); | |
| CHECK_INPUT(lengths); | |
| CHECK_DIM(2, input); // input: (num_rows, max_len) | |
| CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k) | |
| CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len) | |
| CHECK_DIM(1, lengths); // lengths: (num_rows,) | |
| unsigned int num_rows = input.size(0); | |
| unsigned int max_len = input.size(1); | |
| TVM_FFI_ICHECK(output_page_table.size(0) == num_rows) | |
| << "output_page_table batch size mismatch: expected " << num_rows | |
| << ", got " << output_page_table.size(0); | |
| TVM_FFI_ICHECK(output_page_table.size(1) == top_k) | |
| << "output_page_table second dimension mismatch: expected " << top_k | |
| << ", got " << output_page_table.size(1); | |
| TVM_FFI_ICHECK(lengths.size(0) == num_rows) | |
| << "lengths size mismatch: expected " << num_rows | |
| << ", got " << lengths.size(0); | |
| TVM_FFI_ICHECK(src_page_table.size(1) == max_len) | |
| << "src_page_table second dimension mismatch: expected " << max_len | |
| << ", got " << src_page_table.size(1); | |
| int64_t src_stride = src_page_table.stride(0); | |
| cudaSetDevice(input.device().device_id); | |
| auto stream = get_stream(input.device()); | |
| cudaError_t status; | |
| auto dtype = input.dtype(); | |
| sampling::RadixRowState* row_states_ptr = nullptr; | |
| if (maybe_row_states_buffer.has_value()) { | |
| row_states_ptr = | |
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | |
| } | |
| int32_t* row_to_batch_ptr = nullptr; | |
| if (maybe_row_to_batch.has_value()) { | |
| row_to_batch_ptr = static_cast<int32_t*>(maybe_row_to_batch.value().data_ptr()); | |
| } | |
| // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK | |
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | |
| status = sampling::TopKPageTableTransformDispatch<c_type, int32_t>( | |
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_page_table.data_ptr()), | |
| static_cast<const int32_t*>(src_page_table.data_ptr()), src_stride, row_to_batch_ptr, | |
| static_cast<int32_t*>(lengths.data_ptr()), num_rows, static_cast<uint32_t>(top_k), max_len, | |
| row_states_ptr, stream); | |
| return true; | |
| }); | |
| TVM_FFI_ICHECK(status == cudaSuccess) | |
| << "TopKPageTableTransform failed with error code " << cudaGetErrorString(status); | |
| } |
🤖 Prompt for AI Agents
In csrc/topk.cu around lines 62 to 108 (add checks after line 73), add explicit
size validations: ensure output_page_table.size(0) == input.size(0) and
lengths.size(0) == input.size(0) (num_rows), ensure input.size(1) ==
src_page_table.size(1) (max_len), ensure output_page_table.size(1) >=
static_cast<size_t>(top_k) (or == if contract requires exact top_k), if
maybe_row_to_batch is present check its numel == num_rows and that all entries
are within [0, src_page_table.size(0)-1], if maybe_row_states_buffer is present
validate its size is sufficient for num_rows (per RadixRowState element size),
and check that all values in lengths are <= max_len; on any mismatch return a
clear error via TVM_FFI_ICHECK or equivalent with descriptive message.
| void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, | ||
| TensorView lengths, Optional<TensorView> maybe_row_states_buffer, | ||
| int64_t top_k) { | ||
| CHECK_INPUT(input); | ||
| CHECK_INPUT(output_indices); | ||
| CHECK_INPUT(offsets); | ||
| CHECK_INPUT(lengths); | ||
| CHECK_DIM(2, input); // input: (num_rows, max_len) | ||
| CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k) | ||
| CHECK_DIM(1, offsets); // offsets: (num_rows,) | ||
| CHECK_DIM(1, lengths); // lengths: (num_rows,) | ||
|
|
||
| unsigned int num_rows = input.size(0); | ||
| unsigned int max_len = input.size(1); | ||
|
|
||
| cudaSetDevice(input.device().device_id); | ||
| auto stream = get_stream(input.device()); | ||
|
|
||
| cudaError_t status; | ||
| auto dtype = input.dtype(); | ||
|
|
||
| sampling::RadixRowState* row_states_ptr = nullptr; | ||
| if (maybe_row_states_buffer.has_value()) { | ||
| row_states_ptr = | ||
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | ||
| } | ||
|
|
||
| // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK | ||
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | ||
| status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>( | ||
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()), | ||
| static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()), | ||
| num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, stream); | ||
| return true; | ||
| }); | ||
|
|
||
| TVM_FFI_ICHECK(status == cudaSuccess) | ||
| << "TopKRaggedTransform failed with error code " << cudaGetErrorString(status); | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Consider adding size validations for early error detection.
Similar to radix_topk_page_table_transform, this function would benefit from explicit size checks to catch parameter mismatches early with clear error messages.
Consider adding these validations after line 120:
CHECK_DIM(1, lengths); // lengths: (num_rows,)
+
+unsigned int num_rows = input.size(0);
+unsigned int max_len = input.size(1);
+
+TVM_FFI_ICHECK(output_indices.size(0) == num_rows)
+ << "output_indices batch size mismatch: expected " << num_rows
+ << ", got " << output_indices.size(0);
+TVM_FFI_ICHECK(output_indices.size(1) == top_k)
+ << "output_indices second dimension mismatch: expected " << top_k
+ << ", got " << output_indices.size(1);
+TVM_FFI_ICHECK(offsets.size(0) == num_rows)
+ << "offsets size mismatch: expected " << num_rows
+ << ", got " << offsets.size(0);
+TVM_FFI_ICHECK(lengths.size(0) == num_rows)
+ << "lengths size mismatch: expected " << num_rows
+ << ", got " << lengths.size(0);
-
-unsigned int num_rows = input.size(0);
-unsigned int max_len = input.size(1);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, | |
| TensorView lengths, Optional<TensorView> maybe_row_states_buffer, | |
| int64_t top_k) { | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(output_indices); | |
| CHECK_INPUT(offsets); | |
| CHECK_INPUT(lengths); | |
| CHECK_DIM(2, input); // input: (num_rows, max_len) | |
| CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k) | |
| CHECK_DIM(1, offsets); // offsets: (num_rows,) | |
| CHECK_DIM(1, lengths); // lengths: (num_rows,) | |
| unsigned int num_rows = input.size(0); | |
| unsigned int max_len = input.size(1); | |
| cudaSetDevice(input.device().device_id); | |
| auto stream = get_stream(input.device()); | |
| cudaError_t status; | |
| auto dtype = input.dtype(); | |
| sampling::RadixRowState* row_states_ptr = nullptr; | |
| if (maybe_row_states_buffer.has_value()) { | |
| row_states_ptr = | |
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | |
| } | |
| // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK | |
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | |
| status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>( | |
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()), | |
| static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()), | |
| num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, stream); | |
| return true; | |
| }); | |
| TVM_FFI_ICHECK(status == cudaSuccess) | |
| << "TopKRaggedTransform failed with error code " << cudaGetErrorString(status); | |
| } | |
| void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, | |
| TensorView lengths, Optional<TensorView> maybe_row_states_buffer, | |
| int64_t top_k) { | |
| CHECK_INPUT(input); | |
| CHECK_INPUT(output_indices); | |
| CHECK_INPUT(offsets); | |
| CHECK_INPUT(lengths); | |
| CHECK_DIM(2, input); // input: (num_rows, max_len) | |
| CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k) | |
| CHECK_DIM(1, offsets); // offsets: (num_rows,) | |
| CHECK_DIM(1, lengths); // lengths: (num_rows,) | |
| unsigned int num_rows = input.size(0); | |
| unsigned int max_len = input.size(1); | |
| TVM_FFI_ICHECK(output_indices.size(0) == num_rows) | |
| << "output_indices batch size mismatch: expected " << num_rows | |
| << ", got " << output_indices.size(0); | |
| TVM_FFI_ICHECK(output_indices.size(1) == top_k) | |
| << "output_indices second dimension mismatch: expected " << top_k | |
| << ", got " << output_indices.size(1); | |
| TVM_FFI_ICHECK(offsets.size(0) == num_rows) | |
| << "offsets size mismatch: expected " << num_rows | |
| << ", got " << offsets.size(0); | |
| TVM_FFI_ICHECK(lengths.size(0) == num_rows) | |
| << "lengths size mismatch: expected " << num_rows | |
| << ", got " << lengths.size(0); | |
| cudaSetDevice(input.device().device_id); | |
| auto stream = get_stream(input.device()); | |
| cudaError_t status; | |
| auto dtype = input.dtype(); | |
| sampling::RadixRowState* row_states_ptr = nullptr; | |
| if (maybe_row_states_buffer.has_value()) { | |
| row_states_ptr = | |
| static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr()); | |
| } | |
| // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK | |
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { | |
| status = sampling::TopKRaggedTransformDispatch<c_type, int32_t>( | |
| static_cast<c_type*>(input.data_ptr()), static_cast<int32_t*>(output_indices.data_ptr()), | |
| static_cast<const int32_t*>(offsets.data_ptr()), static_cast<int32_t*>(lengths.data_ptr()), | |
| num_rows, static_cast<uint32_t>(top_k), max_len, row_states_ptr, stream); | |
| return true; | |
| }); | |
| TVM_FFI_ICHECK(status == cudaSuccess) | |
| << "TopKRaggedTransform failed with error code " << cudaGetErrorString(status); | |
| } |
🤖 Prompt for AI Agents
In csrc/topk.cu around lines 120-126, add explicit size validations after
computing num_rows and max_len: verify offsets.size(0) == num_rows and
lengths.size(0) == num_rows, verify output_indices.size(0) == num_rows and
output_indices.size(1) == static_cast<size_t>(top_k), check that top_k > 0 and
top_k <= static_cast<int64_t>(max_len), and if
maybe_row_states_buffer.has_value() validate its byte size/element count matches
expected sampling::RadixRowState array length for num_rows; on failure, emit
clear TVM_FFI_ICHECK messages identifying the mismatched parameter.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
docs/api/sampling.rst(1 hunks)docs/api/topk.rst(1 hunks)docs/index.rst(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- docs/api/topk.rst
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
docs/index.rst (1)
39-39: Theapi/topktoctree entry is correct and the referenced documentation file is properly formatted.The verification confirms that
docs/api/topk.rstexists and contains complete, properly formatted documentation for the top-k functions (top_k,top_k_page_table_transform,top_k_ragged_transform, andcan_implement_filtered_topk). The toctree entry placement is logical and follows the existing naming convention.
| .. seealso:: | ||
|
|
||
| For efficient Top-K selection (without sampling), see :ref:`apitopk` which provides | ||
| :func:`~flashinfer.top_k`, :func:`~flashinfer.top_k_page_table_transform`, and | ||
| :func:`~flashinfer.top_k_ragged_transform`. |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Verify that the referenced functions are exposed in the public API and documentation exists
# Check 1: Verify the three top-k functions are exported in flashinfer/__init__.py
echo "=== Checking if top-k functions are exported in __init__.py ==="
rg -n "top_k_page_table_transform|top_k_ragged_transform|^top_k" flashinfer/__init__.py | head -20
# Check 2: Verify docs/api/topk.rst exists and contains the functions
echo -e "\n=== Checking if docs/api/topk.rst exists and documents these functions ==="
if [ -f "docs/api/topk.rst" ]; then
echo "✓ docs/api/topk.rst exists"
rg -n "top_k_page_table_transform|top_k_ragged_transform|.. _apitopk" docs/api/topk.rst
else
echo "✗ docs/api/topk.rst not found"
fi
# Check 3: Verify the :ref:`apitopk` reference label is defined in docs/api/topk.rst
echo -e "\n=== Checking for the apitopk reference label ==="
rg -n ".. _apitopk" docs/api/topk.rst 2>/dev/null || echo "Not found or file doesn't exist"
# Check 4: Verify flashinfer/topk.py implements these functions
echo -e "\n=== Checking if topk.py exists and implements these functions ==="
if [ -f "flashinfer/topk.py" ]; then
echo "✓ flashinfer/topk.py exists"
rg -n "^def.*top_k" flashinfer/topk.py | head -10
else
echo "✗ flashinfer/topk.py not found"
fi
# Check 5: Look for the functions anywhere in the codebase
echo -e "\n=== Searching for top_k function definitions anywhere ==="
rg -n "^def (top_k|top_k_page_table_transform|top_k_ragged_transform)" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 1192
Remove reference to top_k from the seealso section.
The seealso section references :func:~flashinfer.top_k``, but this function is not included in the apitopk documentation. Only :func:~flashinfer.top_k_page_table_transform` and `:func:`~flashinfer.top_k_ragged_transform` are documented in that page.
🤖 Prompt for AI Agents
In docs/api/sampling.rst around lines 8 to 12, the seealso currently references
:func:`~flashinfer.top_k` which does not exist in the apitopk page; remove the
:func:`~flashinfer.top_k` entry so the seealso lists only
:func:`~flashinfer.top_k_page_table_transform` and
:func:`~flashinfer.top_k_ragged_transform`, keeping the surrounding text and
Sphinx cross-reference syntax intact.
|
[FAILED] Pipeline #40387618: 7/20 passed |
yongwww
left a comment
There was a problem hiding this comment.
Thanks for the great work on the Top-K optimizations and the new fused kernels for DSA. The performance gains are impressive
flashinfer/topk.py
Outdated
| topk_page_table_transform = top_k_page_table_transform | ||
| topk_ragged_transform = top_k_ragged_transform |
There was a problem hiding this comment.
seems the aliases can be removed here
|
everything else lgtm |
bkryu
left a comment
There was a problem hiding this comment.
LGTM except for the tagging the functions with @flashinfer_api for logging purposes. Can we get these added?
The achieved speedups are impressive 👀
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
flashinfer/topk.py (2)
258-344: Consider adding validation for k parameter.The implementation is well-documented and follows existing patterns correctly. However, adding a defensive check for
k > 0would catch invalid inputs early with a clear error message rather than relying on kernel-side validation.🔎 Optional validation check
def top_k_page_table_transform( input: torch.Tensor, src_page_table: torch.Tensor, lengths: torch.Tensor, k: int, row_to_batch: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Fused Top-K selection + Page Table Transform for sparse attention. + if k <= 0: + raise ValueError(f"k must be positive, got {k}") device = input.device num_rows = input.size(0)
348-421: Consider adding validation for k parameter.The implementation is well-documented and correctly structured. Similar to
top_k_page_table_transform, adding a defensive check fork > 0would improve early error detection.🔎 Optional validation check
def top_k_ragged_transform( input: torch.Tensor, offsets: torch.Tensor, lengths: torch.Tensor, k: int, ) -> torch.Tensor: r"""Fused Top-K selection + Ragged Index Transform for sparse attention. + if k <= 0: + raise ValueError(f"k must be positive, got {k}") device = input.device num_rows = input.size(0)
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/topk.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/topk.py (4)
flashinfer/api_logging.py (1)
flashinfer_api(464-565)flashinfer/utils.py (4)
register_custom_op(319-328)register_custom_op(338-357)register_fake_op(330-334)register_fake_op(359-364)csrc/flashinfer_topk_binding.cu (4)
radix_topk_page_table_transform(23-26)radix_topk_ragged_transform(28-30)radix_topk(20-21)can_implement_filtered_topk(32-32)csrc/topk.cu (8)
radix_topk_page_table_transform(63-109)radix_topk_page_table_transform(63-66)radix_topk_ragged_transform(111-149)radix_topk_ragged_transform(111-113)radix_topk(25-61)radix_topk(25-26)can_implement_filtered_topk(151-151)can_implement_filtered_topk(151-151)
🔇 Additional comments (4)
flashinfer/topk.py (4)
65-131: LGTM: Custom ops correctly implement fused Top-K transforms.The custom ops and their fake counterparts follow the correct pattern:
- Proper dtype validation (fp32/fp16/bf16)
- Correct
mutates_argsspecification- In-place mutation design matches PyTorch custom op best practices
- Fake ops appropriately stubbed for shape inference
141-152: LGTM: Clean capability query with clear documentation.The helper function provides a clean interface to check GPU support for FilteredTopK, with good documentation about the 128KB shared memory requirement.
257-258: Good: Decorator added per previous review feedback.The
@flashinfer_apidecorator has been added as requested by aleozlx in previous reviews.
347-348: Good: Decorator added per previous review feedback.The
@flashinfer_apidecorator has been added as requested by bkryu in previous reviews.
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
Port the multi-CTA radix-based top-k kernel from flashinfer PR sgl-project#2215 (flashinfer-ai/flashinfer#2215) into sglang as a JIT-compiled kernel. This replaces the existing AOT single-CTA top-k implementation for NSA attention, providing better performance on long sequences (32K+) where the multi-CTA path activates. Key changes: - Add `python/sglang/jit_kernel/topk.py`: Python API exposing three JIT top-k variants (basic, page-table transform, ragged transform) with workspace management and lazy compilation via `cache_once`. - Add `python/sglang/jit_kernel/csrc/elementwise/topk.cuh`: CUDA wrapper providing TVM FFI entry points that dispatch to the flashinfer adaptive top-k kernels (TopKDispatch, TopKPageTableTransformDispatch, TopKRaggedTransformDispatch). - Add `python/sglang/jit_kernel/include/sgl_kernel/topk_fi.cuh`: Core CUDA implementation adapted from flashinfer, featuring: - 8-bit radix selection algorithm with multi-CTA support for large sequences (threshold configurable, default 32K) - Support for float32, float16, and bfloat16 input types - row_starts parameter for ragged input score layouts (sglang-specific) - Three output modes: indices-only, page-table lookup, and ragged offset addition - Update `python/sglang/srt/layers/attention/nsa_backend.py`: Switch NSA indexer to import from JIT kernel instead of AOT sgl_kernel. - Update `sgl-kernel/python/sgl_kernel/top_k.py`: Add JIT fallback path controlled by SGLANG_USE_JIT_TOPK env var (default enabled). When JIT is available, fast_topk_v2 / fast_topk_transform_fused / fast_topk_transform_ragged_fused transparently delegate to JIT kernels. - Add `sgl-kernel/tests/test_topk_jit.py`: Correctness tests covering basic, page-table, ragged, and trivial (length <= topk) cases across various batch sizes and sequence lengths up to 131K. - Add `sgl-kernel/benchmarks/bench_topk_jit.py`: Latency benchmark comparing JIT multi-CTA vs AOT single-CTA kernels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Port the multi-CTA radix-based top-k kernel from flashinfer-ai/flashinfer#2215 into sglang as a JIT-compiled kernel. This replaces the existing AOT single-CTA top-k implementation for NSA attention, providing better performance on long sequences (32K+) where the multi-CTA path activates. Key changes: - Add `python/sglang/jit_kernel/topk.py`: Python API exposing three JIT top-k variants (basic, page-table transform, ragged transform) with workspace management and lazy compilation via `cache_once`. - Add `python/sglang/jit_kernel/csrc/elementwise/topk.cuh`: CUDA wrapper providing TVM FFI entry points that dispatch to the flashinfer adaptive top-k kernels (TopKDispatch, TopKPageTableTransformDispatch, TopKRaggedTransformDispatch). - Add `python/sglang/jit_kernel/include/sgl_kernel/topk_fi.cuh`: Core CUDA implementation adapted from flashinfer, featuring: - 8-bit radix selection algorithm with multi-CTA support for large sequences (threshold configurable, default 32K) - Support for float32, float16, and bfloat16 input types - row_starts parameter for ragged input score layouts (sglang-specific) - Three output modes: indices-only, page-table lookup, and ragged offset addition - Update `python/sglang/srt/layers/attention/nsa_backend.py`: Switch NSA indexer to import from JIT kernel instead of AOT sgl_kernel. - Update `sgl-kernel/python/sgl_kernel/top_k.py`: Add JIT fallback path controlled by SGLANG_USE_JIT_TOPK env var (default enabled). When JIT is available, fast_topk_v2 / fast_topk_transform_fused / fast_topk_transform_ragged_fused transparently delegate to JIT kernels. - Add `sgl-kernel/tests/test_topk_jit.py`: Correctness tests covering basic, page-table, ragged, and trivial (length <= topk) cases across various batch sizes and sequence lengths up to 131K.(TODO(yifan): 1M) - Add `sgl-kernel/benchmarks/bench_topk_jit.py`: Latency benchmark comparing JIT multi-CTA vs AOT single-CTA kernels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
📌 Description
Follow up of #2119 , this PR implements
top_k_page_table_transformandtop_k_ragged_transformfunction required for dsa in sglang. (fp16/bf16/fp32 are supported), as requested in #2221 .This PR also adds the more top-k algorithm to choose from:
and add heuristics to choose between algorthms.
Besides new features, this PR also moves the Top-K related functions outside of sampling.cuh (because they are not only used for sampling).
Benchmark on B200 for fp32 input:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Documentation
Chores
✏️ Tip: You can customize this high-level summary in your review settings.