Skip to content

[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher#5260

Open
vidyasagar-amd wants to merge 26 commits intodevelopfrom
users/vanantha/ck/dispatcher-fmha
Open

[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher#5260
vidyasagar-amd wants to merge 26 commits intodevelopfrom
users/vanantha/ck/dispatcher-fmha

Conversation

@vidyasagar-amd
Copy link
Copy Markdown
Contributor

Motivation

The CK Tile dispatcher currently supports GEMM and Grouped Convolution but has no support for Fused Multi-Head Attention (FMHA). The example/ck_tile/01_fmha folder contains a comprehensive FMHA implementation with forward, backward, split-KV, paged-KV, append-KV, and batch-prefill kernels across multiple GPU architectures — but there is no unified dispatch layer for it. This PR ports the FMHA stack into the dispatcher, following the same architectural patterns established by GEMM and Grouped Convolution, enabling runtime kernel selection, JIT compilation from Python, and a declarative C++ example flow. Autotuning heuristics to follow.

Technical Details

This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring GEMM's layered architecture. Seven new C++ runtime headers provide type definitions (coexisting with upstream headers via __has_include, requiring zero modifications to example/ck_tile/01_fmha/), a problem builder with 18+ setters, Signature + Algorithm kernel key matching, a virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard support and named tile/wave/warp setters, arch-aware registry with JSON export, and a dispatcher with seqtune-aware selection, configurable timing, and multi-stage execution plans for split-KV (two-stage) and backward (three-stage). The codegen pipeline is driven by a fmha_arch_specs.json capturing per-arch tile tables and pipeline constraints for five architectures (gfx90a/942/950/1100/1201), migrated from hardcoded logic in 01_fmha/codegen/, with supporting modules for C++ symbol mappings, validation rules, and named receipt profiles (ck_default, flash, pytorch, aiter, fp32, fp8). Python integration (fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel multi-kernel builds, HIP memory management via ctypes, tolerance-based validation, and a NumPy CPU reference with GQA support. Twenty-seven C++ and thirty-two Python examples cover the full feature surface — forward, split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill, fp8, logits soft cap, sink tokens, and parameter sweeps — all JIT-compiled on the fly.

Test Plan

Seven test files cover the runtime types, codegen, and end-to-end correctness. C++ unit tests validate the problem builder, dispatcher planning (single-stage for forward/paged-KV/append-KV; multi-stage for split-KV and backward), registry operations, and the kernel-set declaration macro. Python unit tests verify codegen emission, profile filtering, and 15 validation rules for masks, hdim constraints, and pipeline requirements. GPU execution validation in 01_basic_fmha --validate reports zero errors across 65,536 elements with max absolute error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py) runs 14 configurations through both the upstream tile_example_fmha_fwd and the dispatcher, comparing exit codes to confirm behavioral parity — all 14 match.

Test Result

The C++ smoke test builds and passes all 9 compiled examples, and a Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching up to 375 TFLOPS at seqlen 2048.

Submission Checklist

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Fused Multi-Head Attention (FMHA) integration to the CK Tile dispatcher ecosystem by introducing new FMHA examples (Python/C++), ctypes bindings, and codegen utilities, alongside updates to grouped convolution and GEMM codegen/shared infrastructure.

Changes:

  • Added many FMHA C++/Python examples plus a ctypes C API library and fallback-kernel generator for Python integration.
  • Updated build system to compile FMHA examples and produce a Python FMHA shared library (and expanded supported GPU arch list).
  • Refactored GEMM codegen to reuse shared codegen infrastructure and improved codegen parallelism/cleanup; refreshed docs to reflect GEMM + grouped conv (and related tooling).

Reviewed changes

Copilot reviewed 63 out of 168 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
projects/composablekernel/dispatcher/examples/fmha/python/15_gqa_fmha.py New Python example exercising GQA/MQA and optional GPU validation via dispatcher.
projects/composablekernel/dispatcher/examples/fmha/python/14_dropout_fmha.py New Python example demonstrating dropout/LSE behavior and baseline GPU run.
projects/composablekernel/dispatcher/examples/fmha/python/13_bias_fmha.py New Python example demonstrating bias modes with CPU reference + limited GPU baseline.
projects/composablekernel/dispatcher/examples/fmha/python/12_masks_fmha.py New Python example demonstrating mask patterns with CPU reference + limited GPU baseline.
projects/composablekernel/dispatcher/examples/fmha/python/11_bf16_fmha.py New Python example showing bf16 handling and fallback behavior.
projects/composablekernel/dispatcher/examples/fmha/python/10_advanced_benchmark.py New Python benchmark driver with warmup/repeat/cache flush for FMHA.
projects/composablekernel/dispatcher/examples/fmha/python/09_multi_registry.py New Python example showing separate registries for different optimization targets.
projects/composablekernel/dispatcher/examples/fmha/python/07_stress_test.py New Python stress test generating/building/validating multiple FMHA kernels.
projects/composablekernel/dispatcher/examples/fmha/python/06_json_export.py New Python example exporting FMHA registry/kernel configs to JSON.
projects/composablekernel/dispatcher/examples/fmha/python/05_numpy_integration.py New NumPy-friendly FMHA wrapper demo with GPU execution and validation.
projects/composablekernel/dispatcher/examples/fmha/python/04_validation.py New Python validation suite comparing dispatcher output to CPU reference.
projects/composablekernel/dispatcher/examples/fmha/python/03_benchmark.py New Python benchmark over batch/sequence sizes.
projects/composablekernel/dispatcher/examples/fmha/python/02_multi_shape.py New Python multi-shape demo reusing one kernel over multiple shapes.
projects/composablekernel/dispatcher/examples/fmha/python/01_basic_fmha.py New Python multi-kernel build + run + validate example.
projects/composablekernel/dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp New C++ example running a single FMHA kernel across multiple shapes.
projects/composablekernel/dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp New C++ example exporting FMHA registry to JSON.
projects/composablekernel/dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp New C++ example demonstrating receipt-alias planning.
projects/composablekernel/dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp New C++ example demonstrating fp32/fp8 profile planning + JSON export.
projects/composablekernel/dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp New C++ example demonstrating AITER profile planning.
projects/composablekernel/dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp New C++ example demonstrating flash profile planning for fwd/bwd stages.
projects/composablekernel/dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp New C++ example demonstrating PyTorch profile planning across multiple families.
projects/composablekernel/dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp New C++ example for batch-prefill planning.
projects/composablekernel/dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp New C++ example for append-KV planning.
projects/composablekernel/dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp New C++ example for backward planning (multi-stage).
projects/composablekernel/dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp New C++ example for paged-KV/append-KV/batch-prefill planning.
projects/composablekernel/dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp New C++ example for split-KV planning (2-stage).
projects/composablekernel/dispatcher/examples/README.md Updated examples README to reflect grouped conv and reorganized content.
projects/composablekernel/dispatcher/examples/CMakeLists.txt Added FMHA & grouped conv example targets; added FMHA/conv Python libs; expanded include dirs and supported GPU arch list.
projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py Refactored to use shared codegen_common and improved parallel generation flow + temp config cleanup.
projects/composablekernel/dispatcher/codegen/kernel_config_loader.py Renamed convolution config classes/functions to “GroupedConv*” and updated generated macro name.
projects/composablekernel/dispatcher/codegen/generate_kernel_wrappers.py Documentation formatting update for wrapper output tree.
projects/composablekernel/dispatcher/codegen/generate_fmha_fallback.py New script to generate + optionally compile an FMHA fallback kernel and dispatch header for Python ctypes.
projects/composablekernel/dispatcher/codegen/generate_dispatcher_registration.py Updated console output messages for registration generation.
projects/composablekernel/dispatcher/codegen/fmha_symbol_map.py New FMHA symbol/config canonicalization + naming utilities (arch specs driven).
projects/composablekernel/dispatcher/codegen/fmha_profiles.py New FMHA profile/receipt aliasing and config filters.
projects/composablekernel/dispatcher/codegen/README.md Updated codegen README to cover grouped conv and shared infrastructure.
projects/composablekernel/dispatcher/codegen/ADDING_NEW_GPU.md Documentation formatting tweaks (ASCII arrows/tree).
projects/composablekernel/dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp New FMHA ctypes C API to initialize registry/dispatcher and run forward.
projects/composablekernel/dispatcher/bindings/README.md Updated bindings README to reflect grouped conv terminology and structure.
projects/composablekernel/dispatcher/README.md Updated top-level dispatcher README to include grouped conv, refreshed docs/trees.
projects/composablekernel/dispatcher/CMakeLists.txt Added FMHA registry/dispatcher sources and widened include path scope.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +71 to +72
hdim_q=128,
hdim_v=128,
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example accepts --hdim, but the kernel config is hard-coded to hdim_q=128/hdim_v=128. If a user passes a non-128 --hdim, the dispatcher setup/run can fail or behave inconsistently with the generated FmhaProblem (which uses args.hdim). Use hdim_q=args.hdim and hdim_v=args.hdim (or validate/restrict the CLI to 128 only) so the config matches the problem being executed.

Suggested change
hdim_q=128,
hdim_v=128,
hdim_q=args.hdim,
hdim_v=args.hdim,

Copilot uses AI. Check for mistakes.
Comment on lines +100 to +105
config = FmhaKernelConfig(
data_type="fp16",
hdim_q=128,
hdim_v=128,
gfx_arch=args.arch,
)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like other examples, this script accepts --hdim but always JIT-builds a 128x128 kernel. That can cause runtime failures when args.hdim != 128 (e.g., fmha_matmul derives hdim_q/hdim_v from the NumPy arrays and creates an FmhaProblem accordingly). Either (a) set hdim_q/hdim_v from the actual inputs/args, or (b) constrain --hdim to 128 for this example and emit a clear error if the user requests unsupported head dims.

Copilot uses AI. Check for mistakes.
Comment on lines +26 to +31
#define HIP_CHECK(call) \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
return -1; \
}
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HIP_CHECK macro returns immediately on failure, but fmha_dispatcher_run_fwd allocates multiple device buffers before subsequent HIP calls. If any later HIP_CHECK(...) fails (e.g., a memcpy), this will leak the buffers already allocated. Prefer a cleanup-on-error path (e.g., a single goto cleanup;/scope guard that frees any non-null pointers) or RAII wrappers for device allocations so all allocated resources are released on every error path.

Suggested change
#define HIP_CHECK(call) \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
return -1; \
}
#define HIP_CHECK(call) \
do \
{ \
hipError_t err = (call); \
if(err != hipSuccess) \
{ \
/* Reset device to release all resources \
allocated on it before returning. */ \
hipDeviceReset(); \
return -1; \
} \
} while(0)

Copilot uses AI. Check for mistakes.
Comment on lines +61 to +65
|---- gemm/
| |---- cpp/ # 6 C++ GEMM examples
| +---- python/ # 11 Python GEMM examples
|
+---- README.md
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The directory-tree snippet appears to include diff-style prefixes (||----, +----) inside the rendered Markdown code block, which makes the README misleading/confusing. Replace this with a clean tree (e.g., ├──, , └──) without any leading |/+ markers so it reads correctly when viewed on GitHub.

Suggested change
|---- gemm/
| |---- cpp/ # 6 C++ GEMM examples
| +---- python/ # 11 Python GEMM examples
|
+---- README.md
├── gemm/
├── cpp/ # 6 C++ GEMM examples
└── python/ # 11 Python GEMM examples
└── README.md

Copilot uses AI. Check for mistakes.
Comment on lines +701 to +726
|---- README.md # This file
|---- CMakeLists.txt # Build configuration
|
|---- include/ck_tile/dispatcher/ # C++ headers
| |---- dispatcher.hpp # Main dispatcher include
| |---- registry.hpp # GEMM kernel registry
| |---- kernel_key.hpp # Kernel configuration
| |---- grouped_conv_config.hpp # Grouped conv configuration
| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder)
| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
| +---- grouped_conv_utils.hpp # Grouped conv utilities
|
|---- src/ # C++ implementation
|
|---- codegen/ # Kernel generation
| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
| |---- unified_gemm_codegen.py # GEMM kernel generator
| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator
| +---- arch_specs.json # GPU specifications
|
|---- python/ # Python utilities
| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output
| |---- ctypes_utils.py # GEMM ctypes utilities
| +---- grouped_conv_utils.py # Grouped conv utilities
|
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several directory-tree blocks in this README use ||/|----/+---- markers that look like diff artifacts and don’t render as a conventional tree. Consider rewriting these blocks using a consistent tree format (├──/│/└──) and removing the extra leading | characters so the structure is readable and copy/paste-friendly.

Suggested change
|---- README.md # This file
|---- CMakeLists.txt # Build configuration
|
|---- include/ck_tile/dispatcher/ # C++ headers
| |---- dispatcher.hpp # Main dispatcher include
| |---- registry.hpp # GEMM kernel registry
| |---- kernel_key.hpp # Kernel configuration
| |---- grouped_conv_config.hpp # Grouped conv configuration
| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder)
| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
| +---- grouped_conv_utils.hpp # Grouped conv utilities
|
|---- src/ # C++ implementation
|
|---- codegen/ # Kernel generation
| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
| |---- unified_gemm_codegen.py # GEMM kernel generator
| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator
| +---- arch_specs.json # GPU specifications
|
|---- python/ # Python utilities
| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output
| |---- ctypes_utils.py # GEMM ctypes utilities
| +---- grouped_conv_utils.py # Grouped conv utilities
|
├── README.md # This file
├── CMakeLists.txt # Build configuration
├── include/ck_tile/dispatcher/ # C++ headers
│ ├── dispatcher.hpp # Main dispatcher include
│ ├── registry.hpp # GEMM kernel registry
│ ├── kernel_key.hpp # Kernel configuration
│ ├── grouped_conv_config.hpp # Grouped conv configuration
│ ├── grouped_conv_problem.hpp # Grouped conv problem (with builder)
│ ├── grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
│ ├── grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
│ └── grouped_conv_utils.hpp # Grouped conv utilities
├── src/ # C++ implementation
├── codegen/ # Kernel generation
│ ├── codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
│ ├── unified_gemm_codegen.py # GEMM kernel generator
│ ├── unified_grouped_conv_codegen.py # Grouped conv kernel generator
│ └── arch_specs.json # GPU specifications
└── python/ # Python utilities
├── dispatcher_common.py # Shared: paths, validation, Colors, phased output
├── ctypes_utils.py # GEMM ctypes utilities
└── grouped_conv_utils.py # Grouped conv utilities
|

Copilot uses AI. Check for mistakes.
@vidyasagar-amd vidyasagar-amd force-pushed the users/vanantha/ck/dispatcher-fmha branch from 886cb39 to 5dc38ca Compare March 11, 2026 18:43
@yraparti yraparti self-requested a review March 12, 2026 19:42
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GFX_ARCH is not defined for Conv.

| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API
| |---- gpu_helper.cpp # CLI helper for Python
| +---- CMakeLists.txt
+---- README.md
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we mention about fmha ctypes?

Comment thread projects/composablekernel/dispatcher/codegen/codegen_common.py Outdated
Comment thread projects/composablekernel/dispatcher/codegen/fmha_profiles.py Outdated
Comment on lines +338 to +342
const int64_t q_bytes = static_cast<int64_t>(batch) * nhead_q * seqlen_q * hdim_q * 2;
const int64_t k_bytes = static_cast<int64_t>(batch) * nhead_k * seqlen_k * hdim_q * 2;
const int64_t v_bytes = static_cast<int64_t>(batch) * nhead_k * seqlen_k * hdim_v * 2;
const int64_t o_bytes = static_cast<int64_t>(batch) * nhead_q * seqlen_q * hdim_v * 2;
const int64_t do_bytes = o_bytes;
Copy link
Copy Markdown
Contributor

@yraparti yraparti Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to alloc based on the dtype instead of *2

args.nhead_stride_q_descale = 0;
args.nhead_stride_k_descale = 0;
args.nhead_stride_v_descale = 0;
args.batch_stride_bias = (bias_type_int > 0) ? nhead_q * seqlen_q * seqlen_k : 0;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nhead_q * seqlen_q * seqlen_k is int * int * int with no cast. nhead_q=32, seqlen_q=seqlen_k=8192 overflows int32 (UB). Same at lines 227-230

traits_hdim_v: int = 0,
) -> Tuple[int, float]:
time_ms = ctypes.c_float(0.0)
rc = self._lib.fmha_dispatcher_run_fwd(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_fwd() passes 19 args to a 25-param C function -- missing is_v_rowmajor, perm, data_type_str, is_group_mode, window_left, window_right. Would cause stack
corruption if called.

add_custom_target(benchmark_fmha
COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py
${FMHA_TE_CONFIGS}/fwd.json
--arch ${USER_GPU_TARGETS}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

${USER_GPU_TARGETS} is a boolean (0/1), not an arch string. All benchmark targets pass --arch 0 or --arch 1, causing runtime failure.
Fix: Use ${SUPPORTED_GPU_TARGETS} following gemm pattern

Comment on lines +20 to +24
using namespace ck_tile::dispatcher;

static std::unique_ptr<FmhaRegistry> g_registry;
static std::unique_ptr<FmhaDispatcher> g_dispatcher;
static bool g_initialized = false;
Copy link
Copy Markdown
Contributor

@yraparti yraparti Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

g_registry, g_dispatcher, g_initialized are globals with no synchronization. Concurrent init/run/cleanup calls are data races (UB).
Fix: Add std::mutex guard or document single-threaded restriction.
Claude downgraded this to a suggestion after questioning about it.

return True


def expand_sweep(config_path: str, arch: str) -> List[FmhaKernelConfig]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expand_sweep() ignores tile_config section of JSON configs entirely. All 8 JSON files' tile_config is dead configuration.
Fix: Either remove tile_config from JSONs or implement filtering

Comment on lines +591 to +598
#if defined(CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE)
static_assert(sizeof(fmha_fwd_traits) >= 40, "fmha_fwd_traits layout may have changed upstream");
static_assert(sizeof(fmha_fwd_args) >= 300, "fmha_fwd_args layout may have changed upstream");
#endif
#if defined(CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE)
static_assert(sizeof(fmha_bwd_traits) >= 32, "fmha_bwd_traits layout may have changed upstream");
static_assert(sizeof(fmha_bwd_args) >= 350, "fmha_bwd_args layout may have changed upstream");
#endif
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use == for exact ABI check. Not sure if that is good enough if the contents change.

Comment thread projects/composablekernel/tile_engine/ops/fmha/fmha_benchmark.py Outdated
cwd=str(self.codegen_dir),
)
self._tick()
if r.returncode != 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed codegen/compile/link returns None silently. r.stderr captured but never logged. Debugging failures requires manual re-runs.
Fix: Log r.stderr on failure

Comment thread projects/composablekernel/dispatcher/codegen/fmha_arch_specs.json
bool matches(const GroupedConvProblem& problem) const
{
// Check if this kernel can handle the problem
return problem.op == key_.op;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GroupedConvKernelKey contains 20+ fields that could constrain kernel selection:

  • Signature fields: dtype_in, dtype_wei, dtype_out, layout, ndim_spatial, op
  • Tile config: tile_m, tile_n, tile_k
  • Wave/warp config: wave_m, wave_n, wave_k, warp_m, warp_n, warp_k
  • Pipeline: pipeline, scheduler, epilogue
  • Other: vector_size_a/b/c, block_per_cu, arch

Yet matches() only checks op

Comment on lines +80 to +89
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
const void* q_descale_ptr;
const void* k_descale_ptr;
const void* v_descale_ptr;
void* rand_val_ptr;
void* lse_ptr;
void* o_ptr;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default nullptr?

return oss.str();
}

auto tie() const
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tie() with 57 fields is a maintenance hazard. No compile-time enforcement when adding fields.

};

struct FmhaProblem
{
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20+ fields duplicated between FmhaProblem and FmhaKernelKey::Signature. from_invocation() is 100+ lines of
field-by-field copying.

ASSERT_EQ(plan.stages.size(), 1u);
EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::FwdAppendKv);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No test for select_kernel() returning nullptr on empty registry or forward no-match.

}

[[nodiscard]] std::uint64_t num_ops() const
{
Copy link
Copy Markdown
Contributor

@yraparti yraparti Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_ops() casts potentially negative int64_t to uint64_t. Wraps to huge values. Better to assert if the value is < 0.

/// Call after registration to trigger auto-export if enabled.
void perform_auto_export()
{
if(auto_export_enabled_.load(std::memory_order_acquire) && auto_export_on_register_)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reads non-atomic fields without lock while enable_auto_export() writes them under mutex. Data race possible.

@vidyasagar-amd vidyasagar-amd force-pushed the users/vanantha/ck/dispatcher-fmha branch from 64a5aa1 to 1a98a59 Compare April 9, 2026 19:36
Copy link
Copy Markdown
Contributor

@shumway shumway left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it about a third of the way through this PR.

There is a lot of cleanup and refactoring needed. This code could be much cleaner and easier to maintain, extend, and troubleshoot.

I'm concerned bout the apparent lack of tests on all the python generation code. I think once this is in heavy use and people are modifying and improving CK, this fragile and our sets of available kernels may not be accurate.

It's unlikely we can really fix this without breaking this PR into smaller PRs. Also, I don't know how well we can verify that we're not breaking behavior when we refactor and simplify. We may have to submit as is and then go through clean up systematically, essentially treating this new code as legacy (untested) code.


#include "ck_tile/dispatcher.hpp"

#ifndef GFX_ARCH
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm reading this file correctly, this code is intended for python bindings. This is CPU code and should not depend on any GPU architecture. That is, any GPU architecture decisions on host should be made at run time, right. What is the use-case that requires conditional compilation based on a GFX architecture. Including this means that we have different versions of the python binding library for each GFX architecture.

static std::unique_ptr<FmhaDispatcher> g_dispatcher;
static bool g_initialized = false;

#define HIP_CHECK(call) \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these helper macros be in one place? I think we're duplicating them in these ctypes cpp files.

}
}

static int dtype_input_bytes(const char* dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be common code, not specific to FMHA.

We can have static (file-scoped helpers) in cpp files, but something general like this probably belongs in a common include.

Also, let's avoid this use of the static keyword (file scoped functions) since static means many different things in c++. The modern pattern for file scoped definitions in cpp files is to use an unnamed namespace:

namespace {

void myFileScopedHelper() {
  // do something
}

} // namespace

// Run the single registered kernel directly, bypassing the multi-stage plan()
// that requires split+combine for splitkv or dot+dq+convert for bwd.
// Used for single-kernel .so benchmarking.
static float run_single_kernel(const FmhaInvocation& invocation)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CK style for functions is camel case: runSingleKernel.

// Used for single-kernel .so benchmarking.
static float run_single_kernel(const FmhaInvocation& invocation)
{
auto kernels = g_registry->get_all();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CK style is to not use auto when the specific type makes the design easier to read.

I'm guessing this is a std::vector<FmhaKernel>, but if it's explicit I don't have to guess and I'll know as I'm reading.


def _check_feature(spec: PipelineSpec) -> bool:
"""logits_soft_cap requires no bias."""
if spec.logits == "t" and spec.bias != "no":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return not (spc.logits == "t" and spec.bias != "no")

which is:

return spec.logits != "t" or spec.bias == "no"

Are these still booleans? This is so confusing. Does spec.logits != "t" mean that spec.logits = "f"? What is going on? Is the other boolean "yes" and "no"?


# ===== Receipt / Product filters =====

RECEIPT_FILTERS = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these numbers 0, 2, 4, 100, 200, 400, 600, 888, 800? Why is 888 before 800?


# Multiple tiles per hdim for splitkv, matching PR #5482 benchmarking additions.
# The instance builder iterates all tiles per hdim, letting the benchmark find the best.
SPLITKV_TILES_FP16 = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these be derived from architecture specs? What happens if there's a mistake here? Where did these numbers come from?

If I'm reading this correctly, these are arbitrary based on which kernels have been implemented in CK Tile. Do these change? How do developers know what to keep in sync? We're down in line 578 of a large python file, really mixing code and configuration.

bias: str
logits: str
sink: str
pagedkv: str = "f"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this bool?

"""Split-KV main kernel pipelines (matches KernelComponentFactoryBase.get_pipelines)."""
specs: List[SplitKVPipelineSpec] = []

SPLITKV_MASKS = ["no", "causal"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of the pattern above that can be simplified.

Copy link
Copy Markdown
Contributor

@shumway shumway left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, had a lot of offline discussion. There is a lot we can debate and improve in this PR. Since this is new code that provides customer value, and it is not yet on the critical path, I think we should merge as-is once we have CI passing.

This will help our code quality, since we can then have individual PRs to clean up code and improve the design, and write them at a size and scope for good engineering discussion.

This FMHA dispatcher is a great new capability that will be the foundation of customer-friendly kernel delivery.

Copy link
Copy Markdown
Contributor

@shumway shumway left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, we had a lot of productive offline discussion.

Since this is new code that introduces a major new feature, the best way forward is to merge as is. There is a lot of code cleanup and improvements we can do. Putting those changes in smaller follow up PRs with limited scope will make review and discussion much easier, resulting in higher quality. Attempting to make those changes on this PR delays this library capability and is not practical at this scale.

Copy link
Copy Markdown
Contributor

@DDEle DDEle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compilation failures found during building fmha_01_basic. In addition, I failed to build fmha_04_bwd.

Comment on lines +586 to +592
/// Enable or disable GPU benchmarking (timing).
/// When disabled, kernels execute once with no timing overhead.
void set_benchmarking(bool enable) { benchmarking_ = enable; }
[[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; }

private:
bool benchmarking_ = true;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_benchmarking(), benchmarking_enabled(), and bool benchmarking_ are each declared twice in GroupedConvDispatcher, causing a compilation error on gfx950:

  error: class member cannot be redeclared: set_benchmarking
  error: class member cannot be redeclared: benchmarking_enabled
  error: duplicate member 'benchmarking_'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants