diff --git a/README.md b/README.md index fe320b604c6..788453d9b13 100755 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ container from ROCm, which has all the required tools to install FlashAttention. #### Composable Kernel Backend FlashAttention-2 ROCm CK backend currently supports: -1. MI200x, MI250x, MI300x, and MI355x GPUs. +1. MI200x, MI250x, MI300x, MI355x, and RDNA4 GPUs. 2. Datatype fp16 and bf16 3. Both forward's and backward's head dimensions up to 256. diff --git a/csrc/flash_attn_ck/flash_common.hpp b/csrc/flash_attn_ck/flash_common.hpp index cc86546ea54..d7922ced3e0 100644 --- a/csrc/flash_attn_ck/flash_common.hpp +++ b/csrc/flash_attn_ck/flash_common.hpp @@ -9,6 +9,10 @@ #include #include #include +#include +#ifdef USE_ROCM +#include +#endif #ifdef OLD_GENERATOR_PATH #include @@ -73,4 +77,21 @@ inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int nu int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits); +inline bool is_gfx12_arch() { +#ifdef USE_ROCM + int dev = 0; + if (hipGetDevice(&dev) != hipSuccess) { + return false; + } + hipDeviceProp_t prop{}; + if (hipGetDeviceProperties(&prop, dev) != hipSuccess) { + return false; + } + std::string arch = prop.gcnArchName; + return !arch.empty() && arch.rfind("gfx12", 0) == 0; +#else + return false; +#endif +} + } // namespace flash diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 083494f5b0c..fac4f087b03 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -319,6 +319,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); + if (deterministic && flash::is_gfx12_arch()) { + TORCH_CHECK(false, + "Deterministic CK backward is unstable on gfx12 GPUs. " + "Please rerun with deterministic=False."); + } auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum; diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 3cd01c32d48..1c338928e8f 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -333,6 +333,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); + if (deterministic && flash::is_gfx12_arch()) { + TORCH_CHECK(false, + "Deterministic CK backward is unstable on gfx12 GPUs. " + "Please rerun with deterministic=False."); + } auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum; @@ -431,4 +436,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a53b4a3108a..dcc5fbad67f 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -19,6 +19,13 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def is_gfx12(device="cuda"): + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(device) + name = getattr(props, "gcnArchName", "") or getattr(props, "name", "") + return "gfx12" in name.lower() + def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel diff --git a/setup.py b/setup.py index fafea904998..ebe82824bfb 100644 --- a/setup.py +++ b/setup.py @@ -197,7 +197,7 @@ def rename_cpp_to_cu(cpp_files): def validate_and_update_archs(archs): # List of allowed architectures - allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1200", "gfx1201"] # Validate if each element in archs is in allowed_archs assert all( @@ -382,10 +382,22 @@ def validate_and_update_archs(archs): os.makedirs("build") optdim = os.getenv("OPT_DIM", "32,64,128,256") - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + + detected_arch = None + if archs != ['native']: + kernel_targets = archs + else: + detected_arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0] + kernel_targets = [detected_arch.lower()] + validate_and_update_archs(kernel_targets) + + targets_arg = ",".join(kernel_targets) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim, "--targets", targets_arg], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 @@ -395,14 +407,7 @@ def validate_and_update_archs(archs): generator_flag = ["-DOLD_GENERATOR_PATH"] check_if_rocm_home_none("flash_attn") - archs = os.getenv("GPU_ARCHS", "native").split(";") - validate_and_update_archs(archs) - - if archs != ['native']: - cc_flag = [f"--offload-arch={arch}" for arch in archs] - else: - arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0] - cc_flag = [f"--offload-arch={arch}"] + cc_flag = [f"--offload-arch={arch}" for arch in kernel_targets] # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as # torch._C._GLIBCXX_USE_CXX11_ABI diff --git a/tests/test_flash_attn_ck.py b/tests/test_flash_attn_ck.py index d5590fcfc82..795776236b5 100644 --- a/tests/test_flash_attn_ck.py +++ b/tests/test_flash_attn_ck.py @@ -13,6 +13,7 @@ flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) +from flash_attn.flash_attn_interface import is_gfx12 from test_flash_attn import ( attn_bias_from_alibi_slopes, @@ -27,9 +28,16 @@ from flash_attn.layers.rotary import apply_rotary_emb + +def skip_deterministic_bwd(deterministic: bool) -> bool: + return deterministic and is_gfx12() + def is_bwd_hdim_supported(d): return d <= 256 +def is_bwd_supported(d, deterministic): + return is_bwd_hdim_supported(d) and not skip_deterministic_bwd(deterministic) + def ck_randval_to_dropout_mask(randval, p): # If p = 0.3, randval in 255 * (0.7, 1.0] will be dropout @@ -143,7 +151,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) - if is_bwd_hdim_supported(d): + if is_bwd_supported(d, deterministic): (dqkv,) = torch.autograd.grad(out, qkv, g) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) (dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) @@ -260,7 +268,7 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) - if is_bwd_hdim_supported(d): + if is_bwd_supported(d, deterministic): (dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) (dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) @@ -442,7 +450,7 @@ def test_flash_attn_output( assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) - if is_bwd_hdim_supported(d): + if is_bwd_supported(d, deterministic): if kvpacked: ( dq, @@ -703,7 +711,7 @@ def test_flash_attn_varlen_output( assert (out - out_ref).abs().max().item() <= 4 * (out_pt - out_ref).abs().max().item() g = torch.randn_like(out) - if is_bwd_hdim_supported(d): + if is_bwd_supported(d, deterministic): if kvpacked: ( dq_unpad, @@ -1513,6 +1521,8 @@ def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): ], ) def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if is_gfx12(): + pytest.skip("Deterministic backward not yet supported on CK backend") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1561,6 +1571,8 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc ], ) def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): + if is_gfx12(): + pytest.skip("Deterministic backward not yet supported on CK backend") if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1615,4 +1627,3 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dv, dv0) assert torch.equal(dk, dk0) assert torch.equal(dq, dq0) -