Skip to content

Add gfx1201 (RDNA4) support across CK Dispatcher, MIOpen, and hipBLASLt#5455

Open
ChrisLundquist wants to merge 41 commits intoROCm:users/vanantha/ck/dispatcher-fmhafrom
ChrisLundquist:gfx1201-work
Open

Add gfx1201 (RDNA4) support across CK Dispatcher, MIOpen, and hipBLASLt#5455
ChrisLundquist wants to merge 41 commits intoROCm:users/vanantha/ck/dispatcher-fmhafrom
ChrisLundquist:gfx1201-work

Conversation

@ChrisLundquist
Copy link
Copy Markdown

@ChrisLundquist ChrisLundquist commented Mar 15, 2026

Proposed changes

This PR fixes gfx1201 (RDNA4 / RX 9070 XT) support in the CK Tile dispatcher. Depends on #5260.

Note: The MIOpen and hipBLASLt enablement commits have been split out into #5462, which targets develop directly with no dispatcher dependency.

Based on PR #5260. cc @vidyasagar-amd

Changes (7 commits)

CK Dispatcher (7): Fix RDNA4 warp tile filtering for BF16/FP8/INT8; fix fmha_arch_specs.json arch_tag for gfx12/gfx11; set correct warp tile defaults for gfx12 in spec_to_config; fix GEMM codegen pipeline for non-gfx942 architectures; fix FMHA wave config defaults for gfx12; add gfx1201 RDNA4 GEMM benchmark example.

Known gaps

Related: #4224, #5260, #5462, MIOpen #3750, TheRock #2510, tensorflow-upstream #3067, rocm-jax #84, vllm #28649

Copy link
Copy Markdown
Contributor

@vidyasagar-amd vidyasagar-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes and example, good to see gfx1201 support being added and tested

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves gfx1201 / RDNA4 support in the CK Tile dispatcher by aligning warp-tile filtering, FMHA architecture tags/config defaults, and example/codegen tooling so kernels can be generated/validated correctly for gfx12 targets.

Changes:

  • Add/validate RDNA4 warp-tile support (16x16x16) across BF16/FP8/BF8/INT8 and add unit tests to prevent regressions.
  • Make example kernel build/codegen path architecture-aware (wave/warp tile defaults + --gpu-target passed through).
  • Update FMHA defaults/spec metadata for gfx11/gfx12 (arch_tag + gfx12 wave/warp overrides) and add a gfx1201 GEMM benchmark example.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
projects/composablekernel/dispatcher/tests/test_dispatcher_common.py Adds RDNA4 warp-tile tests and gfx12 FMHA spec_to_config regression tests.
projects/composablekernel/dispatcher/scripts/example_kernel_builder.py Makes GEMM wildcard expansion/autofill depend on target arch; forwards --gpu-target and switches to --output-dir.
projects/composablekernel/dispatcher/python/fmha_utils.py Overrides gfx12 FMHA warp tiles and wave defaults to valid RDNA4 values.
projects/composablekernel/dispatcher/include/ck_tile/dispatcher/arch_filter.hpp Extends RDNA4 warp-tile filtering to gfx1200/gfx1201 and all relevant dtypes.
projects/composablekernel/dispatcher/examples/gemm/cpp/08_gfx1201_rdna4.cpp New RDNA4 GEMM benchmark example with gfx1201-targeted kernel declarations.
projects/composablekernel/dispatcher/examples/CMakeLists.txt Registers the new RDNA4 GEMM example target.
projects/composablekernel/dispatcher/codegen/fmha_arch_specs.json Fixes FMHA arch_tag mapping to use ck_tile::gfx11_t / ck_tile::gfx12_t.
projects/composablekernel/dispatcher/codegen/arch_filter.py Adds gfx1200 wave and warp-tile combos to the fallback tables.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# All data types that RDNA4 supports via WMMA
RDNA4_DTYPES = ["fp16", "bf16", "fp8", "bf8", "int8"]
EXPECTED_TILE = [16, 16, 16]

Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RDNA4_DTYPES is declared but never used in these tests. Consider either using it to drive the per-dtype warp-tile subtests (to avoid duplication) or removing it to keep the test module clean.

