Skip to content

allow cudnn to be chosen for prefill#2622

Open
scottyokim wants to merge 3 commits into
flashinfer-ai:mainfrom
scottyokim:syokim/prefill_cudnn
Open

allow cudnn to be chosen for prefill#2622
scottyokim wants to merge 3 commits into
flashinfer-ai:mainfrom
scottyokim:syokim/prefill_cudnn

Conversation

@scottyokim
Copy link
Copy Markdown

@scottyokim scottyokim commented Feb 23, 2026

📌 Description

Add a cudnn choice to the prefill implementation options.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • KV layout configuration now influences attention backend selection.
    • Automatic preference for cuDNN backend when available and compatible with the configured KV layout.
    • Added runtime detection of cuDNN availability for attention.
    • Backend selection logic improved to consider device, positional encoding mode, precision, masking, and KV layout.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @scottyokim, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the FlashInfer library by integrating cuDNN as an available backend for prefill attention operations. By introducing a kv_layout parameter, the system can now intelligently select cuDNN when the KV cache is configured in the "NHD" layout, providing more flexibility in backend choice and potentially leveraging cuDNN's optimized performance for specific configurations.

Highlights

  • cuDNN Backend Support: Enabled cuDNN as an attention backend option for prefill operations, allowing for potential performance improvements.
  • KV Layout Consideration: Introduced kv_layout as a parameter in backend determination, specifically allowing cuDNN to be chosen when the KV cache layout is "NHD".
  • Backend Selection Logic: Modified the determine_attention_backend utility function to conditionally select "cudnn" if cuDNN is available and the kv_layout is "NHD".

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/prefill.py
    • Passed kv_layout argument to the internal _get_attention_module calls within the plan methods for both single-request and batch prefill.
  • flashinfer/utils.py
    • Added a new private helper function _is_cudnn_available_for_attention to check for cuDNN library availability.
    • Modified the determine_attention_backend function to accept an optional kv_layout parameter.
    • Updated the logic in determine_attention_backend to prioritize "cudnn" if kv_layout is "NHD" and cuDNN is available, before falling back to "fa2".
Activity
  • No human activity recorded for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 23, 2026

📝 Walkthrough

Walkthrough

The diff threads kv_layout through prefill planning/module construction and makes attention-backend selection layout-aware by adding _is_cudnn_available_for_attention() and extending determine_attention_backend(..., kv_layout=None) to prefer "cudnn" for kv_layout == "NHD" when available.

Changes

Cohort / File(s) Summary
Prefill: propagate KV layout
flashinfer/prefill.py
Pass kv_layout=self._kv_layout into plan/module construction calls so module selection and backend argument assembly consider KV layout.
Backend selection & helpers
flashinfer/utils.py
Add _is_cudnn_available_for_attention(); update determine_attention_backend(..., kv_layout=None) to return "cudnn" when kv_layout == "NHD" and cuDNN is available, otherwise fall back to the existing FA3/FA2 decision path. Docstring updated to document kv_layout.

Sequence Diagram(s)

sequenceDiagram
    participant Prefill as Prefill Module
    participant Selector as determine_attention_backend()
    participant cuDNN as cuDNN Backend
    participant FA as FA2/FA3 Backends

    Prefill->>Selector: determine_attention_backend(..., kv_layout="NHD")
    Selector->>cuDNN: _is_cudnn_available_for_attention()? (device/drivers)
    alt cuDNN available and kv_layout == "NHD"
        Selector->>Prefill: return "cudnn"
        Prefill-->>cuDNN: construct module (kv_layout="NHD")
    else
        Selector->>FA: evaluate FA3/FA2 checks
        Selector->>Prefill: return "fa3" or "fa2"
        Prefill-->>FA: construct module (kv_layout passed where applicable)
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • cyx-6
  • bkryu
  • nvmbreughe
  • jimmyzho
  • aleozlx

Poem

