Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
[submodule "third_party/composable_kernel_tiled"]
path = third_party/composable_kernel_tiled
url = https://github.com/ROCm/composable_kernel.git
branch = develop
branch = users/jam/rdna3-rdna4-fmha-tile-load-fixes
62 changes: 54 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def get_rocm_agent_arch():
arches = subprocess.check_output([exec_path], universal_newlines=True)
arch_list = arches.strip().split()
return arch_list[0]
elif torch.cuda.is_available() and torch.version.hip:
return torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
else:
return "gfx942"

Expand Down Expand Up @@ -579,7 +581,9 @@ def get_extensions():
and (torch.cuda.is_available() or os.getenv("HIP_ARCHITECTURE", "") != "")
):
rename_cpp_cu(source_hip)
hip_version = get_hip_version(ROCM_HOME)
hip_version = get_hip_version(
None if platform.system() == "Windows" else ROCM_HOME
)

source_hip_cu = []
for ff in source_hip:
Expand All @@ -597,27 +601,45 @@ def get_extensions():
]

cc_flag = ["-DBUILD_PYTHON_PACKAGE"]
use_rtn_bf16_convert = os.getenv("ENABLE_HIP_FMHA_RTN_BF16_CONVERT", "0")
if use_rtn_bf16_convert == "1":
cc_flag += ["-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3"]
else:
cc_flag += ["-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=2"]

arch = os.getenv("HIP_ARCHITECTURE", "native")

if arch == "native":
arch = get_rocm_agent_arch()

if arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"]:
use_rtn_bf16_convert = os.getenv("ENABLE_HIP_FMHA_RTN_BF16_CONVERT")
if use_rtn_bf16_convert is None:
# RDNA CK bf16 forward output can be reused by FlashAttention backward;
# truncating fp32 accumulators is not accurate enough for that pairing.
use_rtn_bf16_convert = "1" if arch.startswith(("gfx11", "gfx12")) else "0"
if use_rtn_bf16_convert == "1":
cc_flag += ["-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3"]
else:
cc_flag += ["-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=2"]

if (
arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"]
and not arch.startswith(("gfx11", "gfx12"))
):
raise ValueError(f"Not supported AMD GPU arch: {arch}")

if arch == "gfx950":
cc_flag += ["-DFMHA_BUILD_ON_GFX950"]
elif arch.startswith("gfx11"):
cc_flag += ["-DFMHA_BUILD_ON_GFX11"]
elif arch.startswith("gfx12"):
cc_flag += ["-DFMHA_BUILD_ON_GFX12"]

offload_compress_flag = []
if hip_version >= "6.2.":
offload_compress_flag = ["--offload-compress"]

if platform.system() == "Windows":
cc_flag += [
"-Wno-deprecated-declarations",
"-Wno-unused-command-line-argument",
"-Wno-unknown-attributes",
]

extra_compile_args["nvcc"] = [
"-O3",
"-std=c++20",
Expand Down Expand Up @@ -697,7 +719,31 @@ def __init__(self, *args, **kwargs) -> None:
self.pkg_name = "xformers"
super().__init__(*args, **kwargs)

def _use_windows_link_response_files(self) -> None:
if platform.system() != "Windows":
return

original_spawn = self.compiler.spawn

def spawn_with_response_file(cmd):
tool = os.path.basename(str(cmd[0])).lower() if cmd else ""
command_length = sum(len(str(arg)) + 1 for arg in cmd)
if tool not in {"link.exe", "lld-link.exe"} or command_length < 30_000:
return original_spawn(cmd)

rsp_path = Path(self.build_temp, f"xformers_link_{id(cmd)}.rsp").resolve()
rsp_path.parent.mkdir(parents=True, exist_ok=True)
rsp_path.write_text(
"\n".join(subprocess.list2cmdline([str(arg)]) for arg in cmd[1:])
+ "\n"
)

return original_spawn([cmd[0], f"@{rsp_path}"])

self.compiler.spawn = spawn_with_response_file

def build_extensions(self) -> None:
self._use_windows_link_response_files()
super().build_extensions()

# Fix incorrect output names caused by py_limited_api=True on Windows. see item #1272
Expand Down
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
113 changes: 87 additions & 26 deletions xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,63 @@
#include <torch/types.h>
#include <ATen/cuda/PhiloxUtils.cuh>

#include <ck_tile/core.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <cstdint>

#include "ck_tiled_rand_uniform_kernel.h"
#include <ck_tile/core.hpp>

namespace {

__global__ void rand_uniform_int_kernel(
uint8_t* randvals,
int M,
int N,
int num_heads,
int64_t stride_m,
int64_t stride_n,
int64_t stride_head,
int64_t stride_batch,
uint64_t philox_seed,
uint64_t philox_offset) {
constexpr int kPhiloxPerTile = 64;
constexpr int kWarpGemmMN = 32;

const int row = blockIdx.x;
const int col = blockIdx.y;
const int batch_head = blockIdx.z;
const int i_batch = batch_head / num_heads;
const int i_head = batch_head - i_batch * num_heads;
const int lane = threadIdx.x;

if (lane >= kPhiloxPerTile) {
return;
}

const auto subsequence =
ck_tile::bit_cast<unsigned long long>(make_uint2(row, col));
ck_tile::philox ph(
philox_seed,
philox_offset + (i_batch * num_heads + i_head) * kPhiloxPerTile + lane);

uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, subsequence);

uint8_t* out = randvals +
static_cast<int64_t>(i_batch) * stride_batch +
static_cast<int64_t>(i_head) * stride_head;

for (int r = 0; r < 16; ++r) {
const int i = (16 * (r / 8) % kWarpGemmMN) + 8 * (lane / 32) + (r % 8);
const int j = lane % kWarpGemmMN;
const int m = row * kWarpGemmMN + i;
const int n = col * kWarpGemmMN + j;

if (m < M && n < N) {
out[static_cast<int64_t>(m) * stride_m +
static_cast<int64_t>(n) * stride_n] = random_uint8_t[r];
}
}
}

/**
* generate a tensor with random uniform values. only used for testing, not much
* attention is paid to performance
Expand All @@ -28,6 +78,8 @@ at::Tensor rand_uniform_int(
double dropout_prob,
const at::Tensor& out_pattern) // [Batches, num_head, query_len, key_len]
{
(void)dropout_prob;

int B = out_pattern.size(0);
int num_heads = out_pattern.size(1);
int M = out_pattern.size(2);
Expand Down Expand Up @@ -56,34 +108,43 @@ at::Tensor rand_uniform_int(
randvals = at::empty(
{B, num_heads, M, N}, out_pattern.options().dtype(at::ScalarType::Byte));

{
// only work for batched mode
using FmhaRandUniformKernel_ = FmhaRandUniformKernel<uint8_t, false>;

const auto kargs = FmhaRandUniformKernel_::MakeKargs(
randvals.data_ptr(),
if (B > 0 && num_heads > 0 && M > 0 && N > 0) {
constexpr int kWarpGemmMN = 32;
const dim3 grid(
(M + kWarpGemmMN - 1) / kWarpGemmMN,
(N + kWarpGemmMN - 1) / kWarpGemmMN,
B * num_heads);
const dim3 block(64);

hipLaunchKernelGGL(
rand_uniform_int_kernel,
grid,
block,
0,
stream,
static_cast<uint8_t*>(randvals.data_ptr()),
M,
N,
num_heads,
B,
static_cast<int>(randvals.stride(2)),
static_cast<int>(randvals.stride(3)),
static_cast<int>(randvals.stride(1)),
static_cast<int>(randvals.stride(0)),
{philox_seed, philox_offset});

dim3 kGridSize = FmhaRandUniformKernel_::GridSize(B, num_heads, M, N);
dim3 kBlockSize = FmhaRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu =
FmhaRandUniformKernel_::kBlockPerCu;

(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(
FmhaRandUniformKernel_{}, kGridSize, kBlockSize, 0, kargs));
randvals.stride(2),
randvals.stride(3),
randvals.stride(1),
randvals.stride(0),
static_cast<uint64_t>(philox_seed),
static_cast<uint64_t>(philox_offset));

const auto launch_status = hipGetLastError();
TORCH_CHECK(
launch_status == hipSuccess,
"HIP rand_uniform_int_kernel launch failed: ",
hipGetErrorString(launch_status));
}

(void)hipStreamSynchronize(stream);
const auto sync_status = hipStreamSynchronize(stream);
TORCH_CHECK(
sync_status == hipSuccess,
"HIP rand_uniform_int_kernel failed: ",
hipGetErrorString(sync_status));

return randvals;
} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,19 @@ struct batched_infer_mask_bias_dropout_dispatch {
};
#endif

#if defined(FMHA_BUILD_ON_GFX11) || defined(FMHA_BUILD_ON_GFX12)
// Current RDNA3/4 CK FMHA builds use the sync pipeline; the async
// global-to-LDS path fails for these targets, so keep it uninstantiated
// instead of relying on a core CK fallback.
constexpr bool enable_async_pipeline = false;
#else
const bool enable_async_pipeline = []() {
const char* env_p = std::getenv("FMHA_ENABLE_ASYNC_PIPELINE");
if (env_p == nullptr)
return false;
return static_cast<bool>(atoi(env_p));
}();
#endif

const bool use_async_pipeline =
(!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) &&
Expand Down Expand Up @@ -267,7 +274,9 @@ struct batched_infer_mask_bias_dropout_dispatch {
RunWithKernel<FmhaKernel>(param, stream);
}
});
} else {
}
#if !defined(FMHA_BUILD_ON_GFX11) && !defined(FMHA_BUILD_ON_GFX12)
else {
using FmhaShape = typename FmhaFwdCommonShape<MaxK, MTile>::Type;

const bool pad_seqlen_k = !(param.N % FmhaShape::kN0 == 0);
Expand Down Expand Up @@ -307,7 +316,9 @@ struct batched_infer_mask_bias_dropout_dispatch {
/* runtime will never get here, so no codes to compile */
};
});
};
}
#endif
;
};

template <typename FmhaKernel>
Expand Down
6 changes: 6 additions & 0 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_fwd_setting.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ struct FmhaFwdCommonBlockTile<128, 128> {

template <ck_tile::index_t MTile>
struct FmhaFwdCommonBlockTile<256, MTile> {
#if defined(FMHA_BUILD_ON_GFX11)
// gfx11 WMMA duplicates Q data across subgroups, so a 128x256 Q tile
// can exceed static_for's 256-iteration limit. Keep hdim-256 at M=64.
using tile_lengths = ck_tile::sequence<64, 128, 32, 256, 32, 256>;
#else
using tile_lengths = ck_tile::sequence<128, 128, 32, 256, 32, 256>;
#endif
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,19 @@ struct grouped_infer_mask_bias_dropout_dispatch {
constexpr bool kPadSeqLenQ = false;
constexpr bool kPadSeqLenK = true;

#if defined(FMHA_BUILD_ON_GFX11) || defined(FMHA_BUILD_ON_GFX12)
// Current RDNA3/4 CK FMHA builds use the sync pipeline; the async
// global-to-LDS path fails for these targets, so keep it uninstantiated
// instead of relying on a core CK fallback.
constexpr bool enable_async_pipeline = false;
#else
const bool enable_async_pipeline = []() {
const char* env_p = std::getenv("FMHA_ENABLE_ASYNC_PIPELINE");
if (env_p == nullptr)
return false;
return static_cast<bool>(atoi(env_p));
}();
#endif

const bool use_async_pipeline =
(!kHasBias && (param.K % 8 == 0) && (param.Kv % 8 == 0) &&
Expand Down Expand Up @@ -261,7 +268,9 @@ struct grouped_infer_mask_bias_dropout_dispatch {
RunWithKernel<FmhaKernel>(param, stream);
}
});
} else {
}
#if !defined(FMHA_BUILD_ON_GFX11) && !defined(FMHA_BUILD_ON_GFX12)
else {
if constexpr (MaxK <= 128 && MTile <= 128) {
using FmhaShape = typename FmhaFwdCommonShape<MaxK, MTile>::Type;

Expand Down Expand Up @@ -297,7 +306,9 @@ struct grouped_infer_mask_bias_dropout_dispatch {
} else {
/* runtime will never get here, so no codes to compile */
};
};
}
#endif
;
};

template <typename FmhaKernel>
Expand Down
12 changes: 12 additions & 0 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@

#include <ck_tile/core.hpp>

#if defined(FMHA_BUILD_ON_GFX11) || defined(FMHA_BUILD_ON_GFX12)
// RDNA WMMA only registers 16x16x16 in CK tile's warp_gemm dispatcher for
// fp16 / bf16 (and fp8 / int8). The xformers wrapper layer was authored for
// CDNA MFMA shapes (32x32x16 and 16x16x32). To keep the per-arch *_setting.h
// headers small, we alias all three names down to the WMMA-supported shape.
// Block tile geometry is unchanged — each warp simply issues more WMMA
// instructions to cover the same output tile.
using WarpTile_32x32x16 = ck_tile::sequence<16, 16, 16>;
using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 16>;
using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>;
#else
using WarpTile_32x32x16 = ck_tile::sequence<32, 32, 16>;
using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>;
using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>;
#endif
1 change: 1 addition & 0 deletions xformers/ops/differentiable_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional, Tuple

Expand Down
Loading
Loading