Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Currently supports testing attention, gemm, fused MOE, normalization, and quanti
- `BatchPrefillWithPagedKVCacheWrapper` - Prefill attention with paged KV cache.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`.
- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Fix unordered list indentation to satisfy MD007.

Markdownlint flags inconsistent indentation on this new sub-bullet; align it with the other β€œAlso supports” entries.

🧹 Markdownlint fix
-        - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and  `trtllm_ragged_attention_deepseek`.
+    - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`.
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`.
🧰 Tools
πŸͺ› markdownlint-cli2 (0.18.1)

19-19: Unordered list indentation
Expected: 4; Actual: 8

(MD007, ul-indent)

πŸ€– Prompt for AI Agents
In `@benchmarks/README.md` at line 19, The new sub-bullet "Also supports
computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and
`trtllm_ragged_attention_deepseek`" has inconsistent indentation causing MD007
failures; edit benchmarks/README.md to match the indentation/level used by the
other "Also supports" entries (make this line align with the other sibling
bullets under that list so it uses the same number of spaces or tab characters
as the other "Also supports" lines).

- `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models.
- Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`.
- GEMM:
Expand Down Expand Up @@ -280,7 +280,8 @@ Legend:
- fa2: FlashAttention-2
- fa2_tc: FlashAttention-2 (Tensor Core)
- fa3: FlashAttention-3
- cudnn: cuDNN
- cudnn: cuDNN (via wrapper API)
- cudnn-native: cuDNN (direct API call)
- cutlass: CUTLASS
- trtllm: TensorRT-LLM
- trtllm-gen: TensorRT-LLM (generic wrapper)
Expand All @@ -289,8 +290,8 @@ Legend:
| Routine | 7.5 | 8.0 | 8.6 | 8.9 | 9.0 | 10.0 | 10.3 | 12.0 |
|---------|-----|-----|-----|-----|-----|-------|-------|-------|
| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn |
| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn, trtllm-gen, trtllm-native | fa2, cudnn |
| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn | fa2, cudnn | fa2, cudnn | fa2, fa3, cudnn | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn, cutlass, trtllm-native | fa2, cudnn |
| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native |
| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native |
| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native | fa2, cutlass, trtllm-native | fa2 |
| **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
| **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
Expand All @@ -314,8 +315,9 @@ Backend Legend:
- fa2: FlashAttention2
- fa2_tc: FlashAttention2 (with Tensor Cores for `BatchDecodeWithPagedKVCacheWrapper`)
- fa3: FlashAttention-3
- cudnn: cuDNN
- cublas: cuBLAS
- cudnn: cuDNN (via wrapper API)
- cudnn-native: cuDNN (direct API call)
- cutlass: CUTLASS
- trtllm: TensorRT-LLM
- trtllm-gen: TensorRT-LLM
Expand Down
43 changes: 43 additions & 0 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,17 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
if remove_cudnn:
backends.remove("cudnn")

if "cudnn-native" in backends:
remove_cudnn_native = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] CUDNN-native backend does not support FP8. Skipping.")
remove_cudnn_native = True
if remove_cudnn_native:
backends.remove("cudnn-native")
Comment on lines +1399 to +1408
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file structure and imports
head -50 benchmarks/routines/attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 1213


🏁 Script executed:

# Search for CUDNN_AVAILABLE or similar constants in the file
rg "CUDNN_AVAILABLE|cudnn_available|is_cuda" benchmarks/routines/attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 1207


🏁 Script executed:

# Search for CUDNN_AVAILABLE across the entire codebase
rg "CUDNN_AVAILABLE" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 1417


🏁 Script executed:

# Check for torch.backends.cudnn availability checks
rg "torch\.backends\.cudnn|CUDNN|cudnn" benchmarks/routines/attention.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1323


🏁 Script executed:

# Let's look at the complete context around lines 1399-1408
sed -n '1380,1420p' benchmarks/routines/attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 1633


🏁 Script executed:

# Get exact line numbers where CUDNN_AVAILABLE check appears in relation to cudnn-native
rg -n "cudnn-native|CUDNN_AVAILABLE" benchmarks/routines/attention.py | grep -A 10 -B 5 "cudnn-native"

Repository: flashinfer-ai/flashinfer

Length of output: 829


🏁 Script executed:

# Get the complete cudnn and cudnn-native backend sections with line numbers
sed -n '1370,1450p' benchmarks/routines/attention.py | cat -n | sed 's/^[[:space:]]*//' | sed 's/^/1370+/'

Repository: flashinfer-ai/flashinfer

Length of output: 3843


🏁 Script executed:

# Check context around the first cudnn-native block (lines 784-798)
sed -n '770,810p' benchmarks/routines/attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 1801


🏁 Script executed:

# Check context around the second cudnn-native block (lines 1399-1408) 
sed -n '1390,1415p' benchmarks/routines/attention.py

Repository: flashinfer-ai/flashinfer

Length of output: 1030


Add cuDNN availability guard to cudnn-native backend filter.

The decode routine (lines 1399-1408) checks for FP8 dtype compatibility but lacks the CUDNN_AVAILABLE guard that exists in the prefill routine. If cuDNN is not available, the backend will fail at runtime. Add the availability check before evaluating dtype constraints.

Suggested fix
     if "cudnn-native" in backends:
         remove_cudnn_native = False
+        if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91701:
+            print("[INFO] cuDNN not available. Skipping cudnn-native backend.")
+            remove_cudnn_native = True
-        if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
+        elif q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
             torch.float8_e4m3fn,
             torch.float8_e5m2,
         ]:
             print("[INFO] CUDNN-native backend does not support FP8. Skipping.")
             remove_cudnn_native = True
πŸ€– Prompt for AI Agents
In `@benchmarks/routines/attention.py` around lines 1399 - 1408, The decode
routine's cudnn-native filter lacks the CUDNN_AVAILABLE guard and may select
"cudnn-native" when cuDNN isn't present; update the block that inspects
backends, q_dtype, kv_dtype and remove_cudnn_native to first check the
CUDNN_AVAILABLE flag (same guard used in prefill), skipping/removing
"cudnn-native" immediately if CUDNN_AVAILABLE is false before evaluating FP8
dtype constraints (refer to variables/backends list, q_dtype, kv_dtype,
remove_cudnn_native and the "cudnn-native" string).


if "cutlass" in backends:
remove_cutlass = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
Expand Down Expand Up @@ -1609,6 +1620,34 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
q_data_type=q_dtype,
kv_data_type=kv_dtype,
)
elif backend == "cudnn":
# cuDNN uses NHD layout and the wrapper API
backend_wrappers[backend] = (
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer,
"NHD",
backend="cudnn",
)
)
backend_wrappers[backend].plan(
qo_indptr=q_indptr,
kv_indptr=k_indptr,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=causal,
sm_scale=scale,
q_data_type=q_dtype,
kv_data_type=kv_dtype,
o_data_type=q_dtype,
seq_lens=actual_seq_lens_kv_device,
seq_lens_q=actual_seq_lens_q_device,
max_token_per_sequence=s_qo,
max_sequence_kv=s_kv,
v_indptr=v_indptr,
o_indptr=o_indptr,
)

k_scale, v_scale = None, None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
Expand Down Expand Up @@ -1639,6 +1678,10 @@ def run_backend_wrapper(
if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]:
return backend_wrappers[backend].run_return_lse(q, k, v)[0]
elif backend == "cudnn":
# cuDNN uses wrapper API
return backend_wrappers[backend].run(q, k, v)
elif backend == "cudnn-native":
# Direct cudnn_batch_prefill_with_kv_cache call
return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
q,
k,
Expand Down
15 changes: 8 additions & 7 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,15 @@ def dtype_str_to_torch_dtype(dtype_str):
},
"BatchPrefillWithRaggedKVCacheWrapper": {
# NOTE: trtllm-native calls trtllm_ragged_attention_deepseek
# NOTE: cudnn-native calls cudnn_batch_prefill_with_kv_cache
"7.5": [],
"8.0": ["fa2", "cudnn"],
"8.6": ["fa2", "cudnn"],
"8.9": ["fa2", "cudnn"],
"9.0": ["fa2", "fa3", "cudnn"],
"10.0": ["fa2", "cudnn", "cutlass", "trtllm-native"],
"10.3": ["fa2", "cudnn", "cutlass", "trtllm-native"],
"12.0": ["fa2", "cudnn"],
"8.0": ["fa2", "cudnn", "cudnn-native"],
"8.6": ["fa2", "cudnn", "cudnn-native"],
"8.9": ["fa2", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
"10.0": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"],
"10.3": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"],
"12.0": ["fa2", "cudnn", "cudnn-native"],
},
"BatchMLAPagedAttentionWrapper": {
# NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache_mla
Expand Down
75 changes: 72 additions & 3 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2478,7 +2478,7 @@ def __init__(
will be used in attention computation.

backend : str
The implementation backend, could be ``auto``/``fa2``/``fa3`` or ``cutlass``.
The implementation backend, could be ``auto``/``fa2``/``fa3``/``cudnn`` or ``cutlass``.
Defaults to ``auto``.
If set to ``auto``, the wrapper will automatically choose the backend based on the
device architecture and kernel availability.
Expand Down Expand Up @@ -2598,6 +2598,12 @@ def plan(
max_item_len_ptr: Optional[torch.Tensor] = None,
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
seq_lens: Optional[torch.Tensor] = None,
Comment thread
Anerudhan marked this conversation as resolved.
seq_lens_q: Optional[torch.Tensor] = None,
max_token_per_sequence: Optional[int] = None,
max_sequence_kv: Optional[int] = None,
v_indptr: Optional[torch.Tensor] = None,
o_indptr: Optional[torch.Tensor] = None,
Comment on lines +2601 to +2606
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

git ls-files flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 89


🏁 Script executed:

# Read lines around 2601-2606 and surrounding context
sed -n '2590,2620p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 1480


🏁 Script executed:

# Read lines around 2824-2828 and surrounding context
sed -n '2810,2840p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 1268


🏁 Script executed:

# Search for where seq_lens_q is first used in function definitions with docstrings
rg -A 20 "seq_lens_q" flashinfer/prefill.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3474


🏁 Script executed:

# Let's look at the function containing lines 2601-2606
sed -n '2570,2640p' flashinfer/prefill.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 3802


🏁 Script executed:

#!/bin/bash
# Find all occurrences of seq_lens_q assignment and read context
rg -B 5 -A 5 "self\._seq_lens_q\s*=" flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 2286


🏁 Script executed:

# Search for seq_lens_q in docstrings or comments
rg "seq_lens_q" flashinfer/prefill.py -B 2 -A 2 | head -150

Repository: flashinfer-ai/flashinfer

Length of output: 5119


🏁 Script executed:

#!/bin/bash
# Get line count first
total_lines=$(wc -l < flashinfer/prefill.py)
echo "Total lines: $total_lines"

# Read the section around 2601-2606
echo "=== Lines 2590-2620 ==="
sed -n '2590,2620p' flashinfer/prefill.py

echo ""
echo "=== Lines 2810-2840 ==="
sed -n '2810,2840p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 2748


🏁 Script executed:

#!/bin/bash
# Find the method definitions around the problem areas
echo "=== Around line 2600 (first occurrence) ==="
sed -n '2565,2580p' flashinfer/prefill.py

echo ""
echo "=== Around line 2800 (second occurrence) ==="
sed -n '2790,2810p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 1607


🏁 Script executed:

#!/bin/bash
# Find the seq_lens_q documentation
echo "=== Docstring for seq_lens_q around line 2671 ==="
sed -n '2665,2680p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 1194


🏁 Script executed:

#!/bin/bash
# Search for class definitions in the file
rg "^class " flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 154


🏁 Script executed:

#!/bin/bash
# Count occurrences of the problematic pattern
echo "=== Pattern 1: self._seq_lens_q = seq_lens_q (without fallback) ==="
rg -n "self\._seq_lens_q = seq_lens_q$" flashinfer/prefill.py

echo ""
echo "=== Pattern 2: self._seq_lens_q = seq_lens_q if (with fallback) ==="
rg -n "self\._seq_lens_q = seq_lens_q if" flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 330


🏁 Script executed:

#!/bin/bash
# Find the class definitions and their line numbers
echo "=== Class definitions with line numbers ==="
rg -n "^class " flashinfer/prefill.py

echo ""
echo "=== Location of line 1988 assignment ==="
sed -n '1980,1995p' flashinfer/prefill.py | cat -n

echo ""
echo "=== Location of line 2824 assignment ==="
sed -n '2820,2830p' flashinfer/prefill.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 1522


🏁 Script executed:

#!/bin/bash
# Search for seq_lens_q in docstrings specifically mentioning the default
rg -B 2 -A 2 "seq_lens_q.*will be set to" flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Read the full docstring section describing seq_lens_q
sed -n '2620,2720p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 7077


🏁 Script executed:

#!/bin/bash
# Find which method line 1988 belongs to
echo "=== Method containing line 1988 ==="
sed -n '1900,1990p' flashinfer/prefill.py | grep -E "(def |class )" | tail -1

echo ""
echo "=== Method containing line 2824 ==="
sed -n '2550,2824p' flashinfer/prefill.py | grep -E "(def |class )" | tail -1

echo ""
echo "=== Verify line 2824 is in plan method ==="
sed -n '2570,2580p' flashinfer/prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 446


Set seq_lens_q fallback in BatchPrefillWithRaggedKVCacheWrapper.plan() to match documented API contract.

The docstring states that seq_lens_q defaults to seq_lens when not provided, but the assignment at line 2824 leaves it as None. This causes crashes on the cuDNN backend when the code later calls self._seq_lens_q.dim(). The BatchPrefillWithPagedKVCacheWrapper class already implements the correct fallback pattern; apply the same fix here.

Suggested fix
-        self._seq_lens_q = seq_lens_q
         self._seq_lens_kv = seq_lens
+        self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens
πŸ€– Prompt for AI Agents
In `@flashinfer/prefill.py` around lines 2601 - 2606, In
BatchPrefillWithRaggedKVCacheWrapper.plan(), seq_lens_q is left as None despite
the docstring saying it should default to seq_lens; this causes later calls to
self._seq_lens_q.dim() to crash. Fix by assigning seq_lens_q = seq_lens when
seq_lens_q is None (same fallback used in BatchPrefillWithPagedKVCacheWrapper),
and ensure the method stores the resolved value to self._seq_lens_q before any
use; reference the plan() method and variables seq_lens_q and seq_lens in the
BatchPrefillWithRaggedKVCacheWrapper class.

) -> None:
r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.

Expand Down Expand Up @@ -2692,6 +2698,19 @@ def plan(
and lead to a varied number of launched CTAs.
disable_split_kv : bool,
Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``.
seq_lens: Optional[torch.Tensor]
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
seq_lens_q: Optional[torch.Tensor]
A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``.
If not provided, will be set to the same value as ``seq_lens``.
max_token_per_sequence: Optional[int],
Required for cudnn backend. This is the scalar max token length of each sequence.
max_sequence_kv: Optional[int],
Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache.
v_indptr: Optional[torch.Tensor]
Required for cudnn backend. This is the indptr of the value tensor.
o_indptr: Optional[torch.Tensor]
Required for cudnn backend. This is the indptr of the output tensor.
Note
----
The :meth:`plan` method should be called before any :meth:`run` or
Expand Down Expand Up @@ -2781,6 +2800,17 @@ def plan(
self.device, non_blocking=non_blocking
)

self._o_indptr_buf = (
o_indptr.to(self.device, non_blocking=non_blocking)
if o_indptr is not None
else self._qo_indptr_buf
)
self._v_indptr_buf = (
v_indptr.to(self.device, non_blocking=non_blocking)
if v_indptr is not None
else self._kv_indptr_buf
)

self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
self._cached_o_data_type = o_data_type
Expand All @@ -2791,6 +2821,11 @@ def plan(
self._token_pos_in_items_len = token_pos_in_items_len
self._max_item_len_ptr = max_item_len_ptr

self._seq_lens_q = seq_lens_q
self._seq_lens_kv = seq_lens
self._max_token_per_sequence = max_token_per_sequence
self._max_sequence_kv = max_sequence_kv

if self._jit_module is not None:
self._cached_module = self._jit_module
else:
Expand Down Expand Up @@ -2822,7 +2857,7 @@ def plan(
get_module_args[:9] + (qo_indptr.device,) + get_module_args[9:]
)
self._cached_module = get_fmha_module(*new_get_module_args)
else:
elif self._backend != "cudnn":
self._cached_module = get_batch_prefill_module(
self._backend, *get_module_args
)
Expand All @@ -2832,7 +2867,7 @@ def plan(
self._cached_module, qo_indptr, kv_indptr, num_qo_heads, causal
)
self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item()
else:
elif self._backend != "cudnn":
assert self._cached_module is not None, "cached module is not initialized"
args = [
self._float_workspace_buffer,
Expand Down Expand Up @@ -3040,6 +3075,40 @@ def run(
lse=lse,
)
return (out, lse) if return_lse else out
elif self._backend == "cudnn":
if self._seq_lens_q.dim() == 1:
batch_size = self._seq_lens_q.shape[0]
if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1:
self._seq_lens_q = self._seq_lens_q.reshape(batch_size, 1, 1, 1)

if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1:
self._seq_lens_kv = self._seq_lens_kv.reshape(batch_size, 1, 1, 1)
Comment thread
Anerudhan marked this conversation as resolved.

cudnn_batch_prefill_with_kv_cache(
q,
k,
v,
sm_scale,
self._float_workspace_buffer,
max_token_per_sequence=self._max_token_per_sequence,
max_sequence_kv=self._max_sequence_kv,
actual_seq_lens_q=self._seq_lens_q,
actual_seq_lens_kv=self._seq_lens_kv,
return_lse=return_lse,
causal=self._causal,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
batch_offsets_q=self._qo_indptr_buf,
batch_offsets_k=self._kv_indptr_buf,
batch_offsets_v=self._v_indptr_buf,
batch_offsets_o=self._o_indptr_buf,
is_cuda_graph_compatible=self._use_cuda_graph,
out=out,
lse=lse,
)

return (out, lse) if return_lse else out
Comment thread
Anerudhan marked this conversation as resolved.
Comment thread
Anerudhan marked this conversation as resolved.

# Skip FP8->FP16 conversion for FA3 backend with FP8 support
# The JIT module will handle FP8 natively
Expand Down
4 changes: 2 additions & 2 deletions tests/attention/test_cudnn_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_cudnn_prefill(
)

cumsum_s_qo = torch.sum(actual_seq_lens_q)
q = torch.ones(
q = torch.randn(
cumsum_s_qo, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16
)

Expand All @@ -60,7 +60,7 @@ def test_cudnn_prefill(
total_num_pages = num_pages_per_seq * batch_size

kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.ones(size=kv_cache_shape, dtype=torch.bfloat16).to(device)
kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device)
kv_cache = kv_cache.as_strided(
kv_cache.shape,
(
Expand Down
77 changes: 39 additions & 38 deletions tests/attention/test_cudnn_prefill_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,23 @@


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("s_qo", [32, 64, 87])
@pytest.mark.parametrize("s_kv", [32, 64, 87])
@pytest.mark.parametrize("num_kv_heads", [1])
@pytest.mark.parametrize("num_qo_heads", [1, 16])
@pytest.mark.parametrize("s_qo", [32, 64, 87, 256])
@pytest.mark.parametrize("s_kv", [32, 87, 512])
@pytest.mark.parametrize("num_kv_heads", [1, 4])
@pytest.mark.parametrize("num_qo_heads", [1, 8])
@pytest.mark.parametrize("causal", [True, False])
def test_cudnn_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal
):
if s_qo > s_kv:
pytest.skip("s_qo > s_kv, skipping test as causal")

if num_qo_heads < num_kv_heads:
pytest.skip("num_qo_heads < num_kv_heads, skipping test")
Comment on lines 7 to +20
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find the test file and read it
cat -n tests/attention/test_cudnn_prefill_deepseek.py

Repository: flashinfer-ai/flashinfer

Length of output: 5877


🏁 Script executed:

# Check what's available in flashinfer.utils
find . -name "utils.py" -o -name "utils" -type d | grep flashinfer | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 301


🏁 Script executed:

# Look for get_compute_capability and related functions in flashinfer
rg "get_compute_capability|is_sm90a_supported|is_sm100a_supported" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 17993


🏁 Script executed:

cat -n tests/attention/test_cudnn_prefill.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1740


🏁 Script executed:

# Check flashinfer.utils for cuDNN-specific checks
rg "cudnn|get_compute_capability" flashinfer/utils.py -A 3 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 2283


🏁 Script executed:

# Check if there's documentation or comments about cuDNN GPU requirements
rg "cuDNN|cudnn.*sm|cudnn.*arch|cudnn.*compute" --type py -i -B 2 -A 2 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 4924


🏁 Script executed:

# Check conftest to see if there are global skip mechanisms for cuDNN
cat -n tests/conftest.py 2>/dev/null | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4344


🏁 Script executed:

# Continue reading conftest
tail -n +100 tests/conftest.py | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 3055


🏁 Script executed:

# Look for OOM skip mechanisms mentioned in the coding guidelines
rg "OOM|out.of.memory|OutOfMemory" --type py tests/

Repository: flashinfer-ai/flashinfer

Length of output: 451


🏁 Script executed:

# Check if there are any SM requirements documented in flashinfer for cuDNN prefill
rg "cudnn.*prefill|prefill.*cudnn" --type py -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 6803


🏁 Script executed:

# Check what compute capabilities are actually needed by looking at BatchPrefillWithRaggedKVCacheWrapper
rg "BatchPrefillWithRaggedKVCacheWrapper" --type py -l

Repository: flashinfer-ai/flashinfer

Length of output: 836


🏁 Script executed:

# Check the actual prefill.py to see if there are SM requirements
grep -n "cudnn\|backend" flashinfer/prefill.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1608


🏁 Script executed:

# Read the full test_cudnn_prefill.py to see what GPU checks it has
cat -n tests/attention/test_cudnn_prefill.py | grep -A 30 "def test_cudnn_prefill"

Repository: flashinfer-ai/flashinfer

Length of output: 1982


🏁 Script executed:

# Check flashinfer/cudnn/__init__.py to understand cuDNN availability
cat -n flashinfer/cudnn/__init__.py

Repository: flashinfer-ai/flashinfer

Length of output: 189


🏁 Script executed:

# Check if there's a CUDNN_AVAILABLE flag or similar
rg "CUDNN_AVAILABLE" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 1417


Add GPU capability and cuDNN availability checks to gate test execution.

This test uses the cuDNN backend explicitly but lacks guards for GPU architecture support and cuDNN availability. As per coding guidelines, tests must skip on unsupported hardware. The 512MB workspace allocation can also cause OOM on smaller GPUs. Add checks before tensor allocations.

πŸ§ͺ Suggested skip guards
 import pytest
 import torch

 import flashinfer
+from flashinfer.utils import get_compute_capability
+
+try:
+    import cudnn  # type: ignore
+    CUDNN_AVAILABLE = True
+except (ImportError, OSError):
+    CUDNN_AVAILABLE = False


 `@pytest.mark.parametrize`("batch_size", [1, 4])
@@ -26,6 +34,14 @@ def test_cudnn_prefill_deepseek(
     batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal
 ):
     if s_qo > s_kv:
         pytest.skip("s_qo > s_kv, skipping test as causal")

     if num_qo_heads < num_kv_heads:
         pytest.skip("num_qo_heads < num_kv_heads, skipping test")
+
+    if not CUDNN_AVAILABLE:
+        pytest.skip("cuDNN not available")
+    major, _ = get_compute_capability(torch.device("cuda:0"))
+    if major < 8:
+        pytest.skip("cuDNN prefill requires SM80+")
πŸ€– Prompt for AI Agents
In `@tests/attention/test_cudnn_prefill_deepseek.py` around lines 7 - 20, Before
allocating tensors in test_cudnn_prefill_deepseek, add gates to skip the test if
no CUDA device or cuDNN is available and if the GPU's compute capability or free
memory is insufficient for the 512MB workspace; specifically check
torch.cuda.is_available(), torch.backends.cudnn.is_available(), and
torch.cuda.get_device_capability() (or device major/minor) and optionally
torch.cuda.get_device_properties().total_memory/free memory to skip when the
device lacks required SM capability or memory; place these checks at the top of
test_cudnn_prefill_deepseek (before any use of s_qo, s_kv, num_qo_heads,
num_kv_heads or tensor allocations) so the test is skipped early on unsupported
hardware.


head_dim_qk = 192
head_dim_vo = 128

return_lse = True

# test set up basics
seed = 0
torch.manual_seed(seed)
Expand Down Expand Up @@ -76,14 +77,14 @@ def test_cudnn_prefill_deepseek(
]
).int()

batch_offsets_stats = torch.cat(
[
torch.zeros(
1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype
),
torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads,
]
).cuda()
# batch_offsets_stats = torch.cat(
# [
# torch.zeros(
# 1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype
# ),
# torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads,
# ]
# ).cuda()

k_cache = torch.randn(
batch_size * s_kv,
Expand All @@ -103,45 +104,45 @@ def test_cudnn_prefill_deepseek(
# Initialize scale
scale = float(1.0 / (head_dim_qk**0.5))

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Right-size the workspace buffer to reduce OOM risk.

Hardcoding a 512MB workspace can exhaust memory on smaller GPUs. Consider capping it relative to device memory. As per coding guidelines, avoid OOM-prone test sizes.

πŸ’‘ Safer workspace sizing
-    workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device)
+    total_mem = torch.cuda.get_device_properties(device).total_memory
+    workspace_bytes = min(512 * 1024 * 1024, total_mem // 8)
+    workspace_buffer = torch.empty(workspace_bytes, dtype=torch.int8, device=device)
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device)
total_mem = torch.cuda.get_device_properties(device).total_memory
workspace_bytes = min(512 * 1024 * 1024, total_mem // 8)
workspace_buffer = torch.empty(workspace_bytes, dtype=torch.int8, device=device)
πŸ€– Prompt for AI Agents
In `@tests/attention/test_cudnn_prefill_deepseek.py` at line 107, The test
hardcodes a 512MB workspace (workspace_buffer) which can OOM on smaller GPUs;
replace the fixed size with a safe cap based on the device's total memory by
querying torch.cuda.get_device_properties(device).total_memory and computing a
workspace_size_bytes = min(512*1024*1024, int(total_mem * 0.1)) (or another safe
fraction like 0.05), then allocate workspace_buffer =
torch.empty(workspace_size_bytes, dtype=torch.int8, device=device) so the buffer
scales to the GPU and reduces OOM risk.


# output = torch.zeros_like(q)
output, lse = flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
q,
k_cache,
v_cache,
scale,
workspace_buffer,
wrapper_cudnn = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, "NHD", backend="cudnn"
)

wrapper_cudnn.plan(
qo_indptr=q_indptr,
kv_indptr=k_indptr,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim_qk,
head_dim_vo=head_dim_vo,
causal=causal,
sm_scale=scale,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
o_data_type=torch.bfloat16,
seq_lens=actual_seq_lens_kv,
seq_lens_q=actual_seq_lens_q,
Comment thread
Anerudhan marked this conversation as resolved.
max_token_per_sequence=s_qo,
max_sequence_kv=s_kv,
actual_seq_lens_q=actual_seq_lens_q,
actual_seq_lens_kv=actual_seq_lens_kv,
causal=causal,
return_lse=return_lse,
batch_offsets_q=q_indptr,
batch_offsets_k=k_indptr,
batch_offsets_v=v_indptr,
batch_offsets_o=o_indptr,
batch_offsets_stats=batch_offsets_stats,
is_cuda_graph_compatible=True,
v_indptr=v_indptr,
o_indptr=o_indptr,
)
Comment thread
Anerudhan marked this conversation as resolved.

output = wrapper_cudnn.run(q, k_cache, v_cache)

qo_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q.view(-1), dim=0),
]
).int()

# kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv

# Create kv_indptr as cumulative sum of actual_seq_lens_kv
kv_indptr = torch.cat(
[
torch.tensor(
[0],
device=device,
),
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_kv.view(-1), dim=0),
]
).int()
Expand Down