diff --git a/csrc/flashinfer_topk_binding.cu b/csrc/flashinfer_topk_binding.cu index 36c23ec386..e40e0b9e0d 100644 --- a/csrc/flashinfer_topk_binding.cu +++ b/csrc/flashinfer_topk_binding.cu @@ -19,17 +19,19 @@ using tvm::ffi::Optional; void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, Optional maybe_row_states_buffer, int64_t top_k, bool sorted_output, - bool deterministic, int64_t tie_break); + bool deterministic, int64_t tie_break, bool dsa_graph_safe); void radix_topk_page_table_transform(TensorView input, TensorView output_page_table, TensorView src_page_table, Optional maybe_row_to_batch, TensorView lengths, Optional maybe_row_states_buffer, int64_t top_k, - bool deterministic, int64_t tie_break); + bool deterministic, int64_t tie_break, bool dsa_graph_safe, + Optional maybe_row_starts); void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, TensorView lengths, Optional maybe_row_states_buffer, - int64_t top_k, bool deterministic, int64_t tie_break); + int64_t top_k, bool deterministic, int64_t tie_break, + bool dsa_graph_safe, Optional maybe_row_starts); bool can_implement_filtered_topk(); diff --git a/csrc/topk.cu b/csrc/topk.cu index 45ef16c906..7e97f86ca5 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -40,7 +40,7 @@ inline sampling::TopKTieBreak ParseTopKTieBreak(int64_t tie_break) { void radix_topk(TensorView input, TensorView output_indices, TensorView output_values, Optional maybe_row_states_buffer, int64_t top_k, bool sorted_output, - bool deterministic, int64_t tie_break) { + bool deterministic, int64_t tie_break, bool dsa_graph_safe) { CHECK_INPUT(input); CHECK_INPUT(output_indices); CHECK_INPUT(output_values); @@ -72,7 +72,7 @@ void radix_topk(TensorView input, TensorView output_indices, TensorView output_v status = sampling::TopKDispatch( static_cast(input.data_ptr()), static_cast(output_indices.data_ptr()), static_cast(output_values.data_ptr()), batch_size, static_cast(top_k), d, - row_states_ptr, sorted_output, deterministic, tie_break_mode, stream); + row_states_ptr, sorted_output, deterministic, tie_break_mode, stream, dsa_graph_safe); return true; }); @@ -84,7 +84,8 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta TensorView src_page_table, Optional maybe_row_to_batch, TensorView lengths, Optional maybe_row_states_buffer, int64_t top_k, - bool deterministic, int64_t tie_break) { + bool deterministic, int64_t tie_break, bool dsa_graph_safe, + Optional maybe_row_starts) { CHECK_INPUT(input); CHECK_INPUT(output_page_table); CHECK_INPUT(src_page_table); @@ -93,6 +94,10 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta CHECK_DIM(2, output_page_table); // output_page_table: (num_rows, top_k) CHECK_DIM(2, src_page_table); // src_page_table: (batch_size, max_len) CHECK_DIM(1, lengths); // lengths: (num_rows,) + if (maybe_row_starts.has_value()) { + CHECK_INPUT(maybe_row_starts.value()); + CHECK_DIM(1, maybe_row_starts.value()); + } unsigned int num_rows = input.size(0); unsigned int max_len = input.size(1); @@ -118,14 +123,21 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta if (maybe_row_to_batch.has_value()) { row_to_batch_ptr = static_cast(maybe_row_to_batch.value().data_ptr()); } + int32_t* row_starts_ptr = nullptr; + if (maybe_row_starts.has_value()) { + TVM_FFI_ICHECK(static_cast(maybe_row_starts.value().size(0)) == num_rows) + << "row_starts must have shape (num_rows,)"; + row_starts_ptr = static_cast(maybe_row_starts.value().data_ptr()); + } // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { status = sampling::TopKPageTableTransformDispatch( static_cast(input.data_ptr()), static_cast(output_page_table.data_ptr()), - static_cast(src_page_table.data_ptr()), src_stride, row_to_batch_ptr, - static_cast(lengths.data_ptr()), num_rows, static_cast(top_k), max_len, - row_states_ptr, deterministic, tie_break_mode, stream); + static_cast(src_page_table.data_ptr()), src_stride, + static_cast(lengths.data_ptr()), row_starts_ptr, row_to_batch_ptr, num_rows, + static_cast(top_k), max_len, row_states_ptr, deterministic, tie_break_mode, + stream, dsa_graph_safe); return true; }); @@ -135,7 +147,8 @@ void radix_topk_page_table_transform(TensorView input, TensorView output_page_ta void radix_topk_ragged_transform(TensorView input, TensorView output_indices, TensorView offsets, TensorView lengths, Optional maybe_row_states_buffer, - int64_t top_k, bool deterministic, int64_t tie_break) { + int64_t top_k, bool deterministic, int64_t tie_break, + bool dsa_graph_safe, Optional maybe_row_starts) { CHECK_INPUT(input); CHECK_INPUT(output_indices); CHECK_INPUT(offsets); @@ -144,6 +157,10 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k) CHECK_DIM(1, offsets); // offsets: (num_rows,) CHECK_DIM(1, lengths); // lengths: (num_rows,) + if (maybe_row_starts.has_value()) { + CHECK_INPUT(maybe_row_starts.value()); + CHECK_DIM(1, maybe_row_starts.value()); + } unsigned int num_rows = input.size(0); unsigned int max_len = input.size(1); @@ -163,14 +180,20 @@ void radix_topk_ragged_transform(TensorView input, TensorView output_indices, Te row_states_ptr = static_cast(maybe_row_states_buffer.value().data_ptr()); } + int32_t* row_starts_ptr = nullptr; + if (maybe_row_starts.has_value()) { + TVM_FFI_ICHECK(static_cast(maybe_row_starts.value().size(0)) == num_rows) + << "row_starts must have shape (num_rows,)"; + row_starts_ptr = static_cast(maybe_row_starts.value().data_ptr()); + } // Use unified dispatch with heuristics to choose between FilteredTopK and RadixTopK DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(dtype, c_type, [&] { status = sampling::TopKRaggedTransformDispatch( static_cast(input.data_ptr()), static_cast(output_indices.data_ptr()), static_cast(offsets.data_ptr()), static_cast(lengths.data_ptr()), - num_rows, static_cast(top_k), max_len, row_states_ptr, deterministic, - tie_break_mode, stream); + row_starts_ptr, num_rows, static_cast(top_k), max_len, row_states_ptr, + deterministic, tie_break_mode, stream, dsa_graph_safe); return true; }); diff --git a/flashinfer/topk.py b/flashinfer/topk.py index eef1998d03..2ac4423c9c 100644 --- a/flashinfer/topk.py +++ b/flashinfer/topk.py @@ -72,6 +72,7 @@ def radix_topk( tie_break: int, row_states_buffer: Optional[torch.Tensor], output_values: torch.Tensor, + dsa_graph_safe: bool = False, ) -> torch.Tensor: device = input.device # Supports float32, float16, bfloat16 @@ -91,6 +92,7 @@ def radix_topk( sorted_output, deterministic, tie_break, + dsa_graph_safe, ) return output_indices @@ -103,6 +105,7 @@ def _fake_radix_topk( tie_break: int, row_states_buffer: Optional[torch.Tensor], output_values: torch.Tensor, + dsa_graph_safe: bool = False, ) -> torch.Tensor: batch_size = input.size(0) return torch.empty(batch_size, top_k, dtype=torch.int32, device=input.device) @@ -250,6 +253,8 @@ def radix_topk_page_table_transform( top_k: int, deterministic: bool, tie_break: int, + dsa_graph_safe: bool = False, + row_starts: Optional[torch.Tensor] = None, ) -> None: assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], ( f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16" @@ -264,6 +269,8 @@ def radix_topk_page_table_transform( top_k, deterministic, tie_break, + dsa_graph_safe, + row_starts, ) @register_fake_op("flashinfer::radix_topk_page_table_transform") @@ -277,6 +284,8 @@ def _fake_radix_topk_page_table_transform( top_k: int, deterministic: bool, tie_break: int, + dsa_graph_safe: bool = False, + row_starts: Optional[torch.Tensor] = None, ) -> None: pass @@ -293,6 +302,8 @@ def radix_topk_ragged_transform( top_k: int, deterministic: bool, tie_break: int, + dsa_graph_safe: bool = False, + row_starts: Optional[torch.Tensor] = None, ) -> None: assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], ( f"Unsupported dtype {input.dtype}, expected float32, float16, or bfloat16" @@ -306,6 +317,8 @@ def radix_topk_ragged_transform( top_k, deterministic, tie_break, + dsa_graph_safe, + row_starts, ) @register_fake_op("flashinfer::radix_topk_ragged_transform") @@ -318,6 +331,8 @@ def _fake_radix_topk_ragged_transform( top_k: int, deterministic: bool, tie_break: int, + dsa_graph_safe: bool = False, + row_starts: Optional[torch.Tensor] = None, ) -> None: pass @@ -477,7 +492,9 @@ def topk_clusters_ragged_transform(logits, seq_lens, offsets, top_k, pdl=False): return indices -def can_use_clusters_topk(device, deterministic): +def can_use_clusters_topk(device, deterministic, dsa_graph_safe): + if dsa_graph_safe: + return False algo = os.environ.get("FLASHINFER_TOPK_ALGO") cap = get_compute_capability(device) return (algo is None or algo == "clusters") and not deterministic and cap[0] == 10 @@ -490,6 +507,7 @@ def top_k( sorted: bool = False, deterministic: bool = False, tie_break: int = TopKTieBreak.NONE, + dsa_graph_safe: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Radix-based Top-K selection. @@ -525,6 +543,9 @@ def top_k( - ``2``: prefer larger indices Default is ``0``. + dsa_graph_safe : bool, optional + If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1). + Default is False. Returns ------- @@ -578,7 +599,7 @@ def top_k( if tie_break != TopKTieBreak.NONE: deterministic = True - if can_use_clusters_topk(input.device, deterministic): + if can_use_clusters_topk(input.device, deterministic, dsa_graph_safe): indices, output_values = topk_clusters_exact( input, k, output_values=True, out_dtype=torch.int64 ) @@ -613,6 +634,7 @@ def top_k( tie_break, row_states_buffer, output_values, + dsa_graph_safe, ) # Convert to int64 for compatibility @@ -642,6 +664,8 @@ def top_k_page_table_transform( row_to_batch: Optional[torch.Tensor] = None, deterministic: bool = False, tie_break: int = TopKTieBreak.NONE, + dsa_graph_safe: bool = False, + row_starts: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Fused Top-K selection + Page Table Transform for sparse attention. @@ -682,6 +706,13 @@ def top_k_page_table_transform( - ``2``: prefer larger indices Default is ``0``. + dsa_graph_safe : bool, optional + If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1). + Default is False. + row_starts : Optional[torch.Tensor], optional + Per-row start indices of shape ``(num_rows,)`` with dtype ``int32``. + Top-k is computed over ``[row_starts[i], row_starts[i] + lengths[i])`` for row ``i``. + Default is None (equivalent to all zeros). Returns @@ -694,7 +725,9 @@ def top_k_page_table_transform( Note ---- - This is specifically designed for sparse attention's second stage. - - If lengths[i] <= k, the output simply contains src_page_table[batch_idx, 0:lengths[i]] + - If lengths[i] <= k, the output simply contains + ``src_page_table[batch_idx, row_starts[i]:row_starts[i] + lengths[i]]`` (or start 0 when + ``row_starts`` is None) with remaining positions set to -1. Examples @@ -718,7 +751,11 @@ def top_k_page_table_transform( if tie_break != TopKTieBreak.NONE: deterministic = True - if can_use_clusters_topk(input.device, deterministic) and row_to_batch is None: + if ( + can_use_clusters_topk(input.device, deterministic, dsa_graph_safe) + and row_to_batch is None + and row_starts is None + ): return topk_clusters_page_table_transform(input, lengths, src_page_table, k) # Allocate row_states buffer for multi-CTA path @@ -742,6 +779,8 @@ def top_k_page_table_transform( k, deterministic, tie_break, + dsa_graph_safe, + row_starts=row_starts, ) return output_page_table @@ -755,6 +794,8 @@ def top_k_ragged_transform( k: int, deterministic: bool = False, tie_break: int = TopKTieBreak.NONE, + dsa_graph_safe: bool = False, + row_starts: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Fused Top-K selection + Ragged Index Transform for sparse attention. @@ -788,6 +829,14 @@ def top_k_ragged_transform( - ``2``: prefer larger indices Default is ``0``. + dsa_graph_safe : bool, optional + If True, force FilteredTopK path and graph-safe vectorization (VEC_SIZE=1). + Default is False. + row_starts : Optional[torch.Tensor], optional + Per-row start indices of shape ``(num_rows,)`` with dtype ``int32``. + Top-k is computed over ``[row_starts[i], row_starts[i] + lengths[i])`` for row ``i``. + Output indices remain ``local_topk + offsets[i]`` where ``local_topk`` is relative to + ``row_starts[i]``. Default is None (equivalent to all zeros). Returns @@ -825,7 +874,10 @@ def top_k_ragged_transform( if tie_break != TopKTieBreak.NONE: deterministic = True - if can_use_clusters_topk(input.device, deterministic): + if ( + can_use_clusters_topk(input.device, deterministic, dsa_graph_safe) + and row_starts is None + ): return topk_clusters_ragged_transform(input, lengths, offsets, k) # Allocate row_states buffer for multi-CTA path @@ -848,6 +900,8 @@ def top_k_ragged_transform( k, deterministic, tie_break, + dsa_graph_safe, + row_starts=row_starts, ) return output_indices diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 4fa5108f9e..e9d7893d98 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -1098,6 +1098,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( const IdType* aux_data, // Mode-specific: top_k_arr (Basic), src_page_table (PageTable), offsets (Ragged) IdType* lengths, // [num_rows] per-row lengths, nullptr for Basic (uses stride) + const IdType* row_starts, // [num_rows] per-row start indices, nullptr => 0 const IdType* row_to_batch, // [num_rows] batch mapping for PageTable, nullptr otherwise int64_t aux_stride, // src_page_table stride for PageTable mode, 0 otherwise uint32_t top_k_val, uint32_t stride, uint32_t num_rows, RadixRowState* row_states, @@ -1140,6 +1141,9 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( for (uint32_t iter = 0; iter < total_iterations; iter++) { uint32_t row_idx = group_id + iter * num_groups; if (row_idx >= num_rows) break; + const uint32_t row_start = + (row_starts != nullptr && MODE != RadixTopKMode::Basic) ? row_starts[row_idx] : 0; + DType* row_input = input + static_cast(row_idx) * stride + row_start; // Mode-specific: get row length and k value uint32_t length, k; @@ -1186,7 +1190,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( const IdType* src_page_entry = aux_data + batch_idx * aux_stride; if (length <= top_k_val) { for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) { - row_output[i] = (i < length) ? src_page_entry[i] : static_cast(-1); + row_output[i] = (i < length) ? src_page_entry[row_start + i] : static_cast(-1); } // Clear histogram for next iteration if constexpr (!SINGLE_CTA) { @@ -1229,9 +1233,9 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( uint32_t cta_local_eq_count = 0; OrderedType ordered_pivot = RadixSelectFindPivot( - input + static_cast(row_idx) * stride, shared_ordered, local_histogram, - suffix_sum, shared_scalars, state, chunk_start, actual_chunk_size, k, barrier_phase, - ctas_per_group, cta_in_group, tx, iter, cta_local_gt_count, cta_local_eq_count); + row_input, shared_ordered, local_histogram, suffix_sum, shared_scalars, state, + chunk_start, actual_chunk_size, k, barrier_phase, ctas_per_group, cta_in_group, tx, + iter, cta_local_gt_count, cta_local_eq_count); auto collect_indices = [&](auto&& output_func) { if constexpr (DETERMINISTIC) { @@ -1268,7 +1272,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( // Transform through page table with coalesced access for (uint32_t i = tx; i < k; i += BLOCK_THREADS) { IdType idx = row_output[i]; - row_output[i] = src_page_entry[idx]; + row_output[i] = src_page_entry[row_start + idx]; } } else { // Barrier to ensure all CTAs finished writing indices @@ -1280,7 +1284,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( uint32_t my_end = min(my_start + elems_per_cta, k); for (uint32_t i = my_start + tx; i < my_end; i += BLOCK_THREADS) { IdType idx = row_output[i]; - row_output[i] = src_page_entry[idx]; + row_output[i] = src_page_entry[row_start + idx]; } } } else { // RaggedTransform @@ -1874,6 +1878,7 @@ cudaError_t RadixTopKRenormProbMultiCTA(DType* probs, DType* renormed_prob, IdTy * \param src_stride Stride of source page table (typically max_len) * \param row_to_batch Mapping from row index to batch index [num_rows], or nullptr if 1:1 * \param lengths Sequence lengths per row [num_rows] + * \param row_starts Start indices per row [num_rows], or nullptr to use 0 * \param num_rows Number of rows to process * \param top_k_val Number of top elements to select * \param max_len Maximum sequence length (input stride) @@ -1884,12 +1889,13 @@ template cudaError_t RadixTopKPageTableTransformMultiCTA(DType* input, IdType* output_page_table, const IdType* src_page_table, int64_t src_stride, const IdType* row_to_batch, IdType* lengths, - uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, RadixRowState* row_states_buffer, + const IdType* row_starts, uint32_t num_rows, + uint32_t top_k_val, uint32_t max_len, + RadixRowState* row_states_buffer, bool deterministic, cudaStream_t stream = 0) { using OrderedType = typename RadixTopKTraits::OrderedType; constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), max_len); + const uint32_t vec_size = (row_starts != nullptr) ? 1 : std::gcd(16 / sizeof(DType), max_len); int device; FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); @@ -1934,10 +1940,10 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(DType* input, IdType* output_pag DType* output_values = nullptr; // Not used in PageTableTransform mode dim3 nblks(total_ctas); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_page_table, &output_values, &src_page_table, - &lengths, &row_to_batch, &src_stride, &top_k_val, - &max_len, &num_rows, &row_states_buffer, &det_scratch_buffer, - &chunk_size, &ctas_per_group}; + void* args[] = { + &input, &output_page_table, &output_values, &src_page_table, &lengths, + &row_starts, &row_to_batch, &src_stride, &top_k_val, &max_len, + &num_rows, &row_states_buffer, &det_scratch_buffer, &chunk_size, &ctas_per_group}; #define LAUNCH_PAGE_TABLE_KERNEL(THREADS, SINGLE_CTA_FLAG, DET_FLAG) \ do { \ @@ -1988,12 +1994,13 @@ cudaError_t RadixTopKPageTableTransformMultiCTA(DType* input, IdType* output_pag template cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input, IdType* output_indices, const IdType* offsets, IdType* lengths, - uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, RadixRowState* row_states_buffer, - bool deterministic, cudaStream_t stream = 0) { + const IdType* row_starts, uint32_t num_rows, + uint32_t top_k_val, uint32_t max_len, + RadixRowState* row_states_buffer, bool deterministic, + cudaStream_t stream = 0) { using OrderedType = typename RadixTopKTraits::OrderedType; constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), max_len); + const uint32_t vec_size = (row_starts != nullptr) ? 1 : std::gcd(16 / sizeof(DType), max_len); int device; FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); @@ -2040,10 +2047,10 @@ cudaError_t RadixTopKRaggedTransformMultiCTA(DType* input, IdType* output_indice int64_t aux_stride = 0; // Not used in RaggedTransform mode dim3 nblks(total_ctas); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_indices, &output_values, &offsets, - &lengths, &row_to_batch, &aux_stride, &top_k_val, - &max_len, &num_rows, &row_states_buffer, &det_scratch_buffer, - &chunk_size, &ctas_per_group}; + void* args[] = { + &input, &output_indices, &output_values, &offsets, &lengths, + &row_starts, &row_to_batch, &aux_stride, &top_k_val, &max_len, + &num_rows, &row_states_buffer, &det_scratch_buffer, &chunk_size, &ctas_per_group}; #define LAUNCH_RAGGED_KERNEL(THREADS, SINGLE_CTA_FLAG, DET_FLAG) \ do { \ @@ -2144,14 +2151,15 @@ cudaError_t RadixTopKMultiCTA(DType* input, IdType* output_indices, DType* outpu // Unified kernel parameters IdType* lengths = nullptr; // Not used in Basic mode + const IdType* row_starts = nullptr; // Not used in Basic mode const IdType* row_to_batch = nullptr; // Not used in Basic mode int64_t aux_stride = 0; // Not used in Basic mode dim3 nblks(total_ctas); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&input, &output_indices, &output_values, &top_k_arr, - &lengths, &row_to_batch, &aux_stride, &top_k_val, - &vocab_size, &batch_size, &row_states_buffer, &det_scratch_buffer, - &chunk_size, &ctas_per_group}; + void* args[] = { + &input, &output_indices, &output_values, &top_k_arr, &lengths, + &row_starts, &row_to_batch, &aux_stride, &top_k_val, &vocab_size, + &batch_size, &row_states_buffer, &det_scratch_buffer, &chunk_size, &ctas_per_group}; #define LAUNCH_BASIC_KERNEL(THREADS, SINGLE_CTA_FLAG, DET_FLAG) \ do { \ @@ -2284,8 +2292,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) const IdType* __restrict__ aux_input, // page_table or offsets int64_t aux_stride, // src_stride for PageTable const IdType* __restrict__ row_to_batch, // for PageTable - const IdType* __restrict__ lengths, uint32_t num_rows, uint32_t top_k, - uint32_t max_len) { + const IdType* __restrict__ lengths, + const IdType* __restrict__ row_starts, // per-row score start + uint32_t num_rows, uint32_t top_k, uint32_t max_len) { constexpr uint32_t BLOCK_SIZE = FILTERED_TOPK_BLOCK_THREADS; constexpr int RADIX = 256; constexpr int SMEM_INPUT_SIZE = FILTERED_TOPK_SMEM_INPUT_SIZE; @@ -2297,7 +2306,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) if (bid >= num_rows) return; const int length = (lengths != nullptr) ? lengths[bid] : static_cast(max_len); - const DType* score = input + static_cast(bid) * max_len; + const IdType row_start = + (row_starts != nullptr && MODE != FilteredTopKMode::Plain) ? row_starts[bid] : 0; + const DType* score = input + static_cast(bid) * max_len + row_start; IdType* dst = output + bid * top_k; // Mode-specific setup @@ -2329,7 +2340,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) // In deterministic mode the page-table/ragged transform happens in SortTopKByIndexKernel dst[i] = (i < length) ? static_cast(i) : static_cast(-1); } else if constexpr (MODE == FilteredTopKMode::PageTable) { - dst[i] = (i < length) ? src_page_entry[i] : static_cast(-1); + dst[i] = (i < length) ? src_page_entry[row_start + i] : static_cast(-1); } else { // Ragged dst[i] = (i < length) ? static_cast(i) + offset_val : static_cast(-1); } @@ -2820,7 +2831,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) } else if constexpr (DETERMINISTIC) { // transform in SortTopKByIndexKernel dst[base] = static_cast(idx); } else if constexpr (MODE == FilteredTopKMode::PageTable) { - dst[base] = src_page_entry[idx]; + dst[base] = src_page_entry[row_start + idx]; } else { // Ragged dst[base] = static_cast(idx) + offset_val; } @@ -2840,7 +2851,10 @@ constexpr uint32_t gcd(uint32_t a, uint32_t b) { // Compute optimal VEC_SIZE based on max_len and dtype // Returns 1, 2, 4, or 8 template -constexpr int ComputeFilteredTopKVecSize(uint32_t max_len) { +constexpr int ComputeFilteredTopKVecSize(uint32_t max_len, bool dsa_graph_safe = false) { + if (dsa_graph_safe) { + return 1; + } constexpr int MAX_VEC = 16 / sizeof(DType); // 4 for float32, 8 for fp16/bf16 // Use GCD to find largest power-of-2 divisor const uint32_t g = gcd(max_len, static_cast(MAX_VEC)); @@ -2864,8 +2878,8 @@ template __global__ void __launch_bounds__(BLOCK_THREADS) SortTopKByIndexKernel(IdType* output_indices, DType* output_values, const IdType* aux_input, - int64_t aux_stride, const IdType* row_to_batch, uint32_t top_k, - uint32_t max_len) { + int64_t aux_stride, const IdType* row_starts, const IdType* row_to_batch, + uint32_t top_k, uint32_t max_len) { constexpr bool WITH_VALUES = (MODE == FilteredTopKMode::Plain); using BlockRadixSortT = typename SortTopKByIndexBlockRadixSort::Type; @@ -2904,9 +2918,11 @@ __global__ void __launch_bounds__(BLOCK_THREADS) const IdType* src_page_entry = nullptr; IdType offset = 0; + IdType row_start = 0; if constexpr (MODE == FilteredTopKMode::PageTable) { const uint32_t batch_idx = (row_to_batch != nullptr) ? row_to_batch[row] : row; src_page_entry = aux_input + static_cast(batch_idx) * aux_stride; + row_start = (row_starts != nullptr) ? row_starts[row] : 0; } else if constexpr (MODE == FilteredTopKMode::Ragged) { offset = aux_input[row]; } @@ -2920,7 +2936,7 @@ __global__ void __launch_bounds__(BLOCK_THREADS) row_output[pos] = static_cast(idx); output_values[static_cast(row) * top_k + pos] = values[i]; } else if constexpr (MODE == FilteredTopKMode::PageTable) { - row_output[pos] = (idx != ~0u) ? src_page_entry[idx] : static_cast(-1); + row_output[pos] = (idx != ~0u) ? src_page_entry[row_start + idx] : static_cast(-1); } else { // Ragged row_output[pos] = (idx != ~0u) ? static_cast(idx) + offset : static_cast(-1); @@ -2932,8 +2948,9 @@ __global__ void __launch_bounds__(BLOCK_THREADS) template cudaError_t LaunchSortTopKByIndex(IdType* output_indices, DType* output_values, const IdType* aux_input, int64_t aux_stride, - const IdType* row_to_batch, uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, cudaStream_t stream = 0) { + const IdType* row_starts, const IdType* row_to_batch, + uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + cudaStream_t stream = 0) { // Block-local sort variants cover at most 256 * 8 = 2048 elements. if (top_k_val > 2048) { return cudaErrorInvalidValue; @@ -2949,7 +2966,7 @@ cudaError_t LaunchSortTopKByIndex(IdType* output_indices, DType* output_values, dim3 grid(num_rows); void* args[] = {&output_indices, &output_values, &aux_input, &aux_stride, - &row_to_batch, &top_k_val, &max_len}; + &row_starts, &row_to_batch, &top_k_val, &max_len}; auto launch_sort = [&](auto kernel, uint32_t threads) -> cudaError_t { dim3 block(threads); return cudaLaunchKernel((void*)kernel, grid, block, args, 0, stream); @@ -3063,19 +3080,22 @@ template cudaError_t LaunchFilteredTopKUnified(DType* input, IdType* output, DType* aux_output, const IdType* aux_input, int64_t aux_stride, const IdType* row_to_batch, const IdType* lengths, - uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + const IdType* row_starts, uint32_t num_rows, + uint32_t top_k_val, uint32_t max_len, bool deterministic = false, TopKTieBreak tie_break = TopKTieBreak::None, - cudaStream_t stream = 0) { + cudaStream_t stream = 0, bool dsa_graph_safe = false) { constexpr size_t smem_size = FILTERED_TOPK_SMEM_DYNAMIC; constexpr int MAX_VEC = 16 / sizeof(DType); dim3 grid(num_rows); dim3 block(FILTERED_TOPK_BLOCK_THREADS); - void* args[] = {&input, &output, &aux_output, &aux_input, &aux_stride, - &row_to_batch, &lengths, &num_rows, &top_k_val, &max_len}; + void* args[] = {&input, &output, &aux_output, &aux_input, &aux_stride, &row_to_batch, + &lengths, &row_starts, &num_rows, &top_k_val, &max_len}; - const int vec_size = ComputeFilteredTopKVecSize(max_len); + const int vec_size = (row_starts != nullptr && MODE != FilteredTopKMode::Plain) + ? 1 + : ComputeFilteredTopKVecSize(max_len, dsa_graph_safe); #define LAUNCH_FILTERED_KERNEL(VS, DET, TIE) \ do { \ @@ -3118,41 +3138,45 @@ template cudaError_t FilteredTopKPageTableTransform(DType* input, IdType* output_page_table, const IdType* src_page_table, int64_t src_stride, const IdType* row_to_batch, IdType* lengths, - uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + const IdType* row_starts, uint32_t num_rows, + uint32_t top_k_val, uint32_t max_len, bool deterministic = false, TopKTieBreak tie_break = TopKTieBreak::None, - cudaStream_t stream = 0) { + cudaStream_t stream = 0, bool dsa_graph_safe = false) { DType* aux_output = nullptr; // Not used for PageTable mode return LaunchFilteredTopKUnified( input, output_page_table, aux_output, src_page_table, src_stride, row_to_batch, lengths, - num_rows, top_k_val, max_len, deterministic, tie_break, stream); + row_starts, num_rows, top_k_val, max_len, deterministic, tie_break, stream, dsa_graph_safe); } template cudaError_t FilteredTopKRaggedTransform(DType* input, IdType* output_indices, const IdType* offsets, - IdType* lengths, uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, bool deterministic = false, + IdType* lengths, const IdType* row_starts, + uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + bool deterministic = false, TopKTieBreak tie_break = TopKTieBreak::None, - cudaStream_t stream = 0) { + cudaStream_t stream = 0, bool dsa_graph_safe = false) { DType* aux_output = nullptr; // Not used for Ragged mode int64_t aux_stride = 0; // Not used for Ragged mode const IdType* row_to_batch = nullptr; // Not used for Ragged mode return LaunchFilteredTopKUnified( - input, output_indices, aux_output, offsets, aux_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, deterministic, tie_break, stream); + input, output_indices, aux_output, offsets, aux_stride, row_to_batch, lengths, row_starts, + num_rows, top_k_val, max_len, deterministic, tie_break, stream, dsa_graph_safe); } template cudaError_t FilteredTopK(DType* input, IdType* output_indices, DType* output_values, const IdType* lengths, uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, bool deterministic = false, - TopKTieBreak tie_break = TopKTieBreak::None, cudaStream_t stream = 0) { + TopKTieBreak tie_break = TopKTieBreak::None, cudaStream_t stream = 0, + bool dsa_graph_safe = false) { const IdType* aux_input = nullptr; // Not used for Plain mode int64_t aux_stride = 0; // Not used for Plain mode + const IdType* row_starts = nullptr; // Not used for Plain mode const IdType* row_to_batch = nullptr; // Not used for Plain mode return LaunchFilteredTopKUnified( - input, output_indices, output_values, aux_input, aux_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, deterministic, tie_break, stream); + input, output_indices, output_values, aux_input, aux_stride, row_to_batch, lengths, + row_starts, num_rows, top_k_val, max_len, deterministic, tie_break, stream, dsa_graph_safe); } /*! @@ -3198,7 +3222,12 @@ inline TopKAlgoOverride GetTopKAlgoOverride() { */ template inline bool ShouldUseFilteredTopK(uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, - bool deterministic, TopKTieBreak tie_break) { + bool deterministic, TopKTieBreak tie_break, + bool dsa_graph_safe = false) { + // DSA graph safe mode alwaus uses FilteredTopK + if (dsa_graph_safe) { + return true; + } // Tie-break modes are only supported by FilteredTopK if (tie_break != TopKTieBreak::None) { return true; @@ -3248,59 +3277,64 @@ inline bool ShouldUseFilteredTopK(uint32_t num_rows, uint32_t top_k_val, uint32_ template cudaError_t TopKPageTableTransformDispatch(DType* input, IdType* output_page_table, const IdType* src_page_table, int64_t src_stride, - const IdType* row_to_batch, IdType* lengths, - uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + IdType* lengths, const IdType* row_starts, + const IdType* row_to_batch, uint32_t num_rows, + uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, bool deterministic, TopKTieBreak tie_break = TopKTieBreak::None, - cudaStream_t stream = 0) { + cudaStream_t stream = 0, bool dsa_graph_safe = false) { + const bool require_filtered = dsa_graph_safe || tie_break != TopKTieBreak::None; if (tie_break != TopKTieBreak::None) { deterministic = true; - if (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK()) { - return cudaErrorNotSupported; - } } - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break)) { + if (require_filtered && (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK())) { + return cudaErrorNotSupported; + } + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break, + dsa_graph_safe)) { FLASHINFER_CUDA_CALL((FilteredTopKPageTableTransform( - input, output_page_table, src_page_table, src_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, deterministic, tie_break, stream))); + input, output_page_table, src_page_table, src_stride, row_to_batch, lengths, row_starts, + num_rows, top_k_val, max_len, deterministic, tie_break, stream, dsa_graph_safe))); if (deterministic) { FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( - output_page_table, static_cast(nullptr), src_page_table, src_stride, + output_page_table, static_cast(nullptr), src_page_table, src_stride, row_starts, row_to_batch, num_rows, top_k_val, max_len, stream))); } return cudaSuccess; } return RadixTopKPageTableTransformMultiCTA( - input, output_page_table, src_page_table, src_stride, row_to_batch, lengths, num_rows, - top_k_val, max_len, row_states_buffer, deterministic, stream); + input, output_page_table, src_page_table, src_stride, row_to_batch, lengths, row_starts, + num_rows, top_k_val, max_len, row_states_buffer, deterministic, stream); } template cudaError_t TopKRaggedTransformDispatch(DType* input, IdType* output_indices, const IdType* offsets, - IdType* lengths, uint32_t num_rows, uint32_t top_k_val, - uint32_t max_len, RadixRowState* row_states_buffer, - bool deterministic, + IdType* lengths, const IdType* row_starts, + uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, + RadixRowState* row_states_buffer, bool deterministic, TopKTieBreak tie_break = TopKTieBreak::None, - cudaStream_t stream = 0) { + cudaStream_t stream = 0, bool dsa_graph_safe = false) { + const bool require_filtered = dsa_graph_safe || tie_break != TopKTieBreak::None; if (tie_break != TopKTieBreak::None) { deterministic = true; - if (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK()) { - return cudaErrorNotSupported; - } } - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break)) { + if (require_filtered && (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK())) { + return cudaErrorNotSupported; + } + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break, + dsa_graph_safe)) { FLASHINFER_CUDA_CALL((FilteredTopKRaggedTransform( - input, output_indices, offsets, lengths, num_rows, top_k_val, max_len, deterministic, - tie_break, stream))); + input, output_indices, offsets, lengths, row_starts, num_rows, top_k_val, max_len, + deterministic, tie_break, stream, dsa_graph_safe))); if (deterministic) { FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( - output_indices, static_cast(nullptr), offsets, 0, nullptr, num_rows, top_k_val, - max_len, stream))); + output_indices, static_cast(nullptr), offsets, 0, row_starts, nullptr, num_rows, + top_k_val, max_len, stream))); } return cudaSuccess; } return RadixTopKRaggedTransformMultiCTA(input, output_indices, offsets, lengths, - num_rows, top_k_val, max_len, + row_starts, num_rows, top_k_val, max_len, row_states_buffer, deterministic, stream); } @@ -3309,20 +3343,22 @@ cudaError_t TopKDispatch(DType* input, IdType* output_indices, DType* output_val uint32_t num_rows, uint32_t top_k_val, uint32_t max_len, RadixRowState* row_states_buffer, bool sorted_output = false, bool deterministic = false, TopKTieBreak tie_break = TopKTieBreak::None, - cudaStream_t stream = 0) { + cudaStream_t stream = 0, bool dsa_graph_safe = false) { + const bool require_filtered = dsa_graph_safe || tie_break != TopKTieBreak::None; if (tie_break != TopKTieBreak::None) { deterministic = true; - if (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK()) { - return cudaErrorNotSupported; - } } - if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break)) { - FLASHINFER_CUDA_CALL( - (FilteredTopK(input, output_indices, output_values, nullptr, num_rows, - top_k_val, max_len, deterministic, tie_break, stream))); + if (require_filtered && (top_k_val > FILTERED_TOPK_MAX_K || !CanImplementFilteredTopK())) { + return cudaErrorNotSupported; + } + if (ShouldUseFilteredTopK(num_rows, top_k_val, max_len, deterministic, tie_break, + dsa_graph_safe)) { + FLASHINFER_CUDA_CALL((FilteredTopK(input, output_indices, output_values, nullptr, + num_rows, top_k_val, max_len, deterministic, + tie_break, stream, dsa_graph_safe))); if (deterministic) { FLASHINFER_CUDA_CALL((LaunchSortTopKByIndex( - output_indices, output_values, nullptr, 0, nullptr, num_rows, top_k_val, max_len, + output_indices, output_values, nullptr, 0, nullptr, nullptr, num_rows, top_k_val, max_len, stream))); } } else { diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index 235e908574..dd2363b1ae 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -414,6 +414,7 @@ def reference_page_table_transform( lengths: torch.Tensor, k: int, row_to_batch: torch.Tensor = None, + row_starts: torch.Tensor = None, ) -> torch.Tensor: """Reference implementation for page table transform using torch.topk.""" num_rows = scores.size(0) @@ -424,17 +425,20 @@ def reference_page_table_transform( for i in range(num_rows): length = lengths[i].item() + row_start = row_starts[i].item() if row_starts is not None else 0 batch_idx = row_to_batch[i].item() if row_to_batch is not None else i if length <= k: # Trivial case: just copy first `length` entries - output[i, :length] = src_page_table[batch_idx, :length] + output[i, :length] = src_page_table[ + batch_idx, row_start : row_start + length + ] else: # Get top-k indices - row_scores = scores[i, :length] + row_scores = scores[i, row_start : row_start + length] _, topk_indices = torch.topk(row_scores.float(), k) # Gather from page table - output[i] = src_page_table[batch_idx, topk_indices.long()] + output[i] = src_page_table[batch_idx, row_start + topk_indices.long()] return output @@ -444,6 +448,7 @@ def reference_ragged_transform( offsets: torch.Tensor, lengths: torch.Tensor, k: int, + row_starts: torch.Tensor = None, ) -> torch.Tensor: """Reference implementation for ragged transform using torch.topk.""" num_rows = scores.size(0) @@ -453,6 +458,7 @@ def reference_ragged_transform( for i in range(num_rows): length = lengths[i].item() + row_start = row_starts[i].item() if row_starts is not None else 0 offset = offsets[i].item() if length <= k: @@ -462,7 +468,7 @@ def reference_ragged_transform( ) else: # Get top-k indices - row_scores = scores[i, :length] + row_scores = scores[i, row_start : row_start + length] _, topk_indices = torch.topk(row_scores.float(), k) # Add offset output[i] = topk_indices.int() + offset @@ -573,6 +579,92 @@ def test_top_k_ragged_transform(num_rows, max_len, k, dtype): assert accuracy >= min_accuracy, f"Accuracy {accuracy:.4f} < {min_accuracy}" +@pytest.mark.parametrize("algo", ["multi_cta", "filtered"]) +@pytest.mark.parametrize("dsa_graph_safe", [False, True]) +@pytest.mark.parametrize( + "num_rows,max_len,k", + [ + (2, 128 * 1024, 2048), + (1, 256 * 1024, 1024), + (74, 16 * 1024, 512), + ], +) +def test_top_k_transform_with_row_starts( + algo, dsa_graph_safe, num_rows, max_len, k, set_topk_algo +): + """Transform APIs should honor row_starts windowing with local-index semantics.""" + if (algo == "filtered" or dsa_graph_safe) and not can_implement_filtered_topk(): + pytest.skip("Filtered top-k not supported on this device") + + set_topk_algo(algo) + device = "cuda" + + base = -torch.arange(max_len, device=device, dtype=torch.float32) + scores = base.unsqueeze(0).repeat(num_rows, 1).contiguous() + + max_start = max_len - (k + 1) + start_stride = max(1, max_start // max(1, num_rows - 1)) + row_starts = ( + torch.arange(num_rows, device=device, dtype=torch.int32) * start_stride + ).clamp(max=max_start) + max_windows = max_len - row_starts + lengths = torch.minimum( + max_windows, + k + 1 + (torch.arange(num_rows, device=device, dtype=torch.int32) % 4), + ) + offsets = torch.arange(num_rows, device=device, dtype=torch.int32) * 100 + row_to_batch = torch.arange(num_rows - 1, -1, -1, device=device, dtype=torch.int32) + + src_page_table = ( + torch.arange(max_len, device=device, dtype=torch.int32) + .unsqueeze(0) + .repeat(num_rows, 1) + ) + src_page_table = ( + src_page_table + + 1000 * torch.arange(num_rows, device=device, dtype=torch.int32).unsqueeze(1) + ).contiguous() + + output_page = flashinfer.top_k_page_table_transform( + scores, + src_page_table, + lengths, + k, + row_to_batch=row_to_batch, + row_starts=row_starts, + deterministic=True, + dsa_graph_safe=dsa_graph_safe, + ) + output_ragged = flashinfer.top_k_ragged_transform( + scores, + offsets, + lengths, + k, + row_starts=row_starts, + deterministic=True, + dsa_graph_safe=dsa_graph_safe, + ) + ref_page = reference_page_table_transform( + scores, + src_page_table, + lengths, + k, + row_to_batch=row_to_batch, + row_starts=row_starts, + ) + + ref_ragged = reference_ragged_transform( + scores, offsets, lengths, k, row_starts=row_starts + ) + output_page_sorted, _ = torch.sort(output_page, dim=-1) + ref_page_sorted, _ = torch.sort(ref_page, dim=-1) + assert torch.equal(output_page_sorted, ref_page_sorted) + + output_ragged_sorted, _ = torch.sort(output_ragged, dim=-1) + ref_ragged_sorted, _ = torch.sort(ref_ragged, dim=-1) + assert torch.equal(output_ragged_sorted, ref_ragged_sorted) + + @pytest.mark.parametrize("num_rows", [1, 8, 32]) @pytest.mark.parametrize("max_len", [1024, 4096, 8192]) @pytest.mark.parametrize("k", [64, 256, 512])