Skip to content

[ROCm] Enable FP8 KV-cache and relax constraints for RDNA4 custom paged attention#34741

Open
laudney wants to merge 3 commits intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-custom-attn-fp8kv
Open

[ROCm] Enable FP8 KV-cache and relax constraints for RDNA4 custom paged attention#34741
laudney wants to merge 3 commits intovllm-project:mainfrom
mmonad:feat/rocm-rdna4-custom-attn-fp8kv

Conversation

@laudney
Copy link
Copy Markdown
Contributor

@laudney laudney commented Feb 17, 2026

Summary

Enable FP8 KV-cache support for the gfx12 (RDNA4) custom paged attention kernel via software dequantization, and fix several overly-restrictive constraints that prevented the Navi kernel from being used in valid configurations.

FP8 KV-cache halves the KV-cache memory footprint (16-bit → 8-bit per element), roughly doubling token capacity. On an AMD Radeon AI PRO R9700 (32GB), this increases KV-cache from ~132K tokens to ~264K tokens for a Qwen3-Coder-30B AWQ model — the difference between 1.4x and 2.75x concurrent 96K-context requests.

Kernel changes (csrc/rocm/attention.cu)

  • convert_b8x8_to_b16x8(): Portable FP8 → FP16/BF16 vector conversion using HIP runtime fp8::vec_conversion (works across gfx11/gfx12 without arch-specific intrinsics)
  • FP8 dequant wired into gfx12 QKV kernel: When kv_cache_dtype == "fp8", V-cache loads go through software dequant using per-block KV cache scale factors before the dot-product accumulation
  • GQA ratio 1–2 support on gfx12: The gfx12 kernel correctly handles GQA ratios down to 1 (MHA), but the Python guard restricted to gqa_ratio >= 3 — now gqa_ratio >= 1

Platform guard fixes (vllm/platforms/rocm.py)

Guard Before (bug) After (fix)
_ON_GFX12 didn't exist New flag to distinguish gfx12 (FP8-capable) from gfx11 (no FP8)
block_size == 16 only == 16 on gfx11, == 16 or == 32 on gfx12 (VBLOCKS_PER_LANE=1)
head_size 64 or 128 on Navi 128 only — kernel computes vhead_elem = warpid * 16 + lane16id without bounds check, reads past head slice for 64-wide heads
kv_cache_dtype "auto" only "auto" or "fp8"/"fp8_e4m3" on gfx12

Test changes (tests/kernels/attention/test_attention.py)

  • FP8 KV-cache tests now run on supports_fp8() platforms instead of blanket-skipping all FP8 on Navi

Related PRs (RDNA4/gfx12 series)

Test plan

  • Qwen3-Coder-30B-A3B AWQ-4bit with kv-cache-dtype: fp8 on gfx1201 — server reports "Using fp8 data type to store kv cache", 263K tokens KV capacity, correct inference output
  • Qwen3-14B-FP8 with kv-cache-dtype: fp8_e4m3 on gfx1201 — 150K tokens, correct output
  • GPT-OSS-20B MXFP4 with kv-cache-dtype: fp8 on gfx1201 — correct output
  • No regression on gfx11 (FP8 path gated behind _ON_GFX12)
  • No regression on gfx9/MI-series (separate kernel under __HIP__GFX9__ guard)
  • test_paged_attention FP8 variants pass on gfx12

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces well-structured and robust changes to enable FP8 KV-cache support for RDNA4 (gfx12) and relaxes several constraints for the Navi kernel. The implementation of software dequantization in the attention kernel is clean and leverages if constexpr effectively to handle different data types. The refactoring of Q-loading logic to be independent of the cache data type improves correctness and readability. The Python-level guards have been appropriately updated to reflect the new capabilities and fixes. The test suite has also been correctly modified to include the new FP8 configurations on supported platforms. Overall, this is a high-quality contribution that significantly enhances performance and capabilities on ROCm hardware.

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 17, 2026

Hi @laudney, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@laudney laudney force-pushed the feat/rocm-rdna4-custom-attn-fp8kv branch 2 times, most recently from f80d69a to 6cbf04b Compare February 17, 2026 20:40
and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and block_size_ok
and (gqa_ratio >= 1 and gqa_ratio <= 16)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't this affect the condition of gfx11x?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom paged attention kernels implementation for gfx12x and gfx11x are different.

)
):
pytest.skip()
if fp8_unsupported or head_size != 128 or block_size != 16 or use_alibi:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought on gfx12 it supports block size of 32 as stated in

block_size_ok = block_size == 16 or (_ON_GFX12 and block_size == 32)

wi-adam added a commit to wi-adam/vllm that referenced this pull request Mar 12, 2026
Cherry-picked and adapted from 4 open PRs:

- vllm-project#34740 (laudney): Replace on_gfx9()/on_mi3xx() FP8 gates with
  supports_fp8(), unblocking FP8 on RDNA4/gfx12
- vllm-project#34709 (laudney): Enable wvSplitK/wvSplitKQ skinny GEMM kernels
  for RDNA4 decode (~15% improvement), wave32 DPP reduction
- vllm-project#34741 (laudney): FP8 KV-cache for RDNA4 custom paged attention
  via software dequantization
- vllm-project#36659 (vllmellm): Tuned FP8 MoE Triton configs for AMD Radeon
  AI PRO R9700, AITER mha_v3 attention on gfx12x
L.B.R. added 2 commits March 27, 2026 13:44
…ed attention

Add FP8 KV-cache support for the gfx12 (RDNA4) custom paged attention
kernel via software dequantization, and relax several constraints that
were unnecessarily restrictive:

Kernel changes (attention.cu):
- Add convert_b8x8_to_b16x8() for portable FP8->FP16/BF16 dequant
- Wire FP8 dequant path into gfx12 QKV kernel using per-block
  KV cache scale factors
- Add GQA ratio 1-2 support for gfx12 (was gqa_ratio >= 3)

Platform guard changes (rocm.py):
- Add _ON_GFX12 flag to distinguish gfx12 from gfx11
- Allow block_size=32 on gfx12 (VBLOCKS_PER_LANE=1 is correct)
- Restrict Navi kernel to head_size=128 (kernel assumes 128-wide heads)
- Accept kv_cache_dtype fp8/fp8_e4m3 on gfx12

Test changes (test_attention.py):
- Allow FP8 KV-cache tests on supports_fp8() platforms instead of
  blanket-skipping all FP8 on Navi

Signed-off-by: L.B.R. <lbr@mmonad.com>
…code

The gfx12 WMMA16x16x16 wave32 QK dot-product expects 16 contiguous
head elements across both rows (row 0: lower 8, row 1: upper 8).
For FP8, each row loads 16 values covering non-overlapping ranges
(row 0: base+[0..15], row 1: base+[16..31]). Splitting by byte
position into xy[0]/xy[1] created a cross-row mismatch where the
wrong head dimensions were multiplied together, producing incorrect
attention scores and degenerate model output (loops, nonsense).

Fix by exchanging inner halves between rows via __shfl_xor(val, 16)
before the WMMA calls, so each iteration covers a contiguous 16
head-element range that aligns with Q's layout.

Non-FP8 path is unaffected (compile-time constexpr branch).
V*logits WMMA is unaffected (rows access different token blocks,
not different head elements).

Signed-off-by: L.B.R. <lbr@mmonad.com>
@laudney laudney force-pushed the feat/rocm-rdna4-custom-attn-fp8kv branch from e9efafe to 8c71262 Compare March 27, 2026 13:45
…ions to gfx12 only

- gqa_ratio >= 1 now only applies on gfx12; gfx11 retains >= 3
- Test skip updated to allow block_size=32 on gfx12
- Added NOTE about k_scale/v_scale test limitation

Co-authored-by: Claude

Signed-off-by: L.B.R. <lbr@mmonad.com>
@laudney
Copy link
Copy Markdown
Contributor Author

laudney commented Mar 27, 2026

Rebased onto latest main and addressed review feedback:

Conflict resolution:

  • rocm.py had a conflict in the navi else branch — resolved cleanly (upstream removed VLLM_ROCM_CUSTOM_PAGED_ATTN env var, so we dropped that line)

Fixes for review comments (new commit):

  1. @tjtanaa's concern about gfx11 behaviorgqa_ratio >= 1 was incorrectly applied to all gfx1x. Now scoped: gqa_min = 1 if _ON_GFX12 else 3, so gfx11 retains >= 3.
  2. block_size=32 test coverage — test skip now allows block_size=32 on gfx12 instead of unconditionally requiring 16.
  3. k_scale/v_scale at 1.0 — kept as-is because the reference computation doesn't apply dequant scales. Added a NOTE comment documenting this limitation.

Pre-commit hooks all pass. All changes reviewed by GPT-5.4 (Codex) — no blockers found.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

2 participants