Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include <cstdio>
#include <cstdlib>
#include <string>

#include <hip/hip_runtime.h>

Expand Down Expand Up @@ -220,17 +222,16 @@
}}
}} // namespace
"""
FMHA_FWD_API_FOOTER_TEMPLATE = """
float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunreachable-code"
if ({F_is_v3_enabled}) {{
FMHA_FWD_API_FOOTER = """
float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {
const char* v3_env = std::getenv("CK_FMHA_ENABLE_V3");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.seqlen_q is the total concatenated length in group mode (see fmha_fwd_runner.hpp#L1239: args.seqlen_q = shape_seqlen_q; // unused in group mode). Use args.max_seqlen_q instead -- it correctly reflects the longest individual sequence length in both batch and group modes.

bool v3_enabled = v3_env && std::string(v3_env) == "1";
if (v3_enabled && args.seqlen_q > 1) {
float r = fmha_fwd_v3(traits, args, config);
if (r >= 0) return r;
}}
#pragma clang diagnostic pop
}
return fmha_fwd_v2(traits, args, config);
}}
}
"""

FMHA_FWD_API_PER_ARCH = """{F_if}({F_arch.device_name_check}) {{
Expand Down Expand Up @@ -1566,13 +1567,7 @@ def accept_only_v2(trait: FmhaFwdApiTrait) -> bool:
FMHA_FWD_API_HEADER,
api_pool.render("fmha_fwd_v2", filter_fn=accept_only_v2),
api_pool.render("fmha_fwd_v3", filter_fn=accept_only_v3),
FMHA_FWD_API_FOOTER_TEMPLATE.format(
F_is_v3_enabled=BOOL_MAP[
# NOTE: enable v3 pipelines when ready
0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
# False
]
),
FMHA_FWD_API_FOOTER,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 V3 dispatch is already enabled on develop -- F_is_v3_enabled resolves to true at codegen time because FP8 V3 traits are registered (line 1153: qr_async_trload_v3 for _DT_FP8BF16). Replacing this compile-time gate with a runtime env var that is off by default regresses FP8 V3 dispatch unnecessarily.

Consider keeping the original codegen-time gate and just adding a decode guard (args.max_seqlen_q > 1). If you want to conditionally enable bf16/fp16 V3 instances, the env var check belongs in get_hdim_tile_size_dict() and get_pipelines() where the bf16/fp16 V3 tile sizes and pipelines are currently commented out (lines 1090, 1141) -- guarding instance generation at codegen time rather than dispatch at runtime. e.g. os.environ.get("CK_FMHA_FWD_GENERATE_V3_BF16FP16", "0") == "1"

]
)
update_file(autogen_dir / FMHA_FWD_API_FILENAME, content)
Expand Down