diff --git a/setup.py b/setup.py index 6306cefea2..6c935c9596 100644 --- a/setup.py +++ b/setup.py @@ -404,6 +404,8 @@ def get_rocm_agent_arch(): arches = subprocess.check_output([exec_path], universal_newlines=True) arch_list = arches.strip().split() return arch_list[0] + elif torch.cuda.is_available() and torch.version.hip: + return torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] else: return "gfx942" @@ -579,7 +581,9 @@ def get_extensions(): and (torch.cuda.is_available() or os.getenv("HIP_ARCHITECTURE", "") != "") ): rename_cpp_cu(source_hip) - hip_version = get_hip_version(ROCM_HOME) + hip_version = get_hip_version( + None if platform.system() == "Windows" else ROCM_HOME + ) source_hip_cu = [] for ff in source_hip: @@ -608,16 +612,30 @@ def get_extensions(): if arch == "native": arch = get_rocm_agent_arch() - if arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"]: + if ( + arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"] + and not arch.startswith(("gfx11", "gfx12")) + ): raise ValueError(f"Not supported AMD GPU arch: {arch}") if arch == "gfx950": cc_flag += ["-DFMHA_BUILD_ON_GFX950"] + elif arch.startswith("gfx11"): + cc_flag += ["-DFMHA_BUILD_ON_GFX11"] + elif arch.startswith("gfx12"): + cc_flag += ["-DFMHA_BUILD_ON_GFX12"] offload_compress_flag = [] if hip_version >= "6.2.": offload_compress_flag = ["--offload-compress"] + if platform.system() == "Windows": + cc_flag += [ + "-Wno-deprecated-declarations", + "-Wno-unused-command-line-argument", + "-Wno-unknown-attributes", + ] + extra_compile_args["nvcc"] = [ "-O3", "-std=c++20", @@ -697,7 +715,31 @@ def __init__(self, *args, **kwargs) -> None: self.pkg_name = "xformers" super().__init__(*args, **kwargs) + def _use_windows_link_response_files(self) -> None: + if platform.system() != "Windows": + return + + original_spawn = self.compiler.spawn + + def spawn_with_response_file(cmd): + tool = os.path.basename(str(cmd[0])).lower() if cmd else "" + command_length = sum(len(str(arg)) + 1 for arg in cmd) + if tool not in {"link.exe", "lld-link.exe"} or command_length < 30_000: + return original_spawn(cmd) + + rsp_path = Path(self.build_temp, f"xformers_link_{id(cmd)}.rsp").resolve() + rsp_path.parent.mkdir(parents=True, exist_ok=True) + rsp_path.write_text( + "\n".join(subprocess.list2cmdline([str(arg)]) for arg in cmd[1:]) + + "\n" + ) + + return original_spawn([cmd[0], f"@{rsp_path}"]) + + self.compiler.spawn = spawn_with_response_file + def build_extensions(self) -> None: + self._use_windows_link_response_files() super().build_extensions() # Fix incorrect output names caused by py_limited_api=True on Windows. see item #1272 diff --git a/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp b/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp index 08d9ae13a3..f07ab44cd2 100644 --- a/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp +++ b/xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp @@ -13,13 +13,63 @@ #include #include -#include -#include +#include -#include "ck_tiled_rand_uniform_kernel.h" +#include namespace { +__global__ void rand_uniform_int_kernel( + uint8_t* randvals, + int M, + int N, + int num_heads, + int64_t stride_m, + int64_t stride_n, + int64_t stride_head, + int64_t stride_batch, + uint64_t philox_seed, + uint64_t philox_offset) { + constexpr int kPhiloxPerTile = 64; + constexpr int kWarpGemmMN = 32; + + const int row = blockIdx.x; + const int col = blockIdx.y; + const int batch_head = blockIdx.z; + const int i_batch = batch_head / num_heads; + const int i_head = batch_head - i_batch * num_heads; + const int lane = threadIdx.x; + + if (lane >= kPhiloxPerTile) { + return; + } + + const auto subsequence = + ck_tile::bit_cast(make_uint2(row, col)); + ck_tile::philox ph( + philox_seed, + philox_offset + (i_batch * num_heads + i_head) * kPhiloxPerTile + lane); + + uint8_t random_uint8_t[16]; + ph.get_random_16x8(random_uint8_t, subsequence); + + uint8_t* out = randvals + + static_cast(i_batch) * stride_batch + + static_cast(i_head) * stride_head; + + for (int r = 0; r < 16; ++r) { + const int i = (16 * (r / 8) % kWarpGemmMN) + 8 * (lane / 32) + (r % 8); + const int j = lane % kWarpGemmMN; + const int m = row * kWarpGemmMN + i; + const int n = col * kWarpGemmMN + j; + + if (m < M && n < N) { + out[static_cast(m) * stride_m + + static_cast(n) * stride_n] = random_uint8_t[r]; + } + } +} + /** * generate a tensor with random uniform values. only used for testing, not much * attention is paid to performance @@ -28,6 +78,8 @@ at::Tensor rand_uniform_int( double dropout_prob, const at::Tensor& out_pattern) // [Batches, num_head, query_len, key_len] { + (void)dropout_prob; + int B = out_pattern.size(0); int num_heads = out_pattern.size(1); int M = out_pattern.size(2); @@ -56,34 +108,43 @@ at::Tensor rand_uniform_int( randvals = at::empty( {B, num_heads, M, N}, out_pattern.options().dtype(at::ScalarType::Byte)); - { - // only work for batched mode - using FmhaRandUniformKernel_ = FmhaRandUniformKernel; - - const auto kargs = FmhaRandUniformKernel_::MakeKargs( - randvals.data_ptr(), + if (B > 0 && num_heads > 0 && M > 0 && N > 0) { + constexpr int kWarpGemmMN = 32; + const dim3 grid( + (M + kWarpGemmMN - 1) / kWarpGemmMN, + (N + kWarpGemmMN - 1) / kWarpGemmMN, + B * num_heads); + const dim3 block(64); + + hipLaunchKernelGGL( + rand_uniform_int_kernel, + grid, + block, + 0, + stream, + static_cast(randvals.data_ptr()), M, N, num_heads, - B, - static_cast(randvals.stride(2)), - static_cast(randvals.stride(3)), - static_cast(randvals.stride(1)), - static_cast(randvals.stride(0)), - {philox_seed, philox_offset}); - - dim3 kGridSize = FmhaRandUniformKernel_::GridSize(B, num_heads, M, N); - dim3 kBlockSize = FmhaRandUniformKernel_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = - FmhaRandUniformKernel_::kBlockPerCu; - - (void)ck_tile::launch_kernel( - ck_tile::stream_config{stream, false}, - ck_tile::make_kernel( - FmhaRandUniformKernel_{}, kGridSize, kBlockSize, 0, kargs)); + randvals.stride(2), + randvals.stride(3), + randvals.stride(1), + randvals.stride(0), + static_cast(philox_seed), + static_cast(philox_offset)); + + const auto launch_status = hipGetLastError(); + TORCH_CHECK( + launch_status == hipSuccess, + "HIP rand_uniform_int_kernel launch failed: ", + hipGetErrorString(launch_status)); } - (void)hipStreamSynchronize(stream); + const auto sync_status = hipStreamSynchronize(stream); + TORCH_CHECK( + sync_status == hipSuccess, + "HIP rand_uniform_int_kernel failed: ", + hipGetErrorString(sync_status)); return randvals; } // namespace diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h index 55aac7814f..4f5a8852b6 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h @@ -8,6 +8,18 @@ #include +#if defined(FMHA_BUILD_ON_GFX11) || defined(FMHA_BUILD_ON_GFX12) +// RDNA WMMA only registers 16x16x16 in CK tile's warp_gemm dispatcher for +// fp16 / bf16 (and fp8 / int8). The xformers wrapper layer was authored for +// CDNA MFMA shapes (32x32x16 and 16x16x32). To keep the per-arch *_setting.h +// headers small, we alias all three names down to the WMMA-supported shape. +// Block tile geometry is unchanged — each warp simply issues more WMMA +// instructions to cover the same output tile. +using WarpTile_32x32x16 = ck_tile::sequence<16, 16, 16>; +using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 16>; +using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>; +#else using WarpTile_32x32x16 = ck_tile::sequence<32, 32, 16>; using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>; using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>; +#endif diff --git a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h index 801960a432..9a21459f7a 100644 --- a/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h +++ b/xformers/csrc/attention/hip_fmha/ck_tiled_rand_uniform_kernel.h @@ -171,8 +171,12 @@ struct FmhaRandUniformKernel { return ck_tile::make_tuple(i_block, i_nhead, i_batch); } - __host__ static constexpr auto BlockSize() { - return dim3(kBlockSize); + __host__ static dim3 BlockSize() { + if (ck_tile::is_wave32()) { + return dim3(kBlockSize / 2); + } else { + return dim3(kBlockSize); + } } __device__ static constexpr ck_tile::index_t GetSmemSize() { diff --git a/xformers/ops/differentiable_collectives.py b/xformers/ops/differentiable_collectives.py index 9d87629dd1..e4a49cecb9 100644 --- a/xformers/ops/differentiable_collectives.py +++ b/xformers/ops/differentiable_collectives.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations from typing import Optional, Tuple diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 1201b3360c..8b6701e683 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -73,7 +73,7 @@ FLASH_VERSION = flash_attn.__version__ FLASH_VER_MIN = (2, 7, 1) - FLASH_VER_LAST = (2, 8, 0) # last supported, inclusive + FLASH_VER_LAST = (2, 8, 4) # last supported, inclusive flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) if ( flash_ver_parsed < FLASH_VER_MIN or flash_ver_parsed > FLASH_VER_LAST @@ -358,6 +358,68 @@ def _create_dq_dk_dv( return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) +def _pack_kv_for_seqused_k_flash( + attn_bias: AttentionBias, + key: torch.Tensor, + value: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + key_chunks: List[torch.Tensor] = [] + value_chunks: List[torch.Tensor] = [] + cu_seqlens_k = [0] + + if isinstance( + attn_bias, (PagedBlockDiagonalGappyKeysMask, PagedBlockDiagonalPaddedKeysMask) + ): + assert key.ndim == 4 + assert value.ndim == 4 + for row_idx, seqlen in enumerate(attn_bias.k_seqinfo.seqlen_py): + page_indices = attn_bias.block_tables[row_idx].to( + device=key.device, dtype=torch.long + ) + row_key = key.index_select(0, page_indices).reshape( + [-1, *key.shape[2:]] + ) + row_value = value.index_select(0, page_indices).reshape( + [-1, *value.shape[2:]] + ) + if isinstance(attn_bias, PagedBlockDiagonalGappyKeysMask): + start = attn_bias.k_seqinfo.seqstart_py[row_idx] + end = seqlen + else: + start = 0 + end = seqlen + key_chunks.append(row_key[start:end]) + value_chunks.append(row_value[start:end]) + cu_seqlens_k.append(cu_seqlens_k[-1] + end - start) + else: + assert key.ndim == 3 + assert value.ndim == 3 + for start, end in attn_bias.k_seqinfo.intervals(): + key_chunks.append(key[start:end]) + value_chunks.append(value[start:end]) + cu_seqlens_k.append(cu_seqlens_k[-1] + end - start) + + if key_chunks: + key = torch.cat(key_chunks, dim=0) + value = torch.cat(value_chunks, dim=0) + else: + key = key.new_empty([0, *key.shape[-2:]]) + value = value.new_empty([0, *value.shape[-2:]]) + + cu_seqlens_k_tensor = torch.tensor( + cu_seqlens_k, dtype=torch.int32, device=key.device + ) + max_seqlen_k = max( + (end - start for start, end in zip(cu_seqlens_k[:-1], cu_seqlens_k[1:])), + default=0, + ) + return key, value, cu_seqlens_k_tensor, max_seqlen_k + + +def _is_rocm_device(device: torch.device) -> bool: + return torch.version.hip is not None and device.type == "cuda" + + def _convert_input_format( inp: Inputs, supports_mqa: bool, @@ -446,6 +508,23 @@ def fold(x): num_pages = value.shape[0] // attn_bias.page_size key = key.view(num_pages, attn_bias.page_size, *key.shape[1:]) value = value.view(num_pages, attn_bias.page_size, *value.shape[1:]) + if _is_rocm_device(query.device) and isinstance( + attn_bias, + ( + BlockDiagonalGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ), + ): + # The ROCm CK flash-attn varlen path accepts seqused_k but does not + # consume it in the CK host wrapper, so padded/gappy K/V layouts can + # be read as real tokens. Compacting logical K/V sequences lets us + # use the ordinary cu_seqlens path. + key, value, cu_seqlen_k, max_seqlen_k = _pack_kv_for_seqused_k_flash( + attn_bias, key, value + ) + seqused_k = None new_inp = Inputs( query=query, @@ -604,6 +683,13 @@ class FwOp(AttentionFwOpBase): CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} SUPPORTED_MAX_K = 256 + # ROCm flash-attn bf16 forward (observed with 2.8.4 on gfx1201) can + # exceed the default absolute tolerance while staying within relative + # tolerance, so keep the relaxed threshold scoped to HIP bf16 only. + ERROR_ATOL = { + **AttentionFwOpBase.ERROR_ATOL, + **({torch.bfloat16: 4e-2} if torch.version.hip is not None else {}), + } SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = ( type(None), LowerTriangularMask, @@ -683,6 +769,7 @@ def apply( block_tables = ( inp.attn_bias.block_tables if isinstance(inp.attn_bias, PagedBlockDiagonalPaddedKeysMask) + and inp.key.ndim == 4 else None ) out, softmax_lse, rng_state = cls.OPERATOR( @@ -731,9 +818,9 @@ def apply( out=out, lse=_post_process_lse(softmax_lse, inp, original_query_shape), ) + ctx.rng_state = rng_state if inp.p != 0.0: ctx.op_bw = BwOp - ctx.rng_state = rng_state return (out, ctx) @@ -817,6 +904,14 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: ] assert grad.dtype in cls.SUPPORTED_DTYPES + rng_state = ctx.rng_state + if rng_state is None: + if inp.p != 0.0: + raise RuntimeError( + "Flash-Attention backward requires an RNG state when dropout is enabled" + ) + rng_state = torch.zeros((2,), dtype=torch.int64, device=inp.query.device) + if inp.query.numel() and inp.key.numel(): win_left, win_right = _window_size(inp.attn_bias) grads = Gradients( @@ -837,7 +932,7 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: _is_causal(inp.attn_bias), window_left=win_left, window_right=win_right, - rng_state=ctx.rng_state if inp.p > 0.0 else None, + rng_state=rng_state, ) ) else: diff --git a/xformers/ops/seqpar.py b/xformers/ops/seqpar.py index b81eaf3304..3114e7495b 100644 --- a/xformers/ops/seqpar.py +++ b/xformers/ops/seqpar.py @@ -7,7 +7,15 @@ from typing import Callable, List, Tuple import torch -from torch.distributed.distributed_c10d import _resolve_process_group + +# Some torch builds (e.g. Windows ROCm / TheRock) ship a stripped-down +# torch.distributed without distributed_c10d. Importing this module must not +# break xformers for single-GPU users; the fallback is only reachable if one +# of the sequence-parallel ops below is actually called. +try: + from torch.distributed.distributed_c10d import _resolve_process_group +except ImportError: + _resolve_process_group = None # type: ignore[assignment] from .differentiable_collectives import ( gather_along_first_dim, diff --git a/xformers/ops/sequence_parallel_fused_ops.py b/xformers/ops/sequence_parallel_fused_ops.py index 1f76667382..179125abd6 100644 --- a/xformers/ops/sequence_parallel_fused_ops.py +++ b/xformers/ops/sequence_parallel_fused_ops.py @@ -9,7 +9,15 @@ import torch import torch.distributed as dist import torch.multiprocessing.reductions -from torch.distributed._symmetric_memory import get_symm_mem_workspace + +# torch.distributed._symmetric_memory is absent from some torch builds +# (e.g. Windows ROCm / TheRock). Guard the import so xformers stays usable +# on single-GPU; the None fallback is only reached if the fused ops below +# are actually invoked. +try: + from torch.distributed._symmetric_memory import get_symm_mem_workspace +except ImportError: + get_symm_mem_workspace = None # type: ignore[assignment] OP_FINISHED_CHANNEL = 0 COMMS_READY_CHANNEL = 1