Skip to content

[CK_TILE] Update gfx11 FMHA forward kernel configs#5088

Merged
illsilin merged 6 commits intodevelopfrom
users/hyoon1/ck/rdna-fmha-tuning
Mar 10, 2026
Merged

[CK_TILE] Update gfx11 FMHA forward kernel configs#5088
illsilin merged 6 commits intodevelopfrom
users/hyoon1/ck/rdna-fmha-tuning

Conversation

@hyoon1
Copy link
Copy Markdown
Contributor

@hyoon1 hyoon1 commented Mar 4, 2026

Motivation

Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded seqlen_q/k) cases.
This tuning is based on heuristic search and improves performance in most tested shapes.
Performance should be evaluated on top of ROCm/rocm-libraries#5018 (required baseline).

Technical Details

  • Updated gfx11 codegen heuristic choices for tile size and occupancy.
  • Updated gfx11 pipeline selection:
    - Disabled the npad (f,f,f,f) qr entry because it was consistently slower than the pssk (t,t,f,f) path, and kept pssk enabled so npad cases are dispatched to the faster kernel path.`
  • Kept gfx12 unchanged: with PSSK support from ROCm/rocm-libraries#4957, existing gfx12 config is already sufficient.
  • Tuning rationale:
    • In some cases, higher kBlockPerCu lowers register pressure.
    • On RDNA, this generally aligns with better performance when waves_per_eu >= 6.

Test Plan

  • test_ck_tile_fmha
  • tile_example_fmha_fwd: tested this on gfx1100 and gfx1151
    ./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24 -d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

Test Result

  • TFLOPs by sequence length target: gfx1100 layout: bhsd
  • mode: batch
SeqLen Gain
1024 0.97x
4096 1.17x
8192 1.11x
12288 1.17x
16384 1.34x
20480 1.40x
24576 1.41x
27280 1.63x
  • mode: group
SeqLen Gain
1024 0.99x
4096 1.19x
8192 1.17x
12288 1.20x
16384 1.42x
20480 1.56x
24576 1.64x
27280 1.67x
  • TFLOPs by sequence length target: gfx1151 layout: bshd
  • mode: batch
Batch Gain
1024 1.09x
4096 1.05x
8192 1.01x
12288 0.99x
16384 1.05x
20480 0.99x
24576 1.04x
27280 2.44x
  • mode: group
Batch Gain
1024 0.97x
4096 1.05x
8192 1.07x
12288 1.00x
16384 1.24x
20480 1.71x
24576 2.33x
27280 2.61x

Submission Checklist

@hyoon1 hyoon1 requested a review from a team as a code owner March 4, 2026 08:24
@hyoon1 hyoon1 requested a review from ex-rzr March 4, 2026 08:24
@hyoon1 hyoon1 marked this pull request as draft March 4, 2026 11:21
Copy link
Copy Markdown
Contributor

@ex-rzr ex-rzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about other hdims? Is "f", "f", "f", "f" slower than "t", "t", "f", "f" for them too? (I know that padding for seqlen dimensions does not affect performance as much as padding for hdim, I'm just curious because it's really interesting observation)

Otherwise LGTM (btw, my approval does not make it possible to merge the PR because I'm not a code owner)

@hyoon1 hyoon1 marked this pull request as ready for review March 5, 2026 14:53
@hyoon1
Copy link
Copy Markdown
Contributor Author

hyoon1 commented Mar 9, 2026

What about other hdims? Is "f", "f", "f", "f" slower than "t", "t", "f", "f" for them too? (I know that padding for seqlen dimensions does not affect performance as much as padding for hdim, I'm just curious because it's really interesting observation)

Otherwise LGTM (btw, my approval does not make it possible to merge the PR because I'm not a code owner)

It varied a bit depending on the GPU target and head dimension. In some cases ffff was slightly faster than ttff, but most of the time the performance was very similar and the difference wasn’t really noticeable.

Since the workloads I looked at mostly use hdim 64 or 128 and ttff already performs well there, it doesn’t seem worth adding extra kernel variants for now. Enabling only ttff should be sufficient.

@illsilin illsilin merged commit 36ca523 into develop Mar 10, 2026
30 checks passed
@illsilin illsilin deleted the users/hyoon1/ck/rdna-fmha-tuning branch March 10, 2026 16:46
assistant-librarian bot pushed a commit to ROCm/composable_kernel that referenced this pull request Mar 10, 2026
[CK_TILE] Update gfx11 FMHA forward kernel configs

## Motivation
Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded
seqlen_q/k) cases.
This tuning is based on heuristic search and improves performance in
most tested shapes.
Performance should be evaluated on top of
[`ROCm/rocm-libraries#5018`](ROCm/rocm-libraries#5018)
(required baseline).

## Technical Details

  - Updated gfx11 codegen heuristic choices for tile size and occupancy.
   - Updated gfx11 pipeline selection:
- Disabled the `npad` (`f,f,f,f`) qr entry because it was consistently
slower than the `pssk` (`t,t,f,f`) path, and kept `pssk` enabled so npad
cases are dispatched to the faster kernel path.`
- Kept gfx12 unchanged: with PSSK support from
[`ROCm/rocm-libraries#4957`](ROCm/rocm-libraries#4957),
existing gfx12 config is already sufficient.
  - Tuning rationale:
    - In some cases, higher `kBlockPerCu` lowers register pressure.
- On RDNA, this generally aligns with better performance when
`waves_per_eu >= 6`.

## Test Plan
- test_ck_tile_fmha
- tile_example_fmha_fwd: tested this on gfx1100 and gfx1151
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24
-d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result
- TFLOPs by sequence length target: `gfx1100` layout: `bhsd`
- mode: batch / VGPR usage: 225 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 74.10 | 71.97 | 0.97x
4096 | 66.26 | 77.79 | 1.17x
8192 | 68.18 | 75.88 | 1.11x
12288 | 68.47 | 80.44 | 1.17x
16384 | 59.54 | 79.66 | 1.34x
20480 | 55.78 | 77.91 | 1.40x
24576 | 55.08 | 77.47 | 1.41x
27280 | 47.45 | 77.16 | 1.63x
- mode: group / VGPR usage: 256 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 71.47 | 70.6 | 0.99x
4096 | 64.74 | 77.06 | 1.19x
8192 | 64.68 | 75.47 | 1.17x
12288 | 66.43 | 79.95 | 1.20x
16384 | 56.02 | 79.73 | 1.42x
20480 | 50.21 | 78.15 | 1.56x
24576 | 47.29 | 77.53 | 1.64x
27280 | 46.13 | 77.04 | 1.67x

- TFLOPs by sequence length target: `gfx1151` layout: `bshd`
- mode: batch / VGPR usage: 225 vs 223

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 26.85 | 29.17 | 1.09x
4096 | 24.75 | 26.01 | 1.05x
8192 | 25.24 | 25.50 | 1.01x
12288 | 25.18 | 25.00 | 0.99x
16384 | 24.79 | 25.91 | 1.05x
20480 | 25.56 | 25.24 | 0.99x
24576 | 25.13 | 26.20 | 1.04x
27280 | 10.78 | 26.35 | 2.44x
- mode: group / VGPR usage: 256 vs 229

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 27.44 | 26.71 | 0.97x
4096 | 21.89 | 23.09 | 1.05x
8192 | 22.85 | 24.49 | 1.07x
12288 | 24.33 | 24.42 | 1.00x
16384 | 20.05 | 24.98 | 1.24x
20480 | 14.70 | 25.15 | 1.71x
24576 | 11.30 | 26.31 | 2.33x
27280 | 10.10 | 26.32 | 2.61x

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
kokolchin pushed a commit to kokolchin/rocm-libraries that referenced this pull request Mar 11, 2026
## Motivation
Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded
seqlen_q/k) cases.
This tuning is based on heuristic search and improves performance in
most tested shapes.
Performance should be evaluated on top of
[`ROCm#5018`](ROCm#5018)
(required baseline).

## Technical Details

  - Updated gfx11 codegen heuristic choices for tile size and occupancy.
   - Updated gfx11 pipeline selection:
- Disabled the `npad` (`f,f,f,f`) qr entry because it was consistently
slower than the `pssk` (`t,t,f,f`) path, and kept `pssk` enabled so npad
cases are dispatched to the faster kernel path.`
- Kept gfx12 unchanged: with PSSK support from
[`ROCm#4957`](ROCm#4957),
existing gfx12 config is already sufficient.
  - Tuning rationale:
    - In some cases, higher `kBlockPerCu` lowers register pressure.
- On RDNA, this generally aligns with better performance when
`waves_per_eu >= 6`.

## Test Plan
- test_ck_tile_fmha
- tile_example_fmha_fwd: tested this on gfx1100 and gfx1151
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24
-d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result
- TFLOPs by sequence length target: `gfx1100` layout: `bhsd`
- mode: batch / VGPR usage: 225 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 74.10 | 71.97 | 0.97x
4096 | 66.26 | 77.79 | 1.17x
8192 | 68.18 | 75.88 | 1.11x
12288 | 68.47 | 80.44 | 1.17x
16384 | 59.54 | 79.66 | 1.34x
20480 | 55.78 | 77.91 | 1.40x
24576 | 55.08 | 77.47 | 1.41x
27280 | 47.45 | 77.16 | 1.63x
- mode: group / VGPR usage: 256 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 71.47 | 70.6 | 0.99x
4096 | 64.74 | 77.06 | 1.19x
8192 | 64.68 | 75.47 | 1.17x
12288 | 66.43 | 79.95 | 1.20x
16384 | 56.02 | 79.73 | 1.42x
20480 | 50.21 | 78.15 | 1.56x
24576 | 47.29 | 77.53 | 1.64x
27280 | 46.13 | 77.04 | 1.67x

- TFLOPs by sequence length target: `gfx1151` layout: `bshd`
- mode: batch / VGPR usage: 225 vs 223

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 26.85 | 29.17 | 1.09x
4096 | 24.75 | 26.01 | 1.05x
8192 | 25.24 | 25.50 | 1.01x
12288 | 25.18 | 25.00 | 0.99x
16384 | 24.79 | 25.91 | 1.05x
20480 | 25.56 | 25.24 | 0.99x
24576 | 25.13 | 26.20 | 1.04x
27280 | 10.78 | 26.35 | 2.44x
- mode: group / VGPR usage: 256 vs 229

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 27.44 | 26.71 | 0.97x
4096 | 21.89 | 23.09 | 1.05x
8192 | 22.85 | 24.49 | 1.07x
12288 | 24.33 | 24.42 | 1.00x
16384 | 20.05 | 24.98 | 1.24x
20480 | 14.70 | 25.15 | 1.71x
24576 | 11.30 | 26.31 | 2.33x
27280 | 10.10 | 26.32 | 2.61x


## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
jovanau pushed a commit to jovanau/rocm-libraries that referenced this pull request Mar 19, 2026
## Motivation
Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded
seqlen_q/k) cases.
This tuning is based on heuristic search and improves performance in
most tested shapes.
Performance should be evaluated on top of
[`ROCm#5018`](ROCm#5018)
(required baseline).

## Technical Details

  - Updated gfx11 codegen heuristic choices for tile size and occupancy.
   - Updated gfx11 pipeline selection:
- Disabled the `npad` (`f,f,f,f`) qr entry because it was consistently
slower than the `pssk` (`t,t,f,f`) path, and kept `pssk` enabled so npad
cases are dispatched to the faster kernel path.`
- Kept gfx12 unchanged: with PSSK support from
[`ROCm#4957`](ROCm#4957),
existing gfx12 config is already sufficient.
  - Tuning rationale:
    - In some cases, higher `kBlockPerCu` lowers register pressure.
- On RDNA, this generally aligns with better performance when
`waves_per_eu >= 6`.

## Test Plan
- test_ck_tile_fmha
- tile_example_fmha_fwd: tested this on gfx1100 and gfx1151
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24
-d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result
- TFLOPs by sequence length target: `gfx1100` layout: `bhsd`
- mode: batch / VGPR usage: 225 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 74.10 | 71.97 | 0.97x
4096 | 66.26 | 77.79 | 1.17x
8192 | 68.18 | 75.88 | 1.11x
12288 | 68.47 | 80.44 | 1.17x
16384 | 59.54 | 79.66 | 1.34x
20480 | 55.78 | 77.91 | 1.40x
24576 | 55.08 | 77.47 | 1.41x
27280 | 47.45 | 77.16 | 1.63x
- mode: group / VGPR usage: 256 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 71.47 | 70.6 | 0.99x
4096 | 64.74 | 77.06 | 1.19x
8192 | 64.68 | 75.47 | 1.17x
12288 | 66.43 | 79.95 | 1.20x
16384 | 56.02 | 79.73 | 1.42x
20480 | 50.21 | 78.15 | 1.56x
24576 | 47.29 | 77.53 | 1.64x
27280 | 46.13 | 77.04 | 1.67x

- TFLOPs by sequence length target: `gfx1151` layout: `bshd`
- mode: batch / VGPR usage: 225 vs 223

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 26.85 | 29.17 | 1.09x
4096 | 24.75 | 26.01 | 1.05x
8192 | 25.24 | 25.50 | 1.01x
12288 | 25.18 | 25.00 | 0.99x
16384 | 24.79 | 25.91 | 1.05x
20480 | 25.56 | 25.24 | 0.99x
24576 | 25.13 | 26.20 | 1.04x
27280 | 10.78 | 26.35 | 2.44x
- mode: group / VGPR usage: 256 vs 229

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 27.44 | 26.71 | 0.97x
4096 | 21.89 | 23.09 | 1.05x
8192 | 22.85 | 24.49 | 1.07x
12288 | 24.33 | 24.42 | 1.00x
16384 | 20.05 | 24.98 | 1.24x
20480 | 14.70 | 25.15 | 1.71x
24576 | 11.30 | 26.31 | 2.33x
27280 | 10.10 | 26.32 | 2.61x


## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
OwOwTsai pushed a commit to ROCm/flash-attention that referenced this pull request Mar 19, 2026
First Branch and use develop tag 574c1c121a0f3c0b44155b2b1987d89d16159b58
Next need
     CK have
     ROCm/rocm-libraries#5018
     ROCm/rocm-libraries#5088

and modified flash_attn_ck to use grouping for RDNA default pytorch
invocation
also add bshd or bhsd selection in application layer to choose a faster
kernel
app layer may need to do qkv and output permute(0, 2, 1, 3)
johannes-graner pushed a commit that referenced this pull request Mar 20, 2026
## Motivation
Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded
seqlen_q/k) cases.
This tuning is based on heuristic search and improves performance in
most tested shapes.
Performance should be evaluated on top of
[`#5018`](#5018)
(required baseline).

## Technical Details

  - Updated gfx11 codegen heuristic choices for tile size and occupancy.
   - Updated gfx11 pipeline selection:
- Disabled the `npad` (`f,f,f,f`) qr entry because it was consistently
slower than the `pssk` (`t,t,f,f`) path, and kept `pssk` enabled so npad
cases are dispatched to the faster kernel path.`
- Kept gfx12 unchanged: with PSSK support from
[`#4957`](#4957),
existing gfx12 config is already sufficient.
  - Tuning rationale:
    - In some cases, higher `kBlockPerCu` lowers register pressure.
- On RDNA, this generally aligns with better performance when
`waves_per_eu >= 6`.

## Test Plan
- test_ck_tile_fmha
- tile_example_fmha_fwd: tested this on gfx1100 and gfx1151
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24
-d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result
- TFLOPs by sequence length target: `gfx1100` layout: `bhsd`
- mode: batch / VGPR usage: 225 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 74.10 | 71.97 | 0.97x
4096 | 66.26 | 77.79 | 1.17x
8192 | 68.18 | 75.88 | 1.11x
12288 | 68.47 | 80.44 | 1.17x
16384 | 59.54 | 79.66 | 1.34x
20480 | 55.78 | 77.91 | 1.40x
24576 | 55.08 | 77.47 | 1.41x
27280 | 47.45 | 77.16 | 1.63x
- mode: group / VGPR usage: 256 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 71.47 | 70.6 | 0.99x
4096 | 64.74 | 77.06 | 1.19x
8192 | 64.68 | 75.47 | 1.17x
12288 | 66.43 | 79.95 | 1.20x
16384 | 56.02 | 79.73 | 1.42x
20480 | 50.21 | 78.15 | 1.56x
24576 | 47.29 | 77.53 | 1.64x
27280 | 46.13 | 77.04 | 1.67x

- TFLOPs by sequence length target: `gfx1151` layout: `bshd`
- mode: batch / VGPR usage: 225 vs 223

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 26.85 | 29.17 | 1.09x
4096 | 24.75 | 26.01 | 1.05x
8192 | 25.24 | 25.50 | 1.01x
12288 | 25.18 | 25.00 | 0.99x
16384 | 24.79 | 25.91 | 1.05x
20480 | 25.56 | 25.24 | 0.99x
24576 | 25.13 | 26.20 | 1.04x
27280 | 10.78 | 26.35 | 2.44x
- mode: group / VGPR usage: 256 vs 229

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 27.44 | 26.71 | 0.97x
4096 | 21.89 | 23.09 | 1.05x
8192 | 22.85 | 24.49 | 1.07x
12288 | 24.33 | 24.42 | 1.00x
16384 | 20.05 | 24.98 | 1.24x
20480 | 14.70 | 25.15 | 1.71x
24576 | 11.30 | 26.31 | 2.33x
27280 | 10.10 | 26.32 | 2.61x


## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants