diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e88090f336d..26013ad5d67 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -51,7 +51,7 @@ jobs: cxx11_abi: ["FALSE", "TRUE"] include: - torch-version: "2.9.0.dev20250904" - cuda-version: "13.0" + cuda-version: "13.0.0" exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # Pytorch < 2.5 does not support Python 3.13 diff --git a/setup.py b/setup.py index a108c412c00..9a406839e7f 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;120").split(";") + return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;110;120").split(";") def get_platform(): @@ -94,6 +94,59 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version +def add_cuda_gencodes(cc_flag, archs, bare_metal_version): + """ + Adds -gencode flags based on nvcc capabilities: + - sm_80/90 (regular) + - sm_100/120 on CUDA >= 12.8 + - Use 100f on CUDA >= 12.9 (Blackwell family-specific) + - Map requested 110 -> 101 if CUDA < 13.0 (Thor rename) + - Embed PTX for newest arch for forward compatibility + """ + # Always-regular 80 + if "80" in archs: + cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] + + # Hopper 9.0 needs >= 11.8 + if bare_metal_version >= Version("11.8") and "90" in archs: + cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] + + # Blackwell 10.x requires >= 12.8 + if bare_metal_version >= Version("12.8"): + if "100" in archs: + # CUDA 12.9 introduced "family-specific" for Blackwell (100f) + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] + else: + cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] + + if "120" in archs: + # sm_120 is supported in CUDA 12.8/12.9+ toolkits + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] + else: + cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + + + # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110 + if "110" in archs: + if bare_metal_version >= Version("13.0"): + cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] + else: + # Provide Thor support for CUDA 12.9 via sm_101 + if bare_metal_version >= Version("12.8"): + cc_flag += ["-gencode", "arch=compute_101,code=sm_101"] + # else: no Thor support in older toolkits + + # PTX for newest requested arch (forward-compat) + numeric = [a for a in archs if a.isdigit()] + if numeric: + newest = max(numeric, key=int) + cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] + + return cc_flag + + def get_hip_version(): return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) @@ -175,20 +228,11 @@ def validate_and_update_archs(archs): "FlashAttention is only supported on CUDA 11.7 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) - - if "80" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if CUDA_HOME is not None: - if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") + # Build -gencode (regular + PTX + family-specific 'f' when available) + add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version) + else: + # No nvcc present; warnings already emitted above + pass # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI