From c318136830b21dbcbeffd945feabf9428adf9bf7 Mon Sep 17 00:00:00 2001 From: Scott Yokim Date: Mon, 23 Feb 2026 18:26:26 +0000 Subject: [PATCH 1/3] allow cudnn to be chosen for prefill --- flashinfer/prefill.py | 2 ++ flashinfer/utils.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index bf0fcfc822..59e0be191c 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1909,6 +1909,7 @@ def plan( self._custom_mask_buf is not None, # use_custom_mask q_data_type, kv_data_type, + kv_layout=self._kv_layout, ) if self._backend != "cudnn": get_module_args = ( @@ -2852,6 +2853,7 @@ def plan( self._custom_mask_buf is not None, # use_custom_mask q_data_type, kv_data_type, + kv_layout=self._kv_layout, ) get_module_args = ( diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 44e8f1b762..720aed15f5 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -458,6 +458,16 @@ def is_cutlass_backend_supported( return True +def _is_cudnn_available_for_attention() -> bool: + """Return True if cuDNN is available for attention (prefill).""" + try: + import cudnn # noqa: F401 + + return True + except ImportError: + return False + + def determine_attention_backend( device: torch.device, pos_encoding_mode: int, @@ -465,6 +475,7 @@ def determine_attention_backend( use_custom_mask: bool, dtype_q: torch.dtype, dtype_kv: torch.dtype, + kv_layout: Optional[str] = None, ) -> str: """ Determine the appropriate attention backend based on the device and parameters. @@ -485,6 +496,10 @@ def determine_attention_backend( The data type of the query tensor. dtype_kv : torch.dtype The data type of the key-value tensor. + kv_layout : Optional[str] + The KV cache layout (``"NHD"`` or ``"HND"``). When ``"NHD"`` and cuDNN is + available, cuDNN may be chosen for prefill. When ``None`` (e.g. decode/sparse + callers), cuDNN is not considered. Defaults to ``None``. Returns ------- @@ -499,8 +514,12 @@ def determine_attention_backend( dtype_kv, ): return "fa3" - else: - return "fa2" + if ( + kv_layout == "NHD" + and _is_cudnn_available_for_attention() + ): + return "cudnn" + return "fa2" def version_at_least(version: str, base_version: str) -> bool: From 749a036ab89a6e43c2960bbf939738d31ff2601f Mon Sep 17 00:00:00 2001 From: Scott Yokim Date: Mon, 23 Feb 2026 20:09:59 +0000 Subject: [PATCH 2/3] ruff --- flashinfer/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 720aed15f5..8bff972479 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -514,10 +514,7 @@ def determine_attention_backend( dtype_kv, ): return "fa3" - if ( - kv_layout == "NHD" - and _is_cudnn_available_for_attention() - ): + if kv_layout == "NHD" and _is_cudnn_available_for_attention(): return "cudnn" return "fa2" From 972efae7c4e5633d0b388a40d64444f662df2ce4 Mon Sep 17 00:00:00 2001 From: Scott Yokim Date: Tue, 24 Feb 2026 17:33:59 +0000 Subject: [PATCH 3/3] prefer cudnn --- flashinfer/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 8bff972479..592a2a3594 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -506,6 +506,8 @@ def determine_attention_backend( str The name of the attention backend to be used. """ + if kv_layout == "NHD" and _is_cudnn_available_for_attention(): + return "cudnn" if is_sm90a_supported(device) and is_fa3_backend_supported( pos_encoding_mode, use_fp16_qk_reductions, @@ -514,8 +516,6 @@ def determine_attention_backend( dtype_kv, ): return "fa3" - if kv_layout == "NHD" and _is_cudnn_available_for_attention(): - return "cudnn" return "fa2"