From c1a380ca0db6a7c0e993551b1140774d06f7576e Mon Sep 17 00:00:00 2001 From: jingchao Date: Wed, 17 Dec 2025 18:02:36 +0800 Subject: [PATCH] fix mha bwd golden perf issue --- op_tests/cpp/mha/benchmark_mha_bwd.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/op_tests/cpp/mha/benchmark_mha_bwd.cpp b/op_tests/cpp/mha/benchmark_mha_bwd.cpp index 82cf9769a7..b8829a1372 100644 --- a/op_tests/cpp/mha/benchmark_mha_bwd.cpp +++ b/op_tests/cpp/mha/benchmark_mha_bwd.cpp @@ -953,17 +953,24 @@ bool run(const ck_tile::ArgParser& arg_parser) } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) - ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { - AccDataType do_dot_o = 0; + // Precompute dO_i dot O_i for each (head, seq_q) to avoid redundant computation + // This reduces complexity from O(nhead * seqlen_q * seqlen_k * hdim_v) to + // O(nhead * seqlen_q * hdim_v) + O(nhead * seqlen_q * seqlen_k) + ck_tile::HostTensor do_dot_o_ref({nhead, real_seqlen_q}); + do_dot_o_ref.ForEach([&](auto& self, auto idx_gm) { + AccDataType sum = 0; for(int o = 0; o < hdim_v; o++) { - auto idx_gmo = idx_gmn; - idx_gmo[2] = o; - do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * - ck_tile::type_convert(o_host_refs[wb](idx_gmo)); + sum += ck_tile::type_convert(do_host_ref(idx_gm[0], idx_gm[1], o)) * + ck_tile::type_convert(o_host_refs[wb](idx_gm[0], idx_gm[1], o)); } + self(idx_gm) = sum; + }); + + ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { self(idx_gmn) = ck_tile::type_convert( - p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); + p_hp_host_refs[wb](idx_gmn) * + (dp_hp_host_ref(idx_gmn) - do_dot_o_ref(idx_gmn[0], idx_gmn[1]))); }); if(use_dbias)