Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
cxx11_abi: ["FALSE", "TRUE"]
include:
- torch-version: "2.9.0.dev20250904"
cuda-version: "13.0"
cuda-version: "13.0.0"
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.5 does not support Python 3.13
Expand Down
74 changes: 59 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;120").split(";")
return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;110;120").split(";")


def get_platform():
Expand All @@ -94,6 +94,59 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_version


def add_cuda_gencodes(cc_flag, archs, bare_metal_version):
"""
Adds -gencode flags based on nvcc capabilities:
- sm_80/90 (regular)
- sm_100/120 on CUDA >= 12.8
- Use 100f on CUDA >= 12.9 (Blackwell family-specific)
- Map requested 110 -> 101 if CUDA < 13.0 (Thor rename)
- Embed PTX for newest arch for forward compatibility
"""
# Always-regular 80
if "80" in archs:
cc_flag += ["-gencode", "arch=compute_80,code=sm_80"]

# Hopper 9.0 needs >= 11.8
if bare_metal_version >= Version("11.8") and "90" in archs:
cc_flag += ["-gencode", "arch=compute_90,code=sm_90"]

# Blackwell 10.x requires >= 12.8
if bare_metal_version >= Version("12.8"):
if "100" in archs:
# CUDA 12.9 introduced "family-specific" for Blackwell (100f)
if bare_metal_version >= Version("12.9"):
cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"]
else:
cc_flag += ["-gencode", "arch=compute_100,code=sm_100"]

if "120" in archs:
# sm_120 is supported in CUDA 12.8/12.9+ toolkits
if bare_metal_version >= Version("12.9"):
cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"]
else:
cc_flag += ["-gencode", "arch=compute_120,code=sm_120"]


# Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110
if "110" in archs:
if bare_metal_version >= Version("13.0"):
cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"]
else:
# Provide Thor support for CUDA 12.9 via sm_101
if bare_metal_version >= Version("12.8"):
cc_flag += ["-gencode", "arch=compute_101,code=sm_101"]
# else: no Thor support in older toolkits

# PTX for newest requested arch (forward-compat)
numeric = [a for a in archs if a.isdigit()]
if numeric:
newest = max(numeric, key=int)
cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"]

return cc_flag


def get_hip_version():
return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))

Expand Down Expand Up @@ -175,20 +228,11 @@ def validate_and_update_archs(archs):
"FlashAttention is only supported on CUDA 11.7 and above. "
"Note: make sure nvcc has a supported version by running nvcc -V."
)

if "80" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if bare_metal_version >= Version("11.8") and "90" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
if bare_metal_version >= Version("12.8") and "100" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_100,code=sm_100")
if bare_metal_version >= Version("12.8") and "120" in cuda_archs():
cc_flag.append("-gencode")
cc_flag.append("arch=compute_120,code=sm_120")
# Build -gencode (regular + PTX + family-specific 'f' when available)
add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version)
else:
# No nvcc present; warnings already emitted above
pass

# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
Expand Down