diff --git a/.gitmodules b/.gitmodules index e37b611b640..7fd95283aec 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,10 +1,6 @@ [submodule "csrc/cutlass"] path = csrc/cutlass url = https://github.com/NVIDIA/cutlass.git -[submodule "csrc/composable_kernel"] - path = csrc/composable_kernel - url = https://github.com/ROCm/composable_kernel.git - branch = amd-master [submodule "third_party/aiter"] path = third_party/aiter url = https://github.com/ROCm/aiter.git diff --git a/csrc/composable_kernel b/csrc/composable_kernel deleted file mode 160000 index 13f6d635653..00000000000 --- a/csrc/composable_kernel +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 13f6d635653bd5ffbfcac8577f1ef09590c23d78 diff --git a/csrc/flash_attn_ck/ck_build_config.hpp b/csrc/flash_attn_ck/ck_build_config.hpp new file mode 100644 index 00000000000..e16318bd104 --- /dev/null +++ b/csrc/flash_attn_ck/ck_build_config.hpp @@ -0,0 +1,5 @@ +// Generated by setup.py. Do not edit manually. +#pragma once + +#define FLASHATTN_CK_GIT_COMMIT "unknown" +#define FLASHATTN_CK_USE_CURRENT_API 0 diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 083494f5b0c..231f83c7401 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -4,8 +4,10 @@ #include "flash_common.hpp" +#include "ck_build_config.hpp" #include "fmha_bwd.hpp" #include "mask.hpp" +#include fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, std::string dtype, @@ -14,16 +16,38 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, bool enable_alibi, bool deterministic) { +#if FLASHATTN_CK_USE_CURRENT_API + return fmha_bwd_traits{ + .seqlen_q = -1, + .seqlen_k = -1, + .batch = -1, + .max_seqlen_q = -1, + .max_seqlen_k = -1, + .hdim_q = head_size, + .hdim_v = head_size, + .nhead_q = -1, + .nhead_k = -1, + .data_type = std::move(dtype), + .is_group_mode = false, + .mask_type = mask.type, + .bias_type = enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + .has_dbias = false, + .has_dropout = has_dropout, + .is_store_randval = false, + .is_deterministic = deterministic, + }; +#else return fmha_bwd_traits{head_size, head_size, dtype, - false, // is_group_mode + false, mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, - false, // has_dbias + false, has_dropout, - false, // s_randval + false, deterministic}; +#endif } fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, @@ -98,11 +122,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); - // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) - ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); - ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t stride_dq_acc = dq_acc.stride(2); - ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + // dq_acc: (batch_size, nheads, split, seqlen_q, hdim) + ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t stride_dq_acc = dq_acc.stride(3); float p_undrop = 1.0 - p_dropout; @@ -119,74 +143,159 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } +#if FLASHATTN_CK_USE_CURRENT_API + auto drop_seed_var = std::variant, + std::pair>{ + std::pair{0, 0}}; + if (drop_seed_offset.first != nullptr && drop_seed_offset.second != nullptr) { + drop_seed_var = std::pair{drop_seed_offset.first, drop_seed_offset.second}; + } + return fmha_bwd_args{ + .q_ptr = q.data_ptr(), + .k_ptr = k.data_ptr(), + .v_ptr = v.data_ptr(), + .bias_ptr = alibi_slopes_ptr, + .o_ptr = out.data_ptr(), + .lse_ptr = softmax_lse.data_ptr(), + .do_ptr = dout.data_ptr(), + .d_ptr = d.data_ptr(), + .rand_val_ptr = nullptr, + .dq_ptr = dq.data_ptr(), + .dk_ptr = dk.data_ptr(), + .dv_ptr = dv.data_ptr(), + .dbias_ptr = nullptr, + .dq_acc_ptr = dq_acc.data_ptr(), + .seqstart_q_ptr = nullptr, + .seqstart_k_ptr = nullptr, + .seqlen_q_ptr = nullptr, + .seqlen_k_ptr = nullptr, + .cu_seqlen_q_ptr = nullptr, + .cu_seqlen_k_ptr = nullptr, + .seqlen_q = seqlen_q, + .seqlen_k = seqlen_k, + .batch = b, + .max_seqlen_q = seqlen_q, + .max_seqlen_k = seqlen_k, + .hdim_q = hdim, + .hdim_v = hdim, + .nhead_q = h, + .nhead_k = h_k, + .scale = softmax_scale, + .stride_q = stride_q, + .stride_k = stride_k, + .stride_v = stride_v, + .stride_bias = stride_alibi_slopes, + .stride_o = stride_o, + .stride_randval = 0, + .stride_do = stride_do, + .stride_dq_acc = stride_dq_acc, + .stride_dq = stride_dq, + .stride_dk = stride_dk, + .stride_dv = stride_dv, + .stride_dbias = 0, + .nhead_stride_q = nhead_stride_q, + .nhead_stride_k = nhead_stride_k, + .nhead_stride_v = nhead_stride_v, + .nhead_stride_bias = 0, + .nhead_stride_o = nhead_stride_o, + .nhead_stride_randval = 0, + .nhead_stride_do = nhead_stride_do, + .nhead_stride_lsed = nhead_stride_lse, + .nhead_stride_dq_acc = nhead_stride_dq_acc, + .nhead_stride_dq = nhead_stride_dq, + .nhead_stride_dk = nhead_stride_dk, + .nhead_stride_dv = nhead_stride_dv, + .nhead_stride_dbias = 0, + .batch_stride_q = batch_stride_q, + .batch_stride_k = batch_stride_k, + .batch_stride_v = batch_stride_v, + .batch_stride_bias = 0, + .batch_stride_o = batch_stride_o, + .batch_stride_randval = 0, + .batch_stride_do = batch_stride_do, + .batch_stride_lsed = batch_stride_lse, + .batch_stride_dq_acc = batch_stride_dq_acc, + .batch_stride_dq = batch_stride_dq, + .batch_stride_dk = batch_stride_dk, + .batch_stride_dv = batch_stride_dv, + .batch_stride_dbias = 0, + .split_stride_dq_acc = split_stride_dq_acc, + .window_size_left = mask.left, + .window_size_right = mask.right, + .mask_type = static_cast(mask.type), + .p_drop = p_dropout, + .p_undrop = p_undrop, + .drop_seed_offset = std::move(drop_seed_var), + }; +#else return fmha_bwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), - alibi_slopes_ptr, // bias + alibi_slopes_ptr, out.data_ptr(), softmax_lse.data_ptr(), dout.data_ptr(), d.data_ptr(), - nullptr, // rand_val + nullptr, dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - nullptr, // dbias - dq_acc.data_ptr(), // dq_acc - nullptr, // seqstart_q_ptr - nullptr, // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr + nullptr, + dq_acc.data_ptr(), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, seqlen_q, seqlen_k, b, - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - hdim, // hdim_q - hdim, // hdim_v - h, // nhead - h_k, // nhead_k + seqlen_q, + seqlen_k, + hdim, + hdim, + h, + h_k, softmax_scale, stride_q, stride_k, stride_v, stride_alibi_slopes, stride_o, - 0, // stride_randval + 0, stride_do, stride_dq_acc, stride_dq, stride_dk, stride_dv, - 0, // stride_dbias, FA without bias + 0, nhead_stride_q, nhead_stride_k, nhead_stride_v, - 0, // nhead_stride_bias, FA without bias + 0, nhead_stride_o, - 0, // nhead_stride_randval + 0, nhead_stride_do, nhead_stride_lse, nhead_stride_dq_acc, nhead_stride_dq, nhead_stride_dk, nhead_stride_dv, - 0, // nhead_stride_dbias, FA without dbias + 0, batch_stride_q, batch_stride_k, batch_stride_v, - 0 , // batch_stride_bias, FA without bias + 0, batch_stride_o, - 0, // batch_stride_randval + 0, batch_stride_do, batch_stride_lse, batch_stride_dq_acc, batch_stride_dq, batch_stride_dk, batch_stride_dv, - 0 , // batch_stride_dbias, FA without dbias + 0, split_stride_dq_acc, mask.left, mask.right, @@ -194,6 +303,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, p_dropout, p_undrop, drop_seed_offset}; +#endif } std::vector @@ -323,11 +433,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num at::Tensor dq_accum; if (!deterministic) { - dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({batch_size, num_heads, 1, seqlen_q, head_size}, opts.dtype(at::kFloat)); } else { const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -358,8 +468,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num } if (seqlen_q > 0) { - auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); - auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + std::pair drop_seed_offset = {nullptr, nullptr}; + if (is_dropout) { + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + } ck_tile::stream_config stream_config{stream}; auto traits = diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 0229e777cd5..bc59a0e591f 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -4,8 +4,10 @@ #include "flash_common.hpp" +#include "ck_build_config.hpp" #include "fmha_fwd.hpp" #include "mask.hpp" +#include fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, std::string dtype, @@ -91,32 +93,107 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } +#if FLASHATTN_CK_USE_CURRENT_API + std::variant, std::pair> drop_seed_var; + if (drop_seed_offset.first && drop_seed_offset.second) { + drop_seed_var = std::pair{drop_seed_offset.first, drop_seed_offset.second}; + } else { + drop_seed_var = std::pair{0, 0}; + } + + return fmha_fwd_args{ + .q_ptr = q.data_ptr(), + .k_ptr = k.data_ptr(), + .v_ptr = v.data_ptr(), + .bias_ptr = alibi_slopes_ptr, + .q_descale_ptr = nullptr, + .k_descale_ptr = nullptr, + .v_descale_ptr = nullptr, + .rand_val_ptr = has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + .lse_ptr = has_lse ? softmax_lse.data_ptr() : nullptr, + .o_ptr = out.data_ptr(), + .seqstart_q_ptr = nullptr, + .seqstart_k_ptr = nullptr, + .seqlen_q_ptr = nullptr, + .seqlen_k_ptr = nullptr, + .cu_seqlen_q_ptr = nullptr, + .cu_seqlen_k_ptr = nullptr, + .block_scale_seqstart_q_ptr = nullptr, + .block_scale_seqstart_k_ptr = nullptr, + .sink_ptr = nullptr, + .seqlen_q = seqlen_q, + .seqlen_k = seqlen_k, + .batch = b, + .max_seqlen_q = seqlen_q, + .hdim_q = d, + .hdim_v = d, + .nhead_q = h, + .nhead_k = h_k, + .scale_s = softmax_scale, + .logits_soft_cap = 0.0f, + .stride_q = stride_q, + .stride_k = stride_k, + .stride_v = stride_v, + .stride_bias = stride_alibi_slopes, + .stride_randval = stride_randval, + .stride_o = stride_o, + .nhead_stride_q = nhead_stride_q, + .nhead_stride_k = nhead_stride_k, + .nhead_stride_v = nhead_stride_v, + .nhead_stride_bias = 0, + .nhead_stride_randval = nhead_stride_randval, + .nhead_stride_lse = nhead_stride_lse, + .nhead_stride_o = nhead_stride_o, + .nhead_stride_q_descale = 0, + .nhead_stride_k_descale = 0, + .nhead_stride_v_descale = 0, + .batch_stride_q = batch_stride_q, + .batch_stride_k = batch_stride_k, + .batch_stride_v = batch_stride_v, + .batch_stride_bias = 0, + .batch_stride_randval = batch_stride_randval, + .batch_stride_lse = batch_stride_lse, + .batch_stride_o = batch_stride_o, + .batch_stride_q_descale = 0, + .batch_stride_k_descale = 0, + .batch_stride_v_descale = 0, + .window_size_left = mask.left, + .window_size_right = mask.right, + .sink_size = 0, + .mask_type = static_cast(mask.type), + .min_seqlen_q = 0, + .p_drop = p_dropout, + .s_randval = has_dropout_randval, + .drop_seed_offset = drop_seed_var, + .block_scale_size_q = 0, + .block_scale_size_kv = 0}; +#else return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), - alibi_slopes_ptr, // bias - nullptr, // q_descale_ptr - nullptr, // k_descale_ptr - nullptr, // v_descale_ptr + alibi_slopes_ptr, + nullptr, + nullptr, + nullptr, has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - nullptr, // seqstart_q_ptr - nullptr, // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, seqlen_q, seqlen_k, b, - seqlen_q, // max_seqlen_q - d, // hdim_q - d, // hdim_v - h, // nhead - h_k, // nhead_k - softmax_scale, // scale_s - 0.0f, // logits_soft_cap + seqlen_q, + d, + d, + h, + h_k, + softmax_scale, + 0.0f, stride_q, stride_k, stride_v, @@ -126,24 +203,25 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nhead_stride_q, nhead_stride_k, nhead_stride_v, - 0, // nhead_stride_bias, FA without bias + 0, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - 0, // batch_stride_bias, FA without bias + 0, batch_stride_randval, batch_stride_lse, batch_stride_o, mask.left, mask.right, static_cast(mask.type), - 0, // min_seqlen_q + 0, p_dropout, has_dropout_randval, drop_seed_offset}; +#endif } std::vector @@ -275,7 +353,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num } if (seqlen_k > 0) { - auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + auto drop_seed_offset = p_dropout > 0.0 + ? std::make_pair(rng_state_ptr, rng_state_ptr + 1) + : std::make_pair(nullptr, nullptr); #ifdef HIPIFY_V2 auto stream = at::cuda::getCurrentCUDAStream().stream(); #else diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index 27866f1902e..07793dc9934 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -73,7 +73,7 @@ fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b, // rotary_sin: (seqlen_ro, rotary_dim / 2) // block_table: (batch_size, max_num_blocks_per_seq) - fmha_fwd_appendkv_args args; + fmha_fwd_appendkv_args args{}; args.q_ptr = q.data_ptr(); args.k_ptr = kcache.data_ptr(); args.knew_ptr = knew.data_ptr(); @@ -168,7 +168,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, // lse_acc: (split, batch_size, nheads, seqlen_q) // o_acc: (split, batch_size, nheads, seqlen_q, d) - fmha_fwd_splitkv_args args; + fmha_fwd_splitkv_args args{}; args.q_ptr = q.data_ptr(); args.k_ptr = k.data_ptr(); args.v_ptr = v.data_ptr(); @@ -197,6 +197,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, args.seqstart_q_ptr = nullptr; args.seqstart_k_ptr = nullptr; args.seqlen_k_ptr = seqlens_k.data_ptr(); + args.sink_ptr = nullptr; args.seqlen_q = seqlen_q; args.seqlen_k = seqlen_k; @@ -211,6 +212,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, args.scale_s = softmax_scale; args.scale_p = 1; args.scale_o = 1; + args.logits_soft_cap = 0.0f; args.batch_stride_q = q.stride(0); args.stride_q = q.stride(1); @@ -261,6 +263,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = 0; args.mask_type = static_cast(mask.type); return args; diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 3cd01c32d48..485e62f931b 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -4,8 +4,10 @@ #include "flash_common.hpp" +#include "ck_build_config.hpp" #include "fmha_bwd.hpp" #include "mask.hpp" +#include fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, std::string dtype, @@ -14,16 +16,38 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, bool enable_alibi, bool deterministic) { +#if FLASHATTN_CK_USE_CURRENT_API + return fmha_bwd_traits{ + .seqlen_q = -1, + .seqlen_k = -1, + .batch = -1, + .max_seqlen_q = -1, + .max_seqlen_k = -1, + .hdim_q = head_size, + .hdim_v = head_size, + .nhead_q = -1, + .nhead_k = -1, + .data_type = std::move(dtype), + .is_group_mode = true, + .mask_type = mask.type, + .bias_type = enable_alibi ? bias_enum::alibi : bias_enum::no_bias, + .has_dbias = false, + .has_dropout = has_dropout, + .is_store_randval = false, + .is_deterministic = deterministic, + }; +#else return fmha_bwd_traits{head_size, head_size, dtype, - true, // is_group_mode + true, mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, - false, // has_dbias + false, has_dropout, - false, // s_randval + false, deterministic}; +#endif } fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, @@ -104,11 +128,11 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(0); ck_tile::index_t nhead_stride_dv = dv.stride(1); - // dq_acc: (split, total_q, nheads, hdim) - ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); + // dq_acc: (nheads, split, total_q, hdim) + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(1); ck_tile::index_t batch_stride_dq_acc = 0; - ck_tile::index_t stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(0); float p_undrop = 1.0 - p_dropout; @@ -125,74 +149,159 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } +#if FLASHATTN_CK_USE_CURRENT_API + auto drop_seed_var = std::variant, + std::pair>{ + std::pair{0, 0}}; + if (drop_seed_offset.first != nullptr && drop_seed_offset.second != nullptr) { + drop_seed_var = std::pair{drop_seed_offset.first, drop_seed_offset.second}; + } + return fmha_bwd_args{ + .q_ptr = q.data_ptr(), + .k_ptr = k.data_ptr(), + .v_ptr = v.data_ptr(), + .bias_ptr = alibi_slopes_ptr, + .o_ptr = out.data_ptr(), + .lse_ptr = softmax_lse.data_ptr(), + .do_ptr = dout.data_ptr(), + .d_ptr = d.data_ptr(), + .rand_val_ptr = nullptr, + .dq_ptr = dq.data_ptr(), + .dk_ptr = dk.data_ptr(), + .dv_ptr = dv.data_ptr(), + .dbias_ptr = nullptr, + .dq_acc_ptr = dq_acc.data_ptr(), + .seqstart_q_ptr = seqlens_q.data_ptr(), + .seqstart_k_ptr = seqlens_k.data_ptr(), + .seqlen_q_ptr = nullptr, + .seqlen_k_ptr = nullptr, + .cu_seqlen_q_ptr = nullptr, + .cu_seqlen_k_ptr = nullptr, + .seqlen_q = total_q, + .seqlen_k = total_k, + .batch = b, + .max_seqlen_q = max_seqlen_q, + .max_seqlen_k = max_seqlen_k, + .hdim_q = hdim, + .hdim_v = hdim, + .nhead_q = h, + .nhead_k = h_k, + .scale = softmax_scale, + .stride_q = stride_q, + .stride_k = stride_k, + .stride_v = stride_v, + .stride_bias = stride_alibi_slopes, + .stride_o = stride_o, + .stride_randval = 0, + .stride_do = stride_do, + .stride_dq_acc = stride_dq_acc, + .stride_dq = stride_dq, + .stride_dk = stride_dk, + .stride_dv = stride_dv, + .stride_dbias = 0, + .nhead_stride_q = nhead_stride_q, + .nhead_stride_k = nhead_stride_k, + .nhead_stride_v = nhead_stride_v, + .nhead_stride_bias = 0, + .nhead_stride_o = nhead_stride_o, + .nhead_stride_randval = 0, + .nhead_stride_do = nhead_stride_do, + .nhead_stride_lsed = nhead_stride_lse, + .nhead_stride_dq_acc = nhead_stride_dq_acc, + .nhead_stride_dq = nhead_stride_dq, + .nhead_stride_dk = nhead_stride_dk, + .nhead_stride_dv = nhead_stride_dv, + .nhead_stride_dbias = 0, + .batch_stride_q = batch_stride_q, + .batch_stride_k = batch_stride_k, + .batch_stride_v = batch_stride_v, + .batch_stride_bias = 0, + .batch_stride_o = batch_stride_o, + .batch_stride_randval = 0, + .batch_stride_do = batch_stride_do, + .batch_stride_lsed = batch_stride_lse, + .batch_stride_dq_acc = batch_stride_dq_acc, + .batch_stride_dq = batch_stride_dq, + .batch_stride_dk = batch_stride_dk, + .batch_stride_dv = batch_stride_dv, + .batch_stride_dbias = 0, + .split_stride_dq_acc = split_stride_dq_acc, + .window_size_left = mask.left, + .window_size_right = mask.right, + .mask_type = static_cast(mask.type), + .p_drop = p_dropout, + .p_undrop = p_undrop, + .drop_seed_offset = std::move(drop_seed_var), + }; +#else return fmha_bwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), - alibi_slopes_ptr, // bias + alibi_slopes_ptr, out.data_ptr(), softmax_lse.data_ptr(), dout.data_ptr(), d.data_ptr(), - nullptr, // rand_val + nullptr, dq.data_ptr(), dk.data_ptr(), dv.data_ptr(), - nullptr, // dbias - dq_acc.data_ptr(), // dq_acc - seqlens_q.data_ptr(), // seqstart_q_ptr - seqlens_k.data_ptr(), // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr + nullptr, + dq_acc.data_ptr(), + seqlens_q.data_ptr(), + seqlens_k.data_ptr(), + nullptr, + nullptr, + nullptr, + nullptr, total_q, total_k, b, - max_seqlen_q, // max_seqlen_q - max_seqlen_k, // max_seqlen_k - hdim, // hdim_q - hdim, // hdim_v - h, // nhead - h_k, // nhead_k + max_seqlen_q, + max_seqlen_k, + hdim, + hdim, + h, + h_k, softmax_scale, stride_q, stride_k, stride_v, stride_alibi_slopes, stride_o, - 0, // stride_randval + 0, stride_do, stride_dq_acc, stride_dq, stride_dk, stride_dv, - 0, // stride_dbias, FA without bias + 0, nhead_stride_q, nhead_stride_k, nhead_stride_v, - 0, // nhead_stride_bias, FA without bias + 0, nhead_stride_o, - 0, // nhead_stride_randval + 0, nhead_stride_do, nhead_stride_lse, nhead_stride_dq_acc, nhead_stride_dq, nhead_stride_dk, nhead_stride_dv, - 0, // nhead_stride_dbias, FA without dbias + 0, batch_stride_q, batch_stride_k, batch_stride_v, - 0 , // batch_stride_bias, FA without bias + 0, batch_stride_o, - 0, // batch_stride_randval + 0, batch_stride_do, batch_stride_lse, batch_stride_dq_acc, batch_stride_dq, batch_stride_dk, batch_stride_dv, - 0 , // batch_stride_dbias, FA without dbias + 0, split_stride_dq_acc, mask.left, mask.right, @@ -200,6 +309,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, p_dropout, p_undrop, drop_seed_offset}; +#endif } std::vector @@ -337,11 +447,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads at::Tensor dq_accum; if (!deterministic) { - dq_accum = torch::zeros({1, total_q, num_heads, head_size}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({num_heads, 1, total_q, head_size}, opts.dtype(at::kFloat)); } else { const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -381,8 +491,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads } if (max_seqlen_q > 0) { - auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); - auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + std::pair drop_seed_offset = {nullptr, nullptr}; + if (is_dropout) { + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + } ck_tile::stream_config stream_config{stream}; auto traits = @@ -431,4 +544,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 00b0fcd5738..1596c2ebade 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -4,8 +4,10 @@ #include "flash_common.hpp" +#include "ck_build_config.hpp" #include "fmha_fwd.hpp" #include "mask.hpp" +#include fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, std::string dtype, @@ -112,32 +114,107 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } +#if FLASHATTN_CK_USE_CURRENT_API + std::variant, std::pair> drop_seed_var; + if (drop_seed_offset.first && drop_seed_offset.second) { + drop_seed_var = std::pair{drop_seed_offset.first, drop_seed_offset.second}; + } else { + drop_seed_var = std::pair{0, 0}; + } + + return fmha_fwd_args{ + .q_ptr = q.data_ptr(), + .k_ptr = k.data_ptr(), + .v_ptr = v.data_ptr(), + .bias_ptr = alibi_slopes_ptr, + .q_descale_ptr = nullptr, + .k_descale_ptr = nullptr, + .v_descale_ptr = nullptr, + .rand_val_ptr = has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + .lse_ptr = has_lse ? softmax_lse.data_ptr() : nullptr, + .o_ptr = out.data_ptr(), + .seqstart_q_ptr = seqlens_q.data_ptr(), + .seqstart_k_ptr = seqlens_k.data_ptr(), + .seqlen_q_ptr = nullptr, + .seqlen_k_ptr = nullptr, + .cu_seqlen_q_ptr = nullptr, + .cu_seqlen_k_ptr = nullptr, + .block_scale_seqstart_q_ptr = nullptr, + .block_scale_seqstart_k_ptr = nullptr, + .sink_ptr = nullptr, + .seqlen_q = total_q, + .seqlen_k = total_k, + .batch = b, + .max_seqlen_q = max_seqlen_q, + .hdim_q = d, + .hdim_v = d, + .nhead_q = h, + .nhead_k = h_k, + .scale_s = softmax_scale, + .logits_soft_cap = 0.0f, + .stride_q = stride_q, + .stride_k = stride_k, + .stride_v = stride_v, + .stride_bias = stride_alibi_slopes, + .stride_randval = stride_randval, + .stride_o = stride_o, + .nhead_stride_q = nhead_stride_q, + .nhead_stride_k = nhead_stride_k, + .nhead_stride_v = nhead_stride_v, + .nhead_stride_bias = 0, + .nhead_stride_randval = nhead_stride_randval, + .nhead_stride_lse = nhead_stride_lse, + .nhead_stride_o = nhead_stride_o, + .nhead_stride_q_descale = 0, + .nhead_stride_k_descale = 0, + .nhead_stride_v_descale = 0, + .batch_stride_q = batch_stride_q, + .batch_stride_k = batch_stride_k, + .batch_stride_v = batch_stride_v, + .batch_stride_bias = 0, + .batch_stride_randval = batch_stride_randval, + .batch_stride_lse = batch_stride_lse, + .batch_stride_o = batch_stride_o, + .batch_stride_q_descale = 0, + .batch_stride_k_descale = 0, + .batch_stride_v_descale = 0, + .window_size_left = mask.left, + .window_size_right = mask.right, + .sink_size = 0, + .mask_type = static_cast(mask.type), + .min_seqlen_q = 0, + .p_drop = p_dropout, + .s_randval = has_dropout_randval, + .drop_seed_offset = drop_seed_var, + .block_scale_size_q = 0, + .block_scale_size_kv = 0}; +#else return fmha_fwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), - alibi_slopes_ptr, // bias - nullptr, // q_descale_ptr - nullptr, // k_descale_ptr - nullptr, // v_descale_ptr + alibi_slopes_ptr, + nullptr, + nullptr, + nullptr, has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - seqlens_q.data_ptr(), // seqstart_q_ptr - seqlens_k.data_ptr(), // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_kv_ptr + seqlens_q.data_ptr(), + seqlens_k.data_ptr(), + nullptr, + nullptr, + nullptr, + nullptr, total_q, total_k, b, max_seqlen_q, - d, // hdim_q - d, // hdim_v - h, // nhead - h_k, // nhead_k - softmax_scale, // scale_s - 0.0f, // logits_soft_cap + d, + d, + h, + h_k, + softmax_scale, + 0.0f, stride_q, stride_k, stride_v, @@ -147,24 +224,25 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, nhead_stride_q, nhead_stride_k, nhead_stride_v, - 0, // nhead_stride_bias, FA without bias + 0, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - 0, // batch_stride_bias, FA without bias + 0, batch_stride_randval, batch_stride_lse, batch_stride_o, mask.left, mask.right, static_cast(mask.type), - 0, // min_seqlen_q + 0, p_dropout, has_dropout_randval, drop_seed_offset}; +#endif } fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, @@ -201,7 +279,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, // o_acc: (nheads, split, total_q, d) // block_table: (batch_size, max_num_blocks_per_seq) - fmha_fwd_splitkv_args args; + fmha_fwd_splitkv_args args{}; args.q_ptr = q.data_ptr(); args.k_ptr = k.data_ptr(); args.v_ptr = v.data_ptr(); @@ -231,7 +309,10 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.seqstart_q_ptr = seqlens_q.data_ptr(); args.seqstart_k_ptr = seqlens_k.data_ptr(); args.seqlen_k_ptr = nullptr; + args.sink_ptr = nullptr; + args.seqlen_q = q.size(0); + args.seqlen_k = k.size(0); args.batch = b; args.max_seqlen_q = max_seqlen_q; args.hdim_q = d; @@ -243,6 +324,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.scale_s = softmax_scale; args.scale_p = 1; args.scale_o = 1; + args.logits_soft_cap = 0.0f; args.batch_stride_q = 0; args.stride_q = q.stride(0); @@ -293,6 +375,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = 0; args.mask_type = static_cast(mask.type); return args; @@ -519,7 +602,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si } else { - auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); + auto drop_seed_offset = p_dropout > 0.0 + ? std::make_pair(rng_state_ptr, rng_state_ptr + 1) + : std::make_pair(nullptr, nullptr); auto traits = get_ck_fmha_varlen_fwd_traits( diff --git a/setup.py b/setup.py index 3bf4a904b00..0b077a44cd0 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,12 @@ if IS_ROCM: ROCM_BACKEND = "triton" if os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" else "ck" NVCC_THREADS = os.getenv("NVCC_THREADS") or "4" +CK_GIT_REF = os.getenv("FLASH_ATTENTION_CK_GIT_REF", "").strip() + +CK_API_VARIANT_DROP_SEED_COMMIT = "c24fae234600aa2863e945d072e6f5b3aec2a6b2" +CK_API_CURRENT_COMMIT = "7c6430eca04e62454217630ae2a0bbd70ff50a00" +AITER_SUBMODULE = "third_party/aiter" +AITER_CK_DIR = os.path.join(AITER_SUBMODULE, "3rdparty", "composable_kernel") @functools.lru_cache(maxsize=None) def cuda_archs() -> str: @@ -197,9 +203,86 @@ def rename_cpp_to_cu(cpp_files): shutil.copy(entry, os.path.splitext(entry)[0] + ".cu") +def maybe_checkout_ck_git_ref() -> None: + if not CK_GIT_REF: + return + subprocess.run(["git", "-C", AITER_CK_DIR, "checkout", CK_GIT_REF], check=True) + + +def get_ck_git_commit(ck_dir: str): + try: + return subprocess.check_output( + ["git", "-C", ck_dir, "rev-parse", "HEAD"], + text=True, + ).strip() + except Exception: + return None + + +def ck_commit_is_at_least(ck_dir: str, current_commit: str, minimum_commit: str) -> bool: + if current_commit is None: + return False + if current_commit == minimum_commit: + return True + result = subprocess.run( + ["git", "-C", ck_dir, "merge-base", "--is-ancestor", minimum_commit, current_commit], + check=False, + ) + return result.returncode == 0 + + +def detect_ck_current_api_by_headers(ck_dir: str) -> Optional[bool]: + fmha_bwd = Path(ck_dir) / "example" / "ck_tile" / "01_fmha" / "fmha_bwd.hpp" + fmha_fwd = Path(ck_dir) / "example" / "ck_tile" / "01_fmha" / "fmha_fwd.hpp" + if not fmha_bwd.exists() or not fmha_fwd.exists(): + return None + + bwd_text = fmha_bwd.read_text(encoding="utf-8") + fwd_text = fmha_fwd.read_text(encoding="utf-8") + has_variant_drop_seed = ( + "std::variant, std::pair>" + in bwd_text + and "std::variant, std::pair>" + in fwd_text + ) + has_current_bwd_traits = "struct fmha_bwd_traits" in bwd_text and "int seqlen_q;" in bwd_text + return has_variant_drop_seed and has_current_bwd_traits + + +def write_ck_build_config(ck_dir: str) -> None: + ck_commit = get_ck_git_commit(ck_dir) + use_current_api = detect_ck_current_api_by_headers(ck_dir) + if ( + use_current_api is None + and + ck_commit is not None + and ck_commit_is_at_least(ck_dir, ck_commit, CK_API_VARIANT_DROP_SEED_COMMIT) + and not ck_commit_is_at_least(ck_dir, ck_commit, CK_API_CURRENT_COMMIT) + ): + raise RuntimeError( + "Selected composable_kernel commit is in an unsupported FMHA API transition range. " + f"Choose a legacy commit before {CK_API_VARIANT_DROP_SEED_COMMIT[:9]} or a current " + f"commit at/after {CK_API_CURRENT_COMMIT[:9]}. Current commit: {ck_commit}." + ) + if use_current_api is None: + use_current_api = ( + ck_commit is not None and ck_commit_is_at_least(ck_dir, ck_commit, CK_API_CURRENT_COMMIT) + ) + config_path = Path(this_dir) / "csrc" / "flash_attn_ck" / "ck_build_config.hpp" + content = [ + "// Generated by setup.py. Do not edit manually.", + "#pragma once", + "", + f'#define FLASHATTN_CK_GIT_COMMIT "{ck_commit or "unknown"}"', + f"#define FLASHATTN_CK_USE_CURRENT_API {1 if use_current_api else 0}", + "", + ] + config_path.write_text("\n".join(content), encoding="utf-8") + + def validate_and_update_archs(archs): # List of allowed architectures - allowed_archs = ["native", "gfx90a", "gfx950", "gfx942"] + allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1100", "gfx1101", "gfx1102", "gfx1103"] # Validate if each element in archs is in allowed_archs assert all( @@ -213,24 +296,24 @@ def validate_and_update_archs(archs): # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. if IS_ROCM: + if os.path.isdir(".git"): + subprocess.run(["git", "submodule", "update", "--init", AITER_SUBMODULE], check=True) + if ROCM_BACKEND == "ck": + maybe_checkout_ck_git_ref() + else: + assert os.path.isdir(AITER_SUBMODULE), ( + f"{AITER_SUBMODULE} is missing, please use source distribution or git clone" + ) + if ROCM_BACKEND == "triton": - if os.path.isdir(".git"): - subprocess.run(["git", "submodule", "update", "--init", "third_party/aiter"], check=True) - else: - assert os.path.isdir("third_party/aiter"), ( - "third_party/aiter is missing, please use source distribution or git clone" - ) subprocess.run( - [sys.executable, "-m", "pip", "install", "--no-build-isolation", "third_party/aiter"], + [sys.executable, "-m", "pip", "install", "--no-build-isolation", AITER_SUBMODULE], check=True, ) elif ROCM_BACKEND == "ck": - if os.path.isdir(".git"): - subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) - else: - assert os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py"), ( - "csrc/composable_kernel is missing, please use source distribution or git clone" - ) + assert os.path.exists(os.path.join(AITER_CK_DIR, "example", "ck_tile", "01_fmha", "generate.py")), ( + f"{AITER_CK_DIR} is missing, please initialize {AITER_SUBMODULE}" + ) else: # CUDA: cutlass submodule if os.path.isdir(".git"): @@ -390,17 +473,41 @@ def validate_and_update_archs(archs): # Skips CK C++ extension compilation if using Triton Backend if ROCM_BACKEND == "ck": - ck_dir = "csrc/composable_kernel" + ck_dir = AITER_CK_DIR + write_ck_build_config(ck_dir) + + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + + if archs != ["native"]: + codegen_archs = archs + else: + codegen_archs = [torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0]] + + def map_to_codegen_target(arch): + if arch.startswith("gfx950"): + return "gfx950" + if arch.startswith("gfx9"): + return "gfx9" + if arch.startswith("gfx11"): + return "gfx11" + return None + + ck_codegen_targets = ",".join( + sorted({t for t in (map_to_codegen_target(a) for a in codegen_archs) if t is not None}) + ) + if not ck_codegen_targets: + raise RuntimeError(f"Unable to derive CK codegen targets from archs={codegen_archs}") #use codegen get code dispatch if not os.path.exists("./build"): os.makedirs("build") optdim = os.getenv("OPT_DIM", "32,64,128,256") - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) - subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--targets", ck_codegen_targets, "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--targets", ck_codegen_targets, "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--targets", ck_codegen_targets, "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--targets", ck_codegen_targets, "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 @@ -410,8 +517,6 @@ def validate_and_update_archs(archs): generator_flag = ["-DOLD_GENERATOR_PATH"] check_if_rocm_home_none("flash_attn") - archs = os.getenv("GPU_ARCHS", "native").split(";") - validate_and_update_archs(archs) if archs != ['native']: cc_flag = [f"--offload-arch={arch}" for arch in archs] @@ -488,9 +593,9 @@ def validate_and_update_archs(archs): } include_dirs = [ - Path(this_dir) / "csrc" / "composable_kernel" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", + Path(this_dir) / AITER_CK_DIR / "include", + Path(this_dir) / AITER_CK_DIR / "library" / "include", + Path(this_dir) / AITER_CK_DIR / "example" / "ck_tile" / "01_fmha", ] ext_modules.append(