diff --git a/hopper/setup.py b/hopper/setup.py index 519d1c04f42..022899e7f37 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -67,6 +67,12 @@ DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" +# Auto-disable the Python limited API (abi3) for free-threaded Python (3.13t+) +# Free-threaded Python does not support the limited API +# See: https://docs.python.org/3/howto/free-threading-extensions.html +FREE_THREADED = sysconfig.get_config_var("Py_GIL_DISABLED") +DISABLE_ABI3 = FREE_THREADED or os.getenv("FLASH_ATTENTION_DISABLE_ABI3", "FALSE") == "TRUE" + # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', # and pass "-gencode arch=compute_sm80,code=sm_80" to files ending in '_sm80.cu' @@ -585,11 +591,11 @@ def nvcc_threads_args(): name=f"{PACKAGE_NAME}._C", sources=sources, extra_compile_args={ - "cxx": ["-O3", "-std=c++17", "-DPy_LIMITED_API=0x03090000"] + stable_args + feature_args, + "cxx": ["-O3", "-std=c++17"] + ([] if DISABLE_ABI3 else ["-DPy_LIMITED_API=0x03090000"]) + stable_args + feature_args, "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag + feature_args, }, include_dirs=include_dirs, - py_limited_api=True, + py_limited_api=not DISABLE_ABI3, ) ) @@ -698,5 +704,5 @@ def run(self): "packaging", "ninja", ], - options={"bdist_wheel": {"py_limited_api": "cp39"}}, + options={} if DISABLE_ABI3 else {"bdist_wheel": {"py_limited_api": "cp39"}}, )