Skip to content

[PA] Optimize PA Decode Gluon Performance for BF16/FP16 with KV_BLOCK_SIZE=64 and Fix ROCm 7.0 AOT Compilation#1691

Merged
coderfeli merged 4 commits into
mainfrom
pa_gluon_opt_bf16
Dec 23, 2025
Merged

[PA] Optimize PA Decode Gluon Performance for BF16/FP16 with KV_BLOCK_SIZE=64 and Fix ROCm 7.0 AOT Compilation#1691
coderfeli merged 4 commits into
mainfrom
pa_gluon_opt_bf16

Conversation

@yanguahe

@yanguahe yanguahe commented Dec 19, 2025

Copy link
Copy Markdown
Contributor

Motivation

This PR addresses two issues in the pa_decode_gluon kernel:

  1. Performance Issue: The existing implementation shows suboptimal performance when using tl.bfloat16 or tl.float16 compute types with KV_BLOCK_SIZE=64. The BF16-Gluon configuration with block_size=64 had significantly lower bandwidth (~0.54 TB/s) compared to FP8 configuration.

  2. AOT Compilation Issue: On ROCm 7.0 platform, the AOT (Ahead-of-Time) compilation fails because triton.tools.compile module with compile_kernel and CompileArgs is not available in some Triton versions.

Technical Details

Performance Optimization

  • Layout Optimization: Introduced separate blocked layouts (blocked_key_layout_fp8, blocked_key_layout_f16, blocked_value_layout_fp8, blocked_value_layout_f16) to optimize memory access patterns based on compute type and KV block size.

  • Warp Configuration: Adjusted warps_per_cta and threads_per_warp configurations dynamically based on KV_BLOCK_SIZE (16 vs 64) to better utilize GPU resources for different block sizes.

  • Value Loading Layout: Added DistributedLinearLayout (v_blK_64_layout) for improved value cache loading with better register and lane base configurations for KV_BLOCK_SIZE=64.

  • Computation Reordering: Moved the attention score computation (QK dot product) before value loading to enable better instruction pipelining and hide memory latency.

  • Wave Configuration: Changed waves_per_eu from 2 to 1 for better occupancy control.

  • DOT k_width: Made DOT_QK_K_WIDTH dynamic based on KV_16B_ELEMENT_COUNT to properly handle different data type widths.

AOT Compilation Fix

  • Added a new standalone AOT compile utility at csrc/cpp_itfs/gluon_aot_tools/compile.py that provides compile_kernel and CompileArgs functionality independent of Triton version.

  • Updated pa_decode_gluon_aot.py to import from the new local module instead of triton.tools.compile, ensuring compatibility with ROCm 7.0 and different Triton versions.

Test Plan

Ran performance benchmarks with the following test configurations:

  • BLOCK_SIZE_OPTIONS: [16, 64]
  • COMPUTE_TYPE_OPTIONS: ["fp8", "bf16"]
  • CONTEXT_LENGTH_OPTIONS: [2048, 4096, 8192]
  • BATCH_SIZE_OPTIONS: [1, 2, 4, 8, 16, 32, 64, 128]
  • QUERY_LENGTH_OPTIONS: [1, 2, 3, 4]
  • HEAD_CONFIGURATIONS: [(64, 4), (64, 8)]
  • HEAD_DIMENSION_OPTIONS: [128]
  • AOT implementation enabled (USE_AOT_IMPL_OPTIONS: [True])

Test Result

block_size Summary Before (TB/s) After (TB/s) Speedup
64 FP8 0.871 0.904 1.04x
64 BF16-Gluon 0.544 1.229 2.26x
16 FP8 0.959 0.984 1.03x
16 BF16-Gluon 1.263 1.301 1.03x

The BF16-Gluon configuration with block_size=64 shows the most significant improvement with ~2.26x bandwidth increase.

Submission Checklist

  • Code compiles without errors
  • Performance regression tests pass
  • AOT compilation works on ROCm 7.0 platform
  • Backward compatibility maintained for existing configurations
  • No changes to public API

…m 7.0 AOT

- Add dedicated blocked layouts for f16/bf16 compute types
- Add local AOT compile tool to fix ROCm 7.0 compatibility
@yanguahe yanguahe requested review from a team, coderfeli and valarLip December 19, 2025 05:13
valarLip
valarLip previously approved these changes Dec 20, 2025
@yanguahe yanguahe changed the title [PA] Optimize f16/bf16 performance for KV_BLOCK_SIZE=64 and fix ROCm 7.0 AOT compilation [PA] Optimize PA Decode Gluon Performance for BF16/FP16 with KV_BLOCK_SIZE=64 and Fix ROCm 7.0 AOT Compilation Dec 22, 2025
@coderfeli coderfeli merged commit 14d92a0 into main Dec 23, 2025
23 checks passed
@coderfeli coderfeli deleted the pa_gluon_opt_bf16 branch December 23, 2025 04:09
Zzz9990 added a commit that referenced this pull request Dec 25, 2025
* fix sink error for asm fmha (#1652)

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* add guard in case pynccl init failed (#1671)

* One shot pa (#1670)

* add one shot pa kernel

* fix buffer load in sliding window kernel

* fix typo

* revert

---------

Co-authored-by: root <root@hjbog-srdc-24.amd.com>

* fix(pa_ps): fix pa_ps_asm .co for gfx950 (#1669)

Signed-off-by: Double Young <yang.yang2@amd.com>

* modify test_bf16gemm_test (#1678)

* Fix Ruff command in pre-checks (#1675)

* fix mha bwd golden perf issue (#1666)

* topk uplift v1 (#1662)

/lgtm

The customer has tested the code. It can work.

* topk uplift v1

* topk add api for choose topk_v1 or topk_v2

---------

Co-authored-by: yonshuai <yonshuai@amd.com>
Co-authored-by: yongshuai <yongshuai@amd.com>

* fix missing return in mha_bwd (#1688)

* Remove the input parameter "out" in gemm_a4w4 (#1679)

* Remove the input parameter "out" in gemm_a4w4

* update

* format

---------

Co-authored-by: valarLip <Lingpeng.Jin@amd.com>

* fwd v3 hd192 optimize inst alignment for causal mode (#1663)

Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>

* fix swa case mismatch (#1694)

* fixing the fp4 gemm tune script Exception caused by tile_m name inconsistency (#1686)

* CI: Migrate Triton tests to aiter-1gpu-runner (#1690)

* add ntile 128 for a8 blkQ moe 1 stage (#1695)

* add fmoe co with tilesize 32x128

* add ps co

* fix pertoken co bug

* add co to csv

* add 128ntile logic for one stage asm

* fix mem fault during perf turn

* en vs for pertoken kernel

---------

Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: zufayu <zufayu@amd.com>

* Optimize RoPE in the cases that hdim is small. (#1698)

* Introduce new grid config strategy for compatibility with cases that hdim is small.

* add launch bound to make sure that occu is always 8

* follow Copilot the suggestions

* rm garbage from whl (#1696)

* enhance prebuild logic (#1672)

* enhance prebuild logic

* ATen.h build issues

* bug fix

* bug fix II

* bug fix III

---------

Co-authored-by: zufayu <zufayu@amd.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>

* LLfp4 qr cap for atom (#1673)

* QR cap implemented to limit QR to prefill

* test git config

* Fix to genericize qr comm cap

* Incorrect cap number

* [MLA] MLA conditions rewrite  (#1665)

* open mla mtp and remove some logs

* fix qlen dense 128,N

* fix hint

* support sparse qlen input = 1

* change default splits

* fix dp causal (#1677)

* add two fp4 tune shapes and tuned config (#1687)

* add two fp4 tune shapes and tuned config

* change 32800 to 65536 to cover all cases between 32768 to 65536 as per feedback

* Dev/a8w4 and a8w8splitk (#1667)

* support moe a8w8 splitk  (#1654)

* Add support to a8w8_ck_moe_blk_gemm1 splitk

* add switch and add some logging

* tiny fix

* update ck 3rd party and add some logging

* add AITER_HEURISTIC_ONLY env

* update ck

* add condition to bypass tuned cfg

* change bypass type

* fix

* fix removed log

* upate ck submodule

* fix lint

* force to run tests

---------

Co-authored-by: oscar <huaiguxu@amd.com>

* Zan/moe a8w4 (#1655)

* update

* update

* update quant

* ut ready

* update quant type

* compile pass

* python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready

* update aiter dipatcher for bf16&fp8

* support a16 a8 dispatch

* finish quant & sort

* update aiter framework for a8w4 moe

* update ck

* update

* update

* update for atom

* update

---------

Co-authored-by: Zzz9990 <Zzz9990>
Co-authored-by: root <root@hjbog-srdc-24.amd.com>

* update ck

* fix dispatch

* fix too much logging

* update

* update ck

* update ck

* fix ruff code style

* revert aiter-test yaml

* fix ci

* fix ci

* fix ci

* add mocked tuned result and decoding cfg token to next power of 2

* Update tuned_fmoe.csv

remove duplicate

* remove hack dtype

* fix black

* unique index

* add empty arg to ck_moe_stage1

* resolve bias into lru cache

* rename bypass cfg to AITER_BYPASS_TUNE_CONFIG

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: Zzz9990 <zanzhang@amd.com>
Co-authored-by: root <root@hjbog-srdc-24.amd.com>
Co-authored-by: felix <felix.li@amd.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>

* bf16_gemm_clean_in_kl (#1700)

* bf16_gemm_clean_in_kl

* update

* update

* update

* update

* fix tuner (#1701)

* fix tuner

* Update gradlib/gradlib/GemmTuner.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* add gen_fake for 4 gemm operators (#1456)

Co-authored-by: Lin, Soga <soga.lin@amd.com>
Co-authored-by: sogalin <39478626+sogalin@users.noreply.github.com>

* fix llvm issue (#1703)

* fix llvm issue

* fix copilot

* feat: Adaptive topk algorithm selection based on input characteristics (#1578)

* Add radix-base selection

* Remove explicit template

* Update the selected k condition

* remove pos < k guard

* code format

* Update csrc/include/rocm_ops.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update csrc/kernels/topk_per_row_kernels.cu

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update csrc/kernels/topk_plain_kernels.cu

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update test_topk_plain.py

* Update TODO message

* Update csrc/kernels/topk_per_row_kernels.cu

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update op_tests/test_topk_plain.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* format test_topk_plain.py with black

* Disable triton test for a resonalbe execution time

* add explicit template instantiation

* fix explicit template instantiation

* add explicit template instantiation

* Add bf16 support

* Fix linter

* Fix build errors

* Fix condition

* Fix build and test

* Update conditions

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
Co-authored-by: MHYang <meng-hsuan.yang@amd.com>

* fix mha bwd build error (#1705)

* fix moe bug when pipever=v1 and nblk=64 (#1707)

* fix bug

* update

* fix (#1710)

* fix

* update lint

* [PA] Optimize PA Decode Gluon Performance for BF16/FP16 with KV_BLOCK_SIZE=64 and Fix ROCm 7.0 AOT Compilation (#1691)

* Optimize pa_decode_gluon f16/bf16 perf for KV_BLOCK_SIZE=64 & fix ROCm 7.0 AOT

- Add dedicated blocked layouts for f16/bf16 compute types
- Add local AOT compile tool to fix ROCm 7.0 compatibility

* black format file

* format file to pass the ruff check

* fix error in gfx950

* Fix argument parsing logic when AITER_JIT_DIR is set (#1715)

When AITER_JIT_DIR is defined the enum module is loaded as "module_aiter_enum" rather than "aiter.jit.module_aiter_enum".
This caused the docstring cleanup of enums to not work properly, causing a NameError exception in check_args.

* fix topk deocde bug in logit value is same (#1716)

Co-authored-by: yonshuai <yonshuai@amd.com>

* add fp32 input (#1706)

* add fp32 input

* format code

* perf bug fix

* logic fix : out type != input type

* bug fix

* format code

* remove dtype convert before act_and_mul in fused_moe

---------

Co-authored-by: zufayu <zufayu@amd.com>
Co-authored-by: chenjun <junchen2@amd.com>

* add sampling aot (#1711)

* add sampling aot

* simple compile

* fix compile bugs

* fix a bug

* revert changes

---------

Co-authored-by: root <root@hjbog-srdc-24.amd.com>

* update

* bugfix

* update

* update

---------

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Signed-off-by: Double Young <yang.yang2@amd.com>
Co-authored-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
Co-authored-by: who who who <fsx950223@outlook.com>
Co-authored-by: root <root@hjbog-srdc-24.amd.com>
Co-authored-by: Double Young <yang.yang2@amd.com>
Co-authored-by: amd-ruitang3 <145657428+amd-ruitang3@users.noreply.github.com>
Co-authored-by: Satya Nikhil Kodukula <nikhil.kodukula@gmail.com>
Co-authored-by: JaxChen29 <jichen@amd.com>
Co-authored-by: steamedMantou <82486092+steamedMantou@users.noreply.github.com>
Co-authored-by: yonshuai <yonshuai@amd.com>
Co-authored-by: yongshuai <yongshuai@amd.com>
Co-authored-by: Yu Guo <82124926+yuguo68@users.noreply.github.com>
Co-authored-by: la <46212055+junhaha666@users.noreply.github.com>
Co-authored-by: valarLip <Lingpeng.Jin@amd.com>
Co-authored-by: shay-li77 <xiangxli@amd.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Xin Huang <Xin.Huang@amd.com>
Co-authored-by: zufayu <zufa.yu@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: zufayu <zufayu@amd.com>
Co-authored-by: ruanjm <jiming.ruan@amd.com>
Co-authored-by: amirumoAMD <Amelia.Moore@amd.com>
Co-authored-by: yadaish <yadai@amd.com>
Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: felix <felix.li@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: mqhc2020 <marvin.tsai@amd.com>
Co-authored-by: Lin, Soga <soga.lin@amd.com>
Co-authored-by: sogalin <39478626+sogalin@users.noreply.github.com>
Co-authored-by: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com>
Co-authored-by: MHYang <meng-hsuan.yang@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: yanguahe <yanguahe@amd.com>
Co-authored-by: omoisis-dn <omoisis@drivenets.com>
Co-authored-by: chenjun <junchen2@amd.com>
ZhangLirong-amd pushed a commit that referenced this pull request Dec 29, 2025
…_SIZE=64 and Fix ROCm 7.0 AOT Compilation (#1691)

* Optimize pa_decode_gluon f16/bf16 perf for KV_BLOCK_SIZE=64 & fix ROCm 7.0 AOT

- Add dedicated blocked layouts for f16/bf16 compute types
- Add local AOT compile tool to fix ROCm 7.0 compatibility

* black format file

* format file to pass the ruff check

* fix error in gfx950
farlukas pushed a commit that referenced this pull request Jan 5, 2026
…_SIZE=64 and Fix ROCm 7.0 AOT Compilation (#1691)

* Optimize pa_decode_gluon f16/bf16 perf for KV_BLOCK_SIZE=64 & fix ROCm 7.0 AOT

- Add dedicated blocked layouts for f16/bf16 compute types
- Add local AOT compile tool to fix ROCm 7.0 compatibility

* black format file

* format file to pass the ruff check

* fix error in gfx950
zhuyuhua-v pushed a commit that referenced this pull request Jan 14, 2026
…_SIZE=64 and Fix ROCm 7.0 AOT Compilation (#1691)

* Optimize pa_decode_gluon f16/bf16 perf for KV_BLOCK_SIZE=64 & fix ROCm 7.0 AOT

- Add dedicated blocked layouts for f16/bf16 compute types
- Add local AOT compile tool to fix ROCm 7.0 compatibility

* black format file

* format file to pass the ruff check

* fix error in gfx950
valarLip pushed a commit that referenced this pull request Mar 18, 2026
…_SIZE=64 and Fix ROCm 7.0 AOT Compilation (#1691)

* Optimize pa_decode_gluon f16/bf16 perf for KV_BLOCK_SIZE=64 & fix ROCm 7.0 AOT

- Add dedicated blocked layouts for f16/bf16 compute types
- Add local AOT compile tool to fix ROCm 7.0 compatibility

* black format file

* format file to pass the ruff check

* fix error in gfx950
valarLip pushed a commit that referenced this pull request Mar 18, 2026
…_SIZE=64 and Fix ROCm 7.0 AOT Compilation (#1691)

* Optimize pa_decode_gluon f16/bf16 perf for KV_BLOCK_SIZE=64 & fix ROCm 7.0 AOT

- Add dedicated blocked layouts for f16/bf16 compute types
- Add local AOT compile tool to fix ROCm 7.0 compatibility

* black format file

* format file to pass the ruff check

* fix error in gfx950
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants