From f91ff7e394584268fa276426da4e37c07b0e13ea Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 12 Sep 2025 12:11:31 +0200 Subject: [PATCH 1/5] fix typo --- .github/workflows/publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e88090f336d..26013ad5d67 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -51,7 +51,7 @@ jobs: cxx11_abi: ["FALSE", "TRUE"] include: - torch-version: "2.9.0.dev20250904" - cuda-version: "13.0" + cuda-version: "13.0.0" exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # Pytorch < 2.5 does not support Python 3.13 From d0746bf30d29d188e2c5b470c76aa5ae77696db7 Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 12 Sep 2025 12:15:13 +0200 Subject: [PATCH 2/5] Update setup.py --- setup.py | 63 ++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/setup.py b/setup.py index a108c412c00..c9148762c2c 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;120").split(";") + return os.getenv("FLASH_ATTN_CUDA_ARCHS", "80;90;100;110;120").split(";") def get_platform(): @@ -94,6 +94,48 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version +def add_cuda_gencodes(cc_flag, archs, bare_metal_version): + """ + Fills cc_flag with: + - regular 'sm_XX' targets + - PTX of the newest arch for forward-compat + - family-specific 'f' (100f/110f/120f) if CUDA >= 12.9 + """ + # Regular targets + if "80" in archs: + cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] + + if bare_metal_version >= Version("11.8") and "90" in archs: + cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] + + if bare_metal_version >= Version("12.8") and "100" in archs: + cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] + + if bare_metal_version >= Version("12.8") and "110" in archs: + cc_flag += ["-gencode", "arch=compute_110,code=sm_110"] + + if bare_metal_version >= Version("12.8") and "120" in archs: + cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + + # PTX for newest arch (forward-compat) + numeric_archs = [a for a in archs if a.isdigit()] + if numeric_archs: + newest = max(numeric_archs, key=int) + cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] + + # Family-specific (CUDA >= 12.9): 100f/110f/120f + if bare_metal_version >= Version("12.9"): + if "100" in archs: + # code=sm_100 and code=sm_100f are aliases for the 100f family + cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] + if "110" in archs: + cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] + if "120" in archs: + cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] + + return cc_flag + + def get_hip_version(): return parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+')) @@ -175,20 +217,11 @@ def validate_and_update_archs(archs): "FlashAttention is only supported on CUDA 11.7 and above. " "Note: make sure nvcc has a supported version by running nvcc -V." ) - - if "80" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - if CUDA_HOME is not None: - if bare_metal_version >= Version("11.8") and "90" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - if bare_metal_version >= Version("12.8") and "100" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_100,code=sm_100") - if bare_metal_version >= Version("12.8") and "120" in cuda_archs(): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_120,code=sm_120") + # Build -gencode (regular + PTX + family-specific 'f' when available) + add_cuda_gencodes(cc_flag, set(cuda_archs()), bare_metal_version) + else: + # No nvcc present; warnings already emitted above + pass # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI From e0260ea55baa5f9847430906ae1c2c32a53dc264 Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 12 Sep 2025 12:20:22 +0200 Subject: [PATCH 3/5] Update setup.py --- setup.py | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index c9148762c2c..17707755b8b 100644 --- a/setup.py +++ b/setup.py @@ -96,43 +96,33 @@ def get_cuda_bare_metal_version(cuda_dir): def add_cuda_gencodes(cc_flag, archs, bare_metal_version): """ - Fills cc_flag with: - - regular 'sm_XX' targets - - PTX of the newest arch for forward-compat - - family-specific 'f' (100f/110f/120f) if CUDA >= 12.9 + Adds -gencode flags: + - Regular sm_XX targets (80/90 always regular) + - For 100/110/120: family-specific 'f' if CUDA >= 12.9, else regular + - PTX for the newest arch for forward-compat """ - # Regular targets + # Always regular 80 + 90 if "80" in archs: cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] if bare_metal_version >= Version("11.8") and "90" in archs: cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] - if bare_metal_version >= Version("12.8") and "100" in archs: - cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] + # 100/110/120 → choose family-specific if supported + if bare_metal_version >= Version("12.8"): + for a in ("100", "110", "120"): + if a in archs: + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", f"arch=compute_{a}f,code=sm_{a}"] + else: + cc_flag += ["-gencode", f"arch=compute_{a},code=sm_{a}"] - if bare_metal_version >= Version("12.8") and "110" in archs: - cc_flag += ["-gencode", "arch=compute_110,code=sm_110"] - - if bare_metal_version >= Version("12.8") and "120" in archs: - cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] - - # PTX for newest arch (forward-compat) + # Add PTX of newest arch for forward-compat numeric_archs = [a for a in archs if a.isdigit()] if numeric_archs: newest = max(numeric_archs, key=int) cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] - # Family-specific (CUDA >= 12.9): 100f/110f/120f - if bare_metal_version >= Version("12.9"): - if "100" in archs: - # code=sm_100 and code=sm_100f are aliases for the 100f family - cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] - if "110" in archs: - cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] - if "120" in archs: - cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] - return cc_flag From 82c9a25e9caee098de1d77403903a4c4b2e67736 Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 12 Sep 2025 12:23:23 +0200 Subject: [PATCH 4/5] Update setup.py --- setup.py | 49 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 17707755b8b..7954daed679 100644 --- a/setup.py +++ b/setup.py @@ -96,31 +96,48 @@ def get_cuda_bare_metal_version(cuda_dir): def add_cuda_gencodes(cc_flag, archs, bare_metal_version): """ - Adds -gencode flags: - - Regular sm_XX targets (80/90 always regular) - - For 100/110/120: family-specific 'f' if CUDA >= 12.9, else regular - - PTX for the newest arch for forward-compat + Adds -gencode flags based on nvcc capabilities: + - sm_80/90 (regular) + - sm_100/120 on CUDA >= 12.8 + - Use 100f on CUDA >= 12.9 (Blackwell family-specific) + - Map requested 110 -> 101 if CUDA < 13.0 (Thor rename) + - Embed PTX for newest arch for forward compatibility """ - # Always regular 80 + 90 + # Always-regular 80 if "80" in archs: cc_flag += ["-gencode", "arch=compute_80,code=sm_80"] + # Hopper 9.0 needs >= 11.8 if bare_metal_version >= Version("11.8") and "90" in archs: cc_flag += ["-gencode", "arch=compute_90,code=sm_90"] - # 100/110/120 → choose family-specific if supported + # Blackwell 10.x requires >= 12.8 if bare_metal_version >= Version("12.8"): - for a in ("100", "110", "120"): - if a in archs: + if "100" in archs: + # CUDA 12.9 introduced "family-specific" for Blackwell (100f) + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_100f,code=sm_100"] + else: + cc_flag += ["-gencode", "arch=compute_100,code=sm_100"] + + if "120" in archs: + # sm_120 is supported in CUDA 12.8/12.9+ toolkits + cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + + # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110 + if "110" in archs: + if bare_metal_version >= Version("13.0"): + cc_flag += ["-gencode", "arch=compute_110,code=sm_110"] + else: + # Provide Thor support for CUDA 12.9 via sm_101 if bare_metal_version >= Version("12.9"): - cc_flag += ["-gencode", f"arch=compute_{a}f,code=sm_{a}"] - else: - cc_flag += ["-gencode", f"arch=compute_{a},code=sm_{a}"] - - # Add PTX of newest arch for forward-compat - numeric_archs = [a for a in archs if a.isdigit()] - if numeric_archs: - newest = max(numeric_archs, key=int) + cc_flag += ["-gencode", "arch=compute_101,code=sm_101"] + # else: no Thor support in older toolkits + + # PTX for newest requested arch (forward-compat) + numeric = [a for a in archs if a.isdigit()] + if numeric: + newest = max(numeric, key=int) cc_flag += ["-gencode", f"arch=compute_{newest},code=compute_{newest}"] return cc_flag From 5e68ff080697dc9ed4637ff2457e6505e638004d Mon Sep 17 00:00:00 2001 From: Johnny Date: Fri, 12 Sep 2025 12:26:06 +0200 Subject: [PATCH 5/5] Update setup.py --- setup.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 7954daed679..9a406839e7f 100644 --- a/setup.py +++ b/setup.py @@ -122,15 +122,19 @@ def add_cuda_gencodes(cc_flag, archs, bare_metal_version): if "120" in archs: # sm_120 is supported in CUDA 12.8/12.9+ toolkits - cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + if bare_metal_version >= Version("12.9"): + cc_flag += ["-gencode", "arch=compute_120f,code=sm_120"] + else: + cc_flag += ["-gencode", "arch=compute_120,code=sm_120"] + # Thor rename: 12.9 uses sm_101; 13.0+ uses sm_110 if "110" in archs: if bare_metal_version >= Version("13.0"): - cc_flag += ["-gencode", "arch=compute_110,code=sm_110"] + cc_flag += ["-gencode", "arch=compute_110f,code=sm_110"] else: # Provide Thor support for CUDA 12.9 via sm_101 - if bare_metal_version >= Version("12.9"): + if bare_metal_version >= Version("12.8"): cc_flag += ["-gencode", "arch=compute_101,code=sm_101"] # else: no Thor support in older toolkits