Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ container from ROCm, which has all the required tools to install FlashAttention.

#### Composable Kernel Backend
FlashAttention-2 ROCm CK backend currently supports:
1. MI200x, MI250x, MI300x, and MI355x GPUs.
1. MI200x, MI250x, MI300x, MI355x, and RDNA 3/4 GPUs.
2. Datatype fp16 and bf16
3. Both forward's and backward's head dimensions up to 256.
4. RDNA 3 GPUs do not currently support backward, and RDNA 4 GPUs support backward only with deterministic=False

#### Triton Backend
The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress.
Expand Down
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 38 files
+1 −1 example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+8 −2 example/ck_tile/01_fmha/fmha_fwd.hpp
+418 −0 example/ck_tile/01_fmha/fmha_fwd_head_grouping.hpp
+89 −1 example/ck_tile/01_fmha/fmha_fwd_runner.hpp
+1 −1 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp
+1 −1 example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp
+2 −2 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.h
+2 −2 example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
+5 −5 example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
+42 −16 experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp
+3 −1 experimental/grouped_convolution_tile_instances/generate_instances.py
+22 −1 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
+3 −19 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+52 −24 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+3 −18 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp
+8 −7 include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp
+116 −26 include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+4 −0 include/ck_tile/ops/gemm.hpp
+266 −0 include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_eight_waves_v1.hpp
+240 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp
+81 −93 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp
+563 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp
+3 −3 include/ck_tile/ops/gemm_quant.hpp
+10 −65 include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr_eight_waves.hpp
+323 −0 include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp
+159 −0 include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp
+0 −581 include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eightwarps.hpp
+18 −5 profiler/include/profiler/grouped_convolution_backward_weight_tile_algs.hpp
+21 −3 profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp
+23 −1 profiler/src/profile_grouped_conv_bwd_weight_tile.cpp
+7 −0 test/ck_tile/gemm/CMakeLists.txt
+22 −0 test/ck_tile/gemm/test_gemm_pipeline_comp_async_eight_waves.cpp
+23 −1 test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp
+3 −0 test/ck_tile/gemm/test_gemm_pipeline_ut_cases.inc
+14 −2 test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+4 −4 test/ck_tile/gemm_block_scale/CMakeLists.txt
+4 −4 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwaves.cpp
+7 −7 test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
47 changes: 47 additions & 0 deletions csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <string>

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
Expand Down Expand Up @@ -73,4 +78,46 @@ inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int nu

int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);

inline std::string get_gcn_arch_name() {
#ifdef USE_ROCM
int dev = 0;
if (hipGetDevice(&dev) != hipSuccess) {
return std::string{};
}
hipDeviceProp_t prop{};
if (hipGetDeviceProperties(&prop, dev) != hipSuccess) {
return std::string{};
}
return std::string{prop.gcnArchName};
#else
return "";
#endif
}

inline bool is_gfx11_arch() {
const std::string arch = get_gcn_arch_name();
return !arch.empty() && arch.rfind("gfx11", 0) == 0;
}

inline bool is_gfx12_arch() {
const std::string arch = get_gcn_arch_name();
return !arch.empty() && arch.rfind("gfx12", 0) == 0;
}

inline bool is_gfx1x_arch() {
return is_gfx11_arch() || is_gfx12_arch();
}

inline void check_gfx1x_bwd_supported(bool deterministic) {
if (is_gfx11_arch()) {
TORCH_CHECK(false, "CK backward is not supported on gfx11.");
}

if (is_gfx12_arch() && deterministic) {
TORCH_CHECK(false,
"Deterministic CK backward is not supported on gfx12. "
"Please rerun with deterministic=False.");
}
}

} // namespace flash
3 changes: 3 additions & 0 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
if (flash::is_gfx1x_arch()) {
flash::check_gfx1x_bwd_supported(deterministic);
}
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size}, opts.dtype(at::kFloat));

Expand Down
30 changes: 27 additions & 3 deletions csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include "flash_common.hpp"
#include "mha_fwd_head_grouping_utils.hpp"

#include "fmha_fwd.hpp"
#include "mask.hpp"

#include <optional>
#include <string>

fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
Expand Down Expand Up @@ -119,6 +121,8 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
d, // hdim_v
h, // nhead
h_k, // nhead_k
0, // num_head_q_total
0, // head_start
softmax_scale, // scale_s
0.0f, // logits_soft_cap
stride_q,
Expand Down Expand Up @@ -330,7 +334,27 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
p_dropout,
drop_seed_offset);

float t = fmha_fwd(traits, args, stream_config);
float t =
flash::maybe_dispatch_head_grouped_fwd(
stream_config,
traits,
args,
num_heads,
num_heads_k,
batch_size,
seqlen_k,
head_size,
head_size,
k.element_size(),
v.element_size(),
q.scalar_type(),
[&](const auto& grouped_traits, auto& grouped_args, const auto& grouped_sc) {
return fmha_fwd(grouped_traits, grouped_args, grouped_sc);
});

if (t < 0.0f) {
t = fmha_fwd(traits, args, stream_config);
}
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
}
else {
Expand Down
101 changes: 101 additions & 0 deletions csrc/flash_attn_ck/mha_fwd_head_grouping_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#pragma once

#include "flash_common.hpp"

#include "fmha_fwd.hpp"
#include "fmha_fwd_head_grouping.hpp"

#include <iostream>

namespace flash {

template <typename FmhaFwdTraits, typename FmhaFwdArgs, typename FmhaFwdFn>
inline float maybe_dispatch_head_grouped_fwd(const ck_tile::stream_config& stream_config,
const FmhaFwdTraits& traits,
const FmhaFwdArgs& args,
int num_heads,
int num_heads_k,
int batch_size,
int seqlen_k,
int head_size_q,
int head_size_v,
size_t elem_bytes_k,
size_t elem_bytes_v,
at::ScalarType q_dtype,
FmhaFwdFn&& fmha_fwd_fn)
{
namespace head_grouping = fmha_fwd_head_grouping;

if (head_grouping::disabled_by_env()) {
if (head_grouping::log_enabled()) {
std::cout << "[LLC Head Grouping] disabled by env" << std::endl;
}
return -1.0f;
}

const auto group_size_opt = head_grouping::get_head_group_size(num_heads,
num_heads_k,
batch_size,
seqlen_k,
head_size_q,
head_size_v,
elem_bytes_k,
elem_bytes_v);
if (!group_size_opt.has_value() || group_size_opt.value() >= num_heads) {
if (head_grouping::log_enabled()) {
std::cout << "[LLC Head Grouping] skipped (group_size not set or >= nhead)"
<< std::endl;
}
return -1.0f;
}

if (head_grouping::log_enabled()) {
const std::string arch = ck_tile::get_device_name();
const size_t llc_bytes = head_grouping::get_llc_cache_bytes(arch);
const ck_tile::index_t gqa_ratio = (num_heads_k > 0 ? (num_heads / num_heads_k) : 1);
const ck_tile::index_t group_sz = group_size_opt.value();
const ck_tile::index_t n_groups = ck_tile::integer_divide_ceil(num_heads, group_sz);
std::cout << "[LLC Head Grouping] enabled"
<< " arch=" << (arch.empty() ? "unknown" : arch)
<< " llc_mb=" << (llc_bytes / (1024ull * 1024ull))
<< " nhead_q=" << num_heads << " nhead_k=" << num_heads_k
<< " gqa_ratio=" << gqa_ratio << " group_size=" << group_sz
<< " groups=" << n_groups << std::endl;
}

const bool use_blockscale_qscale = traits.qscale_type == quant_scale_enum::blockscale;
auto dispatch_grouped_fwd = [&](auto type_config_tag) {
using TypeConfig = decltype(type_config_tag);
return head_grouping::run_fwd_head_grouped<typename TypeConfig::QDataType,
typename TypeConfig::KDataType,
typename TypeConfig::VDataType,
typename TypeConfig::ODataType,
float,
typename TypeConfig::LSEDataType,
typename TypeConfig::RandValOutputDataType>(
stream_config,
traits,
args,
num_heads,
num_heads_k,
group_size_opt.value(),
use_blockscale_qscale,
[&](const auto& grouped_traits, auto& grouped_args, const auto& grouped_sc) {
return fmha_fwd_fn(grouped_traits, grouped_args, grouped_sc);
});
};

if (q_dtype == torch::kFloat16) {
return dispatch_grouped_fwd(FmhaFwdTypeConfig<FmhaFwdFp16>{});
}
if (q_dtype == torch::kBFloat16) {
return dispatch_grouped_fwd(FmhaFwdTypeConfig<FmhaFwdBf16>{});
}
return -1.0f;
}

} // namespace flash
5 changes: 4 additions & 1 deletion csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
if (flash::is_gfx1x_arch()) {
flash::check_gfx1x_bwd_supported(deterministic);
}
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size}, opts.dtype(at::kFloat));

Expand Down Expand Up @@ -450,4 +453,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
}

return { dq, dk, dv, softmax_d };
}
}
30 changes: 27 additions & 3 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#include "flash_common.hpp"
#include "mha_fwd_head_grouping_utils.hpp"

#include "fmha_fwd.hpp"
#include "mask.hpp"

#include <optional>
#include <string>

fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
Expand Down Expand Up @@ -141,6 +143,8 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
d, // hdim_v
h, // nhead
h_k, // nhead_k
0, // num_head_q_total
0, // head_start
softmax_scale, // scale_s
0.0f, // logits_soft_cap
stride_q,
Expand Down Expand Up @@ -572,7 +576,27 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
p_dropout,
drop_seed_offset);

float t = fmha_fwd(traits, args, stream_config);
float t =
flash::maybe_dispatch_head_grouped_fwd(
stream_config,
traits,
args,
num_heads,
num_heads_k,
batch_size,
max_seqlen_k,
head_size,
head_size,
k.element_size(),
v.element_size(),
q.scalar_type(),
[&](const auto& grouped_traits, auto& grouped_args, const auto& grouped_sc) {
return fmha_fwd(grouped_traits, grouped_args, grouped_sc);
});

if (t < 0.0f) {
t = fmha_fwd(traits, args, stream_config);
}
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
}
}
Expand Down
Loading