From 167009996664c3dd2095a8efeaef775bd6fb0f48 Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Tue, 13 Jan 2026 23:31:32 -0800 Subject: [PATCH 1/4] Added the deepseek sizes to the Ragged KV Cache wrapper --- flashinfer/prefill.py | 73 +++++++++++++++++- tests/attention/test_cudnn_prefill.py | 4 +- .../attention/test_cudnn_prefill_deepseek.py | 77 ++++++++++--------- 3 files changed, 112 insertions(+), 42 deletions(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index bfdcfd9048..9ab2f51ebc 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2598,6 +2598,12 @@ def plan( max_item_len_ptr: Optional[torch.Tensor] = None, fixed_split_size: Optional[int] = None, disable_split_kv: bool = False, + seq_lens: Optional[torch.Tensor] = None, + seq_lens_q: Optional[torch.Tensor] = None, + max_token_per_sequence: Optional[int] = None, + max_sequence_kv: Optional[int] = None, + v_indptr: Optional[torch.Tensor] = None, + o_indptr: Optional[torch.Tensor] = None, ) -> None: r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. @@ -2692,6 +2698,19 @@ def plan( and lead to a varied number of launched CTAs. disable_split_kv : bool, Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``. + seq_lens: Optional[torch.Tensor] + A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``. + seq_lens_q: Optional[torch.Tensor] + A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``. + If not provided, will be set to the same value as ``seq_lens``. + max_token_per_sequence: Optional[int], + Required for cudnn backend. This is the scalar max token length of each sequence. + max_sequence_kv: Optional[int], + Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache. + v_indptr: Optional[torch.Tensor] + Required for cudnn backend. This is the indptr of the value tensor. + o_indptr: Optional[torch.Tensor] + Required for cudnn backend. This is the indptr of the output tensor. Note ---- The :meth:`plan` method should be called before any :meth:`run` or @@ -2781,6 +2800,17 @@ def plan( self.device, non_blocking=non_blocking ) + self._o_indptr_buf = ( + o_indptr.to(self.device, non_blocking=non_blocking) + if o_indptr is not None + else self._qo_indptr_buf + ) + self._v_indptr_buf = ( + v_indptr.to(self.device, non_blocking=non_blocking) + if v_indptr is not None + else self._kv_indptr_buf + ) + self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type self._cached_o_data_type = o_data_type @@ -2791,6 +2821,11 @@ def plan( self._token_pos_in_items_len = token_pos_in_items_len self._max_item_len_ptr = max_item_len_ptr + self._seq_lens_q = seq_lens_q + self._seq_lens_kv = seq_lens + self._max_token_per_sequence = max_token_per_sequence + self._max_sequence_kv = max_sequence_kv + if self._jit_module is not None: self._cached_module = self._jit_module else: @@ -2822,7 +2857,7 @@ def plan( get_module_args[:9] + (qo_indptr.device,) + get_module_args[9:] ) self._cached_module = get_fmha_module(*new_get_module_args) - else: + elif self._backend != "cudnn": self._cached_module = get_batch_prefill_module( self._backend, *get_module_args ) @@ -2832,7 +2867,7 @@ def plan( self._cached_module, qo_indptr, kv_indptr, num_qo_heads, causal ) self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item() - else: + elif self._backend != "cudnn": assert self._cached_module is not None, "cached module is not initialized" args = [ self._float_workspace_buffer, @@ -3040,6 +3075,40 @@ def run( lse=lse, ) return (out, lse) if return_lse else out + elif self._backend == "cudnn": + if self._seq_lens_q.dim() == 1: + batch_size = self._seq_lens_q.shape[0] + if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1: + self._seq_lens_q = self._seq_lens_q.reshape(batch_size, 1, 1, 1) + + if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1: + self._seq_lens_kv = self._seq_lens_kv.reshape(batch_size, 1, 1, 1) + + cudnn_batch_prefill_with_kv_cache( + q, + k, + v, + sm_scale, + self._float_workspace_buffer, + max_token_per_sequence=self._max_token_per_sequence, + max_sequence_kv=self._max_sequence_kv, + actual_seq_lens_q=self._seq_lens_q, + actual_seq_lens_kv=self._seq_lens_kv, + return_lse=return_lse, + causal=self._causal, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + batch_offsets_q=self._qo_indptr_buf, + batch_offsets_k=self._kv_indptr_buf, + batch_offsets_v=self._v_indptr_buf, + batch_offsets_o=self._o_indptr_buf, + is_cuda_graph_compatible=True, + out=out, + lse=lse, + ) + + return (out, lse) if return_lse else out # Skip FP8->FP16 conversion for FA3 backend with FP8 support # The JIT module will handle FP8 natively diff --git a/tests/attention/test_cudnn_prefill.py b/tests/attention/test_cudnn_prefill.py index 3e03e309d5..af9fc09f2b 100644 --- a/tests/attention/test_cudnn_prefill.py +++ b/tests/attention/test_cudnn_prefill.py @@ -44,7 +44,7 @@ def test_cudnn_prefill( ) cumsum_s_qo = torch.sum(actual_seq_lens_q) - q = torch.ones( + q = torch.randn( cumsum_s_qo, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16 ) @@ -60,7 +60,7 @@ def test_cudnn_prefill( total_num_pages = num_pages_per_seq * batch_size kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim) - kv_cache = torch.ones(size=kv_cache_shape, dtype=torch.bfloat16).to(device) + kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device) kv_cache = kv_cache.as_strided( kv_cache.shape, ( diff --git a/tests/attention/test_cudnn_prefill_deepseek.py b/tests/attention/test_cudnn_prefill_deepseek.py index 8362934ece..db3c076609 100644 --- a/tests/attention/test_cudnn_prefill_deepseek.py +++ b/tests/attention/test_cudnn_prefill_deepseek.py @@ -5,10 +5,10 @@ @pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("s_qo", [32, 64, 87]) -@pytest.mark.parametrize("s_kv", [32, 64, 87]) -@pytest.mark.parametrize("num_kv_heads", [1]) -@pytest.mark.parametrize("num_qo_heads", [1, 16]) +@pytest.mark.parametrize("s_qo", [32, 64, 87, 256]) +@pytest.mark.parametrize("s_kv", [32, 87, 512]) +@pytest.mark.parametrize("num_kv_heads", [1, 4]) +@pytest.mark.parametrize("num_qo_heads", [1, 8]) @pytest.mark.parametrize("causal", [True, False]) def test_cudnn_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal @@ -16,11 +16,12 @@ def test_cudnn_prefill_deepseek( if s_qo > s_kv: pytest.skip("s_qo > s_kv, skipping test as causal") + if num_qo_heads < num_kv_heads: + pytest.skip("num_qo_heads < num_kv_heads, skipping test") + head_dim_qk = 192 head_dim_vo = 128 - return_lse = True - # test set up basics seed = 0 torch.manual_seed(seed) @@ -76,14 +77,14 @@ def test_cudnn_prefill_deepseek( ] ).int() - batch_offsets_stats = torch.cat( - [ - torch.zeros( - 1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype - ), - torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads, - ] - ).cuda() + # batch_offsets_stats = torch.cat( + # [ + # torch.zeros( + # 1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype + # ), + # torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads, + # ] + # ).cuda() k_cache = torch.randn( batch_size * s_kv, @@ -103,29 +104,34 @@ def test_cudnn_prefill_deepseek( # Initialize scale scale = float(1.0 / (head_dim_qk**0.5)) - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) + workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device) - # output = torch.zeros_like(q) - output, lse = flashinfer.prefill.cudnn_batch_prefill_with_kv_cache( - q, - k_cache, - v_cache, - scale, - workspace_buffer, + wrapper_cudnn = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, "NHD", backend="cudnn" + ) + + wrapper_cudnn.plan( + qo_indptr=q_indptr, + kv_indptr=k_indptr, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=scale, + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + o_data_type=torch.bfloat16, + seq_lens=actual_seq_lens_kv, + seq_lens_q=actual_seq_lens_q, max_token_per_sequence=s_qo, max_sequence_kv=s_kv, - actual_seq_lens_q=actual_seq_lens_q, - actual_seq_lens_kv=actual_seq_lens_kv, - causal=causal, - return_lse=return_lse, - batch_offsets_q=q_indptr, - batch_offsets_k=k_indptr, - batch_offsets_v=v_indptr, - batch_offsets_o=o_indptr, - batch_offsets_stats=batch_offsets_stats, - is_cuda_graph_compatible=True, + v_indptr=v_indptr, + o_indptr=o_indptr, ) + output = wrapper_cudnn.run(q, k_cache, v_cache) + qo_indptr = torch.cat( [ torch.tensor([0], device=device), @@ -133,15 +139,10 @@ def test_cudnn_prefill_deepseek( ] ).int() - # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv - # Create kv_indptr as cumulative sum of actual_seq_lens_kv kv_indptr = torch.cat( [ - torch.tensor( - [0], - device=device, - ), + torch.tensor([0], device=device), torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), ] ).int() From 06ca655b0d16f9ab9c770a81001047a0250d744d Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Wed, 14 Jan 2026 18:37:24 -0800 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- flashinfer/prefill.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 9ab2f51ebc..33814dfdf6 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2698,7 +2698,7 @@ def plan( and lead to a varied number of launched CTAs. disable_split_kv : bool, Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``. - seq_lens: Optional[torch.Tensor] + seq_lens: Optional[torch.Tensor] A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``. seq_lens_q: Optional[torch.Tensor] A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``. @@ -3103,7 +3103,7 @@ def run( batch_offsets_k=self._kv_indptr_buf, batch_offsets_v=self._v_indptr_buf, batch_offsets_o=self._o_indptr_buf, - is_cuda_graph_compatible=True, + is_cuda_graph_compatible=self._use_cuda_graph, out=out, lse=lse, ) From d0b6b55435ab15f69f036416d91aa2b173aec295 Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Thu, 15 Jan 2026 21:55:52 -0800 Subject: [PATCH 3/4] Update the docstring --- flashinfer/prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 33814dfdf6..08b3608fb4 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2478,7 +2478,7 @@ def __init__( will be used in attention computation. backend : str - The implementation backend, could be ``auto``/``fa2``/``fa3`` or ``cutlass``. + The implementation backend, could be ``auto``/``fa2``/``fa3``/``cudnn`` or ``cutlass``. Defaults to ``auto``. If set to ``auto``, the wrapper will automatically choose the backend based on the device architecture and kernel availability. From f6ca31b9c013fe60f895207bd7e425ccda62267c Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 16 Jan 2026 05:14:30 +0000 Subject: [PATCH 4/4] Add cudnn via wrapper to benchmark --- benchmarks/README.md | 12 +++--- benchmarks/routines/attention.py | 43 +++++++++++++++++++ .../routines/flashinfer_benchmark_utils.py | 15 ++++--- 3 files changed, 58 insertions(+), 12 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index efcafdc403..15f561ca29 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -16,7 +16,7 @@ Currently supports testing attention, gemm, fused MOE, normalization, and quanti - `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache. - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`. - `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache. - - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`. + - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`. - `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models. - Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`. - GEMM: @@ -280,7 +280,8 @@ Legend: - fa2: FlashAttention-2 - fa2_tc: FlashAttention-2 (Tensor Core) - fa3: FlashAttention-3 -- cudnn: cuDNN +- cudnn: cuDNN (via wrapper API) +- cudnn-native: cuDNN (direct API call) - cutlass: CUTLASS - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM (generic wrapper) @@ -289,8 +290,8 @@ Legend: | Routine | 7.5 | 8.0 | 8.6 | 8.9 | 9.0 | 10.0 | 10.3 | 12.0 | |---------|-----|-----|-----|-----|-----|-------|-------|-------| | **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn | -| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn | -| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn | +| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native | +| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native | | **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native | fa2, cutlass, trtllm-native | fa2 | | **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | @@ -314,8 +315,9 @@ Backend Legend: - fa2: FlashAttention2 - fa2_tc: FlashAttention2 (with Tensor Cores for `BatchDecodeWithPagedKVCacheWrapper`) - fa3: FlashAttention-3 -- cudnn: cuDNN - cublas: cuBLAS +- cudnn: cuDNN (via wrapper API) +- cudnn-native: cuDNN (direct API call) - cutlass: CUTLASS - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index c4ee7e5c1d..d217f6ba3e 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -1396,6 +1396,17 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): if remove_cudnn: backends.remove("cudnn") + if "cudnn-native" in backends: + remove_cudnn_native = False + if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + ]: + print("[INFO] CUDNN-native backend does not support FP8. Skipping.") + remove_cudnn_native = True + if remove_cudnn_native: + backends.remove("cudnn-native") + if "cutlass" in backends: remove_cutlass = False if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ @@ -1609,6 +1620,34 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): q_data_type=q_dtype, kv_data_type=kv_dtype, ) + elif backend == "cudnn": + # cuDNN uses NHD layout and the wrapper API + backend_wrappers[backend] = ( + flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, + "NHD", + backend="cudnn", + ) + ) + backend_wrappers[backend].plan( + qo_indptr=q_indptr, + kv_indptr=k_indptr, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=scale, + q_data_type=q_dtype, + kv_data_type=kv_dtype, + o_data_type=q_dtype, + seq_lens=actual_seq_lens_kv_device, + seq_lens_q=actual_seq_lens_q_device, + max_token_per_sequence=s_qo, + max_sequence_kv=s_kv, + v_indptr=v_indptr, + o_indptr=o_indptr, + ) k_scale, v_scale = None, None if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: @@ -1639,6 +1678,10 @@ def run_backend_wrapper( if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]: return backend_wrappers[backend].run_return_lse(q, k, v)[0] elif backend == "cudnn": + # cuDNN uses wrapper API + return backend_wrappers[backend].run(q, k, v) + elif backend == "cudnn-native": + # Direct cudnn_batch_prefill_with_kv_cache call return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache( q, k, diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 00bb04fb6f..1ecce0cb99 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -220,14 +220,15 @@ def dtype_str_to_torch_dtype(dtype_str): }, "BatchPrefillWithRaggedKVCacheWrapper": { # NOTE: trtllm-native calls trtllm_ragged_attention_deepseek + # NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache "7.5": [], - "8.0": ["fa2", "cudnn"], - "8.6": ["fa2", "cudnn"], - "8.9": ["fa2", "cudnn"], - "9.0": ["fa2", "fa3", "cudnn"], - "10.0": ["fa2", "cudnn", "cutlass", "trtllm-native"], - "10.3": ["fa2", "cudnn", "cutlass", "trtllm-native"], - "12.0": ["fa2", "cudnn"], + "8.0": ["fa2", "cudnn", "cudnn-native"], + "8.6": ["fa2", "cudnn", "cudnn-native"], + "8.9": ["fa2", "cudnn", "cudnn-native"], + "9.0": ["fa2", "fa3", "cudnn", "cudnn-native"], + "10.0": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"], + "10.3": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"], + "12.0": ["fa2", "cudnn", "cudnn-native"], }, "BatchMLAPagedAttentionWrapper": { # NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache_mla