-
Notifications
You must be signed in to change notification settings - Fork 587
perf: Optimize helper max/minmax function in sampling.cuh #2058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -249,27 +249,31 @@ __device__ __forceinline__ std::tuple<float, float> GetMinMaxValue(float* in_dat | |
| TempStorage& temp_storage) { | ||
| const uint32_t tx = threadIdx.x; | ||
| vec_t<float, VEC_SIZE> in_data_vec; | ||
| float max_val = -cuda::std::numeric_limits<float>::infinity(), | ||
| min_val = cuda::std::numeric_limits<float>::infinity(); | ||
| // Thread-local min/max accumulation (deferred reduction) | ||
| float thread_max = -cuda::std::numeric_limits<float>::infinity(); | ||
| float thread_min = cuda::std::numeric_limits<float>::infinity(); | ||
|
|
||
| for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { | ||
| in_data_vec.fill(0); | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); | ||
| } | ||
| float in_data_[VEC_SIZE]; | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| in_data_[j] = in_data_vec[j]; | ||
| thread_max = max(thread_max, static_cast<float>(in_data_vec[j])); | ||
| thread_min = min(thread_min, static_cast<float>(in_data_vec[j])); | ||
| } | ||
| max_val = max( | ||
| max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce<VEC_SIZE>(in_data_, MaxReduceOp{})); | ||
| __syncthreads(); | ||
| min_val = min( | ||
| min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce<VEC_SIZE>(in_data_, MinReduceOp{})); | ||
| __syncthreads(); | ||
| } | ||
|
|
||
| // Single block reduction after loop completes | ||
| float max_val = | ||
| BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce(thread_max, MaxReduceOp{}); | ||
| __syncthreads(); | ||
| float min_val = | ||
| BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce(thread_min, MinReduceOp{}); | ||
|
|
||
| if (tx == 0) { | ||
| temp_storage.max_val = max_val; | ||
| temp_storage.min_val = min_val; | ||
|
|
@@ -288,22 +292,23 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u | |
| const uint32_t tx = threadIdx.x; | ||
| vec_t<float, VEC_SIZE> in_data_vec; | ||
|
|
||
| float max_val = 0; | ||
| // Thread-local max accumulation (deferred reduction) | ||
| float thread_max = 0.0f; | ||
bkryu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { | ||
| in_data_vec.fill(0); | ||
| if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { | ||
| in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); | ||
| } | ||
| float in_data_[VEC_SIZE]; | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < VEC_SIZE; ++j) { | ||
| in_data_[j] = in_data_vec[j]; | ||
| thread_max = max(thread_max, static_cast<float>(in_data_vec[j])); | ||
| } | ||
| max_val = max( | ||
| max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .template Reduce<VEC_SIZE>(in_data_, MaxReduceOp{})); | ||
| __syncthreads(); | ||
| } | ||
|
Comment on lines
+295
to
306
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function has two correctness issues:
Here is a suggested fix that addresses both issues: Please note that |
||
|
|
||
| // Single block reduction after loop completes | ||
| float max_val = | ||
| BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce) | ||
| .Reduce(thread_max, MaxReduceOp{}); | ||
| if (tx == 0) { | ||
| temp_storage.max_val = max_val; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reduction loop has correctness issues with out-of-bounds data handling. It processes padded zeros for threads that are entirely out of bounds and doesn't handle tail elements for partially out-of-bounds vectors. This can lead to incorrect min/max values.
The accumulation should only happen for valid data. The following change corrects the logic:
Please note that
cast_loadmay still read out of bounds if the input is not padded. A fully robust solution might require masked loads or scalar access for tail elements, but the suggested change fixes the logical error in reduction.