[ROCm] Enable FP8 KV-cache and relax constraints for RDNA4 custom paged attention#34741
[ROCm] Enable FP8 KV-cache and relax constraints for RDNA4 custom paged attention#34741laudney wants to merge 3 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces well-structured and robust changes to enable FP8 KV-cache support for RDNA4 (gfx12) and relaxes several constraints for the Navi kernel. The implementation of software dequantization in the attention kernel is clean and leverages if constexpr effectively to handle different data types. The refactoring of Q-loading logic to be independent of the cache data type improves correctness and readability. The Python-level guards have been appropriately updated to reflect the new capabilities and fixes. The test suite has also been correctly modified to include the new FP8 configurations on supported platforms. Overall, this is a high-quality contribution that significantly enhances performance and capabilities on ROCm hardware.
|
Hi @laudney, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
f80d69a to
6cbf04b
Compare
vllm/platforms/rocm.py
Outdated
| and block_size == 16 | ||
| and (gqa_ratio >= 3 and gqa_ratio <= 16) | ||
| and block_size_ok | ||
| and (gqa_ratio >= 1 and gqa_ratio <= 16) |
There was a problem hiding this comment.
won't this affect the condition of gfx11x?
There was a problem hiding this comment.
The custom paged attention kernels implementation for gfx12x and gfx11x are different.
| ) | ||
| ): | ||
| pytest.skip() | ||
| if fp8_unsupported or head_size != 128 or block_size != 16 or use_alibi: |
There was a problem hiding this comment.
I thought on gfx12 it supports block size of 32 as stated in
block_size_ok = block_size == 16 or (_ON_GFX12 and block_size == 32)
Cherry-picked and adapted from 4 open PRs: - vllm-project#34740 (laudney): Replace on_gfx9()/on_mi3xx() FP8 gates with supports_fp8(), unblocking FP8 on RDNA4/gfx12 - vllm-project#34709 (laudney): Enable wvSplitK/wvSplitKQ skinny GEMM kernels for RDNA4 decode (~15% improvement), wave32 DPP reduction - vllm-project#34741 (laudney): FP8 KV-cache for RDNA4 custom paged attention via software dequantization - vllm-project#36659 (vllmellm): Tuned FP8 MoE Triton configs for AMD Radeon AI PRO R9700, AITER mha_v3 attention on gfx12x
…ed attention Add FP8 KV-cache support for the gfx12 (RDNA4) custom paged attention kernel via software dequantization, and relax several constraints that were unnecessarily restrictive: Kernel changes (attention.cu): - Add convert_b8x8_to_b16x8() for portable FP8->FP16/BF16 dequant - Wire FP8 dequant path into gfx12 QKV kernel using per-block KV cache scale factors - Add GQA ratio 1-2 support for gfx12 (was gqa_ratio >= 3) Platform guard changes (rocm.py): - Add _ON_GFX12 flag to distinguish gfx12 from gfx11 - Allow block_size=32 on gfx12 (VBLOCKS_PER_LANE=1 is correct) - Restrict Navi kernel to head_size=128 (kernel assumes 128-wide heads) - Accept kv_cache_dtype fp8/fp8_e4m3 on gfx12 Test changes (test_attention.py): - Allow FP8 KV-cache tests on supports_fp8() platforms instead of blanket-skipping all FP8 on Navi Signed-off-by: L.B.R. <lbr@mmonad.com>
…code The gfx12 WMMA16x16x16 wave32 QK dot-product expects 16 contiguous head elements across both rows (row 0: lower 8, row 1: upper 8). For FP8, each row loads 16 values covering non-overlapping ranges (row 0: base+[0..15], row 1: base+[16..31]). Splitting by byte position into xy[0]/xy[1] created a cross-row mismatch where the wrong head dimensions were multiplied together, producing incorrect attention scores and degenerate model output (loops, nonsense). Fix by exchanging inner halves between rows via __shfl_xor(val, 16) before the WMMA calls, so each iteration covers a contiguous 16 head-element range that aligns with Q's layout. Non-FP8 path is unaffected (compile-time constexpr branch). V*logits WMMA is unaffected (rows access different token blocks, not different head elements). Signed-off-by: L.B.R. <lbr@mmonad.com>
e9efafe to
8c71262
Compare
…ions to gfx12 only - gqa_ratio >= 1 now only applies on gfx12; gfx11 retains >= 3 - Test skip updated to allow block_size=32 on gfx12 - Added NOTE about k_scale/v_scale test limitation Co-authored-by: Claude Signed-off-by: L.B.R. <lbr@mmonad.com>
|
Rebased onto latest Conflict resolution:
Fixes for review comments (new commit):
Pre-commit hooks all pass. All changes reviewed by GPT-5.4 (Codex) — no blockers found. |
Summary
Enable FP8 KV-cache support for the gfx12 (RDNA4) custom paged attention kernel via software dequantization, and fix several overly-restrictive constraints that prevented the Navi kernel from being used in valid configurations.
FP8 KV-cache halves the KV-cache memory footprint (16-bit → 8-bit per element), roughly doubling token capacity. On an AMD Radeon AI PRO R9700 (32GB), this increases KV-cache from ~132K tokens to ~264K tokens for a Qwen3-Coder-30B AWQ model — the difference between 1.4x and 2.75x concurrent 96K-context requests.
Kernel changes (
csrc/rocm/attention.cu)convert_b8x8_to_b16x8(): Portable FP8 → FP16/BF16 vector conversion using HIP runtimefp8::vec_conversion(works across gfx11/gfx12 without arch-specific intrinsics)kv_cache_dtype == "fp8", V-cache loads go through software dequant using per-block KV cache scale factors before the dot-product accumulationgqa_ratio >= 3— nowgqa_ratio >= 1Platform guard fixes (
vllm/platforms/rocm.py)_ON_GFX12block_size== 16only== 16on gfx11,== 16 or == 32on gfx12 (VBLOCKS_PER_LANE=1)head_size64 or 128on Navi128only — kernel computesvhead_elem = warpid * 16 + lane16idwithout bounds check, reads past head slice for 64-wide headskv_cache_dtype"auto"only"auto"or"fp8"/"fp8_e4m3"on gfx12Test changes (
tests/kernels/attention/test_attention.py)supports_fp8()platforms instead of blanket-skipping all FP8 on NaviRelated PRs (RDNA4/gfx12 series)
Test plan
kv-cache-dtype: fp8on gfx1201 — server reports "Using fp8 data type to store kv cache", 263K tokens KV capacity, correct inference outputkv-cache-dtype: fp8_e4m3on gfx1201 — 150K tokens, correct outputkv-cache-dtype: fp8on gfx1201 — correct output_ON_GFX12)__HIP__GFX9__guard)test_paged_attentionFP8 variants pass on gfx12