[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher#5260
[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher#5260vidyasagar-amd wants to merge 26 commits intodevelopfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds Fused Multi-Head Attention (FMHA) integration to the CK Tile dispatcher ecosystem by introducing new FMHA examples (Python/C++), ctypes bindings, and codegen utilities, alongside updates to grouped convolution and GEMM codegen/shared infrastructure.
Changes:
- Added many FMHA C++/Python examples plus a ctypes C API library and fallback-kernel generator for Python integration.
- Updated build system to compile FMHA examples and produce a Python FMHA shared library (and expanded supported GPU arch list).
- Refactored GEMM codegen to reuse shared codegen infrastructure and improved codegen parallelism/cleanup; refreshed docs to reflect GEMM + grouped conv (and related tooling).
Reviewed changes
Copilot reviewed 63 out of 168 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| projects/composablekernel/dispatcher/examples/fmha/python/15_gqa_fmha.py | New Python example exercising GQA/MQA and optional GPU validation via dispatcher. |
| projects/composablekernel/dispatcher/examples/fmha/python/14_dropout_fmha.py | New Python example demonstrating dropout/LSE behavior and baseline GPU run. |
| projects/composablekernel/dispatcher/examples/fmha/python/13_bias_fmha.py | New Python example demonstrating bias modes with CPU reference + limited GPU baseline. |
| projects/composablekernel/dispatcher/examples/fmha/python/12_masks_fmha.py | New Python example demonstrating mask patterns with CPU reference + limited GPU baseline. |
| projects/composablekernel/dispatcher/examples/fmha/python/11_bf16_fmha.py | New Python example showing bf16 handling and fallback behavior. |
| projects/composablekernel/dispatcher/examples/fmha/python/10_advanced_benchmark.py | New Python benchmark driver with warmup/repeat/cache flush for FMHA. |
| projects/composablekernel/dispatcher/examples/fmha/python/09_multi_registry.py | New Python example showing separate registries for different optimization targets. |
| projects/composablekernel/dispatcher/examples/fmha/python/07_stress_test.py | New Python stress test generating/building/validating multiple FMHA kernels. |
| projects/composablekernel/dispatcher/examples/fmha/python/06_json_export.py | New Python example exporting FMHA registry/kernel configs to JSON. |
| projects/composablekernel/dispatcher/examples/fmha/python/05_numpy_integration.py | New NumPy-friendly FMHA wrapper demo with GPU execution and validation. |
| projects/composablekernel/dispatcher/examples/fmha/python/04_validation.py | New Python validation suite comparing dispatcher output to CPU reference. |
| projects/composablekernel/dispatcher/examples/fmha/python/03_benchmark.py | New Python benchmark over batch/sequence sizes. |
| projects/composablekernel/dispatcher/examples/fmha/python/02_multi_shape.py | New Python multi-shape demo reusing one kernel over multiple shapes. |
| projects/composablekernel/dispatcher/examples/fmha/python/01_basic_fmha.py | New Python multi-kernel build + run + validate example. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/15_multi_shape_fmha.cpp | New C++ example running a single FMHA kernel across multiple shapes. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/12_registry_json_fmha.cpp | New C++ example exporting FMHA registry to JSON. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/11_receipt_aliases_fmha.cpp | New C++ example demonstrating receipt-alias planning. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/10_profile_fp32_fp8_fmha.cpp | New C++ example demonstrating fp32/fp8 profile planning + JSON export. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/09_profile_aiter_fmha.cpp | New C++ example demonstrating AITER profile planning. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/08_profile_flash_fmha.cpp | New C++ example demonstrating flash profile planning for fwd/bwd stages. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/07_profile_pytorch_fmha.cpp | New C++ example demonstrating PyTorch profile planning across multiple families. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/06_batch_prefill_fmha.cpp | New C++ example for batch-prefill planning. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/05_appendkv_fmha.cpp | New C++ example for append-KV planning. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/04_bwd_fmha.cpp | New C++ example for backward planning (multi-stage). |
| projects/composablekernel/dispatcher/examples/fmha/cpp/03_kvcache_fmha.cpp | New C++ example for paged-KV/append-KV/batch-prefill planning. |
| projects/composablekernel/dispatcher/examples/fmha/cpp/02_splitkv_fmha.cpp | New C++ example for split-KV planning (2-stage). |
| projects/composablekernel/dispatcher/examples/README.md | Updated examples README to reflect grouped conv and reorganized content. |
| projects/composablekernel/dispatcher/examples/CMakeLists.txt | Added FMHA & grouped conv example targets; added FMHA/conv Python libs; expanded include dirs and supported GPU arch list. |
| projects/composablekernel/dispatcher/codegen/unified_gemm_codegen.py | Refactored to use shared codegen_common and improved parallel generation flow + temp config cleanup. |
| projects/composablekernel/dispatcher/codegen/kernel_config_loader.py | Renamed convolution config classes/functions to “GroupedConv*” and updated generated macro name. |
| projects/composablekernel/dispatcher/codegen/generate_kernel_wrappers.py | Documentation formatting update for wrapper output tree. |
| projects/composablekernel/dispatcher/codegen/generate_fmha_fallback.py | New script to generate + optionally compile an FMHA fallback kernel and dispatch header for Python ctypes. |
| projects/composablekernel/dispatcher/codegen/generate_dispatcher_registration.py | Updated console output messages for registration generation. |
| projects/composablekernel/dispatcher/codegen/fmha_symbol_map.py | New FMHA symbol/config canonicalization + naming utilities (arch specs driven). |
| projects/composablekernel/dispatcher/codegen/fmha_profiles.py | New FMHA profile/receipt aliasing and config filters. |
| projects/composablekernel/dispatcher/codegen/README.md | Updated codegen README to cover grouped conv and shared infrastructure. |
| projects/composablekernel/dispatcher/codegen/ADDING_NEW_GPU.md | Documentation formatting tweaks (ASCII arrows/tree). |
| projects/composablekernel/dispatcher/bindings/ctypes/fmha_ctypes_lib.cpp | New FMHA ctypes C API to initialize registry/dispatcher and run forward. |
| projects/composablekernel/dispatcher/bindings/README.md | Updated bindings README to reflect grouped conv terminology and structure. |
| projects/composablekernel/dispatcher/README.md | Updated top-level dispatcher README to include grouped conv, refreshed docs/trees. |
| projects/composablekernel/dispatcher/CMakeLists.txt | Added FMHA registry/dispatcher sources and widened include path scope. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| hdim_q=128, | ||
| hdim_v=128, |
There was a problem hiding this comment.
This example accepts --hdim, but the kernel config is hard-coded to hdim_q=128/hdim_v=128. If a user passes a non-128 --hdim, the dispatcher setup/run can fail or behave inconsistently with the generated FmhaProblem (which uses args.hdim). Use hdim_q=args.hdim and hdim_v=args.hdim (or validate/restrict the CLI to 128 only) so the config matches the problem being executed.
| hdim_q=128, | |
| hdim_v=128, | |
| hdim_q=args.hdim, | |
| hdim_v=args.hdim, |
| config = FmhaKernelConfig( | ||
| data_type="fp16", | ||
| hdim_q=128, | ||
| hdim_v=128, | ||
| gfx_arch=args.arch, | ||
| ) |
There was a problem hiding this comment.
Like other examples, this script accepts --hdim but always JIT-builds a 128x128 kernel. That can cause runtime failures when args.hdim != 128 (e.g., fmha_matmul derives hdim_q/hdim_v from the NumPy arrays and creates an FmhaProblem accordingly). Either (a) set hdim_q/hdim_v from the actual inputs/args, or (b) constrain --hdim to 128 for this example and emit a clear error if the user requests unsupported head dims.
| #define HIP_CHECK(call) \ | ||
| { \ | ||
| hipError_t err = call; \ | ||
| if(err != hipSuccess) \ | ||
| return -1; \ | ||
| } |
There was a problem hiding this comment.
The HIP_CHECK macro returns immediately on failure, but fmha_dispatcher_run_fwd allocates multiple device buffers before subsequent HIP calls. If any later HIP_CHECK(...) fails (e.g., a memcpy), this will leak the buffers already allocated. Prefer a cleanup-on-error path (e.g., a single goto cleanup;/scope guard that frees any non-null pointers) or RAII wrappers for device allocations so all allocated resources are released on every error path.
| #define HIP_CHECK(call) \ | |
| { \ | |
| hipError_t err = call; \ | |
| if(err != hipSuccess) \ | |
| return -1; \ | |
| } | |
| #define HIP_CHECK(call) \ | |
| do \ | |
| { \ | |
| hipError_t err = (call); \ | |
| if(err != hipSuccess) \ | |
| { \ | |
| /* Reset device to release all resources \ | |
| allocated on it before returning. */ \ | |
| hipDeviceReset(); \ | |
| return -1; \ | |
| } \ | |
| } while(0) |
| |---- gemm/ | ||
| | |---- cpp/ # 6 C++ GEMM examples | ||
| | +---- python/ # 11 Python GEMM examples | ||
| | | ||
| +---- README.md |
There was a problem hiding this comment.
The directory-tree snippet appears to include diff-style prefixes (||----, +----) inside the rendered Markdown code block, which makes the README misleading/confusing. Replace this with a clean tree (e.g., ├──, │, └──) without any leading |/+ markers so it reads correctly when viewed on GitHub.
| |---- gemm/ | |
| | |---- cpp/ # 6 C++ GEMM examples | |
| | +---- python/ # 11 Python GEMM examples | |
| | | |
| +---- README.md | |
| ├── gemm/ | |
| │ ├── cpp/ # 6 C++ GEMM examples | |
| │ └── python/ # 11 Python GEMM examples | |
| │ | |
| └── README.md |
| |---- README.md # This file | ||
| |---- CMakeLists.txt # Build configuration | ||
| | | ||
| |---- include/ck_tile/dispatcher/ # C++ headers | ||
| | |---- dispatcher.hpp # Main dispatcher include | ||
| | |---- registry.hpp # GEMM kernel registry | ||
| | |---- kernel_key.hpp # Kernel configuration | ||
| | |---- grouped_conv_config.hpp # Grouped conv configuration | ||
| | |---- grouped_conv_problem.hpp # Grouped conv problem (with builder) | ||
| | |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations | ||
| | |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe) | ||
| | +---- grouped_conv_utils.hpp # Grouped conv utilities | ||
| | | ||
| |---- src/ # C++ implementation | ||
| | | ||
| |---- codegen/ # Kernel generation | ||
| | |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings | ||
| | |---- unified_gemm_codegen.py # GEMM kernel generator | ||
| | |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator | ||
| | +---- arch_specs.json # GPU specifications | ||
| | | ||
| |---- python/ # Python utilities | ||
| | |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output | ||
| | |---- ctypes_utils.py # GEMM ctypes utilities | ||
| | +---- grouped_conv_utils.py # Grouped conv utilities | ||
| | |
There was a problem hiding this comment.
Several directory-tree blocks in this README use ||/|----/+---- markers that look like diff artifacts and don’t render as a conventional tree. Consider rewriting these blocks using a consistent tree format (├──/│/└──) and removing the extra leading | characters so the structure is readable and copy/paste-friendly.
| |---- README.md # This file | |
| |---- CMakeLists.txt # Build configuration | |
| | | |
| |---- include/ck_tile/dispatcher/ # C++ headers | |
| | |---- dispatcher.hpp # Main dispatcher include | |
| | |---- registry.hpp # GEMM kernel registry | |
| | |---- kernel_key.hpp # Kernel configuration | |
| | |---- grouped_conv_config.hpp # Grouped conv configuration | |
| | |---- grouped_conv_problem.hpp # Grouped conv problem (with builder) | |
| | |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations | |
| | |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe) | |
| | +---- grouped_conv_utils.hpp # Grouped conv utilities | |
| | | |
| |---- src/ # C++ implementation | |
| | | |
| |---- codegen/ # Kernel generation | |
| | |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings | |
| | |---- unified_gemm_codegen.py # GEMM kernel generator | |
| | |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator | |
| | +---- arch_specs.json # GPU specifications | |
| | | |
| |---- python/ # Python utilities | |
| | |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output | |
| | |---- ctypes_utils.py # GEMM ctypes utilities | |
| | +---- grouped_conv_utils.py # Grouped conv utilities | |
| | | |
| ├── README.md # This file | |
| ├── CMakeLists.txt # Build configuration | |
| ├── include/ck_tile/dispatcher/ # C++ headers | |
| │ ├── dispatcher.hpp # Main dispatcher include | |
| │ ├── registry.hpp # GEMM kernel registry | |
| │ ├── kernel_key.hpp # Kernel configuration | |
| │ ├── grouped_conv_config.hpp # Grouped conv configuration | |
| │ ├── grouped_conv_problem.hpp # Grouped conv problem (with builder) | |
| │ ├── grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations | |
| │ ├── grouped_conv_registry.hpp # Grouped conv registry (thread-safe) | |
| │ └── grouped_conv_utils.hpp # Grouped conv utilities | |
| ├── src/ # C++ implementation | |
| ├── codegen/ # Kernel generation | |
| │ ├── codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings | |
| │ ├── unified_gemm_codegen.py # GEMM kernel generator | |
| │ ├── unified_grouped_conv_codegen.py # Grouped conv kernel generator | |
| │ └── arch_specs.json # GPU specifications | |
| └── python/ # Python utilities | |
| ├── dispatcher_common.py # Shared: paths, validation, Colors, phased output | |
| ├── ctypes_utils.py # GEMM ctypes utilities | |
| └── grouped_conv_utils.py # Grouped conv utilities | |
| | |
886cb39 to
5dc38ca
Compare
There was a problem hiding this comment.
GFX_ARCH is not defined for Conv.
| | |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API | ||
| | |---- gpu_helper.cpp # CLI helper for Python | ||
| | +---- CMakeLists.txt | ||
| +---- README.md |
There was a problem hiding this comment.
should we mention about fmha ctypes?
| const int64_t q_bytes = static_cast<int64_t>(batch) * nhead_q * seqlen_q * hdim_q * 2; | ||
| const int64_t k_bytes = static_cast<int64_t>(batch) * nhead_k * seqlen_k * hdim_q * 2; | ||
| const int64_t v_bytes = static_cast<int64_t>(batch) * nhead_k * seqlen_k * hdim_v * 2; | ||
| const int64_t o_bytes = static_cast<int64_t>(batch) * nhead_q * seqlen_q * hdim_v * 2; | ||
| const int64_t do_bytes = o_bytes; |
There was a problem hiding this comment.
better to alloc based on the dtype instead of *2
| args.nhead_stride_q_descale = 0; | ||
| args.nhead_stride_k_descale = 0; | ||
| args.nhead_stride_v_descale = 0; | ||
| args.batch_stride_bias = (bias_type_int > 0) ? nhead_q * seqlen_q * seqlen_k : 0; |
There was a problem hiding this comment.
nhead_q * seqlen_q * seqlen_k is int * int * int with no cast. nhead_q=32, seqlen_q=seqlen_k=8192 overflows int32 (UB). Same at lines 227-230
| traits_hdim_v: int = 0, | ||
| ) -> Tuple[int, float]: | ||
| time_ms = ctypes.c_float(0.0) | ||
| rc = self._lib.fmha_dispatcher_run_fwd( |
There was a problem hiding this comment.
run_fwd() passes 19 args to a 25-param C function -- missing is_v_rowmajor, perm, data_type_str, is_group_mode, window_left, window_right. Would cause stack
corruption if called.
| add_custom_target(benchmark_fmha | ||
| COMMAND ${Python3_EXECUTABLE} ${FMHA_TE_DIR}/fmha_benchmark.py | ||
| ${FMHA_TE_CONFIGS}/fwd.json | ||
| --arch ${USER_GPU_TARGETS} |
There was a problem hiding this comment.
${USER_GPU_TARGETS} is a boolean (0/1), not an arch string. All benchmark targets pass --arch 0 or --arch 1, causing runtime failure.
Fix: Use ${SUPPORTED_GPU_TARGETS} following gemm pattern
| using namespace ck_tile::dispatcher; | ||
|
|
||
| static std::unique_ptr<FmhaRegistry> g_registry; | ||
| static std::unique_ptr<FmhaDispatcher> g_dispatcher; | ||
| static bool g_initialized = false; |
There was a problem hiding this comment.
g_registry, g_dispatcher, g_initialized are globals with no synchronization. Concurrent init/run/cleanup calls are data races (UB).
Fix: Add std::mutex guard or document single-threaded restriction.
Claude downgraded this to a suggestion after questioning about it.
| return True | ||
|
|
||
|
|
||
| def expand_sweep(config_path: str, arch: str) -> List[FmhaKernelConfig]: |
There was a problem hiding this comment.
expand_sweep() ignores tile_config section of JSON configs entirely. All 8 JSON files' tile_config is dead configuration.
Fix: Either remove tile_config from JSONs or implement filtering
| #if defined(CK_TILE_FMHA_FWD_TYPES_FROM_EXAMPLE) | ||
| static_assert(sizeof(fmha_fwd_traits) >= 40, "fmha_fwd_traits layout may have changed upstream"); | ||
| static_assert(sizeof(fmha_fwd_args) >= 300, "fmha_fwd_args layout may have changed upstream"); | ||
| #endif | ||
| #if defined(CK_TILE_FMHA_BWD_TYPES_FROM_EXAMPLE) | ||
| static_assert(sizeof(fmha_bwd_traits) >= 32, "fmha_bwd_traits layout may have changed upstream"); | ||
| static_assert(sizeof(fmha_bwd_args) >= 350, "fmha_bwd_args layout may have changed upstream"); | ||
| #endif |
There was a problem hiding this comment.
use == for exact ABI check. Not sure if that is good enough if the contents change.
| cwd=str(self.codegen_dir), | ||
| ) | ||
| self._tick() | ||
| if r.returncode != 0: |
There was a problem hiding this comment.
Failed codegen/compile/link returns None silently. r.stderr captured but never logged. Debugging failures requires manual re-runs.
Fix: Log r.stderr on failure
| bool matches(const GroupedConvProblem& problem) const | ||
| { | ||
| // Check if this kernel can handle the problem | ||
| return problem.op == key_.op; |
There was a problem hiding this comment.
The GroupedConvKernelKey contains 20+ fields that could constrain kernel selection:
- Signature fields: dtype_in, dtype_wei, dtype_out, layout, ndim_spatial, op
- Tile config: tile_m, tile_n, tile_k
- Wave/warp config: wave_m, wave_n, wave_k, warp_m, warp_n, warp_k
- Pipeline: pipeline, scheduler, epilogue
- Other: vector_size_a/b/c, block_per_cu, arch
Yet matches() only checks op
| const void* q_ptr; | ||
| const void* k_ptr; | ||
| const void* v_ptr; | ||
| const void* bias_ptr; | ||
| const void* q_descale_ptr; | ||
| const void* k_descale_ptr; | ||
| const void* v_descale_ptr; | ||
| void* rand_val_ptr; | ||
| void* lse_ptr; | ||
| void* o_ptr; |
| return oss.str(); | ||
| } | ||
|
|
||
| auto tie() const |
There was a problem hiding this comment.
tie() with 57 fields is a maintenance hazard. No compile-time enforcement when adding fields.
| }; | ||
|
|
||
| struct FmhaProblem | ||
| { |
There was a problem hiding this comment.
20+ fields duplicated between FmhaProblem and FmhaKernelKey::Signature. from_invocation() is 100+ lines of
field-by-field copying.
| ASSERT_EQ(plan.stages.size(), 1u); | ||
| EXPECT_EQ(plan.stages[0].family, FmhaKernelFamily::FwdAppendKv); | ||
| } | ||
|
|
There was a problem hiding this comment.
No test for select_kernel() returning nullptr on empty registry or forward no-match.
| } | ||
|
|
||
| [[nodiscard]] std::uint64_t num_ops() const | ||
| { |
There was a problem hiding this comment.
num_ops() casts potentially negative int64_t to uint64_t. Wraps to huge values. Better to assert if the value is < 0.
| /// Call after registration to trigger auto-export if enabled. | ||
| void perform_auto_export() | ||
| { | ||
| if(auto_export_enabled_.load(std::memory_order_acquire) && auto_export_on_register_) |
There was a problem hiding this comment.
reads non-atomic fields without lock while enable_auto_export() writes them under mutex. Data race possible.
64a5aa1 to
1a98a59
Compare
shumway
left a comment
There was a problem hiding this comment.
I made it about a third of the way through this PR.
There is a lot of cleanup and refactoring needed. This code could be much cleaner and easier to maintain, extend, and troubleshoot.
I'm concerned bout the apparent lack of tests on all the python generation code. I think once this is in heavy use and people are modifying and improving CK, this fragile and our sets of available kernels may not be accurate.
It's unlikely we can really fix this without breaking this PR into smaller PRs. Also, I don't know how well we can verify that we're not breaking behavior when we refactor and simplify. We may have to submit as is and then go through clean up systematically, essentially treating this new code as legacy (untested) code.
|
|
||
| #include "ck_tile/dispatcher.hpp" | ||
|
|
||
| #ifndef GFX_ARCH |
There was a problem hiding this comment.
If I'm reading this file correctly, this code is intended for python bindings. This is CPU code and should not depend on any GPU architecture. That is, any GPU architecture decisions on host should be made at run time, right. What is the use-case that requires conditional compilation based on a GFX architecture. Including this means that we have different versions of the python binding library for each GFX architecture.
| static std::unique_ptr<FmhaDispatcher> g_dispatcher; | ||
| static bool g_initialized = false; | ||
|
|
||
| #define HIP_CHECK(call) \ |
There was a problem hiding this comment.
Should these helper macros be in one place? I think we're duplicating them in these ctypes cpp files.
| } | ||
| } | ||
|
|
||
| static int dtype_input_bytes(const char* dtype) |
There was a problem hiding this comment.
This should be common code, not specific to FMHA.
We can have static (file-scoped helpers) in cpp files, but something general like this probably belongs in a common include.
Also, let's avoid this use of the static keyword (file scoped functions) since static means many different things in c++. The modern pattern for file scoped definitions in cpp files is to use an unnamed namespace:
namespace {
void myFileScopedHelper() {
// do something
}
} // namespace
| // Run the single registered kernel directly, bypassing the multi-stage plan() | ||
| // that requires split+combine for splitkv or dot+dq+convert for bwd. | ||
| // Used for single-kernel .so benchmarking. | ||
| static float run_single_kernel(const FmhaInvocation& invocation) |
There was a problem hiding this comment.
CK style for functions is camel case: runSingleKernel.
| // Used for single-kernel .so benchmarking. | ||
| static float run_single_kernel(const FmhaInvocation& invocation) | ||
| { | ||
| auto kernels = g_registry->get_all(); |
There was a problem hiding this comment.
CK style is to not use auto when the specific type makes the design easier to read.
I'm guessing this is a std::vector<FmhaKernel>, but if it's explicit I don't have to guess and I'll know as I'm reading.
|
|
||
| def _check_feature(spec: PipelineSpec) -> bool: | ||
| """logits_soft_cap requires no bias.""" | ||
| if spec.logits == "t" and spec.bias != "no": |
There was a problem hiding this comment.
return not (spc.logits == "t" and spec.bias != "no")
which is:
return spec.logits != "t" or spec.bias == "no"
Are these still booleans? This is so confusing. Does spec.logits != "t" mean that spec.logits = "f"? What is going on? Is the other boolean "yes" and "no"?
|
|
||
| # ===== Receipt / Product filters ===== | ||
|
|
||
| RECEIPT_FILTERS = { |
There was a problem hiding this comment.
What are these numbers 0, 2, 4, 100, 200, 400, 600, 888, 800? Why is 888 before 800?
|
|
||
| # Multiple tiles per hdim for splitkv, matching PR #5482 benchmarking additions. | ||
| # The instance builder iterates all tiles per hdim, letting the benchmark find the best. | ||
| SPLITKV_TILES_FP16 = { |
There was a problem hiding this comment.
Can these be derived from architecture specs? What happens if there's a mistake here? Where did these numbers come from?
If I'm reading this correctly, these are arbitrary based on which kernels have been implemented in CK Tile. Do these change? How do developers know what to keep in sync? We're down in line 578 of a large python file, really mixing code and configuration.
| bias: str | ||
| logits: str | ||
| sink: str | ||
| pagedkv: str = "f" |
| """Split-KV main kernel pipelines (matches KernelComponentFactoryBase.get_pipelines).""" | ||
| specs: List[SplitKVPipelineSpec] = [] | ||
|
|
||
| SPLITKV_MASKS = ["no", "causal"] |
There was a problem hiding this comment.
More of the pattern above that can be simplified.
shumway
left a comment
There was a problem hiding this comment.
OK, had a lot of offline discussion. There is a lot we can debate and improve in this PR. Since this is new code that provides customer value, and it is not yet on the critical path, I think we should merge as-is once we have CI passing.
This will help our code quality, since we can then have individual PRs to clean up code and improve the design, and write them at a size and scope for good engineering discussion.
This FMHA dispatcher is a great new capability that will be the foundation of customer-friendly kernel delivery.
shumway
left a comment
There was a problem hiding this comment.
OK, we had a lot of productive offline discussion.
Since this is new code that introduces a major new feature, the best way forward is to merge as is. There is a lot of code cleanup and improvements we can do. Putting those changes in smaller follow up PRs with limited scope will make review and discussion much easier, resulting in higher quality. Attempting to make those changes on this PR delays this library capability and is not practical at this scale.
DDEle
left a comment
There was a problem hiding this comment.
Compilation failures found during building fmha_01_basic. In addition, I failed to build fmha_04_bwd.
| /// Enable or disable GPU benchmarking (timing). | ||
| /// When disabled, kernels execute once with no timing overhead. | ||
| void set_benchmarking(bool enable) { benchmarking_ = enable; } | ||
| [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; } | ||
|
|
||
| private: | ||
| bool benchmarking_ = true; |
There was a problem hiding this comment.
set_benchmarking(), benchmarking_enabled(), and bool benchmarking_ are each declared twice in GroupedConvDispatcher, causing a compilation error on gfx950:
error: class member cannot be redeclared: set_benchmarking
error: class member cannot be redeclared: benchmarking_enabled
error: duplicate member 'benchmarking_'
Motivation
The CK Tile dispatcher currently supports GEMM and Grouped Convolution but has no support for Fused Multi-Head Attention (FMHA). The example/ck_tile/01_fmha folder contains a comprehensive FMHA implementation with forward, backward, split-KV, paged-KV, append-KV, and batch-prefill kernels across multiple GPU architectures — but there is no unified dispatch layer for it. This PR ports the FMHA stack into the dispatcher, following the same architectural patterns established by GEMM and Grouped Convolution, enabling runtime kernel selection, JIT compilation from Python, and a declarative C++ example flow. Autotuning heuristics to follow.
Technical Details
This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring GEMM's layered architecture. Seven new C++ runtime headers provide type definitions (coexisting with upstream headers via __has_include, requiring zero modifications to example/ck_tile/01_fmha/), a problem builder with 18+ setters, Signature + Algorithm kernel key matching, a virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard support and named tile/wave/warp setters, arch-aware registry with JSON export, and a dispatcher with seqtune-aware selection, configurable timing, and multi-stage execution plans for split-KV (two-stage) and backward (three-stage). The codegen pipeline is driven by a fmha_arch_specs.json capturing per-arch tile tables and pipeline constraints for five architectures (gfx90a/942/950/1100/1201), migrated from hardcoded logic in 01_fmha/codegen/, with supporting modules for C++ symbol mappings, validation rules, and named receipt profiles (ck_default, flash, pytorch, aiter, fp32, fp8). Python integration (fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel multi-kernel builds, HIP memory management via ctypes, tolerance-based validation, and a NumPy CPU reference with GQA support. Twenty-seven C++ and thirty-two Python examples cover the full feature surface — forward, split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill, fp8, logits soft cap, sink tokens, and parameter sweeps — all JIT-compiled on the fly.
Test Plan
Seven test files cover the runtime types, codegen, and end-to-end correctness. C++ unit tests validate the problem builder, dispatcher planning (single-stage for forward/paged-KV/append-KV; multi-stage for split-KV and backward), registry operations, and the kernel-set declaration macro. Python unit tests verify codegen emission, profile filtering, and 15 validation rules for masks, hdim constraints, and pipeline requirements. GPU execution validation in 01_basic_fmha --validate reports zero errors across 65,536 elements with max absolute error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py) runs 14 configurations through both the upstream tile_example_fmha_fwd and the dispatcher, comparing exit codes to confirm behavioral parity — all 14 match.
Test Result
The C++ smoke test builds and passes all 9 compiled examples, and a Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching up to 375 TFLOPS at seqlen 2048.
Submission Checklist