Skip to content

Commit e4ea76d

Browse files
committed
Optimize helper function
1 parent 55ea787 commit e4ea76d

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

include/flashinfer/sampling.cuh

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,27 +249,31 @@ __device__ __forceinline__ std::tuple<float, float> GetMinMaxValue(float* in_dat
249249
TempStorage& temp_storage) {
250250
const uint32_t tx = threadIdx.x;
251251
vec_t<float, VEC_SIZE> in_data_vec;
252-
float max_val = -cuda::std::numeric_limits<float>::infinity(),
253-
min_val = cuda::std::numeric_limits<float>::infinity();
252+
// Thread-local min/max accumulation (deferred reduction)
253+
float thread_max = -cuda::std::numeric_limits<float>::infinity();
254+
float thread_min = cuda::std::numeric_limits<float>::infinity();
255+
254256
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
255257
in_data_vec.fill(0);
256258
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
257259
in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
258260
}
259-
float in_data_[VEC_SIZE];
260261
#pragma unroll
261262
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
262-
in_data_[j] = in_data_vec[j];
263+
thread_max = max(thread_max, static_cast<float>(in_data_vec[j]));
264+
thread_min = min(thread_min, static_cast<float>(in_data_vec[j]));
263265
}
264-
max_val = max(
265-
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
266-
.Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
267-
__syncthreads();
268-
min_val = min(
269-
min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
270-
.Reduce<VEC_SIZE>(in_data_, MinReduceOp{}));
271-
__syncthreads();
272266
}
267+
268+
// Single block reduction after loop completes
269+
float max_val =
270+
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
271+
.Reduce(thread_max, MaxReduceOp{});
272+
__syncthreads();
273+
float min_val =
274+
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
275+
.Reduce(thread_min, MinReduceOp{});
276+
273277
if (tx == 0) {
274278
temp_storage.max_val = max_val;
275279
temp_storage.min_val = min_val;
@@ -288,22 +292,23 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
288292
const uint32_t tx = threadIdx.x;
289293
vec_t<float, VEC_SIZE> in_data_vec;
290294

291-
float max_val = 0;
295+
// Thread-local max accumulation (deferred reduction)
296+
float thread_max = 0.0f;
292297
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
293298
in_data_vec.fill(0);
294299
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
295300
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
296301
}
297-
float in_data_[VEC_SIZE];
298302
#pragma unroll
299303
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
300-
in_data_[j] = in_data_vec[j];
304+
thread_max = max(thread_max, static_cast<float>(in_data_vec[j]));
301305
}
302-
max_val = max(
303-
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
304-
.template Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
305-
__syncthreads();
306306
}
307+
308+
// Single block reduction after loop completes
309+
float max_val =
310+
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
311+
.Reduce(thread_max, MaxReduceOp{});
307312
if (tx == 0) {
308313
temp_storage.max_val = max_val;
309314
}

0 commit comments

Comments
 (0)