From 2129b0f32ae4e05aefe3f6ebb2af8ab09bd511d3 Mon Sep 17 00:00:00 2001 From: wangqixiang9 Date: Mon, 30 Mar 2026 17:37:08 +0800 Subject: [PATCH 1/2] perf: optimize top-k kernel for fp32 --- include/flashinfer/topk.cuh | 186 +++++++++++++++++++++++++----------- 1 file changed, 129 insertions(+), 57 deletions(-) diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 468ea55495..4982bceaa0 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -194,23 +194,45 @@ struct RadixRowState { * \param tx Thread index within the block */ template -__device__ __forceinline__ void RadixSuffixSum(uint32_t* suffix_sum, uint32_t tx) { +__device__ __forceinline__ void RadixSuffixSum(uint32_t* suffix_sum, uint32_t* scratch, + uint32_t tx) { constexpr uint32_t RADIX = 256; - // Parallel suffix sum: compute count of elements >= each bucket - for (uint32_t stride = 1; stride < RADIX; stride *= 2) { - uint32_t val = 0; - if (tx < RADIX) { - val = suffix_sum[tx]; - if (tx + stride < RADIX) { - val += suffix_sum[tx + stride]; - } + // Warp-parallel suffix sum: 256 bins across 8 warps, 3 __syncthreads (was 16) + uint32_t val = 0; + const unsigned lane = tx & 31; + const unsigned warp_id = tx >> 5; + + if (tx < RADIX) { + val = suffix_sum[tx]; + // Warp-level suffix sum using shuffle-down +#pragma unroll + for (uint32_t stride = 1; stride < 32; stride *= 2) { + uint32_t n = __shfl_down_sync(0xFFFFFFFF, val, stride); + if (lane + stride < 32) val += n; } - __syncthreads(); - if (tx < RADIX) { - suffix_sum[tx] = val; + // Lane 0 of each warp has the sum of that warp's 32 bins + if (lane == 0) { + scratch[warp_id] = val; + } + } + __syncthreads(); + + // Sequential suffix sum of 8 warp totals + if (tx == 0) { + for (int i = 6; i >= 0; --i) { + scratch[i] += scratch[i + 1]; } - __syncthreads(); } + __syncthreads(); + + // Add cross-warp correction and write result + if (tx < RADIX) { + if (warp_id < 7) { + val += scratch[warp_id + 1]; + } + suffix_sum[tx] = val; + } + __syncthreads(); } /*! @@ -383,8 +405,8 @@ __device__ __forceinline__ void RadixSelectOneRound( } __syncthreads(); - // Compute suffix sum - RadixSuffixSum(suffix_sum, tx); + // Compute suffix sum (local_histogram is free to use as scratch here) + RadixSuffixSum(suffix_sum, local_histogram, tx); // Find threshold bucket using shared_scalars for found_bucket and found_remaining_k // shared_scalars[0] = found_bucket, shared_scalars[1] = found_remaining_k @@ -420,9 +442,14 @@ __device__ __forceinline__ void LoadToSharedOrdered(const DType* input, using OrderedType = typename Traits::OrderedType; vec_t input_vec; const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; + constexpr uint32_t STRIDE = BLOCK_THREADS * VEC_SIZE; #pragma unroll 2 - for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += BLOCK_THREADS * VEC_SIZE) { + for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += STRIDE) { + // Prefetch next chunk to L2 + if (i + 2 * STRIDE < aligned_size) { + asm volatile("prefetch.global.L2 [%0];" :: "l"(input + chunk_start + i + 2 * STRIDE)); + } input_vec.cast_load(input + chunk_start + i); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { @@ -573,7 +600,7 @@ __device__ __forceinline__ OrderedType RadixSelectFromSharedMemory( __syncthreads(); // Compute suffix sum - RadixSuffixSum(suffix_sum, tx); + RadixSuffixSum(suffix_sum, local_histogram, tx); // Find threshold bucket if (tx == 0) { @@ -2055,9 +2082,14 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) const int aligned_length = (length / VEC_SIZE) * VEC_SIZE; // Full-row scan helper (vectorized body + tail). Overflow fallback reuses this traversal. auto for_each_score_full = [&](auto&& fn) { - // vectorized body + // vectorized body with L2 prefetch 2 strides ahead for better latency hiding + constexpr int STRIDE = BLOCK_SIZE * VEC_SIZE; #pragma unroll 2 - for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) { + for (int base = tx * VEC_SIZE; base < aligned_length; base += STRIDE) { + // Prefetch 2 chunks ahead to L2 — gives more time for prefetch to complete + if (base + 2 * STRIDE < aligned_length) { + asm volatile("prefetch.global.L2 [%0];" :: "l"(&score[base + 2 * STRIDE])); + } score_vec.cast_load(&score[base]); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { @@ -2076,21 +2108,44 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) for_each_score_full(accumulate_coarse_hist); __syncthreads(); - // Suffix sum (Hillis Steele Scan) + // Suffix sum: warp-parallel approach using shuffle-down + // Reduces __syncthreads from 8 to 3 per invocation (called 5+ times for fp32) const auto run_cumsum = [&]() { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - if (tx < RADIX) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; + int val = 0; + const unsigned lane = tx & 31; + const unsigned warp_id = tx >> 5; + + if (tx < RADIX) { + val = s_histogram[tx]; + // Warp-level suffix sum (32 elements per warp, 8 warps for 256 bins) +#pragma unroll + for (int stride = 1; stride < 32; stride *= 2) { + int n = __shfl_down_sync(0xFFFFFFFF, val, stride); + if (lane + stride < 32) val += n; } - __syncthreads(); + // Lane 0 of each warp has the sum of that warp's 32 bins + if (lane == 0) { + s_histogram_buf[1][warp_id] = val; + } + } + __syncthreads(); + + // Suffix sum of 8 warp totals (thread 0 only, trivial cost) + if (tx == 0) { + for (int i = 6; i >= 0; --i) { + s_histogram_buf[1][i] += s_histogram_buf[1][i + 1]; + } + } + __syncthreads(); + + // Add cross-warp correction and write final result + if (tx < RADIX) { + if (warp_id < 7) { + val += s_histogram_buf[1][warp_id + 1]; + } + s_histogram[tx] = val; } + __syncthreads(); }; auto update_refine_threshold = [&](int next_input_idx, auto reset_next_input_tag) { constexpr bool RESET_NEXT_INPUT = decltype(reset_next_input_tag)::value; @@ -2124,7 +2179,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) // Collect indices where bin > threshold auto collect_coarse_gt = [&](auto raw_input, int index) { const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); - if (bin > threshold_bin) { + if (__builtin_expect(bin > threshold_bin, 0)) { const auto pos = atomicAdd(&s_counter, 1); s_indices[pos] = index; } @@ -2132,25 +2187,33 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) for_each_score_full(collect_coarse_gt); __syncthreads(); } else { - __syncthreads(); if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); // Filter + histogram for refinement + // Use first half of each input buffer for indices, second half for ordered values. + // This avoids re-reading global memory in refine rounds at the cost of halved capacity. + constexpr int HALF_BUF = SMEM_INPUT_SIZE / 2; auto filter_and_add_to_histogram = [&](auto raw_input, int index) { const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); - if (bin > threshold_bin) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = index; - } else if (bin == threshold_bin) { - const auto pos = atomicAdd(&s_num_input[0], 1); - if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { - s_input_idx[0][pos] = index; - const auto ordered = Traits::ToOrdered(raw_input); - const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF; - atomicAdd(&s_histogram[sub_bin], 1); + if (__builtin_expect(bin >= threshold_bin, 0)) { + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = index; } else { - atomicOr(&s_refine_overflow, 1); + const auto pos = atomicAdd(&s_num_input[0], 1); + if (__builtin_expect(pos < HALF_BUF, 1)) { + const auto ordered = Traits::ToOrdered(raw_input); + s_input_idx[0][pos] = index; + s_input_idx[0][HALF_BUF + pos] = static_cast(ordered); + // For fp32: coarse pass (fp16 high 8 bits) already captures bits 24-31 info, + // so start refinement at bit 16 to skip redundant round 0. + constexpr int EFFECTIVE_FIRST_SHIFT = (FIRST_SHIFT >= 16) ? FIRST_SHIFT - 8 : FIRST_SHIFT; + const auto sub_bin = (ordered >> EFFECTIVE_FIRST_SHIFT) & 0xFF; + atomicAdd(&s_histogram[sub_bin], 1); + } else { + atomicOr(&s_refine_overflow, 1); + } } } }; @@ -2165,8 +2228,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) int threshold) { for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = score[idx]; - const auto bin = (Traits::ToOrdered(raw_input) >> offset) & 0xFF; + const auto ordered = static_cast( + static_cast(s_input_idx[r_idx][HALF_BUF + i])); + const auto bin = (ordered >> offset) & 0xFF; if (static_cast(bin) > threshold) { const auto pos = atomicAdd(&s_counter, 1); s_indices[pos] = idx; @@ -2187,17 +2251,17 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) __syncthreads(); for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = score[idx]; - const auto bin = (Traits::ToOrdered(raw_input) >> offset) & 0xFF; + const auto ordered = static_cast(s_input_idx[r_idx][HALF_BUF + i]); + const auto bin = (ordered >> offset) & 0xFF; if (static_cast(bin) > threshold) { const auto pos = atomicAdd(&s_counter, 1); s_indices[pos] = idx; } else if (static_cast(bin) == threshold) { const auto pos = atomicAdd(&s_num_input[next_r_idx], 1); - if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { + if (__builtin_expect(pos < HALF_BUF, 1)) { s_input_idx[next_r_idx][pos] = idx; - const auto bin32 = Traits::ToOrdered(raw_input); - const auto sub_bin = (bin32 >> (offset - 8)) & 0xFF; + s_input_idx[next_r_idx][HALF_BUF + pos] = static_cast(ordered); + const auto sub_bin = (ordered >> (offset - 8)) & 0xFF; atomicAdd(&s_histogram[sub_bin], 1); } else { atomicOr(&s_refine_overflow, 1); @@ -2209,7 +2273,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) auto run_refine_round = [&](int r_idx, int offset, auto is_last_round_tag) { constexpr bool IS_LAST_ROUND = decltype(is_last_round_tag)::value; const auto raw_num_input = s_num_input[r_idx]; - const auto num_input = (raw_num_input < SMEM_INPUT_SIZE) ? raw_num_input : SMEM_INPUT_SIZE; + const auto num_input = (raw_num_input < HALF_BUF) ? raw_num_input : HALF_BUF; update_refine_threshold(r_idx ^ 1, std::true_type{}); @@ -2220,7 +2284,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) // Final round reached: only collect bins strictly greater than threshold. for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = s_input_idx[r_idx][i]; - const auto bin = (Traits::ToOrdered(score[idx]) >> offset) & 0xFF; + const auto ordered = static_cast(s_input_idx[r_idx][HALF_BUF + i]); + const auto bin = (ordered >> offset) & 0xFF; if (static_cast(bin) > threshold) { const auto pos = atomicAdd(&s_counter, 1); s_indices[pos] = idx; @@ -2297,12 +2362,19 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) // Multi-round refine path (float32): if any refine-buffer overflow is detected, // switch to a correctness-first full rebuild of the threshold-bin selection. // This fallback may be slower than the fast path, but avoids partial-state corruption. + // + // For fp32: the filter_and_add_to_histogram step already builds the histogram at + // FIRST_SHIFT-8 (bit 16) instead of FIRST_SHIFT (bit 24), because the coarse pass + // (fp16 high 8 bits) captures the same information as bits 24-31. So we skip + // refine round 0 and start from round 1 with adjusted buffer indexing. + constexpr int SKIP_ROUNDS = (FIRST_SHIFT >= 16) ? 1 : 0; + constexpr int EFFECTIVE_NUM_ROUNDS = NUM_ROUNDS - SKIP_ROUNDS; if (!s_refine_overflow) { #pragma unroll - for (int round = 0; round < NUM_ROUNDS; ++round) { - const auto r_idx = round % 2; - const int offset = FIRST_SHIFT - round * 8; - if (round == NUM_ROUNDS - 1) { + for (int eff_round = 0; eff_round < EFFECTIVE_NUM_ROUNDS; ++eff_round) { + const auto r_idx = eff_round % 2; + const int offset = FIRST_SHIFT - 8 * SKIP_ROUNDS - eff_round * 8; + if (eff_round == EFFECTIVE_NUM_ROUNDS - 1) { if (run_refine_round(r_idx, offset, std::true_type{})) { break; } From dc854cb6c059a45301de82a787005679983a825a Mon Sep 17 00:00:00 2001 From: wangqixiang9 Date: Mon, 30 Mar 2026 18:06:26 +0800 Subject: [PATCH 2/2] perf: optimize top-k kernel for fp32 --- include/flashinfer/topk.cuh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 4982bceaa0..b755494738 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -448,7 +448,7 @@ __device__ __forceinline__ void LoadToSharedOrdered(const DType* input, for (uint32_t i = tx * VEC_SIZE; i < aligned_size; i += STRIDE) { // Prefetch next chunk to L2 if (i + 2 * STRIDE < aligned_size) { - asm volatile("prefetch.global.L2 [%0];" :: "l"(input + chunk_start + i + 2 * STRIDE)); + asm volatile("prefetch.global.L2 [%0];" ::"l"(input + chunk_start + i + 2 * STRIDE)); } input_vec.cast_load(input + chunk_start + i); #pragma unroll @@ -2082,13 +2082,13 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) const int aligned_length = (length / VEC_SIZE) * VEC_SIZE; // Full-row scan helper (vectorized body + tail). Overflow fallback reuses this traversal. auto for_each_score_full = [&](auto&& fn) { - // vectorized body with L2 prefetch 2 strides ahead for better latency hiding - constexpr int STRIDE = BLOCK_SIZE * VEC_SIZE; + // vectorized body with L2 prefetch 2 strides ahead for better latency hiding + constexpr int STRIDE = BLOCK_SIZE * VEC_SIZE; #pragma unroll 2 for (int base = tx * VEC_SIZE; base < aligned_length; base += STRIDE) { // Prefetch 2 chunks ahead to L2 — gives more time for prefetch to complete if (base + 2 * STRIDE < aligned_length) { - asm volatile("prefetch.global.L2 [%0];" :: "l"(&score[base + 2 * STRIDE])); + asm volatile("prefetch.global.L2 [%0];" ::"l"(&score[base + 2 * STRIDE])); } score_vec.cast_load(&score[base]); #pragma unroll @@ -2208,7 +2208,8 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) s_input_idx[0][HALF_BUF + pos] = static_cast(ordered); // For fp32: coarse pass (fp16 high 8 bits) already captures bits 24-31 info, // so start refinement at bit 16 to skip redundant round 0. - constexpr int EFFECTIVE_FIRST_SHIFT = (FIRST_SHIFT >= 16) ? FIRST_SHIFT - 8 : FIRST_SHIFT; + constexpr int EFFECTIVE_FIRST_SHIFT = + (FIRST_SHIFT >= 16) ? FIRST_SHIFT - 8 : FIRST_SHIFT; const auto sub_bin = (ordered >> EFFECTIVE_FIRST_SHIFT) & 0xFF; atomicAdd(&s_histogram[sub_bin], 1); } else {