diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index ebf7f6f28e..8b94aeaba7 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1424,7 +1424,8 @@ def psskddv(): return ret # basic - ret = alibi_slopes is None + ret = get_gfx() == "gfx942" + ret &= alibi_slopes is None ret &= bias is None ret &= dbias is None ret &= dropout_p == 0.0 @@ -2035,7 +2036,8 @@ def psskddv(): def can_impl_fmha_v3_bwd(): # basic - ret = alibi_slopes is None + ret = get_gfx() == "gfx942" + ret &= alibi_slopes is None # ret &= bias is None # ret &= dbias is None ret &= dropout_p == 0.0 diff --git a/csrc/cpp_itfs/mha_bwd_generate.py b/csrc/cpp_itfs/mha_bwd_generate.py index 432a352edf..9ef3c4dc29 100644 --- a/csrc/cpp_itfs/mha_bwd_generate.py +++ b/csrc/cpp_itfs/mha_bwd_generate.py @@ -101,24 +101,31 @@ V2_API = "t = fmha_bwd(traits, args, stream_config);" -V3_MULTI_TARGET_API = """ - if (get_gpu_arch() == "gfx942") { - t = gfx942::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check); - } else if (get_gpu_arch() == "gfx950") { - t = gfx950::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check); - } else { - std::cout << "No supported GPU arch found!" << std::endl; - return -1; - } -""" - def get_v3_api(): + v3_call = "fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check)" gfx_list = get_gfx_list() + v3_arch_list = [arch for arch in ["gfx942", "gfx950"] if arch in gfx_list] + + if len(v3_arch_list) == 0: + return "" # no v3 support if len(gfx_list) == 1: - return f"t = {gfx_list[0]}::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check);" - else: - return V3_MULTI_TARGET_API + return f"t = {gfx_list[0]}::{v3_call};" + + api = """{ + const std::string gpu_arch = get_gpu_arch();""" + for arch in v3_arch_list: + api = ( + api + + f""" + if (gpu_arch == "{arch}") {{ t = {arch}::{v3_call}; }}""" + ) + api = ( + api + + """ + }""" + ) + return api V3_API = get_v3_api() diff --git a/csrc/cpp_itfs/mha_fwd_generate.py b/csrc/cpp_itfs/mha_fwd_generate.py index 48ee4d6939..5fe5190064 100644 --- a/csrc/cpp_itfs/mha_fwd_generate.py +++ b/csrc/cpp_itfs/mha_fwd_generate.py @@ -163,24 +163,31 @@ V2_API = """t = fmha_fwd(traits, args, stream_config);""" -V3_MULTI_TARGET_API = """ - if (get_gpu_arch() == "gfx942") { - t = gfx942::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check); - } else if (get_gpu_arch() == "gfx950") { - t = gfx950::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check); - } else { - std::cout << "No supported GPU arch found!" << std::endl; - return -1; - } -""" - def get_v3_api(): + v3_call = "fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check)" gfx_list = get_gfx_list() + v3_arch_list = [arch for arch in ["gfx942", "gfx950"] if arch in gfx_list] + + if len(v3_arch_list) == 0: + return "" # no v3 support if len(gfx_list) == 1: - return f"t = {gfx_list[0]}::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check);" - else: - return V3_MULTI_TARGET_API + return f"t = {gfx_list[0]}::{v3_call};" + + api = """{ + const std::string gpu_arch = get_gpu_arch();""" + for arch in v3_arch_list: + api = ( + api + + f""" + if (gpu_arch == "{arch}") {{ t = {arch}::{v3_call}; }}""" + ) + api = ( + api + + """ + }""" + ) + return api V3_API = get_v3_api()