@@ -249,27 +249,31 @@ __device__ __forceinline__ std::tuple<float, float> GetMinMaxValue(float* in_dat
249249 TempStorage& temp_storage) {
250250 const uint32_t tx = threadIdx .x ;
251251 vec_t <float , VEC_SIZE> in_data_vec;
252- float max_val = -cuda::std::numeric_limits<float >::infinity (),
253- min_val = cuda::std::numeric_limits<float >::infinity ();
252+ // Thread-local min/max accumulation (deferred reduction)
253+ float thread_max = -cuda::std::numeric_limits<float >::infinity ();
254+ float thread_min = cuda::std::numeric_limits<float >::infinity ();
255+
254256 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
255257 in_data_vec.fill (0 );
256258 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
257259 in_data_vec.cast_load (in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
258260 }
259- float in_data_[VEC_SIZE];
260261#pragma unroll
261262 for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
262- in_data_[j] = in_data_vec[j];
263+ thread_max = max (thread_max, static_cast <float >(in_data_vec[j]));
264+ thread_min = min (thread_min, static_cast <float >(in_data_vec[j]));
263265 }
264- max_val = max (
265- max_val, BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
266- .Reduce <VEC_SIZE>(in_data_, MaxReduceOp{}));
267- __syncthreads ();
268- min_val = min (
269- min_val, BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
270- .Reduce <VEC_SIZE>(in_data_, MinReduceOp{}));
271- __syncthreads ();
272266 }
267+
268+ // Single block reduction after loop completes
269+ float max_val =
270+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
271+ .Reduce (thread_max, MaxReduceOp{});
272+ __syncthreads ();
273+ float min_val =
274+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
275+ .Reduce (thread_min, MinReduceOp{});
276+
273277 if (tx == 0 ) {
274278 temp_storage.max_val = max_val;
275279 temp_storage.min_val = min_val;
@@ -288,22 +292,23 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
288292 const uint32_t tx = threadIdx .x ;
289293 vec_t <float , VEC_SIZE> in_data_vec;
290294
291- float max_val = 0 ;
295+ // Thread-local max accumulation (deferred reduction)
296+ float thread_max = 0 .0f ;
292297 for (uint32_t i = 0 ; i < ceil_div (d, BLOCK_THREADS * VEC_SIZE); ++i) {
293298 in_data_vec.fill (0 );
294299 if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
295300 in_data_vec.cast_load (in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
296301 }
297- float in_data_[VEC_SIZE];
298302#pragma unroll
299303 for (uint32_t j = 0 ; j < VEC_SIZE; ++j) {
300- in_data_[j] = in_data_vec[j];
304+ thread_max = max (thread_max, static_cast < float >( in_data_vec[j])) ;
301305 }
302- max_val = max (
303- max_val, BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
304- .template Reduce <VEC_SIZE>(in_data_, MaxReduceOp{}));
305- __syncthreads ();
306306 }
307+
308+ // Single block reduction after loop completes
309+ float max_val =
310+ BlockReduce<float , BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim .reduce )
311+ .Reduce (thread_max, MaxReduceOp{});
307312 if (tx == 0 ) {
308313 temp_storage.max_val = max_val;
309314 }
0 commit comments