Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
FA3_UNAVAILABLE_REASON = str(e)
FA3_AVAILABLE = False

try:
from flash_attn.cute.interface import _flash_attn_fwd # noqa: F401
FA4_UNAVAILABLE_REASON = None
FA4_AVAILABLE = True
except ImportError as e:
FA4_UNAVAILABLE_REASON = str(e)
FA4_AVAILABLE = False

# isort: on

DEFAULT_FA_VERSION = 2
Expand All @@ -49,20 +57,32 @@ def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
" excluding 8.6 and 8.9 and Blackwell archs (>=10)"
return True, None

def _is_fa4_supported(device = None) -> Tuple[bool, Optional[str]]:
if not FA4_AVAILABLE:
return False, f"FA4 is unavaible due to: {FA4_UNAVAILABLE_REASON}"
if torch.cuda.get_device_capability(device)[0] != 10:
return False, \
"FA4 is only supported on devices with compute capability == 10"
return True, None

def is_fa_version_supported(fa_version: int, device = None) -> bool:
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
assert fa_version in [2, 3, 4], f"Unsupported FA version: {fa_version}"
if fa_version == 2:
return _is_fa2_supported(device)[0]
elif fa_version == 3:
return _is_fa3_supported(device)[0]
elif fa_version == 4:
return _is_fa4_supported(device)[0]

def fa_version_unsupported_reason(fa_version: int, device = None) \
-> Optional[str]:
assert fa_version in [2, 3], f"Unsupported FA version: {fa_version}"
assert fa_version in [2, 3, 4], f"Unsupported FA version: {fa_version}"
if fa_version == 2:
return _is_fa2_supported(device)[1]
elif fa_version == 3:
return _is_fa3_supported(device)[1]
elif fa_version == 4:
return _is_fa4_supported(device)[1]

#
# For vLLM we only care about `flash_attn_varlen_func` and
Expand Down