diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna.co index e0db10bc9d..af10ab6df5 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna_group.co index adceeae5cd..ac1ca972ac 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna_group.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtna_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne.co index f64b868ae6..e4a46bd725 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne_group.co index 979ce35bf8..4b8000efeb 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne_group.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtne_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz.co index 00aad252b7..ab519bce8e 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz_group.co index 04a9ec90b0..fbd5eee308 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz_group.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_causal_rtz_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna.co index d577bb1157..092e402ad1 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna_group.co index 3bac18b0e5..48ac9e54a5 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna_group.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtna_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne.co index 538c8b148d..a63a8c2940 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne_group.co index 9c36a56a12..27c55d1937 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne_group.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtne_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz.co index 36573d04b1..6edbc54bbd 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz_group.co b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz_group.co index b9bcccdb1f..b6d3e01639 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz_group.co and b/hsa/gfx942/fmha_v3_fwd/MI300/fwd_hd128_bf16_rtz_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna.co index da6200fd29..69ab645de8 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna_group.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna_group.co index 45d352a095..9b97c147e9 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna_group.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtna_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne.co index c9945fe4ae..7f5c12bc05 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne_group.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne_group.co index 4e466f63f3..4032e8d161 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne_group.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtne_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz.co index 6b3d3d0f1f..302a658023 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz_group.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz_group.co index d4b171e339..6c2cb554b4 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz_group.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_causal_rtz_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna.co index 0d7ec4c0d5..fd75121550 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna_group.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna_group.co index f82fb0949f..0b4bf90955 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna_group.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtna_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne.co index 9aac653494..3f0351962b 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne_group.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne_group.co index 27b4c5b5f2..52f27a9f9d 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne_group.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtne_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz.co index 4cba024e94..a484b398bf 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz_group.co b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz_group.co index 1cb07b552e..00d8f51b7a 100755 Binary files a/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz_group.co and b/hsa/gfx942/fmha_v3_fwd/MI308/fwd_hd128_bf16_rtz_group.co differ diff --git a/hsa/gfx942/fmha_v3_fwd/codegen.py b/hsa/gfx942/fmha_v3_fwd/codegen.py index dba6bc95ba..497c8a6b87 100644 --- a/hsa/gfx942/fmha_v3_fwd/codegen.py +++ b/hsa/gfx942/fmha_v3_fwd/codegen.py @@ -142,7 +142,35 @@ class fmha_fwd_v3_kernel int gdx = ((fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo + tg_div - 1) / tg_div; int gdy = fmha_v3_traits.h; int gdz = fmha_v3_traits.b; + HIP_CALL(hipModuleLaunchKernel(kernel_func, + gdx, + gdy, + gdz, + bdx, + 1, + 1, + 0, + s.stream_id_, + NULL, + reinterpret_cast(&config))); + } + void + launch_kernel_group(fmha_fwd_v3_traits fmha_v3_traits, fmha_fwd_v3_args args, const ck_tile::stream_config& s) const + { + size_t arg_size = sizeof(args); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + &args, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &arg_size, + HIP_LAUNCH_PARAM_END}; + + int tg_div = (fmha_v3_traits.mask != 0) ? 2 : 1; + + int bdx = 512; + int gdx = fmha_v3_traits.h; + int gdy = fmha_v3_traits.b; + int gdz = ((fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo + tg_div - 1) / tg_div; HIP_CALL(hipModuleLaunchKernel(kernel_func, gdx, gdy, @@ -224,6 +252,69 @@ class fmha_fwd_v3_kernel ); } +template +float fmha_fwd_v3_group_dispatcher(const ck_tile::stream_config& s, mha_fwd_args a, + const void* seqstart_q_padding_ptr, const void* seqstart_k_padding_ptr) +{ + if(s.log_level_ > 0) + std::cout << ", " << FmhaFwdV3Name::fwd_v3_name << std::flush; + + int tune_opt = 5; + if (a.mask_type != 0 && ((a.nhead_q % 8 != 0) || (a.seqlen_q > 16384))) //if num_head is not 8N, or seqlen is bigger than 16K, downgrade to 2and3 + { + tune_opt -= 2; + } + + fmha_fwd_v3_args args; + args.ptr_o = a.o_ptr; + args.ptr_q = a.q_ptr; + args.ptr_k = a.k_ptr; + args.ptr_v = a.v_ptr; + args.ptr_lse = a.lse_ptr; + + args.scalar = a.scale_s; + args.s_seq_len = a.seqlen_q; + args.s_Seqs = a.stride_q * 2; + args.s_Ts = FmhaFwdV3Ts::ts_qo * a.stride_q * 2; + args.s_Hs = a.nhead_stride_q * 2; + args.s_Bs = a.batch_stride_q * 2; + args.s_gqa = a.nhead_q / a.nhead_k; + args.s_k_Seqs = a.stride_k * 2; + args.s_k_Hs = a.nhead_stride_k * 2; + args.s_k_Bs = a.batch_stride_k * 2; + args.s_opt = tune_opt; + args.s_lse = fmha_fwd_kernel_selector::kStoreLSE; + args.s_kv_seq_len = a.seqlen_k; + args.s_qk_head_dim = a.hdim_q; + args.s_v_head_dim = a.hdim_v; + args.s_q_head_num = a.nhead_q; + args.s_v_Seqs = a.stride_v * 2; + args.s_v_Hs = a.nhead_stride_v * 2; + args.s_v_Bs = a.batch_stride_v * 2; + args.s_o_Seqs = a.stride_o * 2; + args.s_o_Hs = a.nhead_stride_o * 2; + args.s_o_Bs = a.batch_stride_o * 2; + + args.s_lse_Hs = a.nhead_stride_lse * 4; + args.ptr_qseq = a.seqstart_q_ptr; + args.ptr_kseq = a.seqstart_k_ptr; + args.ptr_qseq_padding = seqstart_q_padding_ptr == nullptr ? a.seqstart_q_ptr : seqstart_q_padding_ptr; + args.ptr_kseq_padding = seqstart_k_padding_ptr == nullptr ? a.seqstart_k_ptr : seqstart_k_padding_ptr; + + auto traits = fmha_fwd_v3_traits{a.batch, + a.nhead_q, + a.seqlen_q, + a.hdim_q, + a.mask_type, + FmhaFwdV3Ts::ts_qo, + FmhaFwdV3Ts::ts_kv}; + + static thread_local fmha_fwd_v3_kernel impl(FmhaFwdV3Name::fwd_v3_name, FmhaFwdV3Buf::fwd_v3_buf); // static here is for thread safety. + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){ impl.launch_kernel_group(traits, args, s_); } + ); +} + float fmha_fwd_v3(mha_fwd_traits t, mha_fwd_args a, const ck_tile::stream_config& s, const void* seqstart_q_padding_ptr, const void* seqstart_k_padding_ptr, bool is_v3_api_check) { float r = -1; @@ -354,14 +445,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } else if(t.how_v3_bf16_cvt == 1) { @@ -370,14 +461,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } else if(t.how_v3_bf16_cvt == 2) { @@ -386,14 +477,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } } @@ -404,14 +495,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } else if(t.how_v3_bf16_cvt == 1) { @@ -420,14 +511,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } else if(t.how_v3_bf16_cvt == 2) { @@ -436,14 +527,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } } diff --git a/hsa/gfx950/fmha_v3_fwd/codegen.py b/hsa/gfx950/fmha_v3_fwd/codegen.py index 019562e363..e5ab464e12 100644 --- a/hsa/gfx950/fmha_v3_fwd/codegen.py +++ b/hsa/gfx950/fmha_v3_fwd/codegen.py @@ -135,6 +135,36 @@ class fmha_fwd_v3_kernel reinterpret_cast(&config))); } + void + launch_kernel_group(fmha_fwd_v3_traits fmha_v3_traits, fmha_fwd_v3_args args, const ck_tile::stream_config& s) const + { + size_t arg_size = sizeof(args); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + &args, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &arg_size, + HIP_LAUNCH_PARAM_END}; + + int tg_div = (fmha_v3_traits.mask != 0) ? 2 : 1; + + int bdx = (fmha_v3_traits.d == 192) ? 256 : 512; + int gdx = fmha_v3_traits.h; + int gdy = fmha_v3_traits.b; + int gdz = ((fmha_v3_traits.s + fmha_v3_traits.ts_qo - 1) / fmha_v3_traits.ts_qo + tg_div - 1) / tg_div; + + HIP_CALL(hipModuleLaunchKernel(kernel_func, + gdx, + gdy, + gdz, + bdx, + 1, + 1, + 0, + s.stream_id_, + NULL, + reinterpret_cast(&config))); + } + private: hipModule_t module; hipFunction_t kernel_func; @@ -203,6 +233,69 @@ class fmha_fwd_v3_kernel ); } +template +float fmha_fwd_v3_group_dispatcher(const ck_tile::stream_config& s, mha_fwd_args a, + const void* seqstart_q_padding_ptr, const void* seqstart_k_padding_ptr) +{ + if(s.log_level_ > 0) + std::cout << ", " << FmhaFwdV3Name::fwd_v3_name << std::flush; + + int tune_opt = 5; + if (a.mask_type != 0 && ((a.nhead_q % 8 != 0) || (a.seqlen_q > 16384))) //if num_head is not 8N, or seqlen is bigger than 16K, downgrade to 2and3 + { + tune_opt -= 2; + } + + fmha_fwd_v3_args args; + args.ptr_o = a.o_ptr; + args.ptr_q = a.q_ptr; + args.ptr_k = a.k_ptr; + args.ptr_v = a.v_ptr; + args.ptr_lse = a.lse_ptr; + + args.scalar = a.scale_s; + args.s_seq_len = a.seqlen_q; + args.s_Seqs = a.stride_q * 2; + args.s_Ts = FmhaFwdV3Ts::ts_qo * a.stride_q * 2; + args.s_Hs = a.nhead_stride_q * 2; + args.s_Bs = a.batch_stride_q * 2; + args.s_gqa = a.nhead_q / a.nhead_k; + args.s_k_Seqs = a.stride_k * 2; + args.s_k_Hs = a.nhead_stride_k * 2; + args.s_k_Bs = a.batch_stride_k * 2; + args.s_opt = tune_opt; + args.s_lse = fmha_fwd_kernel_selector::kStoreLSE; + args.s_kv_seq_len = a.seqlen_k; + args.s_qk_head_dim = a.hdim_q; + args.s_v_head_dim = a.hdim_v; + args.s_q_head_num = a.nhead_q; + args.s_v_Seqs = a.stride_v * 2; + args.s_v_Hs = a.nhead_stride_v * 2; + args.s_v_Bs = a.batch_stride_v * 2; + args.s_o_Seqs = a.stride_o * 2; + args.s_o_Hs = a.nhead_stride_o * 2; + args.s_o_Bs = a.batch_stride_o * 2; + + args.s_lse_Hs = a.nhead_stride_lse * 4; + args.ptr_qseq = a.seqstart_q_ptr; + args.ptr_kseq = a.seqstart_k_ptr; + args.ptr_qseq_padding = seqstart_q_padding_ptr == nullptr ? a.seqstart_q_ptr : seqstart_q_padding_ptr; + args.ptr_kseq_padding = seqstart_k_padding_ptr == nullptr ? a.seqstart_k_ptr : seqstart_k_padding_ptr; + + auto traits = fmha_fwd_v3_traits{a.batch, + a.nhead_q, + a.seqlen_q, + a.hdim_q, + a.mask_type, + FmhaFwdV3Ts::ts_qo, + FmhaFwdV3Ts::ts_kv}; + + static thread_local fmha_fwd_v3_kernel impl(FmhaFwdV3Name::fwd_v3_name, FmhaFwdV3Buf::fwd_v3_buf); // static here is for thread safety. + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){ impl.launch_kernel_group(traits, args, s_); } + ); +} + float fmha_fwd_v3(mha_fwd_traits t, mha_fwd_args a, const ck_tile::stream_config& s, const void* seqstart_q_padding_ptr, const void* seqstart_k_padding_ptr, bool is_v3_api_check) { float r = -1; @@ -256,14 +349,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } else if (t.mask_type == mask_enum::no_mask) { @@ -272,14 +365,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } } @@ -331,14 +424,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } else if (t.mask_type == mask_enum::no_mask) { @@ -347,14 +440,14 @@ class fmha_fwd_v3_kernel if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } else { using fmha_fwd_kernel = fmha_fwd_kernel_selector; if (is_v3_api_check) { return 1; } - r = fmha_fwd_v3_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); + r = fmha_fwd_v3_group_dispatcher(s, a, seqstart_q_padding_ptr, seqstart_k_padding_ptr); } } } diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal.co index 53af4806f4..2d8bf0bb7f 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal_group.co index 8dee38b802..35a76e1b82 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal_group.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_causal_group.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_group.co index 9505268402..21758854b5 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_group.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd128_bf16_group.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co index bb17f6d63c..e3a45201f4 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co index 42d94613ed..ec05c4ed25 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_causal_group.co differ diff --git a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co index f476539a2d..b2c6afde13 100755 Binary files a/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co and b/hsa/gfx950/fmha_v3_fwd/fwd_hd192_hd128_bf16_group.co differ diff --git a/op_tests/cpp/mha/smoke_test_fwd_v3.sh b/op_tests/cpp/mha/smoke_test_fwd_v3.sh index 008425d269..23a980d35e 100644 --- a/op_tests/cpp/mha/smoke_test_fwd_v3.sh +++ b/op_tests/cpp/mha/smoke_test_fwd_v3.sh @@ -23,8 +23,8 @@ run_gfx950_fwd_v3() { for o_perm in 0 1 ; do for mask in 0 2 ; do for lse in 0 1 ; do - for seqlen_q in 127 192 301 512; do - for seqlen_k in 512 700 1023; do + for seqlen_q in 127 192 301 512 1024; do + for seqlen_k in 512 700 1023 1058; do $EXE -prec=bf16 -b=2 -h=4 -h_k=2 -d=$head_dim -d_v=128 -s=$seqlen_q -s_k=$seqlen_k -iperm=$i_perm -operm=$o_perm -mask=$mask -lse=$lse -fwd_v3=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=bf16 -b=1 -h=3 -h_k=1 -d=$head_dim -d_v=128 -s=$seqlen_q -s_k=$seqlen_k -iperm=$i_perm -operm=$o_perm -mask=$mask -lse=$lse -fwd_v3=1 -mode=$mode -kname=$KNAME $COMMON_ARGS @@ -53,8 +53,8 @@ run_gfx942_fwd_v3() { for o_perm in 0 1 ; do for mask in 0 2 ; do for lse in 0 1 ; do - for seqlen_q in 127 192 301 512; do - for seqlen_k in 512 700 1023; do + for seqlen_q in 127 192 301 512 1024; do + for seqlen_k in 512 700 1023 1058; do for v3_bf16_cvt in 0 1 2; do $EXE -prec=bf16 -b=2 -h=4 -h_k=2 -d=128 -s=$seqlen_q -s_k=$seqlen_k -iperm=$i_perm -operm=$o_perm -mask=$mask -lse=$lse -fwd_v3=1 -v3_bf16_cvt=$v3_bf16_cvt -mode=$mode -kname=$KNAME $COMMON_ARGS