Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash_attn/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
if "all_gather_into_tensor" not in dir(torch.distributed) and hasattr(torch.distributed, "_all_gather_base"):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
if "reduce_scatter_tensor" not in dir(torch.distributed) and hasattr(torch.distributed, "_reduce_scatter_base"):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base


Expand Down
21 changes: 19 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,27 @@ def validate_and_update_archs(archs):
# files included in the source distribution, in case the user compiles from source.
if IS_ROCM:
if ROCM_BACKEND == "triton":
if os.path.isdir(".git"):
if os.path.isdir("third_party/aiter"):
# If the submodule is already present (possibly with local patches applied),
# skip `git submodule update` to avoid overwriting local changes.
pass
elif os.path.isdir(".git"):
subprocess.run(["git", "submodule", "update", "--init", "third_party/aiter"], check=True)
else:
assert os.path.isdir("third_party/aiter"), (
"third_party/aiter is missing, please use source distribution or git clone"
)
aiter_install_env = os.environ.copy()
if sys.platform == "win32":
# CK (composable_kernel) is not available on Windows; disable it
# and skip pre-building HIP C++ kernels so only the pure-Triton
# flash-attention path is installed.
aiter_install_env["ENABLE_CK"] = "0"
aiter_install_env["PREBUILD_KERNELS"] = "0"
subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-build-isolation", "third_party/aiter"],
check=True,
env=aiter_install_env,
)
elif ROCM_BACKEND == "ck":
if os.path.isdir(".git"):
Expand Down Expand Up @@ -615,8 +627,12 @@ def __init__(self, *args, **kwargs) -> None:
# Note: torch is excluded because pip resolves it to CUDA PyTorch from PyPI, overwriting any pre-installed ROCm PyTorch. Users must have torch installed.
install_requires = [
"einops",
"triton==3.5.1",
]
if sys.platform == "win32":
# triton-windows is the community port of Triton for Windows ROCm
install_requires.append("triton-windows>=3.2.0")
else:
install_requires.append("triton==3.5.1")
else:
install_requires = [
"torch",
Expand Down Expand Up @@ -650,6 +666,7 @@ def __init__(self, *args, **kwargs) -> None:
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
"Operating System :: Microsoft :: Windows",
],
ext_modules=ext_modules,
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
Expand Down