perf: optimize top-k kernel for fp32#2915
perf: optimize top-k kernel for fp32#2915Archie-wang wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe pull request optimizes the radix select and filtered top-k kernel implementations by replacing multi-stride Hillis-Steele suffix-sum loops with a warp-parallel shuffle-down approach, adding L2 prefetching during shared-memory operations, and refactoring shared buffer layouts and round control logic in the refinement path. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Warning Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting |
|
This PR may cause a Git merge conflict with PR #2661 , but the goals of the two PRs are orthogonal. |
|
Thanks for the heads-up! These two goals are indeed orthogonal — this PR focuses on optimizing the performance of the existing kernel, while #2661 adds deterministic mode. |
|
Could someone from @flashinfer-ai/ci-users please authorize this PR and trigger CI with @flashinfer-bot run? Thank you. |
|
| __syncthreads(); | ||
|
|
||
| // Suffix sum (Hillis Steele Scan) | ||
| // Suffix sum: warp-parallel approach using shuffle-down |
There was a problem hiding this comment.
The warp-level suffix sum code here is duplicated from above. Would it be possible to refactor them into a single function?
Description
Follow-up optimization of #2215. This PR improves the fp32 performance of both
FilteredTopKandMulti-CTA RadixTopKkernels on H200 (sm_90), achieving ~15% geomean speedup across 120 test cases.All optimizations use standard CUDA primitives (
__shfl_down_sync, PTXprefetch.global.L2,__builtin_expect) with no architecture-specific intrinsics.Changes
Warp-parallel suffix sum: Replace Hillis-Steele scan with
__shfl_down_sync-based approach in both RadixTopK and FilteredTopK. Reduces__syncthreadsfrom 16 to 3 per invocation.L2 prefetch: Add
prefetch.global.L22 strides ahead in FilteredTopK full-row scan and Multi-CTALoadToSharedOrdered.Cache ordered values in smem (FilteredTopK refine): Store
ToOrdered()results during the filter pass, eliminating global memory re-reads in refine rounds.Skip redundant fp32 refine round 0: Coarse pass (fp16 high 8 bits) already covers bits 24-31, making round 0 redundant. Reduces fp32 refine from 4 rounds to 3.
Branch prediction hints: Add
__builtin_expecton rare filter branches.Benchmark on H200
Summary (120 fp32 cases):
Summary by CodeRabbit