🐰 I nudged KV layout down the lane,
so prefill knows where caches remain.
NHD may call cudnn’s gate,
backends align and choose their fate. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Description check ❓ Inconclusive The description is incomplete; it contains the template structure but lacks detail about the actual changes, motivation, and does not indicate whether tests were added or pre-commit checks were run. Provide more detail about what was changed and why, confirm whether tests were added/updated, and indicate if pre-commit checks were completed by marking the appropriate checkboxes.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: enabling cuDNN as a backend option for prefill operations.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for using the cudnn backend for prefill operations. The changes look good and are well-contained. I have a few suggestions to improve correctness and performance.

Specifically, I've recommended adding a check for supported head dimensions before selecting the cudnn backend to prevent potential runtime errors. I've also suggested caching the cudnn availability check to improve performance by avoiding repeated module import attempts.

Comment thread flashinfer/utils.py
return True


def _is_cudnn_available_for_attention() -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid repeated import attempts, which can be costly, it's a good practice to cache the result of this function. Since this function is pure and its result won't change during the program's execution, using @functools.cache is ideal.

Suggested change
def _is_cudnn_available_for_attention() -> bool:
@functools.cache
def _is_cudnn_available_for_attention() -> bool:

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/utils.py (1)

461-468: Move return True to else block; optionally cache availability result

Ruff TRY300 flags the return True inside the try body. Moving it to an else clause also clarifies intent and prevents any future try-body additions from accidentally being reachable on an ImportError.

♻️ Proposed fix
 def _is_cudnn_available_for_attention() -> bool:
     """Return True if cuDNN is available for attention (prefill)."""
     try:
         import cudnn  # noqa: F401
-
-        return True
     except ImportError:
         return False
+    else:
+        return True

Additionally, consider decorating with @functools.cache so repeated calls avoid the sys.modules lookup on every determine_attention_backend invocation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/utils.py` around lines 461 - 468, Move the immediate True return
out of the try body in _is_cudnn_available_for_attention: keep the try to import
cudnn, have except ImportError: return False, and put return True in an else:
block so future changes in the try won't accidentally run on ImportError;
optionally add `@functools.cache` (or `@functools.lru_cache`(maxsize=1)) above
_is_cudnn_available_for_attention to memoize the availability check and avoid
repeated sys.modules lookups in determine_attention_backend.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/utils.py`:
- Around line 517-521: The selection of the "cudnn" backend in
flashinfer/utils.py currently uses only _is_cudnn_available_for_attention() and
must be guarded similarly to is_fa3_backend_supported(): add a new guard
function (e.g., is_cudnn_backend_supported) that validates constraints required
by cudnn (reject when use_custom_mask=True or when pos_encoding_mode != NONE
since cudnn only supports causal), then update the backend selection branch that
returns "cudnn" to call this guard; ensure callers like
BatchPrefillWithPagedKVCacheWrapper.run() and
BatchPrefillWithRaggedKVCacheWrapper.run() will fall back to fa2 (or other safe
backend) when the guard returns False so custom masks and non-causal positional
encodings are not silently dropped.

---

Nitpick comments:
In `@flashinfer/utils.py`:
- Around line 461-468: Move the immediate True return out of the try body in
_is_cudnn_available_for_attention: keep the try to import cudnn, have except
ImportError: return False, and put return True in an else: block so future
changes in the try won't accidentally run on ImportError; optionally add
`@functools.cache` (or `@functools.lru_cache`(maxsize=1)) above
_is_cudnn_available_for_attention to memoize the availability check and avoid
repeated sys.modules lookups in determine_attention_backend.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26ef055 and c318136.

📒 Files selected for processing (2)
  • flashinfer/prefill.py
  • flashinfer/utils.py

Comment thread flashinfer/utils.py Outdated
Comment on lines +517 to +521
if (
kv_layout == "NHD"
and _is_cudnn_available_for_attention()
):
return "cudnn"
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

cuDNN selected without constraint validation — silent correctness regression

