diff --git a/setup.py b/setup.py index f9cb207f2..0b4c99059 100644 --- a/setup.py +++ b/setup.py @@ -175,7 +175,6 @@ def get_flash_attention2_nvcc_archs_flags(cuda_version: int): if platform.system() != "Linux" and cuda_version < 1200: return [] # Figure out default archs to target - DEFAULT_ARCHS_LIST = "" if cuda_version >= 1208: DEFAULT_ARCHS_LIST = "8.0;8.6;9.0;10.0;12.0" elif cuda_version >= 1108: @@ -283,7 +282,7 @@ def get_flash_attention3_nvcc_archs_flags(cuda_version: int): return [] archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST") if archs_list is None: - if torch.cuda.get_device_capability("cuda") != (9, 0): + if torch.cuda.get_device_capability("cuda") != (9, 0) and torch.cuda.get_device_capability("cuda") != (8, 0): return [] archs_list = "8.0 9.0a" nvcc_archs_flags = [] @@ -512,7 +511,7 @@ def get_extensions(): if cuda_version >= 1102: nvcc_flags += [ "--threads", - "4", + os.getenv("NVCC_THREADS", "5"), "--ptxas-options=-v", ] if sys.platform == "win32": diff --git a/xformers/ops/fmha/dispatch.py b/xformers/ops/fmha/dispatch.py index 5908635da..be0f75ead 100644 --- a/xformers/ops/fmha/dispatch.py +++ b/xformers/ops/fmha/dispatch.py @@ -31,7 +31,7 @@ def _get_use_fa3() -> bool: def fa3_available() -> bool: has_cuda = torch.version.cuda is not None - is_90a = has_cuda and torch.cuda.get_device_capability() >= (9, 0) + is_90a = has_cuda and (8, 0) <= torch.cuda.get_device_capability() <= (9, 0) has_valid_flash3 = flash3._C_flashattention3 is not None # pyre-ignore[16] return is_90a and has_valid_flash3 diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 29868800c..646f84dc3 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -74,7 +74,7 @@ FLASH_VERSION = flash_attn.__version__ FLASH_VER_MIN = parse_version("2.7.1") - FLASH_VER_LAST = parse_version("2.8.2") # last supported, inclusive + FLASH_VER_LAST = parse_version("2.8.3") # last supported, inclusive flash_ver_parsed = parse_version(FLASH_VERSION) if ( flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST diff --git a/xformers/ops/fmha/flash3.py b/xformers/ops/fmha/flash3.py index e9136c329..de716c4fb 100644 --- a/xformers/ops/fmha/flash3.py +++ b/xformers/ops/fmha/flash3.py @@ -641,6 +641,16 @@ class FwOp(AttentionFwOpBase): @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(FwOp, cls).not_supported_reasons(d) + + device_type = d.query.device.type + if device_type == "cuda" and (torch.version.hip is None): + device_capability = torch.cuda.get_device_capability(d.device) + if device_capability > cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: + reasons.append( + f"requires device with capability == {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} " + f"but your GPU has capability {device_capability} (too new)" + ) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) check_lastdim_alignment_stride1(reasons, "key", d.value, 8) check_lastdim_alignment_stride1(reasons, "value", d.value, 8) @@ -796,6 +806,16 @@ class BwOp(AttentionBwOpBase): @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d) + + device_type = d.query.device.type + if device_type == "cuda" and (torch.version.hip is None): + device_capability = torch.cuda.get_device_capability(d.device) + if device_capability > cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: + reasons.append( + f"requires device with capability == {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} " + f"but your GPU has capability {device_capability} (too new)" + ) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) check_lastdim_alignment_stride1(reasons, "key", d.value, 8) check_lastdim_alignment_stride1(reasons, "value", d.value, 8)