Skip to content

[AMD] Enable IN_THREAD_TRANSPOSE to GFX1201 by default #10185

Merged
antiagainst merged 4 commits into
triton-lang:mainfrom
skysnow2001:rdna4_thread_transp
May 1, 2026
Merged

[AMD] Enable IN_THREAD_TRANSPOSE to GFX1201 by default #10185
antiagainst merged 4 commits into
triton-lang:mainfrom
skysnow2001:rdna4_thread_transp

Conversation

@skysnow2001
Copy link
Copy Markdown
Contributor

Enables in-thread transpose on by default for gfx1201, matching the existing default for gfx942. It will transpose
the elements within a thread using registers after global load and before local write in order to maintain good global memory coalescing and wide ds instruction bitwidth and avoid shared memory bank conflicts.

Flash-Attention 2 Results (bf16):

headdim batch_size seqlen TFLOPs ITT=0 TFLOPS ITT=1 Δ
128 32 512 45.16 47.74 +5.7%
128 16 1024 53.34 63.69 +19.4%
128 8 2048 57.77 70.62 +22.3%
128 4 4096 61.90 76.49 +23.6%
128 2 8192 64.18 79.10 +23.3%
128 1 16384 62.15 73.48 +18.2%
128 1 32768 61.45 71.19 +15.9%

GEMM Results (bf16)

M N K TFLOPS ITT=0 TFLOPS ITT=1 Δ
1024 1024 1024 60.64 66.26 +9.3%
2048 2048 2048 101.69 101.26 -0.4%
4096 4096 4096 109.69 120.89 +10.2%
8192 8192 8192 107.64 118.09 +9.7%
4096 11008 4096 107.50 117.04 +8.9%
4096 4096 11008 109.55 119.10 +8.7%
4096 14336 4096 108.94 118.11 +8.4%
4096 4096 14336 109.35 118.83 +8.7%
4096 12288 4096 108.27 117.76 +8.8%

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because it only flips the default value of is_in_thread_transpose_enabled for a new architecture (gfx1201).
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Comment thread third_party/amd/backend/compiler.py Outdated
def is_in_thread_transpose_enabled(arch):
return (arch == "gfx942") if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
return (arch in ("gfx942",
"gfx1201")) if knobs.amd.use_in_thread_transpose is None else knobs.amd.use_in_thread_transpose
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include other gfx12 targets too. See how is_hip_rdna4 defined.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@antiagainst antiagainst marked this pull request as ready for review May 1, 2026 19:07
@antiagainst antiagainst requested a review from zhanglx13 as a code owner May 1, 2026 19:07
@antiagainst antiagainst enabled auto-merge (squash) May 1, 2026 19:08
@antiagainst antiagainst merged commit 5d69e1c into triton-lang:main May 1, 2026
8 of 9 checks passed
antiagainst pushed a commit that referenced this pull request May 27, 2026
#10390)

`InThreadTranspose` rewrites `tt.load -> ttg.local_alloc ->
ttg.local_load -> dot_op` so the K-contiguous WMMA/MFMA operand can be
read from LDS as wide `ds_load_b128` instead of scalar `ds_load_u16`
pairs when the load order doesn't match the consumer's K dimension.
The pattern matcher in `matchInThreadTransposePattern` already accepts
`AMDWmmaEncodingAttr` alongside `AMDMfmaEncodingAttr`, but the gate in
`is_in_thread_transpose_enabled` only activates the pass on gfx942
(CDNA3) and gfx120x (RDNA4, enabled in #10185). Extend it to also
cover RDNA3 (gfx110x/gfx1103) and RDNA3.5 (gfx115x).

Added a `inThreadTranspose_wmma` sub-test to
`test/TritonGPU/amd/in-thread-transpose.mlir` (gfx1151, wave32, WMMA
encoding) that verifies the pass produces an `amdg.in_thread_transpose`
op and that the downstream `ttg.local_load` returns the K-contiguous
`dot_op` layout (`kWidth = 16`).

On AITER's `flash_attn_2.varlen_fwd` at the Qwen3-Omni ViT prefill
shape (B=1, S=3200, H=16, head_dim=72, fp16) on gfx1151, this lifts
the inner-loop V `local_load` from 512 scalar `ds_load_u16(_d16_hi)`
to 144 vectorized `ds_load_b128` and gives a 3.8% median speedup
(3.042 -> 2.925 ms).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants