Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ container from ROCm, which has all the required tools to install FlashAttention.

#### Composable Kernel Backend
FlashAttention-2 ROCm CK backend currently supports:
1. MI200x, MI250x, MI300x, and MI355x GPUs.
1. MI200x, MI250x, MI300x, MI355x, and RDNA4 GPUs.
2. Datatype fp16 and bf16
3. Both forward's and backward's head dimensions up to 256.

Expand Down
21 changes: 21 additions & 0 deletions csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <string>
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
Expand Down Expand Up @@ -73,4 +77,21 @@ inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int nu

int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);

inline bool is_gfx12_arch() {
#ifdef USE_ROCM
int dev = 0;
if (hipGetDevice(&dev) != hipSuccess) {
return false;
}
hipDeviceProp_t prop{};
if (hipGetDeviceProperties(&prop, dev) != hipSuccess) {
return false;
}
std::string arch = prop.gcnArchName;
return !arch.empty() && arch.rfind("gfx12", 0) == 0;
#else
return false;
#endif
}

} // namespace flash
5 changes: 5 additions & 0 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
if (deterministic && flash::is_gfx12_arch()) {
TORCH_CHECK(false,
"Deterministic CK backward is unstable on gfx12 GPUs. "
"Please rerun with deterministic=False.");
}
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor dq_accum;

Expand Down
7 changes: 6 additions & 1 deletion csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
if (deterministic && flash::is_gfx12_arch()) {
TORCH_CHECK(false,
"Deterministic CK backward is unstable on gfx12 GPUs. "
"Please rerun with deterministic=False.");
}
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor dq_accum;

Expand Down Expand Up @@ -431,4 +436,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
}

return { dq, dk, dv, softmax_d };
}
}
7 changes: 7 additions & 0 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def is_gfx12(device="cuda"):
if not torch.cuda.is_available():
return False
props = torch.cuda.get_device_properties(device)
name = getattr(props, "gcnArchName", "") or getattr(props, "name", "")
return "gfx12" in name.lower()


def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
Expand Down
31 changes: 18 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def rename_cpp_to_cu(cpp_files):

def validate_and_update_archs(archs):
# List of allowed architectures
allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"]
allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1200", "gfx1201"]

# Validate if each element in archs is in allowed_archs
assert all(
Expand Down Expand Up @@ -382,10 +382,22 @@ def validate_and_update_archs(archs):
os.makedirs("build")

optdim = os.getenv("OPT_DIM", "32,64,128,256")
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True)
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True)
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True)
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True)
archs = os.getenv("GPU_ARCHS", "native").split(";")
validate_and_update_archs(archs)

detected_arch = None
if archs != ['native']:
kernel_targets = archs
else:
detected_arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0]
kernel_targets = [detected_arch.lower()]
validate_and_update_archs(kernel_targets)

targets_arg = ",".join(kernel_targets)
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True)
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True)
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True)
subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True)

# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
Expand All @@ -395,14 +407,7 @@ def validate_and_update_archs(archs):
generator_flag = ["-DOLD_GENERATOR_PATH"]

check_if_rocm_home_none("flash_attn")
archs = os.getenv("GPU_ARCHS", "native").split(";")
validate_and_update_archs(archs)

if archs != ['native']:
cc_flag = [f"--offload-arch={arch}" for arch in archs]
else:
arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0]
cc_flag = [f"--offload-arch={arch}"]
cc_flag = [f"--offload-arch={arch}" for arch in kernel_targets]

# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
Expand Down
21 changes: 16 additions & 5 deletions tests/test_flash_attn_ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
from flash_attn.flash_attn_interface import is_gfx12

from test_flash_attn import (
attn_bias_from_alibi_slopes,
Expand All @@ -27,9 +28,16 @@

from flash_attn.layers.rotary import apply_rotary_emb


def skip_deterministic_bwd(deterministic: bool) -> bool:
return deterministic and is_gfx12()

def is_bwd_hdim_supported(d):
return d <= 256

def is_bwd_supported(d, deterministic):
return is_bwd_hdim_supported(d) and not skip_deterministic_bwd(deterministic)


def ck_randval_to_dropout_mask(randval, p):
# If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout
Expand Down Expand Up @@ -143,7 +151,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
if is_bwd_supported(d, deterministic):
(dqkv,) = torch.autograd.grad(out, qkv, g)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
Expand Down Expand Up @@ -260,7 +268,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi,
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
if is_bwd_supported(d, deterministic):
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
Expand Down Expand Up @@ -442,7 +450,7 @@ def test_flash_attn_output(
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()

g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
if is_bwd_supported(d, deterministic):
if kvpacked:
(
dq,
Expand Down Expand Up @@ -703,7 +711,7 @@ def test_flash_attn_varlen_output(
assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item()

g = torch.randn_like(out)
if is_bwd_hdim_supported(d):
if is_bwd_supported(d, deterministic):
if kvpacked:
(
dq_unpad,
Expand Down Expand Up @@ -1513,6 +1521,8 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype):
],
)
def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if is_gfx12():
pytest.skip("Deterministic backward not yet supported on CK backend")
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down Expand Up @@ -1561,6 +1571,8 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
],
)
def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype):
if is_gfx12():
pytest.skip("Deterministic backward not yet supported on CK backend")
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
Expand Down Expand Up @@ -1615,4 +1627,3 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert torch.equal(dq, dq0)