From 54084ea526770c2f0beb4ecf28f299c379766a92 Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 1 May 2026 06:31:16 +0000 Subject: [PATCH 1/9] [ROCm][DSv4] MI300X (gfx942) support for DeepSeek V4 DSv4 on AMD MI300X (gfx942) hits several FP8-related issues that this commit addresses: 1. **fp8_utils.py**: ``process_fp8_weight_block_strategy`` calls ``normalize_e4m3fn_to_e4m3fnuz`` which doubles ``weight_scale`` by ``weight_scale * 2.0``. On models with UE8M0 scales (``torch.float8_e8m0fnu``), that ``mul`` is not implemented on CUDA/HIP and load aborts with:: NotImplementedError: "mul_cuda" not implemented for 'Float8_e8m0fnu' UE8M0 stores power-of-2 exponent values (2^(exp-127)) with no mantissa, so doubling the scale is equivalent to incrementing the exponent byte by 1. Handle the UE8M0 case explicitly and fall back to the float path otherwise. 2. **fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu**: gate ``kFp8Max`` to match the FNUZ/OCP path actually taken on each ROCm arch (240 on gfx942 FNUZ, 448 on gfx950 OCP). 3. **deepseek_v4_attention.py** + **cache_utils.py**: small MI300X path fixes that go with the FNUZ scale handling above. Co-authored-by: ganyi Signed-off-by: ganyi Signed-off-by: Markus Hartikainen Co-authored-by: Cursor --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 4 +++ .../layers/quantization/utils/fp8_utils.py | 25 ++++++++++++++++--- .../deepseek_v4/common/ops/cache_utils.py | 10 ++++++-- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 589e5f7bac04..6108f1ba189b 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -85,7 +85,11 @@ constexpr int kQuantBlock = 64; constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 +#if defined(USE_ROCM) && !defined(HIP_FP8_TYPE_OCP) +constexpr float kFp8Max = 240.0f; +#else constexpr float kFp8Max = 448.0f; +#endif #ifndef USE_ROCM // When num_tokens is less than this threshold, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 8b20c13a97f9..abac0726dea3 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1310,9 +1310,28 @@ def process_fp8_weight_block_strategy( ) if current_platform.is_fp8_fnuz() and weight.dtype == torch.float8_e4m3fn: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale - ) + if weight_scale.dtype == torch.float8_e8m0fnu: + # UE8M0 scales: e8m0 stores exponent-only values (2^(exp-127)), + # so doubling the dequant scale == incrementing the exponent byte + # by 1. Convert the OCP E4M3 weight bytes to FNUZ in place by + # reinterpreting and patching the NaN sentinel (-128 in int8), + # then double the UE8M0 exponent so the dequantized magnitudes + # match. + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + exp_bytes = weight_scale.view(torch.uint8) + weight_scale = ( + (exp_bytes.to(torch.int16) + 1) + .clamp(max=254) + .to(torch.uint8) + .view(torch.float8_e8m0fnu) + ) + else: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) weight = _maybe_pad_fp8_weight(weight) return weight, weight_scale diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index ac66751e3111..13baeae62456 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -216,6 +216,7 @@ def _dequantize_and_gather_k_kernel( output_dim: tl.constexpr, # 512 fp8_max: tl.constexpr, n_quant_blocks: tl.constexpr, # 7 real blocks + use_fnuz: tl.constexpr = False, ): batch_idx = tl.program_id(0) worker_id = tl.program_id(1) @@ -273,8 +274,11 @@ def _dequantize_and_gather_k_kernel( # Load quantized fp8 values (stored as uint8) x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0) - # Bitcast uint8 back to fp8 - x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + # Bitcast uint8 back to fp8 (FNUZ on gfx942, OCP otherwise) + if use_fnuz: + x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) + else: + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) # Convert fp8 to float32 for computation x_float = x_fp8.to(tl.float32) @@ -317,6 +321,7 @@ def dequantize_and_gather_k_cache_triton( block_table: torch.Tensor, block_size: int, offset: int, + use_fnuz: bool = False, ) -> None: TOKEN_FP8_DIM = 448 TOKEN_BF16_DIM = 64 @@ -347,6 +352,7 @@ def dequantize_and_gather_k_cache_triton( output_dim=512, fp8_max=FP8_MAX, n_quant_blocks=7, + use_fnuz=use_fnuz, ) From 69146983af6da94400eca20c1af435fc13650e56 Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 1 May 2026 07:14:20 +0000 Subject: [PATCH 2/9] [ROCm][DSv4] Use FNUZ FP8 on gfx942 in fused KV insert kernel ``fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu`` selected its FP8 type and ``kFp8Max`` based purely on the HIP build macro ``HIP_FP8_TYPE_OCP``. That macro is set by the HIP runtime version, not by the target GPU arch -- on a HIP build that defines ``HIP_FP8_TYPE_OCP``, the kernel was using OCP E4M3 / ``448.0`` even on gfx942 (MI300X), whose MFMA instructions only accept FNUZ E4M3. The rest of vLLM's gfx942 path (Triton sparse-MLA, indexer Q quant, ``current_platform.fp8_dtype()``) all use FNUZ on this arch, so the C++ writer was producing K-cache entries the FNUZ readers misinterpret. Gate the OCP branch on ``defined(__gfx950__)`` so: * gfx942 (MI300X) -> ``__hip_fp8_e4m3_fnuz`` + ``kFp8Max = 240.0f`` * gfx950 (MI355X) -> ``__hip_fp8_e4m3`` + ``kFp8Max = 448.0f`` This matches the encoding chosen elsewhere on each arch. Signed-off-by: ganyi Signed-off-by: Markus Hartikainen Co-authored-by: Cursor --- csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 6108f1ba189b..b90a725a90cb 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -56,7 +56,11 @@ #ifdef USE_ROCM // ROCm-compatible FP8 conversion helpers __device__ __forceinline__ uint8_t rocm_cvt_float_to_fp8_e4m3(float val) { - #if defined(HIP_FP8_TYPE_OCP) + // HIP defines HIP_FP8_TYPE_OCP based on HIP version, not GPU arch. On gfx942 + // mfma only supports FNUZ fp8, and the rest of vLLM's gfx942 path (Triton + // indexer / current_platform.fp8_dtype()) uses FNUZ. Gate OCP on __gfx950__ + // so the K cache encoding matches what the reader expects. + #if defined(HIP_FP8_TYPE_OCP) && defined(__gfx950__) __hip_fp8_e4m3 fp8_val(val); #else __hip_fp8_e4m3_fnuz fp8_val(val); @@ -85,7 +89,9 @@ constexpr int kQuantBlock = 64; constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 -#if defined(USE_ROCM) && !defined(HIP_FP8_TYPE_OCP) +// Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942 +// (max 240), OCP on gfx950 (max 448). +#if defined(USE_ROCM) && (!defined(HIP_FP8_TYPE_OCP) || !defined(__gfx950__)) constexpr float kFp8Max = 240.0f; #else constexpr float kFp8Max = 448.0f; From 3b4a8eb74307c88240131d0a1a6d6e6aba7fd262 Mon Sep 17 00:00:00 2001 From: Markus Hartikainen Date: Sat, 16 May 2026 23:10:08 +0300 Subject: [PATCH 3/9] [ROCm][DSv4] Use tl.float8e4b8 for FNUZ on MI300X sparse MLA kernels The DSv4 sparse MLA Triton kernels added in #41812 (and the matching turboquant store/decode kernels) bitcast uint8 to ``tl.float8e4b15`` when ``IS_FNUZ`` is true. ``float8e4b15`` is not a real Triton type; on AMD gfx942 (MI300X) Triton only supports the FP8 dtypes listed in the error from triton/compiler: ('fp8e4b8', 'fp8e4nv', 'fp8e5', 'fp8e5b16') The correct FNUZ E4M3 type is ``tl.float8e4b8`` (bias 8, matches the PyTorch ``torch.float8_e4m3fnuz`` used elsewhere on the MI300 path). The non-FNUZ branch already correctly uses ``tl.float8e4nv``. Without this fix, the very first profile run on MI300X with sparse MLA fails inside the dequant/gather kernel: type fp8e4b15 not supported in this architecture. This swaps all FNUZ branches to ``tl.float8e4b8``. Verified that ``IS_FNUZ`` is gated on ``current_platform.fp8_dtype() == torch.float8_e4m3fnuz`` so it never fires on OCP hardware. Signed-off-by: Markus Hartikainen --- vllm/v1/attention/ops/rocm_aiter_mla_sparse.py | 4 ++-- vllm/v1/attention/ops/triton_turboquant_decode.py | 4 ++-- vllm/v1/attention/ops/triton_turboquant_store.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 80731296fcf0..56ebca36ef3e 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -1225,7 +1225,7 @@ def _sparse_attn_decode_ragged_kernel( other=0, ) if IS_FNUZ: - x_fp8 = x_uint8.to(tl.float8e4b15, bitcast=True) + x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) else: x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) encoded_scales = tl.load( @@ -1293,7 +1293,7 @@ def _sparse_attn_decode_ragged_kernel( other=0, ) if IS_FNUZ: - x_fp8 = x_uint8.to(tl.float8e4b15, bitcast=True) + x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) else: x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) encoded_scales = tl.load( diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index 3adaf2610d8d..ceaaa2f920c0 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -162,7 +162,7 @@ def _tq_decode_stage1( other=0, ) if FP8_E4B15: - k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) + k_float = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32) else: k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) scores = ( @@ -371,7 +371,7 @@ def _tq_full_dequant_kv( if KEY_FP8: k_raw = tl.load(KV_cache_ptr + slot_base + d_offs, mask=d_mask, other=0) if FP8_E4B15: - k_recon = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) + k_recon = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32) else: k_recon = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask) diff --git a/vllm/v1/attention/ops/triton_turboquant_store.py b/vllm/v1/attention/ops/triton_turboquant_store.py index 3ad2d41488e7..ea1b1d7a9e42 100644 --- a/vllm/v1/attention/ops/triton_turboquant_store.py +++ b/vllm/v1/attention/ops/triton_turboquant_store.py @@ -189,7 +189,7 @@ def _tq_fused_store_fp8( d_offs = tl.arange(0, BLOCK_D) d_mask = d_offs < D k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0) - k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv) + k_fp8 = k_vals.to(tl.float8e4b8) if FP8_E4B15 else k_vals.to(tl.float8e4nv) k_bytes = k_fp8.to(tl.uint8, bitcast=True) tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask) From b65601efc30954a4a1830e430c658c00960cec71 Mon Sep 17 00:00:00 2001 From: Markus Hartikainen Date: Sun, 17 May 2026 12:53:46 +0300 Subject: [PATCH 4/9] [ROCm][DSv4] Zero ROCm sparse-MLA prefill KV workspace ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill_attn_impl`` in ``vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py`` is the actual ROCm path reached from ``DeepseekV4MLAAttention.forward`` at ``deepseek_v4_attention.py:762`` (``current_platform.is_rocm()``). ``DeepseekV4MLAAttention._forward_prefill`` in the same file is dead code on ROCm, so the previous ``kv.zero_()`` patch (commit 36a70373a2) fixed only the generic path. This ROCm-only forward also gets ``kv`` via ``current_workspace_manager().get_simultaneous(...)`` -- uninitialized shared memory reused across requests and layers -- writes only the compressed-K prefix and the SWA window for each chunk row, then reads the entire ``kv.view(-1, 1, head_dim)`` through ragged indices that can land on the holes for very short sequences. The result is exactly the symptom we observe on MI300X DSv4-Flash: 10 identical temperature=0 ``/v1/completions`` calls produce 10 distinct first tokens. Apply the same zero-init here. Cost is one bf16 fill of the workspace tile, dwarfed by the FP8 dequant + sparse attention. Signed-off-by: Markus Hartikainen Co-authored-by: Cursor --- vllm/models/deepseek_v4/amd/rocm.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/models/deepseek_v4/amd/rocm.py b/vllm/models/deepseek_v4/amd/rocm.py index 24a58a51b54c..552da7741b11 100644 --- a/vllm/models/deepseek_v4/amd/rocm.py +++ b/vllm/models/deepseek_v4/amd/rocm.py @@ -789,6 +789,18 @@ def _forward_prefill( kv = workspace_manager.get_simultaneous( ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), )[0] + # The workspace allocator returns uninitialized memory and is shared + # across requests + other layers. dequantize_and_gather_k_cache only + # writes the compressed-K prefix (rows [0, seq_len/compress_ratio)) + # and the SWA window (rows [N, N+gather_lens)) for each chunk row, + # leaving holes in the M dimension. rocm_sparse_attn_prefill then + # reads ``kv.view(-1, 1, head_dim)`` via ragged indices which can + # reach those holes for very short sequences, causing data-dependent + # non-determinism across otherwise-identical temperature=0 requests. + # Zero once per call so the holes are deterministic (zero attention + # contribution). The cost is one bf16 fill of the workspace tile, + # which is dwarfed by the FP8 dequant + sparse attention themselves. + kv.zero_() for chunk_idx in range(num_chunks): chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE chunk_end = min(chunk_start + cls.PREFILL_CHUNK_SIZE, num_prefills) From f677c498a9526b0d99219f21e8a0fd2eac4087ae Mon Sep 17 00:00:00 2001 From: Jin Tao Date: Mon, 18 May 2026 15:15:18 +0000 Subject: [PATCH 5/9] [ROCm][DSv4] Propagate FNUZ vs OCP gating to ROCm prefill+decode paths PR #42893 fixed the C++ SWA-K-cache encoder so it writes FNUZ E4M3 bytes on gfx942 (and OCP on gfx950) and updated the *generic* ``DeepseekV4MLAAttention._forward_prefill`` to call ``dequantize_and_gather_k_cache(..., use_fnuz=is_fp8_fnuz())`` for SWA and ``use_fnuz=False`` for the Triton-OCP-encoded compressed K cache. Two FP8-format mismatches remained on the actual ROCm DSv4 path (``DeepseekV4ROCMAiterMLASparseImpl``): 1. The public ``dequantize_and_gather_k_cache`` wrapper in ``vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py`` did not accept ``use_fnuz`` -- it silently dropped the kwarg when forwarding to ``dequantize_and_gather_k_cache_triton`` (which defaults to False). The ROCm prefill called the wrapper without ``use_fnuz``, so the SWA K cache (FNUZ on gfx942) was being read as OCP, scaling every K vector by ~448/240 in prefill attention. 2. ``_sparse_attn_decode_ragged_kernel`` in ``vllm/v1/attention/ops/rocm_aiter_mla_sparse.py`` decoded both the SWA (FNUZ on gfx942) and the compressed (always OCP) K caches with a single ``IS_FNUZ`` constexpr, so on MI300X the compressed-side branch reinterpreted OCP bytes as FNUZ -- the same encoder/decoder mismatch as (1) in the opposite direction (~240/448) on the decode side. Together these scrambled K vectors going into both prefill and decode attention, producing the GSM8K=0.005 gibberish PR #42893 documented but could not explain with eager-vs-graphs. This commit: * Adds ``use_fnuz`` to the wrapper and forwards it to the Triton implementation (the cuteDSL path is dead on ROCm anyway). * Splits ``_sparse_attn_decode_ragged_kernel``'s ``IS_FNUZ`` into per-cache flags ``IS_FNUZ_MAIN`` (SWA) and ``IS_FNUZ_EXTRA`` (compressed) so each cache is decoded with its own encoder's format. * Wires ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill`` to pass ``use_fnuz=False`` for the compressed call (Triton-OCP encoder) and ``use_fnuz=current_platform.is_fp8_fnuz()`` for the SWA call (C++ FNUZ-on-gfx942 encoder), matching the asymmetry that PR #42893's "[ROCm][DSv4] Fix compressed K cache dequant to match Triton OCP encoder" introduced for the generic path. Validated on 1 node x 4 x MI300X (gfx942), TP=4, VLLM_ROCM_USE_AITER=1, ``deepseek-ai/DeepSeek-V4-Flash``, both eager and CUDA-graphs ``FULL_AND_PIECEWISE`` configs from PR #42810. GSM8K 5-shot, n=200, num_concurrent=32 against /v1/completions: | Mode | exact_match | Stderr | | ----- | ----------- | -------- | | Eager | 0.955 | +/-0.0147 | | Graph | 0.955 | +/-0.0147 | vs. the pre-fix 0.005 PR #42893 reported on the same configuration. The two modes match each other to all three reported digits on both strict-match and flexible-extract filters. Co-authored-by: Cursor Signed-off-by: Markus Hartikainen --- vllm/models/deepseek_v4/amd/rocm.py | 9 ++++++ .../deepseek_v4/common/ops/cache_utils.py | 30 ++++++++++++++++++- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 18 ++++++++--- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/vllm/models/deepseek_v4/amd/rocm.py b/vllm/models/deepseek_v4/amd/rocm.py index 552da7741b11..104bfed66eff 100644 --- a/vllm/models/deepseek_v4/amd/rocm.py +++ b/vllm/models/deepseek_v4/amd/rocm.py @@ -12,6 +12,7 @@ DeepseekV4FlashMLASparseBackend, DeepseekV4SparseMLAAttentionImpl, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( CommonAttentionMetadata, @@ -809,6 +810,12 @@ def _forward_prefill( assert attn_metadata is not None assert compressed_k_cache is not None block_table = attn_metadata.block_table[num_decodes:] + # The compressed-K encoder (Triton _fused_kv_compress_norm_... + # in fused_compress_quant_cache.py) writes bytes via + # tl.float8e4nv with FP8_MAX=448.0 regardless of platform. + # The SWA-side C++ encoder, by contrast, switches to FNUZ on + # gfx942 (PR #42893), so the two caches need different + # use_fnuz settings even on the same MI300X. dequantize_and_gather_k_cache( kv[:chunk_size], compressed_k_cache, @@ -817,6 +824,7 @@ def _forward_prefill( block_table=block_table[chunk_start:chunk_end], block_size=attn_metadata.block_size // layer.compress_ratio, offset=0, + use_fnuz=False, ) swa_block_table = swa_metadata.block_table[num_decodes:] @@ -828,6 +836,7 @@ def _forward_prefill( block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, offset=N, + use_fnuz=current_platform.is_fp8_fnuz(), ) query_start = ( diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index 13baeae62456..29cec22c864b 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -369,7 +369,28 @@ def dequantize_and_gather_k_cache( block_table: torch.Tensor, block_size: int, offset: int, + use_fnuz: bool = False, ) -> None: + """Dequantize and gather a paged DSv4 K cache. + + ``use_fnuz`` selects which Triton FP8 dtype is used to bitcast the stored + bytes back to floats. It MUST match the encoder that wrote into this + particular cache: + + * ``compressed_k_cache``: Triton encoders in ``cache_utils.py`` / + ``fused_compress_quant_cache.py`` always use ``tl.float8e4nv`` (OCP, + bias 7, FP8_MAX=448) regardless of platform, so the reader must use + ``use_fnuz=False`` everywhere. + * ``swa_k_cache``: written by the C++ kernel + ``fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert``. After PR #42893 + that kernel encodes FNUZ E4M3 on gfx942 (MI300X, FP8_MAX=240) and OCP + on gfx950, so the reader must use + ``use_fnuz=current_platform.is_fp8_fnuz()``. + + Mismatching the two -- e.g. reading FNUZ bytes as OCP -- silently + rescales every K vector by ~448/240 ≈ 1.87 and produces gibberish in + sparse attention. + """ if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. from vllm.models.deepseek_v4.nvidia.ops.dequant_gather_k_cutedsl import ( @@ -382,7 +403,14 @@ def dequantize_and_gather_k_cache( return dequantize_and_gather_k_cache_triton( - out, k_cache, seq_lens, gather_lens, block_table, block_size, offset + out, + k_cache, + seq_lens, + gather_lens, + block_table, + block_size, + offset, + use_fnuz=use_fnuz, ) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 56ebca36ef3e..ed5ff03c8706 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -1167,7 +1167,12 @@ def _sparse_attn_decode_ragged_kernel( NOPE_DIM: tl.constexpr, NOPE_BLOCK: tl.constexpr, ROPE_DIM: tl.constexpr, - IS_FNUZ: tl.constexpr, + # `main_cache` is the SWA K-cache (written by the C++ encoder, FNUZ on + # gfx942 / OCP on gfx950). `extra_cache` is the compressed K-cache + # (Triton encoder, OCP on every platform). Reading both with the same + # `IS_FNUZ` would mis-decode one of them by the FNUZ/OCP scale ratio. + IS_FNUZ_MAIN: tl.constexpr, + IS_FNUZ_EXTRA: tl.constexpr, BLOCK_H: tl.constexpr, BLOCK_K: tl.constexpr, ): @@ -1224,7 +1229,7 @@ def _sparse_attn_decode_ragged_kernel( mask=valid[:, None] & nope_mask[None, :], other=0, ) - if IS_FNUZ: + if IS_FNUZ_MAIN: x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) else: x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) @@ -1292,7 +1297,7 @@ def _sparse_attn_decode_ragged_kernel( mask=valid[:, None] & nope_mask[None, :], other=0, ) - if IS_FNUZ: + if IS_FNUZ_EXTRA: x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) else: x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) @@ -1570,7 +1575,12 @@ def _rocm_sparse_attn_decode_ragged_triton( NOPE_DIM=nope_head_dim, NOPE_BLOCK=triton.next_power_of_2(nope_head_dim), ROPE_DIM=rope_head_dim, - IS_FNUZ=current_platform.is_fp8_fnuz(), + # main_cache = swa_k_cache (C++ encoder, FNUZ on gfx942 / OCP on gfx950). + # extra_cache = compressed kv_cache (Triton encoder, OCP everywhere). + # Reading both with a single IS_FNUZ would mis-decode one of them by + # the FNUZ/OCP scale ratio (~1.87×). + IS_FNUZ_MAIN=current_platform.is_fp8_fnuz(), + IS_FNUZ_EXTRA=False, BLOCK_H=block_h, BLOCK_K=block_k, num_warps=8, From 54536ca5977ab33fda7c6e388929069bc98030df Mon Sep 17 00:00:00 2001 From: Markus Hartikainen Date: Mon, 18 May 2026 18:43:56 +0300 Subject: [PATCH 6/9] [ROCm][DSv4] Revert turboquant fp8e4b15 -> fp8e4b8 changes (NVIDIA-only path) PR review on #42893 (gemini-code-assist) flagged that the three turboquant changes in commit 2bef91e7 ("[ROCm][DSv4] Use tl.float8e4b8 for FNUZ on MI300X sparse MLA kernels") are dead code on MI300X: they sit inside ``if FP8_E4B15:`` branches, and FP8_E4B15 is the constexpr returned by ``_use_fp8_e4b15(device)`` -- which is 1 only when ``torch.cuda.get_device_capability() < (8, 9)``. MI300X (gfx942) reports cap >= (9, x), so FP8_E4B15 = 0 on every AMD platform and the patched FNUZ branch is never executed. More importantly, the changes are *wrong* on the hardware where FP8_E4B15 = 1 -- NVIDIA Ampere/Ada (sm < 8.9). On those cards ``tl.float8e4b15`` (E4M3 with bias 15) is the correct Triton FP8 type for software emulation; ``tl.float8e4b8`` (E4M3 with bias 8) is the AMD-FNUZ-specific type and Triton on NVIDIA Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix. The original commit message conflated two unrelated gating constexprs (``IS_FNUZ`` in rocm_aiter_mla_sparse.py vs ``FP8_E4B15`` in the turboquant kernels). Only the rocm_aiter_mla_sparse.py hunks of 2bef91e7 are actually correct -- those are gated on ``current_platform.fp8_dtype() == torch.float8_e4m3fnuz`` / ``current_platform.is_fp8_fnuz()`` and are the ones that actually fix the MI300X sparse-MLA decode failure. Revert just the three turboquant lines back to ``tl.float8e4b15`` so the NVIDIA Ampere/Ada FP8 path is preserved. The MI300X fix in ``_sparse_attn_decode_ragged_kernel`` (the dequant/gather kernel cited in the original commit message) is unchanged. Signed-off-by: Markus Hartikainen Co-authored-by: Cursor --- vllm/v1/attention/ops/triton_turboquant_decode.py | 4 ++-- vllm/v1/attention/ops/triton_turboquant_store.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index ceaaa2f920c0..3adaf2610d8d 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -162,7 +162,7 @@ def _tq_decode_stage1( other=0, ) if FP8_E4B15: - k_float = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32) + k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) else: k_float = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) scores = ( @@ -371,7 +371,7 @@ def _tq_full_dequant_kv( if KEY_FP8: k_raw = tl.load(KV_cache_ptr + slot_base + d_offs, mask=d_mask, other=0) if FP8_E4B15: - k_recon = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32) + k_recon = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) else: k_recon = k_raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) tl.store(K_out_ptr + ko_base + d_offs, k_recon.to(tl.float16), mask=d_mask) diff --git a/vllm/v1/attention/ops/triton_turboquant_store.py b/vllm/v1/attention/ops/triton_turboquant_store.py index ea1b1d7a9e42..3ad2d41488e7 100644 --- a/vllm/v1/attention/ops/triton_turboquant_store.py +++ b/vllm/v1/attention/ops/triton_turboquant_store.py @@ -189,7 +189,7 @@ def _tq_fused_store_fp8( d_offs = tl.arange(0, BLOCK_D) d_mask = d_offs < D k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0) - k_fp8 = k_vals.to(tl.float8e4b8) if FP8_E4B15 else k_vals.to(tl.float8e4nv) + k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv) k_bytes = k_fp8.to(tl.uint8, bitcast=True) tl.store(KV_cache_ptr + slot_base + d_offs, k_bytes, mask=d_mask) From 4da0919243df5b96c5d5edd821ad61588adc9fd2 Mon Sep 17 00:00:00 2001 From: Markus Hartikainen Date: Thu, 21 May 2026 10:03:34 +0300 Subject: [PATCH 7/9] [ROCm][DSv4][gfx942] Vendor fp8_mqa_logits kernel to avoid AITER upgrade The AITER wrapper bundled in the currently-pinned aiter wheel launches fp8_mqa_logits with (BLOCK_KV=128, num_stages=2) on gfx942. For the DSv4 sparse indexer shape (NUM_HEADS=64, HEAD_SIZE=128) this double-buffered KV tile + fp32 scores accumulator + Q tile pushes Triton's LDS request to 96 KiB, which exceeds MI300X's 64 KiB per CU. The launch JIT-aborts with OutOfResources on the first inference. The fix is upstreamed as ROCm/aiter#3257 but until vLLM bumps to an AITER version that contains it, this patch ships the same kernel + tile-size logic vendored into vllm/. - Add vllm/v1/attention/ops/triton_fp8_mqa_logits.py with a byte-for-byte copy of AITER's @triton.jit kernel and a Python wrapper that selects (BLOCK_KV=64, num_stages=1) (~33 KiB) when the default tile would not fit on gfx942 (see module docstring for the LDS budget calculation). - Route rocm_fp8_mqa_logits to the vendored kernel on gfx942 when AITER ops are enabled. gfx950+ and CUDA still use the upstream AITER wrapper (which has dedicated Gluon kernels this vendor copy does not include). - Fix a latent broadcasting bug in the torch reference fallback: the per-KV-token scale arrives as [N, 1] (a [N, 4] uint8 buffer view-cast to fp32) and was being multiplied against an [H, M, N] score tensor where PyTorch right-aligns [N, 1] against the M dim. Flatten to [N] so the multiply lines up with the last axis. Also drop a hard-coded 'cuda' device on the index tensors so the fallback works on ROCm with HIP devices. This entire patch is intended to be reverted once vLLM picks up an AITER version that includes ROCm/aiter#3257. Co-authored-by: Cursor Signed-off-by: Markus Hartikainen --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 30 +- .../v1/attention/ops/triton_fp8_mqa_logits.py | 286 ++++++++++++++++++ 2 files changed, 314 insertions(+), 2 deletions(-) create mode 100644 vllm/v1/attention/ops/triton_fp8_mqa_logits.py diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index ed5ff03c8706..b2c90e48c844 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -504,7 +504,13 @@ def fp8_mqa_logits_torch( ) mask = mask_lo & mask_hi - score = torch.einsum("mhd,nd->hmn", q, k).float() * scale + # ``score`` is [H, M, N]; ``scale`` is the per-KV-token scale, which + # vLLM callers hand us as ``[N, 1]`` (a ``[N, 4]`` uint8 buffer cast + # to fp32). PyTorch right-aligns dimensions for broadcasting, so a + # naked ``score * scale`` would align ``scale``'s leading dim with + # ``score``'s M dim and raise a shape mismatch. Flatten to ``[N]`` so + # broadcasting lines up with the last dim of ``score``. + score = torch.einsum("mhd,nd->hmn", q, k).float() * scale.reshape(-1) logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) @@ -557,13 +563,33 @@ def rocm_fp8_mqa_logits( # path after aiter merge this kernel into main from vllm._aiter_ops import rocm_aiter_ops + k_fp8, scale = kv + + # gfx942 (MI300X): the AITER ``fp8_mqa_logits`` wrapper bundled in the + # currently-pinned aiter wheel launches its Triton kernel with + # ``(BLOCK_KV=128, num_stages=2)``, which requests ~96 KiB of LDS for + # the DSv4 sparse indexer shape. MI300X CUs have 64 KiB of LDS, so + # the launch JIT-aborts with ``OutOfResources: shared memory`` on the + # first inference. Route gfx942 callers to a vLLM-vendored copy of + # the same kernel that selects ``(BLOCK_KV=64, num_stages=1)`` when + # the default tile doesn't fit (~33 KiB), matching the fix in + # ROCm/aiter#3257. This entire branch can be removed once vLLM bumps + # to an AITER version that includes that PR. + if _ON_GFX942 and rocm_aiter_ops.is_enabled(): + from vllm.v1.attention.ops.triton_fp8_mqa_logits import ( + fp8_mqa_logits_gfx942, + ) + + return fp8_mqa_logits_gfx942( + q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke + ) + aiter_mqa_logits_module = None if rocm_aiter_ops.is_enabled(): aiter_mqa_logits_module = mqa_logits_module() if aiter_mqa_logits_module is not None: fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits - k_fp8, scale = kv return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke) else: return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) diff --git a/vllm/v1/attention/ops/triton_fp8_mqa_logits.py b/vllm/v1/attention/ops/triton_fp8_mqa_logits.py new file mode 100644 index 000000000000..2163283781a8 --- /dev/null +++ b/vllm/v1/attention/ops/triton_fp8_mqa_logits.py @@ -0,0 +1,286 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FP8 MQA logits kernel (vendored for ROCm gfx942). + +This module ships a vLLM-local copy of AITER's ``fp8_mqa_logits`` Triton +kernel together with the launch-time tile-size selection from AITER PR +ROCm/aiter#3257. It exists solely as a temporary workaround so that +PR vllm-project/vllm#42893 (DeepSeek V4 functional fixes on MI300X) can +land *without* requiring an AITER version bump on the consumer side. + +Background +---------- +AITER's ``aiter.ops.triton.attention.fp8_mqa_logits._fp8_mqa_logits_kernel`` +in versions <= the wheel bundled with vLLM nightly +``a6682d1d259cca69a9ae737ea5608fbbe7520031`` launches with the default +``(BLOCK_KV=128, num_stages=2)`` tile on gfx942 (MI300X) hardware. For the +DeepSeek V4 sparse indexer shape ``(NUM_HEADS=64, HEAD_SIZE=128)`` this +double-buffers a ``128 * 128 * 2 = 32 KiB`` KV tile plus a ``64 * 128 * 4 += 32 KiB`` fp32 scores accumulator plus an ``64 * 128 = 8 KiB`` Q tile, +which Triton 3.5/3.6 routes into LDS and ends up requesting 96 KiB. The +MI300X CU only has 64 KiB of LDS, so the kernel JIT-aborts with +``triton.runtime.errors.OutOfResources: shared memory``. + +AITER ROCm/aiter#3257 (queued for merge) keeps the same ``@triton.jit`` +body but selects ``(BLOCK_KV=64, num_stages=1)`` (~33 KiB) on gfx942 when +the default tile would not fit. Until that PR is consumed by vLLM, we +route gfx942 callers in :func:`rocm_fp8_mqa_logits` to this vendored +implementation, which carries the same logic. On gfx950+ and CUDA the +upstream AITER wrapper continues to be used (it has a dedicated Gluon +kernel that this vendor does not include). + +The ``@triton.jit`` body below is byte-for-byte equivalent to +``aiter.ops.triton._triton_kernels.attention.fp8_mqa_logits``; only the +imports were adjusted to use ``vllm.triton_utils``. +""" + +import torch + +from vllm.triton_utils import tl, triton + +# gfx942 (MI300X) has 64 KiB of LDS per CU. We accept the default +# (BLOCK_KV=128, num_stages=2) tile only when *both* of these hold: +# +# 1. Occupancy gate. With waves_per_eu=2 and num_warps=4 we target two +# workgroups co-resident on a CU -> per-WG LDS budget = 32 KiB. Triton +# keeps Q in registers (loop-invariant) and the fp32 scores accumulator +# in VGPRs (heavy VALU), so only the double-buffered KV tile is +# expected to live in LDS. A 0.9 safety factor leaves headroom for any +# LDS overhead the compiler may add. +# +# 2. Hardware ceiling. Defensive upper bound that also counts Q and +# scores against the 64 KiB CU limit, in case a Triton version (older +# or future) decides to spill them to LDS. False positives here only +# shrink the tile; false negatives are JIT-aborts, so we lean +# conservative. +_GFX942_CU_LDS_BYTES = 64 * 1024 +_GFX942_PER_WG_LDS_BUDGET_BYTES = _GFX942_CU_LDS_BYTES * 9 // 20 # ~28.8 KiB + + +def _gfx942_default_tile_fits_lds(num_heads: int, head_size: int) -> bool: + """Return True iff (BLOCK_KV=128, num_stages=2) fits in MI300X LDS.""" + BLOCK_KV = 128 + NUM_STAGES = 2 + kv_bytes = head_size * BLOCK_KV * NUM_STAGES + scores_bytes = num_heads * BLOCK_KV * 4 + q_bytes = num_heads * head_size + fits_occupancy = kv_bytes < _GFX942_PER_WG_LDS_BUDGET_BYTES + fits_hardware = q_bytes + kv_bytes + scores_bytes <= _GFX942_CU_LDS_BYTES + return fits_occupancy and fits_hardware + + +@triton.jit +def _fp8_mqa_logits_kernel( + Q_ptr, # fp8e4m3 [seq_len, H, D] + KV_ptr, # fp8e4m3 [seq_len_kv, D] + kv_scales_ptr, # fp32 [seq_len_kv] + weights_ptr, # fp32 [seq_len, H] + cu_start_ptr, # int32 [seq_len] + cu_end_ptr, # int32 [seq_len] + logits_ptr, # fp32 [seq_len, seq_len_kv] + seq_len, + seq_len_kv, + NUM_HEADS: tl.constexpr, + HEAD_SIZE: tl.constexpr, + # strides + stride_q_s: tl.int64, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_s: tl.int64, + stride_kv_d: tl.constexpr, + stride_w_s: tl.int64, + stride_w_h: tl.constexpr, + stride_logits_s: tl.int64, + stride_logits_k: tl.int64, + # block sizes + BLOCK_KV: tl.constexpr, +): + row_id = tl.program_id(0) + # go from larger to smaller in terms of work + # to reduce the tail effect + row_id = tl.num_programs(0) - row_id - 1 + tl.assume(row_id >= 0) + tl.assume(stride_q_s > 0) + tl.assume(stride_q_h > 0) + tl.assume(stride_q_d > 0) + tl.assume(stride_kv_s > 0) + tl.assume(stride_kv_d > 0) + tl.assume(stride_w_s > 0) + tl.assume(stride_w_h > 0) + + logits_row_ptrs = logits_ptr + row_id * stride_logits_s + + h_inds = tl.arange(0, NUM_HEADS)[:, None] + d_inds = tl.arange(0, HEAD_SIZE) + + # load Q[BLOCK_Q, NUM_HEADS, HEAD_SIZE] + q_ptrs = ( + Q_ptr + row_id * stride_q_s + h_inds * stride_q_h + d_inds[None, :] * stride_q_d + ) + + q_block = tl.load(q_ptrs, cache_modifier=".cg") + w_ptrs = weights_ptr + row_id * stride_w_s + h_inds * stride_w_h + w_block = tl.load(w_ptrs, cache_modifier=".cg").to(tl.float32) + + # Load start/end for each row in this block + start_ind = tl.load(cu_start_ptr + row_id) + end_ind = tl.load(cu_end_ptr + row_id) + + start_ind = tl.maximum(start_ind, 0) + end_ind = tl.minimum(end_ind, seq_len_kv) + shifted_end = end_ind - start_ind + shifted_unmasked_end = shifted_end // BLOCK_KV * BLOCK_KV + + kv_col_offsets = tl.arange(0, BLOCK_KV) + start_ind + kv_ptrs = ( + KV_ptr + kv_col_offsets[None, :] * stride_kv_s + d_inds[:, None] * stride_kv_d + ) + + kv_scales_ptrs = kv_scales_ptr + kv_col_offsets + + logits_ptrs = logits_row_ptrs + kv_col_offsets * stride_logits_k + + # Loop over KV tiles + for _ in tl.range(0, shifted_unmasked_end, BLOCK_KV): + kv_block = tl.load(kv_ptrs) + kv_scales = tl.load(kv_scales_ptrs) + + # [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV] + scores = tl.dot(q_block, kv_block, input_precision="ieee") + # Multiply by kv_scales (broadcast along rows) + scores = scores * kv_scales[None, :] + # ReLU + scores = tl.maximum(scores, 0.0) + scores = scores * w_block + # [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ] + scores = tl.sum(scores, axis=0) + tl.store(logits_ptrs, scores) + + kv_ptrs += BLOCK_KV * stride_kv_s + kv_scales_ptrs += BLOCK_KV + logits_ptrs += BLOCK_KV * stride_logits_k + kv_col_offsets += BLOCK_KV + + # masked load + kv_col_mask = kv_col_offsets < end_ind + kv_block = tl.load(kv_ptrs, mask=kv_col_mask[None, :], other=0.0) + kv_scales = tl.load(kv_scales_ptrs, mask=kv_col_mask, other=0.0) + + # [NUM_HEADS, BLOCK_KV] = [NUM_HEADS, HEAD_SIZE] x [HEAD_SIZE, BLOCK_KV] + scores = tl.dot(q_block, kv_block, input_precision="ieee") + # Multiply by kv_scales (broadcast along rows) + scores = scores * kv_scales[None, :] + # ReLU + scores = tl.maximum(scores, 0.0) + scores = scores * w_block + # [NUM_HEADS, BLOCK_KV] -> [BLOCK_KV, ] + scores = tl.sum(scores, axis=0) + # masked store + in_window = (kv_col_offsets >= start_ind) & (kv_col_offsets < end_ind) + tl.store(logits_ptrs, scores, mask=in_window) + + +def fp8_mqa_logits_gfx942( + q: torch.Tensor, + k_fp8: torch.Tensor, + kv_scales: torch.Tensor, + weights: torch.Tensor, + cu_starts: torch.Tensor, + cu_ends: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits on MI300X (gfx942) using the vendored kernel. + + Drop-in replacement for ``aiter.ops.triton.attention.fp8_mqa_logits. + fp8_mqa_logits`` on MI300X. Selects ``(BLOCK_KV, num_stages)`` based on + whether the default tile fits within the 64 KiB LDS budget of a gfx942 + CU (see module docstring). + + Args: + q: Query tensor of shape ``[M, H, D]``, FP8 dtype. + k_fp8: Key tensor of shape ``[N, D]``, FP8 dtype. + kv_scales: K scales of shape ``[N]`` (or ``[N, 1]`` -- viewed as + ``[N]``), float32. + weights: Per-head weights of shape ``[M, H]``, float32. + cu_starts: Start indices (inclusive) of shape ``[M]``, int32. + cu_ends: End indices (exclusive) of shape ``[M]``, int32. + + Returns: + Logits of shape ``[M, N]``, float32 -- positions outside + ``[cu_starts[i], cu_ends[i])`` for row ``i`` are pre-filled with + ``-inf`` so the caller can run a top-k without masking. + """ + seq_len, num_heads, head_size = q.shape + seq_len_kv = k_fp8.shape[0] + assert num_heads & (num_heads - 1) == 0, ( + f"num_heads must be a power of two (got {num_heads})" + ) + assert head_size & (head_size - 1) == 0, ( + f"head_size must be a power of two (got {head_size})" + ) + + # The kernel walks ``kv_scales`` as a 1-D contiguous array of size N + # (it indexes by ``kv_scales_ptr + kv_col_offsets``). The vLLM caller + # passes a ``[N, 4]`` uint8 view-cast-to-float32 which lands as + # ``[N, 1]`` contiguous -- byte-identical to ``[N]`` -- but flatten + # explicitly to keep the kernel's pointer arithmetic intent clear. + kv_scales_1d = kv_scales.reshape(-1) + + # Initialise with -inf so positions outside [cu_starts, cu_ends) read + # as ``-inf`` after the masked store path -- this matches AITER's + # ``fp8_mqa_logits`` semantics and is what the top-k consumer expects. + logits = torch.full( + (seq_len, seq_len_kv), + fill_value=-float("inf"), + dtype=torch.float32, + device=q.device, + ) + + if _gfx942_default_tile_fits_lds(num_heads, head_size): + block_kv = 128 + num_stages = 2 + else: + # DSv4 sparse indexer (NUM_HEADS=64, HEAD_SIZE=128) lands here: + # default tile spills past gfx942's 64 KiB LDS budget. (64, 1) + # needs ~33 KiB and clears the per-WG budget with margin. + block_kv = 64 + num_stages = 1 + + # heuristic for MFMA instruction shape, identical to AITER's choice + matrix_instr_nonkdim = 32 + if seq_len <= 1024: + matrix_instr_nonkdim = 16 + + stride_q_s, stride_q_h, stride_q_d = q.stride() + stride_kv_s, stride_kv_d = k_fp8.stride() + stride_w_s, stride_w_h = weights.stride() + stride_logits_s, stride_logits_k = logits.stride() + + _fp8_mqa_logits_kernel[(seq_len,)]( + Q_ptr=q, + KV_ptr=k_fp8, + kv_scales_ptr=kv_scales_1d, + weights_ptr=weights, + cu_start_ptr=cu_starts, + cu_end_ptr=cu_ends, + logits_ptr=logits, + seq_len=seq_len, + seq_len_kv=seq_len_kv, + NUM_HEADS=num_heads, + HEAD_SIZE=head_size, + stride_q_s=stride_q_s, + stride_q_h=stride_q_h, + stride_q_d=stride_q_d, + stride_kv_s=stride_kv_s, + stride_kv_d=stride_kv_d, + stride_w_s=stride_w_s, + stride_w_h=stride_w_h, + stride_logits_s=stride_logits_s, + stride_logits_k=stride_logits_k, + BLOCK_KV=block_kv, + num_warps=4, + num_stages=num_stages, + waves_per_eu=2, + matrix_instr_nonkdim=matrix_instr_nonkdim, + ) + + return logits From f13e4450ae94e3e168960452f4b57b7ba4e71e2a Mon Sep 17 00:00:00 2001 From: Markus Hartikainen Date: Fri, 22 May 2026 19:13:37 +0300 Subject: [PATCH 8/9] [ROCm][DSv4] Use FNUZ FP8_MAX=224 to match vLLM convention (review #42893) Switch the FNUZ branch of `kFp8Max` to 224.0 (was 240.0, the FNUZ dtype's raw representable max). 224.0 is what the rest of vLLM's FNUZ pipeline uses -- see `vllm/model_executor/layers/quantization/utils/fp8_utils.py:412-417`, which notes that 240.0 hurts dynamic-quant accuracy. The OCP branch (gfx950 + NVIDIA) keeps 448.0. Make the unit test honor the same split: add an optional `use_fnuz` constexpr to `quantize_and_insert_k_kernel` (default False, no production caller affected) and pick the encoding from `current_platform.is_fp8_fnuz()`. Byte-exact comparison now succeeds on both gfx942 and gfx950. Verified: 36/36 unit tests pass on MI300X (gfx942) and MI355X (gfx950). Signed-off-by: Markus Hartikainen Co-authored-by: Cursor --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 6 +-- ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 20 ++++++--- .../deepseek_v4/common/ops/cache_utils.py | 41 +++++++++---------- 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index b90a725a90cb..cc9c3dc56304 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -89,10 +89,10 @@ constexpr int kQuantBlock = 64; constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 -// Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942 -// (max 240), OCP on gfx950 (max 448). +// FNUZ on gfx942 / OCP on gfx950. FNUZ uses 224.0 (not the dtype's raw +// 240.0) to match the rest of vLLM's FNUZ pipeline; see fp8_utils.py:412-417. #if defined(USE_ROCM) && (!defined(HIP_FP8_TYPE_OCP) || !defined(__gfx950__)) -constexpr float kFp8Max = 240.0f; +constexpr float kFp8Max = 224.0f; #else constexpr float kFp8Max = 448.0f; #endif diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index 13010540d973..63f46204e9ae 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -23,13 +23,16 @@ dequantize_and_gather_k_cache, quantize_and_insert_k_cache, ) +from vllm.platforms import current_platform # ── Constants matching the kernel ──────────────────────────────────────────── HEAD_DIM = 512 ROPE_DIM = 64 NOPE_DIM = HEAD_DIM - ROPE_DIM # 448 QUANT_BLOCK = 64 -FP8_MAX = 448.0 +# Match the C++ SWA-K encoder: FNUZ on gfx942 (FP8_MAX=224), OCP elsewhere (448). +USE_FNUZ = current_platform.is_fp8_fnuz() +FP8_MAX = 224.0 if USE_FNUZ else 448.0 HEAD_BYTES = NOPE_DIM + ROPE_DIM * 2 + 8 # 448 + 128 + 8 = 584 @@ -198,7 +201,7 @@ def test_kv_path_matches_reference(num_tokens: int, block_size: int): num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device ) quantize_and_insert_k_cache( - kv_ref, k_cache_ref, slot_mapping, block_size=block_size + kv_ref, k_cache_ref, slot_mapping, block_size=block_size, use_fnuz=USE_FNUZ ) # ── Fused path (dummy q, single head) ────────────────────────────────── @@ -229,7 +232,14 @@ def _dequant(k_cache_2d): # gather_lens arg is None (use seq_lens) k_cache_3d = k_cache_2d.view(num_blocks, block_size, HEAD_BYTES) dequantize_and_gather_k_cache( - out, k_cache_3d, seq_lens, None, block_table, block_size, offset=0 + out, + k_cache_3d, + seq_lens, + None, + block_table, + block_size, + offset=0, + use_fnuz=USE_FNUZ, ) return out[0, :num_tokens] @@ -292,7 +302,7 @@ def test_kv_path_with_dp_padding(num_tokens: int, pad: int, block_size: int): num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device ) quantize_and_insert_k_cache( - kv_ref, k_cache_ref, slot_mapping, block_size=block_size + kv_ref, k_cache_ref, slot_mapping, block_size=block_size, use_fnuz=USE_FNUZ ) # Fused: pass full-sized q/kv/positions, shorter slot_mapping. @@ -341,7 +351,7 @@ def test_combined_q_and_kv(num_tokens: int, n_heads: int, block_size: int): num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device ) quantize_and_insert_k_cache( - kv_ref, k_cache_ref, slot_mapping, block_size=block_size + kv_ref, k_cache_ref, slot_mapping, block_size=block_size, use_fnuz=USE_FNUZ ) # Fused single call. diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index 29cec22c864b..3876e4c02542 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -39,6 +39,7 @@ def quantize_and_insert_k_kernel( block_stride: tl.constexpr, # total bytes per block (padded) fp8_max: tl.constexpr, n_quant_blocks: tl.constexpr, # 8 (7 real + 1 padding) + use_fnuz: tl.constexpr = False, ): """ Quantize K tensor and insert into paged K cache. @@ -49,6 +50,9 @@ def quantize_and_insert_k_kernel( - [64*576 + 64*8, block_stride): Padding One program per token. + + ``use_fnuz=True`` selects FNUZ (``tl.float8e4b8``); default OCP + (``tl.float8e4nv``) matches every production caller. """ pid = tl.program_id(0) @@ -112,8 +116,11 @@ def quantize_and_insert_k_kernel( x_scaled = x / scale x_clamped = tl.clamp(x_scaled, -fp8_max, fp8_max) - # Convert to fp8, then bitcast to uint8 for storage - x_fp8 = x_clamped.to(tl.float8e4nv) + # Convert to fp8 (FNUZ on gfx942, OCP elsewhere), then bitcast to uint8. + if use_fnuz: + x_fp8 = x_clamped.to(tl.float8e4b8) + else: + x_fp8 = x_clamped.to(tl.float8e4nv) x_uint8 = x_fp8.to(tl.uint8, bitcast=True) # Store as uint8 (1 byte each) @@ -145,6 +152,7 @@ def quantize_and_insert_k_cache( slot_mapping: torch.Tensor, # [num_tokens] int64 block_size: int = 64, is_ue8m0: bool = True, + use_fnuz: bool = False, ): """ Quantize K tensor and insert into paged K cache. @@ -155,6 +163,9 @@ def quantize_and_insert_k_cache( - Next 64 * 8 = 512 bytes: Scales - Each token: 8 bytes (uint8 scales, 7 real + 1 padding) - Padded to multiple of 576 + + ``use_fnuz=True`` switches to FNUZ E4M3 (FP8_MAX=224); default OCP + (FP8_MAX=448) matches every production caller. """ assert k.dim() == 2 and k.shape[1] == 512, ( f"K must be [num_tokens, 512], got {k.shape}" @@ -171,7 +182,7 @@ def quantize_and_insert_k_cache( TOKEN_BF16_DIM = 64 TOKEN_SCALE_DIM = 8 QUANT_BLOCK_SIZE = 64 - FP8_MAX = 448.0 + FP8_MAX = 224.0 if use_fnuz else 448.0 # FNUZ value matches fp8_utils.py TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 grid = (num_tokens,) @@ -191,6 +202,7 @@ def quantize_and_insert_k_cache( block_stride=block_stride, fp8_max=FP8_MAX, n_quant_blocks=8, + use_fnuz=use_fnuz, ) @@ -274,7 +286,7 @@ def _dequantize_and_gather_k_kernel( # Load quantized fp8 values (stored as uint8) x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0) - # Bitcast uint8 back to fp8 (FNUZ on gfx942, OCP otherwise) + # Bitcast uint8 back to fp8 (FNUZ on gfx942, OCP elsewhere). if use_fnuz: x_fp8 = x_uint8.to(tl.float8e4b8, bitcast=True) else: @@ -373,23 +385,10 @@ def dequantize_and_gather_k_cache( ) -> None: """Dequantize and gather a paged DSv4 K cache. - ``use_fnuz`` selects which Triton FP8 dtype is used to bitcast the stored - bytes back to floats. It MUST match the encoder that wrote into this - particular cache: - - * ``compressed_k_cache``: Triton encoders in ``cache_utils.py`` / - ``fused_compress_quant_cache.py`` always use ``tl.float8e4nv`` (OCP, - bias 7, FP8_MAX=448) regardless of platform, so the reader must use - ``use_fnuz=False`` everywhere. - * ``swa_k_cache``: written by the C++ kernel - ``fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert``. After PR #42893 - that kernel encodes FNUZ E4M3 on gfx942 (MI300X, FP8_MAX=240) and OCP - on gfx950, so the reader must use - ``use_fnuz=current_platform.is_fp8_fnuz()``. - - Mismatching the two -- e.g. reading FNUZ bytes as OCP -- silently - rescales every K vector by ~448/240 ≈ 1.87 and produces gibberish in - sparse attention. + ``use_fnuz`` MUST match the encoder of the specific cache being read: + ``False`` for ``compressed_k_cache`` (Triton encoder is OCP everywhere), + ``current_platform.is_fp8_fnuz()`` for ``swa_k_cache`` (C++ encoder + writes FNUZ on gfx942 and OCP on gfx950). """ if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. From 9ae8703063fdcb87914db25d29725d42536a1e5f Mon Sep 17 00:00:00 2001 From: Markus Hartikainen Date: Fri, 22 May 2026 19:14:45 +0300 Subject: [PATCH 9/9] [ROCm][DSv4] Tighten KV-workspace zero comment + FP8 encoding notes (review #42893) Replace the long rationale around ``kv.zero_()`` (in both prefill paths) with a brief TODO that names the proper fix: mask invalid rows in the indexer (score = -inf) or in the sparse-attention kernel (skip indices >= valid_len). The current zero is the minimal interim workaround; the underlying bug is arch-independent (uninitialized workspace + indexer that scores the entire M dim) so the call stays unchanged on every platform until the indexer/kernel fix lands. No behavior change. Also condense the duplicate FNUZ-vs-OCP comments at the dequant call sites and in ``_sparse_attn_decode_ragged_kernel``: the wrapper docstring already explains the asymmetry, so per-call-site repetition was just noise. Signed-off-by: Markus Hartikainen Co-authored-by: Cursor --- vllm/models/deepseek_v4/amd/rocm.py | 22 ++++------------ .../v1/attention/ops/rocm_aiter_mla_sparse.py | 25 ++++++------------- 2 files changed, 12 insertions(+), 35 deletions(-) diff --git a/vllm/models/deepseek_v4/amd/rocm.py b/vllm/models/deepseek_v4/amd/rocm.py index 104bfed66eff..770255e177fd 100644 --- a/vllm/models/deepseek_v4/amd/rocm.py +++ b/vllm/models/deepseek_v4/amd/rocm.py @@ -790,17 +790,10 @@ def _forward_prefill( kv = workspace_manager.get_simultaneous( ((cls.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), )[0] - # The workspace allocator returns uninitialized memory and is shared - # across requests + other layers. dequantize_and_gather_k_cache only - # writes the compressed-K prefix (rows [0, seq_len/compress_ratio)) - # and the SWA window (rows [N, N+gather_lens)) for each chunk row, - # leaving holes in the M dimension. rocm_sparse_attn_prefill then - # reads ``kv.view(-1, 1, head_dim)`` via ragged indices which can - # reach those holes for very short sequences, causing data-dependent - # non-determinism across otherwise-identical temperature=0 requests. - # Zero once per call so the holes are deterministic (zero attention - # contribution). The cost is one bf16 fill of the workspace tile, - # which is dwarfed by the FP8 dequant + sparse attention themselves. + # TODO: workspace is torch.empty() and only the compressed-K prefix + + # SWA window are written per chunk row; the indexer's topK can land in + # the unwritten holes for short sequences. Proper fix is to mask invalid + # rows in the indexer (score = -inf) or in rocm_sparse_attn_prefill. kv.zero_() for chunk_idx in range(num_chunks): chunk_start = chunk_idx * cls.PREFILL_CHUNK_SIZE @@ -810,12 +803,7 @@ def _forward_prefill( assert attn_metadata is not None assert compressed_k_cache is not None block_table = attn_metadata.block_table[num_decodes:] - # The compressed-K encoder (Triton _fused_kv_compress_norm_... - # in fused_compress_quant_cache.py) writes bytes via - # tl.float8e4nv with FP8_MAX=448.0 regardless of platform. - # The SWA-side C++ encoder, by contrast, switches to FNUZ on - # gfx942 (PR #42893), so the two caches need different - # use_fnuz settings even on the same MI300X. + # compressed_k_cache is OCP on every platform (Triton encoder). dequantize_and_gather_k_cache( kv[:chunk_size], compressed_k_cache, diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index b2c90e48c844..1be924eacd0e 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -565,16 +565,11 @@ def rocm_fp8_mqa_logits( k_fp8, scale = kv - # gfx942 (MI300X): the AITER ``fp8_mqa_logits`` wrapper bundled in the - # currently-pinned aiter wheel launches its Triton kernel with - # ``(BLOCK_KV=128, num_stages=2)``, which requests ~96 KiB of LDS for - # the DSv4 sparse indexer shape. MI300X CUs have 64 KiB of LDS, so - # the launch JIT-aborts with ``OutOfResources: shared memory`` on the - # first inference. Route gfx942 callers to a vLLM-vendored copy of - # the same kernel that selects ``(BLOCK_KV=64, num_stages=1)`` when - # the default tile doesn't fit (~33 KiB), matching the fix in - # ROCm/aiter#3257. This entire branch can be removed once vLLM bumps - # to an AITER version that includes that PR. + # gfx942: AITER's bundled fp8_mqa_logits launches with BLOCK_KV=128 + + # num_stages=2 (~96 KiB LDS), exceeding MI300X's 64 KiB LDS so it aborts + # with OutOfResources. Route gfx942 to a vendored copy that drops to + # BLOCK_KV=64 + num_stages=1 (~33 KiB) per ROCm/aiter#3257. Remove this + # branch once vLLM bumps AITER to a version that includes that PR. if _ON_GFX942 and rocm_aiter_ops.is_enabled(): from vllm.v1.attention.ops.triton_fp8_mqa_logits import ( fp8_mqa_logits_gfx942, @@ -1193,10 +1188,8 @@ def _sparse_attn_decode_ragged_kernel( NOPE_DIM: tl.constexpr, NOPE_BLOCK: tl.constexpr, ROPE_DIM: tl.constexpr, - # `main_cache` is the SWA K-cache (written by the C++ encoder, FNUZ on - # gfx942 / OCP on gfx950). `extra_cache` is the compressed K-cache - # (Triton encoder, OCP on every platform). Reading both with the same - # `IS_FNUZ` would mis-decode one of them by the FNUZ/OCP scale ratio. + # SWA K-cache (main): C++ encoder writes FNUZ on gfx942, OCP on gfx950. + # Compressed K-cache (extra): Triton encoder writes OCP everywhere. IS_FNUZ_MAIN: tl.constexpr, IS_FNUZ_EXTRA: tl.constexpr, BLOCK_H: tl.constexpr, @@ -1601,10 +1594,6 @@ def _rocm_sparse_attn_decode_ragged_triton( NOPE_DIM=nope_head_dim, NOPE_BLOCK=triton.next_power_of_2(nope_head_dim), ROPE_DIM=rope_head_dim, - # main_cache = swa_k_cache (C++ encoder, FNUZ on gfx942 / OCP on gfx950). - # extra_cache = compressed kv_cache (Triton encoder, OCP everywhere). - # Reading both with a single IS_FNUZ would mis-decode one of them by - # the FNUZ/OCP scale ratio (~1.87×). IS_FNUZ_MAIN=current_platform.is_fp8_fnuz(), IS_FNUZ_EXTRA=False, BLOCK_H=block_h,