Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 130 additions & 57 deletions include/flashinfer/topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -194,23 +194,45 @@ struct RadixRowState {
* \param tx Thread index within the block
*/
template <uint32_t BLOCK_THREADS>
__device__ __forceinline__ void RadixSuffixSum(uint32_t* suffix_sum, uint32_t tx) {
__device__ __forceinline__ void RadixSuffixSum(uint32_t* suffix_sum, uint32_t* scratch,
uint32_t tx) {
constexpr uint32_t RADIX = 256;
// Parallel suffix sum: compute count of elements >= each bucket
for (uint32_t stride = 1; stride < RADIX; stride *= 2) {
uint32_t val = 0;
if (tx < RADIX) {
val = suffix_sum[tx];
if (tx + stride < RADIX) {
val += suffix_sum[tx + stride];
}
// Warp-parallel suffix sum: 256 bins across 8 warps, 3 __syncthreads (was 16)
uint32_t val = 0;
const unsigned lane = tx & 31;
const unsigned warp_id = tx >> 5;

if (tx < RADIX) {
val = suffix_sum[tx];
// Warp-level suffix sum using shuffle-down
#pragma unroll
for (uint32_t stride = 1; stride < 32; stride *= 2) {
uint32_t n = __shfl_down_sync(0xFFFFFFFF, val, stride);
if (lane + stride < 32) val += n;
}
__syncthreads();
if (tx < RADIX) {
suffix_sum[tx] = val;
// Lane 0 of each warp has the sum of that warp's 32 bins
if (lane == 0) {
scratch[warp_id] = val;
}
}
__syncthreads();

// Sequential suffix sum of 8 warp totals
if (tx == 0) {
for (int i = 6; i >= 0; --i) {
scratch[i] += scratch[i + 1];
}
__syncthreads();
}
__syncthreads();

// Add cross-warp correction and write result
if (tx < RADIX) {
if (warp_id < 7) {
val += scratch[warp_id + 1];
}
suffix_sum[tx] = val;
}
__syncthreads();
}

/*!
Expand Down Expand Up @@ -383,8 +405,8 @@ __device__ __forceinline__ void RadixSelectOneRound(
}
__syncthreads();

// Compute suffix sum
RadixSuffixSum<BLOCK_THREADS>(suffix_sum, tx);
// Compute suffix sum (local_histogram is free to use as scratch here)
RadixSuffixSum<BLOCK_THREADS>(suffix_sum, local_histogram, tx);

// Find threshold bucket using shared_scalars for found_bucket and found_remaining_k
// shared_scalars[0] = found_bucket, shared_scalars[1] = found_remaining_k
Expand Down Expand Up @@ -420,9 +442,14 @@ __device__ __forceinline__ void LoadToSharedOrdered(const DType* input,
using OrderedType = typename Traits::OrderedType;
vec_t<DType, VEC_SIZE> input_vec;
const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE;
constexpr uint32_t STRIDE = BLOCK_THREADS * VEC_SIZE;

#pragma unroll 2
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) {
for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += STRIDE) {
// Prefetch next chunk to L2
if (i + 2 * STRIDE < aligned_size) {
asm volatile("prefetch.global.L2 [%0];" ::"l"(input + chunk_start + i + 2 * STRIDE));
}
input_vec.cast_load(input + chunk_start + i);
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
Expand Down Expand Up @@ -573,7 +600,7 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory(
__syncthreads();

// Compute suffix sum
RadixSuffixSum<BLOCK_THREADS>(suffix_sum, tx);
RadixSuffixSum<BLOCK_THREADS>(suffix_sum, local_histogram, tx);

// Find threshold bucket
if (tx == 0) {
Expand Down Expand Up @@ -2055,9 +2082,14 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
const int aligned_length = (length / VEC_SIZE) * VEC_SIZE;
// Full-row scan helper (vectorized body + tail). Overflow fallback reuses this traversal.
auto for_each_score_full = [&](auto&& fn) {
// vectorized body
// vectorized body with L2 prefetch 2 strides ahead for better latency hiding
constexpr int STRIDE = BLOCK_SIZE * VEC_SIZE;
#pragma unroll 2
for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) {
for (int base = tx * VEC_SIZE; base < aligned_length; base += STRIDE) {
// Prefetch 2 chunks ahead to L2 — gives more time for prefetch to complete
if (base + 2 * STRIDE < aligned_length) {
asm volatile("prefetch.global.L2 [%0];" ::"l"(&score[base + 2 * STRIDE]));
}
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
Expand All @@ -2076,21 +2108,44 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
for_each_score_full(accumulate_coarse_hist);
__syncthreads();

// Suffix sum (Hillis Steele Scan)
// Suffix sum: warp-parallel approach using shuffle-down
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warp-level suffix sum code here is duplicated from above. Would it be possible to refactor them into a single function?

// Reduces __syncthreads from 8 to 3 per invocation (called 5+ times for fp32)
const auto run_cumsum = [&]() {
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
if (tx < RADIX) {
const auto j = 1 << i;
const auto k = i & 1;
auto value = s_histogram_buf[k][tx];
if (tx < RADIX - j) {
value += s_histogram_buf[k][tx + j];
}
s_histogram_buf[k ^ 1][tx] = value;
int val = 0;
const unsigned lane = tx & 31;
const unsigned warp_id = tx >> 5;

if (tx < RADIX) {
val = s_histogram[tx];
// Warp-level suffix sum (32 elements per warp, 8 warps for 256 bins)
#pragma unroll
for (int stride = 1; stride < 32; stride *= 2) {
int n = __shfl_down_sync(0xFFFFFFFF, val, stride);
if (lane + stride < 32) val += n;
}
__syncthreads();
// Lane 0 of each warp has the sum of that warp's 32 bins
if (lane == 0) {
s_histogram_buf[1][warp_id] = val;
}
}
__syncthreads();

// Suffix sum of 8 warp totals (thread 0 only, trivial cost)
if (tx == 0) {
for (int i = 6; i >= 0; --i) {
s_histogram_buf[1][i] += s_histogram_buf[1][i + 1];
}
}
__syncthreads();

// Add cross-warp correction and write final result
if (tx < RADIX) {
if (warp_id < 7) {
val += s_histogram_buf[1][warp_id + 1];
}
s_histogram[tx] = val;
}
__syncthreads();
};
auto update_refine_threshold = [&](int next_input_idx, auto reset_next_input_tag) {
constexpr bool RESET_NEXT_INPUT = decltype(reset_next_input_tag)::value;
Expand Down Expand Up @@ -2124,33 +2179,42 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
// Collect indices where bin > threshold
auto collect_coarse_gt = [&](auto raw_input, int index) {
const auto bin = static_cast<int>(Traits::ToCoarseKey(raw_input));
if (bin > threshold_bin) {
if (__builtin_expect(bin > threshold_bin, 0)) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = index;
}
};
for_each_score_full(collect_coarse_gt);
__syncthreads();
} else {
__syncthreads();
if (tx < RADIX + 1) s_histogram[tx] = 0;
__syncthreads();

// Filter + histogram for refinement
// Use first half of each input buffer for indices, second half for ordered values.
// This avoids re-reading global memory in refine rounds at the cost of halved capacity.
constexpr int HALF_BUF = SMEM_INPUT_SIZE / 2;
auto filter_and_add_to_histogram = [&](auto raw_input, int index) {
const auto bin = static_cast<int>(Traits::ToCoarseKey(raw_input));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = index;
} else if (bin == threshold_bin) {
const auto pos = atomicAdd(&s_num_input[0], 1);
if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) {
s_input_idx[0][pos] = index;
const auto ordered = Traits::ToOrdered(raw_input);
const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF;
atomicAdd(&s_histogram[sub_bin], 1);
if (__builtin_expect(bin >= threshold_bin, 0)) {
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = index;
} else {
atomicOr(&s_refine_overflow, 1);
const auto pos = atomicAdd(&s_num_input[0], 1);
if (__builtin_expect(pos < HALF_BUF, 1)) {
const auto ordered = Traits::ToOrdered(raw_input);
s_input_idx[0][pos] = index;
s_input_idx[0][HALF_BUF + pos] = static_cast<int>(ordered);
// For fp32: coarse pass (fp16 high 8 bits) already captures bits 24-31 info,
// so start refinement at bit 16 to skip redundant round 0.
constexpr int EFFECTIVE_FIRST_SHIFT =
(FIRST_SHIFT >= 16) ? FIRST_SHIFT - 8 : FIRST_SHIFT;
const auto sub_bin = (ordered >> EFFECTIVE_FIRST_SHIFT) & 0xFF;
atomicAdd(&s_histogram[sub_bin], 1);
} else {
atomicOr(&s_refine_overflow, 1);
}
}
}
};
Expand All @@ -2165,8 +2229,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
int threshold) {
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto raw_input = score[idx];
const auto bin = (Traits::ToOrdered(raw_input) >> offset) & 0xFF;
const auto ordered = static_cast<typename Traits::OrderedType>(
static_cast<uint32_t>(s_input_idx[r_idx][HALF_BUF + i]));
const auto bin = (ordered >> offset) & 0xFF;
if (static_cast<int>(bin) > threshold) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = idx;
Expand All @@ -2187,17 +2252,17 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
__syncthreads();
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto raw_input = score[idx];
const auto bin = (Traits::ToOrdered(raw_input) >> offset) & 0xFF;
const auto ordered = static_cast<uint32_t>(s_input_idx[r_idx][HALF_BUF + i]);
const auto bin = (ordered >> offset) & 0xFF;
if (static_cast<int>(bin) > threshold) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = idx;
} else if (static_cast<int>(bin) == threshold) {
const auto pos = atomicAdd(&s_num_input[next_r_idx], 1);
if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) {
if (__builtin_expect(pos < HALF_BUF, 1)) {
s_input_idx[next_r_idx][pos] = idx;
const auto bin32 = Traits::ToOrdered(raw_input);
const auto sub_bin = (bin32 >> (offset - 8)) & 0xFF;
s_input_idx[next_r_idx][HALF_BUF + pos] = static_cast<int>(ordered);
const auto sub_bin = (ordered >> (offset - 8)) & 0xFF;
atomicAdd(&s_histogram[sub_bin], 1);
} else {
atomicOr(&s_refine_overflow, 1);
Expand All @@ -2209,7 +2274,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
auto run_refine_round = [&](int r_idx, int offset, auto is_last_round_tag) {
constexpr bool IS_LAST_ROUND = decltype(is_last_round_tag)::value;
const auto raw_num_input = s_num_input[r_idx];
const auto num_input = (raw_num_input < SMEM_INPUT_SIZE) ? raw_num_input : SMEM_INPUT_SIZE;
const auto num_input = (raw_num_input < HALF_BUF) ? raw_num_input : HALF_BUF;

update_refine_threshold(r_idx ^ 1, std::true_type{});

Expand All @@ -2220,7 +2285,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
// Final round reached: only collect bins strictly greater than threshold.
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto bin = (Traits::ToOrdered(score[idx]) >> offset) & 0xFF;
const auto ordered = static_cast<uint32_t>(s_input_idx[r_idx][HALF_BUF + i]);
const auto bin = (ordered >> offset) & 0xFF;
if (static_cast<int>(bin) > threshold) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = idx;
Expand Down Expand Up @@ -2297,12 +2363,19 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
// Multi-round refine path (float32): if any refine-buffer overflow is detected,
// switch to a correctness-first full rebuild of the threshold-bin selection.
// This fallback may be slower than the fast path, but avoids partial-state corruption.
//
// For fp32: the filter_and_add_to_histogram step already builds the histogram at
// FIRST_SHIFT-8 (bit 16) instead of FIRST_SHIFT (bit 24), because the coarse pass
// (fp16 high 8 bits) captures the same information as bits 24-31. So we skip
// refine round 0 and start from round 1 with adjusted buffer indexing.
constexpr int SKIP_ROUNDS = (FIRST_SHIFT >= 16) ? 1 : 0;
constexpr int EFFECTIVE_NUM_ROUNDS = NUM_ROUNDS - SKIP_ROUNDS;
if (!s_refine_overflow) {
#pragma unroll
for (int round = 0; round < NUM_ROUNDS; ++round) {
const auto r_idx = round % 2;
const int offset = FIRST_SHIFT - round * 8;
if (round == NUM_ROUNDS - 1) {
for (int eff_round = 0; eff_round < EFFECTIVE_NUM_ROUNDS; ++eff_round) {
const auto r_idx = eff_round % 2;
const int offset = FIRST_SHIFT - 8 * SKIP_ROUNDS - eff_round * 8;
if (eff_round == EFFECTIVE_NUM_ROUNDS - 1) {
if (run_refine_round(r_idx, offset, std::true_type{})) {
break;
}
Expand Down
Loading