Unlike fa3 (gated by is_fa3_backend_supported()), the cuDNN branch has no equivalent guard. This creates two concrete regressions:

  1. Custom mask silently discarded: when use_custom_mask=True + NHD layout + cuDNN installed, fa3 is correctly rejected (is_fa3_backend_supported returns False), but now the code falls through to cuDNN instead of fa2. Neither BatchPrefillWithPagedKVCacheWrapper.run() nor BatchPrefillWithRaggedKVCacheWrapper.run() passes the custom mask to cudnn_batch_prefill_with_kv_cache, so the mask is silently dropped.

  2. Positional encoding silently ignored: same path when pos_encoding_mode != NONE — the cuDNN call sites accept only causal, not a pos-encoding mode.

A guard function analogous to is_fa3_backend_supported is needed:

🐛 Proposed fix
+def is_cudnn_backend_supported(
+    pos_encoding_mode: int,
+    use_fp16_qk_reductions: bool,
+    use_custom_mask: bool,
+    dtype_q: torch.dtype,
+    dtype_kv: torch.dtype,
+) -> bool:
+    if use_custom_mask:
+        return False
+    if pos_encoding_mode != PosEncodingMode.NONE.value:
+        return False
+    return True
+
+
 def determine_attention_backend(
     ...
 ) -> str:
     if is_sm90a_supported(device) and is_fa3_backend_supported(
         pos_encoding_mode,
         use_fp16_qk_reductions,
         use_custom_mask,
         dtype_q,
         dtype_kv,
     ):
         return "fa3"
     if (
         kv_layout == "NHD"
         and _is_cudnn_available_for_attention()
+        and is_cudnn_backend_supported(
+            pos_encoding_mode,
+            use_fp16_qk_reductions,
+            use_custom_mask,
+            dtype_q,
+            dtype_kv,
+        )
     ):
         return "cudnn"
     return "fa2"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/utils.py` around lines 517 - 521, The selection of the "cudnn"
backend in flashinfer/utils.py currently uses only
_is_cudnn_available_for_attention() and must be guarded similarly to
is_fa3_backend_supported(): add a new guard function (e.g.,
is_cudnn_backend_supported) that validates constraints required by cudnn (reject
when use_custom_mask=True or when pos_encoding_mode != NONE since cudnn only
supports causal), then update the backend selection branch that returns "cudnn"
to call this guard; ensure callers like
BatchPrefillWithPagedKVCacheWrapper.run() and
BatchPrefillWithRaggedKVCacheWrapper.run() will fall back to fa2 (or other safe
backend) when the guard returns False so custom masks and non-causal positional
encodings are not silently dropped.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for this, otherwise looks good to me.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saltyminty, understood! I can help implement the guard function to ensure cuDNN is only selected when it supports the requested features. Would you like me to generate the code changes that add the is_cudnn_backend_supported function and integrate it into the backend selection logic as CodeRabbit suggested?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/utils.py (1)

509-519: ⚠️ Potential issue | 🔴 Critical

Guard cuDNN selection with feature-compat checks.

The cuDNN branch is still chosen solely by availability + NHD layout. If custom masks or non-NONE positional encoding are requested, cuDNN prefill paths that don’t accept those inputs will silently drop them. Add a dedicated guard (analogous to FA3) so unsupported configs fall back to FA2.

🐛 Proposed fix
+def is_cudnn_backend_supported(
+    pos_encoding_mode: int,
+    use_fp16_qk_reductions: bool,
+    use_custom_mask: bool,
+    dtype_q: torch.dtype,
+    dtype_kv: torch.dtype,
+) -> bool:
+    if use_custom_mask:
+        return False
+    if pos_encoding_mode != PosEncodingMode.NONE.value:
+        return False
+    return True
+
 def determine_attention_backend(
     device: torch.device,
     pos_encoding_mode: int,
     use_fp16_qk_reductions: bool,
     use_custom_mask: bool,
     dtype_q: torch.dtype,
     dtype_kv: torch.dtype,
     kv_layout: Optional[str] = None,
 ) -> str:
     ...
-    if kv_layout == "NHD" and _is_cudnn_available_for_attention():
+    if (
+        kv_layout == "NHD"
+        and _is_cudnn_available_for_attention()
+        and is_cudnn_backend_supported(
+            pos_encoding_mode,
+            use_fp16_qk_reductions,
+            use_custom_mask,
+            dtype_q,
+            dtype_kv,
+        )
+    ):
         return "cudnn"
     return "fa2"
#!/bin/bash
# Locate cuDNN prefill call sites and verify which args are supported.
rg -n -C3 '\bcudnn_\w*prefill\w*\s*\(' -g '*.py'

# Check how custom masks / positional encodings are threaded in prefill paths.
rg -n -C3 '\b(custom_mask|pos_encoding_mode|kv_layout)\b' -g 'flashinfer/prefill.py'
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/utils.py` around lines 509 - 519, The cuDNN branch currently
selects based only on kv_layout and availability and must be guarded against
unsupported features (positional encodings, custom masks, FP16 Q/K reductions,
dtype combos); add a dedicated compatibility check (e.g., call or add a helper
like is_cudnn_backend_supported or is_cudnn_prefill_supported) that accepts
pos_encoding_mode, use_custom_mask, use_fp16_qk_reductions, dtype_q, dtype_kv
(same signature style as is_fa3_backend_supported) and only return "cudnn" when
that helper returns true and kv_layout == "NHD" and
_is_cudnn_available_for_attention(); otherwise fall back to "fa2" so unsupported
configs don't silently drop inputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/utils.py`:
- Around line 509-519: The cuDNN branch currently selects based only on
kv_layout and availability and must be guarded against unsupported features
(positional encodings, custom masks, FP16 Q/K reductions, dtype combos); add a
dedicated compatibility check (e.g., call or add a helper like
is_cudnn_backend_supported or is_cudnn_prefill_supported) that accepts
pos_encoding_mode, use_custom_mask, use_fp16_qk_reductions, dtype_q, dtype_kv
(same signature style as is_fa3_backend_supported) and only return "cudnn" when
that helper returns true and kv_layout == "NHD" and
_is_cudnn_available_for_attention(); otherwise fall back to "fa2" so unsupported
configs don't silently drop inputs.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c318136 and 749a036.

