diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index a05ee886f5b9..362c733c8d73 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -18,10 +18,9 @@ steps: source_file_dependencies: - csrc/ - tests/kernels/core - - tests/kernels/test_top_k_per_row.py - tests/kernels/test_concat_mla_q.py commands: - - pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py + - pytest -v -s kernels/core kernels/test_concat_mla_q.py - label: Kernels Attention Test %N timeout_in_minutes: 35 @@ -107,6 +106,7 @@ steps: - vllm/v1/attention/backends/mla/flashinfer_mla.py - vllm/v1/attention/selector.py - vllm/platforms/cuda.py + - tests/kernels/test_top_k_per_row.py commands: - nvidia-smi - python3 examples/basic/offline_inference/chat.py @@ -117,6 +117,7 @@ steps: - pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py - pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py - pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py + - pytest -v -s tests/kernels/test_top_k_per_row.py # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py diff --git a/csrc/ops.h b/csrc/ops.h index cc58422231ff..686aa5196603 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -114,9 +114,9 @@ void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n, int64_t numRows, int64_t stride0, int64_t stride1, int64_t topK); -void large_context_topk(const torch::Tensor& score, torch::Tensor& indices, - const torch::Tensor& lengths, - std::optional row_starts_opt); +void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, + torch::Tensor& output, torch::Tensor& workspace, int64_t k, + int64_t max_seq_len); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, diff --git a/csrc/persistent_topk.cuh b/csrc/persistent_topk.cuh new file mode 100644 index 000000000000..694fedad39f1 --- /dev/null +++ b/csrc/persistent_topk.cuh @@ -0,0 +1,1321 @@ +/* + * Persistent TopK Scheduler for DSA Indexer + */ + +#ifndef PERSISTENT_TOPK_CUH_ +#define PERSISTENT_TOPK_CUH_ + +#include +#include +#include +#include +#include + +namespace vllm { +namespace persistent { + +// ============================================================================ +// Constants +// ============================================================================ + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; +constexpr int RADIX = 256; + +// Medium path: all shared state in dynamic smem (no static __shared__, +// which would inflate the kernel's smem footprint and kill occupancy +// for the decode/trivial paths). +constexpr size_t kMediumHistBytes = 2 * (RADIX + 128) * sizeof(int); // 3072 +constexpr size_t kMediumScalarsBytes = 5 * sizeof(int); // 20 +constexpr size_t kMediumHeaderSize = + (kMediumHistBytes + kMediumScalarsBytes + 127) & ~size_t(127); // 3200 +constexpr int MAX_BUFFERED_ITEMS = 4096; +constexpr size_t kSmemMedium = + kMediumHeaderSize + 2 * MAX_BUFFERED_ITEMS * sizeof(int); // 35968 +constexpr uint32_t RADIX_THRESHOLD = 32768; + +// Decode path constants +constexpr int kDecodeBins = 2048; +constexpr uint32_t HIST2048_THRESHOLD = 8192; + +// Large path: fixed shared memory for histograms + scalars +constexpr size_t kFixedSmemLarge = + ((RADIX + RADIX + 5) * sizeof(uint32_t) + 15) & ~size_t(15); + +// ============================================================================ +// Common helpers +// ============================================================================ + +__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +// ============================================================================ +// Vectorized load helpers +// ============================================================================ + +// Unconditional float4 load with cache hint (.cg = cache at global level only). +__device__ __forceinline__ void load_float4(const float* ptr, float& v0, + float& v1, float& v2, float& v3) { + uint32_t r0, r1, r2, r3; + asm volatile("ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%4];\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) + : "l"(ptr)); + v0 = __uint_as_float(r0); + v1 = __uint_as_float(r1); + v2 = __uint_as_float(r2); + v3 = __uint_as_float(r3); +} + +// Per-element predicated scalar loads with -inf default. +__device__ __forceinline__ void load_float4_predicated(const float* ptr, + int base, int seq_len, + float& v0, float& v1, + float& v2, float& v3) { + uint32_t r0, r1, r2, r3; + int p0 = (base < seq_len); + int p1 = (base + 1 < seq_len); + int p2 = (base + 2 < seq_len); + int p3 = (base + 3 < seq_len); + asm volatile( + "{\n" + " .reg .pred pr0, pr1, pr2, pr3;\n" + " setp.ne.u32 pr0, %4, 0;\n" + " setp.ne.u32 pr1, %5, 0;\n" + " setp.ne.u32 pr2, %6, 0;\n" + " setp.ne.u32 pr3, %7, 0;\n" + " mov.u32 %0, 0xFF800000;\n" + " mov.u32 %1, 0xFF800000;\n" + " mov.u32 %2, 0xFF800000;\n" + " mov.u32 %3, 0xFF800000;\n" + " @pr0 ld.global.cg.u32 %0, [%8];\n" + " @pr1 ld.global.cg.u32 %1, [%8+4];\n" + " @pr2 ld.global.cg.u32 %2, [%8+8];\n" + " @pr3 ld.global.cg.u32 %3, [%8+12];\n" + "}\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) + : "r"(p0), "r"(p1), "r"(p2), "r"(p3), "l"(ptr)); + v0 = __uint_as_float(r0); + v1 = __uint_as_float(r1); + v2 = __uint_as_float(r2); + v3 = __uint_as_float(r3); +} + +// ============================================================================ +// Large path: inter-CTA coordination state (one per group) +// ============================================================================ + +struct RadixRowState { + uint32_t histogram[3][256]; // Triple-buffered histograms + uint32_t remaining_k; + uint32_t prefix; + int arrival_counter; + int output_counter; +}; + +// ============================================================================ +// Kernel parameters +// ============================================================================ + +struct PersistentTopKParams { + const float* __restrict__ input; // [num_rows, stride] + int32_t* __restrict__ output; // [num_rows, TopK] + int32_t* __restrict__ lengths; // [num_rows] + RadixRowState* row_states; // large path: per-group state + uint32_t num_rows; + uint32_t stride; + uint32_t chunk_size; // large path: elements per CTA + uint32_t ctas_per_group; // 1=medium, >1=large + uint32_t max_seq_len; // max seq_len across all rows (for early CTA exit) +}; + +// ============================================================================ +// Decode path: 2048-bin histogram for short sequences (seq_len <= 8192) +// Uses 11-bit half-precision bins for fine granularity. +// One histogram pass typically suffices since 8192/2048 = 4 elements/bin avg. +// ============================================================================ + +// 11-bit bin from half-precision representation (ascending: high values -> high +// bins) +__device__ __forceinline__ uint32_t decode_bin(float x) { + __half hx = __float2half(x); + uint16_t bits = __half_as_ushort(hx); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return key >> 5; +} + +__device__ __noinline__ void histogram_2048_topk( + const float* __restrict__ logits, int32_t* __restrict__ output_indices, + int32_t seq_len) { + extern __shared__ int decode_smem[]; + const int tx = threadIdx.x; + const int lane = tx & 31; + + // ---- Layout constants ---- + constexpr int SBASE = 8192 - 8; // 8184 + constexpr int RHIST = RADIX + 128; // 384 + constexpr int BOFF = 2 * RHIST; // 768 + constexpr int DBUF = (SBASE - BOFF) / 2; // 3708 + constexpr int MAX_ITEMS_PER_THREAD = + (HIST2048_THRESHOLD + kThreadsPerBlock - 1) / kThreadsPerBlock; + + enum : int { sTHR = 0, sOUT = 1, sREF = 2, sFIN = 3, sBUF0 = 4, sBUF1 = 5 }; + + // ---- Initialize scalars (prevents stale data from prior rows) ---- + if (tx < 8) { + decode_smem[SBASE + tx] = 0; + } + + // ---- Phase 1: Build 2048-bin histogram with float4 vectorized loads ---- + int* histo = decode_smem; + uint16_t reg_bins[MAX_ITEMS_PER_THREAD]; + int nitems = 0; + + for (int i = tx; i < kDecodeBins; i += kThreadsPerBlock) { + histo[i] = 0; + } + __syncthreads(); + + const int n_vec = (seq_len + 3) >> 2; + const bool row_aligned = ((reinterpret_cast(logits) & 15) == 0); + + for (int i = tx; i < n_vec; i += kThreadsPerBlock) { + const int base = i << 2; + float v0, v1, v2, v3; + + if (row_aligned && base + 3 < seq_len) { + load_float4(logits + base, v0, v1, v2, v3); + } else { + load_float4_predicated(logits + base, base, seq_len, v0, v1, v2, v3); + } + + const uint16_t b0 = static_cast(decode_bin(v0)); + const uint16_t b1 = static_cast(decode_bin(v1)); + const uint16_t b2 = static_cast(decode_bin(v2)); + const uint16_t b3 = static_cast(decode_bin(v3)); + reg_bins[nitems++] = b0; + reg_bins[nitems++] = b1; + reg_bins[nitems++] = b2; + reg_bins[nitems++] = b3; + atomicAdd(&histo[b0], 1); + atomicAdd(&histo[b1], 1); + atomicAdd(&histo[b2], 1); + atomicAdd(&histo[b3], 1); + } + __syncthreads(); + + // ---- CUB suffix sum ---- + using BlockScanT = cub::BlockScan; + const int h0 = histo[2 * tx]; + const int pair_sum = h0 + histo[2 * tx + 1]; + + auto& scan_storage = *reinterpret_cast( + decode_smem + kDecodeBins); + + int pair_prefix, total; + BlockScanT(scan_storage).ExclusiveSum(pair_sum, pair_prefix, total); + + // Find threshold bin purely from registers + const int pair_suffix = total - pair_prefix; + + if (pair_suffix >= TopK && (pair_suffix - h0) < TopK) { + decode_smem[SBASE + sTHR] = 2 * tx; + } + { + const int right_suf = pair_suffix - h0; + const int next_suf = pair_suffix - pair_sum; + if (right_suf >= TopK && next_suf < TopK) { + decode_smem[SBASE + sTHR] = 2 * tx + 1; + } + } + __syncthreads(); + + const int threshold = decode_smem[SBASE + sTHR]; + + // ---- Phase 2: Collection with warp-aggregated atomicAdds ---- + int* bufs[2] = {decode_smem + BOFF, decode_smem + BOFF + DBUF}; + const int sOUT_abs = SBASE + sOUT; + const int sBUF0_abs = SBASE + sBUF0; + + { + const uint32_t uthr = static_cast(threshold); + int item = 0; + const int n_vec_iters = (n_vec + kThreadsPerBlock - 1) / kThreadsPerBlock; + + for (int iter = 0; iter < n_vec_iters; iter++) { + const int i = tx + iter * kThreadsPerBlock; + const bool vec_valid = (i < n_vec); + const int base_idx = i << 2; + +#pragma unroll 4 + for (int sub = 0; sub < 4; sub++) { + const int elem_idx = base_idx + sub; + uint32_t bin = 0; + if (vec_valid) bin = reg_bins[item++]; + const bool is_above = vec_valid && (bin > uthr); + const bool is_equal = vec_valid && (bin == uthr); + + const uint32_t above_mask = __ballot_sync(0xffffffff, is_above); + if (above_mask) { + const int above_count = __popc(above_mask); + const int above_rank = __popc(above_mask & ((1u << lane) - 1)); + int above_base; + if (lane == 0) { + above_base = atomicAdd(&decode_smem[sOUT_abs], above_count); + } + above_base = __shfl_sync(0xffffffff, above_base, 0); + if (is_above) { + output_indices[above_base + above_rank] = elem_idx; + } + } + + const uint32_t equal_mask = __ballot_sync(0xffffffff, is_equal); + if (equal_mask) { + const int equal_count = __popc(equal_mask); + const int equal_rank = __popc(equal_mask & ((1u << lane) - 1)); + int equal_base; + if (lane == 0) { + equal_base = atomicAdd(&decode_smem[sBUF0_abs], equal_count); + } + equal_base = __shfl_sync(0xffffffff, equal_base, 0); + if (is_equal && __builtin_expect(equal_base + equal_rank < DBUF, 1)) { + bufs[0][equal_base + equal_rank] = elem_idx; + } + } + } + } + } + __syncthreads(); + + int remaining_k = TopK - decode_smem[SBASE + sOUT]; + if (remaining_k <= 0) return; + + // If all buffered elements fit, output them all (common for short seqs) + const int raw_buf0 = decode_smem[SBASE + sBUF0]; + if (raw_buf0 <= remaining_k) { + const int nb = (raw_buf0 < DBUF) ? raw_buf0 : DBUF; + const int base = decode_smem[SBASE + sOUT]; + for (int i = tx; i < nb; i += kThreadsPerBlock) { + output_indices[base + i] = bufs[0][i]; + } + __syncthreads(); + return; + } + + // ---- Phase 3: Deferred refinement (rare path) ---- + int* refine[2] = {decode_smem, decode_smem + RHIST}; + const int num_buf0 = (raw_buf0 < DBUF) ? raw_buf0 : DBUF; + + for (int i = tx; i < RHIST; i += kThreadsPerBlock) { + refine[0][i] = 0; + } + __syncthreads(); + + for (int i = tx; i < num_buf0; i += kThreadsPerBlock) { + const uint32_t fp32 = convert_to_uint32_v2(logits[bufs[0][i]]); + atomicAdd(&refine[0][(fp32 >> 24) & 0xFF], 1); + } + __syncthreads(); + + auto compute_suffix_sum = [&]() { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + if (tx < RADIX) { + const int stride = 1 << i; + const int s = i & 1; + const int d = s ^ 1; + int value = refine[s][tx]; + if (tx < RADIX - stride) value += refine[s][tx + stride]; + refine[d][tx] = value; + } + __syncthreads(); + } + }; + +#pragma unroll 4 + for (int pass = 0; pass < 4; ++pass) { + const int src = pass & 1; + const int dst = src ^ 1; + + const int raw_buf = decode_smem[SBASE + sBUF0 + src]; + const int num_buffered = (raw_buf < DBUF) ? raw_buf : DBUF; + + compute_suffix_sum(); + + if (tx < RADIX && refine[0][tx] > remaining_k && + refine[0][tx + 1] <= remaining_k) { + decode_smem[SBASE + sREF] = tx; + decode_smem[SBASE + sBUF0 + dst] = 0; + decode_smem[SBASE + sFIN] = remaining_k - refine[0][tx + 1]; + } + __syncthreads(); + + const int ref_thr = decode_smem[SBASE + sREF]; + remaining_k -= refine[0][ref_thr + 1]; + const int bit_offset = 24 - pass * 8; + + if (remaining_k == 0) { + for (int i = tx; i < num_buffered; i += kThreadsPerBlock) { + const int idx = bufs[src][i]; + const uint32_t fp32 = convert_to_uint32_v2(logits[idx]); + if (((fp32 >> bit_offset) & 0xFF) > static_cast(ref_thr)) { + const int pos = atomicAdd(&decode_smem[SBASE + sOUT], 1); + output_indices[pos] = idx; + } + } + __syncthreads(); + break; + } + + __syncthreads(); + if (tx < RADIX + 1) refine[0][tx] = 0; + __syncthreads(); + + for (int i = tx; i < num_buffered; i += kThreadsPerBlock) { + const int idx = bufs[src][i]; + const float logit_val = logits[idx]; + const uint32_t fp32 = convert_to_uint32_v2(logit_val); + const int bin = (fp32 >> bit_offset) & 0xFF; + + if (bin > ref_thr) { + const int pos = atomicAdd(&decode_smem[SBASE + sOUT], 1); + output_indices[pos] = idx; + } else if (bin == ref_thr) { + if (pass == 3) { + const int slot = atomicAdd(&decode_smem[SBASE + sFIN], -1); + if (slot > 0) output_indices[TopK - slot] = idx; + } else { + const int bp = atomicAdd(&decode_smem[SBASE + sBUF0 + dst], 1); + if (__builtin_expect(bp < DBUF, 1)) { + bufs[dst][bp] = idx; + const int nbo = bit_offset - 8; + atomicAdd(&refine[0][(fp32 >> nbo) & 0xFF], 1); + } + } + } + } + __syncthreads(); + } +} + +// ============================================================================ +// Medium path: coarse FP16 histogram + 4-pass FP32 radix refinement +// For sequences 8K < seq_len <= 64K. +// ============================================================================ + +// Adapted from: +// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87 +// by: DarkSharpness +// which at the same time is an optimized topk kernel copied from tilelang +// kernel +__device__ __noinline__ void histogram_256_topk( + const float* __restrict__ logits, int* __restrict__ output_indices, + int logits_offset, int seq_len) { + // All shared state lives in dynamic shared memory to avoid static + extern __shared__ char medium_smem[]; + + int (*shared_histogram)[RADIX + 128] = + reinterpret_cast(medium_smem); + int* medium_scalars = reinterpret_cast(medium_smem + kMediumHistBytes); + int& shared_output_count = medium_scalars[0]; + int& shared_threshold_bin = medium_scalars[1]; + int* shared_buffered_count = &medium_scalars[2]; + int& shared_final_k = medium_scalars[4]; + int (*buffered_indices)[MAX_BUFFERED_ITEMS] = + reinterpret_cast(medium_smem + + kMediumHeaderSize); + + const int thread_id = threadIdx.x; + int remaining_k = TopK; + + if (thread_id < RADIX + 1) { + shared_histogram[0][thread_id] = 0; + } + __syncthreads(); + + for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { + const auto bin = convert_to_uint8(logits[idx + logits_offset]); + atomicAdd(&shared_histogram[0][bin], 1); + } + __syncthreads(); + + auto compute_cumulative_sum = [&]() { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + if (__builtin_expect(thread_id < RADIX, 1)) { + const int stride = 1 << i; + const int src_buffer = i & 1; + const int dst_buffer = src_buffer ^ 1; + int value = shared_histogram[src_buffer][thread_id]; + if (thread_id < RADIX - stride) { + value += shared_histogram[src_buffer][thread_id + stride]; + } + shared_histogram[dst_buffer][thread_id] = value; + } + __syncthreads(); + } + }; + + compute_cumulative_sum(); + + if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && + shared_histogram[0][thread_id + 1] <= remaining_k) { + shared_threshold_bin = thread_id; + shared_buffered_count[0] = 0; + shared_output_count = 0; + } + __syncthreads(); + + const int threshold_bin = shared_threshold_bin; + remaining_k -= shared_histogram[0][threshold_bin + 1]; + + if (remaining_k == 0) { + for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { + const int bin = convert_to_uint8(logits[idx + logits_offset]); + if (bin > threshold_bin) { + const int output_pos = atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } + } + __syncthreads(); + return; + } + + __syncthreads(); + if (thread_id < RADIX + 1) { + shared_histogram[0][thread_id] = 0; + } + __syncthreads(); + + for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { + const float logit_value = logits[idx + logits_offset]; + const int bin = convert_to_uint8(logit_value); + if (bin > threshold_bin) { + const int output_pos = atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } else if (bin == threshold_bin) { + const int buffer_pos = atomicAdd(&shared_buffered_count[0], 1); + if (__builtin_expect(buffer_pos < MAX_BUFFERED_ITEMS, 1)) { + buffered_indices[0][buffer_pos] = idx; + const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); + const int next_bin = (fp32_bits >> 24) & 0xFF; + atomicAdd(&shared_histogram[0][next_bin], 1); + } + } + } + __syncthreads(); + +#pragma unroll 4 + for (int pass = 0; pass < 4; ++pass) { + const int src_buffer = pass % 2; + const int dst_buffer = src_buffer ^ 1; + const int raw_buffered = shared_buffered_count[src_buffer]; + const int num_buffered = + (raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS; + + compute_cumulative_sum(); + + if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && + shared_histogram[0][thread_id + 1] <= remaining_k) { + shared_threshold_bin = thread_id; + shared_buffered_count[dst_buffer] = 0; + shared_final_k = remaining_k - shared_histogram[0][thread_id + 1]; + } + __syncthreads(); + + const int threshold_bin = shared_threshold_bin; + remaining_k -= shared_histogram[0][threshold_bin + 1]; + const int bit_offset = 24 - pass * 8; + + if (remaining_k == 0) { + for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { + const int idx = buffered_indices[src_buffer][i]; + const uint32_t fp32_bits = + convert_to_uint32_v2(logits[idx + logits_offset]); + const int bin = (fp32_bits >> bit_offset) & 0xFF; + if (bin > threshold_bin) { + const int output_pos = atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } + } + __syncthreads(); + break; + } + + __syncthreads(); + if (thread_id < RADIX + 1) { + shared_histogram[0][thread_id] = 0; + } + __syncthreads(); + + for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { + const int idx = buffered_indices[src_buffer][i]; + const float logit_value = logits[idx + logits_offset]; + const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); + const int bin = (fp32_bits >> bit_offset) & 0xFF; + if (bin > threshold_bin) { + const int output_pos = atomicAdd(&shared_output_count, 1); + output_indices[output_pos] = idx; + } else if (bin == threshold_bin) { + if (pass == 3) { + const int slot = atomicAdd(&shared_final_k, -1); + if (slot > 0) { + output_indices[TopK - slot] = idx; + } + } else { + const int buffer_pos = + atomicAdd(&shared_buffered_count[dst_buffer], 1); + if (__builtin_expect(buffer_pos < MAX_BUFFERED_ITEMS, 1)) { + buffered_indices[dst_buffer][buffer_pos] = idx; + const int next_bit_offset = bit_offset - 8; + const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF; + atomicAdd(&shared_histogram[0][next_bin], 1); + } + } + } + } + __syncthreads(); + } +} + +// ============================================================================ +// Inter-CTA sync primitives +// ============================================================================ + +__device__ __forceinline__ int ld_acquire(int* ptr) { + int state = 0; +#if (__CUDA_ARCH__ >= 700) + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(ptr)); +#else + asm volatile("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#endif + return state; +} + +__device__ __forceinline__ void red_release(int* ptr, int val) { +#if (__CUDA_ARCH__ >= 700) + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(ptr), "r"(val)); +#else + __threadfence(); + atomicAdd(ptr, val); +#endif +} + +__device__ __forceinline__ void st_release(int* ptr, int val) { +#if (__CUDA_ARCH__ >= 700) + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("st.release.gpu.global.b32 [%0], %1;\n" : : "l"(ptr), "r"(val)); +#else + __threadfence(); + atomicExch(ptr, val); +#endif +} + +__device__ __forceinline__ void wait_ge(int* ptr, int target_val, + int thread_idx) { + if (thread_idx == 0) { +#pragma unroll 1 + while (ld_acquire(ptr) < target_val) { + } + } + __syncthreads(); +} + +// ============================================================================ +// Large path: multi-CTA radix select for sequences > 64K +// +// Each row is processed by a group of CTAs. Each CTA loads its chunk into +// shared memory as ordered uint32, then participates in 4 rounds of +// coordinated radix select via global-memory histograms and barriers. +// ============================================================================ + +// ============================================================================ +// Multi-CTA cooperative RadixTopK for a single large row. +// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215 +// ============================================================================ + +template +__device__ void radix_topk(const float* __restrict__ row_input, + int32_t* __restrict__ row_output, uint32_t seq_len, + uint32_t my_chunk_start, uint32_t chunk_size, + uint32_t* local_histogram, uint32_t* suffix_sum, + uint32_t* shared_scalars, uint32_t* shared_ordered, + RadixRowState* state, uint32_t cta_in_group, + uint32_t ctas_per_group, int& barrier_phase, + uint32_t iter, uint32_t tx) { + const uint32_t my_chunk_end = (my_chunk_start + chunk_size < seq_len) + ? my_chunk_start + chunk_size + : seq_len; + const uint32_t actual_chunk_size = + (my_chunk_start < seq_len) ? (my_chunk_end - my_chunk_start) : 0; + + // -- Stage 1: Load chunk to shared memory as ordered uint32 -- + { + const uint32_t aligned_size = (actual_chunk_size / VEC_SIZE) * VEC_SIZE; + + for (uint32_t i = tx * VEC_SIZE; i < aligned_size; + i += kThreadsPerBlock * VEC_SIZE) { + const float* src = row_input + my_chunk_start + i; + if constexpr (VEC_SIZE == 4) { + float4 v = *reinterpret_cast(src); + shared_ordered[i] = convert_to_uint32_v2(v.x); + shared_ordered[i + 1] = convert_to_uint32_v2(v.y); + shared_ordered[i + 2] = convert_to_uint32_v2(v.z); + shared_ordered[i + 3] = convert_to_uint32_v2(v.w); + } else if constexpr (VEC_SIZE == 2) { + float2 v = *reinterpret_cast(src); + shared_ordered[i] = convert_to_uint32_v2(v.x); + shared_ordered[i + 1] = convert_to_uint32_v2(v.y); + } else { + shared_ordered[i] = convert_to_uint32_v2(*src); + } + } + for (uint32_t i = aligned_size + tx; i < actual_chunk_size; + i += kThreadsPerBlock) { + shared_ordered[i] = convert_to_uint32_v2(row_input[my_chunk_start + i]); + } + } + __syncthreads(); + + // -- Init radix select state -- + if (tx == 0) { + shared_scalars[0] = 0; // prefix + shared_scalars[1] = TopK; // remaining_k + } + __syncthreads(); + + // -- Initial barrier -- + if (tx == 0) { + red_release(&state->arrival_counter, 1); + } + wait_ge(&state->arrival_counter, + (barrier_phase + 1) * static_cast(ctas_per_group), tx); + barrier_phase++; + __syncthreads(); + + if (cta_in_group == 0 && tx == 0) { + st_release(&state->output_counter, 0); + } + + // -- Stage 2: 4 rounds of radix select -- + for (uint32_t round = 0; round < 4; round++) { + const uint32_t global_round = iter * 4 + round; + const uint32_t shift = 24 - round * 8; + const uint32_t prefix = shared_scalars[0]; + const uint32_t remaining_k = shared_scalars[1]; + + uint32_t* current_hist = state->histogram[global_round % 3]; + uint32_t* next_hist = state->histogram[(global_round + 1) % 3]; + + for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) { + local_histogram[i] = 0; + } + __syncthreads(); + + for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) { + uint32_t ordered = shared_ordered[i]; + uint32_t mask = (round == 0) ? 0u : (~0u << (32 - round * 8)); + if ((ordered & mask) == prefix) { + uint32_t bucket = (ordered >> shift) & 0xFF; + atomicAdd(&local_histogram[bucket], 1); + } + } + __syncthreads(); + + for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) { + if (local_histogram[i] > 0) { + atomicAdd(¤t_hist[i], local_histogram[i]); + } + } + + if (cta_in_group == 0) { + for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) { + next_hist[i] = 0; + } + } + + if (tx == 0) { + red_release(&state->arrival_counter, 1); + } + wait_ge(&state->arrival_counter, + (barrier_phase + 1) * static_cast(ctas_per_group), tx); + barrier_phase++; + __syncthreads(); + + for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) { + suffix_sum[i] = current_hist[i]; + } + __syncthreads(); + + for (uint32_t stride = 1; stride < RADIX; stride *= 2) { + uint32_t val = 0; + if (tx < RADIX) { + val = suffix_sum[tx]; + if (tx + stride < RADIX) val += suffix_sum[tx + stride]; + } + __syncthreads(); + if (tx < RADIX) suffix_sum[tx] = val; + __syncthreads(); + } + + if (tx == 0) { + shared_scalars[2] = 0; + shared_scalars[3] = remaining_k; + } + __syncthreads(); + + if (tx < RADIX) { + uint32_t count_ge = suffix_sum[tx]; + uint32_t count_gt = (tx + 1 < RADIX) ? suffix_sum[tx + 1] : 0; + if (count_ge >= remaining_k && count_gt < remaining_k) { + shared_scalars[2] = tx; + shared_scalars[3] = remaining_k - count_gt; + } + } + __syncthreads(); + + if (tx == 0) { + shared_scalars[0] = prefix | (shared_scalars[2] << shift); + shared_scalars[1] = shared_scalars[3]; + } + __syncthreads(); + } // end 4 radix rounds + + // -- Count local > pivot elements -- + const uint32_t ordered_pivot = shared_scalars[0]; + + if (tx == 0) suffix_sum[0] = 0; + __syncthreads(); + + uint32_t my_gt_count = 0; + for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) { + if (shared_ordered[i] > ordered_pivot) my_gt_count++; + } + for (int offset = 16; offset > 0; offset /= 2) { + my_gt_count += __shfl_down_sync(0xffffffff, my_gt_count, offset); + } + if (tx % 32 == 0 && my_gt_count > 0) { + atomicAdd(&suffix_sum[0], my_gt_count); + } + __syncthreads(); + const uint32_t local_gt_count = suffix_sum[0]; + + // -- Stage 3: Collect top-k indices -- + if (tx == 0) { + local_histogram[0] = 0; + if (local_gt_count > 0) { + local_histogram[1] = + atomicAdd(&state->output_counter, static_cast(local_gt_count)); + } + } + __syncthreads(); + + for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) { + if (shared_ordered[i] > ordered_pivot) { + uint32_t local_pos = atomicAdd(&local_histogram[0], 1); + int pos = static_cast(local_histogram[1]) + local_pos; + row_output[pos] = static_cast(my_chunk_start + i); + } + } + + if (tx == 0) { + red_release(&state->arrival_counter, 1); + } + wait_ge(&state->arrival_counter, + (barrier_phase + 1) * static_cast(ctas_per_group), tx); + barrier_phase++; + __syncthreads(); + + for (uint32_t i = tx; i < actual_chunk_size; i += kThreadsPerBlock) { + if (shared_ordered[i] == ordered_pivot) { + int pos = atomicAdd(&state->output_counter, 1); + if (pos < TopK) { + row_output[pos] = static_cast(my_chunk_start + i); + } + } + } +} + +// ============================================================================ +// Persistent kernel — BS≤32, decode/medium/large paths with RadixTopK +// BS>32 uses standalone histogram_256_buffered_topk (separate kernel, +// see filtered_topk.cuh) +// ============================================================================ + +template +__global__ void __launch_bounds__(kThreadsPerBlock, 2) + persistent_topk_kernel(PersistentTopKParams params) { + const uint32_t tx = threadIdx.x; + extern __shared__ uint8_t smem_raw[]; + + // ======================================================================== + // Group mode: multi-CTA groups with static round-robin row assignment. + // Non-large rows: CTA-0 handles trivial/decode/medium. + // Large rows: all CTAs in the group cooperate via RadixTopK. + // ======================================================================== + const uint32_t ctas_per_group = params.ctas_per_group; + const uint32_t group_id = blockIdx.x / ctas_per_group; + const uint32_t cta_in_group = blockIdx.x % ctas_per_group; + const uint32_t num_groups = gridDim.x / ctas_per_group; + const uint32_t chunk_size = params.chunk_size; + + if (blockIdx.x >= num_groups * ctas_per_group) return; + + // Early exit: non-CTA-0 threads are never needed if no large rows exist + if (cta_in_group != 0 && params.max_seq_len <= RADIX_THRESHOLD) return; + + uint32_t* local_histogram = reinterpret_cast(smem_raw); + uint32_t* suffix_sum = local_histogram + RADIX; + uint32_t* shared_scalars = suffix_sum + RADIX; + uint32_t* shared_ordered = + reinterpret_cast(smem_raw + kFixedSmemLarge); + + // RadixRowState for multi-CTA cooperative radix + RadixRowState* state = ¶ms.row_states[group_id]; + + // -- Initialize RadixRowState (only needed if large rows exist) -- + if (params.max_seq_len > RADIX_THRESHOLD) { + if (cta_in_group == 0) { + for (uint32_t buf = 0; buf < 3; buf++) { + for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) { + state->histogram[buf][i] = 0; + } + } + if (tx == 0) { + state->remaining_k = 0; + state->prefix = 0; + state->arrival_counter = 0; + state->output_counter = 0; + } + } + __syncthreads(); + } + + int barrier_phase = 0; + const uint32_t total_iters = (params.num_rows + num_groups - 1) / num_groups; + + for (uint32_t iter = 0; iter < total_iters; iter++) { + // Static round-robin: all CTAs in the group implicitly agree on the row + uint32_t row_idx = group_id + iter * num_groups; + if (row_idx >= params.num_rows) break; + + const uint32_t seq_len = params.lengths[row_idx]; + int32_t* row_output = params.output + row_idx * TopK; + const float* row_input = params.input + row_idx * params.stride; + + if (seq_len <= RADIX_THRESHOLD) { + if (cta_in_group == 0) { + if (seq_len <= static_cast(TopK)) { + // Trivial case: seq_len <= TopK + for (uint32_t i = tx; i < static_cast(TopK); + i += kThreadsPerBlock) { + row_output[i] = (i < seq_len) ? static_cast(i) : -1; + } + } else if (seq_len <= static_cast(HIST2048_THRESHOLD)) { + histogram_2048_topk(row_input, row_output, seq_len); + } else { + histogram_256_topk(row_input, row_output, 0, seq_len); + } + } + continue; + } + + const uint32_t my_chunk_start = cta_in_group * chunk_size; + radix_topk(row_input, row_output, seq_len, my_chunk_start, + chunk_size, local_histogram, suffix_sum, + shared_scalars, shared_ordered, state, cta_in_group, + ctas_per_group, barrier_phase, iter, tx); + } +} + +} // namespace persistent + +// ============================================================================ +// FlashInfer FilteredTopK (BS>32 dispatch) — float32 only. +// Extracted from flashinfer_topk.cuh. Lives in namespace vllm (not persistent). +// Adapted from https://github.com/flashinfer-ai/flashinfer/pull/2215 +// ============================================================================ + +#define FLASHINFER_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + return e; \ + } \ + } + +#define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ + +template +struct vec_t { + T data[N]; + + FLASHINFER_INLINE T& operator[](size_t i) { return data[i]; } + FLASHINFER_INLINE const T& operator[](size_t i) const { return data[i]; } + + FLASHINFER_INLINE void cast_load(const T* ptr) { +#pragma unroll + for (size_t i = 0; i < N; ++i) { + data[i] = ptr[i]; + } + } + + FLASHINFER_INLINE void cast_store(T* ptr) const { +#pragma unroll + for (size_t i = 0; i < N; ++i) { + ptr[i] = data[i]; + } + } +}; +#undef FLASHINFER_INLINE + +// FilteredTopK traits for different data types +template +struct FilteredTopKTraits; + +// Specialization for float (32-bit): coarse histogram uses FP16 high 8 bits, 4 +// refinement rounds +template <> +struct FilteredTopKTraits { + using OrderedType = uint32_t; + static constexpr int NUM_REFINE_ROUNDS = 4; + static constexpr int FIRST_REFINE_SHIFT = 24; + + __device__ __forceinline__ static uint8_t ToCoarseKey(float x) { + // Convert to FP16 representation and extract high 8 bits + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ static OrderedType ToOrdered(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } +}; + +constexpr uint32_t FILTERED_TOPK_MAX_K = 2048; +constexpr uint32_t FILTERED_TOPK_BLOCK_THREADS = 1024; +constexpr uint32_t FILTERED_TOPK_SMEM_INPUT_SIZE = + 16 * 1024; // 16K indices per buffer +constexpr size_t FILTERED_TOPK_SMEM_DYNAMIC = + sizeof(int) * 2 * FILTERED_TOPK_SMEM_INPUT_SIZE; // 128KB + +/*! + * \brief Filtered Top-K kernel for ragged sequences. + * + * \tparam DType Data type (float, half, nv_bfloat16) + * \tparam IdType Index type (int32_t) + * \tparam VEC_SIZE Vector size for input loads (1, 2, 4, or 8) + */ +template +__global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) + FilteredTopKUnifiedKernel(const DType* __restrict__ input, + IdType* __restrict__ output, + const IdType* __restrict__ lengths, + uint32_t num_rows, uint32_t top_k, + uint32_t max_len) { + constexpr uint32_t BLOCK_SIZE = FILTERED_TOPK_BLOCK_THREADS; + constexpr int RADIX = 256; + constexpr int SMEM_INPUT_SIZE = FILTERED_TOPK_SMEM_INPUT_SIZE; + + const uint32_t bid = blockIdx.x; + const int tx = threadIdx.x; + + if (bid >= num_rows) return; + + const int length = + (lengths != nullptr) ? lengths[bid] : static_cast(max_len); + const DType* score = input + bid * max_len; + IdType* dst = output + bid * top_k; + + // Trivial case: length <= top_k + if (length <= static_cast(top_k)) { + for (int i = tx; i < static_cast(top_k); i += BLOCK_SIZE) { + dst[i] = (i < length) ? static_cast(i) : static_cast(-1); + } + return; + } + + // Static shared memory + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + alignas(128) __shared__ int s_indices[FILTERED_TOPK_MAX_K]; + + auto& s_histogram = s_histogram_buf[0]; + + // Dynamic shared memory for input double buffer + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + using Traits = FilteredTopKTraits; + int topk = top_k; + + // Stage 1: 8-bit coarse histogram with vectorized loads + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + vec_t score_vec; + + const int aligned_length = (length / VEC_SIZE) * VEC_SIZE; +#pragma unroll 2 + for (int base = tx * VEC_SIZE; base < aligned_length; + base += BLOCK_SIZE * VEC_SIZE) { + score_vec.cast_load(&score[base]); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + const auto bin = Traits::ToCoarseKey(score_vec[j]); + atomicAdd(&s_histogram[bin], 1); + } + } + // Handle tail + for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) { + const auto bin = Traits::ToCoarseKey(score[i]); + atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + // Suffix sum + const auto run_cumsum = [&]() { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + constexpr int NUM_ROUNDS = Traits::NUM_REFINE_ROUNDS; + constexpr int FIRST_SHIFT = Traits::FIRST_REFINE_SHIFT; + + if (topk == 0) { + // Collect indices where bin > threshold +#pragma unroll 2 + for (int base = tx * VEC_SIZE; base < aligned_length; + base += BLOCK_SIZE * VEC_SIZE) { + score_vec.cast_load(&score[base]); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + const auto bin = static_cast(Traits::ToCoarseKey(score_vec[j])); + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = base + j; + } + } + } + // Handle tail + for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) { + const auto bin = static_cast(Traits::ToCoarseKey(score[i])); + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = i; + } + } + __syncthreads(); + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + // Filter + histogram for refinement + auto filter_and_add_to_histogram = [&](auto raw_input, int index) { + const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = index; + } else if (bin == threshold_bin) { + const auto pos = atomicAdd(&s_num_input[0], 1); + if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { + s_input_idx[0][pos] = index; + const auto ordered = Traits::ToOrdered(raw_input); + const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF; + atomicAdd(&s_histogram[sub_bin], 1); + } + } + }; +#pragma unroll 2 + for (int base = tx * VEC_SIZE; base < aligned_length; + base += BLOCK_SIZE * VEC_SIZE) { + score_vec.cast_load(&score[base]); +#pragma unroll + for (int j = 0; j < VEC_SIZE; ++j) { + filter_and_add_to_histogram(score_vec[j], base + j); + } + } + // Handle tail + for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) { + filter_and_add_to_histogram(score[i], i); + } + __syncthreads(); + + // Stage 2: refine with 8bit radix passes +#pragma unroll + for (int round = 0; round < NUM_ROUNDS; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = + (_raw_num_input < SMEM_INPUT_SIZE) ? _raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold = s_threshold_bin_id; + topk -= s_histogram[threshold + 1]; + + const int offset = FIRST_SHIFT - round * 8; + const bool is_last_round = (round == NUM_ROUNDS - 1); + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto bin = (Traits::ToOrdered(score[idx]) >> offset) & 0xFF; + if (static_cast(bin) > threshold) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = score[idx]; + const auto bin = (Traits::ToOrdered(raw_input) >> offset) & 0xFF; + if (static_cast(bin) > threshold) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = idx; + } else if (static_cast(bin) == threshold) { + if (is_last_round) { + const auto pos = atomicAdd(&s_last_remain, -1); + if (pos > 0) { + s_indices[top_k - pos] = idx; + } + } else { + const auto pos = atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin32 = Traits::ToOrdered(raw_input); + const auto sub_bin = (bin32 >> (offset - 8)) & 0xFF; + atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + // Output phase - mode-specific +#pragma unroll 2 + for (int base = tx; base < static_cast(top_k); base += BLOCK_SIZE) { + const int idx = s_indices[base]; + dst[base] = static_cast(idx); + } +} + +// Helper to compute GCD for VEC_SIZE selection +constexpr uint32_t gcd(uint32_t a, uint32_t b) { + while (b != 0) { + uint32_t t = b; + b = a % b; + a = t; + } + return a; +} + +// Compute optimal VEC_SIZE based on max_len and dtype +// Returns 1, 2, 4, or 8 +template +constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) { + constexpr int MAX_VEC = 16 / sizeof(DType); // 4 for float32, 8 for fp16/bf16 + // Use GCD to find largest power-of-2 divisor + const uint32_t g = gcd(max_len, static_cast(MAX_VEC)); + return static_cast(g); +} + +template +cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, + IdType* lengths, uint32_t num_rows, + uint32_t top_k_val, uint32_t max_len, + cudaStream_t stream = 0) { + constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC; + constexpr int MAX_VEC = 16 / sizeof(DType); + + dim3 grid(num_rows); + dim3 block(FILTERED_TOPK_BLOCK_THREADS); + void* args[] = {&input, &output_indices, &lengths, + &num_rows, &top_k_val, &max_len}; + + const int vec_size = ComputeFilteredTopKVecSize(max_len); + +#define DISPATCH_VEC_SIZE(VS) \ + if (vec_size == VS) { \ + auto kernel = FilteredTopKUnifiedKernel; \ + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( \ + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); \ + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, grid, block, args, \ + smem_size, stream)); \ + return cudaSuccess; \ + } + + DISPATCH_VEC_SIZE(1) + DISPATCH_VEC_SIZE(2) + DISPATCH_VEC_SIZE(4) + if constexpr (MAX_VEC >= 8) { + DISPATCH_VEC_SIZE(8) + } +#undef DISPATCH_VEC_SIZE + + return cudaSuccess; +} + +} // namespace vllm + +#endif // PERSISTENT_TOPK_CUH_ diff --git a/csrc/topk.cu b/csrc/topk.cu index a7850f5363b9..402b64b027ae 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -1,373 +1,154 @@ -// Portions of this file are adapted from SGLang PR: -// https://github.com/sgl-project/sglang/pull/11194 -// and -// https://github.com/sgl-project/sglang/pull/17747 +// Persistent TopK kernel for DeepSeek V3 sparse attention indexer. +// See persistent_topk.cuh for kernel implementation. -#include "cuda_compat.h" -#include "dispatch_utils.h" - -#include -#include +#include +#include +#include +#include #ifndef USE_ROCM - #include -#else - #include + #include "persistent_topk.cuh" #endif -namespace vllm { - -constexpr int TopK = 2048; // DeepSeek V3 sparse attention top-k -constexpr int kThreadsPerBlock = 1024; // Threads per block - -// Shared memory budget -#if defined(USE_ROCM) -constexpr size_t kSmem = 48 * 1024; // ROCm default: 48KB -#else -// Reduced from 128KB to 32KB to improve occupancy. -// Each radix pass needs at most ~TopK candidates in the threshold bin, -// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. -constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) -#endif - -struct FastTopKParams { - const float* __restrict__ input; // [batch, seq_len] Logits - const int32_t* __restrict__ row_starts; // [batch] Offset into each row - // (optional) - int32_t* __restrict__ indices; // [batch, TopK] Output top-k indices - int32_t* __restrict__ lengths; // [batch] Sequence lengths per row - int64_t input_stride; // Stride between rows -}; - -__device__ __forceinline__ auto convert_to_uint32_v2(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); -} - -__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) - : static_cast(bits | 0x8000); - return static_cast(key >> 8); -} - -__device__ void naive_topk_cuda(const float* __restrict__ logits, - int32_t* __restrict__ output_indices, - int32_t seq_len) { - const int thread_id = threadIdx.x; - for (int i = thread_id; i < TopK; i += kThreadsPerBlock) { - output_indices[i] = (i < seq_len) ? i : -1; - } -} - -// Adapted from: -// https://github.com/sgl-project/sglang/blob/v0.5.8/sgl-kernel/csrc/elementwise/topk.cu#L87 -// by: DarkSharpness -// which at the same time is an optimized topk kernel copied from tilelang -// kernel -__device__ void fast_topk_cuda_tl( - const float* __restrict__ logits, // Input logits [seq_len] - int* __restrict__ output_indices, // Output top-k indices [TopK] - int logits_offset, // Starting offset in logits array - int seq_len) // Number of valid logits to process -{ - constexpr int RADIX = 256; - constexpr int MAX_BUFFERED_ITEMS = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int shared_histogram[2][RADIX + 128]; - alignas(128) __shared__ int shared_output_count; - alignas(128) __shared__ int shared_threshold_bin; - alignas(128) __shared__ int shared_buffered_count[2]; - - extern __shared__ int buffered_indices[][MAX_BUFFERED_ITEMS]; - - const int thread_id = threadIdx.x; - int remaining_k = TopK; - - // Pass 0: Build coarse 8-bit histogram using FP16 high bits - if (thread_id < RADIX + 1) { - shared_histogram[0][thread_id] = 0; - } - __syncthreads(); - - for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { - const auto bin = convert_to_uint8(logits[idx + logits_offset]); - ::atomicAdd(&shared_histogram[0][bin], 1); - } - __syncthreads(); - - // Helper: Compute cumulative sum (suffix sum) over histogram using ping-pong - // buffers - auto compute_cumulative_sum = [&]() { - static_assert(1 << 8 == RADIX, - "Radix must be 256 for 8 unrolled iterations"); -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - if (C10_LIKELY(thread_id < RADIX)) { - const int stride = 1 << i; - const int src_buffer = i & 1; - const int dst_buffer = src_buffer ^ 1; - - int value = shared_histogram[src_buffer][thread_id]; - if (thread_id < RADIX - stride) { - value += shared_histogram[src_buffer][thread_id + stride]; - } - shared_histogram[dst_buffer][thread_id] = value; - } - __syncthreads(); - } - }; - - compute_cumulative_sum(); - - // Find threshold bin where cumsum crosses remaining_k - if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && - shared_histogram[0][thread_id + 1] <= remaining_k) { - shared_threshold_bin = thread_id; - shared_buffered_count[0] = 0; - shared_output_count = 0; - } - __syncthreads(); - - const int threshold_bin = shared_threshold_bin; - remaining_k -= shared_histogram[0][threshold_bin + 1]; - - // Early exit if threshold bin perfectly matches remaining_k - if (remaining_k == 0) { - for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { - const int bin = convert_to_uint8(logits[idx + logits_offset]); - if (bin > threshold_bin) { - const int output_pos = ::atomicAdd(&shared_output_count, 1); - output_indices[output_pos] = idx; - } - } - __syncthreads(); - return; - } - - // Prepare for refinement passes: Process threshold bin - __syncthreads(); - if (thread_id < RADIX + 1) { - shared_histogram[0][thread_id] = 0; - } - __syncthreads(); - - // Scan all elements and: - // 1. Write indices > threshold_bin to output - // 2. Buffer indices == threshold_bin for refinement - // 3. Build histogram for next refinement pass (fused optimization) - for (int idx = thread_id; idx < seq_len; idx += kThreadsPerBlock) { - const float logit_value = logits[idx + logits_offset]; - const int bin = convert_to_uint8(logit_value); - - if (bin > threshold_bin) { - // in top-k, write to output - const int output_pos = ::atomicAdd(&shared_output_count, 1); - output_indices[output_pos] = idx; - } else if (bin == threshold_bin) { - // Candidate for top-k, needs refinement - const int buffer_pos = ::atomicAdd(&shared_buffered_count[0], 1); - if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) { - buffered_indices[0][buffer_pos] = idx; - // Fused: Build histogram for next pass - const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); - const int next_bin = (fp32_bits >> 24) & 0xFF; - ::atomicAdd(&shared_histogram[0][next_bin], 1); - } - } +void persistent_topk(const torch::Tensor& logits, const torch::Tensor& lengths, + torch::Tensor& output, torch::Tensor& workspace, int64_t k, + int64_t max_seq_len) { +#ifndef USE_ROCM + TORCH_CHECK(logits.is_cuda(), "logits must be CUDA tensor"); + TORCH_CHECK(lengths.is_cuda(), "lengths must be CUDA tensor"); + TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor"); + TORCH_CHECK(logits.dtype() == torch::kFloat32, "Only float32 supported"); + TORCH_CHECK(lengths.dtype() == torch::kInt32, "lengths must be int32"); + TORCH_CHECK(output.dtype() == torch::kInt32, "output must be int32"); + TORCH_CHECK(logits.dim() == 2, "logits must be 2D"); + TORCH_CHECK(lengths.dim() == 1, "lengths must be 1D"); + TORCH_CHECK(output.dim() == 2, "output must be 2D"); + + const int64_t num_rows = logits.size(0); + const int64_t stride = logits.size(1); + + TORCH_CHECK(lengths.size(0) == num_rows, "lengths size mismatch"); + TORCH_CHECK(output.size(0) == num_rows && output.size(1) == k, + "output size mismatch"); + namespace P = vllm::persistent; + + TORCH_CHECK(k == P::TopK, "k must be 2048"); + TORCH_CHECK(k <= stride, "k out of range"); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + static int num_sms = 0; + static int max_smem_per_block = 0; + if (num_sms == 0) { + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device); + cudaDeviceGetAttribute(&max_smem_per_block, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device); } - __syncthreads(); - - // ============================================================================ - // Passes 1-4: Refine using 8-bit passes over FP32 bits - // ============================================================================ - // FP32 bits [31:0] split into 4 bytes processed MSB-first: - // Pass 1: bits [31:24], Pass 2: bits [23:16], Pass 3: bits [15:8], Pass 4: - // bits [7:0] -#pragma unroll 4 - for (int pass = 0; pass < 4; ++pass) { - __shared__ int shared_final_k; // For final pass: remaining slots to fill - const int src_buffer = pass % 2; - const int dst_buffer = src_buffer ^ 1; - - // Clamp buffered count to prevent overflow - const int raw_buffered = shared_buffered_count[src_buffer]; - const int num_buffered = - (raw_buffered < MAX_BUFFERED_ITEMS) ? raw_buffered : MAX_BUFFERED_ITEMS; - - compute_cumulative_sum(); - - // Find threshold bin for this pass - if (thread_id < RADIX && shared_histogram[0][thread_id] > remaining_k && - shared_histogram[0][thread_id + 1] <= remaining_k) { - shared_threshold_bin = thread_id; - shared_buffered_count[dst_buffer] = 0; - shared_final_k = remaining_k - shared_histogram[0][thread_id + 1]; - } - __syncthreads(); - - const int threshold_bin = shared_threshold_bin; - remaining_k -= shared_histogram[0][threshold_bin + 1]; - - // Bit offset for this pass: 24, 16, 8, 0 - const int bit_offset = 24 - pass * 8; - // Early exit if threshold bin perfectly matches - if (remaining_k == 0) { - for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { - const int idx = buffered_indices[src_buffer][i]; - const uint32_t fp32_bits = - convert_to_uint32_v2(logits[idx + logits_offset]); - const int bin = (fp32_bits >> bit_offset) & 0xFF; - if (bin > threshold_bin) { - const int output_pos = ::atomicAdd(&shared_output_count, 1); - output_indices[output_pos] = idx; - } - } - __syncthreads(); - break; + if (num_rows > 32 && max_smem_per_block >= 128 * 1024) { + cudaError_t status = vllm::FilteredTopKRaggedTransform( + logits.data_ptr(), output.data_ptr(), + lengths.data_ptr(), static_cast(num_rows), + static_cast(k), static_cast(stride), stream); + TORCH_CHECK(status == cudaSuccess, + "FilteredTopK failed: ", cudaGetErrorString(status)); + } else { + TORCH_CHECK(workspace.is_cuda(), "workspace must be CUDA tensor"); + TORCH_CHECK(workspace.dtype() == torch::kUInt8, "workspace must be uint8"); + + // Smem cap: smaller smem → more CTAs/group → more per-row parallelism for + // large path. Empirically tuned. + int effective_max_smem; + if (num_rows <= 4) { + effective_max_smem = + std::min(max_smem_per_block, static_cast(P::kSmemMedium)); + } else if (num_rows <= 8) { + constexpr int kSmemCapMedium = 48 * 1024; + effective_max_smem = std::min(max_smem_per_block, kSmemCapMedium); + } else { + effective_max_smem = max_smem_per_block; } - // Continue refinement - __syncthreads(); - if (thread_id < RADIX + 1) { - shared_histogram[0][thread_id] = 0; + size_t available_for_ordered = + static_cast(effective_max_smem) - P::kFixedSmemLarge; + uint32_t max_chunk_elements = + static_cast(available_for_ordered / sizeof(uint32_t)); + + uint32_t vec_size = 1; + if (stride % 4 == 0) + vec_size = 4; + else if (stride % 2 == 0) + vec_size = 2; + + max_chunk_elements = (max_chunk_elements / vec_size) * vec_size; + uint32_t min_chunk = vec_size * P::kThreadsPerBlock; + if (max_chunk_elements < min_chunk) max_chunk_elements = min_chunk; + + uint32_t ctas_per_group = + (static_cast(stride) + max_chunk_elements - 1) / + max_chunk_elements; + uint32_t chunk_size = + (static_cast(stride) + ctas_per_group - 1) / ctas_per_group; + chunk_size = ((chunk_size + vec_size - 1) / vec_size) * vec_size; + if (chunk_size > max_chunk_elements) chunk_size = max_chunk_elements; + + size_t smem_size = P::kFixedSmemLarge + chunk_size * sizeof(uint32_t); + if (smem_size < P::kSmemMedium) smem_size = P::kSmemMedium; + + int occupancy = 1; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, P::persistent_topk_kernel<4>, P::kThreadsPerBlock, + smem_size); + if (occupancy < 1) occupancy = 1; + + uint32_t max_resident_ctas = static_cast(num_sms) * occupancy; + uint32_t num_groups = std::min(max_resident_ctas / ctas_per_group, + static_cast(num_rows)); + if (num_groups == 0) num_groups = 1; + uint32_t total_ctas = num_groups * ctas_per_group; + + size_t state_bytes = num_groups * sizeof(P::RadixRowState); + TORCH_CHECK(workspace.size(0) >= static_cast(state_bytes), + "workspace too small, need ", state_bytes, " bytes"); + + P::PersistentTopKParams params; + params.input = logits.data_ptr(); + params.output = output.data_ptr(); + params.lengths = lengths.data_ptr(); + params.num_rows = static_cast(num_rows); + params.stride = static_cast(stride); + params.chunk_size = chunk_size; + params.row_states = + reinterpret_cast(workspace.data_ptr()); + params.ctas_per_group = ctas_per_group; + params.max_seq_len = static_cast(max_seq_len); + + #define LAUNCH_PERSISTENT(VS) \ + do { \ + auto kernel = &P::persistent_topk_kernel; \ + cudaError_t err = cudaFuncSetAttribute( \ + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); \ + TORCH_CHECK(err == cudaSuccess, \ + "Failed to set smem: ", cudaGetErrorString(err)); \ + kernel<<>>(params); \ + } while (0) + + if (vec_size == 4) { + LAUNCH_PERSISTENT(4); + } else if (vec_size == 2) { + LAUNCH_PERSISTENT(2); + } else { + LAUNCH_PERSISTENT(1); } - __syncthreads(); - - for (int i = thread_id; i < num_buffered; i += kThreadsPerBlock) { - const int idx = buffered_indices[src_buffer][i]; - const float logit_value = logits[idx + logits_offset]; - const uint32_t fp32_bits = convert_to_uint32_v2(logit_value); - const int bin = (fp32_bits >> bit_offset) & 0xFF; - - if (bin > threshold_bin) { - // Definitely in top-k - const int output_pos = ::atomicAdd(&shared_output_count, 1); - output_indices[output_pos] = idx; - } else if (bin == threshold_bin) { - if (pass == 3) { - // Final pass (bits [7:0]): No more refinement possible - // Fill remaining slots in reverse order to maintain descending order - const int slot = ::atomicAdd(&shared_final_k, -1); - if (slot > 0) { - output_indices[TopK - slot] = idx; - } - } else { - // Buffer for next pass and build next histogram - const int buffer_pos = - ::atomicAdd(&shared_buffered_count[dst_buffer], 1); - if (C10_LIKELY(buffer_pos < MAX_BUFFERED_ITEMS)) { - buffered_indices[dst_buffer][buffer_pos] = idx; - // Fused: Build histogram for next pass - const int next_bit_offset = bit_offset - 8; - const int next_bin = (fp32_bits >> next_bit_offset) & 0xFF; - ::atomicAdd(&shared_histogram[0][next_bin], 1); - } - } - } - } - __syncthreads(); + #undef LAUNCH_PERSISTENT } -} -__global__ __launch_bounds__(kThreadsPerBlock) void topk_kernel( - const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const uint64_t batch_idx = blockIdx.x; - const int logits_offset = row_starts == nullptr ? 0 : row_starts[batch_idx]; - const int seq_len = lengths[batch_idx]; - int* output_indices = indices + batch_idx * TopK; - const float* logits = input + batch_idx * input_stride; - - if (seq_len <= TopK) { - // Shortcut: All elements are in top-k - return naive_topk_cuda(logits, output_indices, seq_len); - } else { - return fast_topk_cuda_tl(logits, output_indices, logits_offset, seq_len); - } -} - -FastTopKParams get_params( - const at::Tensor& score, const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) { - const int64_t batch_size = score.size(0); - - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1, - "score must be 2D with contiguous rows"); - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous() && - lengths.size(0) == batch_size, - "lengths must be 1D contiguous with size matching batch"); - - const int32_t* row_starts_ptr = nullptr; - if (row_starts_opt.has_value()) { - const auto& row_starts = *row_starts_opt; - TORCH_CHECK(row_starts.dim() == 1 && row_starts.size(0) == batch_size, - "row_starts must be 1D with size matching batch"); - row_starts_ptr = row_starts.data_ptr(); - } - - int32_t* indices_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = *indices_opt; - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous() && - indices.size(0) == batch_size && indices.size(1) == TopK, - "indices must be 2D contiguous [batch, TopK]"); - indices_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_ptr, - .indices = indices_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; -} - -template -void setup_kernel_smem_once() { - static const cudaError_t result = []() -> cudaError_t { -#ifdef USE_ROCM - auto func_ptr = reinterpret_cast(kernel_func); + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, + "persistent_topk failed: ", cudaGetErrorString(err)); #else - auto func_ptr = kernel_func; + TORCH_CHECK(false, "persistent_topk is not supported on ROCm"); #endif - return cudaFuncSetAttribute( - func_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - }(); - - TORCH_CHECK( - result == cudaSuccess, - "Failed to set kernel shared memory limit: ", cudaGetErrorString(result)); } - -} // namespace vllm - -void large_context_topk( - const torch::Tensor& logits, torch::Tensor& indices, - const torch::Tensor& seq_lens, - std::optional row_starts = std::nullopt) { - TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); - TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); - TORCH_CHECK(seq_lens.is_cuda(), "seq_lens must be a CUDA tensor"); - if (row_starts.has_value()) { - TORCH_CHECK(row_starts->is_cuda(), "row_starts must be a CUDA tensor"); - } - - const auto params = vllm::get_params(logits, seq_lens, row_starts, indices); - const int64_t batch_size = logits.size(0); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const dim3 grid(static_cast(batch_size)); - const dim3 block(vllm::kThreadsPerBlock); - - vllm::setup_kernel_smem_once(); - vllm::topk_kernel<<>>(params); - - const cudaError_t result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "large_context_topk kernel failed: ", cudaGetErrorString(result)); -} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3593f1d2225b..5c57a0df44f2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -185,10 +185,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode); ops.def( - "large_context_topk(Tensor score, Tensor indices, Tensor lengths, " - "Tensor? " - "row_starts_opt) -> ()"); - ops.impl("large_context_topk", torch::kCUDA, &large_context_topk); + "persistent_topk(Tensor logits, Tensor lengths, Tensor! output, " + "Tensor workspace, int k, int max_seq_len) -> ()"); + ops.impl("persistent_topk", torch::kCUDA, &persistent_topk); // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index f4bfc1666c09..40bd2af65ea1 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -122,6 +122,39 @@ def compare_top_k_results( return True +def validate_topk_against_reference( + logits: torch.Tensor, + cuda_indices: torch.Tensor, + row_starts: torch.Tensor, + row_ends: torch.Tensor, + top_k: int, + kernel_name: str, +) -> None: + """ + Validate CUDA top-k results against PyTorch reference implementation. + + Args: + logits: Input logits tensor + cuda_indices: CUDA kernel output indices + row_starts: Row start positions + row_ends: Row end positions + top_k: Number of top elements to select + kernel_name: Name of the kernel being tested (for error messages) + """ + num_rows = cuda_indices.shape[0] + torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + + for i in range(num_rows): + row_end = int(row_ends[i]) + k_i = min(top_k, row_end) + idx = logits[i, :row_end].topk(k_i, dim=-1)[1] + torch_indices[i, :k_i] = idx + + assert compare_top_k_results( + logits, cuda_indices, torch_indices, row_starts, row_ends, top_k + ), f"{kernel_name} results don't match torch.topk" + + @pytest.mark.parametrize("num_rows", NUM_ROWS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("clean_logits", [True, False]) @@ -278,111 +311,540 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None: @pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize( + "seq_len_range,test_id", + [ + pytest.param((4000, 8000), "short_sequences", id="short"), + pytest.param((8000, 32000), "medium_sequences", id="medium"), + pytest.param((32000, 163840), "long_sequences", id="long"), + ], +) @pytest.mark.parametrize("clean_logits", [True, False]) +@pytest.mark.parametrize("top_k", [2048]) +@pytest.mark.parametrize("next_n", [1, 4]) @torch.inference_mode() -def test_deepseek_hybrid_topk(clean_logits: bool) -> None: +def test_deepseek_persistent_topk( + seq_len_range: tuple[int, int], + test_id: str, + clean_logits: bool, + top_k: int, + next_n: int, +) -> None: + """ + Test persistent_topk with varying sequence lengths and speculative decoding. + Supports speculative decoding with next_n > 1. + """ + set_random_seed(42 if test_id == "short_sequences" else 43) torch.set_default_device("cuda:0") - top_k = 2048 - - # Test case 1: Short sequences (< 8192) - batch_size_short = 4 - next_n = 1 - num_rows_short = batch_size_short * next_n + batch_size = 4 + num_rows = batch_size * next_n - # Create sequences with max length < 8192 - seq_lens_short = torch.randint( - 4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda" + seq_lens = torch.randint( + seq_len_range[0], + seq_len_range[1], + (batch_size,), + dtype=torch.int32, + device="cuda", ) - row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda") - row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n - next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n - row_ends_short = ( - seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1 + # Compute row boundaries for speculative decoding + row_starts = torch.zeros(num_rows, dtype=torch.int32, device="cuda") + row_indices = torch.arange(num_rows, device="cuda") // next_n + next_n_offset = torch.arange(num_rows, device="cuda") % next_n + row_ends = seq_lens[row_indices] - next_n + next_n_offset + 1 + + logits = create_random_logits( + row_starts, row_ends, torch.float32, 42, clean_logits, "random" ) - logits_short = create_random_logits( - row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random" + indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + + if next_n == 1: + lengths = seq_lens + else: + offsets = torch.arange(next_n, device=logits.device, dtype=torch.int32) + lengths = (seq_lens.unsqueeze(1) - next_n + 1 + offsets).flatten() + + workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda") + max_seq_len = int(seq_lens.max().item()) + torch.ops._C.persistent_topk( + logits, lengths, indices, workspace, top_k, max_seq_len ) - indices_vllm = torch.empty( - (num_rows_short, top_k), dtype=torch.int32, device="cuda" + validate_topk_against_reference( + logits, indices, row_starts, row_ends, top_k, f"persistent_topk ({test_id})" ) - # Use vllm's kernel for short sequences - torch.ops._C.top_k_per_row_decode( - logits_short, - next_n, - seq_lens_short, - indices_vllm, - num_rows_short, - logits_short.stride(0), - logits_short.stride(1), - top_k, + +def run_large_context_topk_test( + batch_size: int, + seq_lens: list[int], + top_k: int, + data_type: str = "random", + seed: int = 42, +) -> None: + """ + Helper to run persistent_topk kernel test with given parameters. + + Args: + batch_size: Number of rows/sequences + seq_lens: List of sequence lengths (one per row) + top_k: Number of top elements to select + data_type: Type of test data to generate + seed: Random seed for reproducibility + """ + torch.set_default_device("cuda:0") + set_random_seed(seed) + + # Create test data + num_rows = batch_size + max_len = max(seq_lens) + lengths = torch.tensor(seq_lens, dtype=torch.int32, device="cuda") + + if data_type == "random": + logits = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda") + elif data_type == "sorted_asc": + # Each row gets its own ascending sequence based on its length + logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda") + for i, length in enumerate(seq_lens): + logits[i, :length] = torch.arange( + length, dtype=torch.float32, device="cuda" + ) + if length < max_len: + logits[i, length:] = float("-inf") + elif data_type == "sorted_desc": + # Each row gets its own descending sequence based on its length + logits = torch.empty(num_rows, max_len, dtype=torch.float32, device="cuda") + for i, length in enumerate(seq_lens): + logits[i, :length] = torch.arange( + length, 0, -1, dtype=torch.float32, device="cuda" + ) + if length < max_len: + logits[i, length:] = float("-inf") + elif data_type == "all_same": + logits = torch.ones(num_rows, max_len, dtype=torch.float32, device="cuda") + for i, length in enumerate(seq_lens): + if length < max_len: + logits[i, length:] = float("-inf") + elif data_type == "many_ties": + # Only 10 unique values, many duplicates + logits = torch.randint(0, 10, (num_rows, max_len), device="cuda").float() / 10.0 + for i, length in enumerate(seq_lens): + if length < max_len: + logits[i, length:] = float("-inf") + elif data_type == "small_differences": + # Very small differences to test float precision + base = torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda") + noise = ( + torch.randn(num_rows, max_len, dtype=torch.float32, device="cuda") * 1e-6 + ) + logits = base + noise + for i, length in enumerate(seq_lens): + if length < max_len: + logits[i, length:] = float("-inf") + else: + raise ValueError(f"Unknown data_type: {data_type}") + + # Create output tensor + indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + + workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda") + max_seq_len = max(seq_lens) + torch.ops._C.persistent_topk( + logits, lengths, indices, workspace, top_k, max_seq_len ) - # Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel - batch_size_long = 4 - num_rows_long = batch_size_long * next_n + torch.accelerator.synchronize() + + torch_indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") + for i in range(num_rows): + length = seq_lens[i] + k_i = min(top_k, length) + if k_i > 0: + idx = logits[i, :length].topk(k_i, dim=-1)[1] + torch_indices[i, :k_i] = idx + if k_i < top_k: + torch_indices[i, k_i:] = -1 + else: + torch_indices[i, :] = -1 + + # Compare results + for i in range(num_rows): + length = seq_lens[i] + k_i = min(top_k, length) + + if k_i == 0: + continue + + cuda_row = indices[i, :k_i].cpu() + torch_row = torch_indices[i, :k_i].cpu() + + # Filter out -1 padding values from cuda_row + valid_mask = cuda_row >= 0 + cuda_row = cuda_row[valid_mask] + + # Compare sets (order may differ for ties) + cuda_set = set(cuda_row.tolist()) + torch_set = set(torch_row.tolist()) + + if cuda_set == torch_set: + continue + + # If sets differ, check if it's due to equal values (ties) + cuda_vals = logits[i, cuda_row].cpu() + torch_vals = logits[i, torch_row].cpu() + + # Check that min CUDA value >= max of values NOT in top-k + if k_i < length: + non_topk_indices = torch.tensor( + list(set(range(length)) - cuda_set), dtype=torch.int32 + ) + if len(non_topk_indices) > 0: + non_topk_vals = logits[i, non_topk_indices].cpu() + min_cuda_val = cuda_vals.min() + max_non_topk = non_topk_vals.max() + + # Allow small tolerance for floating point errors + assert min_cuda_val >= max_non_topk - 1e-4, ( + f"Row {i}: CUDA top-k contains values smaller than non-top-k. " + f"Min CUDA: {min_cuda_val}, Max non-top-k: {max_non_topk}, " + f"Length: {length}, k: {k_i}, CUDA indices: {sorted(cuda_set)[:10]}..., " # noqa: E501 + f"Expected indices: {sorted(torch_set)[:10]}..." + ) + + # For ties, verify the values are close + assert torch.allclose( + cuda_vals.sort(descending=True)[0], + torch_vals.sort(descending=True)[0], + rtol=1e-4, + atol=1e-4, + ), f"""Row {i}: Top-k values don't match. + CUDA: {cuda_vals.sort(descending=True)[0][:10]}, + Torch: {torch_vals.sort(descending=True)[0][:10]}""" - # Create sequences with max length >= 8192 - seq_lens_long = torch.randint( - 8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda" + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize( + "test_config", + [ + # ==================== CATEGORY: Sequence Length Edge Cases ==================== + pytest.param( + {"seq_lens": [1, 10, 100, 2048], "top_k": 2048, "data_type": "random"}, + id="seq_len_edge_very_small_to_medium", + ), + pytest.param( + { + "seq_lens": [2049, 2100, 2500, 3000], + "top_k": 2048, + "data_type": "random", + }, + id="seq_len_edge_above_k", + ), + pytest.param( + {"seq_lens": [8000, 16384, 20000], "top_k": 2048, "data_type": "random"}, + id="algo_transition_filtered_radix", + ), + # ==================== CATEGORY: Data Distributions ==================== + pytest.param( + {"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_asc"}, + id="data_sorted_ascending", + ), + pytest.param( + {"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "sorted_desc"}, + id="data_sorted_descending", + ), + pytest.param( + {"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "all_same"}, + id="data_all_same", + ), + pytest.param( + {"seq_lens": [5000, 10000], "top_k": 2048, "data_type": "many_ties"}, + id="data_many_ties", + ), + pytest.param( + { + "seq_lens": [5000, 10000], + "top_k": 2048, + "data_type": "small_differences", + }, + id="data_float_precision", + ), + # ==================== CATEGORY: Alignment / Vectorization ==================== + pytest.param( + { + "seq_lens": [2055, 2056, 2057, 2063], + "top_k": 2048, + "data_type": "random", + }, + id="align_vec_boundaries_low", + ), + pytest.param( + { + "seq_lens": [4095, 4096, 4097, 4102], + "top_k": 2048, + "data_type": "random", + }, + id="align_4k_boundary", + ), + pytest.param( + { + "seq_lens": [8191, 8192, 8193, 8198], + "top_k": 2048, + "data_type": "random", + }, + id="align_8k_boundary", + ), + pytest.param( + { + "seq_lens": [16383, 16384, 16385, 16390], + "top_k": 2048, + "data_type": "random", + }, + id="align_16k_boundary", + ), + ], +) +@torch.inference_mode() +def test_persistent_topk_correctness(test_config: dict) -> None: + """ + Comprehensive correctness tests covering: + - Sequence length edge cases (trivial, boundary, varied) + - Very small sequences (< 100 elements) + - Mixed sequence lengths in same batch + - Data distributions (sorted, ties, precision) + - Memory alignment / vectorization boundaries + """ + run_large_context_topk_test( + batch_size=len(test_config["seq_lens"]), + seq_lens=test_config["seq_lens"], + top_k=test_config["top_k"], + data_type=test_config.get("data_type", "random"), ) - row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda") - row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n - next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n - row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1 - logits_long = create_random_logits( - row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random" +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize( + "test_config", + [ + # ==================== CATEGORY: Batch Size Scalability ==================== + pytest.param( + {"batch_size": 1, "seq_len": 5000, "top_k": 2048}, + id="batch_1", + ), + pytest.param( + {"batch_size": 4, "seq_len": 5000, "top_k": 2048}, + id="batch_4", + ), + pytest.param( + {"batch_size": 32, "seq_len": 5000, "top_k": 2048}, + id="batch_32", + ), + pytest.param( + {"batch_size": 256, "seq_len": 5000, "top_k": 2048}, + id="batch_256", + ), + # ==================== CATEGORY: Single-CTA vs Multi-CTA ==================== + pytest.param( + {"batch_size": 2, "seq_len": 4096, "top_k": 2048}, + id="single_cta_4k", + ), + pytest.param( + {"batch_size": 2, "seq_len": 8192, "top_k": 2048}, + id="single_cta_8k", + ), + pytest.param( + {"batch_size": 2, "seq_len": 163840, "top_k": 2048}, + id="multi_cta_163840_dsv3_max", + ), + # ==================== CATEGORY: Extreme Cases ==================== + pytest.param( + {"batch_size": 512, "seq_len": 5000, "top_k": 2048}, + id="extreme_large_batch", + ), + pytest.param( + {"batch_size": 2, "seq_len": 163840, "top_k": 2048}, + id="extreme_dsv3_max_context", + ), + ], +) +@torch.inference_mode() +def test_persistent_topk_algorithm_paths(test_config: dict) -> None: + """ + Test different algorithm execution paths (capped at 163840 for DeepSeek V3.2): + - Batch size scalability (1, 4, 32, 256) + - Single-CTA vs Multi-CTA execution + - Extreme configurations (large batch, max context length) + """ + run_large_context_topk_test( + batch_size=test_config["batch_size"], + seq_lens=[test_config["seq_len"]] * test_config["batch_size"], + top_k=test_config["top_k"], ) - indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda") - # Use large_context_topk kernel for long sequences - if next_n == 1: - lengths = seq_lens_long - else: - offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32) - lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten() +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_persistent_topk_stress() -> None: + """ + Stress test with random configurations to catch edge cases. + Capped at 163840 (DeepSeek V3.2 max context) for realistic testing. + """ + torch.set_default_device("cuda:0") + top_k = 2048 - torch.ops._C.large_context_topk( - logits_long, - indices, - lengths, - None, + for seed in range(3): + set_random_seed(seed) + + # Random batch size (limited for speed) + batch_size = torch.randint(1, 32, (1,)).item() + + # Random sequence lengths capped at DeepSeek V3.2 max context + seq_lens = torch.randint(100, 163840, (batch_size,)).tolist() + + run_large_context_topk_test( + batch_size=batch_size, + seq_lens=seq_lens, + top_k=top_k, + seed=seed, + ) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize( + "test_config", + [ + # Mixed batch: rows spanning all four paths (trivial, decode, medium, large) + pytest.param( + { + "seq_lens": [2000, 6000, 30000, 80000], + "top_k": 2048, + "data_type": "random", + }, + id="mixed_all_paths", + ), + # All decode/medium rows (typical decode scenario) + pytest.param( + { + "seq_lens": [2048, 4096, 8192, 16000], + "top_k": 2048, + "data_type": "random", + }, + id="all_decode_medium", + ), + # All large rows + pytest.param( + { + "seq_lens": [70000, 100000, 163840], + "top_k": 2048, + "data_type": "random", + }, + id="all_large", + ), + # Boundary around LARGE_THRESHOLD (32K) + pytest.param( + { + "seq_lens": [32767, 32768, 32769, 32772], + "top_k": 2048, + "data_type": "random", + }, + id="large_threshold_boundary", + ), + # Single row medium + pytest.param( + { + "seq_lens": [5000], + "top_k": 2048, + "data_type": "random", + }, + id="single_row_medium", + ), + # Single row large + pytest.param( + { + "seq_lens": [100000], + "top_k": 2048, + "data_type": "random", + }, + id="single_row_large", + ), + # Trivial rows mixed with medium and large + pytest.param( + { + "seq_lens": [100, 2048, 10000, 80000], + "top_k": 2048, + "data_type": "random", + }, + id="trivial_medium_large_mix", + ), + ], +) +@torch.inference_mode() +def test_persistent_topk(test_config: dict) -> None: + """ + Tests specific to the persistent_topk kernel: + - Mixed medium/large rows in the same batch (dynamic per-row dispatch) + - Boundary around LARGE_THRESHOLD (32K) + - Trivial + medium + large rows in a single batch + """ + run_large_context_topk_test( + batch_size=len(test_config["seq_lens"]), + seq_lens=test_config["seq_lens"], + top_k=test_config["top_k"], + data_type=test_config.get("data_type", "random"), ) - torch_indices_short = torch.empty( - (num_rows_short, top_k), dtype=torch.int32, device="cuda" + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@torch.inference_mode() +def test_persistent_topk_padded_stride() -> None: + """ + Test persistent_topk with padded logits (large stride, small seq_len) + to simulate the e2e CUDAGraph scenario where fp8_paged_mqa_logits + returns [B, max_model_len] with max_model_len=163840. + """ + set_random_seed(42) + torch.set_default_device("cuda:0") + + top_k = 2048 + batch_size = 4 + padded_stride = 163840 # DeepSeek-V3.2 max_model_len + actual_seq_lens = [3000, 5000, 8000, 12000] + + # Create padded logits tensor (like fp8_paged_mqa_logits output) + logits = torch.full( + (batch_size, padded_stride), + float("-inf"), + dtype=torch.float32, + device="cuda", ) - for i in range(num_rows_short): - row_end = int(row_ends_short[i]) - k_i = min(top_k, row_end) - idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1] - torch_indices_short[i, :k_i] = idx + for i, sl in enumerate(actual_seq_lens): + logits[i, :sl] = torch.randn(sl, dtype=torch.float32, device="cuda") - assert compare_top_k_results( - logits_short, - indices_vllm, - torch_indices_short, - row_starts_short, - row_ends_short, - top_k, - ), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk" + lengths = torch.tensor(actual_seq_lens, dtype=torch.int32, device="cuda") + indices = torch.empty((batch_size, top_k), dtype=torch.int32, device="cuda") + workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda") - torch_indices_long = torch.empty( - (num_rows_long, top_k), dtype=torch.int32, device="cuda" + torch.ops._C.persistent_topk( + logits, lengths, indices, workspace, top_k, max(actual_seq_lens) ) - for i in range(num_rows_long): - row_end = int(row_ends_long[i]) - k_i = min(top_k, row_end) - idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1] - torch_indices_long[i, :k_i] = idx + torch.accelerator.synchronize() - assert compare_top_k_results( - logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k - ), "large_context_topk kernel (long sequences) doesn't match torch.topk" + # Validate against torch.topk + for i in range(batch_size): + sl = actual_seq_lens[i] + k_i = min(top_k, sl) + expected = logits[i, :sl].topk(k_i, dim=-1)[1].cpu() + actual = indices[i, :k_i].cpu() + + expected_set = set(expected.tolist()) + actual_set = set(actual.tolist()) + + if expected_set != actual_set: + # Allow ties + expected_vals = logits[i, expected].cpu().sort(descending=True)[0] + actual_vals = logits[i, actual].cpu().sort(descending=True)[0] + assert torch.allclose(expected_vals, actual_vals, rtol=1e-4, atol=1e-4), ( + f"Row {i}: persistent_topk with padded stride doesn't match. " + f"seq_len={sl}, stride={padded_stride}" + ) diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca148536f327..1844b75561e1 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -25,6 +25,8 @@ logger = init_logger(__name__) +RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 + def sparse_attn_indexer( hidden_states: torch.Tensor, @@ -51,6 +53,7 @@ def sparse_attn_indexer( current_workspace_manager().get_simultaneous( ((total_seq_lens, head_dim), torch.float8_e4m3fn), ((total_seq_lens, 4), torch.uint8), + ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), ) # Dummy allocation to simulate for peak logits tensor memory during inference. @@ -157,15 +160,6 @@ def sparse_attn_indexer( topk_tokens, ) - # Compute lengths from row spans - # lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32) - # torch.ops._C.large_context_topk( - # logits, - # topk_indices, - # lengths, - # chunk.cu_seqlen_ks, # row_starts - # ) - if has_decode: decode_metadata = attn_metadata.decode assert decode_metadata is not None @@ -204,23 +198,29 @@ def sparse_attn_indexer( num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - if decode_metadata.use_large_context_topk: - if next_n == 1: - lengths = decode_metadata.seq_lens - else: - # (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,) - lengths = ( - decode_metadata.seq_lens.unsqueeze(1) - - next_n - + 1 - + decode_metadata.offsets - ).flatten() - - torch.ops._C.large_context_topk( + if next_n == 1: + lengths = decode_metadata.seq_lens + else: + # (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,) + lengths = ( + decode_metadata.seq_lens.unsqueeze(1) + - next_n + + 1 + + decode_metadata.offsets + ).flatten() + + if current_platform.is_cuda(): + workspace_manager = current_workspace_manager() + (topk_workspace,) = workspace_manager.get_simultaneous( + ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), + ) + torch.ops._C.persistent_topk( logits, - topk_indices, lengths, - None, + topk_indices, + topk_workspace, + topk_tokens, + attn_metadata.max_seq_len, ) else: if current_platform.is_xpu(): diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cfeb36f4af25..17ddd5edeced 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -67,7 +67,9 @@ per_token_group_quant_fp8, ) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.model_executor.layers.sparse_attn_indexer import ( + SparseAttnIndexer, +) from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -1203,7 +1205,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( - vllm_config, prefix, topk_indices_buffer=topk_indices_buffer + vllm_config, + prefix, + topk_indices_buffer=topk_indices_buffer, ), prefix=f"{prefix}.layers", ) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 927583a0f17b..5cb4b46a7345 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -145,7 +145,6 @@ class DeepSeekV32IndexerDecodeMetadata: decode_lens: torch.Tensor requires_padding: bool schedule_metadata: torch.Tensor - use_large_context_topk: bool offsets: torch.Tensor | None # Precomputed offsets for speculative decoding @@ -437,7 +436,6 @@ def build( if use_native and next_n > 1: offsets = self.offsets_buffer - batch_size = num_decodes elif max_decode_len > 1: # Flatten multi-token decode requests into single-token # batch entries, expanding seq_lens and block tables so @@ -496,10 +494,8 @@ def build( self.decode_lens_buffer[:num_decode_tokens] = 1 decode_lens = self.decode_lens_buffer[:num_decode_tokens] offsets = None - batch_size = num_decode_tokens else: offsets = None - batch_size = num_decodes # DeepGEMM is required for the paged MQA logits on CUDA devices if current_platform.is_cuda() and has_deep_gemm(): @@ -509,20 +505,12 @@ def build( self.num_sms, ) - # Decide which top-k kernel to use based on batch size and sequence length - # Decision logic based on micro-benchmark results: - # - large_context_topk wins for batch <= 128 and seq_len > 8K - # - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K - _is_large_context = common_attn_metadata.max_seq_len > 8192 - use_large_context_topk = batch_size <= 128 and _is_large_context - decode_metadata = DeepSeekV32IndexerDecodeMetadata( block_table=block_table, seq_lens=seq_lens, decode_lens=decode_lens, requires_padding=False, schedule_metadata=self.scheduler_metadata_buffer, - use_large_context_topk=use_large_context_topk, offsets=offsets, )