Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def get_flash_attention2_nvcc_archs_flags(cuda_version: int):
if platform.system() != "Linux" and cuda_version < 1200:
return []
# Figure out default archs to target
DEFAULT_ARCHS_LIST = ""
if cuda_version >= 1208:
DEFAULT_ARCHS_LIST = "8.0;8.6;9.0;10.0;12.0"
elif cuda_version >= 1108:
Expand Down Expand Up @@ -283,7 +282,7 @@ def get_flash_attention3_nvcc_archs_flags(cuda_version: int):
return []
archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
if archs_list is None:
if torch.cuda.get_device_capability("cuda") != (9, 0):
if torch.cuda.get_device_capability("cuda") != (9, 0) and torch.cuda.get_device_capability("cuda") != (8, 0):
return []
archs_list = "8.0 9.0a"
nvcc_archs_flags = []
Expand Down Expand Up @@ -512,7 +511,7 @@ def get_extensions():
if cuda_version >= 1102:
nvcc_flags += [
"--threads",
"4",
os.getenv("NVCC_THREADS", "5"),
"--ptxas-options=-v",
]
if sys.platform == "win32":
Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _get_use_fa3() -> bool:

def fa3_available() -> bool:
has_cuda = torch.version.cuda is not None
is_90a = has_cuda and torch.cuda.get_device_capability() >= (9, 0)
is_90a = has_cuda and (8, 0) <= torch.cuda.get_device_capability() <= (9, 0)
has_valid_flash3 = flash3._C_flashattention3 is not None # pyre-ignore[16]
return is_90a and has_valid_flash3

Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@

FLASH_VERSION = flash_attn.__version__
FLASH_VER_MIN = parse_version("2.7.1")
FLASH_VER_LAST = parse_version("2.8.2") # last supported, inclusive
FLASH_VER_LAST = parse_version("2.8.3") # last supported, inclusive
flash_ver_parsed = parse_version(FLASH_VERSION)
if (
flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST
Expand Down
20 changes: 20 additions & 0 deletions xformers/ops/fmha/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,16 @@ class FwOp(AttentionFwOpBase):
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)

device_type = d.query.device.type
if device_type == "cuda" and (torch.version.hip is None):
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability > cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability == {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too new)"
)

check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.value, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
Expand Down Expand Up @@ -796,6 +806,16 @@ class BwOp(AttentionBwOpBase):
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)

device_type = d.query.device.type
if device_type == "cuda" and (torch.version.hip is None):
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability > cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability == {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too new)"
)

check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.value, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
Expand Down