Skip to content

perf: optimize top-k kernel for fp32#2915

Open
Archie-wang wants to merge 2 commits intoflashinfer-ai:mainfrom
Archie-wang:optimize-topk-fp32
Open

perf: optimize top-k kernel for fp32#2915
Archie-wang wants to merge 2 commits intoflashinfer-ai:mainfrom
Archie-wang:optimize-topk-fp32

Conversation

@Archie-wang
Copy link
Copy Markdown

@Archie-wang Archie-wang commented Mar 30, 2026

Description

Follow-up optimization of #2215. This PR improves the fp32 performance of both FilteredTopK and Multi-CTA RadixTopK kernels on H200 (sm_90), achieving ~15% geomean speedup across 120 test cases.

All optimizations use standard CUDA primitives (__shfl_down_sync, PTX prefetch.global.L2, __builtin_expect) with no architecture-specific intrinsics.

Changes

  1. Warp-parallel suffix sum: Replace Hillis-Steele scan with __shfl_down_sync-based approach in both RadixTopK and FilteredTopK. Reduces __syncthreads from 16 to 3 per invocation.

  2. L2 prefetch: Add prefetch.global.L2 2 strides ahead in FilteredTopK full-row scan and Multi-CTA LoadToSharedOrdered.

  3. Cache ordered values in smem (FilteredTopK refine): Store ToOrdered() results during the filter pass, eliminating global memory re-reads in refine rounds.

  4. Skip redundant fp32 refine round 0: Coarse pass (fp16 high 8 bits) already covers bits 24-31, making round 0 redundant. Reduces fp32 refine from 4 rounds to 3.

  5. Branch prediction hints: Add __builtin_expect on rare filter branches.

Benchmark on H200

Summary (120 fp32 cases):

====================================================================================================
top_k_page_table_transform: Fused top-k + page table gather (dtype=FP32)
NOTE: sgl-kernel comparison only available for k=2048
====================================================================================================
 batch  seq_len      k |     PR#2215     This PR    Speedup |  sgl-kernel    Speedup
-----------------------------------------------------------------------------------------------
     1     4096    256 |     13.28us     10.25us      1.30x |
     1     4096    512 |     13.89us     10.22us      1.36x |
     1     4096   1024 |     12.53us     10.21us      1.23x |
     1     4096   2048 |     10.45us      9.19us      1.14x |      9.96us      1.08x
     1     4096   4096 |      7.83us      7.46us      1.05x |
     1    16384    256 |     15.70us     13.47us      1.17x |
     1    16384    512 |     16.41us     14.29us      1.15x |
     1    16384   1024 |     17.30us     15.04us      1.15x |
     1    16384   2048 |     19.06us     16.39us      1.16x |     17.82us      1.09x
     1    16384   4096 |     24.38us     22.54us      1.08x |
     1    65536    256 |     36.62us     33.94us      1.08x |
     1    65536    512 |     36.91us     34.37us      1.07x |
     1    65536   1024 |     37.64us     34.99us      1.08x |
     1    65536   2048 |     38.83us     36.25us      1.07x |     42.46us      1.17x
     1    65536   4096 |     40.50us     37.92us      1.07x |
     1   131072    256 |     42.10us     39.03us      1.08x |
     1   131072    512 |     42.49us     39.48us      1.08x |
     1   131072   1024 |     43.21us     41.93us      1.03x |
     1   131072   2048 |     43.93us     41.42us      1.06x |     67.98us      1.64x
     1   131072   4096 |     46.66us     44.00us      1.06x |
     1   262144    256 |     46.48us     43.08us      1.08x |
     1   262144    512 |     46.71us     43.30us      1.08x |
     1   262144   1024 |     47.23us     43.96us      1.07x |
     1   262144   2048 |     48.26us     45.02us      1.07x |    122.65us      2.72x
     1   262144   4096 |     50.43us     47.04us      1.07x |
     1   524288    256 |     47.26us     43.73us      1.08x |
     1   524288    512 |     47.61us     43.84us      1.09x |
     1   524288   1024 |     47.72us     44.35us      1.08x |
     1   524288   2048 |     48.23us     45.00us      1.07x |    204.49us      4.54x
     1   524288   4096 |     49.70us     45.91us      1.08x |
-----------------------------------------------------------------------------------------------
    16     4096    256 |     12.31us     10.44us      1.18x |
    16     4096    512 |     12.88us     10.73us      1.20x |
    16     4096   1024 |     12.53us     10.63us      1.18x |
    16     4096   2048 |     11.98us     10.36us      1.16x |     11.56us      1.12x
    16     4096   4096 |      7.80us      7.59us      1.03x |
    16    16384    256 |     16.28us     13.97us      1.17x |
    16    16384    512 |     17.51us     14.70us      1.19x |
    16    16384   1024 |     18.31us     15.69us      1.17x |
    16    16384   2048 |     20.05us     17.36us      1.16x |     18.95us      1.09x
    16    16384   4096 |     25.11us     23.07us      1.09x |
    16    65536    256 |     29.81us     24.38us      1.22x |
    16    65536    512 |     34.69us     28.47us      1.22x |
    16    65536   1024 |     35.45us     29.13us      1.22x |
    16    65536   2048 |     38.58us     32.45us      1.19x |     43.70us      1.35x
    16    65536   4096 |     43.26us     40.53us      1.07x |
    16   131072    256 |     46.94us     38.85us      1.21x |
    16   131072    512 |     47.42us     39.30us      1.21x |
    16   131072   1024 |     56.92us     47.32us      1.20x |
    16   131072   2048 |     58.23us     48.48us      1.20x |     69.79us      1.44x
    16   131072   4096 |     49.99us     46.88us      1.07x |
    16   262144    256 |     50.55us     47.07us      1.07x |
    16   262144    512 |     50.87us     47.53us      1.07x |
    16   262144   1024 |     51.35us     47.98us      1.07x |
    16   262144   2048 |     52.68us     49.26us      1.07x |    124.34us      2.52x
    16   262144   4096 |     55.17us     51.75us      1.07x |
    16   524288    256 |     95.28us     88.63us      1.08x |
    16   524288    512 |     95.73us     88.67us      1.08x |
    16   524288   1024 |     96.25us     89.29us      1.08x |
    16   524288   2048 |     97.59us     90.73us      1.08x |    238.07us      2.62x
    16   524288   4096 |    100.62us     93.57us      1.08x |
-----------------------------------------------------------------------------------------------
    64     4096    256 |     13.21us     11.27us      1.17x |
    64     4096    512 |     13.26us     11.69us      1.13x |
    64     4096   1024 |     13.24us     11.12us      1.19x |
    64     4096   2048 |     12.62us     10.77us      1.17x |     12.02us      1.12x
    64     4096   4096 |      8.25us      8.15us      1.01x |
    64    16384    256 |     17.71us     15.57us      1.14x |
    64    16384    512 |     18.07us     16.29us      1.11x |
    64    16384   1024 |     19.00us     17.23us      1.10x |
    64    16384   2048 |     20.89us     18.60us      1.12x |     20.08us      1.08x
    64    16384   4096 |     26.15us     24.23us      1.08x |
    64    65536    256 |     31.49us     25.54us      1.23x |
    64    65536    512 |     36.52us     29.85us      1.22x |
    64    65536   1024 |     37.53us     30.70us      1.22x |
    64    65536   2048 |     41.29us     34.70us      1.19x |     46.38us      1.34x
    64    65536   4096 |     47.07us     44.91us      1.05x |
    64   131072    256 |     52.95us     41.09us      1.29x |
    64   131072    512 |     53.49us     41.17us      1.30x |
    64   131072   1024 |     62.51us     49.83us      1.25x |
    64   131072   2048 |     64.33us     51.67us      1.25x |     80.99us      1.57x
    64   131072   4096 |     98.73us     93.12us      1.06x |
    64   262144    256 |     96.74us     66.51us      1.45x |
    64   262144    512 |    103.59us     72.32us      1.43x |
    64   262144   1024 |    104.21us     72.92us      1.43x |
    64   262144   2048 |    120.49us     88.53us      1.36x |    181.57us      2.05x
    64   262144   4096 |    154.53us    145.33us      1.06x |
    64   524288    256 |    173.74us    120.62us      1.44x |
    64   524288    512 |    174.28us    120.88us      1.44x |
    64   524288   1024 |    187.29us    131.79us      1.42x |
    64   524288   2048 |    188.14us    132.74us      1.42x |    316.33us      2.38x
    64   524288   4096 |    240.19us    224.12us      1.07x |
-----------------------------------------------------------------------------------------------
   256     4096    256 |     20.93us     17.76us      1.18x |
   256     4096    512 |     21.93us     18.29us      1.20x |
   256     4096   1024 |     22.28us     18.08us      1.23x |
   256     4096   2048 |     19.86us     16.35us      1.21x |     18.90us      1.16x
   256     4096   4096 |     11.08us     10.96us      1.01x |
   256    16384    256 |     31.45us     27.29us      1.15x |
   256    16384    512 |     33.34us     29.32us      1.14x |
   256    16384   1024 |     36.04us     31.62us      1.14x |
   256    16384   2048 |     40.13us     35.71us      1.12x |     38.73us      1.08x
   256    16384   4096 |     49.75us     46.33us      1.07x |
   256    65536    256 |     64.36us     54.59us      1.18x |
   256    65536    512 |     73.88us     62.69us      1.18x |
   256    65536   1024 |     75.83us     64.51us      1.18x |
   256    65536   2048 |     83.09us     72.59us      1.14x |     95.78us      1.32x
   256    65536   4096 |    160.92us    151.48us      1.06x |
   256   131072    256 |    117.91us     93.49us      1.26x |
   256   131072    512 |    118.73us     94.21us      1.26x |
   256   131072   1024 |    137.38us    109.11us      1.26x |
   256   131072   2048 |    141.89us    113.06us      1.26x |    201.97us      1.79x
   256   131072   4096 |    274.71us    258.09us      1.06x |
   256   262144    256 |    196.77us    158.45us      1.24x |
   256   262144    512 |    210.94us    166.44us      1.27x |
   256   262144   1024 |    213.29us    168.41us      1.27x |
   256   262144   2048 |    263.30us    198.33us      1.33x |    381.79us      1.93x
   256   262144   4096 |    480.43us    448.87us      1.07x |
   256   524288    256 |    364.47us    294.78us      1.24x |
   256   524288    512 |    365.91us    295.14us      1.24x |
   256   524288   1024 |    395.21us    311.07us      1.27x |
   256   524288   2048 |    400.29us    315.95us      1.27x |    651.57us      2.06x
   256   524288   4096 |    899.66us    835.37us      1.08x |
-----------------------------------------------------------------------------------------------
Geomean speedup vs PR#2215:   1.15x  (116 faster, 4 tied, 0 slower)
Geomean speedup vs sgl-kernel: 1.59x  (24 faster, 0 slower)

Summary by CodeRabbit

  • Refactor
    • Optimized suffix-sum algorithms in top-k selection kernels with improved warp-parallel operations.
    • Enhanced shared-memory data layout for refinement stages.
    • Added L2 cache prefetching for input scores to improve memory throughput.
    • Reduced thread synchronization overhead for better GPU kernel efficiency.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 30, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9b4a5946-afe4-4aa6-8892-29c29959372e

📥 Commits

Reviewing files that changed from the base of the PR and between 4941606 and dc854cb.