Suggested change
def test_rdna4_all_dtypes_warp_tile(self):
"""Validate 16x16x16 warp tiles for all RDNA4 WMMA data types."""
for arch in self.RDNA4_ARCHS:
for dtype in self.RDNA4_DTYPES:
with self.subTest(arch=arch, dtype=dtype):
is_valid, msg = validate_warp_tile_config(
self.EXPECTED_TILE, arch, dtype
)
self.assertTrue(is_valid, f"{arch} {dtype}: {msg}")

Copilot uses AI. Check for mistakes.
Comment on lines +636 to +637
# Matches: , "gfx1201") or similar at end of .add() body
if m := re.search(r',\s*"(gfx\w+)"\s*\)', add_body):
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In parse_gemm_declarations(), the regex used to capture the per-kernel architecture from the third argument to .add() (re.search(r',\s*"(gfx\w+)"\s*\)', add_body)) will never match because add_body is extracted inside the balanced parentheses and therefore does not contain the closing ) of the .add(...) call. This prevents kernel["arch"]/decl_arch from being detected and can cause wildcard expansion/autofill to use the wrong architecture defaults when --gpu-target is not explicitly provided. Update the pattern to match the end of add_body (no )), or otherwise parse the trailing argument reliably.

Suggested change
# Matches: , "gfx1201") or similar at end of .add() body
if m := re.search(r',\s*"(gfx\w+)"\s*\)', add_body):
# Matches: , "gfx1201" at end of .add() body
if m := re.search(r',\s*"(gfx\w+)"\s*$', add_body):

Copilot uses AI. Check for mistakes.
- (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * 64 <= 1024
- For 128x128 tile: only (32,32,k) works (16 warps * 64 = 1024)
- For 64x64 tile: both (16,16,k) and (32,32,k) work
- (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * warp_size <= 1024
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block-size constraint description in expand_gemm_wildcards() is now inconsistent with the implementation: the docstring mentions (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * warp_size <= 1024, but the code checks warp_m * warp_n * warp_k * warp_size <= 1024. Please update the docstring to reflect the current logic so future edits don’t reintroduce the old (and different) constraint.

Suggested change
- (tile_m/warp_tile_m) * (tile_n/warp_tile_n) * warp_size <= 1024
- warp_m * warp_n * warp_k * warp_size <= 1024

Copilot uses AI. Check for mistakes.
@0xDELUXA
Copy link
Copy Markdown

Thanks for the fixes and example, good to see gfx1201 support being added and tested

I agree, and it's even better that this PR isn't gfx1201-only - it brings in gfx1200 too, so it's a win for RDNA4 overall.

Comment thread projects/composablekernel/dispatcher/codegen/arch_filter.py
Comment on lines +206 to +223
"gfx1200": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
"gfx1201": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if gfx1200 and 1201 are essentially the same, do we need 2 different keys here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yraparti , is it make sense to add int4?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is supported by gfx12_, sure.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does - see RDNA4 ISA, page 411 (bottom of page).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

@fjankovi fjankovi requested a review from hyoon1 March 30, 2026 16:35
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is not necessary. We can use the existing examples to add a test with a different --arch.

Comment on lines +1420 to +1423
is_gfx12 = arch.startswith("gfx12")
warp_m = 16 if is_gfx12 else 32
warp_n = 16 if is_gfx12 else 32
warp_k = 16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arch specific details must be derived from arch_specs.json.

# wave config = warp distribution across block; warp/warp_tile = WMMA tile dims
# gfx12 valid wave configs: [2,4,1], [4,2,1], [1,8,1], [8,1,1]
# Default (4,1,1) from FmhaKernelConfig is invalid for gfx12
wave_m_cfg, wave_n_cfg, wave_k_cfg = 2, 4, 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please get this information from arch_specs.json

ChrisLundquist and others added 7 commits April 3, 2026 23:08
get_supported_warp_tiles() in arch_filter.hpp returned empty tile lists
for BF16, FP8, BF8, and INT8 on gfx1200/gfx1201 despite arch_specs.json
defining all 7 data type combinations at [16,16,16]. This caused the CK
Tile dispatcher to never select WMMA kernels for these data types, forcing
fallback to scalar FMA at 47-88x slower throughput.

Changes:
- arch_filter.hpp: Add is_rdna4() helper returning {{16,16,16}} for all
  data types on gfx1200/gfx1201
- arch_filter.py: Add gfx1200/gfx1201 to fallback WARP_TILE_SUPPORTED_COMBINATIONS
  and WARP_SUPPORTED_COMBINATIONS dicts (defense-in-depth for when
  arch_specs_generated is unavailable)
- test_dispatcher_common.py: Add TestRDNA4WarpTileSupport with 8 tests
  covering all data types, negative case for 32x32x16, and dtype combo
  completeness

Benchmark on RX 9070 XT (gfx1201):
  BF16 WMMA: 196.94 TFLOPS (was BLOCKED, 47.6x faster than FP32 fallback)
  FP8  WMMA: 364.21 TFLOPS (was BLOCKED, 88.0x faster than FP32 fallback)

Fixes: tensorflow-upstream ROCm#3067, rocm-jax ROCm#84

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
gfx1201 used arch_tag "ck_tile::gfx1201_t" and gfx1100 used
"ck_tile::gfx1100_t", but CK headers only define family-level types:
gfx9_t, gfx950_t, gfx103_t, gfx11_t, gfx12_t. This caused "no member
named 'gfx1201_t' in namespace 'ck_tile'" compile errors for all FMHA
JIT kernels targeting gfx1201.

Fix: gfx1201 -> ck_tile::gfx12_t, gfx1100 -> ck_tile::gfx11_t

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…config

spec_to_config() used hardcoded 32x32x16 warp tiles which are correct
for CDNA (gfx9) but wrong for RDNA4 (gfx12) which only supports
16x16x16 WMMA tiles. This caused "space_filling_curve should be used to
access a non-empty tensor" assertions when building FMHA kernels for
gfx1201.

Fix: Detect gfx12 arch prefix and set 16x16x16 warp tiles for all three
GEMM stages (Q*K, softmax*V, and any additional stages).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The GEMM example builder had two bugs that prevented correct kernel
generation for gfx1201 (RDNA4) and other non-gfx942 targets:

1. Hardcoded gfx942 defaults in parse/autocorrect: expand_gemm_wildcards()
   and auto_fill_gemm_defaults() used hardcoded wave configs (2,2,1) and
   warp tiles (32,32,16). On gfx1201 this autocorrected valid wave(2,4,1)
   + warp(16,16,16) configs to gfx942 values, limiting tiles to 64x64.

2. Missing --gpu-target in codegen subprocess: generate_gemm_kernels()
   called unified_gemm_codegen.py without passing the GPU target, so the
   arch filter validated against gfx942 and rejected all gfx1201-native
   configs.

Fix: Add _get_arch_configs() helper that queries ArchFilter for valid
wave/warp/warp_size per architecture. Thread arch parameter through
parse_gemm_declarations(), expand_gemm_wildcards(), auto_fill_gemm_defaults(),
detect_and_parse(), and generate_gemm_kernels(). Fix block_size formula
from (tile/warp_tile)*64 to wave_m*wave_n*wave_k*warp_size.

Impact: Unlocks 128x128 tiles on gfx1201, improving FP16 GEMM from
89 TFLOPS to 144.5 TFLOPS (74% of WMMA peak) on RX 9070 XT.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add Example 08: a gfx1201-specific GEMM benchmark that tests all
registered kernel variants (FP16/BF16) with RDNA4-native configs:
- Wave(2,4,1) with 16x16x16 WMMA warp tiles
- Tile configs: 128x128x64, 128x128x32, 64x64x32
- Benchmarks every kernel and reports per-tile TFLOPS + WMMA efficiency
- Includes correctness verification

Build: cmake .. -DGPU_TARGETS=gfx1201 -DBUILD_DISPATCHER_EXAMPLES=ON
Run:   ./gemm_08_gfx1201_rdna4 [--M 4096 --N 4096 --K 4096]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
spec_to_config() used hardcoded wave config (4,1,1) which is only
valid for CDNA (gfx9). gfx12 valid wave configs are [2,4,1], [4,2,1],
[1,8,1], [8,1,1]. Default to (2,4,1) for gfx12 as a balanced config.

Without this fix, FMHA JIT compilation silently fails because the
ArchFilter rejects the (4,1,1) wave config for gfx12.

Add 4 unit tests verifying gfx12 gets valid wave config and 16x16x16
warp tiles, and that gfx942 defaults are unchanged.

Note: FMHA JIT still has a secondary issue — CK's
block_gemm_areg_bsmem_creg_v2.hpp has a template instantiation error
("no viable overloaded '='") when compiling for gfx1201. This appears
to be a missing operator overload for the gfx12 WMMA register types
in the block GEMM code, not a config issue. Needs CK-level fix.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…h config, add INT4, consolidate RDNA4

Address all reviewer feedback from @yraparti, @k-artem, @0xDELUXA, and Copilot
on PR ROCm#5455.

Arch config from arch_specs.json (yraparti):
- fmha_utils.py: Replace hardcoded `is_gfx12` warp tile / wave config
  overrides with lookup from arch_specs.json via _get_arch_spec(). Non-CDNA
  architectures (warp_size != 64) get warp tiles and wave configs from the
  JSON; CDNA archs keep FmhaKernelConfig defaults. Results are cached at
  module level. Gracefully falls back to defaults on FileNotFoundError.
- This also correctly handles gfx1100 (RDNA3), which has the same wave32 /
  16x16x16 tile characteristics as RDNA4.

Remove redundant example (yraparti):
- Delete 08_gfx1201_rdna4.cpp — existing examples with --arch flag suffice.
- Remove corresponding CMakeLists.txt entry.

Consolidate gfx1200/gfx1201 keys (yraparti):
- arch_filter.py: gfx1200 and gfx1201 have identical warp tile configs.
  Define once under gfx1200, alias gfx1201 to the same dict.
- Add missing gfx1200 to ARCH_FAMILY_MAP fallback.

Add INT4 WMMA support (k-artem, 0xDELUXA — RDNA4 ISA page 411):
- arch_specs.json: Add int4_int4_int32 [[16,16,16]] for gfx1200/gfx1201.
- arch_filter.py: Add int4_int4_int32 to consolidated RDNA4 entry.
- arch_filter.hpp: Add DataType::INT4 case returning rdna4_tiles.
- Regenerated arch_specs_generated.py and arch_specs_generated.hpp.

Fix Copilot nits:
- test_dispatcher_common.py: Collapse 5 per-dtype test methods into single
  loop over RDNA4_DTYPES (now includes int4).
- example_kernel_builder.py:637: Fix regex — add_body excludes closing paren,
  so match `$` instead of `\)`.
- example_kernel_builder.py:682: Fix docstring to match actual block-size
  constraint (warp_m * warp_n * warp_k * warp_size <= 1024).

Tests:
- FMHA tests now cover both gfx1200 and gfx1201 (not just gfx1201).
- Add gfx1100 (RDNA3) test to lock in wave32 arch_specs.json behavior.
- Add missing-arch_specs.json fallback test using mock.patch.
- RDNA4 dtype combo test updated to expect 8 entries (added int4).
- All 41 tests pass (24 subtests).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ChrisLundquist
Copy link
Copy Markdown
Author

Hi — this is Claude (Opus 4.6), working with @ChrisLundquist. We've done a pass through all the review feedback and believe everything has been addressed in the final commit (b9b743a):

@yraparti's comments (all resolved):

  • ✅ Added gfx1200 to the family map
  • ✅ Consolidated gfx1200/gfx1201 via dict alias (single definition, no duplication)
  • ✅ FMHA warp tile and wave config overrides are now data-driven from arch_specs.json via _get_arch_spec(), not hardcoded
  • ✅ Removed 08_gfx1201_rdna4.cpp example file — existing examples with --arch flag suffice

@k-artem / @0xDELUXA (resolved):

  • ✅ Added int4_int4_int32 support for gfx1200/gfx1201 per RDNA4 ISA (page 411)

Copilot review (resolved):

  • ✅ Docstring in expand_gemm_wildcards() updated to match new warp_m * warp_n * warp_k * warp_size <= 1024 logic
  • RDNA4_DTYPES is used (referenced in test_rdna4_all_dtypes_warp_tile via self.RDNA4_DTYPES)

Local validation:

  • 41/41 unit tests pass (including the 6 FMHA spec_to_config tests that verify gfx12 gets 16x16x16 tiles and valid wave configs)
  • JSON specs consistency verified — all 8 dtype combos present for both gfx1200/gfx1201, fmha_arch_specs.json arch_tags correct

Ready for another look when convenient. Thanks for the thorough reviews.

@vidyasagar-amd vidyasagar-amd force-pushed the users/vanantha/ck/dispatcher-fmha branch from 64a5aa1 to 1a98a59 Compare April 9, 2026 19:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants