-
Notifications
You must be signed in to change notification settings - Fork 267
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950 #6529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -164,6 +164,8 @@ | |
| // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n | ||
| // auto generated by generate.py | ||
| #include <cstdio> | ||
| #include <cstdlib> | ||
| #include <string> | ||
|
|
||
| #include <hip/hip_runtime.h> | ||
|
|
||
|
|
@@ -220,17 +222,16 @@ | |
| }} | ||
| }} // namespace | ||
| """ | ||
| FMHA_FWD_API_FOOTER_TEMPLATE = """ | ||
| float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{ | ||
| #pragma clang diagnostic push | ||
| #pragma clang diagnostic ignored "-Wunreachable-code" | ||
| if ({F_is_v3_enabled}) {{ | ||
| FMHA_FWD_API_FOOTER = """ | ||
| float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { | ||
| const char* v3_env = std::getenv("CK_FMHA_ENABLE_V3"); | ||
| bool v3_enabled = v3_env && std::string(v3_env) == "1"; | ||
| if (v3_enabled && args.seqlen_q > 1) { | ||
| float r = fmha_fwd_v3(traits, args, config); | ||
| if (r >= 0) return r; | ||
| }} | ||
| #pragma clang diagnostic pop | ||
| } | ||
| return fmha_fwd_v2(traits, args, config); | ||
| }} | ||
| } | ||
| """ | ||
|
|
||
| FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{ | ||
|
|
@@ -1566,13 +1567,7 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool: | |
| FMHA_FWD_API_HEADER, | ||
| api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2), | ||
| api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3), | ||
| FMHA_FWD_API_FOOTER_TEMPLATE.format( | ||
| F_is_v3_enabled=BOOL_MAP[ | ||
| # NOTE: enable v3 pipelines when ready | ||
| 0 < api_pool.get_num_traits(filter_fn=accept_only_v3) | ||
| # False | ||
| ] | ||
| ), | ||
| FMHA_FWD_API_FOOTER, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FP8 V3 dispatch is already enabled on develop -- Consider keeping the original codegen-time gate and just adding a decode guard ( |
||
| ] | ||
| ) | ||
| update_file(autogen_dir / FMHA_FWD_API_FILENAME, content) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args.seqlen_qis the total concatenated length in group mode (see fmha_fwd_runner.hpp#L1239:args.seqlen_q = shape_seqlen_q; // unused in group mode). Useargs.max_seqlen_qinstead -- it correctly reflects the longest individual sequence length in both batch and group modes.