diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 27ef088ccaf..a186b14a2f4 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -25,6 +25,14 @@ FA3_UNAVAILABLE_REASON = str(e) FA3_AVAILABLE = False +try: + from flash_attn.cute.interface import _flash_attn_fwd # noqa: F401 + FA4_UNAVAILABLE_REASON = None + FA4_AVAILABLE = True +except ImportError as e: + FA4_UNAVAILABLE_REASON = str(e) + FA4_AVAILABLE = False + # isort: on DEFAULT_FA_VERSION = 2 @@ -49,20 +57,32 @@ def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]: " excluding 8.6 and 8.9 and Blackwell archs (>=10)" return True, None +def _is_fa4_supported(device = None) -> Tuple[bool, Optional[str]]: + if not FA4_AVAILABLE: + return False, f"FA4 is unavaible due to: {FA4_UNAVAILABLE_REASON}" + if torch.cuda.get_device_capability(device)[0] != 10: + return False, \ + "FA4 is only supported on devices with compute capability == 10" + return True, None + def is_fa_version_supported(fa_version: int, device = None) -> bool: - assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" + assert fa_version in [2, 3, 4], f"Unsupported FA version: {fa_version}" if fa_version == 2: return _is_fa2_supported(device)[0] elif fa_version == 3: return _is_fa3_supported(device)[0] + elif fa_version == 4: + return _is_fa4_supported(device)[0] def fa_version_unsupported_reason(fa_version: int, device = None) \ -> Optional[str]: - assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}" + assert fa_version in [2, 3, 4], f"Unsupported FA version: {fa_version}" if fa_version == 2: return _is_fa2_supported(device)[1] elif fa_version == 3: return _is_fa3_supported(device)[1] + elif fa_version == 4: + return _is_fa4_supported(device)[1] # # For vLLM we only care about `flash_attn_varlen_func` and