📒 Files selected for processing (1)
  • flashinfer/utils.py

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Feb 24, 2026

cc @saltyminty for review

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/utils.py (1)

509-510: ⚠️ Potential issue | 🔴 Critical

cuDNN is selected without any constraint validation — silent correctness regression.

The cuDNN branch fires solely on kv_layout == "NHD" + import availability. Unlike the fa3 branch (which is gated by is_fa3_backend_supported), there is no analogous guard here. Two concrete regressions remain:

  1. Custom mask silently dropped – when use_custom_mask=True + NHD layout + cuDNN installed, fa3 is correctly rejected, but the code now routes to "cudnn". Neither BatchPrefillWithPagedKVCacheWrapper.run() nor BatchPrefillWithRaggedKVCacheWrapper.run() forwards the custom mask to cudnn_batch_prefill_with_kv_cache, so the mask is silently ignored.
  2. Positional encoding silently ignored – same path when pos_encoding_mode != NONE; the cuDNN call sites accept only causal, not a full positional-encoding mode.

A guard function analogous to is_fa3_backend_supported is needed before this can be safely merged:

🐛 Proposed fix
+def is_cudnn_backend_supported(
+    pos_encoding_mode: int,
+    use_fp16_qk_reductions: bool,
+    use_custom_mask: bool,
+    dtype_q: torch.dtype,
+    dtype_kv: torch.dtype,
+) -> bool:
+    if use_custom_mask:
+        return False
+    if pos_encoding_mode != PosEncodingMode.NONE.value:
+        return False
+    return True
+
+
 def determine_attention_backend(...) -> str:
     if (
         kv_layout == "NHD"
         and _is_cudnn_available_for_attention()
+        and is_cudnn_backend_supported(
+            pos_encoding_mode,
+            use_fp16_qk_reductions,
+            use_custom_mask,
+            dtype_q,
+            dtype_kv,
+        )
     ):
         return "cudnn"
     if is_sm90a_supported(device) and is_fa3_backend_supported(...):
         return "fa3"
     return "fa2"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/utils.py` around lines 509 - 510, The cuDNN path is chosen solely
by kv_layout and import availability which causes silent regressions (custom
masks and non-causal positional encodings dropped); add a guard like
is_cudnn_backend_supported(...) and use it instead of only
_is_cudnn_available_for_attention() in the selector where kv_layout == "NHD"
(the code around kv_layout and _is_cudnn_available_for_attention). Implement
is_cudnn_backend_supported to validate that callers' constraints are safe (e.g.,
use_custom_mask is False or masks can be forwarded, and pos_encoding_mode is
NONE or CAUSAL only), update the selector to return "cudnn" only when that guard
passes, and either modify BatchPrefillWithPagedKVCacheWrapper.run /
BatchPrefillWithRaggedKVCacheWrapper.run to forward custom masks/positional
encoding to cudnn_batch_prefill_with_kv_cache or prevent selecting cudnn when
those features are in use.
🧹 Nitpick comments (1)
flashinfer/utils.py (1)

461-468: Cache _is_cudnn_available_for_attention and restructure the try/except.

Two minor improvements:

  1. TRY300 (Ruff): The return True inside the try body should move to an else block, which is the idiomatic way to express "run this only if no exception was raised."
  2. Caching: The result of this function is constant for the entire process lifetime (cuDNN is either installed or not). Without @functools.cache the try/import round-trip happens on every call to determine_attention_backend.
♻️ Proposed refactor
+@functools.cache
 def _is_cudnn_available_for_attention() -> bool:
     """Return True if cuDNN is available for attention (prefill)."""
     try:
         import cudnn  # noqa: F401
-
-        return True
     except ImportError:
         return False
+    else:
+        return True

As per coding guidelines, Python functions in flashinfer/*.py should use @functools.cache for module caching to avoid repeated overhead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/utils.py` around lines 461 - 468, The function
_is_cudnn_available_for_attention should be decorated with functools.cache and
the try/except should use an else block: import cudnn inside try, return False
in the except ImportError, and return True in the else. Add "from functools
import cache" or use "functools.cache" and apply `@cache` to
_is_cudnn_available_for_attention so the import check is performed only once per
process.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/utils.py`:
- Around line 509-510: The cuDNN path is chosen solely by kv_layout and import
availability which causes silent regressions (custom masks and non-causal
positional encodings dropped); add a guard like is_cudnn_backend_supported(...)
and use it instead of only _is_cudnn_available_for_attention() in the selector
where kv_layout == "NHD" (the code around kv_layout and
_is_cudnn_available_for_attention). Implement is_cudnn_backend_supported to
validate that callers' constraints are safe (e.g., use_custom_mask is False or
masks can be forwarded, and pos_encoding_mode is NONE or CAUSAL only), update
the selector to return "cudnn" only when that guard passes, and either modify
BatchPrefillWithPagedKVCacheWrapper.run /
BatchPrefillWithRaggedKVCacheWrapper.run to forward custom masks/positional
encoding to cudnn_batch_prefill_with_kv_cache or prevent selecting cudnn when
those features are in use.

---

Nitpick comments:
In `@flashinfer/utils.py`:
- Around line 461-468: The function _is_cudnn_available_for_attention should be
decorated with functools.cache and the try/except should use an else block:
import cudnn inside try, return False in the except ImportError, and return True
in the else. Add "from functools import cache" or use "functools.cache" and
apply `@cache` to _is_cudnn_available_for_attention so the import check is
performed only once per process.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 749a036 and 972efae.

📒 Files selected for processing (1)
  • flashinfer/utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants