Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions csrc/cpp_itfs/mha_bwd_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,64 @@

V2_API = "t = fmha_bwd(traits, args, stream_config);"

V3_MULTI_TARGET_API = """
if (get_gpu_arch() == "gfx942") {
t = gfx942::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check);
} else if (get_gpu_arch() == "gfx950") {
t = gfx950::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check);
} else {
std::cout << "No supported GPU arch found!" << std::endl;
return -1;
}
"""
V3_SUPPORTED_ARCH = ["gfx942", "gfx950"]


def _build_call(arch, call_args):
return f"t = {arch}::{call_args};"


def _build_no_supported_arch_block():
return [
f'std::cout << "No supported GPU arch found!" << std::endl;',
f"return -1;",
]


def _build_multi_target_api(supported_archs, call_args):
lines = []

if not supported_archs:
lines += _build_no_supported_arch_block()
return "\n".join(lines)

# First 'if'
first = supported_archs[0]
lines.append(
f'if (get_gpu_arch() == "{first}") {{\n'
f" {_build_call(first, call_args)}\n"
f"}}"
)
# Subsequent 'else if'
for arch in supported_archs[1:]:
lines.append(
f'else if (get_gpu_arch() == "{arch}") {{\n'
f" {_build_call(arch, call_args)}\n"
f"}}"
)
# Final 'else'
lines += [f"else {{"]
lines += _build_no_supported_arch_block()
lines += [f"}}", ""]
return "\n".join(lines)


def get_v3_api():
gfx_list = get_gfx_list()
if len(gfx_list) == 1:
return f"t = {gfx_list[0]}::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check);"
call_args = "fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check)"

# Find intersection of compile-time archs and supported archs
supported_gfx_list = [arch for arch in V3_SUPPORTED_ARCH if arch in gfx_list]

if len(supported_gfx_list) == 0:
# No supported arch compiled
return "\n".join(_build_no_supported_arch_block())
elif len(supported_gfx_list) == 1:
# Single arch: direct call
return _build_call(supported_gfx_list[0], call_args)
else:
return V3_MULTI_TARGET_API
# Multiple archs: build dispatch
return _build_multi_target_api(supported_gfx_list, call_args)


V3_API = get_v3_api()
Expand Down
66 changes: 53 additions & 13 deletions csrc/cpp_itfs/mha_fwd_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,24 +163,64 @@

V2_API = """t = fmha_fwd(traits, args, stream_config);"""

V3_MULTI_TARGET_API = """
if (get_gpu_arch() == "gfx942") {
t = gfx942::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check);
} else if (get_gpu_arch() == "gfx950") {
t = gfx950::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check);
} else {
std::cout << "No supported GPU arch found!" << std::endl;
return -1;
}
"""
V3_SUPPORTED_ARCH = ["gfx942", "gfx950"]


def _build_call(arch, call_args):
return f"t = {arch}::{call_args};"


def _build_no_supported_arch_block():
return [
f'std::cout << "No supported GPU arch found!" << std::endl;',
f"return -1;",
]


def _build_multi_target_api(supported_archs, call_args):
lines = []

if not supported_archs:
lines += _build_no_supported_arch_block()
return "\n".join(lines)

# First 'if'
first = supported_archs[0]
lines.append(
f'if (get_gpu_arch() == "{first}") {{\n'
f" {_build_call(first, call_args)}\n"
f"}}"
)
# Subsequent 'else if'
for arch in supported_archs[1:]:
lines.append(
f'else if (get_gpu_arch() == "{arch}") {{\n'
f" {_build_call(arch, call_args)}\n"
f"}}"
)
# Final 'else'
lines += [f"else {{"]
lines += _build_no_supported_arch_block()
lines += [f"}}", ""]
return "\n".join(lines)


def get_v3_api():
gfx_list = get_gfx_list()
if len(gfx_list) == 1:
return f"t = {gfx_list[0]}::fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check);"
call_args = "fmha_fwd_v3(traits, args, stream_config, seqstart_q_padding_ptr, seqstart_k_padding_ptr, is_v3_api_check)"

# Find intersection of compile-time archs and supported archs
supported_gfx_list = [arch for arch in V3_SUPPORTED_ARCH if arch in gfx_list]

if len(supported_gfx_list) == 0:
# No supported arch compiled
return "\n".join(_build_no_supported_arch_block())
elif len(supported_gfx_list) == 1:
# Single arch: direct call
return _build_call(supported_gfx_list[0], call_args)
else:
return V3_MULTI_TARGET_API
# Multiple archs: build dispatch
return _build_multi_target_api(supported_gfx_list, call_args)


V3_API = get_v3_api()
Expand Down