Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,7 +1416,8 @@ def psskddv():
return ret

# basic
ret = alibi_slopes is None
ret = get_gfx() == "gfx942"
ret &= alibi_slopes is None
ret &= bias is None
ret &= dbias is None
ret &= dropout_p == 0.0
Expand Down Expand Up @@ -2028,7 +2029,8 @@ def psskddv():

def can_impl_fmha_v3_bwd():
# basic
ret = alibi_slopes is None
ret = get_gfx() == "gfx942"
ret &= alibi_slopes is None
# ret &= bias is None
# ret &= dbias is None
ret &= dropout_p == 0.0
Expand Down
8 changes: 5 additions & 3 deletions csrc/cpp_itfs/mha_bwd_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@

def get_v3_api():
gfx_list = get_gfx_list()
if len(gfx_list) == 1:
return f"t = {gfx_list[0]}::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check);"
else:
if "gfx942" in gfx_list and "gfx950" in gfx_list:
return V3_MULTI_TARGET_API
elif "gfx942" in gfx_list:
return """if (get_gpu_arch() == "gfx942") { t = gfx942::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check); }"""
elif "gfx950" in gfx_list:
return """if (get_gpu_arch() == "gfx950") { t = gfx950::fmha_bwd_v3(traits, args, stream_config, seqlen_q_padded, seqlen_k_padded, is_v3_api_check); }"""


V3_API = get_v3_api()
Expand Down