From c4a2da83efa46993990e5a744d172199ef150df0 Mon Sep 17 00:00:00 2001 From: 0xDELUXA Date: Fri, 17 Apr 2026 12:25:08 +0300 Subject: [PATCH 1/2] Initial RDNA Windows bring-up for CK FMHA --- setup.py | 159 ++++++++++++++++-- .../hip_fmha/ck_tiled_fmha_warp_tile_define.h | 12 ++ xformers/ops/differentiable_collectives.py | 12 +- xformers/ops/seqpar.py | 10 +- xformers/ops/sequence_parallel_fused_ops.py | 10 +- 5 files changed, 180 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index 6306cefea2..dae6f3bba4 100644 --- a/setup.py +++ b/setup.py @@ -144,19 +144,32 @@ def get_cuda_version(cuda_dir) -> int: def get_hip_version(rocm_dir) -> Optional[str]: - hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc") - try: - raw_output = subprocess.check_output( - [hipcc_bin, "--version"], universal_newlines=True - ) - except Exception as e: - print( - f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}" - ) + candidates: List[str] = [] + if rocm_dir is not None: + # Standard Linux layout + candidates.append(os.path.join(rocm_dir, "bin", "hipcc")) + # Some Windows ROCm distributions (e.g. TheRock) place hipcc directly + # under the venv Scripts dir; rocm_dir may already point there. + candidates.append(os.path.join(rocm_dir, "hipcc")) + # Fall back to PATH lookup + candidates.append("hipcc") + last_error: Optional[Exception] = None + for hipcc_bin in candidates: + try: + raw_output = subprocess.check_output( + [hipcc_bin, "--version"], universal_newlines=True + ) + except Exception as e: + last_error = e + continue + for line in raw_output.split("\n"): + if "HIP version" in line: + return line.split()[-1] return None - for line in raw_output.split("\n"): - if "HIP version" in line: - return line.split()[-1] + print( + f"hip installation not found: {last_error} " + f"ROCM_PATH={os.environ.get('ROCM_PATH')}" + ) return None @@ -394,18 +407,76 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): def rename_cpp_cu(cpp_files): + # Only overwrite the .cu copy if the source has actually changed. shutil.copy + # always bumps mtime, which invalidates ninja's incremental cache and forces + # a full HIP rebuild every invocation. for entry in cpp_files: - shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") + dst = os.path.splitext(entry)[0] + ".cu" + if os.path.exists(dst): + try: + with open(entry, "rb") as f1, open(dst, "rb") as f2: + if f1.read() == f2.read(): + continue + except OSError: + pass + shutil.copy(entry, dst) + + +def get_rocm_root() -> Optional[str]: + """Locate the ROCm SDK root (the dir containing bin/, lib/llvm/, etc.). + + Needed on Windows: hipcc forwards compile commands to clang but doesn't + pass `--rocm-path`, so clang fails to find the device library bitcode. + The TheRock SDK lives at venv/Lib/site-packages/_rocm_sdk_devel; we + discover it via the bundled `rocm-sdk` helper, then fall back to env + vars or hipcc's binary location. + """ + # 1) Use the rocm-sdk helper if installed (TheRock). + try: + out = subprocess.check_output( + ["rocm-sdk", "path", "--root"], universal_newlines=True + ).strip() + if out and os.path.isdir(out): + return out + except Exception: + pass + # 2) Standard env vars. + for key in ("ROCM_HOME", "ROCM_PATH", "HIP_PATH"): + val = os.environ.get(key) + if val and os.path.isdir(val) and os.path.isdir( + os.path.join(val, "lib", "llvm", "amdgcn", "bitcode") + ): + return val + # 3) Walk up from hipcc's location. + hipcc = shutil.which("hipcc") + if hipcc: + # hipcc usually lives at /bin/hipcc; walk up one level. + candidate = os.path.dirname(os.path.dirname(os.path.realpath(hipcc))) + if os.path.isdir( + os.path.join(candidate, "lib", "llvm", "amdgcn", "bitcode") + ): + return candidate + return None def get_rocm_agent_arch(): + # 1) Linux: rocm_agent_enumerator if present. exec_path = "/opt/rocm/bin/rocm_agent_enumerator" if os.path.isfile(exec_path) and os.access(exec_path, os.X_OK): arches = subprocess.check_output([exec_path], universal_newlines=True) - arch_list = arches.strip().split() - return arch_list[0] - else: - return "gfx942" + arch_list = [a for a in arches.strip().split() if a and a != "gfx000"] + if arch_list: + return arch_list[0] + # 2) Cross-platform: ask torch (works on Windows ROCm via TheRock). + try: + if torch.cuda.is_available() and torch.version.hip: + arch_name = torch.cuda.get_device_properties(0).gcnArchName + # gcnArchName looks like "gfx1200" or "gfx942:sramecc+:xnack-"; strip features. + return arch_name.split(":")[0] + except Exception as e: + print(f"torch-based GPU arch detection failed: {e}") + # 3) Last-resort fallback (preserves prior behaviour). + return "gfx942" def get_extensions(): @@ -580,11 +651,21 @@ def get_extensions(): ): rename_cpp_cu(source_hip) hip_version = get_hip_version(ROCM_HOME) + rocm_root = get_rocm_root() source_hip_cu = [] for ff in source_hip: source_hip_cu += [ff.replace(".cpp", ".cu")] + # Mirror the CUDA-side XFORMERS_SELECTIVE_BUILD filter above so ROCm + # users also get a substring knob for dev-time iteration. Like the + # CUDA version, this is sharp-edged: the pattern must be wide enough + # to keep every instance referenced by the dispatcher TUs that remain + # in the build, or the link will fail. + if "XFORMERS_SELECTIVE_BUILD" in os.environ: + pattern = os.environ["XFORMERS_SELECTIVE_BUILD"] + source_hip_cu = [f for f in source_hip_cu if pattern in str(f)] + extension = CUDAExtension sources += source_hip_cu include_dirs += [ @@ -608,26 +689,68 @@ def get_extensions(): if arch == "native": arch = get_rocm_agent_arch() - if arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"]: + # CDNA archs use MFMA. RDNA archs (gfx11xx / gfx12xx) use WMMA — the + # xformers wrapper layer also needs FMHA_BUILD_ON_GFX11/GFX12 defines + # so the per-arch warp_tile / pipeline selection picks WMMA-friendly + # shapes. CK tile itself supports both. + cdna_archs = ["gfx908", "gfx90a", "gfx942", "gfx950"] + rdna3_archs = [ + "gfx1100", "gfx1101", "gfx1102", "gfx1103", + "gfx1150", "gfx1151", "gfx1152", "gfx1153", + ] + rdna4_archs = ["gfx1200", "gfx1201"] + + if arch not in cdna_archs + rdna3_archs + rdna4_archs: raise ValueError(f"Not supported AMD GPU arch: {arch}") if arch == "gfx950": cc_flag += ["-DFMHA_BUILD_ON_GFX950"] + elif arch in rdna3_archs: + cc_flag += ["-DFMHA_BUILD_ON_GFX11"] + elif arch in rdna4_archs: + cc_flag += ["-DFMHA_BUILD_ON_GFX12"] offload_compress_flag = [] if hip_version >= "6.2.": offload_compress_flag = ["--offload-compress"] + # MSVC's UCRT marks std::getenv as _CRT_INSECURE_DEPRECATE, which fires + # -Wdeprecated-declarations inside ck_tile/core/utility/env.hpp. That + # header is transitively included by virtually every HIP TU, so under + # -Werror the Windows build fails to link. Only silence this class of + # warning on Windows — Linux/CDNA keeps -Werror enforcement unchanged. + windows_warning_flags = ( + ["-Wno-deprecated-declarations"] + if platform.system() == "Windows" + else [] + ) + + rocm_path_flag: List[str] = [] + if rocm_root is not None: + rocm_path_flag.append(f"--rocm-path={rocm_root}") + # TheRock layout puts the device-library bitcode at + # /lib/llvm/amdgcn/bitcode rather than the upstream + # /amdgcn/bitcode that clang's --rocm-path discovery + # expects. Pass the explicit override so clang can resolve + # oclc_isa_version_*.bc et al. + device_lib = os.path.join( + rocm_root, "lib", "llvm", "amdgcn", "bitcode" + ) + if os.path.isdir(device_lib): + rocm_path_flag.append(f"--rocm-device-lib-path={device_lib}") + extra_compile_args["nvcc"] = [ "-O3", "-std=c++20", f"--offload-arch={arch}", + *rocm_path_flag, *offload_compress_flag, "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", "-Werror", + *windows_warning_flags, "-Wno-c++11-narrowing", "-Woverloaded-virtual", "-Wno-unknown-warning-option", diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h index 55aac7814f..4f5a8852b6 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h @@ -8,6 +8,18 @@ #include +#if defined(FMHA_BUILD_ON_GFX11) || defined(FMHA_BUILD_ON_GFX12) +// RDNA WMMA only registers 16x16x16 in CK tile's warp_gemm dispatcher for +// fp16 / bf16 (and fp8 / int8). The xformers wrapper layer was authored for +// CDNA MFMA shapes (32x32x16 and 16x16x32). To keep the per-arch *_setting.h +// headers small, we alias all three names down to the WMMA-supported shape. +// Block tile geometry is unchanged — each warp simply issues more WMMA +// instructions to cover the same output tile. +using WarpTile_32x32x16 = ck_tile::sequence<16, 16, 16>; +using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 16>; +using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>; +#else using WarpTile_32x32x16 = ck_tile::sequence<32, 32, 16>; using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>; using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>; +#endif diff --git a/xformers/ops/differentiable_collectives.py b/xformers/ops/differentiable_collectives.py index 9d87629dd1..1930ffe0a9 100644 --- a/xformers/ops/differentiable_collectives.py +++ b/xformers/ops/differentiable_collectives.py @@ -4,11 +4,17 @@ # LICENSE file in the root directory of this source tree. -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import torch import torch.distributed +if TYPE_CHECKING: + # torch.distributed.Work is not exposed on every build (notably Windows + # ROCm / TheRock), so we only import it for type checkers. The runtime + # annotation is a forward-ref string that is never evaluated. + from torch.distributed import Work + def all_reduce( x: torch.Tensor, *, process_group: torch.distributed.ProcessGroup @@ -24,7 +30,7 @@ def all_reduce( def gather_along_first_dim_async( input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup -) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: +) -> Tuple[torch.Tensor, Optional["Work"]]: mp_size = process_group.size() if mp_size == 1: return input_, None @@ -42,7 +48,7 @@ def gather_along_first_dim_async( def reduce_scatter_along_first_dim_async( input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup -) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: +) -> Tuple[torch.Tensor, Optional["Work"]]: mp_size = process_group.size() if mp_size == 1: return input_, None diff --git a/xformers/ops/seqpar.py b/xformers/ops/seqpar.py index b81eaf3304..3114e7495b 100644 --- a/xformers/ops/seqpar.py +++ b/xformers/ops/seqpar.py @@ -7,7 +7,15 @@ from typing import Callable, List, Tuple import torch -from torch.distributed.distributed_c10d import _resolve_process_group + +# Some torch builds (e.g. Windows ROCm / TheRock) ship a stripped-down +# torch.distributed without distributed_c10d. Importing this module must not +# break xformers for single-GPU users; the fallback is only reachable if one +# of the sequence-parallel ops below is actually called. +try: + from torch.distributed.distributed_c10d import _resolve_process_group +except ImportError: + _resolve_process_group = None # type: ignore[assignment] from .differentiable_collectives import ( gather_along_first_dim, diff --git a/xformers/ops/sequence_parallel_fused_ops.py b/xformers/ops/sequence_parallel_fused_ops.py index 1f76667382..179125abd6 100644 --- a/xformers/ops/sequence_parallel_fused_ops.py +++ b/xformers/ops/sequence_parallel_fused_ops.py @@ -9,7 +9,15 @@ import torch import torch.distributed as dist import torch.multiprocessing.reductions -from torch.distributed._symmetric_memory import get_symm_mem_workspace + +# torch.distributed._symmetric_memory is absent from some torch builds +# (e.g. Windows ROCm / TheRock). Guard the import so xformers stays usable +# on single-GPU; the None fallback is only reached if the fused ops below +# are actually invoked. +try: + from torch.distributed._symmetric_memory import get_symm_mem_workspace +except ImportError: + get_symm_mem_workspace = None # type: ignore[assignment] OP_FINISHED_CHANNEL = 0 COMMS_READY_CHANNEL = 1 From 66881bc5a3b56d34ad8c0b05415d9421e13f87da Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Sat, 2 May 2026 16:11:07 +0900 Subject: [PATCH 2/2] Fix RDNA FMHA attention tests --- setup.py | 197 ++++++------------ .../hip_fmha/attention_ck_rand_uniform.cpp | 113 +++++++--- .../hip_fmha/ck_tiled_rand_uniform_kernel.h | 8 +- xformers/ops/differentiable_collectives.py | 13 +- xformers/ops/fmha/flash.py | 101 ++++++++- 5 files changed, 253 insertions(+), 179 deletions(-) diff --git a/setup.py b/setup.py index dae6f3bba4..6c935c9596 100644 --- a/setup.py +++ b/setup.py @@ -144,32 +144,19 @@ def get_cuda_version(cuda_dir) -> int: def get_hip_version(rocm_dir) -> Optional[str]: - candidates: List[str] = [] - if rocm_dir is not None: - # Standard Linux layout - candidates.append(os.path.join(rocm_dir, "bin", "hipcc")) - # Some Windows ROCm distributions (e.g. TheRock) place hipcc directly - # under the venv Scripts dir; rocm_dir may already point there. - candidates.append(os.path.join(rocm_dir, "hipcc")) - # Fall back to PATH lookup - candidates.append("hipcc") - last_error: Optional[Exception] = None - for hipcc_bin in candidates: - try: - raw_output = subprocess.check_output( - [hipcc_bin, "--version"], universal_newlines=True - ) - except Exception as e: - last_error = e - continue - for line in raw_output.split("\n"): - if "HIP version" in line: - return line.split()[-1] + hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc") + try: + raw_output = subprocess.check_output( + [hipcc_bin, "--version"], universal_newlines=True + ) + except Exception as e: + print( + f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}" + ) return None - print( - f"hip installation not found: {last_error} " - f"ROCM_PATH={os.environ.get('ROCM_PATH')}" - ) + for line in raw_output.split("\n"): + if "HIP version" in line: + return line.split()[-1] return None @@ -407,76 +394,20 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): def rename_cpp_cu(cpp_files): - # Only overwrite the .cu copy if the source has actually changed. shutil.copy - # always bumps mtime, which invalidates ninja's incremental cache and forces - # a full HIP rebuild every invocation. for entry in cpp_files: - dst = os.path.splitext(entry)[0] + ".cu" - if os.path.exists(dst): - try: - with open(entry, "rb") as f1, open(dst, "rb") as f2: - if f1.read() == f2.read(): - continue - except OSError: - pass - shutil.copy(entry, dst) - - -def get_rocm_root() -> Optional[str]: - """Locate the ROCm SDK root (the dir containing bin/, lib/llvm/, etc.). - - Needed on Windows: hipcc forwards compile commands to clang but doesn't - pass `--rocm-path`, so clang fails to find the device library bitcode. - The TheRock SDK lives at venv/Lib/site-packages/_rocm_sdk_devel; we - discover it via the bundled `rocm-sdk` helper, then fall back to env - vars or hipcc's binary location. - """ - # 1) Use the rocm-sdk helper if installed (TheRock). - try: - out = subprocess.check_output( - ["rocm-sdk", "path", "--root"], universal_newlines=True - ).strip() - if out and os.path.isdir(out): - return out - except Exception: - pass - # 2) Standard env vars. - for key in ("ROCM_HOME", "ROCM_PATH", "HIP_PATH"): - val = os.environ.get(key) - if val and os.path.isdir(val) and os.path.isdir( - os.path.join(val, "lib", "llvm", "amdgcn", "bitcode") - ): - return val - # 3) Walk up from hipcc's location. - hipcc = shutil.which("hipcc") - if hipcc: - # hipcc usually lives at /bin/hipcc; walk up one level. - candidate = os.path.dirname(os.path.dirname(os.path.realpath(hipcc))) - if os.path.isdir( - os.path.join(candidate, "lib", "llvm", "amdgcn", "bitcode") - ): - return candidate - return None + shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") def get_rocm_agent_arch(): - # 1) Linux: rocm_agent_enumerator if present. exec_path = "/opt/rocm/bin/rocm_agent_enumerator" if os.path.isfile(exec_path) and os.access(exec_path, os.X_OK): arches = subprocess.check_output([exec_path], universal_newlines=True) - arch_list = [a for a in arches.strip().split() if a and a != "gfx000"] - if arch_list: - return arch_list[0] - # 2) Cross-platform: ask torch (works on Windows ROCm via TheRock). - try: - if torch.cuda.is_available() and torch.version.hip: - arch_name = torch.cuda.get_device_properties(0).gcnArchName - # gcnArchName looks like "gfx1200" or "gfx942:sramecc+:xnack-"; strip features. - return arch_name.split(":")[0] - except Exception as e: - print(f"torch-based GPU arch detection failed: {e}") - # 3) Last-resort fallback (preserves prior behaviour). - return "gfx942" + arch_list = arches.strip().split() + return arch_list[0] + elif torch.cuda.is_available() and torch.version.hip: + return torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] + else: + return "gfx942" def get_extensions(): @@ -650,22 +581,14 @@ def get_extensions(): and (torch.cuda.is_available() or os.getenv("HIP_ARCHITECTURE", "") != "") ): rename_cpp_cu(source_hip) - hip_version = get_hip_version(ROCM_HOME) - rocm_root = get_rocm_root() + hip_version = get_hip_version( + None if platform.system() == "Windows" else ROCM_HOME + ) source_hip_cu = [] for ff in source_hip: source_hip_cu += [ff.replace(".cpp", ".cu")] - # Mirror the CUDA-side XFORMERS_SELECTIVE_BUILD filter above so ROCm - # users also get a substring knob for dev-time iteration. Like the - # CUDA version, this is sharp-edged: the pattern must be wide enough - # to keep every instance referenced by the dispatcher TUs that remain - # in the build, or the link will fail. - if "XFORMERS_SELECTIVE_BUILD" in os.environ: - pattern = os.environ["XFORMERS_SELECTIVE_BUILD"] - source_hip_cu = [f for f in source_hip_cu if pattern in str(f)] - extension = CUDAExtension sources += source_hip_cu include_dirs += [ @@ -689,68 +612,40 @@ def get_extensions(): if arch == "native": arch = get_rocm_agent_arch() - # CDNA archs use MFMA. RDNA archs (gfx11xx / gfx12xx) use WMMA — the - # xformers wrapper layer also needs FMHA_BUILD_ON_GFX11/GFX12 defines - # so the per-arch warp_tile / pipeline selection picks WMMA-friendly - # shapes. CK tile itself supports both. - cdna_archs = ["gfx908", "gfx90a", "gfx942", "gfx950"] - rdna3_archs = [ - "gfx1100", "gfx1101", "gfx1102", "gfx1103", - "gfx1150", "gfx1151", "gfx1152", "gfx1153", - ] - rdna4_archs = ["gfx1200", "gfx1201"] - - if arch not in cdna_archs + rdna3_archs + rdna4_archs: + if ( + arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"] + and not arch.startswith(("gfx11", "gfx12")) + ): raise ValueError(f"Not supported AMD GPU arch: {arch}") if arch == "gfx950": cc_flag += ["-DFMHA_BUILD_ON_GFX950"] - elif arch in rdna3_archs: + elif arch.startswith("gfx11"): cc_flag += ["-DFMHA_BUILD_ON_GFX11"] - elif arch in rdna4_archs: + elif arch.startswith("gfx12"): cc_flag += ["-DFMHA_BUILD_ON_GFX12"] offload_compress_flag = [] if hip_version >= "6.2.": offload_compress_flag = ["--offload-compress"] - # MSVC's UCRT marks std::getenv as _CRT_INSECURE_DEPRECATE, which fires - # -Wdeprecated-declarations inside ck_tile/core/utility/env.hpp. That - # header is transitively included by virtually every HIP TU, so under - # -Werror the Windows build fails to link. Only silence this class of - # warning on Windows — Linux/CDNA keeps -Werror enforcement unchanged. - windows_warning_flags = ( - ["-Wno-deprecated-declarations"] - if platform.system() == "Windows" - else [] - ) - - rocm_path_flag: List[str] = [] - if rocm_root is not None: - rocm_path_flag.append(f"--rocm-path={rocm_root}") - # TheRock layout puts the device-library bitcode at - # /lib/llvm/amdgcn/bitcode rather than the upstream - # /amdgcn/bitcode that clang's --rocm-path discovery - # expects. Pass the explicit override so clang can resolve - # oclc_isa_version_*.bc et al. - device_lib = os.path.join( - rocm_root, "lib", "llvm", "amdgcn", "bitcode" - ) - if os.path.isdir(device_lib): - rocm_path_flag.append(f"--rocm-device-lib-path={device_lib}") + if platform.system() == "Windows": + cc_flag += [ + "-Wno-deprecated-declarations", + "-Wno-unused-command-line-argument", + "-Wno-unknown-attributes", + ] extra_compile_args["nvcc"] = [ "-O3", "-std=c++20", f"--offload-arch={arch}", - *rocm_path_flag, *offload_compress_flag, "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", "-Werror", - *windows_warning_flags, "-Wno-c++11-narrowing", "-Woverloaded-virtual", "-Wno-unknown-warning-option", @@ -820,7 +715,31 @@ def __init__(self, *args, **kwargs) -> None: self.pkg_name = "xformers" super().__init__(*args, **kwargs) + def _use_windows_link_response_files(self) -> None: + if platform.system() != "Windows": + return + + original_spawn = self.compiler.spawn + + def spawn_with_response_file(cmd): + tool = os.path.basename(str(cmd[0])).lower() if cmd else "" + command_length = sum(len(str(arg)) + 1 for arg in cmd) + if tool not in {"link.exe", "lld-link.exe"} or command_length < 30_000: + return original_spawn(cmd) + + rsp_path = Path(self.build_temp, f"xformers_link_{id(cmd)}.rsp").resolve() + rsp_path.parent.mkdir(parents=True, exist_ok=True) + rsp_path.write_text( + "\n".join(subprocess.list2cmdline([str(arg)]) for arg in cmd[1:]) + + "\n" + ) + + return original_spawn([cmd[0], f"@{rsp_path}"]) + + self.compiler.spawn = spawn_with_response_file + def build_extensions(self) -> None: + self._use_windows_link_response_files() super().build_extensions() # Fix incorrect output names caused by py_limited_api=True on Windows. see item #1272 diff --git a/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp b/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp index 08d9ae13a3..f07ab44cd2 100644 --- a/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp @@ -13,13 +13,63 @@ #include #include -#include -#include +#include -#include "ck_tiled_rand_uniform_kernel.h" +#include namespace { +__global__ void rand_uniform_int_kernel( + uint8_t* randvals, + int M, + int N, + int num_heads, + int64_t stride_m, + int64_t stride_n, + int64_t stride_head, + int64_t stride_batch, + uint64_t philox_seed, + uint64_t philox_offset) { + constexpr int kPhiloxPerTile = 64; + constexpr int kWarpGemmMN = 32; + + const int row = blockIdx.x; + const int col = blockIdx.y; + const int batch_head = blockIdx.z; + const int i_batch = batch_head / num_heads; + const int i_head = batch_head - i_batch * num_heads; + const int lane = threadIdx.x; + + if (lane >= kPhiloxPerTile) { + return; + } + + const auto subsequence = + ck_tile::bit_cast(make_uint2(row, col)); + ck_tile::philox ph( + philox_seed, + philox_offset + (i_batch * num_heads + i_head) * kPhiloxPerTile + lane); + + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, subsequence); + + uint8_t* out = randvals + + static_cast(i_batch) * stride_batch + + static_cast(i_head) * stride_head; + + for (int r = 0; r < 16; ++r) { + const int i = (16 * (r / 8) % kWarpGemmMN) + 8 * (lane / 32) + (r % 8); + const int j = lane % kWarpGemmMN; + const int m = row * kWarpGemmMN + i; + const int n = col * kWarpGemmMN + j; + + if (m < M && n < N) { + out[static_cast(m) * stride_m + + static_cast(n) * stride_n] = random_uint8_t[r]; + } + } +} + /** * generate a tensor with random uniform values. only used for testing, not much * attention is paid to performance @@ -28,6 +78,8 @@ at::Tensor rand_uniform_int( double dropout_prob, const at::Tensor& out_pattern) // [Batches, num_head, query_len, key_len] { + (void)dropout_prob; + int B = out_pattern.size(0); int num_heads = out_pattern.size(1); int M = out_pattern.size(2); @@ -56,34 +108,43 @@ at::Tensor rand_uniform_int( randvals = at::empty( {B, num_heads, M, N}, out_pattern.options().dtype(at::ScalarType::Byte)); - { - // only work for batched mode - using FmhaRandUniformKernel_ = FmhaRandUniformKernel; - - const auto kargs = FmhaRandUniformKernel_::MakeKargs( - randvals.data_ptr(), + if (B > 0 && num_heads > 0 && M > 0 && N > 0) { + constexpr int kWarpGemmMN = 32; + const dim3 grid( + (M + kWarpGemmMN - 1) / kWarpGemmMN, + (N + kWarpGemmMN - 1) / kWarpGemmMN, + B * num_heads); + const dim3 block(64); + + hipLaunchKernelGGL( + rand_uniform_int_kernel, + grid, + block, + 0, + stream, + static_cast(randvals.data_ptr()), M, N, num_heads, - B, - static_cast(randvals.stride(2)), - static_cast(randvals.stride(3)), - static_cast(randvals.stride(1)), - static_cast(randvals.stride(0)), - {philox_seed, philox_offset}); - - dim3 kGridSize = FmhaRandUniformKernel_::GridSize(B, num_heads, M, N); - dim3 kBlockSize = FmhaRandUniformKernel_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = - FmhaRandUniformKernel_::kBlockPerCu; - - (void)ck_tile::launch_kernel( - ck_tile::stream_config{stream, false}, - ck_tile::make_kernel( - FmhaRandUniformKernel_{}, kGridSize, kBlockSize, 0, kargs)); + randvals.stride(2), + randvals.stride(3), + randvals.stride(1), + randvals.stride(0), + static_cast(philox_seed), + static_cast(philox_offset)); + + const auto launch_status = hipGetLastError(); + TORCH_CHECK( + launch_status == hipSuccess, + "HIP rand_uniform_int_kernel launch failed: ", + hipGetErrorString(launch_status)); } - (void)hipStreamSynchronize(stream); + const auto sync_status = hipStreamSynchronize(stream); + TORCH_CHECK( + sync_status == hipSuccess, + "HIP rand_uniform_int_kernel failed: ", + hipGetErrorString(sync_status)); return randvals; } // namespace diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h index 801960a432..9a21459f7a 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h @@ -171,8 +171,12 @@ struct FmhaRandUniformKernel { return ck_tile::make_tuple(i_block, i_nhead, i_batch); } - __host__ static constexpr auto BlockSize() { - return dim3(kBlockSize); + __host__ static dim3 BlockSize() { + if (ck_tile::is_wave32()) { + return dim3(kBlockSize / 2); + } else { + return dim3(kBlockSize); + } } __device__ static constexpr ck_tile::index_t GetSmemSize() { diff --git a/xformers/ops/differentiable_collectives.py b/xformers/ops/differentiable_collectives.py index 1930ffe0a9..e4a49cecb9 100644 --- a/xformers/ops/differentiable_collectives.py +++ b/xformers/ops/differentiable_collectives.py @@ -3,18 +3,13 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Tuple +from typing import Optional, Tuple import torch import torch.distributed -if TYPE_CHECKING: - # torch.distributed.Work is not exposed on every build (notably Windows - # ROCm / TheRock), so we only import it for type checkers. The runtime - # annotation is a forward-ref string that is never evaluated. - from torch.distributed import Work - def all_reduce( x: torch.Tensor, *, process_group: torch.distributed.ProcessGroup @@ -30,7 +25,7 @@ def all_reduce( def gather_along_first_dim_async( input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup -) -> Tuple[torch.Tensor, Optional["Work"]]: +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: mp_size = process_group.size() if mp_size == 1: return input_, None @@ -48,7 +43,7 @@ def gather_along_first_dim_async( def reduce_scatter_along_first_dim_async( input_: torch.Tensor, *, process_group: torch.distributed.ProcessGroup -) -> Tuple[torch.Tensor, Optional["Work"]]: +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: mp_size = process_group.size() if mp_size == 1: return input_, None diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 1201b3360c..8b6701e683 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -73,7 +73,7 @@ FLASH_VERSION = flash_attn.__version__ FLASH_VER_MIN = (2, 7, 1) - FLASH_VER_LAST = (2, 8, 0) # last supported, inclusive + FLASH_VER_LAST = (2, 8, 4) # last supported, inclusive flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) if ( flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST @@ -358,6 +358,68 @@ def _create_dq_dk_dv( return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) +def _pack_kv_for_seqused_k_flash( + attn_bias: AttentionBias, + key: torch.Tensor, + value: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + key_chunks: List[torch.Tensor] = [] + value_chunks: List[torch.Tensor] = [] + cu_seqlens_k = [0] + + if isinstance( + attn_bias, (PagedBlockDiagonalGappyKeysMask, PagedBlockDiagonalPaddedKeysMask) + ): + assert key.ndim == 4 + assert value.ndim == 4 + for row_idx, seqlen in enumerate(attn_bias.k_seqinfo.seqlen_py): + page_indices = attn_bias.block_tables[row_idx].to( + device=key.device, dtype=torch.long + ) + row_key = key.index_select(0, page_indices).reshape( + [-1, *key.shape[2:]] + ) + row_value = value.index_select(0, page_indices).reshape( + [-1, *value.shape[2:]] + ) + if isinstance(attn_bias, PagedBlockDiagonalGappyKeysMask): + start = attn_bias.k_seqinfo.seqstart_py[row_idx] + end = seqlen + else: + start = 0 + end = seqlen + key_chunks.append(row_key[start:end]) + value_chunks.append(row_value[start:end]) + cu_seqlens_k.append(cu_seqlens_k[-1] + end - start) + else: + assert key.ndim == 3 + assert value.ndim == 3 + for start, end in attn_bias.k_seqinfo.intervals(): + key_chunks.append(key[start:end]) + value_chunks.append(value[start:end]) + cu_seqlens_k.append(cu_seqlens_k[-1] + end - start) + + if key_chunks: + key = torch.cat(key_chunks, dim=0) + value = torch.cat(value_chunks, dim=0) + else: + key = key.new_empty([0, *key.shape[-2:]]) + value = value.new_empty([0, *value.shape[-2:]]) + + cu_seqlens_k_tensor = torch.tensor( + cu_seqlens_k, dtype=torch.int32, device=key.device + ) + max_seqlen_k = max( + (end - start for start, end in zip(cu_seqlens_k[:-1], cu_seqlens_k[1:])), + default=0, + ) + return key, value, cu_seqlens_k_tensor, max_seqlen_k + + +def _is_rocm_device(device: torch.device) -> bool: + return torch.version.hip is not None and device.type == "cuda" + + def _convert_input_format( inp: Inputs, supports_mqa: bool, @@ -446,6 +508,23 @@ def fold(x): num_pages = value.shape[0] // attn_bias.page_size key = key.view(num_pages, attn_bias.page_size, *key.shape[1:]) value = value.view(num_pages, attn_bias.page_size, *value.shape[1:]) + if _is_rocm_device(query.device) and isinstance( + attn_bias, + ( + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ), + ): + # The ROCm CK flash-attn varlen path accepts seqused_k but does not + # consume it in the CK host wrapper, so padded/gappy K/V layouts can + # be read as real tokens. Compacting logical K/V sequences lets us + # use the ordinary cu_seqlens path. + key, value, cu_seqlen_k, max_seqlen_k = _pack_kv_for_seqused_k_flash( + attn_bias, key, value + ) + seqused_k = None new_inp = Inputs( query=query, @@ -604,6 +683,13 @@ class FwOp(AttentionFwOpBase): CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} SUPPORTED_MAX_K = 256 + # ROCm flash-attn bf16 forward (observed with 2.8.4 on gfx1201) can + # exceed the default absolute tolerance while staying within relative + # tolerance, so keep the relaxed threshold scoped to HIP bf16 only. + ERROR_ATOL = { + **AttentionFwOpBase.ERROR_ATOL, + **({torch.bfloat16: 4e-2} if torch.version.hip is not None else {}), + } SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( type(None), LowerTriangularMask, @@ -683,6 +769,7 @@ def apply( block_tables = ( inp.attn_bias.block_tables if isinstance(inp.attn_bias, PagedBlockDiagonalPaddedKeysMask) + and inp.key.ndim == 4 else None ) out, softmax_lse, rng_state = cls.OPERATOR( @@ -731,9 +818,9 @@ def apply( out=out, lse=_post_process_lse(softmax_lse, inp, original_query_shape), ) + ctx.rng_state = rng_state if inp.p != 0.0: ctx.op_bw = BwOp - ctx.rng_state = rng_state return (out, ctx) @@ -817,6 +904,14 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: ] assert grad.dtype in cls.SUPPORTED_DTYPES + rng_state = ctx.rng_state + if rng_state is None: + if inp.p != 0.0: + raise RuntimeError( + "Flash-Attention backward requires an RNG state when dropout is enabled" + ) + rng_state = torch.zeros((2,), dtype=torch.int64, device=inp.query.device) + if inp.query.numel() and inp.key.numel(): win_left, win_right = _window_size(inp.attn_bias) grads = Gradients( @@ -837,7 +932,7 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: _is_causal(inp.attn_bias), window_left=win_left, window_right=win_right, - rng_state=ctx.rng_state if inp.p > 0.0 else None, + rng_state=rng_state, ) ) else: