diff --git a/csrc/cpp_itfs/mha_bwd_generate.py b/csrc/cpp_itfs/mha_bwd_generate.py index 432a352edf..bf0fe15283 100644 --- a/csrc/cpp_itfs/mha_bwd_generate.py +++ b/csrc/cpp_itfs/mha_bwd_generate.py @@ -101,24 +101,64 @@ V2_API = "t = fmha_bwd(traits, args, stream_config);" -V3_MULTI_TARGET_API = """ - if (get_gpu_arch() == "gfx942") { - t = gfx942::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check); - } else if (get_gpu_arch() == "gfx950") { - t = gfx950::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check); - } else { - std::cout << "No supported GPU arch found!" << std::endl; - return -1; - } -""" +V3_SUPPORTED_ARCH = ["gfx942", "gfx950"] + + +def _build_call(arch, call_args): + return f"t = {arch}::{call_args};" + + +def _build_no_supported_arch_block(): + return [ + f'std::cout << "No supported GPU arch found!" << std::endl;', + f"return -1;", + ] + + +def _build_multi_target_api(supported_archs, call_args): + lines = [] + + if not supported_archs: + lines += _build_no_supported_arch_block() + return "\n".join(lines) + + # First 'if' + first = supported_archs[0] + lines.append( + f'if (get_gpu_arch() == "{first}") {{\n' + f" {_build_call(first, call_args)}\n" + f"}}" + ) + # Subsequent 'else if' + for arch in supported_archs[1:]: + lines.append( + f'else if (get_gpu_arch() == "{arch}") {{\n' + f" {_build_call(arch, call_args)}\n" + f"}}" + ) + # Final 'else' + lines += [f"else {{"] + lines += _build_no_supported_arch_block() + lines += [f"}}", ""] + return "\n".join(lines) def get_v3_api(): gfx_list = get_gfx_list() - if len(gfx_list) == 1: - return f"t = {gfx_list[0]}::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check);" + call_args = "fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check)" + + # Find intersection of compile-time archs and supported archs + supported_gfx_list = [arch for arch in V3_SUPPORTED_ARCH if arch in gfx_list] + + if len(supported_gfx_list) == 0: + # No supported arch compiled + return "\n".join(_build_no_supported_arch_block()) + elif len(supported_gfx_list) == 1: + # Single arch: direct call + return _build_call(supported_gfx_list[0], call_args) else: - return V3_MULTI_TARGET_API + # Multiple archs: build dispatch + return _build_multi_target_api(supported_gfx_list, call_args) V3_API = get_v3_api() diff --git a/csrc/cpp_itfs/mha_fwd_generate.py b/csrc/cpp_itfs/mha_fwd_generate.py index 48ee4d6939..16d4d9a2c1 100644 --- a/csrc/cpp_itfs/mha_fwd_generate.py +++ b/csrc/cpp_itfs/mha_fwd_generate.py @@ -163,24 +163,64 @@ V2_API = """t = fmha_fwd(traits, args, stream_config);""" -V3_MULTI_TARGET_API = """ - if (get_gpu_arch() == "gfx942") { - t = gfx942::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check); - } else if (get_gpu_arch() == "gfx950") { - t = gfx950::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check); - } else { - std::cout << "No supported GPU arch found!" << std::endl; - return -1; - } -""" +V3_SUPPORTED_ARCH = ["gfx942", "gfx950"] + + +def _build_call(arch, call_args): + return f"t = {arch}::{call_args};" + + +def _build_no_supported_arch_block(): + return [ + f'std::cout << "No supported GPU arch found!" << std::endl;', + f"return -1;", + ] + + +def _build_multi_target_api(supported_archs, call_args): + lines = [] + + if not supported_archs: + lines += _build_no_supported_arch_block() + return "\n".join(lines) + + # First 'if' + first = supported_archs[0] + lines.append( + f'if (get_gpu_arch() == "{first}") {{\n' + f" {_build_call(first, call_args)}\n" + f"}}" + ) + # Subsequent 'else if' + for arch in supported_archs[1:]: + lines.append( + f'else if (get_gpu_arch() == "{arch}") {{\n' + f" {_build_call(arch, call_args)}\n" + f"}}" + ) + # Final 'else' + lines += [f"else {{"] + lines += _build_no_supported_arch_block() + lines += [f"}}", ""] + return "\n".join(lines) def get_v3_api(): gfx_list = get_gfx_list() - if len(gfx_list) == 1: - return f"t = {gfx_list[0]}::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check);" + call_args = "fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check)" + + # Find intersection of compile-time archs and supported archs + supported_gfx_list = [arch for arch in V3_SUPPORTED_ARCH if arch in gfx_list] + + if len(supported_gfx_list) == 0: + # No supported arch compiled + return "\n".join(_build_no_supported_arch_block()) + elif len(supported_gfx_list) == 1: + # Single arch: direct call + return _build_call(supported_gfx_list[0], call_args) else: - return V3_MULTI_TARGET_API + # Multiple archs: build dispatch + return _build_multi_target_api(supported_gfx_list, call_args) V3_API = get_v3_api()