From aae29e613a636869fec5e47891f79fe58fbc26dd Mon Sep 17 00:00:00 2001 From: "L.B.R." Date: Tue, 17 Feb 2026 19:53:22 +0000 Subject: [PATCH 1/3] [ROCm] Enable FP8 KV-cache and relax constraints for RDNA4 custom paged attention Add FP8 KV-cache support for the gfx12 (RDNA4) custom paged attention kernel via software dequantization, and relax several constraints that were unnecessarily restrictive: Kernel changes (attention.cu): - Add convert_b8x8_to_b16x8() for portable FP8->FP16/BF16 dequant - Wire FP8 dequant path into gfx12 QKV kernel using per-block KV cache scale factors - Add GQA ratio 1-2 support for gfx12 (was gqa_ratio >= 3) Platform guard changes (rocm.py): - Add _ON_GFX12 flag to distinguish gfx12 from gfx11 - Allow block_size=32 on gfx12 (VBLOCKS_PER_LANE=1 is correct) - Restrict Navi kernel to head_size=128 (kernel assumes 128-wide heads) - Accept kv_cache_dtype fp8/fp8_e4m3 on gfx12 Test changes (test_attention.py): - Allow FP8 KV-cache tests on supports_fp8() platforms instead of blanket-skipping all FP8 on Navi Signed-off-by: L.B.R. --- csrc/rocm/attention.cu | 142 ++++++++++++++++++---- tests/kernels/attention/test_attention.py | 13 +- vllm/platforms/rocm.py | 10 +- 3 files changed, 130 insertions(+), 35 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index a339c5641bb4..b65ddbe7567e 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -2410,6 +2410,45 @@ typedef struct _B8x16 { _B8x8 xy[2]; } _B8x16; +// Convert 8 FP8 (e4m3) values to 8 f16/bf16 values for software dequant. +// Uses HIP runtime FP8 conversion (portable across gfx11/gfx12). +template +__device__ __forceinline__ _B16x8 convert_b8x8_to_b16x8(const _B8x8 input) { + _B16x8 ret; + if constexpr (std::is_same::value) { + union { + uint4 u32x4; + _B16x8 b16x8; + } cvt; + cvt.u32x4 = vllm::fp8::vec_conversion(input); + ret = cvt.b16x8; + } else if constexpr (std::is_same::value) { + // Reuse vector fp8->half conversion, then convert 4 half2 packs to bf16x2. + union { + uint4 u32x4; + uint32_t u32[4]; + } f16x8; + f16x8.u32x4 = vllm::fp8::vec_conversion(input); + union { + __hip_bfloat162 bf16x2[4]; + _B16x8 b16x8; + } cvt; + #pragma unroll + for (int i = 0; i < 4; i++) { + union { + uint32_t u32; + __half2 h2; + } half2_u; + half2_u.u32 = f16x8.u32[i]; + cvt.bf16x2[i] = __float22bfloat162_rn(__half22float2(half2_u.h2)); + } + ret = cvt.b16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } + return ret; +} + template __device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x8& inpA, const bit16x8& inpB, @@ -2566,8 +2605,12 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across // warp - _B16x8 Qlocal[QKHELOOP]; // note that 16 contiguous elements of Q should - // be fetched per lane for 16 bit cache types + // Q loading always uses scalar_t-based constants (independent of cache type) + constexpr int Q_ELEMS_16B = 16 / sizeof(scalar_t); // always 8 for f16/bf16 + constexpr int Q_PER_FETCH = Q_ELEMS_16B * ROWS_PER_WARP; // always 16 + constexpr int QHELOOP = HEAD_SIZE / Q_PER_FETCH; // always HEAD_SIZE/16 + + _B16x8 Qlocal[QHELOOP]; constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); @@ -2598,12 +2641,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( const int local_qhead_idx = lane16id % GQA_RATIO; const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; const scalar_t* q_ptr = q + query_start_off * q_stride + - global_qhead_idx * HEAD_SIZE + - rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + global_qhead_idx * HEAD_SIZE + rowid * Q_ELEMS_16B; if (lane16id < GQA_RATIO) { #pragma unroll - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH; + for (int qkhe_depth = 0; qkhe_depth < QHELOOP; qkhe_depth++) { + const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * Q_PER_FETCH; const _B16x8* q_fetch_ptr_16B = reinterpret_cast(q_fetch_ptr); Qlocal[qkhe_depth] = *q_fetch_ptr_16B; @@ -2632,7 +2674,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( __syncthreads(); #pragma unroll - for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkhe_depth = 0; qkhe_depth < QHELOOP; qkhe_depth++) { Qlocal[qkhe_depth] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0]; } @@ -2739,16 +2781,36 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( } } + // calculate post qk wmma scale + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + scale2 *= *k_scale; + } + floatx8 dout[TLOOP]; // qk wmma for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] = {0}; for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { - dout[token_depth] = gcn_wmma16x16x16_instr( - Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8, - dout[token_depth]); + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + dout[token_depth] = gcn_wmma16x16x16_instr( + Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8, + dout[token_depth]); + } else { + // FP8 KV cache: each _B16x8 contains 16 FP8 values (16 bytes). + // Split into two _B8x8 (8 FP8 each), convert to _B16x8 (8 f16/bf16), + // and do two WMMA calls. QKHELOOP is halved for FP8, so total WMMA + // iterations remain the same. + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int j = 0; j < 2; j++) { + _B16x8 Kconv = convert_b8x8_to_b16x8(Ktmp8x16.xy[j]); + dout[token_depth] = gcn_wmma16x16x16_instr( + Kconv.u16x8, Qlocal[qkhe_depth * 2 + j].u16x8, dout[token_depth]); + } + } } - dout[token_depth] *= scale; + dout[token_depth] *= scale2; } // calculate qk_max and exp_sum per warp and write to shared memory @@ -2833,22 +2895,50 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( _B16x8 outelems[VHELOOP]; // Softmax V wmma // v layout: 16he across lanes x 16 tokens per lane + // VTLANELOOP_F16 is the number of 8-f16-element chunks per V fetch group. + // For f16 cache: VTLANELOOP=2, matching directly. + // For FP8 cache: VTLANELOOP=1 (one 16-byte fetch = 16 FP8 values), + // which we split into 2 halves of 8, so effective loop count is 2. + constexpr int VTLANELOOP_F16 = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, Q_ELEMS_16B); // always 2 for 16 vtokens / 8 per f16 for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { floatx8 tmp_out = {0}; for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { - for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { - const int offset = rowid * VTLANELOOP + vfetch_depth; - const int offset1 = offset % ROWS_PER_WARP; - const int offset2 = offset / ROWS_PER_WARP; - // if output format is 16 qheads across 16 lanes, 16 head elems spread - // across rows - tmp_out = gcn_wmma16x16x16_instr( - Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8, - shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8, - tmp_out); + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int offset = rowid * VTLANELOOP + vfetch_depth; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + tmp_out = gcn_wmma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8, + shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8, + tmp_out); + } + } else { + // FP8 KV cache: each Vlocal entry has 16 FP8 values (16 bytes). + // Split into two _B8x8, convert each to _B16x8, and do 2 WMMA calls. + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + auto Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j = 0; j < 2; j++) { + _B16x8 Vconv = convert_b8x8_to_b16x8(Vtmp8x16.xy[j]); + const int vf_idx = vfetch_depth * 2 + j; + const int offset = rowid * VTLANELOOP_F16 + vf_idx; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + tmp_out = gcn_wmma16x16x16_instr( + Vconv.u16x8, + shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8, + tmp_out); + } + } } } + // apply post Softmax V wmma v_scale + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= *v_scale; + } outelems[vhe_depth] = from_floatx8(tmp_out); } @@ -3391,7 +3481,7 @@ void paged_attention_custom_launcher_navi( torch::Tensor& block_tables, torch::Tensor& seq_lens, const std::optional& query_start_loc, int max_seq_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, const std::optional& fp8_out_scale) { int num_seqs = block_tables.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -3421,8 +3511,10 @@ void paged_attention_custom_launcher_navi( const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - // NOTE: Navi does not support fp8. - const auto fp8_out_scale_ptr = nullptr; + const auto fp8_out_scale_ptr = + fp8_out_scale + ? reinterpret_cast(fp8_out_scale.value().data_ptr()) + : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE); @@ -3568,7 +3660,7 @@ void paged_attention_custom_launcher_navi( ALIBI_ENABLED, MFMA_TYPE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \ - max_seq_len, alibi_slopes, k_scale, v_scale); \ + max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \ } #define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 9ddceef8fb38..a94271903268 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -141,14 +141,13 @@ def test_paged_attention( ): pytest.skip() - if ( - version == "rocm" - and current_platform.is_navi() - and ( - kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi + if version == "rocm" and current_platform.is_navi(): + # gfx12 (RDNA4) supports FP8 KV cache via software dequant + fp8_unsupported = ( + kv_cache_dtype == "fp8" and not current_platform.supports_fp8() ) - ): - pytest.skip() + if fp8_unsupported or head_size != 128 or block_size != 16 or use_alibi: + pytest.skip() global PARTITION_SIZE diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 2f76aedbf27c..a7c09650a278 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -146,6 +146,7 @@ def _get_gcn_arch() -> str: _GCN_ARCH = _get_gcn_arch() _ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"]) +_ON_GFX12 = "gfx12" in _GCN_ARCH _ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"]) _ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) _ON_GFX942 = "gfx942" in _GCN_ARCH @@ -269,16 +270,19 @@ def use_rocm_custom_paged_attention( ) else: + # gfx12 (RDNA4) supports FP8 KV cache via software dequant + fp8_ok = kv_cache_dtype in ("fp8", "fp8_e4m3") and _ON_GFX12 + block_size_ok = block_size == 16 or (_ON_GFX12 and block_size == 32) return ( _ON_GFX1X and (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 - and block_size == 16 - and (gqa_ratio >= 3 and gqa_ratio <= 16) + and block_size_ok + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and alibi_slopes is None - and kv_cache_dtype == "auto" + and (kv_cache_dtype == "auto" or fp8_ok) and sinks is None ) From 8c712626ec903ad1d5aaf85935e6cfb9a8f296ba Mon Sep 17 00:00:00 2001 From: "L.B.R." Date: Sun, 22 Feb 2026 08:45:09 +0000 Subject: [PATCH 2/3] Fix gfx12 FP8 QK WMMA head-element misalignment in paged attention decode The gfx12 WMMA16x16x16 wave32 QK dot-product expects 16 contiguous head elements across both rows (row 0: lower 8, row 1: upper 8). For FP8, each row loads 16 values covering non-overlapping ranges (row 0: base+[0..15], row 1: base+[16..31]). Splitting by byte position into xy[0]/xy[1] created a cross-row mismatch where the wrong head dimensions were multiplied together, producing incorrect attention scores and degenerate model output (loops, nonsense). Fix by exchanging inner halves between rows via __shfl_xor(val, 16) before the WMMA calls, so each iteration covers a contiguous 16 head-element range that aligns with Q's layout. Non-FP8 path is unaffected (compile-time constexpr branch). V*logits WMMA is unaffected (rows access different token blocks, not different head elements). Signed-off-by: L.B.R. --- csrc/rocm/attention.cu | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index b65ddbe7567e..e6b035eb6783 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -2797,14 +2797,34 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8, dout[token_depth]); } else { - // FP8 KV cache: each _B16x8 contains 16 FP8 values (16 bytes). - // Split into two _B8x8 (8 FP8 each), convert to _B16x8 (8 f16/bf16), - // and do two WMMA calls. QKHELOOP is halved for FP8, so total WMMA - // iterations remain the same. + // FP8 KV cache: each row loads 16 FP8 K values covering different + // head-element ranges (row 0: base+[0..15], row 1: base+[16..31]). + // But Q is loaded with 16 contiguous head elements per WMMA split + // across both rows (row 0: lower 8, row 1: upper 8). + // Splitting the 16 FP8 bytes into xy[0]/xy[1] by byte position + // creates a cross-row mismatch: + // j=0: row0=[0..7] OK, row1=[16..23] WRONG (Q expects [8..15]) + // j=1: row0=[8..15] WRONG (Q expects [16..23]), row1=[24..31] OK + // Fix: exchange inner halves between rows via cross-row shuffle. auto Ktmp = Klocal[token_depth][qkhe_depth]; _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + + // Row 0 sends xy[1] (he 8..15), row 1 sends xy[0] (he 16..23). + // After shfl_xor: row 0 gets he 16..23, row 1 gets he 8..15. + const _B8x8 inner = Ktmp8x16.xy[1 - rowid]; + _B8x8 cross; + cross.x = __shfl_xor(inner.x, 16); + cross.y = __shfl_xor(inner.y, 16); + + #pragma unroll for (int j = 0; j < 2; j++) { - _B16x8 Kconv = convert_b8x8_to_b16x8(Ktmp8x16.xy[j]); + // j==rowid: use own outer half; j!=rowid: use cross-row data. + // j=0: row0=xy[0](he 0..7), row1=cross(he 8..15) -> he [0..15] + // j=1: row0=cross(he 16..23), row1=xy[1](he 24..31) -> he [16..31] + _B8x8 Kfp8; + Kfp8.x = (j == rowid) ? Ktmp8x16.xy[j].x : cross.x; + Kfp8.y = (j == rowid) ? Ktmp8x16.xy[j].y : cross.y; + _B16x8 Kconv = convert_b8x8_to_b16x8(Kfp8); dout[token_depth] = gcn_wmma16x16x16_instr( Kconv.u16x8, Qlocal[qkhe_depth * 2 + j].u16x8, dout[token_depth]); } From 61d467f57b09b6024af88ad49b82271910f3756c Mon Sep 17 00:00:00 2001 From: "L.B.R." Date: Fri, 27 Mar 2026 14:45:05 +0000 Subject: [PATCH 3/3] Fix gfx11 behavior regression: scope gqa_ratio and block_size relaxations to gfx12 only - gqa_ratio >= 1 now only applies on gfx12; gfx11 retains >= 3 - Test skip updated to allow block_size=32 on gfx12 - Added NOTE about k_scale/v_scale test limitation Co-authored-by: Claude Signed-off-by: L.B.R. --- tests/kernels/attention/test_attention.py | 13 ++++++++----- vllm/platforms/rocm.py | 3 ++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index a94271903268..ab29b81bacb3 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -142,11 +142,12 @@ def test_paged_attention( pytest.skip() if version == "rocm" and current_platform.is_navi(): - # gfx12 (RDNA4) supports FP8 KV cache via software dequant - fp8_unsupported = ( - kv_cache_dtype == "fp8" and not current_platform.supports_fp8() - ) - if fp8_unsupported or head_size != 128 or block_size != 16 or use_alibi: + # gfx12 (RDNA4) supports FP8 KV cache via software dequant; + # within is_navi(), supports_fp8() implies gfx12. + is_gfx12 = current_platform.supports_fp8() + fp8_unsupported = kv_cache_dtype == "fp8" and not is_gfx12 + block_size_ok = block_size == 16 or (is_gfx12 and block_size == 32) + if fp8_unsupported or head_size != 128 or not block_size_ok or use_alibi: pytest.skip() global PARTITION_SIZE @@ -195,6 +196,8 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale + # NOTE: non-trivial k_scale/v_scale would exercise FP8 dequant paths but + # the reference computation does not apply scales, so keep at 1.0 for now. k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Call the paged attention kernel. diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a7c09650a278..85f79ab404e4 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -273,13 +273,14 @@ def use_rocm_custom_paged_attention( # gfx12 (RDNA4) supports FP8 KV cache via software dequant fp8_ok = kv_cache_dtype in ("fp8", "fp8_e4m3") and _ON_GFX12 block_size_ok = block_size == 16 or (_ON_GFX12 and block_size == 32) + gqa_min = 1 if _ON_GFX12 else 3 return ( _ON_GFX1X and (sliding_window == 0 or sliding_window == (-1, -1)) and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size_ok - and (gqa_ratio >= 1 and gqa_ratio <= 16) + and (gqa_ratio >= gqa_min and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 and alibi_slopes is None and (kv_cache_dtype == "auto" or fp8_ok)