Added the cudnn backend Ragged KV Cache wrapper#2352
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughExtends BatchPrefillWithPagedKVCacheWrapper.plan/run to accept sequence length and indptr parameters, buffer indptrs, and add a cuDNN-specific execution path; tests and benchmarks updated to exercise cudnn-native and wrapper-based ragged/paged KV prefill flows. Changes
Sequence Diagram(s)sequenceDiagram
participant Test
participant Wrapper as BatchPrefillWithPagedKVCacheWrapper
participant Run as run()
participant CuDNN as cudnn_batch_prefill_with_kv_cache
participant Module as other_batch_prefill_backend
Test->>Wrapper: plan(..., seq_lens, seq_lens_q, v_indptr, o_indptr, max_...)
Wrapper->>Wrapper: store seq attrs and indptr buffers
Test->>Run: run(q, k_cache, v_cache)
alt backend == "cudnn"
Run->>Run: reshape/expand seq tensors if 1D -> 4D
Run->>CuDNN: cudnn_batch_prefill_with_kv_cache(seq_lens_q, seq_lens_kv, max_token_per_sequence, max_sequence_kv, qo_indptr, kv_indptr, v_indptr, o_indptr, ...)
CuDNN-->>Run: outputs (and lse if requested)
else
Run->>Module: call existing batch_prefill path
Module-->>Run: outputs
end
Run-->>Test: return outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Anerudhan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new cuDNN backend into the Ragged KV Cache prefill functionality, aiming to leverage NVIDIA's optimized deep neural network primitives for improved efficiency. The core changes involve extending the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds a cudnn backend for the Ragged KV Cache wrapper and updates tests. The changes are a good addition, but I've identified a few issues in the implementation. These include a critical bug that could cause a runtime error due to an undefined variable, a hardcoded value that should be dynamic, and a minor docstring formatting issue. I've also pointed out an API inconsistency that could be confusing for users. My review includes specific suggestions to address these points.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_cudnn_prefill_deepseek.py (1)
44-78: Add GPU architecture check for cuDNN backend support.The indptr calculations are correct—they create element-level offsets (cumsum × head_dim × num_heads) which match cuDNN's
batch_offsets_*expectations. However, the test is missing a GPU architecture validation check. According to coding guidelines, test implementations should useflashinfer.utilsfunctions to skip tests on unsupported GPU architectures. Similar totest_cudnn_prefill.py, this test should check compute capability before running, as cuDNN prefill is backend-specific and may not be supported on all GPU architectures.Add at the beginning of the test function:
from flashinfer.utils import get_compute_capabilityAnd after the seed/device setup:
major, _ = get_compute_capability(torch.device(device)) if major < 10: # or your minimum supported compute capability pytest.skip(f"cuDNN prefill not supported on compute capability {major}")
🤖 Fix all issues with AI agents
In `@flashinfer/prefill.py`:
- Around line 2701-2714: The docstring currently runs "disable_split_kv" and
"seq_lens:" together causing bad rendering; edit the relevant function/class
docstring to insert a blank line (newline) before the parameter block so
"seq_lens:" starts on its own paragraph. Locate the docstring containing
"disable_split_kv" and "seq_lens" and ensure there is an empty line between the
prose and the parameter list (preserving existing indentation and parameter
descriptions like seq_lens_q, max_token_per_sequence, max_sequence_kv, v_indptr,
o_indptr).
- Around line 3087-3111: The cuDNN call assumes required size parameters but may
receive None; before calling cudnn_batch_prefill_with_kv_cache check that
self._max_token_per_sequence and self._max_sequence_kv are not None (and
optionally >0) and raise a clear ValueError (or AssertionError) referencing the
function call context if they are missing; update the caller (the method
invoking cudnn_batch_prefill_with_kv_cache in prefill.py) to perform this
validation early (and include the attribute names _max_token_per_sequence and
_max_sequence_kv in the error message) so the backend is never invoked with
invalid parameters.
🧹 Nitpick comments (2)
tests/attention/test_cudnn_prefill_deepseek.py (2)
1-21: Missing GPU architecture skip check per coding guidelines.The test should use
flashinfer.utilsfunctions to skip tests on unsupported GPU architectures. The cuDNN backend may not be supported on all GPUs.Proposed fix
import pytest import torch import flashinfer +from flashinfer.utils import get_compute_capability `@pytest.mark.parametrize`("batch_size", [1, 4]) `@pytest.mark.parametrize`("s_qo", [32, 64, 87, 256]) `@pytest.mark.parametrize`("s_kv", [32, 87, 512]) `@pytest.mark.parametrize`("num_kv_heads", [1, 4]) `@pytest.mark.parametrize`("num_qo_heads", [1, 8]) `@pytest.mark.parametrize`("causal", [True, False]) def test_cudnn_prefill_deepseek( batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal ): if s_qo > s_kv: pytest.skip("s_qo > s_kv, skipping test as causal") if num_qo_heads < num_kv_heads: pytest.skip("num_qo_heads < num_kv_heads, skipping test") + + device = "cuda:0" + major, _ = get_compute_capability(torch.device(device)) + if major < 8: + pytest.skip(f"cuDNN backend requires compute capability >= 8.0, got {major}")
80-87: Consider removing commented-out code.The commented-out
batch_offsets_statsblock appears to be dead code. If it's no longer needed, consider removing it to keep the test clean.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between f0277fd and 6fcf574dde224351d8fa45f7e58d82af4bbff1c0.
📒 Files selected for processing (3)
flashinfer/prefill.pytests/attention/test_cudnn_prefill.pytests/attention/test_cudnn_prefill_deepseek.py
🧰 Additional context used
📓 Path-based instructions (2)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/prefill.py
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/attention/test_cudnn_prefill.pytests/attention/test_cudnn_prefill_deepseek.py
🧬 Code graph analysis (2)
flashinfer/prefill.py (1)
flashinfer/cudnn/prefill.py (1)
cudnn_batch_prefill_with_kv_cache(555-724)
tests/attention/test_cudnn_prefill.py (1)
flashinfer/logits_processor/types.py (1)
size(132-136)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (9)
tests/attention/test_cudnn_prefill_deepseek.py (2)
109-133: LGTM!The cuDNN wrapper setup and plan call correctly use the new API parameters for ragged KV cache support.
150-173: LGTM!The reference wrapper comparison is correctly set up with standard token-count indptrs and appropriate tolerances for bfloat16 comparison.
flashinfer/prefill.py (5)
2601-2606: LGTM!New optional parameters are correctly typed and positioned for cuDNN backend support while maintaining backward compatibility.
2803-2812: LGTM!The indptr buffer assignments correctly handle optional parameters with sensible fallbacks to existing buffers.
2824-2827: LGTM!Instance attributes are correctly stored for use in the cuDNN execution path.
2860-2863: LGTM!The conditional correctly skips batch prefill module initialization for the cuDNN backend.
2870-2896: LGTM!The plan info generation correctly skips the cuDNN backend.
tests/attention/test_cudnn_prefill.py (2)
47-49: LGTM!Using
torch.randninstead oftorch.onesprovides better test coverage with more realistic input distributions while maintaining reproducibility through the seed.
63-73: LGTM!Using
torch.randnfor KV cache initialization improves test robustness by avoiding the edge case of all-ones values.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
bkryu
left a comment
There was a problem hiding this comment.
Generally looking good, but left some minor comments.
Additionally, I created a branch on my end and added this commit to add microbenchmarking support for
cudnn--> cuDNN via wrapper APIcudnn-native-->cudnn_batch_prefill_with_kv_cache
Do you mind copy-pasting or cherrypicking the changes from the linked commit?
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2cdb5f6 to
f6ca31b
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/routines/attention.py (1)
1681-1704: Respect--no_cuda_graphfor the cudnn-native path.
is_cuda_graph_compatibleis hardcoded toTrue, ignoring the CLI flag and diverging from other backends. Pass the computed flag instead.🔧 Use the computed flag
- is_cuda_graph_compatible=True, + is_cuda_graph_compatible=is_cuda_graph_compatible,
🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Line 19: The new sub-bullet "Also supports computationally similar
`cudnn_batch_prefill_with_kv_cache` (cudnn-native) and
`trtllm_ragged_attention_deepseek`" has inconsistent indentation causing MD007
failures; edit benchmarks/README.md to match the indentation/level used by the
other "Also supports" entries (make this line align with the other sibling
bullets under that list so it uses the same number of spaces or tab characters
as the other "Also supports" lines).
In `@benchmarks/routines/attention.py`:
- Around line 1399-1408: The decode routine's cudnn-native filter lacks the
CUDNN_AVAILABLE guard and may select "cudnn-native" when cuDNN isn't present;
update the block that inspects backends, q_dtype, kv_dtype and
remove_cudnn_native to first check the CUDNN_AVAILABLE flag (same guard used in
prefill), skipping/removing "cudnn-native" immediately if CUDNN_AVAILABLE is
false before evaluating FP8 dtype constraints (refer to variables/backends list,
q_dtype, kv_dtype, remove_cudnn_native and the "cudnn-native" string).
In `@flashinfer/prefill.py`:
- Around line 2601-2606: In BatchPrefillWithRaggedKVCacheWrapper.plan(),
seq_lens_q is left as None despite the docstring saying it should default to
seq_lens; this causes later calls to self._seq_lens_q.dim() to crash. Fix by
assigning seq_lens_q = seq_lens when seq_lens_q is None (same fallback used in
BatchPrefillWithPagedKVCacheWrapper), and ensure the method stores the resolved
value to self._seq_lens_q before any use; reference the plan() method and
variables seq_lens_q and seq_lens in the BatchPrefillWithRaggedKVCacheWrapper
class.
In `@tests/attention/test_cudnn_prefill_deepseek.py`:
- Line 107: The test hardcodes a 512MB workspace (workspace_buffer) which can
OOM on smaller GPUs; replace the fixed size with a safe cap based on the
device's total memory by querying
torch.cuda.get_device_properties(device).total_memory and computing a
workspace_size_bytes = min(512*1024*1024, int(total_mem * 0.1)) (or another safe
fraction like 0.05), then allocate workspace_buffer =
torch.empty(workspace_size_bytes, dtype=torch.int8, device=device) so the buffer
scales to the GPU and reduces OOM risk.
- Around line 7-20: Before allocating tensors in test_cudnn_prefill_deepseek,
add gates to skip the test if no CUDA device or cuDNN is available and if the
GPU's compute capability or free memory is insufficient for the 512MB workspace;
specifically check torch.cuda.is_available(),
torch.backends.cudnn.is_available(), and torch.cuda.get_device_capability() (or
device major/minor) and optionally
torch.cuda.get_device_properties().total_memory/free memory to skip when the
device lacks required SM capability or memory; place these checks at the top of
test_cudnn_prefill_deepseek (before any use of s_qo, s_kv, num_qo_heads,
num_kv_heads or tensor allocations) so the test is skipped early on unsupported
hardware.
♻️ Duplicate comments (1)
flashinfer/prefill.py (1)
3079-3085: Potential None deref / UnboundLocalError in cuDNN run path.
self._seq_lens_q.dim()is called before aNonecheck andbatch_sizecan be undefined when_seq_lens_qis missing or not 1D. This remains a crash risk.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 2cdb5f62c59b5467fdcce1f4dbfb5cca263c587a and f6ca31b.
📒 Files selected for processing (6)
benchmarks/README.mdbenchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pyflashinfer/prefill.pytests/attention/test_cudnn_prefill.pytests/attention/test_cudnn_prefill_deepseek.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/attention/test_cudnn_prefill.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/attention/test_cudnn_prefill_deepseek.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/prefill.py
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
benchmarks/routines/flashinfer_benchmark_utils.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
benchmarks/routines/flashinfer_benchmark_utils.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads
Applied to files:
benchmarks/README.md
🧬 Code graph analysis (2)
benchmarks/routines/attention.py (1)
flashinfer/prefill.py (7)
BatchPrefillWithRaggedKVCacheWrapper(2337-3201)run(2019-2031)run(2034-2046)run(2049-2288)run(2935-2945)run(2948-2958)run(2961-3170)
flashinfer/prefill.py (1)
flashinfer/cudnn/prefill.py (1)
cudnn_batch_prefill_with_kv_cache(555-724)
🪛 markdownlint-cli2 (0.18.1)
benchmarks/README.md
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
benchmarks/routines/attention.py (1)
1623-1650: cudnn wrapper planning integration looks solid.flashinfer/prefill.py (2)
2481-2484: Doc updates for cuDNN parameters are clear.Also applies to: 2701-2713
2803-2812: Buffering v/o indptrs and skipping module planning for cuDNN looks good.Also applies to: 2860-2871
benchmarks/routines/flashinfer_benchmark_utils.py (1)
223-231: Backend support matrix update aligns with cudnn-native integration.benchmarks/README.md (1)
283-285: cudnn-native legend/table updates read well.Also applies to: 293-294, 319-320
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_batch_context_with_kv_cache`. | ||
| - `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache. | ||
| - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` and `trtllm_ragged_attention_deepseek`. | ||
| - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`. |
There was a problem hiding this comment.
Fix unordered list indentation to satisfy MD007.
Markdownlint flags inconsistent indentation on this new sub-bullet; align it with the other “Also supports” entries.
🧹 Markdownlint fix
- - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`.
+ - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`. | |
| - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`. |
🧰 Tools
🪛 markdownlint-cli2 (0.18.1)
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
🤖 Prompt for AI Agents
In `@benchmarks/README.md` at line 19, The new sub-bullet "Also supports
computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and
`trtllm_ragged_attention_deepseek`" has inconsistent indentation causing MD007
failures; edit benchmarks/README.md to match the indentation/level used by the
other "Also supports" entries (make this line align with the other sibling
bullets under that list so it uses the same number of spaces or tab characters
as the other "Also supports" lines).
| if "cudnn-native" in backends: | ||
| remove_cudnn_native = False | ||
| if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
| ]: | ||
| print("[INFO] CUDNN-native backend does not support FP8. Skipping.") | ||
| remove_cudnn_native = True | ||
| if remove_cudnn_native: | ||
| backends.remove("cudnn-native") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file structure and imports
head -50 benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 1213
🏁 Script executed:
# Search for CUDNN_AVAILABLE or similar constants in the file
rg "CUDNN_AVAILABLE|cudnn_available|is_cuda" benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 1207
🏁 Script executed:
# Search for CUDNN_AVAILABLE across the entire codebase
rg "CUDNN_AVAILABLE" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 1417
🏁 Script executed:
# Check for torch.backends.cudnn availability checks
rg "torch\.backends\.cudnn|CUDNN|cudnn" benchmarks/routines/attention.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1323
🏁 Script executed:
# Let's look at the complete context around lines 1399-1408
sed -n '1380,1420p' benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 1633
🏁 Script executed:
# Get exact line numbers where CUDNN_AVAILABLE check appears in relation to cudnn-native
rg -n "cudnn-native|CUDNN_AVAILABLE" benchmarks/routines/attention.py | grep -A 10 -B 5 "cudnn-native"Repository: flashinfer-ai/flashinfer
Length of output: 829
🏁 Script executed:
# Get the complete cudnn and cudnn-native backend sections with line numbers
sed -n '1370,1450p' benchmarks/routines/attention.py | cat -n | sed 's/^[[:space:]]*//' | sed 's/^/1370+/'Repository: flashinfer-ai/flashinfer
Length of output: 3843
🏁 Script executed:
# Check context around the first cudnn-native block (lines 784-798)
sed -n '770,810p' benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 1801
🏁 Script executed:
# Check context around the second cudnn-native block (lines 1399-1408)
sed -n '1390,1415p' benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer
Length of output: 1030
Add cuDNN availability guard to cudnn-native backend filter.
The decode routine (lines 1399-1408) checks for FP8 dtype compatibility but lacks the CUDNN_AVAILABLE guard that exists in the prefill routine. If cuDNN is not available, the backend will fail at runtime. Add the availability check before evaluating dtype constraints.
Suggested fix
if "cudnn-native" in backends:
remove_cudnn_native = False
+ if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91701:
+ print("[INFO] cuDNN not available. Skipping cudnn-native backend.")
+ remove_cudnn_native = True
- if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
+ elif q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] CUDNN-native backend does not support FP8. Skipping.")
remove_cudnn_native = True🤖 Prompt for AI Agents
In `@benchmarks/routines/attention.py` around lines 1399 - 1408, The decode
routine's cudnn-native filter lacks the CUDNN_AVAILABLE guard and may select
"cudnn-native" when cuDNN isn't present; update the block that inspects
backends, q_dtype, kv_dtype and remove_cudnn_native to first check the
CUDNN_AVAILABLE flag (same guard used in prefill), skipping/removing
"cudnn-native" immediately if CUDNN_AVAILABLE is false before evaluating FP8
dtype constraints (refer to variables/backends list, q_dtype, kv_dtype,
remove_cudnn_native and the "cudnn-native" string).
| seq_lens: Optional[torch.Tensor] = None, | ||
| seq_lens_q: Optional[torch.Tensor] = None, | ||
| max_token_per_sequence: Optional[int] = None, | ||
| max_sequence_kv: Optional[int] = None, | ||
| v_indptr: Optional[torch.Tensor] = None, | ||
| o_indptr: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
git ls-files flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 89
🏁 Script executed:
# Read lines around 2601-2606 and surrounding context
sed -n '2590,2620p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 1480
🏁 Script executed:
# Read lines around 2824-2828 and surrounding context
sed -n '2810,2840p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 1268
🏁 Script executed:
# Search for where seq_lens_q is first used in function definitions with docstrings
rg -A 20 "seq_lens_q" flashinfer/prefill.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3474
🏁 Script executed:
# Let's look at the function containing lines 2601-2606
sed -n '2570,2640p' flashinfer/prefill.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3802
🏁 Script executed:
#!/bin/bash
# Find all occurrences of seq_lens_q assignment and read context
rg -B 5 -A 5 "self\._seq_lens_q\s*=" flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 2286
🏁 Script executed:
# Search for seq_lens_q in docstrings or comments
rg "seq_lens_q" flashinfer/prefill.py -B 2 -A 2 | head -150Repository: flashinfer-ai/flashinfer
Length of output: 5119
🏁 Script executed:
#!/bin/bash
# Get line count first
total_lines=$(wc -l < flashinfer/prefill.py)
echo "Total lines: $total_lines"
# Read the section around 2601-2606
echo "=== Lines 2590-2620 ==="
sed -n '2590,2620p' flashinfer/prefill.py
echo ""
echo "=== Lines 2810-2840 ==="
sed -n '2810,2840p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 2748
🏁 Script executed:
#!/bin/bash
# Find the method definitions around the problem areas
echo "=== Around line 2600 (first occurrence) ==="
sed -n '2565,2580p' flashinfer/prefill.py
echo ""
echo "=== Around line 2800 (second occurrence) ==="
sed -n '2790,2810p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 1607
🏁 Script executed:
#!/bin/bash
# Find the seq_lens_q documentation
echo "=== Docstring for seq_lens_q around line 2671 ==="
sed -n '2665,2680p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 1194
🏁 Script executed:
#!/bin/bash
# Search for class definitions in the file
rg "^class " flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
#!/bin/bash
# Count occurrences of the problematic pattern
echo "=== Pattern 1: self._seq_lens_q = seq_lens_q (without fallback) ==="
rg -n "self\._seq_lens_q = seq_lens_q$" flashinfer/prefill.py
echo ""
echo "=== Pattern 2: self._seq_lens_q = seq_lens_q if (with fallback) ==="
rg -n "self\._seq_lens_q = seq_lens_q if" flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 330
🏁 Script executed:
#!/bin/bash
# Find the class definitions and their line numbers
echo "=== Class definitions with line numbers ==="
rg -n "^class " flashinfer/prefill.py
echo ""
echo "=== Location of line 1988 assignment ==="
sed -n '1980,1995p' flashinfer/prefill.py | cat -n
echo ""
echo "=== Location of line 2824 assignment ==="
sed -n '2820,2830p' flashinfer/prefill.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1522
🏁 Script executed:
#!/bin/bash
# Search for seq_lens_q in docstrings specifically mentioning the default
rg -B 2 -A 2 "seq_lens_q.*will be set to" flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Read the full docstring section describing seq_lens_q
sed -n '2620,2720p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 7077
🏁 Script executed:
#!/bin/bash
# Find which method line 1988 belongs to
echo "=== Method containing line 1988 ==="
sed -n '1900,1990p' flashinfer/prefill.py | grep -E "(def |class )" | tail -1
echo ""
echo "=== Method containing line 2824 ==="
sed -n '2550,2824p' flashinfer/prefill.py | grep -E "(def |class )" | tail -1
echo ""
echo "=== Verify line 2824 is in plan method ==="
sed -n '2570,2580p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer
Length of output: 446
Set seq_lens_q fallback in BatchPrefillWithRaggedKVCacheWrapper.plan() to match documented API contract.
The docstring states that seq_lens_q defaults to seq_lens when not provided, but the assignment at line 2824 leaves it as None. This causes crashes on the cuDNN backend when the code later calls self._seq_lens_q.dim(). The BatchPrefillWithPagedKVCacheWrapper class already implements the correct fallback pattern; apply the same fix here.
Suggested fix
- self._seq_lens_q = seq_lens_q
self._seq_lens_kv = seq_lens
+ self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lens🤖 Prompt for AI Agents
In `@flashinfer/prefill.py` around lines 2601 - 2606, In
BatchPrefillWithRaggedKVCacheWrapper.plan(), seq_lens_q is left as None despite
the docstring saying it should default to seq_lens; this causes later calls to
self._seq_lens_q.dim() to crash. Fix by assigning seq_lens_q = seq_lens when
seq_lens_q is None (same fallback used in BatchPrefillWithPagedKVCacheWrapper),
and ensure the method stores the resolved value to self._seq_lens_q before any
use; reference the plan() method and variables seq_lens_q and seq_lens in the
BatchPrefillWithRaggedKVCacheWrapper class.
| @pytest.mark.parametrize("batch_size", [1, 4]) | ||
| @pytest.mark.parametrize("s_qo", [32, 64, 87]) | ||
| @pytest.mark.parametrize("s_kv", [32, 64, 87]) | ||
| @pytest.mark.parametrize("num_kv_heads", [1]) | ||
| @pytest.mark.parametrize("num_qo_heads", [1, 16]) | ||
| @pytest.mark.parametrize("s_qo", [32, 64, 87, 256]) | ||
| @pytest.mark.parametrize("s_kv", [32, 87, 512]) | ||
| @pytest.mark.parametrize("num_kv_heads", [1, 4]) | ||
| @pytest.mark.parametrize("num_qo_heads", [1, 8]) | ||
| @pytest.mark.parametrize("causal", [True, False]) | ||
| def test_cudnn_prefill_deepseek( | ||
| batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal | ||
| ): | ||
| if s_qo > s_kv: | ||
| pytest.skip("s_qo > s_kv, skipping test as causal") | ||
|
|
||
| if num_qo_heads < num_kv_heads: | ||
| pytest.skip("num_qo_heads < num_kv_heads, skipping test") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find the test file and read it
cat -n tests/attention/test_cudnn_prefill_deepseek.pyRepository: flashinfer-ai/flashinfer
Length of output: 5877
🏁 Script executed:
# Check what's available in flashinfer.utils
find . -name "utils.py" -o -name "utils" -type d | grep flashinfer | head -20Repository: flashinfer-ai/flashinfer
Length of output: 301
🏁 Script executed:
# Look for get_compute_capability and related functions in flashinfer
rg "get_compute_capability|is_sm90a_supported|is_sm100a_supported" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 17993
🏁 Script executed:
cat -n tests/attention/test_cudnn_prefill.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1740
🏁 Script executed:
# Check flashinfer.utils for cuDNN-specific checks
rg "cudnn|get_compute_capability" flashinfer/utils.py -A 3 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 2283
🏁 Script executed:
# Check if there's documentation or comments about cuDNN GPU requirements
rg "cuDNN|cudnn.*sm|cudnn.*arch|cudnn.*compute" --type py -i -B 2 -A 2 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 4924
🏁 Script executed:
# Check conftest to see if there are global skip mechanisms for cuDNN
cat -n tests/conftest.py 2>/dev/null | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4344
🏁 Script executed:
# Continue reading conftest
tail -n +100 tests/conftest.py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 3055
🏁 Script executed:
# Look for OOM skip mechanisms mentioned in the coding guidelines
rg "OOM|out.of.memory|OutOfMemory" --type py tests/Repository: flashinfer-ai/flashinfer
Length of output: 451
🏁 Script executed:
# Check if there are any SM requirements documented in flashinfer for cuDNN prefill
rg "cudnn.*prefill|prefill.*cudnn" --type py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 6803
🏁 Script executed:
# Check what compute capabilities are actually needed by looking at BatchPrefillWithRaggedKVCacheWrapper
rg "BatchPrefillWithRaggedKVCacheWrapper" --type py -lRepository: flashinfer-ai/flashinfer
Length of output: 836
🏁 Script executed:
# Check the actual prefill.py to see if there are SM requirements
grep -n "cudnn\|backend" flashinfer/prefill.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1608
🏁 Script executed:
# Read the full test_cudnn_prefill.py to see what GPU checks it has
cat -n tests/attention/test_cudnn_prefill.py | grep -A 30 "def test_cudnn_prefill"Repository: flashinfer-ai/flashinfer
Length of output: 1982
🏁 Script executed:
# Check flashinfer/cudnn/__init__.py to understand cuDNN availability
cat -n flashinfer/cudnn/__init__.pyRepository: flashinfer-ai/flashinfer
Length of output: 189
🏁 Script executed:
# Check if there's a CUDNN_AVAILABLE flag or similar
rg "CUDNN_AVAILABLE" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 1417
Add GPU capability and cuDNN availability checks to gate test execution.
This test uses the cuDNN backend explicitly but lacks guards for GPU architecture support and cuDNN availability. As per coding guidelines, tests must skip on unsupported hardware. The 512MB workspace allocation can also cause OOM on smaller GPUs. Add checks before tensor allocations.
🧪 Suggested skip guards
import pytest
import torch
import flashinfer
+from flashinfer.utils import get_compute_capability
+
+try:
+ import cudnn # type: ignore
+ CUDNN_AVAILABLE = True
+except (ImportError, OSError):
+ CUDNN_AVAILABLE = False
`@pytest.mark.parametrize`("batch_size", [1, 4])
@@ -26,6 +34,14 @@ def test_cudnn_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal
):
if s_qo > s_kv:
pytest.skip("s_qo > s_kv, skipping test as causal")
if num_qo_heads < num_kv_heads:
pytest.skip("num_qo_heads < num_kv_heads, skipping test")
+
+ if not CUDNN_AVAILABLE:
+ pytest.skip("cuDNN not available")
+ major, _ = get_compute_capability(torch.device("cuda:0"))
+ if major < 8:
+ pytest.skip("cuDNN prefill requires SM80+")🤖 Prompt for AI Agents
In `@tests/attention/test_cudnn_prefill_deepseek.py` around lines 7 - 20, Before
allocating tensors in test_cudnn_prefill_deepseek, add gates to skip the test if
no CUDA device or cuDNN is available and if the GPU's compute capability or free
memory is insufficient for the 512MB workspace; specifically check
torch.cuda.is_available(), torch.backends.cudnn.is_available(), and
torch.cuda.get_device_capability() (or device major/minor) and optionally
torch.cuda.get_device_properties().total_memory/free memory to skip when the
device lacks required SM capability or memory; place these checks at the top of
test_cudnn_prefill_deepseek (before any use of s_qo, s_kv, num_qo_heads,
num_kv_heads or tensor allocations) so the test is skipped early on unsupported
hardware.
| scale = float(1.0 / (head_dim_qk**0.5)) | ||
|
|
||
| workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) | ||
| workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device) |
There was a problem hiding this comment.
Right-size the workspace buffer to reduce OOM risk.
Hardcoding a 512MB workspace can exhaust memory on smaller GPUs. Consider capping it relative to device memory. As per coding guidelines, avoid OOM-prone test sizes.
💡 Safer workspace sizing
- workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device)
+ total_mem = torch.cuda.get_device_properties(device).total_memory
+ workspace_bytes = min(512 * 1024 * 1024, total_mem // 8)
+ workspace_buffer = torch.empty(workspace_bytes, dtype=torch.int8, device=device)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device) | |
| total_mem = torch.cuda.get_device_properties(device).total_memory | |
| workspace_bytes = min(512 * 1024 * 1024, total_mem // 8) | |
| workspace_buffer = torch.empty(workspace_bytes, dtype=torch.int8, device=device) |
🤖 Prompt for AI Agents
In `@tests/attention/test_cudnn_prefill_deepseek.py` at line 107, The test
hardcodes a 512MB workspace (workspace_buffer) which can OOM on smaller GPUs;
replace the fixed size with a safe cap based on the device's total memory by
querying torch.cuda.get_device_properties(device).total_memory and computing a
workspace_size_bytes = min(512*1024*1024, int(total_mem * 0.1)) (or another safe
fraction like 0.05), then allocate workspace_buffer =
torch.empty(workspace_size_bytes, dtype=torch.int8, device=device) so the buffer
scales to the GPU and reduces OOM risk.
bkryu
left a comment
There was a problem hiding this comment.
Thanks @Anerudhan, LGTM. Unit test failures are unrelated
|
[FAILED] Pipeline #41861590: 14/20 passed |
📌 Description
Added the cudnn backend Ragged KV Cache wrapper
Fixed the test_prefill.py to not use torch.ones (accidentally did it before)
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.