[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950#6529
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950#6529goldcoderZ wants to merge 1 commit intoROCm:developfrom
Conversation
Replace compile-time F_is_v3_enabled gate with runtime CK_FMHA_ENABLE_V3 environment variable check (opt-in, disabled by default). When enabled: - Prefill workloads (seqlen_q > 1) dispatch to V3 persistent pipeline - Decode workloads (seqlen_q == 1) always use V2 Also adds #include <cstdlib> and <string> for std::getenv usage.
| if ({F_is_v3_enabled}) {{ | ||
| FMHA_FWD_API_FOOTER = """ | ||
| float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) { | ||
| const char* v3_env = std::getenv("CK_FMHA_ENABLE_V3"); |
There was a problem hiding this comment.
args.seqlen_q is the total concatenated length in group mode (see fmha_fwd_runner.hpp#L1239: args.seqlen_q = shape_seqlen_q; // unused in group mode). Use args.max_seqlen_q instead -- it correctly reflects the longest individual sequence length in both batch and group modes.
| # False | ||
| ] | ||
| ), | ||
| FMHA_FWD_API_FOOTER, |
There was a problem hiding this comment.
FP8 V3 dispatch is already enabled on develop -- F_is_v3_enabled resolves to true at codegen time because FP8 V3 traits are registered (line 1153: qr_async_trload_v3 for _DT_FP8BF16). Replacing this compile-time gate with a runtime env var that is off by default regresses FP8 V3 dispatch unnecessarily.
Consider keeping the original codegen-time gate and just adding a decode guard (args.max_seqlen_q > 1). If you want to conditionally enable bf16/fp16 V3 instances, the env var check belongs in get_hdim_tile_size_dict() and get_pipelines() where the bf16/fp16 V3 tile sizes and pipelines are currently commented out (lines 1090, 1141) -- guarding instance generation at codegen time rather than dispatch at runtime. e.g. os.environ.get("CK_FMHA_FWD_GENERATE_V3_BF16FP16", "0") == "1"
[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950
Motivation
Enable the existing V3 persistent kernel path for CK-Tile FMHA forward on
gfx950 (MI350X/MI355X). The V3 kernel and codegen infrastructure already
exist but are disabled via hardcoded
F_is_v3_enabled=False.This change replaces the compile-time gate with a runtime environment variable
CK_FMHA_ENABLE_V3=1(disabled by default, opt-in). When enabled:The V3 persistent kernel uses grid-stride scheduling, XCD-interleave tile
assignment for L2 locality, LPT reversal for causal masks, and gfx950 async
buffer loads.
Technical Details
Single file:
example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py#include <cstdlib>and<string>forstd::getenv{F_is_v3_enabled}template parameter with runtime env var checkseqlen_q > 1guard (decode always uses V2).format()call inwrite_fwd_api()Dependencies
Depends on #6501 — builds on
XCD-interleave and LPT scheduling infrastructure.
Test Plan
./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3CK_FMHA_ENABLE_V3=1 ./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3./build/bin/tile_example_fmha_fwd -b=64 -h=32 -h_k=8 -s=1 -s_k=4096 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1 -repeat=3Test Result
Benchmark results (MI350X, gfx950, ROCm 7.0):
Additional validation:
CK_FMHA_ENABLE_V3=0correctly falls back to V2 (default behavior unchanged)CK_FMHA_ENABLE_V3=1dispatches to V3 for prefill, V2 for decode