Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,6 +1909,7 @@ def plan(
self._custom_mask_buf is not None, # use_custom_mask
q_data_type,
kv_data_type,
kv_layout=self._kv_layout,
)
if self._backend != "cudnn":
get_module_args = (
Expand Down Expand Up @@ -2852,6 +2853,7 @@ def plan(
self._custom_mask_buf is not None, # use_custom_mask
q_data_type,
kv_data_type,
kv_layout=self._kv_layout,
)

get_module_args = (
Expand Down
20 changes: 18 additions & 2 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,24 @@ def is_cutlass_backend_supported(
return True


def _is_cudnn_available_for_attention() -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid repeated import attempts, which can be costly, it's a good practice to cache the result of this function. Since this function is pure and its result won't change during the program's execution, using @functools.cache is ideal.

Suggested change
def _is_cudnn_available_for_attention() -> bool:
@functools.cache
def _is_cudnn_available_for_attention() -> bool:

"""Return True if cuDNN is available for attention (prefill)."""
try:
import cudnn # noqa: F401

return True
except ImportError:
return False


def determine_attention_backend(
device: torch.device,
pos_encoding_mode: int,
use_fp16_qk_reductions: bool,
use_custom_mask: bool,
dtype_q: torch.dtype,
dtype_kv: torch.dtype,
kv_layout: Optional[str] = None,
) -> str:
"""
Determine the appropriate attention backend based on the device and parameters.
Expand All @@ -485,12 +496,18 @@ def determine_attention_backend(
The data type of the query tensor.
dtype_kv : torch.dtype
The data type of the key-value tensor.
kv_layout : Optional[str]
The KV cache layout (``"NHD"`` or ``"HND"``). When ``"NHD"`` and cuDNN is
available, cuDNN may be chosen for prefill. When ``None`` (e.g. decode/sparse
callers), cuDNN is not considered. Defaults to ``None``.

Returns
-------
str
The name of the attention backend to be used.
"""
if kv_layout == "NHD" and _is_cudnn_available_for_attention():
return "cudnn"
if is_sm90a_supported(device) and is_fa3_backend_supported(
pos_encoding_mode,
use_fp16_qk_reductions,
Expand All @@ -499,8 +516,7 @@ def determine_attention_backend(
dtype_kv,
):
return "fa3"
else:
return "fa2"
return "fa2"


def version_at_least(version: str, base_version: str) -> bool:
Expand Down