diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index bf0fcfc822..59e0be191c 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1909,6 +1909,7 @@ def plan( self._custom_mask_buf is not None, # use_custom_mask q_data_type, kv_data_type, + kv_layout=self._kv_layout, ) if self._backend != "cudnn": get_module_args = ( @@ -2852,6 +2853,7 @@ def plan( self._custom_mask_buf is not None, # use_custom_mask q_data_type, kv_data_type, + kv_layout=self._kv_layout, ) get_module_args = ( diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 44e8f1b762..592a2a3594 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -458,6 +458,16 @@ def is_cutlass_backend_supported( return True +def _is_cudnn_available_for_attention() -> bool: + """Return True if cuDNN is available for attention (prefill).""" + try: + import cudnn # noqa: F401 + + return True + except ImportError: + return False + + def determine_attention_backend( device: torch.device, pos_encoding_mode: int, @@ -465,6 +475,7 @@ def determine_attention_backend( use_custom_mask: bool, dtype_q: torch.dtype, dtype_kv: torch.dtype, + kv_layout: Optional[str] = None, ) -> str: """ Determine the appropriate attention backend based on the device and parameters. @@ -485,12 +496,18 @@ def determine_attention_backend( The data type of the query tensor. dtype_kv : torch.dtype The data type of the key-value tensor. + kv_layout : Optional[str] + The KV cache layout (``"NHD"`` or ``"HND"``). When ``"NHD"`` and cuDNN is + available, cuDNN may be chosen for prefill. When ``None`` (e.g. decode/sparse + callers), cuDNN is not considered. Defaults to ``None``. Returns ------- str The name of the attention backend to be used. """ + if kv_layout == "NHD" and _is_cudnn_available_for_attention(): + return "cudnn" if is_sm90a_supported(device) and is_fa3_backend_supported( pos_encoding_mode, use_fp16_qk_reductions, @@ -499,8 +516,7 @@ def determine_attention_backend( dtype_kv, ): return "fa3" - else: - return "fa2" + return "fa2" def version_at_least(version: str, base_version: str) -> bool: