diff --git a/csrc/cpp_itfs/mha_bwd.cpp b/csrc/cpp_itfs/mha_bwd.cpp index e2f97f8541..ffd3ccc2a7 100644 --- a/csrc/cpp_itfs/mha_bwd.cpp +++ b/csrc/cpp_itfs/mha_bwd.cpp @@ -127,8 +127,9 @@ std::tuple get_heuristic_kernel(std::stri float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) { + float asm_ret = fmha_v3_bwd(a, s); #if ONLY_FAV3 - return fmha_v3_bwd(a, s); + return asm_ret; #else fmha_bwd_traits traits{a.hdim_q, a.hdim_v, @@ -225,11 +226,11 @@ float mha_bwd(mha_bwd_args a, const ck_tile::stream_config& s) /* drop_seed_offset */ a.drop_seed_offset, }; - float asm_ret = fmha_v3_bwd(a, s); if(asm_ret == -1) { return fmha_bwd(traits, ck_args, s); } + return asm_ret; #endif }