📒 Files selected for processing (1)
  • include/flashinfer/topk.cuh

📝 Walkthrough

Walkthrough

The pull request optimizes the radix select and filtered top-k kernel implementations by replacing multi-stride Hillis-Steele suffix-sum loops with a warp-parallel shuffle-down approach, adding L2 prefetching during shared-memory operations, and refactoring shared buffer layouts and round control logic in the refinement path.

Changes

Cohort / File(s) Summary
Suffix-Sum Algorithm Optimization
include/flashinfer/topk.cuh
Replaced RadixSuffixSum multi-stride Hillis-Steele loop with warp-parallel shuffle-down over 256 bins (8 warps), introducing cross-warp correction via scratch pointer. Updated all call sites in RadixSelectOneRound and RadixSelectFromSharedMemory to pass local_histogram scratch parameter.
Prefetching Enhancement
include/flashinfer/topk.cuh
Added L2 prefetching during shared-memory loads in LoadToSharedOrdered using stride-based inline PTX prefetch.global.L2 and enhanced vectorized full-row scan loops in FilteredTopKUnifiedKernel with L2 prefetching for input scores.
Refinement Path Optimization
include/flashinfer/topk.cuh
Split input buffer halves for indices and ordered-value caching in FilteredTopKUnifiedKernel refinement; adjusted bounds checks and cache reconstruction. Added logic to skip round 0 for fp32 when FIRST_SHIFT >= 16, adjusted control-flow synchronization and bin comparison predicates.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • nvmbreughe
  • cyx-6
  • yzh119
  • djmmoss
  • bkryu

Poem

🐰 Whiskers twitch with GPU glee,
Warps now shuffle, sync-free spree!
Prefetch whispers, buffers dance,
Radix hops in swift advance—
Faster top-k, what a gleam!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main optimization focus (fp32 performance) and aligns with the primary changes across the codebase for radix and filter top-k kernels.
Description check ✅ Passed The description comprehensively covers changes, rationale, and benchmark results, but does not explicitly address the PR checklist items (pre-commit checks and tests) from the template.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting /gemini review.

@jiangyinzuo
Copy link
Copy Markdown
Contributor

This PR may cause a Git merge conflict with PR #2661 , but the goals of the two PRs are orthogonal.

@Archie-wang
Copy link
Copy Markdown
Author

Thanks for the heads-up! These two goals are indeed orthogonal — this PR focuses on optimizing the performance of the existing kernel, while #2661 adds deterministic mode.
From the code overlap, most conflicts should be limited to the Multi-CTA RadixTopK path: LoadToSharedOrdered (my L2 prefetch changes vs. your RadixSelectFindPivot refactor) and RadixSelectFromSharedMemory (both PRs modify its function signature).
The FilteredTopK changes in this PR — including warp-parallel suffix sum, shared-memory caching, and skipping the refine round — should have minimal overlap with #2661.
I’m happy to rebase once either PR lands first. The conflicts are fairly localized and should be straightforward to resolve.

@Archie-wang
Copy link
Copy Markdown
Author

Could someone from @flashinfer-ai/ci-users please authorize this PR and trigger CI with @flashinfer-bot run? Thank you.

@jiangyinzuo
Copy link
Copy Markdown
Contributor

jiangyinzuo commented Apr 2, 2026

  1. I found a issue about top-k kernel optimazation,if this PR helps address this issue, could you mention that in the PR description?
  1. PR feat: implement deterministic topk #2661 is merged now, could you please resolve the git conflicts and
    check whether your optimization also works in deterministic mode?

  2. Seems that some optimization techniques can also improve fp16/bf16 topk performance, can you please show their benchmark results?

__syncthreads();

// Suffix sum (Hillis Steele Scan)
// Suffix sum: warp-parallel approach using shuffle-down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warp-level suffix sum code here is duplicated from above. Would it be possible to refactor them into a single function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants