From e4ea76d0c1783f43b4e87df5bd23bab340f7357e Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 7 Nov 2025 00:26:00 +0000 Subject: [PATCH] Optimize helper function --- include/flashinfer/sampling.cuh | 43 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6b134630cf..2952d88d0b 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -249,27 +249,31 @@ __device__ __forceinline__ std::tuple GetMinMaxValue(float* in_dat TempStorage& temp_storage) { const uint32_t tx = threadIdx.x; vec_t in_data_vec; - float max_val = -cuda::std::numeric_limits::infinity(), - min_val = cuda::std::numeric_limits::infinity(); + // Thread-local min/max accumulation (deferred reduction) + float thread_max = -cuda::std::numeric_limits::infinity(); + float thread_min = cuda::std::numeric_limits::infinity(); + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { in_data_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } - float in_data_[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - in_data_[j] = in_data_vec[j]; + thread_max = max(thread_max, static_cast(in_data_vec[j])); + thread_min = min(thread_min, static_cast(in_data_vec[j])); } - max_val = max( - max_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, MaxReduceOp{})); - __syncthreads(); - min_val = min( - min_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, MinReduceOp{})); - __syncthreads(); } + + // Single block reduction after loop completes + float max_val = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(thread_max, MaxReduceOp{}); + __syncthreads(); + float min_val = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(thread_min, MinReduceOp{}); + if (tx == 0) { temp_storage.max_val = max_val; temp_storage.min_val = min_val; @@ -288,22 +292,23 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u const uint32_t tx = threadIdx.x; vec_t in_data_vec; - float max_val = 0; + // Thread-local max accumulation (deferred reduction) + float thread_max = 0.0f; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { in_data_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - float in_data_[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - in_data_[j] = in_data_vec[j]; + thread_max = max(thread_max, static_cast(in_data_vec[j])); } - max_val = max( - max_val, BlockReduce(temp_storage.block_prim.reduce) - .template Reduce(in_data_, MaxReduceOp{})); - __syncthreads(); } + + // Single block reduction after loop completes + float max_val = + BlockReduce(temp_storage.block_prim.reduce) + .Reduce(thread_max, MaxReduceOp{}); if (tx == 0) { temp_storage.max_val = max_val; }