Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def get_rocm_agent_arch():
arches = subprocess.check_output([exec_path], universal_newlines=True)
arch_list = arches.strip().split()
return arch_list[0]
elif torch.cuda.is_available() and torch.version.hip:
return torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
else:
return "gfx942"

Expand Down Expand Up @@ -579,7 +581,9 @@ def get_extensions():
and (torch.cuda.is_available() or os.getenv("HIP_ARCHITECTURE", "") != "")
):
rename_cpp_cu(source_hip)
hip_version = get_hip_version(ROCM_HOME)
hip_version = get_hip_version(
None if platform.system() == "Windows" else ROCM_HOME
)

source_hip_cu = []
for ff in source_hip:
Expand Down Expand Up @@ -608,16 +612,30 @@ def get_extensions():
if arch == "native":
arch = get_rocm_agent_arch()

if arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"]:
if (
arch not in ["gfx908", "gfx90a", "gfx942", "gfx950"]
and not arch.startswith(("gfx11", "gfx12"))
):
raise ValueError(f"Not supported AMD GPU arch: {arch}")

if arch == "gfx950":
cc_flag += ["-DFMHA_BUILD_ON_GFX950"]
elif arch.startswith("gfx11"):
cc_flag += ["-DFMHA_BUILD_ON_GFX11"]
elif arch.startswith("gfx12"):
cc_flag += ["-DFMHA_BUILD_ON_GFX12"]

offload_compress_flag = []
if hip_version >= "6.2.":
offload_compress_flag = ["--offload-compress"]

if platform.system() == "Windows":
cc_flag += [
"-Wno-deprecated-declarations",
"-Wno-unused-command-line-argument",
"-Wno-unknown-attributes",
]

extra_compile_args["nvcc"] = [
"-O3",
"-std=c++20",
Expand Down Expand Up @@ -697,7 +715,31 @@ def __init__(self, *args, **kwargs) -> None:
self.pkg_name = "xformers"
super().__init__(*args, **kwargs)

def _use_windows_link_response_files(self) -> None:
if platform.system() != "Windows":
return

original_spawn = self.compiler.spawn

def spawn_with_response_file(cmd):
tool = os.path.basename(str(cmd[0])).lower() if cmd else ""
command_length = sum(len(str(arg)) + 1 for arg in cmd)
if tool not in {"link.exe", "lld-link.exe"} or command_length < 30_000:
return original_spawn(cmd)

rsp_path = Path(self.build_temp, f"xformers_link_{id(cmd)}.rsp").resolve()
rsp_path.parent.mkdir(parents=True, exist_ok=True)
rsp_path.write_text(
"\n".join(subprocess.list2cmdline([str(arg)]) for arg in cmd[1:])
+ "\n"
)

return original_spawn([cmd[0], f"@{rsp_path}"])

self.compiler.spawn = spawn_with_response_file

def build_extensions(self) -> None:
self._use_windows_link_response_files()
super().build_extensions()

# Fix incorrect output names caused by py_limited_api=True on Windows. see item #1272
Expand Down
113 changes: 87 additions & 26 deletions xformers/csrc/attention/hip_fmha/attention_ck_rand_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,63 @@
#include <torch/types.h>
#include <ATen/cuda/PhiloxUtils.cuh>

#include <ck_tile/core.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <cstdint>

#include "ck_tiled_rand_uniform_kernel.h"
#include <ck_tile/core.hpp>

namespace {

__global__ void rand_uniform_int_kernel(
uint8_t* randvals,
int M,
int N,
int num_heads,
int64_t stride_m,
int64_t stride_n,
int64_t stride_head,
int64_t stride_batch,
uint64_t philox_seed,
uint64_t philox_offset) {
constexpr int kPhiloxPerTile = 64;
constexpr int kWarpGemmMN = 32;

const int row = blockIdx.x;
const int col = blockIdx.y;
const int batch_head = blockIdx.z;
const int i_batch = batch_head / num_heads;
const int i_head = batch_head - i_batch * num_heads;
const int lane = threadIdx.x;

if (lane >= kPhiloxPerTile) {
return;
}

const auto subsequence =
ck_tile::bit_cast<unsigned long long>(make_uint2(row, col));
ck_tile::philox ph(
philox_seed,
philox_offset + (i_batch * num_heads + i_head) * kPhiloxPerTile + lane);

uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, subsequence);

uint8_t* out = randvals +
static_cast<int64_t>(i_batch) * stride_batch +
static_cast<int64_t>(i_head) * stride_head;

for (int r = 0; r < 16; ++r) {
const int i = (16 * (r / 8) % kWarpGemmMN) + 8 * (lane / 32) + (r % 8);
const int j = lane % kWarpGemmMN;
const int m = row * kWarpGemmMN + i;
const int n = col * kWarpGemmMN + j;

if (m < M && n < N) {
out[static_cast<int64_t>(m) * stride_m +
static_cast<int64_t>(n) * stride_n] = random_uint8_t[r];
}
}
}

/**
* generate a tensor with random uniform values. only used for testing, not much
* attention is paid to performance
Expand All @@ -28,6 +78,8 @@ at::Tensor rand_uniform_int(
double dropout_prob,
const at::Tensor& out_pattern) // [Batches, num_head, query_len, key_len]
{
(void)dropout_prob;

int B = out_pattern.size(0);
int num_heads = out_pattern.size(1);
int M = out_pattern.size(2);
Expand Down Expand Up @@ -56,34 +108,43 @@ at::Tensor rand_uniform_int(
randvals = at::empty(
{B, num_heads, M, N}, out_pattern.options().dtype(at::ScalarType::Byte));

{
// only work for batched mode
using FmhaRandUniformKernel_ = FmhaRandUniformKernel<uint8_t, false>;

const auto kargs = FmhaRandUniformKernel_::MakeKargs(
randvals.data_ptr(),
if (B > 0 && num_heads > 0 && M > 0 && N > 0) {
constexpr int kWarpGemmMN = 32;
const dim3 grid(
(M + kWarpGemmMN - 1) / kWarpGemmMN,
(N + kWarpGemmMN - 1) / kWarpGemmMN,
B * num_heads);
const dim3 block(64);

hipLaunchKernelGGL(
rand_uniform_int_kernel,
grid,
block,
0,
stream,
static_cast<uint8_t*>(randvals.data_ptr()),
M,
N,
num_heads,
B,
static_cast<int>(randvals.stride(2)),
static_cast<int>(randvals.stride(3)),
static_cast<int>(randvals.stride(1)),
static_cast<int>(randvals.stride(0)),
{philox_seed, philox_offset});

dim3 kGridSize = FmhaRandUniformKernel_::GridSize(B, num_heads, M, N);
dim3 kBlockSize = FmhaRandUniformKernel_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu =
FmhaRandUniformKernel_::kBlockPerCu;

(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(
FmhaRandUniformKernel_{}, kGridSize, kBlockSize, 0, kargs));
randvals.stride(2),
randvals.stride(3),
randvals.stride(1),
randvals.stride(0),
static_cast<uint64_t>(philox_seed),
static_cast<uint64_t>(philox_offset));

const auto launch_status = hipGetLastError();
TORCH_CHECK(
launch_status == hipSuccess,
"HIP rand_uniform_int_kernel launch failed: ",
hipGetErrorString(launch_status));
}

(void)hipStreamSynchronize(stream);
const auto sync_status = hipStreamSynchronize(stream);
TORCH_CHECK(
sync_status == hipSuccess,
"HIP rand_uniform_int_kernel failed: ",
hipGetErrorString(sync_status));

return randvals;
} // namespace
Expand Down
12 changes: 12 additions & 0 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_warp_tile_define.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@

#include <ck_tile/core.hpp>

#if defined(FMHA_BUILD_ON_GFX11) || defined(FMHA_BUILD_ON_GFX12)
// RDNA WMMA only registers 16x16x16 in CK tile's warp_gemm dispatcher for
// fp16 / bf16 (and fp8 / int8). The xformers wrapper layer was authored for
// CDNA MFMA shapes (32x32x16 and 16x16x32). To keep the per-arch *_setting.h
// headers small, we alias all three names down to the WMMA-supported shape.
// Block tile geometry is unchanged — each warp simply issues more WMMA
// instructions to cover the same output tile.
using WarpTile_32x32x16 = ck_tile::sequence<16, 16, 16>;
using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 16>;
using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>;
#else
using WarpTile_32x32x16 = ck_tile::sequence<32, 32, 16>;
using WarpTile_16x16x32 = ck_tile::sequence<16, 16, 32>;
using WarpTile_16x16x16 = ck_tile::sequence<16, 16, 16>;
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ struct FmhaRandUniformKernel {
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}

__host__ static constexpr auto BlockSize() {
return dim3(kBlockSize);
__host__ static dim3 BlockSize() {
if (ck_tile::is_wave32()) {
return dim3(kBlockSize / 2);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// suggest to use kBlockSize/get_warp_size() * 32

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in #87

} else {
return dim3(kBlockSize);
}
}

__device__ static constexpr ck_tile::index_t GetSmemSize() {
Expand Down
1 change: 1 addition & 0 deletions xformers/ops/differentiable_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Optional, Tuple

Expand Down
Loading