Skip to content

[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950#6529

Open
goldcoderZ wants to merge 1 commit intoROCm:developfrom
goldcoderZ:meta/fmha-fwd-persistent-kernel
Open

[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950#6529
goldcoderZ wants to merge 1 commit intoROCm:developfrom
goldcoderZ:meta/fmha-fwd-persistent-kernel

Conversation

@goldcoderZ
Copy link
Copy Markdown

@goldcoderZ goldcoderZ commented Apr 17, 2026

[CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950

Motivation

Enable the existing V3 persistent kernel path for CK-Tile FMHA forward on
gfx950 (MI350X/MI355X). The V3 kernel and codegen infrastructure already
exist but are disabled via hardcoded F_is_v3_enabled=False.

This change replaces the compile-time gate with a runtime environment variable
CK_FMHA_ENABLE_V3=1 (disabled by default, opt-in). When enabled:

  • Prefill workloads (seqlen_q > 1) dispatch to V3 persistent pipeline
  • Decode workloads (seqlen_q == 1) always use V2 (memory-bound, better suited)

The V3 persistent kernel uses grid-stride scheduling, XCD-interleave tile
assignment for L2 locality, LPT reversal for causal masks, and gfx950 async
buffer loads.

Technical Details

Single file: example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

  • Add #include <cstdlib> and <string> for std::getenv
  • Replace {F_is_v3_enabled} template parameter with runtime env var check
  • Add seqlen_q > 1 guard (decode always uses V2)
  • Remove .format() call in write_fwd_api()

Dependencies

Depends on #6501 — builds on
XCD-interleave and LPT scheduling infrastructure.

Test Plan

  • GPU validation on MI350X (gfx950, ROCm 7.0):
  • Command (V2): ./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3
  • Command (V3): CK_FMHA_ENABLE_V3=1 ./build/bin/tile_example_fmha_fwd -b=2 -h=8 -s=4096 -d=128 -prec=bf16 -v=1 -warmup=1 -repeat=3
  • Command (decode, always V2): ./build/bin/tile_example_fmha_fwd -b=64 -h=32 -h_k=8 -s=1 -s_k=4096 -d=128 -prec=bf16 -mode=group -v=1 -warmup=1 -repeat=3

Test Result

Benchmark results (MI350X, gfx950, ROCm 7.0):

Config V2 (TFlops) V3 (TFlops) Speedup
Non-causal b=2 h=8 hk=2 s=4096 d=128 bf16 696.3 884.2 +27.0%
Causal b=2 h=8 hk=2 s=4096 d=128 bf16 371.3 494.9 +33.3%
GQA b=2 h=32 hk=8 s=2048 d=128 bf16 671.3 831.7 +23.9%
LLaMA-70B b=1 h=64 hk=8 s=4096 d=128 bf16 761.5 927.3 +21.8%
Causal GQA b=2 h=32 hk=8 s=2048 d=128 bf16 345.4 631.9 +82.9%
Long-seq b=1 h=16 s=16384 d=128 bf16 797.8 969.9 +21.6%
Decode b=64 h=32 hk=8 s=1 s_k=4096 bf16 1828 GB/s — (V2 path) unaffected

Additional validation:

  • CK_FMHA_ENABLE_V3=0 correctly falls back to V2 (default behavior unchanged)
  • CK_FMHA_ENABLE_V3=1 dispatches to V3 for prefill, V2 for decode
  • Validation passes across fp16/bf16, batch/group mode, causal/non-causal
  • No regression on decode path

Replace compile-time F_is_v3_enabled gate with runtime CK_FMHA_ENABLE_V3
environment variable check (opt-in, disabled by default). When enabled:
- Prefill workloads (seqlen_q > 1) dispatch to V3 persistent pipeline
- Decode workloads (seqlen_q == 1) always use V2

Also adds #include <cstdlib> and <string> for std::getenv usage.
@goldcoderZ goldcoderZ requested a review from a team as a code owner April 17, 2026 16:15
@assistant-librarian assistant-librarian bot added the external contribution Code contribution from users community.. label Apr 17, 2026
@goldcoderZ goldcoderZ changed the title [CK FMHA FWD] Enable V3 persistent kernel dispatch [CK_TILE] Enable V3 persistent kernel dispatch for FMHA forward on gfx950 Apr 17, 2026
if ({F_is_v3_enabled}) {{
FMHA_FWD_API_FOOTER = """
float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {
const char* v3_env = std::getenv("CK_FMHA_ENABLE_V3");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.seqlen_q is the total concatenated length in group mode (see fmha_fwd_runner.hpp#L1239: args.seqlen_q = shape_seqlen_q; // unused in group mode). Use args.max_seqlen_q instead -- it correctly reflects the longest individual sequence length in both batch and group modes.

# False
]
),
FMHA_FWD_API_FOOTER,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 V3 dispatch is already enabled on develop -- F_is_v3_enabled resolves to true at codegen time because FP8 V3 traits are registered (line 1153: qr_async_trload_v3 for _DT_FP8BF16). Replacing this compile-time gate with a runtime env var that is off by default regresses FP8 V3 dispatch unnecessarily.

Consider keeping the original codegen-time gate and just adding a decode guard (args.max_seqlen_q > 1). If you want to conditionally enable bf16/fp16 V3 instances, the env var check belongs in get_hdim_tile_size_dict() and get_pipelines() where the bf16/fp16 V3 tile sizes and pipelines are currently commented out (lines 1090, 1141) -- guarding instance generation at codegen time rather than dispatch at runtime. e.g. os.environ.get("CK_FMHA_FWD_GENERATE_V3_BF16FP16", "0") == "1"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

external contribution Code contribution from users community.. project: composablekernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants