diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index bc401538c5..83e5c8d8a1 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1500,10 +1500,11 @@ def can_impl_fmha_v3_bwd_gfx950(): ret &= dbias is None ret &= dropout_p == 0.0 ret &= not deterministic or is_950_1block - ret &= hdim_q == hdim_v ret &= nhead_q % nhead_k == 0 - ret &= hdim_q > 64 and hdim_q <= 128 and hdim_q % 8 == 0 - + ret &= ( + (hdim_q > 64 and hdim_q <= 128) + or (hdim_q == 192 and hdim_v == 128 and nmask) + ) and hdim_q % 8 == 0 return ret can_impl_fmha_v3_bwd_ |= can_impl_fmha_v3_bwd_gfx950() diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index aae35fe10b..ec4315a1a7 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -386,7 +386,8 @@ struct fmha_bwd_v3_traits int ts_dq = 64; }; -template struct fmha_bwd_dq_dk_dv_v3_traits_ { - static constexpr ck_tile::index_t HDim = HDim_; + static constexpr ck_tile::index_t HDim_q = HDim_q_; + static constexpr ck_tile::index_t HDim_v = HDim_v_; using DataType = ck_tile::remove_cvref_t; static constexpr int mask_type = mask_type_; static constexpr bool kIsAtomic32 = kIsAtomic32_; diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index 704188a29b..b0c1420cc0 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -329,14 +329,14 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v] at::Tensor dq_accum; if (!deterministic) { - dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_q}, opts.dtype(at::kFloat)); } else { - const ck_tile::index_t kN0 = head_size_v <= 128 ? 128 : 64; + const ck_tile::index_t kN0 = head_size_q <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); if (mask.type == mask_enum::no_mask) - dq_accum = torch::empty({nsplits, batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(at::kFloat)); + dq_accum = torch::empty({nsplits, batch_size, seqlen_q, num_heads, head_size_q}, opts.dtype(at::kFloat)); else // Some block may be skipped with causal mask and dq are not set to zeros - dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_q}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index 1fd6fb9063..c93c15ddeb 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -318,11 +318,11 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] at::Tensor dq_accum; if (!deterministic) { - dq_accum = torch::zeros({1, total_q, num_heads, head_size_v}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({1, total_q, num_heads, head_size_q}, opts.dtype(at::kFloat)); } else { const ck_tile::index_t kN0 = head_size_q <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size_v}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size_q}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; diff --git a/csrc/py_itfs_cu/asm_mha_bwd.cu b/csrc/py_itfs_cu/asm_mha_bwd.cu index 01efd1a46f..730f49d511 100644 --- a/csrc/py_itfs_cu/asm_mha_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_bwd.cu @@ -303,16 +303,13 @@ std::vector fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h if (!deterministic) { if (is_v3_atomic_fp32) { - dq_accum = torch::zeros({1, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({1, batch_size, num_heads, seqlen_q, head_size_q}, opts.dtype(at::kFloat)); } else { - // When atomic16, padding dq_accum seqlen to 16x, head dim to 128 + // When atomic16, padding dq_accum seqlen to 16x, head dim to 128/192 // In this case, dq_accum could have any layout, we set it to be `bhsd` - dq_accum = torch::zeros({1, batch_size, num_heads, (seqlen_q + 15) / 16 * 16, 128}, opts.dtype(q_dtype)); + int padded_head_size_q = head_size_q == 192? 192: 128; + dq_accum = torch::zeros({1, batch_size, num_heads, (seqlen_q + 15) / 16 * 16, padded_head_size_q}, opts.dtype(q_dtype)); } - } else { - const ck_tile::index_t kN0 = head_size_v <= 128 ? 128 : 64; - const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; diff --git a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu index 81e36d4025..0dad6f0c04 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu @@ -94,7 +94,7 @@ fmha_bwd_args get_asm_fmha_varlen_bwd_args(const mask_info &mask, ck_tile::index_t batch_stride_dq_acc; ck_tile::index_t nhead_stride_dq_acc; ck_tile::index_t stride_dq_acc; - // For atomic32, dq_acc layout is (1, num_heads, total_q, head_size_v) + // For atomic32, dq_acc layout is (1, num_heads, total_q, head_size_q) // For atomic16, dq_acc layout is (1, batch_size, num_heads, (max_seqlen_q + 15) / 16 * 16, 128) if (is_v3_atomic_fp32) { split_stride_dq_acc = dq_acc.stride(0); @@ -338,16 +338,12 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v if (!deterministic) { if (is_v3_atomic_fp32) { - dq_accum = torch::zeros({1, num_heads, total_q, head_size_v}, opts.dtype(at::kFloat)); + dq_accum = torch::zeros({1, num_heads, total_q, head_size_q}, opts.dtype(at::kFloat)); } else { // When atomic16, padding dq_accum seqlen to 16x of max_seqlen_q, head dim to 128 // In this case, dq_accum could have any layout, we set it to be `bhsd` dq_accum = torch::zeros({1, batch_size, num_heads, (max_seqlen_q + 15) / 16 * 16, 128}, opts.dtype(q_dtype)); } - } else { - const ck_tile::index_t kN0 = head_size_q <= 128 ? 128 : 64; - const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, num_heads, total_q, head_size_v}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; diff --git a/hsa/gfx942/fmha_v3_bwd/codegen.py b/hsa/gfx942/fmha_v3_bwd/codegen.py index e44e4b41fc..3749ce8558 100644 --- a/hsa/gfx942/fmha_v3_bwd/codegen.py +++ b/hsa/gfx942/fmha_v3_bwd/codegen.py @@ -19,371 +19,371 @@ namespace aiter { -// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtne"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtna"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtz"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtne"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtna"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtz"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a16"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a16"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtne_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtna_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtz_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtne_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtna_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtz_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtne_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtna_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtz_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a16_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a16_pddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtne"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtna"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtz"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtne"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtna"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtz"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a16"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a32_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a16"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a32_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_a32_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_swa_a32_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtne_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtna_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtz_psskddv"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a32_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a32_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_br_a32_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtne_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtna_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtz_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtne_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtna_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtz_pssk_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_a32_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv_group"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv_group"; }; +// ########################################################|HDim_q|HDim_v| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtne"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtna"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtz"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtne"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtna"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtz"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a16"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a16"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtne_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtna_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtz_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtne_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtna_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtz_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtne_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtna_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtz_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a16_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a16_pddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtne"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtna"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtz"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtne"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtna"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtz"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a16"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a32_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a16"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a32_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_a32_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_swa_a32_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtne_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtna_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtz_psskddv"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a32_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a32_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_br_a32_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtne_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtna_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtz_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtne_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtna_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_br_a32_rtz_pssk_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_a32_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv_group"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv_group"; }; -// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a16.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a16.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtne_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtna_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtz_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtne_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtna_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtz_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a16_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a16_pddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_br_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtne_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtna_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtz_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtne_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtna_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtz_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtne_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtna_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtz_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a16.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a32_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a16.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a32_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_br_a32_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_br_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_swa_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_br_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_br_a32_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_br_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_a32_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_a32_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_br_a32_psskddv_group.co"; }; +// ########################################################|HDim_q|HDim_v| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a16.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a16.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtne_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtna_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtz_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtne_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtna_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtz_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a16_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a16_pddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_br_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtne_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtna_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtz_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtne_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtna_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtz_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtne_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtna_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtz_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a16.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a32_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a16.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a32_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_br_a32_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_br_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_swa_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_br_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_br_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_br_a32_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_br_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_br_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_a32_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_a32_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_br_a32_psskddv_group.co"; }; -// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +// ########################################################|HDim_q|HDim_v| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; namespace gfx942{ class fmha_bwd_v3_kernel @@ -937,7 +937,7 @@ class fmha_bwd_v3_kernel if((t.is_group_mode == false) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ if(t.mask_type == mask_enum::no_mask){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_a32_psskddv"; if (is_v3_api_check) { @@ -949,7 +949,7 @@ class fmha_bwd_v3_kernel else if((((t.mask_type != mask_enum::no_mask) && (a.seqlen_q == a.seqlen_k)) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_causal_a32_psskddv"; if (is_v3_api_check) { @@ -960,7 +960,7 @@ class fmha_bwd_v3_kernel } else if((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_causal_br_a32_psskddv"; if (is_v3_api_check) { @@ -973,7 +973,7 @@ class fmha_bwd_v3_kernel else if((t.is_group_mode == true) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){//group mode if(t.mask_type == mask_enum::no_mask){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_a32_psskddv_group"; if (is_v3_api_check) { @@ -984,7 +984,7 @@ class fmha_bwd_v3_kernel } else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_top_left)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_causal_a32_psskddv_group"; if (is_v3_api_check) { @@ -995,7 +995,7 @@ class fmha_bwd_v3_kernel } else if((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_causal_br_a32_psskddv_group"; if (is_v3_api_check) { @@ -1011,7 +1011,7 @@ class fmha_bwd_v3_kernel if(t.mask_type == mask_enum::no_mask){ if(t.how_v3_bf16_cvt == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1022,7 +1022,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 1){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -1033,7 +1033,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 2){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -1047,7 +1047,7 @@ class fmha_bwd_v3_kernel ((a.window_size_left == -1) && (a.window_size_right == 0))){ if(t.how_v3_bf16_cvt == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1058,7 +1058,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 1){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -1069,7 +1069,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 2){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -1082,7 +1082,7 @@ class fmha_bwd_v3_kernel else if((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ if(t.how_v3_bf16_cvt == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_br_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1093,7 +1093,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 1){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_br_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -1104,7 +1104,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 2){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_br_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -1120,7 +1120,7 @@ class fmha_bwd_v3_kernel using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, true, true, true>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv_group"; if (is_v3_api_check) { return 1; @@ -1129,7 +1129,7 @@ class fmha_bwd_v3_kernel return r; } else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 1, true, true, true, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv_group"; if (is_v3_api_check) { return 1; @@ -1138,7 +1138,7 @@ class fmha_bwd_v3_kernel return r; } else if(t.how_v3_bf16_cvt == 2){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 2, true, true, true, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv_group"; if (is_v3_api_check) { return 1; @@ -1153,7 +1153,7 @@ class fmha_bwd_v3_kernel using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; if(t.how_v3_bf16_cvt == 0){ // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -1162,7 +1162,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 1){ // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 1, true, true, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -1171,7 +1171,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 2){ // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 2, true, true, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -1184,7 +1184,7 @@ class fmha_bwd_v3_kernel using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; if(t.how_v3_bf16_cvt == 0){ // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_br_a32_rtne_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -1193,7 +1193,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 1){ // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_br_a32_rtna_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 1, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 1, true, true, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -1202,7 +1202,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 2){ // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_br_a32_rtz_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 2, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 2, true, true, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -1221,7 +1221,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32"; if (is_v3_api_check) { @@ -1232,7 +1232,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv"; if (is_v3_api_check) { @@ -1243,7 +1243,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv"; if (is_v3_api_check) { @@ -1254,7 +1254,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv"; if (is_v3_api_check) { @@ -1265,7 +1265,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv"; if (is_v3_api_check) { @@ -1280,7 +1280,7 @@ class fmha_bwd_v3_kernel (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, false, 0, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a16"; if (is_v3_api_check) { return 1; @@ -1290,7 +1290,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, false, 0, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a16_pddv"; if (is_v3_api_check) { return 1; @@ -1304,7 +1304,7 @@ class fmha_bwd_v3_kernel if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_pssk_group"; if (is_v3_api_check) { @@ -1315,7 +1315,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv_group"; if (is_v3_api_check) { @@ -1332,7 +1332,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32"; if (is_v3_api_check) { @@ -1344,7 +1344,7 @@ class fmha_bwd_v3_kernel else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv"; if (is_v3_api_check) { @@ -1355,7 +1355,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv"; if (is_v3_api_check) { @@ -1366,7 +1366,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv"; if (is_v3_api_check) { @@ -1377,7 +1377,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv"; if (is_v3_api_check) { @@ -1390,7 +1390,7 @@ class fmha_bwd_v3_kernel else if(t.mask_type == mask_enum::mask_bottom_right){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_psskddv"; if (is_v3_api_check) { @@ -1401,7 +1401,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_psskddv"; if (is_v3_api_check) { @@ -1412,7 +1412,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_psskddv"; if (is_v3_api_check) { @@ -1423,7 +1423,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_psskddv"; if (is_v3_api_check) { @@ -1439,7 +1439,7 @@ class fmha_bwd_v3_kernel (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, false, 0, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16"; if (is_v3_api_check) { return 1; @@ -1449,7 +1449,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, false, 0, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16_pddv"; if (is_v3_api_check) { return 1; @@ -1463,7 +1463,7 @@ class fmha_bwd_v3_kernel if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1474,7 +1474,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1485,7 +1485,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv; if (is_v3_api_check) { @@ -1496,7 +1496,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1512,7 +1512,7 @@ class fmha_bwd_v3_kernel if(t.mask_type == mask_enum::mask_top_left){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_pssk_group"; if (is_v3_api_check) { @@ -1523,7 +1523,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv_group"; if (is_v3_api_check) { @@ -1536,7 +1536,7 @@ class fmha_bwd_v3_kernel else if(t.mask_type == mask_enum::mask_bottom_right){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_pssk_group"; if (is_v3_api_check) { @@ -1547,7 +1547,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_psskddv_group"; if (is_v3_api_check) { @@ -1568,7 +1568,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne"; if (is_v3_api_check) { @@ -1579,7 +1579,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1590,7 +1590,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1601,7 +1601,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1612,7 +1612,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1627,7 +1627,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 1, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna"; if (is_v3_api_check) { @@ -1638,7 +1638,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -1649,7 +1649,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -1660,7 +1660,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -1671,7 +1671,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -1686,7 +1686,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 2, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz"; if (is_v3_api_check) { @@ -1697,7 +1697,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -1708,7 +1708,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -1719,7 +1719,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -1730,7 +1730,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -1747,7 +1747,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.hdim_q == 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 0, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne"; if (is_v3_api_check) { return 1; @@ -1757,7 +1757,7 @@ class fmha_bwd_v3_kernel } else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 0, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne_pddv"; if (is_v3_api_check) { return 1; @@ -1769,7 +1769,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.hdim_q == 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 1, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna"; if (is_v3_api_check) { return 1; @@ -1779,7 +1779,7 @@ class fmha_bwd_v3_kernel } else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 1, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna_pddv"; if (is_v3_api_check) { return 1; @@ -1791,7 +1791,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.hdim_q == 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 2, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz"; if (is_v3_api_check) { return 1; @@ -1801,7 +1801,7 @@ class fmha_bwd_v3_kernel } else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 2, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz_pddv"; if (is_v3_api_check) { return 1; @@ -1817,7 +1817,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_pssk_group"; if (is_v3_api_check) { @@ -1828,7 +1828,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv_group"; if (is_v3_api_check) { @@ -1841,7 +1841,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 1, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtna_pssk_group"; if (is_v3_api_check) { @@ -1852,7 +1852,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 1, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtna_psskddv_group"; if (is_v3_api_check) { @@ -1865,7 +1865,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 2, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtz_pssk_group"; if (is_v3_api_check) { @@ -1876,7 +1876,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 2, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtz_psskddv_group"; if (is_v3_api_check) { @@ -1895,7 +1895,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne"; if (is_v3_api_check) { @@ -1907,7 +1907,7 @@ class fmha_bwd_v3_kernel else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1918,7 +1918,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1929,7 +1929,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1940,7 +1940,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1953,7 +1953,7 @@ class fmha_bwd_v3_kernel else if(t.mask_type == mask_enum::mask_bottom_right){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1964,7 +1964,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1975,7 +1975,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -1986,7 +1986,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -2002,7 +2002,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 1, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna"; if (is_v3_api_check) { @@ -2014,7 +2014,7 @@ class fmha_bwd_v3_kernel else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2025,7 +2025,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2036,7 +2036,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2047,7 +2047,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2060,7 +2060,7 @@ class fmha_bwd_v3_kernel else if(t.mask_type == mask_enum::mask_bottom_right){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2071,7 +2071,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2082,7 +2082,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2093,7 +2093,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2109,7 +2109,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 2, false, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz"; if (is_v3_api_check) { @@ -2121,7 +2121,7 @@ class fmha_bwd_v3_kernel else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2132,7 +2132,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2143,7 +2143,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2154,7 +2154,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2167,7 +2167,7 @@ class fmha_bwd_v3_kernel else if(t.mask_type == mask_enum::mask_bottom_right){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2178,7 +2178,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2189,7 +2189,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2200,7 +2200,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2218,7 +2218,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.hdim_q == 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 0, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne"; if (is_v3_api_check) { return 1; @@ -2228,7 +2228,7 @@ class fmha_bwd_v3_kernel } else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 0, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne_pddv"; if (is_v3_api_check) { return 1; @@ -2240,7 +2240,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.hdim_q == 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 1, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna"; if (is_v3_api_check) { return 1; @@ -2250,7 +2250,7 @@ class fmha_bwd_v3_kernel } else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 1, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna_pddv"; if (is_v3_api_check) { return 1; @@ -2262,7 +2262,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.hdim_q == 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 2, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz"; if (is_v3_api_check) { return 1; @@ -2272,7 +2272,7 @@ class fmha_bwd_v3_kernel } else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 2, false, true, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz_pddv"; if (is_v3_api_check) { return 1; @@ -2288,7 +2288,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -2299,7 +2299,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -2310,7 +2310,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv; if (is_v3_api_check) { @@ -2321,7 +2321,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv"; if (is_v3_api_check) { @@ -2334,7 +2334,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2345,7 +2345,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2356,7 +2356,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv; if (is_v3_api_check) { @@ -2367,7 +2367,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv"; if (is_v3_api_check) { @@ -2380,7 +2380,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2391,7 +2391,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2402,7 +2402,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv; if (is_v3_api_check) { @@ -2413,7 +2413,7 @@ class fmha_bwd_v3_kernel } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv"; if (is_v3_api_check) { @@ -2430,7 +2430,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_pssk_group"; if (is_v3_api_check) { @@ -2441,7 +2441,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv_group"; if (is_v3_api_check) { @@ -2454,7 +2454,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 1, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtna_pssk_group"; if (is_v3_api_check) { @@ -2465,7 +2465,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 1, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtna_psskddv_group"; if (is_v3_api_check) { @@ -2478,7 +2478,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 2, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtz_pssk_group"; if (is_v3_api_check) { @@ -2489,7 +2489,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 2, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtz_psskddv_group"; if (is_v3_api_check) { @@ -2504,7 +2504,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtne_pssk_group"; if (is_v3_api_check) { @@ -2515,7 +2515,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_br_a32_rtne_psskddv_group"; if (is_v3_api_check) { @@ -2528,7 +2528,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 1, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 1, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_rtna_pssk_group"; if (is_v3_api_check) { @@ -2539,7 +2539,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 1, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 1, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_rtna_psskddv_group"; if (is_v3_api_check) { @@ -2552,7 +2552,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.hdim_q == 128){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 2, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 2, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_32_rtz_pssk_group"; if (is_v3_api_check) { @@ -2563,7 +2563,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 2, true, true, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 2, true, true, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_br_a32_rtz_psskddv_group"; if (is_v3_api_check) { @@ -2584,7 +2584,7 @@ class fmha_bwd_v3_kernel if(t.is_group_mode == false){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk"; if (is_v3_api_check) { @@ -2595,7 +2595,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk"; if (is_v3_api_check) { @@ -2607,7 +2607,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group"; if (is_v3_api_check) { @@ -2621,7 +2621,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, false, 0, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16"; if (is_v3_api_check) { return 1; @@ -2636,7 +2636,7 @@ class fmha_bwd_v3_kernel if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; if (is_v3_api_check) { @@ -2647,7 +2647,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; if (is_v3_api_check) { @@ -2660,7 +2660,7 @@ class fmha_bwd_v3_kernel else if(t.mask_type == mask_enum::mask_bottom_right){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk"; if (is_v3_api_check) { @@ -2671,7 +2671,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk"; if (is_v3_api_check) { @@ -2685,7 +2685,7 @@ class fmha_bwd_v3_kernel else if(t.is_group_mode == true){ if(t.mask_type == mask_enum::mask_top_left){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group"; if (is_v3_api_check) { @@ -2696,7 +2696,7 @@ class fmha_bwd_v3_kernel } else if(t.mask_type == mask_enum::mask_bottom_right){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, 3, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, 3, true, 0, true, false, true, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk_group"; if (is_v3_api_check) { @@ -2711,7 +2711,7 @@ class fmha_bwd_v3_kernel (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, false, 0, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16"; if (is_v3_api_check) { return 1; @@ -2728,7 +2728,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; if (is_v3_api_check) { @@ -2739,7 +2739,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; if (is_v3_api_check) { @@ -2752,7 +2752,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; if (is_v3_api_check) { @@ -2763,7 +2763,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; if (is_v3_api_check) { @@ -2776,7 +2776,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; if (is_v3_api_check) { @@ -2787,7 +2787,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; if (is_v3_api_check) { @@ -2802,21 +2802,21 @@ class fmha_bwd_v3_kernel using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false, 0>; if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 0, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 1, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else{ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 2, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -2830,7 +2830,7 @@ class fmha_bwd_v3_kernel (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ if(t.how_v3_bf16_cvt == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, false, 0, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne"; if (is_v3_api_check) { return 1; @@ -2840,7 +2840,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 1){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, false, 1, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna"; if (is_v3_api_check) { return 1; @@ -2850,7 +2850,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 2){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, false, 2, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz"; if (is_v3_api_check) { return 1; @@ -2867,7 +2867,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; if (is_v3_api_check) { @@ -2878,7 +2878,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; if (is_v3_api_check) { @@ -2891,7 +2891,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; if (is_v3_api_check) { @@ -2902,7 +2902,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; if (is_v3_api_check) { @@ -2915,7 +2915,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; if (is_v3_api_check) { @@ -2926,7 +2926,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; if (is_v3_api_check) { @@ -2941,7 +2941,7 @@ class fmha_bwd_v3_kernel if(t.how_v3_bf16_cvt == 0){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_br_a32_rtne_pssk"; if (is_v3_api_check) { @@ -2952,7 +2952,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_br_a32_rtne_pssk"; if (is_v3_api_check) { @@ -2965,7 +2965,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_br_a32_rtna_pssk"; if (is_v3_api_check) { @@ -2976,7 +2976,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_br_a32_rtna_pssk"; if (is_v3_api_check) { @@ -2989,7 +2989,7 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_br_a32_rtz_pssk"; if (is_v3_api_check) { @@ -3000,7 +3000,7 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx942>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_br_a32_rtz_pssk"; if (is_v3_api_check) { @@ -3017,21 +3017,21 @@ class fmha_bwd_v3_kernel using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false, 0>; if(t.mask_type == mask_enum::mask_top_left){ if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 0, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 1, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else{ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 2, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -3041,21 +3041,21 @@ class fmha_bwd_v3_kernel } else if(t.mask_type == mask_enum::mask_bottom_right){ if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 0, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 0, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 1, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 1, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else{ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 2, true, false, true, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 2, true, false, true, GPUArch::gfx942>; if (is_v3_api_check) { return 1; } @@ -3070,7 +3070,7 @@ class fmha_bwd_v3_kernel (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ if(t.how_v3_bf16_cvt == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, false, 0, false, false, false, GPUArch::gfx942>; const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; if (is_v3_api_check) { return 1; @@ -3080,7 +3080,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 1){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, false, 1, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; if (is_v3_api_check) { return 1; @@ -3090,7 +3090,7 @@ class fmha_bwd_v3_kernel } else if(t.how_v3_bf16_cvt == 2){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false, false, false, GPUArch::gfx942>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, false, 2, false, false, false, GPUArch::gfx942>; // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; if (is_v3_api_check) { return 1; diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_dq_shuffle.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_dq_shuffle.co new file mode 100755 index 0000000000..c7b0616d37 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_dq_shuffle.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_a16_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_a16_pssk.co new file mode 100755 index 0000000000..6169f8f7e6 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_a16_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_a32_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_a32_pssk.co new file mode 100755 index 0000000000..7eca9acf92 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_a32_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_causal_a16_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_causal_a16_pssk.co new file mode 100755 index 0000000000..561e3f8c82 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_causal_a16_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_causal_a32_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_causal_a32_pssk.co new file mode 100755 index 0000000000..285d02ed79 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_bf16_causal_a32_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_a16_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_a16_pssk.co new file mode 100755 index 0000000000..8b6806da68 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_a16_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_a32_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_a32_pssk.co new file mode 100755 index 0000000000..506e083a38 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_a32_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_causal_a16_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_causal_a16_pssk.co new file mode 100755 index 0000000000..652eaea2dd Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_causal_a16_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_causal_a32_pssk.co b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_causal_a32_pssk.co new file mode 100755 index 0000000000..9e42768879 Binary files /dev/null and b/hsa/gfx950/fmha_v3_bwd/bwd_hd192_hd128_fp16_causal_a32_pssk.co differ diff --git a/hsa/gfx950/fmha_v3_bwd/codegen.py b/hsa/gfx950/fmha_v3_bwd/codegen.py index 65f582d20d..7c51f1b0e6 100644 --- a/hsa/gfx950/fmha_v3_bwd/codegen.py +++ b/hsa/gfx950/fmha_v3_bwd/codegen.py @@ -137,275 +137,314 @@ static constexpr int ts_dq = 64; }; -// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a16_rtne_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a16_rtna_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a16_rtz_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a16_rtne_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a16_rtna_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a16_rtz_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_a16_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_a32_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_a16_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_a32_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a16"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a32"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a16"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a32"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a16_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a32_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a16_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a16_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a32_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a32_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a16_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a32_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a16_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a16_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_a32_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_swa_a32_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_swa_a32_rtne_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_swa_a32_rtna_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_swa_a32_rtz_psskddv_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_a32_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_a32_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a16_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a32_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a16_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a16_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a32_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a32_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a16_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a32_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a16_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a16_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv_group"; }; // native gfx950 -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_a32_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv_group_recompile"; }; -template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv_group_recompile"; }; - -// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a16_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a16_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a16_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtne_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtna_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtz_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a16_rtne.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a16_rtna.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a16_rtz.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtne_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtna_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtz_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtne_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtna_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtz_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_a16.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_a32_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_a16.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_a32_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_br_a32_pssk.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a16.co"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a32.co"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a16.co"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a32.co"; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a16_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a32_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a16_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a16_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a32_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a32_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a16_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a32_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a16_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a16_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a32_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a32_psskddv.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_br_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_swa_a32_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtne_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtna_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtz_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_br_a32_pssk_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a16_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a32_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a16_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a16_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a32_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a32_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a16_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a32_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a16_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a16_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a32_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a32_psskddv_group.co"; }; // native gfx950 -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_a32_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_a32_psskddv_group.co"; }; -template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_br_a32_psskddv_group.co"; }; - -// ########################################################|HDim| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; -template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +// ########################################################|HDim_q|HDim_v| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a16_rtne_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a16_rtna_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a16_rtz_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a16_rtne_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a16_rtna_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a16_rtz_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_a16_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_a32_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_a16_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_a32_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a16"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a32"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a16"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a32"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a16_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a32_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a16_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a16_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a32_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a32_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a16_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a32_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a16_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a16_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_a32_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_swa_a32_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_swa_a32_rtne_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_swa_a32_rtna_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_swa_a32_rtz_psskddv_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtne_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtna_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_bf16_causal_br_a32_rtz_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_a32_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_a32_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd64_fp16_causal_br_a32_pssk_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a16_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_a32_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a16_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a16_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_br_a32_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_bf16_causal_a32_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a16_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_a32_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a16_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a16_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_br_a32_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv_group"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_a32_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv_group_recompile"; }; +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_fp16_causal_br_a32_psskddv_group_recompile"; }; + +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_bf16_a32_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_bf16_causal_a32_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_bf16_causal_br_a32_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_bf16_a16_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_bf16_causal_a16_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_bf16_causal_br_a16_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_fp16_a32_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_fp16_causal_a32_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_fp16_causal_br_a32_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_fp16_a16_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_fp16_causal_a16_pssk"; }; // native gfx950 +template<> struct FmhaBwdV3Name> { static constexpr const char * kernel_name = "fmha_bwd_hd192_hd128_fp16_causal_br_a16_pssk"; }; // native gfx950 + +// ########################################################|HDim_q|HDim_v| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a16_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a16_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a16_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtne_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtna_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtz_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a16_rtne.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a16_rtna.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a16_rtz.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtne_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtna_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtz_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtne_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtna_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtz_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_a16.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_a32_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_a16.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_a32_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_br_a32_pssk.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a16.co"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a32.co"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a16.co"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a32.co"; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a16_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a32_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a16_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a16_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a32_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a32_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a16_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a32_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a16_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a16_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a32_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a32_psskddv.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_br_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_swa_a32_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtne_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtna_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_bf16_causal_br_a32_rtz_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd64_fp16_causal_br_a32_pssk_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a16_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_a32_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a16_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a16_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_br_a32_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_fp16_causal_a32_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a16_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_a32_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a16_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a16_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_br_a32_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd128_bf16_causal_a32_psskddv_group.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtne_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtna_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_bf16_causal_br_a32_rtz_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_a32_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_a32_psskddv_group.co"; }; +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_fp16_causal_br_a32_psskddv_group.co"; }; + +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_bf16_a32_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_bf16_causal_a32_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_bf16_causal_br_a32_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_bf16_a16_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_bf16_causal_a16_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_bf16_causal_br_a16_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_fp16_a32_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_fp16_causal_a32_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_fp16_causal_br_a32_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_fp16_a16_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_fp16_causal_a16_pssk.co"; }; // native gfx950 +template<> struct FmhaBwdV3Buf> { static constexpr const char * file_name = "bwd_hd192_hd128_fp16_causal_br_a16_pssk.co"; }; // native gfx950 + +// ########################################################|HDim_q|HDim_v| DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode| GPUArch| +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950, currently not used +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 256; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }; + +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 +template<> struct FmhaBwdV3Ts> { static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }; // native gfx950 namespace gfx950{ class fmha_dq_shuffle_kernel @@ -682,13 +721,13 @@ class fmha_bwd_v3_kernel args.Seqs_kv = a.stride_k * 2; args.Seqs_dkv = a.stride_dk * 2; auto traits = fmha_bwd_v3_traits{a.batch, - a.nhead_q, - a.seqlen_q, - a.seqlen_k, - a.hdim_q, - a.mask_type, - FmhaBwdV3Ts::ts_qo, - FmhaBwdV3Ts::ts_kv}; + a.nhead_q, + a.seqlen_q, + a.seqlen_k, + a.hdim_q, + a.mask_type, + FmhaBwdV3Ts::ts_qo, + FmhaBwdV3Ts::ts_kv}; static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name::kernel_name, FmhaBwdV3Buf::file_name); // static here is for thread safety. @@ -729,13 +768,13 @@ class fmha_bwd_v3_kernel args.Seqs_dkv = a.stride_dk * 2; args.head_dim = a.hdim_q; auto traits = fmha_bwd_v3_traits{a.batch, - a.nhead_q, - a.seqlen_q, - a.seqlen_k, - a.hdim_q, - a.mask_type, - FmhaBwdV3Ts::ts_qo, - FmhaBwdV3Ts::ts_kv}; + a.nhead_q, + a.seqlen_q, + a.seqlen_k, + a.hdim_q, + a.mask_type, + FmhaBwdV3Ts::ts_qo, + FmhaBwdV3Ts::ts_kv}; static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name::kernel_name, FmhaBwdV3Buf::file_name); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, @@ -773,13 +812,13 @@ class fmha_bwd_v3_kernel args.Seqs_kv = a.stride_k * 2; args.Seqs_dkv = a.stride_dk * 2; auto traits = fmha_bwd_v3_traits{a.batch, - a.nhead_q, - a.seqlen_q, - a.seqlen_k, - a.hdim_q, - a.mask_type, - FmhaBwdV3Ts::ts_qo, - FmhaBwdV3Ts::ts_kv}; + a.nhead_q, + a.seqlen_q, + a.seqlen_k, + a.hdim_q, + a.mask_type, + FmhaBwdV3Ts::ts_qo, + FmhaBwdV3Ts::ts_kv}; static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name::kernel_name, FmhaBwdV3Buf::file_name); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, @@ -819,13 +858,13 @@ class fmha_bwd_v3_kernel args.Seqs_dkv = a.stride_dk * 2; args.head_dim = a.hdim_q; auto traits = fmha_bwd_v3_traits{a.batch, - a.nhead_q, - a.seqlen_q, - a.seqlen_k, - a.hdim_q, - a.mask_type, - FmhaBwdV3Ts::ts_qo, - FmhaBwdV3Ts::ts_kv}; + a.nhead_q, + a.seqlen_q, + a.seqlen_k, + a.hdim_q, + a.mask_type, + FmhaBwdV3Ts::ts_qo, + FmhaBwdV3Ts::ts_kv}; static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name::kernel_name, FmhaBwdV3Buf::file_name); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, @@ -876,13 +915,13 @@ class fmha_bwd_v3_kernel args.Seqs_dv = a.stride_dv * 2; auto traits = fmha_bwd_v3_traits{a.batch, - a.nhead_q, - a.seqlen_q, - a.seqlen_k, - a.hdim_q, - a.mask_type, - FmhaBwdV3Ts::ts_qo, - FmhaBwdV3Ts::ts_kv}; + a.nhead_q, + a.seqlen_q, + a.seqlen_k, + a.hdim_q, + a.mask_type, + FmhaBwdV3Ts::ts_qo, + FmhaBwdV3Ts::ts_kv}; static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name::kernel_name, FmhaBwdV3Buf::file_name); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, @@ -935,14 +974,14 @@ class fmha_bwd_v3_kernel args.Seqs_dv = a.stride_dv * 2; args.head_dim = a.hdim_q; - auto traits = fmha_bwd_v3_traits{ a.batch, - a.nhead_q, - a.max_seqlen_q, - a.max_seqlen_k, - a.hdim_q, - a.mask_type, - FmhaBwdV3Ts::ts_qo, - FmhaBwdV3Ts::ts_kv }; + auto traits = fmha_bwd_v3_traits{a.batch, + a.nhead_q, + a.max_seqlen_q, + a.max_seqlen_k, + a.hdim_q, + a.mask_type, + FmhaBwdV3Ts::ts_qo, + FmhaBwdV3Ts::ts_kv }; static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name::kernel_name, FmhaBwdV3Buf::file_name); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, @@ -1005,13 +1044,13 @@ class fmha_bwd_v3_kernel args.mask_x = generic_mask.at(ck_tile::number<1>{}); auto traits = fmha_bwd_v3_traits{a.batch, - a.nhead_q, - a.seqlen_q, - a.seqlen_k, - a.hdim_q, - a.mask_type, - FmhaBwdV3Ts::ts_qo, - FmhaBwdV3Ts::ts_kv}; + a.nhead_q, + a.seqlen_q, + a.seqlen_k, + a.hdim_q, + a.mask_type, + FmhaBwdV3Ts::ts_qo, + FmhaBwdV3Ts::ts_kv}; static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name::kernel_name, FmhaBwdV3Buf::file_name); // static here is for thread safety. return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, @@ -1028,41 +1067,42 @@ class fmha_bwd_v3_kernel if (is_v3_api_check) return 1; fmha_bwd_v3_args_gfx950 args; - args.ptr_dq = a.dq_acc_ptr; - args.ptr_dk = a.dk_ptr; - args.ptr_dv = a.dv_ptr; - args.ptr_q = a.q_ptr; - args.ptr_k = a.k_ptr; - args.ptr_v = a.v_ptr; - args.ptr_do = a.do_ptr; - args.ptr_lse = a.lse_ptr; - args.ptr_d = a.d_ptr; - args.scalar = a.scale; - args.log2e = ck_tile::log2e_v;; - args.ratio = a.nhead_q / a.nhead_k; - args.seqlen_q = a.seqlen_q; - args.seqlen_k = a.seqlen_k; - args.head_dim_q = a.hdim_q; - args.nhead_q = a.nhead_q; - args.Ts = FmhaBwdV3Ts::ts_kv * a.stride_k * 2; - args.Hs_q = a.nhead_stride_q * 2; - args.BAs_q = a.batch_stride_q * 2; - args.Seqs_q = a.stride_q * 2; - args.Hs_k = a.nhead_stride_k * 2; - args.BAs_k = a.batch_stride_k * 2; - args.Seqs_k = a.stride_k * 2; - args.Hs_v = a.nhead_stride_v * 2; - args.BAs_v = a.batch_stride_v * 2; - args.Seqs_v = a.stride_v * 2; - args.Hs_do = a.nhead_stride_do * 2; - args.BAs_do = a.batch_stride_do * 2; - args.Seqs_do = a.stride_do * 2; - args.Hs_dk = a.nhead_stride_dk * 2; - args.BAs_dk = a.batch_stride_dk * 2; - args.Seqs_dk = a.stride_dk * 2; - args.Hs_dv = a.nhead_stride_dv * 2; - args.BAs_dv = a.batch_stride_dv * 2; - args.Seqs_dv = a.stride_dv * 2; + args.ptr_dq = a.dq_acc_ptr; + args.ptr_dk = a.dk_ptr; + args.ptr_dv = a.dv_ptr; + args.ptr_q = a.q_ptr; + args.ptr_k = a.k_ptr; + args.ptr_v = a.v_ptr; + args.ptr_do = a.do_ptr; + args.ptr_lse = a.lse_ptr; + args.ptr_d = a.d_ptr; + args.scalar = a.scale; + args.log2e = ck_tile::log2e_v;; + args.ratio = a.nhead_q / a.nhead_k; + args.seqlen_q = a.seqlen_q; + args.seqlen_k = a.seqlen_k; + args.head_dim_q = a.hdim_q; + args.head_dim_v = a.hdim_v; + args.nhead_q = a.nhead_q; + args.Ts = FmhaBwdV3Ts::ts_kv * a.stride_k * 2; + args.Hs_q = a.nhead_stride_q * 2; + args.BAs_q = a.batch_stride_q * 2; + args.Seqs_q = a.stride_q * 2; + args.Hs_k = a.nhead_stride_k * 2; + args.BAs_k = a.batch_stride_k * 2; + args.Seqs_k = a.stride_k * 2; + args.Hs_v = a.nhead_stride_v * 2; + args.BAs_v = a.batch_stride_v * 2; + args.Seqs_v = a.stride_v * 2; + args.Hs_do = a.nhead_stride_do * 2; + args.BAs_do = a.batch_stride_do * 2; + args.Seqs_do = a.stride_do * 2; + args.Hs_dk = a.nhead_stride_dk * 2; + args.BAs_dk = a.batch_stride_dk * 2; + args.Seqs_dk = a.stride_dk * 2; + args.Hs_dv = a.nhead_stride_dv * 2; + args.BAs_dv = a.batch_stride_dv * 2; + args.Seqs_dv = a.stride_dv * 2; args.Hs_lsed = a.nhead_stride_lsed * 4; args.ptr_qseq = a.seqstart_q_ptr; args.ptr_kseq = a.seqstart_k_ptr; @@ -1076,8 +1116,8 @@ class fmha_bwd_v3_kernel auto traits = fmha_bwd_v3_traits{a.batch, a.nhead_q, - a.seqlen_q, - a.seqlen_k, + a.max_seqlen_q, // when batch mode, max_seqlen equal to seqlen + a.max_seqlen_k, // when batch mode, max_seqlen equal to seqlen a.hdim_q, a.mask_type, FmhaBwdV3Ts::ts_qo, @@ -1094,7 +1134,7 @@ class fmha_bwd_v3_kernel template float fmha_bwd_v3_genl_gfx950(const ck_tile::stream_config& s, fmha_bwd_args a, bool is_v3_api_check, const void* seqlen_q_padded = nullptr, const void* seqlen_k_padded = nullptr) { - using dq_shuffle_traits = dq_shuffle_traits_; + using dq_shuffle_traits = dq_shuffle_traits_; if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << FmhaBwdV3Name::kernel_name << ", " << dq_shuffle_traits::kernel_name() << std::flush; @@ -1116,6 +1156,7 @@ class fmha_bwd_v3_kernel args.seqlen_q = a.seqlen_q; args.seqlen_k = a.seqlen_k; args.head_dim_q = a.hdim_q; + args.head_dim_v = a.hdim_v; args.nhead_q = a.nhead_q; args.Ts = FmhaBwdV3Ts::ts_kv * a.stride_k * 2; args.Hs_q = a.nhead_stride_q * 2; @@ -1195,182 +1236,39 @@ class fmha_bwd_v3_kernel if (t.use_ext_asm == true){ if ((t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && - (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) && (a.nhead_q % a.nhead_k == 0)) { - if((a.hdim_q > 128) && (a.hdim_q <= 192) && (a.hdim_q % 8 == 0)){ - if(t.data_type.compare("fp16") == 0){ - if((t.is_group_mode == false) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ - if(t.mask_type == mask_enum::no_mask){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_fp16_a32_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else if((((t.mask_type != mask_enum::no_mask) && (a.seqlen_q == a.seqlen_k)) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))) && - ((a.window_size_left == -1) && (a.window_size_right == 0))){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_a32_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else if((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_br_a32_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - } - else if((t.is_group_mode == true) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){//group mode - if(t.mask_type == mask_enum::no_mask){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_fp16_a32_psskddv_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_top_left)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_a32_psskddv_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_bottom_right)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_br_a32_psskddv_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - } - } - else if(t.data_type.compare("bf16") == 0){ - if((t.is_group_mode == false) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ - if(t.mask_type == mask_enum::no_mask){ - if(t.how_v3_bf16_cvt == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else if(t.how_v3_bf16_cvt == 1){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else if(t.how_v3_bf16_cvt == 2){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - } - else if((((t.mask_type != mask_enum::no_mask) && (a.seqlen_q == a.seqlen_k)) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))) && - ((a.window_size_left == -1) && (a.window_size_right == 0))){ - if(t.how_v3_bf16_cvt == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else if(t.how_v3_bf16_cvt == 1){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else if(t.how_v3_bf16_cvt == 2){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - } - else if((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ - if(t.how_v3_bf16_cvt == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtne_psskddv"; + (t.is_deterministic == false) && (a.nhead_q % a.nhead_k == 0) && (a.hdim_q % 8 == 0) && (a.hdim_v % 8 == 0)) { + if(a.hdim_q == a.hdim_v){ + if((a.hdim_q > 128) && (a.hdim_q <= 192)){ + if(t.data_type.compare("fp16") == 0){ + if((t.is_group_mode == false) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if(t.mask_type == mask_enum::no_mask){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_fp16_a32_psskddv"; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_genl_(s, a); return r; } - else if(t.how_v3_bf16_cvt == 1){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtna_psskddv"; + else if((((t.mask_type != mask_enum::no_mask) && (a.seqlen_q == a.seqlen_k)) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))) && + ((a.window_size_left == -1) && (a.window_size_right == 0))){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_a32_psskddv"; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_genl_(s, a); return r; } - else if(t.how_v3_bf16_cvt == 2){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtz_psskddv"; + else if((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_br_a32_psskddv"; if (is_v3_api_check) { return 1; } @@ -1378,95 +1276,34 @@ class fmha_bwd_v3_kernel return r; } } - } - else if((t.is_group_mode == true) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){//group mode - if(t.mask_type == mask_enum::no_mask){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, true, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; - if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true, true, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - else if(t.how_v3_bf16_cvt == 2){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true, true, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - - } - else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_top_left)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, true, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; - if(t.how_v3_bf16_cvt == 0){ - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - else if(t.how_v3_bf16_cvt == 1){ - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - else if(t.how_v3_bf16_cvt == 2){ - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - } - else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_bottom_right)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, true, true, true>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; - if(t.how_v3_bf16_cvt == 0){ - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtne_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx950>; + else if((t.is_group_mode == true) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){//group mode + if(t.mask_type == mask_enum::no_mask){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_fp16_a32_psskddv_group"; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if(t.how_v3_bf16_cvt == 1){ - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtna_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 1, true, true, true, GPUArch::gfx950>; + else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_top_left)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_a32_psskddv_group"; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if(t.how_v3_bf16_cvt == 2){ - // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtz_psskddv_group"; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, 3, true, 2, true, true, true, GPUArch::gfx950>; + else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_bottom_right)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, true, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_fp16_causal_br_a32_psskddv_group"; if (is_v3_api_check) { return 1; } @@ -1475,894 +1312,907 @@ class fmha_bwd_v3_kernel } } } - } - } - else if ((a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ - if (t.data_type.compare("fp16") == 0){ - if (t.is_group_mode == false){ - if (t.mask_type == mask_enum::no_mask) { - if (t.is_v3_atomic_fp32 == true){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.data_type.compare("bf16") == 0){ + if((t.is_group_mode == false) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if(t.mask_type == mask_enum::no_mask){ + if(t.how_v3_bf16_cvt == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 1){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 1, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 2){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 2, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } } - else if (t.is_v3_atomic_fp32 == false){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if((((t.mask_type != mask_enum::no_mask) && (a.seqlen_q == a.seqlen_k)) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))) && + ((a.window_size_left == -1) && (a.window_size_right == 0))){ + if(t.how_v3_bf16_cvt == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 1){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 1, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 2){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 2, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } } - } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ + if(t.how_v3_bf16_cvt == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 1){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 1, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtna_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 2){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 2, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtz_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); return r; } - } else if (t.is_v3_atomic_fp32 == false){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + } + } + else if((t.is_group_mode == true) && (t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){//group mode + if(t.mask_type == mask_enum::no_mask){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; + if(t.how_v3_bf16_cvt == 0){ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv_group"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 1){ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 1, true, true, true, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv_group"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 2){ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, false, true, 2, true, true, true, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv_group"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } + } - } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - } else if (t.is_v3_atomic_fp32 == false){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_top_left)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; + if(t.how_v3_bf16_cvt == 0){ + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv_group"; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + else if(t.how_v3_bf16_cvt == 1){ + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv_group"; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 1, true, true, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - } - } else if (((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){ - if(t.is_v3_atomic_fp32 == true){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; + else if(t.how_v3_bf16_cvt == 2){ + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv_group"; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, true, true, 2, true, true, true, GPUArch::gfx950>; if (is_v3_api_check) { return 1; } - r = fmha_bwd_v3_swa_genl_(s, a); + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; + } + else if(((a.window_size_left == -1) && (a.window_size_right == 0)) && (t.mask_type == mask_enum::mask_bottom_right)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, true, true, true>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, true, true, true, false, 0>; + if(t.how_v3_bf16_cvt == 0){ + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtne_psskddv_group"; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx950>; if (is_v3_api_check) { return 1; } - r = fmha_bwd_v3_swa_genl_(s, a); + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv; + else if(t.how_v3_bf16_cvt == 1){ + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtna_psskddv_group"; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 1, true, true, true, GPUArch::gfx950>; if (is_v3_api_check) { return 1; } - r = fmha_bwd_v3_swa_genl_(s, a); + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; + else if(t.how_v3_bf16_cvt == 2){ + // const std::string kernel_name = "bwd_v3_hd192_bf16_causal_br_a32_rtz_psskddv_group"; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 192, FmhaBwdBf16, 3, true, 2, true, true, true, GPUArch::gfx950>; if (is_v3_api_check) { return 1; } - r = fmha_bwd_v3_swa_genl_(s, a); + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); return r; } } } } - else if (t.is_group_mode == true){ - if (t.mask_type == mask_enum::no_mask) { - if (t.is_v3_atomic_fp32 == true){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; - // const std::string bwd_v3_name = "bwd_hd128_fp16_a32_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } else { - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, true, true, true, GPUArch::gfx950>; - // const std::string bwd_v3_name = "bwd_hd128_fp16_a16_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } - } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; - // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_a32_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } else { - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, true, true, true, GPUArch::gfx950>; - // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_a16_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } - } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; - // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_br_a32_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } else { - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 3, false, 0, true, true, true, GPUArch::gfx950>; - // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_br_a16_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } - } - } } - else if(t.data_type.compare("bf16") == 0){ - if (t.is_group_mode == false){ - if (t.mask_type == mask_enum::no_mask) { - if (t.is_v3_atomic_fp32 == true){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - } - else if (t.is_v3_atomic_fp32 == false){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - } - } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; + else if ((a.hdim_q > 64) && (a.hdim_q <= 128) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if (t.data_type.compare("fp16") == 0){ + if (t.is_group_mode == false){ + if (t.mask_type == mask_enum::no_mask) { + if (t.is_v3_atomic_fp32 == true){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } } - } - else if (t.is_v3_atomic_fp32 == false){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; + else if (t.is_v3_atomic_fp32 == false){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; + } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + } else if (t.is_v3_atomic_fp32 == false){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; + } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + } else if (t.is_v3_atomic_fp32 == false){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_fp16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; + } else if (((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){ + if(t.is_v3_atomic_fp32 == true){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } } } - } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + } + else if (t.is_group_mode == true){ + if (t.mask_type == mask_enum::no_mask) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; + // const std::string bwd_v3_name = "bwd_hd128_fp16_a32_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + } else { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, false, false, 0, true, true, true, GPUArch::gfx950>; + // const std::string bwd_v3_name = "bwd_hd128_fp16_a16_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; + // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_a32_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + } else { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, true, false, 0, true, true, true, GPUArch::gfx950>; + // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_a16_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; } - } - else if (t.is_v3_atomic_fp32 == false){ - if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); - return r; - } - else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false, 0>; + // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_br_a32_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; - } - else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + } else { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdFp16, 3, false, 0, true, true, true, GPUArch::gfx950>; + // const std::string bwd_v3_name = "bwd_hd128_fp16_causal_br_a16_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; } } - } else if (((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){ - if(t.is_v3_atomic_fp32 == true){ - if(t.how_v3_bf16_cvt == 0){ + } + } + else if(t.data_type.compare("bf16") == 0){ + if (t.is_group_mode == false){ + if (t.mask_type == mask_enum::no_mask) { + if (t.is_v3_atomic_fp32 == true){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } } - else if(t.how_v3_bf16_cvt == 1){ + else if (t.is_v3_atomic_fp32 == false){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + } + } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } } - else if(t.how_v3_bf16_cvt == 2){ + else if (t.is_v3_atomic_fp32 == false){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + } + } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; - // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_swa_genl_(s, a); + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a32_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } } - } - } - } - else if (t.is_group_mode == true){ - if (t.mask_type == mask_enum::no_mask) { - if (t.is_v3_atomic_fp32 == true){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; - // const std::string bwd_v3_name = "bwd_hd128_bf16_a32_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } else { - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, true, true, true, GPUArch::gfx950>; - // const std::string bwd_v3_name = "bwd_hd128_bf16_a16_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } - } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; - // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_a32_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } else { - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, true, true, true, GPUArch::gfx950>; - // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_a16_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } - } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { - if (t.is_v3_atomic_fp32 == true){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; - // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_br_a32_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } else { - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 3, false, 0, true, true, true, GPUArch::gfx950>; - // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_br_a16_psskddv_group"; - r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); - return r; - } - } - } - } - } - else if(a.hdim_q == 64){ - if(t.data_type.compare("fp16") == 0){ - if(t.mask_type == mask_enum::no_mask){ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ - if(t.is_group_mode == false){ - if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_a32_pssk"; - if (is_v3_api_check) { - return 1; + else if (t.is_v3_atomic_fp32 == false){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_a32_pssk"; - if (is_v3_api_check) { - return 1; + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_a32_pssk_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - return r; - } - } - else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && - (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && - (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false, false, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_a16"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_(s, a); - return r; - } - } - else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ - if(t.is_group_mode == false){ - if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ - if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, false, 0, true, true, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd128_bf16_causal_br_a16_psskddv"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); return r; } } - else if(t.mask_type == mask_enum::mask_bottom_right){ - if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk"; - if (is_v3_api_check) { - return 1; + } else if (((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){ + if(t.is_v3_atomic_fp32 == true){ + if(t.how_v3_bf16_cvt == 0){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 0, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtne_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; } - r = fmha_bwd_v3_genl_(s, a); - return r; } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk"; - if (is_v3_api_check) { - return 1; + else if(t.how_v3_bf16_cvt == 1){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 1, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtna_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + } + else if(t.how_v3_bf16_cvt == 2){ + if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; + } + else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 2, true, 2, true, true, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd128_bf16_swa_a32_rtz_psskddv"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_swa_genl_(s, a); + return r; } - r = fmha_bwd_v3_genl_(s, a); - return r; } } } - else if(t.is_group_mode == true){ - if(t.mask_type == mask_enum::mask_top_left){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + } + else if (t.is_group_mode == true){ + if (t.mask_type == mask_enum::no_mask) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; + // const std::string bwd_v3_name = "bwd_hd128_bf16_a32_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; - } - else if(t.mask_type == mask_enum::mask_bottom_right){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, 3, true, 0, true, false, true, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk_group"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + } else { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, false, false, 0, true, true, true, GPUArch::gfx950>; + // const std::string bwd_v3_name = "bwd_hd128_bf16_a16_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); + return r; + } + } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; + // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_a32_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); + return r; + } else { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, true, false, 0, true, true, true, GPUArch::gfx950>; + // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_a16_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); + return r; + } + } else if ((t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, true, 0, true, true, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false, 0>; + // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_br_a32_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); + return r; + } else { + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, 128, FmhaBwdBf16, 3, false, 0, true, true, true, GPUArch::gfx950>; + // const std::string bwd_v3_name = "bwd_hd128_bf16_causal_br_a16_psskddv_group"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check, seqlen_q_padded, seqlen_k_padded); return r; } } } - else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && - (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && - (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false, false, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a16"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_(s, a); - return r; - } } } - else if(t.data_type.compare("bf16") == 0){ - if(t.mask_type == mask_enum::no_mask){ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ - if(t.is_group_mode == false){ - if(t.how_v3_bf16_cvt == 0){ - if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - } - else if(t.how_v3_bf16_cvt == 1){ - if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; - } - } - else if(t.how_v3_bf16_cvt == 2){ + else if(a.hdim_q == 64){ + if(t.data_type.compare("fp16") == 0){ + if(t.mask_type == mask_enum::no_mask){ + if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if(t.is_group_mode == false){ if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_a32_pssk"; if (is_v3_api_check) { return 1; } @@ -2370,10 +2220,10 @@ class fmha_bwd_v3_kernel return r; } else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_a32_pssk"; if (is_v3_api_check) { return 1; } @@ -2381,61 +2231,24 @@ class fmha_bwd_v3_kernel return r; } } - } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false, 0>; - if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - } - else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); - } else{ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, true, GPUArch::gfx950>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, true, 0, true, false, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_a32_pssk_group"; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + return r; } - return r; - } - } - else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && - (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && - (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ - if(t.how_v3_bf16_cvt == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false, false, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a16_rtne"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_(s, a); - return r; - } - else if(t.how_v3_bf16_cvt == 1){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false, false, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a16_rtna"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_(s, a); - return r; } - else if(t.how_v3_bf16_cvt == 2){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false, false, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_a16_rtz"; + else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && + (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && + (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, false, false, 0, false, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_a16"; if (is_v3_api_check) { return 1; } @@ -2443,17 +2256,15 @@ class fmha_bwd_v3_kernel return r; } } - } - else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ - if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ - if(t.is_group_mode == false){ - if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ - if(t.how_v3_bf16_cvt == 0){ + else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ + if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if(t.is_group_mode == false){ + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; if (is_v3_api_check) { return 1; } @@ -2461,10 +2272,10 @@ class fmha_bwd_v3_kernel return r; } else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a32_pssk"; if (is_v3_api_check) { return 1; } @@ -2472,12 +2283,12 @@ class fmha_bwd_v3_kernel return r; } } - else if(t.how_v3_bf16_cvt == 1){ + else if(t.mask_type == mask_enum::mask_bottom_right){ if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk"; if (is_v3_api_check) { return 1; } @@ -2485,10 +2296,10 @@ class fmha_bwd_v3_kernel return r; } else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, 3, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk"; if (is_v3_api_check) { return 1; } @@ -2496,38 +2307,56 @@ class fmha_bwd_v3_kernel return r; } } - else if(t.how_v3_bf16_cvt == 2){ - if(a.seqlen_q % 64 == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; + } + else if(t.is_group_mode == true){ + if(t.mask_type == mask_enum::mask_top_left){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, true, 0, true, false, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group"; + if (is_v3_api_check) { + return 1; } - else{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx950>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; - if (is_v3_api_check) { - return 1; - } - r = fmha_bwd_v3_genl_(s, a); - return r; + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + return r; + } + else if(t.mask_type == mask_enum::mask_bottom_right){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, 3, true, 0, true, false, true, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_br_a32_pssk_group"; + if (is_v3_api_check) { + return 1; } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + return r; } } - else if(t.mask_type == mask_enum::mask_bottom_right){ + } + else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && + (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && + (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdFp16, true, false, 0, false, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd64_fp16_causal_a16"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_(s, a); + return r; + } + } + } + else if(t.data_type.compare("bf16") == 0){ + if(t.mask_type == mask_enum::no_mask){ + if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if(t.is_group_mode == false){ if(t.how_v3_bf16_cvt == 0){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtne_pssk"; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; if (is_v3_api_check) { return 1; } @@ -2536,9 +2365,9 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 0, true, false, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtne_pssk"; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtne_pssk"; if (is_v3_api_check) { return 1; } @@ -2549,9 +2378,9 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 1){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtna_pssk"; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; if (is_v3_api_check) { return 1; } @@ -2560,9 +2389,9 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 1, true, false, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtna_pssk"; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtna_pssk"; if (is_v3_api_check) { return 1; } @@ -2573,9 +2402,9 @@ class fmha_bwd_v3_kernel else if(t.how_v3_bf16_cvt == 2){ if(a.seqlen_q % 64 == 0){ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtz_pssk"; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; if (is_v3_api_check) { return 1; } @@ -2584,9 +2413,9 @@ class fmha_bwd_v3_kernel } else{ using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 2, true, false, false, GPUArch::gfx950>; using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtz_pssk"; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a32_rtz_pssk"; if (is_v3_api_check) { return 1; } @@ -2595,27 +2424,25 @@ class fmha_bwd_v3_kernel } } } - } - else if(t.is_group_mode == true){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false, 0>; - if(t.mask_type == mask_enum::mask_top_left){ + else{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false, 0>; if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, true, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 0, true, false, true, GPUArch::gfx950>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, true, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 1, true, false, true, GPUArch::gfx950>; if (is_v3_api_check) { return 1; } r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } else{ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, true, GPUArch::gfx950>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, true, 2, true, false, true, GPUArch::gfx950>; if (is_v3_api_check) { return 1; } @@ -2623,64 +2450,356 @@ class fmha_bwd_v3_kernel } return r; } - else if(t.mask_type == mask_enum::mask_bottom_right){ - if(t.how_v3_bf16_cvt == 0){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 0, true, false, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; + } + else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && + (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && + (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ + if(t.how_v3_bf16_cvt == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, false, 0, false, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a16_rtne"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_(s, a); + return r; + } + else if(t.how_v3_bf16_cvt == 1){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, false, 1, false, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a16_rtna"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_(s, a); + return r; + } + else if(t.how_v3_bf16_cvt == 2){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, false, false, 2, false, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_a16_rtz"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_(s, a); + return r; + } + } + } + else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){ + if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if(t.is_group_mode == false){ + if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){ + if(t.how_v3_bf16_cvt == 0){ + if(a.seqlen_q % 64 == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + else{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + } + else if(t.how_v3_bf16_cvt == 1){ + if(a.seqlen_q % 64 == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + else{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 1, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + } + else if(t.how_v3_bf16_cvt == 2){ + if(a.seqlen_q % 64 == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + else{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 2, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } - else if(t.how_v3_bf16_cvt == 1){ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 1, true, false, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; + else if(t.mask_type == mask_enum::mask_bottom_right){ + if(t.how_v3_bf16_cvt == 0){ + if(a.seqlen_q % 64 == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtne_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + else{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtne_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + } + else if(t.how_v3_bf16_cvt == 1){ + if(a.seqlen_q % 64 == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtna_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + else{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 1, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtna_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + } + else if(t.how_v3_bf16_cvt == 2){ + if(a.seqlen_q % 64 == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtz_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } + else{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 2, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false, 0>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_br_a32_rtz_pssk"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_genl_(s, a); + return r; + } } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } - else{ - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, 3, true, 2, true, false, true, GPUArch::gfx950>; - if (is_v3_api_check) { - return 1; + } + else if(t.is_group_mode == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false, 0>; + if(t.mask_type == mask_enum::mask_top_left){ + if(t.how_v3_bf16_cvt == 0){ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 0, true, false, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); } - r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + else if(t.how_v3_bf16_cvt == 1){ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 1, true, false, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + } + else{ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, true, 2, true, false, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + } + return r; + } + else if(t.mask_type == mask_enum::mask_bottom_right){ + if(t.how_v3_bf16_cvt == 0){ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 0, true, false, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + } + else if(t.how_v3_bf16_cvt == 1){ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 1, true, false, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + } + else{ + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, 3, true, 2, true, false, true, GPUArch::gfx950>; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_group_(s, a, seqlen_q_padded, seqlen_k_padded); + } + return r; + } + } + } + else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && + (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && + (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ + if(t.how_v3_bf16_cvt == 0){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, false, 0, false, false, false, GPUArch::gfx950>; + const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_(s, a); + return r; + } + else if(t.how_v3_bf16_cvt == 1){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, false, 1, false, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; + if (is_v3_api_check) { + return 1; + } + r = fmha_bwd_v3_(s, a); + return r; + } + else if(t.how_v3_bf16_cvt == 2){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, 64, FmhaBwdBf16, true, false, 2, false, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; + if (is_v3_api_check) { + return 1; } + r = fmha_bwd_v3_(s, a); return r; } } } - else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) && - (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) && - (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){ - if(t.how_v3_bf16_cvt == 0){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false, false, false, GPUArch::gfx950>; - const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a16_rtne"; - if (is_v3_api_check) { - return 1; + } + } + } else { + if ((a.hdim_q == 192) && (a.hdim_v == 128) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){ + if (t.data_type.compare("fp16") == 0){ + if (t.is_group_mode == false){ + if (t.mask_type == mask_enum::no_mask) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdFp16, 0, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd192_hd128_fp16_a32_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if (t.is_v3_atomic_fp32 == false){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdFp16, 0, false, 0, true, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd192_hd128_fp16_a16_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + } else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdFp16, 1, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd192_hd128_fp16_causal_a32_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } else if (t.is_v3_atomic_fp32 == false){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdFp16, 1, false, 0, true, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd192_hd128_fp16_causal_a16_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; } - r = fmha_bwd_v3_(s, a); - return r; } - else if(t.how_v3_bf16_cvt == 1){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false, false, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a16_rtna"; - if (is_v3_api_check) { - return 1; + } + } + else if(t.data_type.compare("bf16") == 0){ + if (t.is_group_mode == false){ + if (t.mask_type == mask_enum::no_mask) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdBf16, 0, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd192_hd128_bf16_a32_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } + else if (t.is_v3_atomic_fp32 == false){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdBf16, 0, false, 0, true, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd192_hd128_bf16_a16_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; } - r = fmha_bwd_v3_(s, a); - return r; } - else if(t.how_v3_bf16_cvt == 2){ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>; - using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false, false, false, GPUArch::gfx950>; - // const std::string kernel_name = "bwd_v3_hd64_bf16_causal_a16_rtz"; - if (is_v3_api_check) { - return 1; + else if ((t.mask_type == mask_enum::mask_top_left) && ((a.window_size_left == -1) && (a.window_size_right == 0))) { + if (t.is_v3_atomic_fp32 == true){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdBf16, 1, true, 0, true, false, false, GPUArch::gfx950>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false, 0>; + // const std::string kernel_name = "bwd_hd192_hd128_bf16_causal_a32_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; + } else if (t.is_v3_atomic_fp32 == false){ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>; + using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, 128, FmhaBwdBf16, 1, false, 0, true, false, false, GPUArch::gfx950>; + // const std::string kernel_name = "bwd_hd192_hd128_bf16_causal_a16_pssk"; + r = fmha_bwd_v3_genl_gfx950(s, a, is_v3_api_check); + return r; } - r = fmha_bwd_v3_(s, a); - return r; } } } diff --git a/op_tests/cpp/mha/benchmark_mha_bwd.cpp b/op_tests/cpp/mha/benchmark_mha_bwd.cpp index f96e7ae7c1..8d719dd254 100644 --- a/op_tests/cpp/mha/benchmark_mha_bwd.cpp +++ b/op_tests/cpp/mha/benchmark_mha_bwd.cpp @@ -352,7 +352,8 @@ bool run(const ck_tile::ArgParser& arg_parser) deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1; const ck_tile::index_t a16_dq_acc_seq = v3_atomic_fp32 ? shape_seqlen_q : (mode == mode_enum::batch ? (seqlen_q + 15) / 16 * 16 : (max_seqlen_q + 15) / 16 * 16); - const ck_tile::index_t a16_dq_acc_hdim = v3_atomic_fp32 ? hdim_q : 128; + // hdim_q = 192 pipline currently don't support hdim padding + const ck_tile::index_t a16_dq_acc_hdim = v3_atomic_fp32 ? hdim_q : hdim_q == 192? 192: 128; ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); diff --git a/op_tests/cpp/mha/smoke_test_bwd_v3.sh b/op_tests/cpp/mha/smoke_test_bwd_v3.sh index bfe6fbb8f5..be219eca40 100644 --- a/op_tests/cpp/mha/smoke_test_bwd_v3.sh +++ b/op_tests/cpp/mha/smoke_test_bwd_v3.sh @@ -100,14 +100,19 @@ run_gfx950_bwd_v3() { for prec in "bf16" "fp16" ; do for mask in 0 1 2 ; do for v3_atomic_fp32 in 1 0 ; do + for hdim in 72 96 112 120 192 ; do for batch in 1 3 ; do for head in 2 4 ; do - for hdim in 72 96 112 120 ; do for sq in 13 62 174 ; do - for sk in 65 174 299 577 799; do + for sk in 65 174 299 577 799 ; do for perm in 0 1 ; do - $EXE -prec=$prec -b=$batch -h=$head -h_k=2 -d=$hdim -s=$sq -s_k=$sk -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -mode=0 -kname=$KNAME $COMMON_ARGS + hdim_v=$hdim + if [ $hdim -eq 192 ]; then + hdim_v=128 + fi + + $EXE -prec=$prec -b=$batch -h=$head -h_k=2 -d=$hdim -d_v=$hdim_v -s=$sq -s_k=$sk -iperm=$perm -operm=$perm -mask=$mask -bwd_v3=1 -v3_atomic_fp32=$v3_atomic_fp32 -mode=0 -kname=$KNAME $COMMON_ARGS done done diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 61bb1e89e8..a5b16e9574 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -178,6 +178,7 @@ def run_ck( (192, 192), (224, 224), (256, 256), + (192, 128), ], ) @pytest.mark.parametrize( @@ -449,6 +450,7 @@ def flash_attn_output_benchmark( (192, 192), (224, 224), (256, 256), + (192, 128), ], ) @pytest.mark.parametrize( @@ -710,7 +712,15 @@ def test_flash_attn_seq_padding( "-d_qk_v", type=dtypes.str2tuple, nargs="+", - default=[(32, 32), (40, 40), (64, 64), (111, 111), (128, 128), (160, 160)], + default=[ + (32, 32), + (40, 40), + (64, 64), + (111, 111), + (128, 128), + (160, 160), + (192, 128), + ], help="""Dimension of query and key. Default is None. e.g.: -qk_v 256,256""", )