-
Notifications
You must be signed in to change notification settings - Fork 585
perf: Optimize helper max/minmax function in sampling.cuh #2058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: Optimize helper max/minmax function in sampling.cuh #2058
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis change refactors the sampling kernel's min/max reduction: threads accumulate thread-local min/max across loop iterations, then perform a single block-wide reduction after the loop to compute final min/max, removing per-iteration reductions, intermediate in-data arrays, and in-loop synchronizations. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Host
participant Kernel as SamplingKernel
participant Threads as Thread[block]
participant BlockReduce
Host->>Kernel: launch with data pointers
Kernel->>Threads: start loop over iterations
alt per-iteration (old)
Note over Threads: per-iteration reduction & sync\n(store intermediate in_data_)
Threads->>Threads: per-iteration min/max reduction
Threads->>Threads: __syncthreads()
end
Note over Threads: new flow — each thread updates\nthread_local_min / thread_local_max per iteration
Threads->>BlockReduce: after loop, single block-wide reduce
BlockReduce-->>Threads: block_min / block_max
Threads->>Kernel: write final min/max to temp_storage
Kernel-->>Host: kernel completes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces performance optimizations to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/sampling.cuh(2 hunks)
🔇 Additional comments (1)
include/flashinfer/sampling.cuh (1)
252-275: Optimization looks correct for GetMinMaxValue.The thread-local accumulation pattern followed by a single block reduction is a good optimization that reduces synchronization overhead. The initialization values (negative infinity for max, positive infinity for min) are correct, and the logic properly reuses
temp_storagewith appropriate synchronization between the max and min reductions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to optimize the max/minmax helper functions by deferring block-wide reductions until after the main loop. This is a solid performance optimization strategy. However, the current implementation introduces/preserves critical correctness bugs related to out-of-bounds data handling and incorrect initialization. Specifically, the reduction loops process padded zeros and potentially uninitialized data, and one of the functions incorrectly initializes its accumulator. I've provided suggestions to fix these critical issues to ensure correctness while retaining the performance benefits.
| for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { | ||
| in_data_vec.fill(0); | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); | ||
| } | ||
| float in_data_[VEC_SIZE]; | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| in_data_[j] = in_data_vec[j]; | ||
| thread_max = max(thread_max, static_cast<float>(in_data_vec[j])); | ||
| thread_min = min(thread_min, static_cast<float>(in_data_vec[j])); | ||
| } | ||
| max_val = max( | ||
| max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce<VEC_SIZE>(in_data_, MaxReduceOp{})); | ||
| __syncthreads(); | ||
| min_val = min( | ||
| min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce<VEC_SIZE>(in_data_, MinReduceOp{})); | ||
| __syncthreads(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reduction loop has correctness issues with out-of-bounds data handling. It processes padded zeros for threads that are entirely out of bounds and doesn't handle tail elements for partially out-of-bounds vectors. This can lead to incorrect min/max values.
The accumulation should only happen for valid data. The following change corrects the logic:
Please note that cast_load may still read out of bounds if the input is not padded. A fully robust solution might require masked loads or scalar access for tail elements, but the suggested change fixes the logical error in reduction.
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
const uint32_t base_offset = (i * BLOCK_THREADS + tx) * VEC_SIZE;
if (base_offset < d) {
in_data_vec.cast_load(in_data + row_idx * d + base_offset);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (base_offset + j < d) {
thread_max = max(thread_max, in_data_vec[j]);
thread_min = min(thread_min, in_data_vec[j]);
}
}
}
}
| // Thread-local max accumulation (deferred reduction) | ||
| float thread_max = 0.0f; | ||
| for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { | ||
| in_data_vec.fill(0); | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); | ||
| } | ||
| float in_data_[VEC_SIZE]; | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| in_data_[j] = in_data_vec[j]; | ||
| thread_max = max(thread_max, static_cast<float>(in_data_vec[j])); | ||
| } | ||
| max_val = max( | ||
| max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .template Reduce<VEC_SIZE>(in_data_, MaxReduceOp{})); | ||
| __syncthreads(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function has two correctness issues:
-
Incorrect Initialization:
thread_maxis initialized to0.0fon line 296. If all input values are negative, this will cause the function to incorrectly return 0 instead of the true maximum. It should be initialized to-infinity. -
Faulty Reduction Logic: The loop structure from lines 297-306 suffers from the same issues as in
GetMinMaxValue. It processes padded zeros for out-of-bounds threads and doesn't handle tail elements correctly, leading to incorrect results.
Here is a suggested fix that addresses both issues:
Please note that cast_load may still read out of bounds if the input is not padded. A fully robust solution might require masked loads or scalar access for tail elements, but the suggested change fixes the logical error in reduction.
// Thread-local max accumulation (deferred reduction)
float thread_max = -cuda::std::numeric_limits<float>::infinity();
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
const uint32_t base_offset = (i * BLOCK_THREADS + tx) * VEC_SIZE;
if (base_offset < d) {
in_data_vec.cast_load(in_data + row_idx * d + base_offset);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (base_offset + j < d) {
thread_max = max(thread_max, in_data_vec[j]);
}
}
}
}
|
/bot run |
|
[FAILED] Pipeline #38046670: 8/17 passed |
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, should be ready to merge when gitlab CI passed.
📌 Description
Apply optimizations similar to #2044 to max/min functions.
top_p_renorm_probs,top_k_renorm_probs, andtop_k_mask_logitssee 14%, 7%, and 19% speedups on B200 (logs frombench_sampling.py: bench_sampling_before.txt and bench_sampling_after.txt )🔍 Related Issues
#2044
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit