Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
"hopper/instantiations/flash_fwd_hdim64_512_bf16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_128_bf16*_sm90.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})
file(GLOB FA3_BF16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu")
list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_})

# FP16 source files
file(GLOB FA3_FP16_GEN_SRCS
Expand All @@ -208,9 +205,6 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
"hopper/instantiations/flash_fwd_hdim64_512_fp16*_sm90.cu"
"hopper/instantiations/flash_fwd_hdim192_128_fp16*_sm90.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})
file(GLOB FA3_FP16_GEN_SRCS_
"hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu")
list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_})

# FP8 source files
file(GLOB FA3_FP8_GEN_SRCS
Expand All @@ -229,7 +223,7 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0)
# For CUDA we set the architectures on a per file basis
# FaV3 is not yet supported in Blackwell
if (VLLM_GPU_LANG STREQUAL "CUDA")
cuda_archs_loose_intersection(FA3_ARCHS "8.0;9.0a;" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(FA3_ARCHS "9.0a;" "${CUDA_ARCHS}")
message(STATUS "FA3_ARCHS: ${FA3_ARCHS}")

set_gencode_flags_for_srcs(
Expand Down
9 changes: 3 additions & 6 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,10 @@ def _is_fa2_supported(device = None) -> Tuple[bool, Optional[str]]:
def _is_fa3_supported(device = None) -> Tuple[bool, Optional[str]]:
if not FA3_AVAILABLE:
return False, f"FA3 is unavaible due to: {FA3_UNAVAILABLE_REASON}"
if torch.cuda.get_device_capability(device)[0] < 8 \
or torch.cuda.get_device_capability(device)[0] >= 10 \
or torch.cuda.get_device_capability(device) == (8, 6) \
or torch.cuda.get_device_capability(device) == (8, 9):
if torch.cuda.get_device_capability(device)[0] < 9 \
or torch.cuda.get_device_capability(device)[0] >= 10:
return False, \
"FA3 is only supported on devices with compute capability >= 8" \
" excluding 8.6 and 8.9 and Blackwell archs (>=10)"
"FA3 is only supported on devices with compute capability 9.0"
return True, None

def _is_fa4_supported(device = None) -> Tuple[bool, Optional[str]]:
Expand Down