allow cudnn to be chosen for prefill#2622
Conversation
Summary of ChangesHello @scottyokim, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the FlashInfer library by integrating cuDNN as an available backend for prefill attention operations. By introducing a Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughThe diff threads Changes
Sequence Diagram(s)sequenceDiagram
participant Prefill as Prefill Module
participant Selector as determine_attention_backend()
participant cuDNN as cuDNN Backend
participant FA as FA2/FA3 Backends
Prefill->>Selector: determine_attention_backend(..., kv_layout="NHD")
Selector->>cuDNN: _is_cudnn_available_for_attention()? (device/drivers)
alt cuDNN available and kv_layout == "NHD"
Selector->>Prefill: return "cudnn"
Prefill-->>cuDNN: construct module (kv_layout="NHD")
else
Selector->>FA: evaluate FA3/FA2 checks
Selector->>Prefill: return "fa3" or "fa2"
Prefill-->>FA: construct module (kv_layout passed where applicable)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for using the cudnn backend for prefill operations. The changes look good and are well-contained. I have a few suggestions to improve correctness and performance.
Specifically, I've recommended adding a check for supported head dimensions before selecting the cudnn backend to prevent potential runtime errors. I've also suggested caching the cudnn availability check to improve performance by avoiding repeated module import attempts.
| return True | ||
|
|
||
|
|
||
| def _is_cudnn_available_for_attention() -> bool: |
There was a problem hiding this comment.
To avoid repeated import attempts, which can be costly, it's a good practice to cache the result of this function. Since this function is pure and its result won't change during the program's execution, using @functools.cache is ideal.
| def _is_cudnn_available_for_attention() -> bool: | |
| @functools.cache | |
| def _is_cudnn_available_for_attention() -> bool: |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/utils.py (1)
461-468: Movereturn Truetoelseblock; optionally cache availability resultRuff TRY300 flags the
return Trueinside thetrybody. Moving it to anelseclause also clarifies intent and prevents any futuretry-body additions from accidentally being reachable on anImportError.♻️ Proposed fix
def _is_cudnn_available_for_attention() -> bool: """Return True if cuDNN is available for attention (prefill).""" try: import cudnn # noqa: F401 - - return True except ImportError: return False + else: + return TrueAdditionally, consider decorating with
@functools.cacheso repeated calls avoid thesys.moduleslookup on everydetermine_attention_backendinvocation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/utils.py` around lines 461 - 468, Move the immediate True return out of the try body in _is_cudnn_available_for_attention: keep the try to import cudnn, have except ImportError: return False, and put return True in an else: block so future changes in the try won't accidentally run on ImportError; optionally add `@functools.cache` (or `@functools.lru_cache`(maxsize=1)) above _is_cudnn_available_for_attention to memoize the availability check and avoid repeated sys.modules lookups in determine_attention_backend.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/utils.py`:
- Around line 517-521: The selection of the "cudnn" backend in
flashinfer/utils.py currently uses only _is_cudnn_available_for_attention() and
must be guarded similarly to is_fa3_backend_supported(): add a new guard
function (e.g., is_cudnn_backend_supported) that validates constraints required
by cudnn (reject when use_custom_mask=True or when pos_encoding_mode != NONE
since cudnn only supports causal), then update the backend selection branch that
returns "cudnn" to call this guard; ensure callers like
BatchPrefillWithPagedKVCacheWrapper.run() and
BatchPrefillWithRaggedKVCacheWrapper.run() will fall back to fa2 (or other safe
backend) when the guard returns False so custom masks and non-causal positional
encodings are not silently dropped.
---
Nitpick comments:
In `@flashinfer/utils.py`:
- Around line 461-468: Move the immediate True return out of the try body in
_is_cudnn_available_for_attention: keep the try to import cudnn, have except
ImportError: return False, and put return True in an else: block so future
changes in the try won't accidentally run on ImportError; optionally add
`@functools.cache` (or `@functools.lru_cache`(maxsize=1)) above
_is_cudnn_available_for_attention to memoize the availability check and avoid
repeated sys.modules lookups in determine_attention_backend.
| if ( | ||
| kv_layout == "NHD" | ||
| and _is_cudnn_available_for_attention() | ||
| ): | ||
| return "cudnn" |
There was a problem hiding this comment.
cuDNN selected without constraint validation — silent correctness regression
Unlike fa3 (gated by is_fa3_backend_supported()), the cuDNN branch has no equivalent guard. This creates two concrete regressions:
-
Custom mask silently discarded: when
use_custom_mask=True+ NHD layout + cuDNN installed, fa3 is correctly rejected (is_fa3_backend_supportedreturnsFalse), but now the code falls through to cuDNN instead of fa2. NeitherBatchPrefillWithPagedKVCacheWrapper.run()norBatchPrefillWithRaggedKVCacheWrapper.run()passes the custom mask tocudnn_batch_prefill_with_kv_cache, so the mask is silently dropped. -
Positional encoding silently ignored: same path when
pos_encoding_mode != NONE— the cuDNN call sites accept onlycausal, not a pos-encoding mode.
A guard function analogous to is_fa3_backend_supported is needed:
🐛 Proposed fix
+def is_cudnn_backend_supported(
+ pos_encoding_mode: int,
+ use_fp16_qk_reductions: bool,
+ use_custom_mask: bool,
+ dtype_q: torch.dtype,
+ dtype_kv: torch.dtype,
+) -> bool:
+ if use_custom_mask:
+ return False
+ if pos_encoding_mode != PosEncodingMode.NONE.value:
+ return False
+ return True
+
+
def determine_attention_backend(
...
) -> str:
if is_sm90a_supported(device) and is_fa3_backend_supported(
pos_encoding_mode,
use_fp16_qk_reductions,
use_custom_mask,
dtype_q,
dtype_kv,
):
return "fa3"
if (
kv_layout == "NHD"
and _is_cudnn_available_for_attention()
+ and is_cudnn_backend_supported(
+ pos_encoding_mode,
+ use_fp16_qk_reductions,
+ use_custom_mask,
+ dtype_q,
+ dtype_kv,
+ )
):
return "cudnn"
return "fa2"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/utils.py` around lines 517 - 521, The selection of the "cudnn"
backend in flashinfer/utils.py currently uses only
_is_cudnn_available_for_attention() and must be guarded similarly to
is_fa3_backend_supported(): add a new guard function (e.g.,
is_cudnn_backend_supported) that validates constraints required by cudnn (reject
when use_custom_mask=True or when pos_encoding_mode != NONE since cudnn only
supports causal), then update the backend selection branch that returns "cudnn"
to call this guard; ensure callers like
BatchPrefillWithPagedKVCacheWrapper.run() and
BatchPrefillWithRaggedKVCacheWrapper.run() will fall back to fa2 (or other safe
backend) when the guard returns False so custom masks and non-causal positional
encodings are not silently dropped.
There was a problem hiding this comment.
+1 for this, otherwise looks good to me.
There was a problem hiding this comment.
@saltyminty, understood! I can help implement the guard function to ensure cuDNN is only selected when it supports the requested features. Would you like me to generate the code changes that add the is_cudnn_backend_supported function and integrate it into the backend selection logic as CodeRabbit suggested?
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/utils.py (1)
509-519:⚠️ Potential issue | 🔴 CriticalGuard cuDNN selection with feature-compat checks.
The cuDNN branch is still chosen solely by availability + NHD layout. If custom masks or non-
NONEpositional encoding are requested, cuDNN prefill paths that don’t accept those inputs will silently drop them. Add a dedicated guard (analogous to FA3) so unsupported configs fall back to FA2.🐛 Proposed fix
+def is_cudnn_backend_supported( + pos_encoding_mode: int, + use_fp16_qk_reductions: bool, + use_custom_mask: bool, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, +) -> bool: + if use_custom_mask: + return False + if pos_encoding_mode != PosEncodingMode.NONE.value: + return False + return True + def determine_attention_backend( device: torch.device, pos_encoding_mode: int, use_fp16_qk_reductions: bool, use_custom_mask: bool, dtype_q: torch.dtype, dtype_kv: torch.dtype, kv_layout: Optional[str] = None, ) -> str: ... - if kv_layout == "NHD" and _is_cudnn_available_for_attention(): + if ( + kv_layout == "NHD" + and _is_cudnn_available_for_attention() + and is_cudnn_backend_supported( + pos_encoding_mode, + use_fp16_qk_reductions, + use_custom_mask, + dtype_q, + dtype_kv, + ) + ): return "cudnn" return "fa2"#!/bin/bash # Locate cuDNN prefill call sites and verify which args are supported. rg -n -C3 '\bcudnn_\w*prefill\w*\s*\(' -g '*.py' # Check how custom masks / positional encodings are threaded in prefill paths. rg -n -C3 '\b(custom_mask|pos_encoding_mode|kv_layout)\b' -g 'flashinfer/prefill.py'🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/utils.py` around lines 509 - 519, The cuDNN branch currently selects based only on kv_layout and availability and must be guarded against unsupported features (positional encodings, custom masks, FP16 Q/K reductions, dtype combos); add a dedicated compatibility check (e.g., call or add a helper like is_cudnn_backend_supported or is_cudnn_prefill_supported) that accepts pos_encoding_mode, use_custom_mask, use_fp16_qk_reductions, dtype_q, dtype_kv (same signature style as is_fa3_backend_supported) and only return "cudnn" when that helper returns true and kv_layout == "NHD" and _is_cudnn_available_for_attention(); otherwise fall back to "fa2" so unsupported configs don't silently drop inputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/utils.py`:
- Around line 509-519: The cuDNN branch currently selects based only on
kv_layout and availability and must be guarded against unsupported features
(positional encodings, custom masks, FP16 Q/K reductions, dtype combos); add a
dedicated compatibility check (e.g., call or add a helper like
is_cudnn_backend_supported or is_cudnn_prefill_supported) that accepts
pos_encoding_mode, use_custom_mask, use_fp16_qk_reductions, dtype_q, dtype_kv
(same signature style as is_fa3_backend_supported) and only return "cudnn" when
that helper returns true and kv_layout == "NHD" and
_is_cudnn_available_for_attention(); otherwise fall back to "fa2" so unsupported
configs don't silently drop inputs.
|
cc @saltyminty for review |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/utils.py (1)
509-510:⚠️ Potential issue | 🔴 CriticalcuDNN is selected without any constraint validation — silent correctness regression.
The cuDNN branch fires solely on
kv_layout == "NHD"+ import availability. Unlike thefa3branch (which is gated byis_fa3_backend_supported), there is no analogous guard here. Two concrete regressions remain:
- Custom mask silently dropped – when
use_custom_mask=True+ NHD layout + cuDNN installed,fa3is correctly rejected, but the code now routes to"cudnn". NeitherBatchPrefillWithPagedKVCacheWrapper.run()norBatchPrefillWithRaggedKVCacheWrapper.run()forwards the custom mask tocudnn_batch_prefill_with_kv_cache, so the mask is silently ignored.- Positional encoding silently ignored – same path when
pos_encoding_mode != NONE; the cuDNN call sites accept onlycausal, not a full positional-encoding mode.A guard function analogous to
is_fa3_backend_supportedis needed before this can be safely merged:🐛 Proposed fix
+def is_cudnn_backend_supported( + pos_encoding_mode: int, + use_fp16_qk_reductions: bool, + use_custom_mask: bool, + dtype_q: torch.dtype, + dtype_kv: torch.dtype, +) -> bool: + if use_custom_mask: + return False + if pos_encoding_mode != PosEncodingMode.NONE.value: + return False + return True + + def determine_attention_backend(...) -> str: if ( kv_layout == "NHD" and _is_cudnn_available_for_attention() + and is_cudnn_backend_supported( + pos_encoding_mode, + use_fp16_qk_reductions, + use_custom_mask, + dtype_q, + dtype_kv, + ) ): return "cudnn" if is_sm90a_supported(device) and is_fa3_backend_supported(...): return "fa3" return "fa2"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/utils.py` around lines 509 - 510, The cuDNN path is chosen solely by kv_layout and import availability which causes silent regressions (custom masks and non-causal positional encodings dropped); add a guard like is_cudnn_backend_supported(...) and use it instead of only _is_cudnn_available_for_attention() in the selector where kv_layout == "NHD" (the code around kv_layout and _is_cudnn_available_for_attention). Implement is_cudnn_backend_supported to validate that callers' constraints are safe (e.g., use_custom_mask is False or masks can be forwarded, and pos_encoding_mode is NONE or CAUSAL only), update the selector to return "cudnn" only when that guard passes, and either modify BatchPrefillWithPagedKVCacheWrapper.run / BatchPrefillWithRaggedKVCacheWrapper.run to forward custom masks/positional encoding to cudnn_batch_prefill_with_kv_cache or prevent selecting cudnn when those features are in use.
🧹 Nitpick comments (1)
flashinfer/utils.py (1)
461-468: Cache_is_cudnn_available_for_attentionand restructure the try/except.Two minor improvements:
- TRY300 (Ruff): The
return Trueinside thetrybody should move to anelseblock, which is the idiomatic way to express "run this only if no exception was raised."- Caching: The result of this function is constant for the entire process lifetime (cuDNN is either installed or not). Without
@functools.cachethetry/importround-trip happens on every call todetermine_attention_backend.♻️ Proposed refactor
+@functools.cache def _is_cudnn_available_for_attention() -> bool: """Return True if cuDNN is available for attention (prefill).""" try: import cudnn # noqa: F401 - - return True except ImportError: return False + else: + return TrueAs per coding guidelines, Python functions in
flashinfer/*.pyshould use@functools.cachefor module caching to avoid repeated overhead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/utils.py` around lines 461 - 468, The function _is_cudnn_available_for_attention should be decorated with functools.cache and the try/except should use an else block: import cudnn inside try, return False in the except ImportError, and return True in the else. Add "from functools import cache" or use "functools.cache" and apply `@cache` to _is_cudnn_available_for_attention so the import check is performed only once per process.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/utils.py`:
- Around line 509-510: The cuDNN path is chosen solely by kv_layout and import
availability which causes silent regressions (custom masks and non-causal
positional encodings dropped); add a guard like is_cudnn_backend_supported(...)
and use it instead of only _is_cudnn_available_for_attention() in the selector
where kv_layout == "NHD" (the code around kv_layout and
_is_cudnn_available_for_attention). Implement is_cudnn_backend_supported to
validate that callers' constraints are safe (e.g., use_custom_mask is False or
masks can be forwarded, and pos_encoding_mode is NONE or CAUSAL only), update
the selector to return "cudnn" only when that guard passes, and either modify
BatchPrefillWithPagedKVCacheWrapper.run /
BatchPrefillWithRaggedKVCacheWrapper.run to forward custom masks/positional
encoding to cudnn_batch_prefill_with_kv_cache or prevent selecting cudnn when
those features are in use.
---
Nitpick comments:
In `@flashinfer/utils.py`:
- Around line 461-468: The function _is_cudnn_available_for_attention should be
decorated with functools.cache and the try/except should use an else block:
import cudnn inside try, return False in the except ImportError, and return True
in the else. Add "from functools import cache" or use "functools.cache" and
apply `@cache` to _is_cudnn_available_for_attention so the import check is
performed only once per process.
📌 Description
Add a cudnn choice to the prefill implementation options.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit