Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,31 @@ __device__ __forceinline__ std::tuple<float, float> GetMinMaxValue(float* in_dat
TempStorage& temp_storage) {
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;
float max_val = -cuda::std::numeric_limits<float>::infinity(),
min_val = cuda::std::numeric_limits<float>::infinity();
// Thread-local min/max accumulation (deferred reduction)
float thread_max = -cuda::std::numeric_limits<float>::infinity();
float thread_min = cuda::std::numeric_limits<float>::infinity();

for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
thread_max = max(thread_max, static_cast<float>(in_data_vec[j]));
thread_min = min(thread_min, static_cast<float>(in_data_vec[j]));
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
__syncthreads();
min_val = min(
min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, MinReduceOp{}));
__syncthreads();
}
Comment on lines 256 to 266
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reduction loop has correctness issues with out-of-bounds data handling. It processes padded zeros for threads that are entirely out of bounds and doesn't handle tail elements for partially out-of-bounds vectors. This can lead to incorrect min/max values.

The accumulation should only happen for valid data. The following change corrects the logic:

Please note that cast_load may still read out of bounds if the input is not padded. A fully robust solution might require masked loads or scalar access for tail elements, but the suggested change fixes the logical error in reduction.

  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
    const uint32_t base_offset = (i * BLOCK_THREADS + tx) * VEC_SIZE;
    if (base_offset < d) {
      in_data_vec.cast_load(in_data + row_idx * d + base_offset);
#pragma unroll
      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
        if (base_offset + j < d) {
          thread_max = max(thread_max, in_data_vec[j]);
          thread_min = min(thread_min, in_data_vec[j]);
        }
      }
    }
  }


// Single block reduction after loop completes
float max_val =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(thread_max, MaxReduceOp{});
__syncthreads();
float min_val =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(thread_min, MinReduceOp{});

if (tx == 0) {
temp_storage.max_val = max_val;
temp_storage.min_val = min_val;
Expand All @@ -288,22 +292,23 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;

float max_val = 0;
// Thread-local max accumulation (deferred reduction)
float thread_max = 0.0f;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
thread_max = max(thread_max, static_cast<float>(in_data_vec[j]));
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.template Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
__syncthreads();
}
Comment on lines +295 to 306
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function has two correctness issues:

  1. Incorrect Initialization: thread_max is initialized to 0.0f on line 296. If all input values are negative, this will cause the function to incorrectly return 0 instead of the true maximum. It should be initialized to -infinity.

  2. Faulty Reduction Logic: The loop structure from lines 297-306 suffers from the same issues as in GetMinMaxValue. It processes padded zeros for out-of-bounds threads and doesn't handle tail elements correctly, leading to incorrect results.

Here is a suggested fix that addresses both issues:

Please note that cast_load may still read out of bounds if the input is not padded. A fully robust solution might require masked loads or scalar access for tail elements, but the suggested change fixes the logical error in reduction.

  // Thread-local max accumulation (deferred reduction)
  float thread_max = -cuda::std::numeric_limits<float>::infinity();
  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
    const uint32_t base_offset = (i * BLOCK_THREADS + tx) * VEC_SIZE;
    if (base_offset < d) {
      in_data_vec.cast_load(in_data + row_idx * d + base_offset);
#pragma unroll
      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
        if (base_offset + j < d) {
          thread_max = max(thread_max, in_data_vec[j]);
        }
      }
    }
  }


// Single block reduction after loop completes
float max_val =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(thread_max, MaxReduceOp{});
if (tx == 0) {
temp_storage.max_val = max_val;
}
Expand Down