diff --git a/CMakeLists.txt b/CMakeLists.txt index 95319e27c6a..e5798a446b4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -191,9 +191,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) "hopper/instantiations/flash_fwd_hdim64_512_bf16*_sm90.cu" "hopper/instantiations/flash_fwd_hdim192_128_bf16*_sm90.cu") list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) - file(GLOB FA3_BF16_GEN_SRCS_ - "hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu") - list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) # FP16 source files file(GLOB FA3_FP16_GEN_SRCS @@ -208,9 +205,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) "hopper/instantiations/flash_fwd_hdim64_512_fp16*_sm90.cu" "hopper/instantiations/flash_fwd_hdim192_128_fp16*_sm90.cu") list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) - file(GLOB FA3_FP16_GEN_SRCS_ - "hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu") - list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) # FP8 source files file(GLOB FA3_FP8_GEN_SRCS @@ -229,7 +223,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) # For CUDA we set the architectures on a per file basis # FaV3 is not yet supported in Blackwell if (VLLM_GPU_LANG STREQUAL "CUDA") - cuda_archs_loose_intersection(FA3_ARCHS "8.0;9.0a;" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(FA3_ARCHS "9.0a;" "${CUDA_ARCHS}") message(STATUS "FA3_ARCHS: ${FA3_ARCHS}") set_gencode_flags_for_srcs( diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 3d3aca79f0d..96f9335841d 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -48,13 +48,10 @@ def _is_fa2_supported(device = None) -> Tuple[bool, Optional[str]]: def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]: if not FA3_AVAILABLE: return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}" - if torch.cuda.get_device_capability(device)[0] < 8 \ - or torch.cuda.get_device_capability(device)[0] >= 10 \ - or torch.cuda.get_device_capability(device) == (8, 6) \ - or torch.cuda.get_device_capability(device) == (8, 9): + if torch.cuda.get_device_capability(device)[0] < 9 \ + or torch.cuda.get_device_capability(device)[0] >= 10: return False, \ - "FA3 is only supported on devices with compute capability >= 8" \ - " excluding 8.6 and 8.9 and Blackwell archs (>=10)" + "FA3 is only supported on devices with compute capability 9.0" return True, None def _is_fa4_supported(device = None) -> Tuple[bool, Optional[str]]: