Add gfx1201 (RDNA4) support across CK Dispatcher, MIOpen, and hipBLASLt#5455
Add gfx1201 (RDNA4) support across CK Dispatcher, MIOpen, and hipBLASLt#5455ChrisLundquist wants to merge 41 commits intoROCm:users/vanantha/ck/dispatcher-fmhafrom
Conversation
vidyasagar-amd
left a comment
There was a problem hiding this comment.
Thanks for the fixes and example, good to see gfx1201 support being added and tested
There was a problem hiding this comment.
Pull request overview
This PR improves gfx1201 / RDNA4 support in the CK Tile dispatcher by aligning warp-tile filtering, FMHA architecture tags/config defaults, and example/codegen tooling so kernels can be generated/validated correctly for gfx12 targets.
Changes:
- Add/validate RDNA4 warp-tile support (16x16x16) across BF16/FP8/BF8/INT8 and add unit tests to prevent regressions.
- Make example kernel build/codegen path architecture-aware (wave/warp tile defaults +
--gpu-targetpassed through). - Update FMHA defaults/spec metadata for gfx11/gfx12 (arch_tag + gfx12 wave/warp overrides) and add a gfx1201 GEMM benchmark example.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| projects/composablekernel/dispatcher/tests/test_dispatcher_common.py | Adds RDNA4 warp-tile tests and gfx12 FMHA spec_to_config regression tests. |
| projects/composablekernel/dispatcher/scripts/example_kernel_builder.py | Makes GEMM wildcard expansion/autofill depend on target arch; forwards --gpu-target and switches to --output-dir. |
| projects/composablekernel/dispatcher/python/fmha_utils.py | Overrides gfx12 FMHA warp tiles and wave defaults to valid RDNA4 values. |
| projects/composablekernel/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp | Extends RDNA4 warp-tile filtering to gfx1200/gfx1201 and all relevant dtypes. |
| projects/composablekernel/dispatcher/examples/gemm/cpp/08_gfx1201_rdna4.cpp | New RDNA4 GEMM benchmark example with gfx1201-targeted kernel declarations. |
| projects/composablekernel/dispatcher/examples/CMakeLists.txt | Registers the new RDNA4 GEMM example target. |
| projects/composablekernel/dispatcher/codegen/fmha_arch_specs.json | Fixes FMHA arch_tag mapping to use ck_tile::gfx11_t / ck_tile::gfx12_t. |
| projects/composablekernel/dispatcher/codegen/arch_filter.py | Adds gfx1200 wave and warp-tile combos to the fallback tables. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # All data types that RDNA4 supports via WMMA | ||
| RDNA4_DTYPES = ["fp16", "bf16", "fp8", "bf8", "int8"] | ||
| EXPECTED_TILE = [16, 16, 16] | ||
|
|
There was a problem hiding this comment.
RDNA4_DTYPES is declared but never used in these tests. Consider either using it to drive the per-dtype warp-tile subtests (to avoid duplication) or removing it to keep the test module clean.
| def test_rdna4_all_dtypes_warp_tile(self): | |
| """Validate 16x16x16 warp tiles for all RDNA4 WMMA data types.""" | |
| for arch in self.RDNA4_ARCHS: | |
| for dtype in self.RDNA4_DTYPES: | |
| with self.subTest(arch=arch, dtype=dtype): | |
| is_valid, msg = validate_warp_tile_config( | |
| self.EXPECTED_TILE, arch, dtype | |
| ) | |
| self.assertTrue(is_valid, f"{arch} {dtype}: {msg}") |
| # Matches: , "gfx1201") or similar at end of .add() body | ||
| if m := re.search(r',\s*"(gfx\w+)"\s*\)', add_body): |
There was a problem hiding this comment.
In parse_gemm_declarations(), the regex used to capture the per-kernel architecture from the third argument to .add() (re.search(r',\s*"(gfx\w+)"\s*\)', add_body)) will never match because add_body is extracted inside the balanced parentheses and therefore does not contain the closing ) of the .add(...) call. This prevents kernel["arch"]/decl_arch from being detected and can cause wildcard expansion/autofill to use the wrong architecture defaults when --gpu-target is not explicitly provided. Update the pattern to match the end of add_body (no )), or otherwise parse the trailing argument reliably.
| # Matches: , "gfx1201") or similar at end of .add() body | |
| if m := re.search(r',\s*"(gfx\w+)"\s*\)', add_body): | |
| # Matches: , "gfx1201" at end of .add() body | |
| if m := re.search(r',\s*"(gfx\w+)"\s*$', add_body): |
| - (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024 | ||
| - For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024) | ||
| - For 64x64 tile: both (16,16,k) and (32,32,k) work | ||
| - (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * warp_size <= 1024 |
There was a problem hiding this comment.
The block-size constraint description in expand_gemm_wildcards() is now inconsistent with the implementation: the docstring mentions (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * warp_size <= 1024, but the code checks warp_m * warp_n * warp_k * warp_size <= 1024. Please update the docstring to reflect the current logic so future edits don’t reintroduce the old (and different) constraint.
| - (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * warp_size <= 1024 | |
| - warp_m * warp_n * warp_k * warp_size <= 1024 |
I agree, and it's even better that this PR isn't gfx1201-only - it brings in gfx1200 too, so it's a win for RDNA4 overall. |
| "gfx1200": { | ||
| "fp16_fp16_fp32": [[16, 16, 16]], | ||
| "bf16_bf16_fp32": [[16, 16, 16]], | ||
| "fp8_fp8_fp32": [[16, 16, 16]], | ||
| "bf8_bf8_fp32": [[16, 16, 16]], | ||
| "fp8_bf8_fp32": [[16, 16, 16]], | ||
| "bf8_fp8_fp32": [[16, 16, 16]], | ||
| "int8_int8_int32": [[16, 16, 16]], | ||
| }, | ||
| "gfx1201": { | ||
| "fp16_fp16_fp32": [[16, 16, 16]], | ||
| "bf16_bf16_fp32": [[16, 16, 16]], | ||
| "fp8_fp8_fp32": [[16, 16, 16]], | ||
| "bf8_bf8_fp32": [[16, 16, 16]], | ||
| "fp8_bf8_fp32": [[16, 16, 16]], | ||
| "bf8_fp8_fp32": [[16, 16, 16]], | ||
| "int8_int8_int32": [[16, 16, 16]], | ||
| }, |
There was a problem hiding this comment.
if gfx1200 and 1201 are essentially the same, do we need 2 different keys here?
There was a problem hiding this comment.
If it is supported by gfx12_, sure.
There was a problem hiding this comment.
I think it does - see RDNA4 ISA, page 411 (bottom of page).
There was a problem hiding this comment.
this file is not necessary. We can use the existing examples to add a test with a different --arch.
| is_gfx12 = arch.startswith("gfx12") | ||
| warp_m = 16 if is_gfx12 else 32 | ||
| warp_n = 16 if is_gfx12 else 32 | ||
| warp_k = 16 |
There was a problem hiding this comment.
Arch specific details must be derived from arch_specs.json.
| # wave config = warp distribution across block; warp/warp_tile = WMMA tile dims | ||
| # gfx12 valid wave configs: [2,4,1], [4,2,1], [1,8,1], [8,1,1] | ||
| # Default (4,1,1) from FmhaKernelConfig is invalid for gfx12 | ||
| wave_m_cfg, wave_n_cfg, wave_k_cfg = 2, 4, 1 |
There was a problem hiding this comment.
please get this information from arch_specs.json
get_supported_warp_tiles() in arch_filter.hpp returned empty tile lists
for BF16, FP8, BF8, and INT8 on gfx1200/gfx1201 despite arch_specs.json
defining all 7 data type combinations at [16,16,16]. This caused the CK
Tile dispatcher to never select WMMA kernels for these data types, forcing
fallback to scalar FMA at 47-88x slower throughput.
Changes:
- arch_filter.hpp: Add is_rdna4() helper returning {{16,16,16}} for all
data types on gfx1200/gfx1201
- arch_filter.py: Add gfx1200/gfx1201 to fallback WARP_TILE_SUPPORTED_COMBINATIONS
and WARP_SUPPORTED_COMBINATIONS dicts (defense-in-depth for when
arch_specs_generated is unavailable)
- test_dispatcher_common.py: Add TestRDNA4WarpTileSupport with 8 tests
covering all data types, negative case for 32x32x16, and dtype combo
completeness
Benchmark on RX 9070 XT (gfx1201):
BF16 WMMA: 196.94 TFLOPS (was BLOCKED, 47.6x faster than FP32 fallback)
FP8 WMMA: 364.21 TFLOPS (was BLOCKED, 88.0x faster than FP32 fallback)
Fixes: tensorflow-upstream ROCm#3067, rocm-jax ROCm#84
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
gfx1201 used arch_tag "ck_tile::gfx1201_t" and gfx1100 used "ck_tile::gfx1100_t", but CK headers only define family-level types: gfx9_t, gfx950_t, gfx103_t, gfx11_t, gfx12_t. This caused "no member named 'gfx1201_t' in namespace 'ck_tile'" compile errors for all FMHA JIT kernels targeting gfx1201. Fix: gfx1201 -> ck_tile::gfx12_t, gfx1100 -> ck_tile::gfx11_t Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…config spec_to_config() used hardcoded 32x32x16 warp tiles which are correct for CDNA (gfx9) but wrong for RDNA4 (gfx12) which only supports 16x16x16 WMMA tiles. This caused "space_filling_curve should be used to access a non-empty tensor" assertions when building FMHA kernels for gfx1201. Fix: Detect gfx12 arch prefix and set 16x16x16 warp tiles for all three GEMM stages (Q*K, softmax*V, and any additional stages). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The GEMM example builder had two bugs that prevented correct kernel generation for gfx1201 (RDNA4) and other non-gfx942 targets: 1. Hardcoded gfx942 defaults in parse/autocorrect: expand_gemm_wildcards() and auto_fill_gemm_defaults() used hardcoded wave configs (2,2,1) and warp tiles (32,32,16). On gfx1201 this autocorrected valid wave(2,4,1) + warp(16,16,16) configs to gfx942 values, limiting tiles to 64x64. 2. Missing --gpu-target in codegen subprocess: generate_gemm_kernels() called unified_gemm_codegen.py without passing the GPU target, so the arch filter validated against gfx942 and rejected all gfx1201-native configs. Fix: Add _get_arch_configs() helper that queries ArchFilter for valid wave/warp/warp_size per architecture. Thread arch parameter through parse_gemm_declarations(), expand_gemm_wildcards(), auto_fill_gemm_defaults(), detect_and_parse(), and generate_gemm_kernels(). Fix block_size formula from (tile/warp_tile)*64 to wave_m*wave_n*wave_k*warp_size. Impact: Unlocks 128x128 tiles on gfx1201, improving FP16 GEMM from 89 TFLOPS to 144.5 TFLOPS (74% of WMMA peak) on RX 9070 XT. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add Example 08: a gfx1201-specific GEMM benchmark that tests all registered kernel variants (FP16/BF16) with RDNA4-native configs: - Wave(2,4,1) with 16x16x16 WMMA warp tiles - Tile configs: 128x128x64, 128x128x32, 64x64x32 - Benchmarks every kernel and reports per-tile TFLOPS + WMMA efficiency - Includes correctness verification Build: cmake .. -DGPU_TARGETS=gfx1201 -DBUILD_DISPATCHER_EXAMPLES=ON Run: ./gemm_08_gfx1201_rdna4 [--M 4096 --N 4096 --K 4096] Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
spec_to_config() used hardcoded wave config (4,1,1) which is only
valid for CDNA (gfx9). gfx12 valid wave configs are [2,4,1], [4,2,1],
[1,8,1], [8,1,1]. Default to (2,4,1) for gfx12 as a balanced config.
Without this fix, FMHA JIT compilation silently fails because the
ArchFilter rejects the (4,1,1) wave config for gfx12.
Add 4 unit tests verifying gfx12 gets valid wave config and 16x16x16
warp tiles, and that gfx942 defaults are unchanged.
Note: FMHA JIT still has a secondary issue — CK's
block_gemm_areg_bsmem_creg_v2.hpp has a template instantiation error
("no viable overloaded '='") when compiling for gfx1201. This appears
to be a missing operator overload for the gfx12 WMMA register types
in the block GEMM code, not a config issue. Needs CK-level fix.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…h config, add INT4, consolidate RDNA4 Address all reviewer feedback from @yraparti, @k-artem, @0xDELUXA, and Copilot on PR ROCm#5455. Arch config from arch_specs.json (yraparti): - fmha_utils.py: Replace hardcoded `is_gfx12` warp tile / wave config overrides with lookup from arch_specs.json via _get_arch_spec(). Non-CDNA architectures (warp_size != 64) get warp tiles and wave configs from the JSON; CDNA archs keep FmhaKernelConfig defaults. Results are cached at module level. Gracefully falls back to defaults on FileNotFoundError. - This also correctly handles gfx1100 (RDNA3), which has the same wave32 / 16x16x16 tile characteristics as RDNA4. Remove redundant example (yraparti): - Delete 08_gfx1201_rdna4.cpp — existing examples with --arch flag suffice. - Remove corresponding CMakeLists.txt entry. Consolidate gfx1200/gfx1201 keys (yraparti): - arch_filter.py: gfx1200 and gfx1201 have identical warp tile configs. Define once under gfx1200, alias gfx1201 to the same dict. - Add missing gfx1200 to ARCH_FAMILY_MAP fallback. Add INT4 WMMA support (k-artem, 0xDELUXA — RDNA4 ISA page 411): - arch_specs.json: Add int4_int4_int32 [[16,16,16]] for gfx1200/gfx1201. - arch_filter.py: Add int4_int4_int32 to consolidated RDNA4 entry. - arch_filter.hpp: Add DataType::INT4 case returning rdna4_tiles. - Regenerated arch_specs_generated.py and arch_specs_generated.hpp. Fix Copilot nits: - test_dispatcher_common.py: Collapse 5 per-dtype test methods into single loop over RDNA4_DTYPES (now includes int4). - example_kernel_builder.py:637: Fix regex — add_body excludes closing paren, so match `$` instead of `\)`. - example_kernel_builder.py:682: Fix docstring to match actual block-size constraint (warp_m * warp_n * warp_k * warp_size <= 1024). Tests: - FMHA tests now cover both gfx1200 and gfx1201 (not just gfx1201). - Add gfx1100 (RDNA3) test to lock in wave32 arch_specs.json behavior. - Add missing-arch_specs.json fallback test using mock.patch. - RDNA4 dtype combo test updated to expect 8 entries (added int4). - All 41 tests pass (24 subtests). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
f23cf66 to
b9b743a
Compare
|
Hi — this is Claude (Opus 4.6), working with @ChrisLundquist. We've done a pass through all the review feedback and believe everything has been addressed in the final commit (b9b743a): @yraparti's comments (all resolved):
@k-artem / @0xDELUXA (resolved):
Copilot review (resolved):
Local validation:
Ready for another look when convenient. Thanks for the thorough reviews. |
64a5aa1 to
1a98a59
Compare

Proposed changes
This PR fixes gfx1201 (RDNA4 / RX 9070 XT) support in the CK Tile dispatcher. Depends on #5260.
Based on PR #5260. cc @vidyasagar-amd
Changes (7 commits)
CK Dispatcher (7): Fix RDNA4 warp tile filtering for BF16/FP8/INT8; fix fmha_arch_specs.json arch_tag for gfx12/gfx11; set correct warp tile defaults for gfx12 in spec_to_config; fix GEMM codegen pipeline for non-gfx942 architectures; fix FMHA wave config defaults for gfx12; add gfx1201 RDNA4 GEMM benchmark example.
Known gaps
Related: #4224, #5260, #5462, MIOpen #3750, TheRock #2510, tensorflow-upstream #3067, rocm-jax #84, vllm #28649