Skip to content

[ROCm][DSv4] Functional fixes for DeepSeek V4 on MI300X (gfx942)#42893

Draft
maeehart wants to merge 9 commits into
vllm-project:mainfrom
maeehart:mahartik/dsv4-rocm-mi300x-fixes
Draft

[ROCm][DSv4] Functional fixes for DeepSeek V4 on MI300X (gfx942)#42893
maeehart wants to merge 9 commits into
vllm-project:mainfrom
maeehart:mahartik/dsv4-rocm-mi300x-fixes

Conversation

@maeehart
Copy link
Copy Markdown
Contributor

@maeehart maeehart commented May 17, 2026

Purpose

Builds on #42810 to bring DeepSeek V4 (Pro and Flash) to a working state on AMD MI300X (gfx942). With main + #42810 alone, the model fails to load on MI300X — both in eager and in cudagraphs (FULL_AND_PIECEWISE) modes — with:

NotImplementedError: "mul_cuda" not implemented for 'Float8_e8m0fnu'

and once that is past, the first sparse-MLA forward errors with:

type fp8e4b15 not supported in this architecture.

These come from a mix of an MX-format scale path that assumes float arithmetic, a HIP build macro that doesn't actually correspond to the GPU arch, a non-existent Triton FP8 dtype, an asymmetric K-cache encoder/decoder pair, and prefill workspaces that aren't zeroed before reuse. After the load-time fixes were in, two more FP8-format mismatches were caught on the actual ROCm DSv4 path (DeepseekV4ROCMAiterMLASparseImpl): the public dequant wrapper silently dropped use_fnuz, and the decode kernel used a single IS_FNUZ for both the FNUZ SWA cache and the always-OCP compressed cache, scaling K vectors by ~448/240 in prefill and ~240/448 in decode. Each issue is fixed in its own commit so they can be reviewed independently.

#42810 fixed the ffn_norm regression and the AITER MHC accuracy issue at high concurrency; this PR is the MI300X-specific complement, and with all 7 commits in place GSM8K accuracy matches #42810's reference number to all three reported digits (see test results below).

Changes (8 commits)

  1. [ROCm][DSv4] MI300X (gfx942) support for DeepSeek V4 (authored by ganyi)

    • fp8_utils.py: handle torch.float8_e8m0fnu weight scales in process_fp8_weight_block_strategy by incrementing the UE8M0 exponent byte instead of * 2.0 (which is unimplemented on CUDA/HIP for that dtype). This is the fix for the load-time crash above.
    • fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu: keep kFp8Max consistent with the FP8 dtype actually emitted on each arch.
    • Small companion changes in deepseek_v4_attention.py and cache_utils.py so the MI300X path stays internally consistent.
  2. [ROCm][DSv4] Use FNUZ FP8 on gfx942 in fused KV insert kernel (authored by ganyi)

    • The fused K-cache writer selected between __hip_fp8_e4m3 and __hip_fp8_e4m3_fnuz based purely on HIP_FP8_TYPE_OCP. That macro is a property of the HIP runtime version, not the target arch — on a HIP build that defines it, the kernel writes OCP bytes even on gfx942 hardware that only supports FNUZ MFMA. Gate the OCP path additionally on __gfx950__ so the encoding matches the arch (FNUZ + 240 on gfx942, OCP + 448 on gfx950).
  3. [ROCm][DSv4] Use tl.float8e4b8 for FNUZ on MI300X sparse MLA kernels

    • The DSv4 sparse-MLA Triton kernels added in [ROCm][DSv4] implement flash sparse mla with triton kernels #41812 (_sparse_attn_decode_ragged_kernel) bitcast uint8 -> tl.float8e4b15 when IS_FNUZ is true. float8e4b15 is not a real Triton type on gfx942 — Triton on MI300X only supports fp8e4b8, fp8e4nv, fp8e5, fp8e5b16. The correct FNUZ E4M3 type is tl.float8e4b8 (bias 8, matches the PyTorch torch.float8_e4m3fnuz used elsewhere on the MI300 path). IS_FNUZ here is correctly gated on current_platform.fp8_dtype() == torch.float8_e4m3fnuz / current_platform.is_fp8_fnuz() so it never fires on OCP hardware.
  4. [ROCm][DSv4] Fix compressed K cache dequant to match Triton OCP encoder

    • The two DSv4 K caches have different writers: compressed_k_cache is written by Triton (_fused_kv_compress_norm_rope_insert_sparse_attn) using OCP-style E4M3 encoding, but the C++ dequant in deepseek_v4_attention.py was reading it as FNUZ on gfx942, giving 2x scale mismatch on the compressed K side. This commit makes the dequant match the encoder.
  5. [ROCm][DSv4] Zero prefill-attn KV workspace before gather

  6. [ROCm][DSv4] Zero ROCm sparse-MLA prefill KV workspace

    • Both prefill paths take a torch.empty-backed workspace view from current_workspace_manager().get_simultaneous and then scatter into a subset of its rows. The remaining rows are read by the attention kernel and contain uninitialized memory from earlier layers/requests. Zero the workspace before the scatter so the unused rows are deterministic.
  7. [ROCm][DSv4] Propagate FNUZ vs OCP gating to ROCm prefill+decode paths (authored by jin-amd, validated the GSM8K result reported below)

    • Two FP8-format mismatches that Use FlashAttention for multi_query_kv_attention #4 fixed for the generic prefill path remained on the actual ROCm DSv4 path (DeepseekV4ROCMAiterMLASparseImpl):
      • The public dequantize_and_gather_k_cache wrapper in cache_utils.py did not accept use_fnuz — it silently dropped the kwarg when forwarding to dequantize_and_gather_k_cache_triton (which defaults to False). The ROCm prefill called the wrapper without use_fnuz, so the SWA K cache (FNUZ on gfx942) was being read as OCP, scaling every K vector by ~448/240 in prefill.
      • _sparse_attn_decode_ragged_kernel in rocm_aiter_mla_sparse.py decoded both the SWA (FNUZ on gfx942) and the compressed (always OCP) K caches with a single IS_FNUZ constexpr, so on MI300X the compressed-side branch reinterpreted OCP bytes as FNUZ — the same encoder/decoder mismatch in the opposite direction (~240/448) on the decode side.
    • Adds use_fnuz to the wrapper and forwards it to the Triton implementation; splits IS_FNUZ into IS_FNUZ_MAIN (SWA) and IS_FNUZ_EXTRA (compressed) so each cache is decoded with its own encoder's format; wires DeepseekV4ROCMAiterMLASparseImpl._forward_prefill to pass use_fnuz=False for the compressed (Triton-OCP) call and use_fnuz=current_platform.is_fp8_fnuz() for the SWA (C++ FNUZ-on-gfx942) call. This is the commit that closes the GSM8K=0.005 accuracy gap originally documented below.
  8. [ROCm][DSv4] Revert turboquant fp8e4b15 -> fp8e4b8 changes (NVIDIA-only path)

    • Drops three drive-by hunks from commit 3 in triton_turboquant_decode.py / triton_turboquant_store.py. Per review (gemini-code-assist), those sit inside if FP8_E4B15: branches whose constexpr is 1 only when torch.cuda.get_device_capability() < (8, 9) — i.e. NVIDIA Ampere/Ada, where tl.float8e4b15 IS the correct Triton FP8 type and tl.float8e4b8 would be rejected. FP8_E4B15 is always 0 on AMD (gfx942 reports cap (9, x)), so the change was both unreachable on MI300X and incorrect on its actual NVIDIA target. The real MI300X sparse-MLA fix in rocm_aiter_mla_sparse.py is gated on the right constexpr (IS_FNUZ from current_platform.is_fp8_fnuz()) and is unaffected.

Diff stats

csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu        | 12 ++++++-
vllm/model_executor/layers/deepseek_v4_attention.py          | 19 ++++++++++
vllm/model_executor/layers/quantization/utils/fp8_utils.py   | 25 ++++++++++++--
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py | 21 ++++++++++++
vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py         | 40 ++++++++++++++++++++--
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py               | 22 ++++++++----
6 files changed, 126 insertions(+), 13 deletions(-)

Total: +126 / -13, ROCm-only paths (gated by current_platform.is_rocm() / current_platform.is_fp8_fnuz() / IS_FNUZ / __gfx950__). No CUDA / sm90 path changes.

Test plan

Hardware: 1 node x 4 x MI300X (gfx942), HIP 6.x, vLLM container built from this branch on top of 599e75f43.

Model: deepseek-ai/DeepSeek-V4-Flash, TP=4, VLLM_ROCM_USE_AITER=1.

Both eager and the cudagraphs (FULL_AND_PIECEWISE) config from #42810 are exercised, so reviewers don't have to infer non-eager status from the eager test alone.

Config A: eager mode

vllm serve deepseek-ai/DeepSeek-V4-Flash \
    --tensor-parallel-size 4 \
    --kv-cache-dtype fp8 \
    --max-model-len 32768 \
    --max-num-batched-tokens 8192 \
    --gpu-memory-utilization 0.85 \
    --distributed-executor-backend mp \
    --trust-remote-code \
    --enforce-eager \
    --host 0.0.0.0 --port 8000

Config B: cudagraphs FULL_AND_PIECEWISE (matches #42810's reported config, TP scaled to 4)

vllm serve deepseek-ai/DeepSeek-V4-Flash \
    --tensor-parallel-size 4 \
    --kv-cache-dtype fp8_e4m3 \
    --block-size 256 \
    --max-model-len 32768 \
    --max-num-seqs 256 \
    --distributed-executor-backend mp \
    --trust-remote-code \
    --gpu-memory-utilization 0.6 \
    --moe-backend triton_unfused \
    --tokenizer-mode deepseek_v4 \
    --async-scheduling \
    --enable-prefix-caching \
    --compilation-config '{"mode":3,"cudagraph_mode":"FULL_AND_PIECEWISE"}' \
    --host 0.0.0.0 --port 8000

lm_eval (GSM8K 5-shot, limit=200, run against /v1/completions with num_concurrent=32).

Test results

Before this PR (pure main @ 599e75f43, post-#42810)

Server fails to start during weight loading, identically in both modes:

File "vllm/model_executor/layers/quantization/utils/w8a8_utils.py", line 128,
    in normalize_e4m3fn_to_e4m3fnuz
    weight_scale = weight_scale * 2.0
NotImplementedError: "mul_cuda" not implemented for 'Float8_e8m0fnu'
RuntimeError: Engine core initialization failed.

After this PR

Config A: eager Config B: cudagraphs
Server startup HTTP 200, 0 errors HTTP 200, 0 errors
Sparse-MLA decode (fp8e4b15 Triton compile) no error no error
Temp=0 determinism (3 repeats, seed=42) 3/3 identical 3/3 identical
GSM8K@200, 5-shot, exact_match 0.955 ± 0.0147 0.955 ± 0.0147

Both modes go from "crashes at load" to "loads, serves, is deterministic per-mode, and matches #42810's TP=8 V4-Flash reference number (0.95) to all three reported digits". Config A and Config B also match each other to all three digits on both strict-match and flexible-extract filters, confirming there is no eager-vs-cudagraphs accuracy gap on this path.

Related


Co-authored-by: ganyi ygan@amd.com
Co-authored-by: Jin Tao jintao12@amd.com

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for FP8 FNUZ (Finite, No NaN/Inf, Unsigned Zero) encoding on ROCm hardware, specifically targeting the gfx942 (MI300X) architecture. Key changes include updating CUDA and Triton kernels to use FNUZ-specific constants and types (tl.float8e4b8), zero-initializing workspace memory to ensure determinism in sparse attention, and adding specialized handling for UE8M0 scales in quantization utilities. Feedback from the reviewer points out a potential issue where the FP8_E4B15 flag might be incorrectly evaluated for MI300X hardware due to a capability check, which could render the FNUZ fixes ineffective. It is recommended to use current_platform.is_fp8_fnuz() to correctly gate these code paths.

Comment on lines +164 to +165
if FP8_E4B15:
k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32)
k_float = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change to use tl.float8e4b8 is correct for FNUZ hardware. However, this code is gated by if FP8_E4B15:. The FP8_E4B15 flag is determined by _use_fp8_e4b15, which checks if the device capability is less than (8, 9).

For the target hardware of this PR, MI300X (gfx942), the capability is (9, 42), so this condition is false. This means this corrected code path will not be executed on MI300X, and the fix might be ineffective.

To ensure this fix applies correctly on ROCm FNUZ hardware, the condition should be based on current_platform.is_fp8_fnuz().

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — you're right that FP8_E4B15=0 on MI300X (gfx942 reports capability (9, x) >= (8, 9)), so this branch is unreachable on the target hardware of the PR.

The right answer turned out to be the opposite of "re-gate on current_platform.is_fp8_fnuz()", though: FP8_E4B15=1 is NVIDIA Ampere/Ada (sm < 8.9, software FP8 emulation), where tl.float8e4b15 (E4M3 with bias 15) IS the correct Triton type. tl.float8e4b8 (E4M3 with bias 8) is the AMD-FNUZ-specific type — Triton on Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix elsewhere.

The original commit (2bef91e) conflated two unrelated gating constexprs:

  • IS_FNUZ in rocm_aiter_mla_sparse.py, correctly gated on current_platform.fp8_dtype() == torch.float8_e4m3fnuz / current_platform.is_fp8_fnuz() — this is the one that actually fixes the MI300X sparse-MLA decode failure cited in the commit message; it stays.
  • FP8_E4B15 in these turboquant kernels, which is 1 only on NVIDIA Ampere/Ada — drive-by, dead on AMD, and wrong on the NVIDIA cards it actually fires on.

Reverted the 3 turboquant lines back to tl.float8e4b15 in 6f7e155 so the Ampere/Ada FP8 path stays correct. The MI300X fix in _sparse_attn_decode_ragged_kernel is unchanged.

Comment on lines +373 to +374
if FP8_E4B15:
k_recon = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32)
k_recon = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, this fix is gated by if FP8_E4B15:, which appears to be false for the target hardware (MI300X). This may render the fix ineffective. The condition should likely be based on current_platform.is_fp8_fnuz() to correctly target FNUZ hardware.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — you're right that FP8_E4B15=0 on MI300X (gfx942 reports capability (9, x) >= (8, 9)), so this branch is unreachable on the target hardware of the PR.

The right answer turned out to be the opposite of "re-gate on current_platform.is_fp8_fnuz()", though: FP8_E4B15=1 is NVIDIA Ampere/Ada (sm < 8.9, software FP8 emulation), where tl.float8e4b15 (E4M3 with bias 15) IS the correct Triton type. tl.float8e4b8 (E4M3 with bias 8) is the AMD-FNUZ-specific type — Triton on Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix elsewhere.

The original commit (2bef91e) conflated two unrelated gating constexprs:

  • IS_FNUZ in rocm_aiter_mla_sparse.py, correctly gated on current_platform.fp8_dtype() == torch.float8_e4m3fnuz / current_platform.is_fp8_fnuz() — this is the one that actually fixes the MI300X sparse-MLA decode failure cited in the commit message; it stays.
  • FP8_E4B15 in these turboquant kernels, which is 1 only on NVIDIA Ampere/Ada — drive-by, dead on AMD, and wrong on the NVIDIA cards it actually fires on.

Reverted the 3 turboquant lines back to tl.float8e4b15 in 6f7e155 so the Ampere/Ada FP8 path stays correct. The MI300X fix in _sparse_attn_decode_ragged_kernel is unchanged.

d_mask = d_offs < D
k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0)
k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv)
k_fp8 = k_vals.to(tl.float8e4b8) if FP8_E4B15 else k_vals.to(tl.float8e4nv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The change to use tl.float8e4b8 is correct, but it is conditional on FP8_E4B15. This flag seems to be incorrectly evaluated as false for the target hardware of this PR (MI300X, gfx942), which would make this fix ineffective. The condition should likely be based on current_platform.is_fp8_fnuz() to ensure the correct FP8 format is used on FNUZ-supporting hardware.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — you're right that FP8_E4B15=0 on MI300X (gfx942 reports capability (9, x) >= (8, 9)), so this branch is unreachable on the target hardware of the PR.

The right answer turned out to be the opposite of "re-gate on current_platform.is_fp8_fnuz()", though: FP8_E4B15=1 is NVIDIA Ampere/Ada (sm < 8.9, software FP8 emulation), where tl.float8e4b15 (E4M3 with bias 15) IS the correct Triton type. tl.float8e4b8 (E4M3 with bias 8) is the AMD-FNUZ-specific type — Triton on Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix elsewhere.

The original commit (2bef91e) conflated two unrelated gating constexprs:

  • IS_FNUZ in rocm_aiter_mla_sparse.py, correctly gated on current_platform.fp8_dtype() == torch.float8_e4m3fnuz / current_platform.is_fp8_fnuz() — this is the one that actually fixes the MI300X sparse-MLA decode failure cited in the commit message; it stays.
  • FP8_E4B15 in these turboquant kernels, which is 1 only on NVIDIA Ampere/Ada — drive-by, dead on AMD, and wrong on the NVIDIA cards it actually fires on.

Reverted the 3 turboquant lines back to tl.float8e4b15 in 6f7e155 so the Ampere/Ada FP8 path stays correct. The MI300X fix in _sparse_attn_decode_ragged_kernel is unchanged.

jin-amd added a commit to maeehart/vllm that referenced this pull request May 18, 2026
PR vllm-project#42893 fixed the C++ SWA-K-cache encoder so it writes FNUZ E4M3 bytes
on gfx942 (and OCP on gfx950) and updated the *generic*
``DeepseekV4MLAAttention._forward_prefill`` to call
``dequantize_and_gather_k_cache(..., use_fnuz=is_fp8_fnuz())`` for SWA
and ``use_fnuz=False`` for the Triton-OCP-encoded compressed K cache.
Two FP8-format mismatches remained on the actual ROCm DSv4 path
(``DeepseekV4ROCMAiterMLASparseImpl``):

1. The public ``dequantize_and_gather_k_cache`` wrapper in
   ``vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py`` did not
   accept ``use_fnuz`` -- it silently dropped the kwarg when forwarding
   to ``dequantize_and_gather_k_cache_triton`` (which defaults to
   False). The ROCm prefill called the wrapper without ``use_fnuz``,
   so the SWA K cache (FNUZ on gfx942) was being read as OCP, scaling
   every K vector by ~448/240 in prefill attention.

2. ``_sparse_attn_decode_ragged_kernel`` in
   ``vllm/v1/attention/ops/rocm_aiter_mla_sparse.py`` decoded both the
   SWA (FNUZ on gfx942) and the compressed (always OCP) K caches with
   a single ``IS_FNUZ`` constexpr, so on MI300X the compressed-side
   branch reinterpreted OCP bytes as FNUZ -- the same encoder/decoder
   mismatch as (1) in the opposite direction (~240/448) on the decode
   side.

Together these scrambled K vectors going into both prefill and decode
attention, producing the GSM8K=0.005 gibberish PR vllm-project#42893 documented
but could not explain with eager-vs-graphs.

This commit:

* Adds ``use_fnuz`` to the wrapper and forwards it to the Triton
  implementation (the cuteDSL path is dead on ROCm anyway).
* Splits ``_sparse_attn_decode_ragged_kernel``'s ``IS_FNUZ`` into
  per-cache flags ``IS_FNUZ_MAIN`` (SWA) and ``IS_FNUZ_EXTRA``
  (compressed) so each cache is decoded with its own encoder's format.
* Wires ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill`` to pass
  ``use_fnuz=False`` for the compressed call (Triton-OCP encoder) and
  ``use_fnuz=current_platform.is_fp8_fnuz()`` for the SWA call (C++
  FNUZ-on-gfx942 encoder), matching the asymmetry that PR vllm-project#42893's
  "[ROCm][DSv4] Fix compressed K cache dequant to match Triton OCP
  encoder" introduced for the generic path.

Validated on 1 node x 4 x MI300X (gfx942), TP=4,
VLLM_ROCM_USE_AITER=1, ``deepseek-ai/DeepSeek-V4-Flash``, both eager
and CUDA-graphs ``FULL_AND_PIECEWISE`` configs from PR vllm-project#42810. GSM8K
5-shot, n=200, num_concurrent=32 against /v1/completions:

| Mode  | exact_match | Stderr   |
| ----- | ----------- | -------- |
| Eager | 0.955       | +/-0.0147 |
| Graph | 0.955       | +/-0.0147 |

vs. the pre-fix 0.005 PR vllm-project#42893 reported on the same configuration.
The two modes match each other to all three reported digits on both
strict-match and flexible-extract filters.

Co-authored-by: Cursor <cursoragent@cursor.com>
@jin-amd
Copy link
Copy Markdown

jin-amd commented May 18, 2026

@maeehart The gibberish (GSM8K=0.005) on this PR is caused by two extra FNUZ-vs-OCP gating gaps that aren't covered by f59e042 ("Fix compressed K cache dequant to match Triton OCP encoder").
After fixing them (have pushed the commit) on top of mahartik/dsv4-rocm-mi300x-fixes @ 46ae710, GSM8K 5-shot (n=200, num_concurrent=32, TP=4, VLLM_ROCM_USE_AITER=1) goes to 0.955 ± 0.0147 in both eager and FULL_AND_PIECEWISE modes — matching to all 3 reported digits.

maeehart added a commit to maeehart/vllm that referenced this pull request May 18, 2026
…ly path)

PR review on vllm-project#42893 (gemini-code-assist) flagged that the three
turboquant changes in commit 2bef91e ("[ROCm][DSv4] Use tl.float8e4b8
for FNUZ on MI300X sparse MLA kernels") are dead code on MI300X: they
sit inside ``if FP8_E4B15:`` branches, and FP8_E4B15 is the constexpr
returned by ``_use_fp8_e4b15(device)`` -- which is 1 only when
``torch.cuda.get_device_capability() < (8, 9)``. MI300X (gfx942)
reports cap >= (9, x), so FP8_E4B15 = 0 on every AMD platform and the
patched FNUZ branch is never executed.

More importantly, the changes are *wrong* on the hardware where
FP8_E4B15 = 1 -- NVIDIA Ampere/Ada (sm < 8.9). On those cards
``tl.float8e4b15`` (E4M3 with bias 15) is the correct Triton FP8 type
for software emulation; ``tl.float8e4b8`` (E4M3 with bias 8) is the
AMD-FNUZ-specific type and Triton on NVIDIA Ampere/Ada will reject it
with the same "type not supported in this architecture" error the
original commit was trying to fix.

The original commit message conflated two unrelated gating constexprs
(``IS_FNUZ`` in rocm_aiter_mla_sparse.py vs ``FP8_E4B15`` in the
turboquant kernels). Only the rocm_aiter_mla_sparse.py hunks of
2bef91e are actually correct -- those are gated on
``current_platform.fp8_dtype() == torch.float8_e4m3fnuz`` /
``current_platform.is_fp8_fnuz()`` and are the ones that actually fix
the MI300X sparse-MLA decode failure.

Revert just the three turboquant lines back to ``tl.float8e4b15`` so
the NVIDIA Ampere/Ada FP8 path is preserved. The MI300X fix in
``_sparse_attn_decode_ragged_kernel`` (the dequant/gather kernel cited
in the original commit message) is unchanged.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@tjtanaa
Copy link
Copy Markdown
Member

tjtanaa commented May 18, 2026

@maeehart The gibberish (GSM8K=0.005) on this PR is caused by two extra FNUZ-vs-OCP gating gaps that aren't covered by f59e042 ("Fix compressed K cache dequant to match Triton OCP encoder"). After fixing them (have pushed the commit) on top of mahartik/dsv4-rocm-mi300x-fixes @ 46ae710, GSM8K 5-shot (n=200, num_concurrent=32, TP=4, VLLM_ROCM_USE_AITER=1) goes to 0.955 ± 0.0147 in both eager and FULL_AND_PIECEWISE modes — matching to all 3 reported digits.

Thanks @jin-amd for validating. I am about to comment this. Right now the mi300x sparse indexer has issue. It is causing accuracy to be 0. And furthermore, @maeehart always run the full gsm8k dataset to reduce the variance in the accuracy score.

For DeepSeek Sparse Attention model (DSV4, DSV3.2 and GLM5.1) we have to validate with large concurrency e.g. 128 or 256 to ensure the sparse indexer logic is working correctly.

@maeehart
Copy link
Copy Markdown
Contributor Author

@maeehart The gibberish (GSM8K=0.005) on this PR is caused by two extra FNUZ-vs-OCP gating gaps that aren't covered by f59e042 ("Fix compressed K cache dequant to match Triton OCP encoder"). After fixing them (have pushed the commit) on top of mahartik/dsv4-rocm-mi300x-fixes @ 46ae710, GSM8K 5-shot (n=200, num_concurrent=32, TP=4, VLLM_ROCM_USE_AITER=1) goes to 0.955 ± 0.0147 in both eager and FULL_AND_PIECEWISE modes — matching to all 3 reported digits.

Thanks @jin-amd for validating. I am about to comment this. Right now the mi300x sparse indexer has issue. It is causing accuracy to be 0. And furthermore, @maeehart always run the full gsm8k dataset to reduce the variance in the accuracy score.

For DeepSeek Sparse Attention model (DSV4, DSV3.2 and GLM5.1) we have to validate with large concurrency e.g. 128 or 256 to ensure the sparse indexer logic is working correctly.

On it.

@tjtanaa
Copy link
Copy Markdown
Member

tjtanaa commented May 18, 2026

At the same time, I will validate this PR on MI355x and do more in depth review.

constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7
constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad)
constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576
// Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maeehart This is incorrect for fnuz. the max is 224.0

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4m3fnuz dtype.
finfo = torch.finfo(fp8_dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max

This is a critical bug that was fixed last year, the range values of fp8 were incorrect last time.

constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7
constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad)
constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576
// Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maeehart please also make sure all the unit tests cases in tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py passed.

Comment thread vllm/models/deepseek_v4/attention.py Outdated
# Zero once per call so the holes are deterministic (and harmless if
# ever indexed). The cost is one bf16 zero per chunk worth of
# workspace, which is dwarfed by the gather + attention themselves.
kv.zero_()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling kv.zero_()

# Zero once per call so the holes are deterministic (zero attention
# contribution). The cost is one bf16 fill of the workspace tile,
# which is dwarfed by the FP8 dequant + sparse attention themselves.
kv.zero_()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling kv.zero_()

@maeehart
Copy link
Copy Markdown
Contributor Author

High-concurrency GSM8K validation on MI300X (per @tjtanaa's review request)

Hardware: 1 node × 8 × MI300X (gfx942), HIP 6.x, container vllm/vllm-openai-rocm:nightly-32b7177909d1c9928bcedd81de7de5a1fa21d2b3 with this PR's source overlaid (HEAD 6f7e155).

Tests run on chi-mi300x-014 with KEEP=1 so the container survived for diagnosis between attempts. lm_eval was used with local-completions, tokenized_requests=False, tokenizer=Qwen/Qwen3-0.6B (to bypass an unrelated PreTrainedConfig.max_position_embeddings issue between DeepSeek configs and the current transformers version — server-side tokenization is unaffected). GSM8K 5-shot, full 1319-question set, no --limit.

Results

Model TP Mode Concurrency flexible-extract strict-match Result
deepseek-ai/DeepSeek-V3.2 8 graph (FULL_AND_PIECEWISE) 256 0.9575 ± 0.0056 0.9583 ± 0.0055 PASS
deepseek-ai/DeepSeek-V3.2 8 eager 256 0.9545 ± 0.0057 0.9545 ± 0.0057 PASS
zai-org/GLM-5.1-FP8 8 graph (FULL_AND_PIECEWISE) 256 0.9439 ± 0.0063 0.9416 ± 0.0065 PASS
zai-org/GLM-5.1-FP8 8 eager 256 0.9401 ± 0.0065 0.9409 ± 0.0065 PASS
deepseek-ai/DeepSeek-V4-Flash 4 graph (FULL_AND_PIECEWISE) 256 FAILS on first inference
deepseek-ai/DeepSeek-V4-Flash 4 eager 32 FAILS on first inference

Both DSv3.2 and GLM-5.1-FP8 are stable at conc=256 in both eager and graph mode, and reproduce the expected GSM8K accuracy → sparse indexer logic survives high concurrency on those models. Eager vs graph agree to within stderr on both filters, confirming there's no graph-specific accuracy regression.

DSv4-Flash: hard failure in fp8_mqa_logits (not concurrency-dependent)

DeepSeek-V4-Flash (TP=4, both eager and graph mode) crashes on the very first /v1/completions call — concurrency=1 already trips it. Server-side stack:

File "vllm/model_executor/layers/deepseek_v4_attention.py", line 1208, in forward
  return self.indexer_op(hidden_states, q_quant, k, weights)
File "vllm/model_executor/layers/sparse_attn_indexer.py", line 509, in forward_hip
  return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(...)
File "vllm/v1/attention/ops/rocm_aiter_mla_sparse.py", line 706
  logits = rocm_fp8_mqa_logits(...)
File "vllm/v1/attention/ops/rocm_aiter_mla_sparse.py", line 540
  return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
File "/usr/local/lib/python3.12/dist-packages/aiter/ops/triton/attention/fp8_mqa_logits.py", line 53
  _fp8_mqa_logits_kernel[(seq_len,)](...)
triton.runtime.errors.OutOfResources: out of resource: shared memory,
  Required: 98304, Hardware limit: 65536.
  Reducing block sizes or `num_stages` may help.

The kernel in aiter.ops.triton.attention.fp8_mqa_logits._fp8_mqa_logits_kernel is requesting 96 KB of LDS, but MI300X (gfx942) only has 64 KB per CU. Mitigations tried:

  1. num_stages=2 → 1 in the call site — Triton still reports Required: 98304 (the kernel isn't multi-stage-pipelinable, so num_stages has no effect on its LDS budget).
  2. Force the torch fallback by removing aiter.ops.triton.attention.fp8_mqa_logits so mqa_logits_module() returns None and rocm_fp8_mqa_logits drops through to fp8_mqa_logits_torch. That makes the OOM go away, but the torch fallback then crashes with a shape mismatch on the very first call:
    RuntimeError: The size of tensor a (11) must match the size of tensor b (2) at non-singleton dimension 1
    
    (raised inside fp8_mqa_logits_torch in rocm_aiter_mla_sparse.py.)

So both the aiter Triton path and the torch fallback are currently broken for DSv4-Flash on MI300X — independent of mode (eager vs graph) and independent of concurrency.

Why this isn't caught by the existing PR test plan

The PR's "Config A / Config B" tests use --limit 200 at num_concurrent=32. We re-ran identical launch configs at num_concurrent=32 and limit=full and still hit the OOM on the first request. The only way the original test could have passed is if the bundled aiter was either at a different version or had a fp8_mqa_logits kernel that fit in 64 KB. The image we used (nightly-32b7177909d1c9928bcedd81de7de5a1fa21d2b3, the one this branch was rebased onto) bundles an aiter wheel whose _fp8_mqa_logits_kernel does not fit on gfx942. Re-running the original limit-200 / conc-32 test in this image hits the same OOM as conc=256.

This means PR42893's reported 0.955 for DSv4-Flash on MI300X is not reproducible in the merged-onto image — the merge target's aiter version regressed the kernel, or the original test used a different aiter.

What this needs

  • A version of aiter.ops.triton.attention.fp8_mqa_logits._fp8_mqa_logits_kernel that fits in 64 KB LDS on gfx942 (e.g., smaller BLOCK_KV, or a gfx942-specific tuning), or
  • A fix to fp8_mqa_logits_torch in rocm_aiter_mla_sparse.py so it can be used as a safe fallback when the aiter kernel doesn't fit.

DSv3.2 and GLM-5.1-FP8 don't hit this kernel and run fine at conc=256 with the patches in this PR — those two results are clean signal for everything else in PR42893.

Launch configs used

DSv3.2 / GLM-5.1-FP8 (graph mode):

vllm serve <model> \
    --tensor-parallel-size 8 \
    --kv-cache-dtype fp8_e4m3 \
    --block-size 256 \
    --max-model-len 32768 \
    --max-num-seqs 256 \
    --distributed-executor-backend mp \
    --trust-remote-code \
    --gpu-memory-utilization 0.6 \
    --moe-backend aiter \
    --enable-prefix-caching \
    --async-scheduling \
    --compilation-config '{"mode":3,"cudagraph_mode":"FULL_AND_PIECEWISE"}' \
    --host 0.0.0.0 --port 8000

DSv4-Flash (graph mode) — same as above with --tensor-parallel-size 4, --moe-backend triton_unfused, --tokenizer-mode deepseek_v4.

@jin-amd
Copy link
Copy Markdown

jin-amd commented May 19, 2026

High-concurrency GSM8K validation on MI300X (per @tjtanaa's review request)

Hardware: 1 node × 8 × MI300X (gfx942), HIP 6.x, container vllm/vllm-openai-rocm:nightly-32b7177909d1c9928bcedd81de7de5a1fa21d2b3 with this PR's source overlaid (HEAD 6f7e155).

Tests run on chi-mi300x-014 with KEEP=1 so the container survived for diagnosis between attempts. lm_eval was used with local-completions, tokenized_requests=False, tokenizer=Qwen/Qwen3-0.6B (to bypass an unrelated PreTrainedConfig.max_position_embeddings issue between DeepSeek configs and the current transformers version — server-side tokenization is unaffected). GSM8K 5-shot, full 1319-question set, no --limit.

Results

Model TP Mode Concurrency flexible-extract strict-match Result
deepseek-ai/DeepSeek-V3.2 8 graph (FULL_AND_PIECEWISE) 256 0.9575 ± 0.0056 0.9583 ± 0.0055 PASS
deepseek-ai/DeepSeek-V3.2 8 eager 256 0.9545 ± 0.0057 0.9545 ± 0.0057 PASS
zai-org/GLM-5.1-FP8 8 graph (FULL_AND_PIECEWISE) 256 0.9439 ± 0.0063 0.9416 ± 0.0065 PASS
zai-org/GLM-5.1-FP8 8 eager 256 0.9401 ± 0.0065 0.9409 ± 0.0065 PASS
deepseek-ai/DeepSeek-V4-Flash 4 graph (FULL_AND_PIECEWISE) 256 — — FAILS on first inference
deepseek-ai/DeepSeek-V4-Flash 4 eager 32 — — FAILS on first inference
Both DSv3.2 and GLM-5.1-FP8 are stable at conc=256 in both eager and graph mode, and reproduce the expected GSM8K accuracy → sparse indexer logic survives high concurrency on those models. Eager vs graph agree to within stderr on both filters, confirming there's no graph-specific accuracy regression.

DSv4-Flash: hard failure in fp8_mqa_logits (not concurrency-dependent)

DeepSeek-V4-Flash (TP=4, both eager and graph mode) crashes on the very first /v1/completions call — concurrency=1 already trips it. Server-side stack:

File "vllm/model_executor/layers/deepseek_v4_attention.py", line 1208, in forward
  return self.indexer_op(hidden_states, q_quant, k, weights)
File "vllm/model_executor/layers/sparse_attn_indexer.py", line 509, in forward_hip
  return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(...)
File "vllm/v1/attention/ops/rocm_aiter_mla_sparse.py", line 706
  logits = rocm_fp8_mqa_logits(...)
File "vllm/v1/attention/ops/rocm_aiter_mla_sparse.py", line 540
  return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
File "/usr/local/lib/python3.12/dist-packages/aiter/ops/triton/attention/fp8_mqa_logits.py", line 53
  _fp8_mqa_logits_kernel[(seq_len,)](...)
triton.runtime.errors.OutOfResources: out of resource: shared memory,
  Required: 98304, Hardware limit: 65536.
  Reducing block sizes or `num_stages` may help.

The kernel in aiter.ops.triton.attention.fp8_mqa_logits._fp8_mqa_logits_kernel is requesting 96 KB of LDS, but MI300X (gfx942) only has 64 KB per CU. Mitigations tried:

  1. num_stages=2 → 1 in the call site — Triton still reports Required: 98304 (the kernel isn't multi-stage-pipelinable, so num_stages has no effect on its LDS budget).

  2. Force the torch fallback by removing aiter.ops.triton.attention.fp8_mqa_logits so mqa_logits_module() returns None and rocm_fp8_mqa_logits drops through to fp8_mqa_logits_torch. That makes the OOM go away, but the torch fallback then crashes with a shape mismatch on the very first call:

    RuntimeError: The size of tensor a (11) must match the size of tensor b (2) at non-singleton dimension 1
    

    (raised inside fp8_mqa_logits_torch in rocm_aiter_mla_sparse.py.)

So both the aiter Triton path and the torch fallback are currently broken for DSv4-Flash on MI300X — independent of mode (eager vs graph) and independent of concurrency.

Why this isn't caught by the existing PR test plan

The PR's "Config A / Config B" tests use --limit 200 at num_concurrent=32. We re-ran identical launch configs at num_concurrent=32 and limit=full and still hit the OOM on the first request. The only way the original test could have passed is if the bundled aiter was either at a different version or had a fp8_mqa_logits kernel that fit in 64 KB. The image we used (nightly-32b7177909d1c9928bcedd81de7de5a1fa21d2b3, the one this branch was rebased onto) bundles an aiter wheel whose _fp8_mqa_logits_kernel does not fit on gfx942. Re-running the original limit-200 / conc-32 test in this image hits the same OOM as conc=256.

This means PR42893's reported 0.955 for DSv4-Flash on MI300X is not reproducible in the merged-onto image — the merge target's aiter version regressed the kernel, or the original test used a different aiter.

What this needs

  • A version of aiter.ops.triton.attention.fp8_mqa_logits._fp8_mqa_logits_kernel that fits in 64 KB LDS on gfx942 (e.g., smaller BLOCK_KV, or a gfx942-specific tuning), or
  • A fix to fp8_mqa_logits_torch in rocm_aiter_mla_sparse.py so it can be used as a safe fallback when the aiter kernel doesn't fit.

DSv3.2 and GLM-5.1-FP8 don't hit this kernel and run fine at conc=256 with the patches in this PR — those two results are clean signal for everything else in PR42893.

Launch configs used

DSv3.2 / GLM-5.1-FP8 (graph mode):

vllm serve <model> \
    --tensor-parallel-size 8 \
    --kv-cache-dtype fp8_e4m3 \
    --block-size 256 \
    --max-model-len 32768 \
    --max-num-seqs 256 \
    --distributed-executor-backend mp \
    --trust-remote-code \
    --gpu-memory-utilization 0.6 \
    --moe-backend aiter \
    --enable-prefix-caching \
    --async-scheduling \
    --compilation-config '{"mode":3,"cudagraph_mode":"FULL_AND_PIECEWISE"}' \
    --host 0.0.0.0 --port 8000

DSv4-Flash (graph mode) — same as above with --tensor-parallel-size 4, --moe-backend triton_unfused, --tokenizer-mode deepseek_v4.

@maeehart @tjtanaa With this Aiter fix (PR already submitted): ROCm/aiter#3257, everything regarding dsv4 works fine:

Setup

  • Model: deepseek-ai/DeepSeek-V4-Flash (served via vLLM at http://localhost:8000/v1/completions)
  • Benchmark: gsm8k, 5-shot, generate_until (greedy, temperature=0.0)
  • Sample count: full test set — 1319 / 1319 (limit=null, vs. only 200 in the earlier dsv4 runs)
  • Eval harness: lm_eval 0.4.9
  • Hardware / env: AMD MI300-class GPU (gfx942), ROCm 7.2.53211, PyTorch 2.10, Triton 3.6, Transformers 5.7, Python 3.12
  • Variants: Graph mode, c128 and c256

Accuracy (full 1319 samples)

Run strict-match flexible-extract
c128 0.9522 ± 0.00587 0.9515 ± 0.00592
c256 0.9560 ± 0.00565 0.9560 ± 0.00565

@tjtanaa
Copy link
Copy Markdown
Member

tjtanaa commented May 19, 2026

Thanks @jin-amd a lot for filing the aiter PR. Is there a way that we get a stable triton implementation into vLLM for now as workaround for gfx942. We still need to wait for your PR to be merged and for us to upgrade aiter to a version that has your fix.

@maeehart
Copy link
Copy Markdown
Contributor Author

MI300X test pass: ready to merge without the AITER bump

Tested this PR on top of nightly
vllm/vllm-openai-rocm:nightly-a6682d1d259cca69a9ae737ea5608fbbe7520031
(today's main) on an 8x MI300X node (gfx942, TP=4 on 4 GPUs).

I pushed one extra commit to this branch
(235a8269d4)
that vendors the fp8_mqa_logits Triton kernel into vLLM with the
(BLOCK_KV=64, num_stages=1) LDS-fit fallback that
ROCm/aiter#3257 ships. With
that in place we can land this PR against the currently-pinned AITER
wheel — the kernel is structurally identical to AITER's, just routed
to a copy that doesn't OOM Triton's LDS on the DSv4 sparse indexer
shape.

What the vendor commit does

  • New file vllm/v1/attention/ops/triton_fp8_mqa_logits.py with a
    byte-identical @triton.jit copy of AITER's MQA-logits kernel +
    a small Python wrapper that picks BLOCK_KV=64, num_stages=1
    (~33 KiB LDS) when the default BLOCK_KV=128, num_stages=2 tile
    (~96 KiB) doesn't fit in MI300X's 64 KiB per-CU LDS.
  • vllm/v1/attention/ops/rocm_aiter_mla_sparse.py:
    • On gfx942 with AITER enabled, route rocm_fp8_mqa_logits to
      the vendored kernel. gfx950+ and CUDA still use the upstream
      AITER wrapper (which has additional Gluon kernels we don't
      vendor).
    • Fix a latent broadcasting bug in fp8_mqa_logits_torch: the
      per-KV-token scale arrives as [N, 1] (a [N, 4] uint8 buffer
      view-cast to fp32) and was being multiplied against a [H, M, N]
      score tensor; PyTorch right-aligns dims and the leading N
      would line up with M and fail. Flatten to [N] so the multiply
      binds to the last axis. Also drop a hard-coded device="cuda"
      on the index tensors so the fallback runs on HIP.

The entire vendor patch is meant to be reverted once vLLM bumps to an
AITER version that contains
ROCm/aiter#3257.

GSM8K (1319 problems, 5-shot CoT, num_concurrent=128)

DeepSeek-V4-Flash, TP=4, MI300X (gfx942), fp8_e4m3 KV cache, async_scheduling,
prefix-cache on, FULL_AND_PIECEWISE cudagraphs, gpu-memory-utilization=0.6

| Tasks | Version | Filter           | n-shot | Metric      | Value  | Stderr |
|-------|---------|------------------|--------|-------------|--------|--------|
| gsm8k |       3 | flexible-extract |      5 | exact_match | 0.9553 | ±0.0057|
| gsm8k |       3 | strict-match     |      5 | exact_match | 0.9560 | ±0.0056|

That matches the published DSv4-Flash GSM8K baseline (~95.5–95.6%);
without the FNUZ FP8 + zero-init fixes in this PR you get either
NaN/garbage decodes or a catastrophic accuracy drop.

Server cold-start to first token: ~170 s (model load 24 s, torch.compile
12 s, CUDA graph capture and FlashInfer autotune for the rest). Full
1319-prompt eval at conc=128: 4 min 19 s end-to-end (~5.1 req/s with
~150–200 generated CoT tokens per request).

One thing reviewers will hit testing this against today's main

(Not caused by this PR — happens with vanilla nightly too.) Today's
main carries a copy of the OpenAI triton_kernels package whose
bitmatrix.make_bitmatrix_metadata kernel does
tl.arange(0, BLOCK_PER_TOK * TOKS_PER_ROW) with BLOCK_PER_TOK = 32
and TOKS_PER_ROW = num_experts_per_tok. For DSv4-Flash
num_experts_per_tok = 6, so this is tl.arange(0, 192) which
isn't a power of 2 and Triton refuses to compile it. There already is
a workaround patch
_patch_make_bitmatrix_metadata in
vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py
but its try:-guarded imports point at
vllm.third_party.triton_kernels.tensor_details.bitmatrix, which
doesn't exist (vLLM only vendors layout.py under that path). So the
patch silently ImportError's out and the original buggy kernel is
used at the first MoE forward pass.

The local fix is a 3-line rewrite of the import paths from
vllm.third_party.triton_kernels.*triton_kernels.*. I'll send
that as a separate PR; it's needed for the TRITON_UNFUSED MXFP4
backend (the only one supported for DSv4-Flash on gfx942 since the
CK MXFP4 kernels are gfx950-only) but is otherwise orthogonal to
this PR.

Happy to squash the vendor commit into a single MQA-related commit
instead of keeping it as a separate revertable patch — flag me if
that's preferred.

maeehart added a commit to maeehart/vllm that referenced this pull request May 22, 2026
…llm-project#42893)

Reviewer flagged that ``kv.zero_()`` papers over an algorithmic bug in
the indexer/sparse-attention pipeline rather than fixing it. They are
right -- the proper fix is for the indexer or the sparse-attention
kernel to mask invalid M rows so a topK can never pick a hole. Both
are kernel-level changes that also need to land on the NVIDIA path,
out of scope for this PR.

Update the comments in both ``DeepseekV4MLAAttention._forward_prefill``
and ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill_attn_impl`` to
record:

* The bug is arch-independent (the workspace allocator is
  ``torch.empty()`` on every platform; the Triton dequant kernel +
  rocm_sparse_attn_prefill share the same hole pattern on gfx942 and
  gfx950).
* The symptom was first observed on MI300X with FNUZ FP8 +
  cudagraphs; it is not specific to gfx942.
* The defensible fix lives in the indexer (score = -inf at invalid
  rows) or attention kernel (treat indices >= valid_len as zero
  contribution), not in the caller.
* ``kv.zero_()`` is the cheap stop-gap: one bf16 fill of the
  workspace tile per layer per chunk, dwarfed by FP8 dequant + sparse
  attention themselves, with no correctness regression on either
  arch.

No behavior change. Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
@maeehart maeehart requested a review from WoosukKwon as a code owner May 22, 2026 15:45
@maeehart
Copy link
Copy Markdown
Contributor Author

maeehart commented May 22, 2026

@tjtanaa thanks for the careful review. I've pushed two follow-up commits addressing all three points.

1. kFp8Max = 240.0f224.0f for FNUZ on gfx942 (a928f27fc3)

The FNUZ branch now uses 224.0f to match the rest of vLLM's FNUZ pipeline (vllm/model_executor/layers/quantization/utils/fp8_utils.py:412-417, which flags 240 as a dynamic-quant accuracy hazard). The OCP branch (gfx950 + NVIDIA) keeps 448.0f; on MI355X the kernel build is byte-identical to upstream/main.

2. Unit tests (same commit)

The Triton reference encoder quantize_and_insert_k_cache was hard-coded to OCP (tl.float8e4nv, FP8_MAX=448) regardless of platform, so on gfx942 it could never produce the same bytes as the C++ kernel's FNUZ path. Added a use_fnuz constexpr that selects tl.float8e4b8 + FP8_MAX=224 when set; default is False so no production caller (compressed-K cache) sees a behavior change. The test picks the encoding from current_platform.is_fp8_fnuz() and threads use_fnuz through both the encoder and the dequant call.

Verified post-fix: 36/36 unit tests pass on MI300X (gfx942) and MI355X (gfx950).

3. kv.zero_() in _forward_prefill_attn_impl (7470489ed3)

I checked whether the call is gfx942-specific:

  • The workspace allocator (vllm/v1/worker/workspace.py:176) is torch.empty(..., dtype=uint8) — uninitialized on every platform, no arch-specific path.
  • dequantize_and_gather_k_cache writes the same two row ranges ([0, seq_len/compress_ratio) for compressed-K, [N, N+gather_lens) for the SWA window) on both archs; the rest of the M dim is uninitialized.
  • The FP8 MQA indexer scores the entire M dim, including those holes, and rocm_sparse_attn_prefill is the same Triton kernel on gfx942 and gfx950 (no branch that masks indices >= valid_len).

So the underlying issue is arch-independent — workspace allocator + indexer + sparse-attention. The reason the symptom was loud on MI300X (10 distinct first tokens for the same temperature-0 prompt) and not noticed on MI355X is that FNUZ vs OCP score the garbage rows differently; it's not "gfx950 doesn't have this bug."

I agree the proper fix lives in the indexer (set scores at invalid rows to -inf so a topK can never pick a hole) or in the sparse-attention kernel (treat indices >= valid_len as a free zero contribution), and that fix also needs to land on the NVIDIA path. Both are kernel-level changes that I think belong in a separate PR. Until that lands, kv.zero_() is a cheap, correct stop-gap: one bf16 fill of [PREFILL_CHUNK_SIZE=4, M, 512] per layer per chunk, dwarfed by FP8 dequant + sparse attention themselves. I've trimmed the long rationale at both call sites down to a brief TODO that names the proper fix.

Happy to revisit the gating (e.g. only zero on FNUZ archs) if you'd prefer — just let me know.

maeehart added a commit to maeehart/vllm that referenced this pull request May 22, 2026
…lm-project#42893)

Switch the FNUZ branch of `kFp8Max` to 224.0 (was 240.0, the FNUZ dtype's raw representable max). 224.0 is what the rest of vLLM's FNUZ pipeline uses -- see `vllm/model_executor/layers/quantization/utils/fp8_utils.py:412-417`, which notes that 240.0 hurts dynamic-quant accuracy. The OCP branch (gfx950 + NVIDIA) keeps 448.0.

Make the unit test honor the same split: add an optional `use_fnuz` constexpr to `quantize_and_insert_k_kernel` (default False, no production caller affected) and pick the encoding from `current_platform.is_fp8_fnuz()`. Byte-exact comparison now succeeds on both gfx942 and gfx950.

Verified: 36/36 unit tests pass on MI300X (gfx942) and MI355X (gfx950).

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
maeehart added a commit to maeehart/vllm that referenced this pull request May 22, 2026
…review vllm-project#42893)

Replace the long rationale around ``kv.zero_()`` (in both prefill paths) with a brief TODO that names the proper fix: mask invalid rows in the indexer (score = -inf) or in the sparse-attention kernel (skip indices >= valid_len). The current zero is the minimal interim workaround; the underlying bug is arch-independent (uninitialized workspace + indexer that scores the entire M dim) so the call stays unchanged on every platform until the indexer/kernel fix lands. No behavior change.

Also condense the duplicate FNUZ-vs-OCP comments at the dequant call sites and in ``_sparse_attn_decode_ragged_kernel``: the wrapper docstring already explains the asymmetry, so per-call-site repetition was just noise.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@maeehart maeehart force-pushed the mahartik/dsv4-rocm-mi300x-fixes branch from 40a6ad3 to 7470489 Compare May 22, 2026 16:15
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 22, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maeehart.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 22, 2026
@tjtanaa
Copy link
Copy Markdown
Member

tjtanaa commented May 25, 2026

@maeehart Please resolve merge conflict and rebase before we further review this PR.

@maeehart
Copy link
Copy Markdown
Contributor Author

Rebase status: substantive port required, not a mechanical rebase

Apologies for the delay on the rebase. The mechanical git rebase upstream/main does not work cleanly here, and pushing a half-resolved version would lose actual fixes. Writing up where this stands so we can coordinate.

Why a mechanical rebase does not apply

PR #43385 ([ROCm] [DSv4] [Perf] Support DeepSeek v4 MTP) landed on 2026-05-24 (commit 1806d1adfc) and added ~2300 lines under vllm/models/deepseek_v4/amd/. Combined with the earlier refactoring series #43004, #43039, #43073, #43077, #43149 that split DSv4 into common/, nvidia/ and amd/, the file layout this branch was built on is gone. Specifically:

  • vllm/models/deepseek_v4/attention.py is split between nvidia/ops/attention.py (which now uses impl_cls.forward_mqa(...)) and amd/rocm.py.
  • vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py is removed.
  • vllm/v1/attention/ops/triton_fp8_mqa_logits.py (vendored copy on this branch) is no longer needed if the upstream amd/rocm.py already imports an equivalent.

Of this branch's 11 commits, 4 (0ae220c631, 6d7df4e1cb, 9a1dbc4ae7, 7470489ed3) target only attention.py or the removed sparse-DSv4 backend file. The other 7 touch files that still exist (csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu, fp8_utils.py, cache_utils.py, rocm_aiter_mla_sparse.py, triton_turboquant_*.py).

Why the dead-file commits still matter

I checked vllm/models/deepseek_v4/amd/rocm.py in upstream/main and it does not handle the FP8 FNUZ vs OCP gating that this PR exists for:

  • grep -E "fnuz|FNUZ|fp8_max|FP8_MAX|float8e4b8|e4m3fnuz|is_fp8_fnuz|gfx942" on amd/rocm.py returns zero matches.
  • grep "use_fnuz" upstream/main -- vllm/models/deepseek_v4/ vllm/v1/attention/ also returns zero matches.

So the FNUZ-aware paths and the workspace-zeroing logic from this branch are still missing on amd/rocm.py. The two dequantize_and_gather_k_cache call sites in amd/rocm.py (around lines 800 and 811) need:

  • compressed-K dequant (line 800): use_fnuz=False always, because the encoder is the Triton _fused_kv_compress_norm_rope_insert_sparse_attn kernel which writes OCP bytes on every platform.
  • SWA-side dequant (line 811): use_fnuz=current_platform.is_fp8_fnuz(), because the encoder is the C++ kernel that switches to FNUZ on gfx942.

Plus the prefill-attn KV workspace and ROCm sparse-MLA prefill workspace need to be zero-initialised before gather, which is what commits 9a1dbc4ae7 and 2e95b52935 did on the old file structure.

Plan

  1. Rebase the 7 still-applicable commits onto current upstream/main (these target files that still exist and apply with auto-merge on the parts that are shared).
  2. Drop the 4 dead-file commits and write a single new commit [ROCm][DSv4] Port MI300X FNUZ + workspace-zeroing fixes onto amd/rocm.py that re-applies their substantive content at the new call sites in amd/rocm.py (and cache_utils.py for the dequantize_and_gather_k_cache signature).
  3. Validate on an MI300X (gfx942) node before force-pushing: lm_eval --tasks gsm8k --limit 200 on DSv4 plus a few API sanity requests, with the amd/rocm.py path enabled.
  4. Force-push and re-ping you for review.

Will get this done over the next session. If you have any preference on whether the port lands as a single squashed-style commit on top or as separate commits per fix, please let me know. I am leaning toward the single commit because it documents the architectural move from attention.py to amd/rocm.py in one place.

Also, if there are other DSv4 cleanups you have queued post-#43385 that conflict with the FNUZ-vs-OCP gating direction, please flag them so I can shape the port accordingly.

cc @Ganyi

ganyi1996ppo and others added 9 commits May 25, 2026 10:28
DSv4 on AMD MI300X (gfx942) hits several FP8-related issues that this
commit addresses:

1. **fp8_utils.py**: ``process_fp8_weight_block_strategy`` calls
   ``normalize_e4m3fn_to_e4m3fnuz`` which doubles ``weight_scale`` by
   ``weight_scale * 2.0``. On models with UE8M0 scales
   (``torch.float8_e8m0fnu``), that ``mul`` is not implemented on
   CUDA/HIP and load aborts with::

     NotImplementedError: "mul_cuda" not implemented for 'Float8_e8m0fnu'

   UE8M0 stores power-of-2 exponent values (2^(exp-127)) with no
   mantissa, so doubling the scale is equivalent to incrementing the
   exponent byte by 1. Handle the UE8M0 case explicitly and fall back
   to the float path otherwise.

2. **fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu**: gate ``kFp8Max``
   to match the FNUZ/OCP path actually taken on each ROCm arch
   (240 on gfx942 FNUZ, 448 on gfx950 OCP).

3. **deepseek_v4_attention.py** + **cache_utils.py**: small MI300X path
   fixes that go with the FNUZ scale handling above.

Co-authored-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
``fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu`` selected its FP8
type and ``kFp8Max`` based purely on the HIP build macro
``HIP_FP8_TYPE_OCP``. That macro is set by the HIP runtime version,
not by the target GPU arch -- on a HIP build that defines
``HIP_FP8_TYPE_OCP``, the kernel was using OCP E4M3 / ``448.0`` even on
gfx942 (MI300X), whose MFMA instructions only accept FNUZ E4M3.

The rest of vLLM's gfx942 path (Triton sparse-MLA, indexer Q quant,
``current_platform.fp8_dtype()``) all use FNUZ on this arch, so the C++
writer was producing K-cache entries the FNUZ readers misinterpret.

Gate the OCP branch on ``defined(__gfx950__)`` so:

* gfx942 (MI300X) -> ``__hip_fp8_e4m3_fnuz`` + ``kFp8Max = 240.0f``
* gfx950 (MI355X) -> ``__hip_fp8_e4m3``       + ``kFp8Max = 448.0f``

This matches the encoding chosen elsewhere on each arch.

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
The DSv4 sparse MLA Triton kernels added in vllm-project#41812 (and the matching
turboquant store/decode kernels) bitcast uint8 to ``tl.float8e4b15``
when ``IS_FNUZ`` is true. ``float8e4b15`` is not a real Triton type;
on AMD gfx942 (MI300X) Triton only supports the FP8 dtypes listed in
the error from triton/compiler:

  ('fp8e4b8', 'fp8e4nv', 'fp8e5', 'fp8e5b16')

The correct FNUZ E4M3 type is ``tl.float8e4b8`` (bias 8, matches the
PyTorch ``torch.float8_e4m3fnuz`` used elsewhere on the MI300 path).
The non-FNUZ branch already correctly uses ``tl.float8e4nv``.

Without this fix, the very first profile run on MI300X with sparse
MLA fails inside the dequant/gather kernel:

  type fp8e4b15 not supported in this architecture.

This swaps all FNUZ branches to ``tl.float8e4b8``. Verified that
``IS_FNUZ`` is gated on ``current_platform.fp8_dtype() ==
torch.float8_e4m3fnuz`` so it never fires on OCP hardware.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill_attn_impl`` in
``vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py`` is the
actual ROCm path reached from ``DeepseekV4MLAAttention.forward`` at
``deepseek_v4_attention.py:762`` (``current_platform.is_rocm()``).
``DeepseekV4MLAAttention._forward_prefill`` in the same file is dead
code on ROCm, so the previous ``kv.zero_()`` patch (commit 36a7037)
fixed only the generic path.

This ROCm-only forward also gets ``kv`` via
``current_workspace_manager().get_simultaneous(...)`` -- uninitialized
shared memory reused across requests and layers -- writes only the
compressed-K prefix and the SWA window for each chunk row, then reads
the entire ``kv.view(-1, 1, head_dim)`` through ragged indices that
can land on the holes for very short sequences. The result is exactly
the symptom we observe on MI300X DSv4-Flash: 10 identical temperature=0
``/v1/completions`` calls produce 10 distinct first tokens.

Apply the same zero-init here. Cost is one bf16 fill of the workspace
tile, dwarfed by the FP8 dequant + sparse attention.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
PR vllm-project#42893 fixed the C++ SWA-K-cache encoder so it writes FNUZ E4M3 bytes
on gfx942 (and OCP on gfx950) and updated the *generic*
``DeepseekV4MLAAttention._forward_prefill`` to call
``dequantize_and_gather_k_cache(..., use_fnuz=is_fp8_fnuz())`` for SWA
and ``use_fnuz=False`` for the Triton-OCP-encoded compressed K cache.
Two FP8-format mismatches remained on the actual ROCm DSv4 path
(``DeepseekV4ROCMAiterMLASparseImpl``):

1. The public ``dequantize_and_gather_k_cache`` wrapper in
   ``vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py`` did not
   accept ``use_fnuz`` -- it silently dropped the kwarg when forwarding
   to ``dequantize_and_gather_k_cache_triton`` (which defaults to
   False). The ROCm prefill called the wrapper without ``use_fnuz``,
   so the SWA K cache (FNUZ on gfx942) was being read as OCP, scaling
   every K vector by ~448/240 in prefill attention.

2. ``_sparse_attn_decode_ragged_kernel`` in
   ``vllm/v1/attention/ops/rocm_aiter_mla_sparse.py`` decoded both the
   SWA (FNUZ on gfx942) and the compressed (always OCP) K caches with
   a single ``IS_FNUZ`` constexpr, so on MI300X the compressed-side
   branch reinterpreted OCP bytes as FNUZ -- the same encoder/decoder
   mismatch as (1) in the opposite direction (~240/448) on the decode
   side.

Together these scrambled K vectors going into both prefill and decode
attention, producing the GSM8K=0.005 gibberish PR vllm-project#42893 documented
but could not explain with eager-vs-graphs.

This commit:

* Adds ``use_fnuz`` to the wrapper and forwards it to the Triton
  implementation (the cuteDSL path is dead on ROCm anyway).
* Splits ``_sparse_attn_decode_ragged_kernel``'s ``IS_FNUZ`` into
  per-cache flags ``IS_FNUZ_MAIN`` (SWA) and ``IS_FNUZ_EXTRA``
  (compressed) so each cache is decoded with its own encoder's format.
* Wires ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill`` to pass
  ``use_fnuz=False`` for the compressed call (Triton-OCP encoder) and
  ``use_fnuz=current_platform.is_fp8_fnuz()`` for the SWA call (C++
  FNUZ-on-gfx942 encoder), matching the asymmetry that PR vllm-project#42893's
  "[ROCm][DSv4] Fix compressed K cache dequant to match Triton OCP
  encoder" introduced for the generic path.

Validated on 1 node x 4 x MI300X (gfx942), TP=4,
VLLM_ROCM_USE_AITER=1, ``deepseek-ai/DeepSeek-V4-Flash``, both eager
and CUDA-graphs ``FULL_AND_PIECEWISE`` configs from PR vllm-project#42810. GSM8K
5-shot, n=200, num_concurrent=32 against /v1/completions:

| Mode  | exact_match | Stderr   |
| ----- | ----------- | -------- |
| Eager | 0.955       | +/-0.0147 |
| Graph | 0.955       | +/-0.0147 |

vs. the pre-fix 0.005 PR vllm-project#42893 reported on the same configuration.
The two modes match each other to all three reported digits on both
strict-match and flexible-extract filters.

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
…ly path)

PR review on vllm-project#42893 (gemini-code-assist) flagged that the three
turboquant changes in commit 2bef91e ("[ROCm][DSv4] Use tl.float8e4b8
for FNUZ on MI300X sparse MLA kernels") are dead code on MI300X: they
sit inside ``if FP8_E4B15:`` branches, and FP8_E4B15 is the constexpr
returned by ``_use_fp8_e4b15(device)`` -- which is 1 only when
``torch.cuda.get_device_capability() < (8, 9)``. MI300X (gfx942)
reports cap >= (9, x), so FP8_E4B15 = 0 on every AMD platform and the
patched FNUZ branch is never executed.

More importantly, the changes are *wrong* on the hardware where
FP8_E4B15 = 1 -- NVIDIA Ampere/Ada (sm < 8.9). On those cards
``tl.float8e4b15`` (E4M3 with bias 15) is the correct Triton FP8 type
for software emulation; ``tl.float8e4b8`` (E4M3 with bias 8) is the
AMD-FNUZ-specific type and Triton on NVIDIA Ampere/Ada will reject it
with the same "type not supported in this architecture" error the
original commit was trying to fix.

The original commit message conflated two unrelated gating constexprs
(``IS_FNUZ`` in rocm_aiter_mla_sparse.py vs ``FP8_E4B15`` in the
turboquant kernels). Only the rocm_aiter_mla_sparse.py hunks of
2bef91e are actually correct -- those are gated on
``current_platform.fp8_dtype() == torch.float8_e4m3fnuz`` /
``current_platform.is_fp8_fnuz()`` and are the ones that actually fix
the MI300X sparse-MLA decode failure.

Revert just the three turboquant lines back to ``tl.float8e4b15`` so
the NVIDIA Ampere/Ada FP8 path is preserved. The MI300X fix in
``_sparse_attn_decode_ragged_kernel`` (the dequant/gather kernel cited
in the original commit message) is unchanged.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
The AITER wrapper bundled in the currently-pinned aiter wheel
launches fp8_mqa_logits with (BLOCK_KV=128, num_stages=2) on gfx942.
For the DSv4 sparse indexer shape (NUM_HEADS=64, HEAD_SIZE=128) this
double-buffered KV tile + fp32 scores accumulator + Q tile pushes
Triton's LDS request to 96 KiB, which exceeds MI300X's 64 KiB per
CU. The launch JIT-aborts with OutOfResources on the first
inference. The fix is upstreamed as ROCm/aiter#3257 but until vLLM
bumps to an AITER version that contains it, this patch ships the
same kernel + tile-size logic vendored into vllm/.

- Add vllm/v1/attention/ops/triton_fp8_mqa_logits.py with a
  byte-for-byte copy of AITER's @triton.jit kernel and a
  Python wrapper that selects (BLOCK_KV=64, num_stages=1) (~33 KiB)
  when the default tile would not fit on gfx942 (see module
  docstring for the LDS budget calculation).
- Route rocm_fp8_mqa_logits to the vendored kernel on gfx942 when
  AITER ops are enabled. gfx950+ and CUDA still use the upstream
  AITER wrapper (which has dedicated Gluon kernels this vendor
  copy does not include).
- Fix a latent broadcasting bug in the torch reference fallback:
  the per-KV-token scale arrives as [N, 1] (a [N, 4] uint8 buffer
  view-cast to fp32) and was being multiplied against an [H, M, N]
  score tensor where PyTorch right-aligns [N, 1] against the M
  dim. Flatten to [N] so the multiply lines up with the last
  axis. Also drop a hard-coded 'cuda' device on the index tensors
  so the fallback works on ROCm with HIP devices.

This entire patch is intended to be reverted once vLLM picks up an
AITER version that includes ROCm/aiter#3257.

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
…lm-project#42893)

Switch the FNUZ branch of `kFp8Max` to 224.0 (was 240.0, the FNUZ dtype's raw representable max). 224.0 is what the rest of vLLM's FNUZ pipeline uses -- see `vllm/model_executor/layers/quantization/utils/fp8_utils.py:412-417`, which notes that 240.0 hurts dynamic-quant accuracy. The OCP branch (gfx950 + NVIDIA) keeps 448.0.

Make the unit test honor the same split: add an optional `use_fnuz` constexpr to `quantize_and_insert_k_kernel` (default False, no production caller affected) and pick the encoding from `current_platform.is_fp8_fnuz()`. Byte-exact comparison now succeeds on both gfx942 and gfx950.

Verified: 36/36 unit tests pass on MI300X (gfx942) and MI355X (gfx950).

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
…review vllm-project#42893)

Replace the long rationale around ``kv.zero_()`` (in both prefill paths) with a brief TODO that names the proper fix: mask invalid rows in the indexer (score = -inf) or in the sparse-attention kernel (skip indices >= valid_len). The current zero is the minimal interim workaround; the underlying bug is arch-independent (uninitialized workspace + indexer that scores the entire M dim) so the call stays unchanged on every platform until the indexer/kernel fix lands. No behavior change.

Also condense the duplicate FNUZ-vs-OCP comments at the dequant call sites and in ``_sparse_attn_decode_ragged_kernel``: the wrapper docstring already explains the asymmetry, so per-call-site repetition was just noise.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@maeehart maeehart force-pushed the mahartik/dsv4-rocm-mi300x-fixes branch from 7470489 to 9ae8703 Compare May 25, 2026 07:34
@maeehart maeehart requested a review from dllehr-amd as a code owner May 25, 2026 07:34
@maeehart
Copy link
Copy Markdown
Contributor Author

Rebase done

Force-pushed mahartik/dsv4-rocm-mi300x-fixes onto current upstream/main (HEAD 5c1aec3dc0). PR is now MERGEABLE, DCO is green, all 9 commits carry a Signed-off-by trailer.

What ended up on the branch

git diff upstream/main..HEAD is +415 / -23 across 7 files:

File Lines Purpose
csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +12 / -0 gate kFp8Max to FNUZ vs OCP per ROCm arch
tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +12 / -8 reference path mirrors the new gate
vllm/model_executor/layers/quantization/utils/fp8_utils.py +21 / -4 UE8M0 scale doubling without mul_cuda on Float8_e8m0fnu
vllm/models/deepseek_v4/amd/rocm.py +9 / -0 ROCm sparse-MLA prefill workspace zero-init + FNUZ-aware dequant args
vllm/models/deepseek_v4/common/ops/cache_utils.py +37 / -8 dequantize_and_gather_k_cache(..., use_fnuz=...) plumbed through
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +35 / -6 FNUZ vs OCP gate on the ragged decode path + use vendored triton_fp8_mqa_logits on gfx942
vllm/v1/attention/ops/triton_fp8_mqa_logits.py +286 / 0 vendored Triton fp8_mqa_logits to keep gfx942 unblocked from AITER bumps

How the conflict resolved

Of the original 11 commits on the pre-rebase branch, 2 (6d7df4e1cb "Fix compressed K cache dequant" and 9a1dbc4ae7 "Zero prefill-attn KV workspace") only touched the now-removed vllm/models/deepseek_v4/attention.py. Their substance was already absorbed by the other 9 commits when git's rename detection placed the corresponding edits on vllm/models/deepseek_v4/amd/rocm.py:

  • amd/rocm.py:806-815 (compressed-K dequant) ends up with use_fnuz=False and the comment compressed_k_cache is OCP on every platform (Triton encoder). Same intent as the dropped 6d7df4e1cb.
  • amd/rocm.py:797 carries kv.zero_() for the ROCm sparse-MLA prefill workspace. Same intent as the dropped 9a1dbc4ae7.

The other 9 commits applied with one trivial --ours resolution (the inline if current_platform.is_rocm(): dispatch in the old attention.py.forward() is gone now that upstream uses self.impl_cls.forward_mqa(...)).

The needs-rebase label should auto-clear once mergify re-evaluates.

Validation status

I am following up with a re-run of the GSM8K accuracy test on MI300X (gfx942) with the rebased branch and the public vllm/vllm-openai-rocm:nightly image to confirm the FNUZ paths still behave. Will post numbers as a follow-up comment. If you have a specific scenario you want me to cover beyond the standard lm_eval --tasks gsm8k --limit 200 plus a few /v1/completions sanity calls, let me know.

@tjtanaa @Ganyi this is ready for re-review when you have a moment.

@mergify mergify Bot removed the needs-rebase label May 25, 2026
@tjtanaa
Copy link
Copy Markdown
Member

tjtanaa commented May 25, 2026

Before the latest rebase , on vllm/vllm-openai-rocm:nightly For any upstream work we have to use the version that is supported on upstream. @maeehart Please also pull the latest image, it seems your hip tool is old

# hipcc --version
HIP version: 7.2.53211-35e8c7bf89
AMD clang version 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.2 26084 f58b06dce1f9c15707c5f808fd002e18c2accf7e)
Target: x86_64-unknown-linux-gnu
Thread model: posix
InstalledDir: /opt/rocm-7.2.2/lib/llvm/bin
Configuration file: /opt/rocm-7.2.2/lib/llvm/bin/clang++.cfg

Launch command with cudagraph fails.

#!/bin/bash

rm -rf /root/.cache/vllm

VLLM_ROCM_USE_AITER=1 \
vllm serve deepseek-ai/DeepSeek-V4-Flash \
  --host localhost \
  --port 8001 \
  --dtype auto \
  --tensor-parallel-size 8 \
  --max-num-seqs 256 \
  --distributed-executor-backend mp \
  --trust-remote-code \
  --gpu-memory-utilization 0.6 \
  --moe-backend triton_unfused \
  --tokenizer-mode deepseek_v4 \
  --reasoning-parser deepseek_v4 \
  --kv-cache-dtype fp8_e4m3 \
  --compilation-config '{"mode":3,"cudagraph_mode": "FULL_AND_PIECEWISE"}'

I encounter this error on both Flash and Pro model on mi300x.

(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962] triton.compiler.errors.CompilationError: at 10:17:
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962] def _bitmatrix_metadata_compute_stage2(ColSortedIndx, RowSortedIndx, NonzeroIndx, n_tokens, ColPartialSum, stride_pm,
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]                                        stride_pn, ColOffs, TOKS_PER_ROW: tl.constexpr, BLOCK_PER_TOK: tl.constexpr):
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]     BLOCK_SIZE: tl.constexpr = BLOCK_PER_TOK * TOKS_PER_ROW
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]     tl.static_assert(BLOCK_SIZE <= 32768)
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]     if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]         n_tokens = tl.load(n_tokens)
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]     nonzero_indx_size = n_tokens * TOKS_PER_ROW
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]     pid_m = tl.program_id(0)
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]     # load column indices
(Worker_TP0 pid=104638) ERROR 05-25 06:40:20 [multiproc_executor.py:962]     offs_local = tl.arange(0, BLOCK_SIZE)

@maeehart
Copy link
Copy Markdown
Contributor Author

Post-rebase MI300X validation

Built _C.abi3.so from this branch (mahartik/dsv4-rocm-mi300x-fixes-rebased,
HEAD 9ae87030) inside vllm/vllm-openai-rocm:nightly on a 1-node x 8 x MI300X
host (gfx942), then ran the model with the rebuilt extension and the branch's
python overlay.

Server config: vllm serve deepseek-ai/DeepSeek-V4-Flash --tensor-parallel-size 8 --enforce-eager --kv-cache-dtype fp8 --max-model-len 4096 --max-num-seqs 16 --gpu-memory-utilization 0.92, with VLLM_ROCM_USE_AITER=1 and
VLLM_ROCM_USE_AITER_MLA=1.

lm_eval 5-shot GSM8K, --limit 200, num_concurrent=32, against
/v1/completions:

Filter exact_match Stderr
flexible-extract 0.970 +/-0.0121
strict-match 0.970 +/-0.0121

Sanity checks before the eval also pass cleanly: "What is 17 plus 25?" -> 42,
"60 miles in 2 hours, average speed?" -> 30 mph, "Capital of France" -> Paris.

This matches and slightly improves on the 0.955 number quoted in the original
commit message of "Propagate FNUZ vs OCP gating to ROCm prefill+decode paths"
on the same single-eager configuration. Branch is good to merge from a
correctness standpoint on MI300X.

Note: the public nightly's prebuilt _C.abi3.so does not include the
csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu FP8_MAX=224 change, so
running this PR via python-overlay-only on top of the nightly produces garbage
(reads FNUZ encoder bytes as OCP). The rebuild is required for an
end-to-end validation.

dequantize_and_gather_k_cache,
quantize_and_insert_k_cache,
)
from vllm.platforms import current_platform
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is failing

-Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_matches_reference[16-2048] - AssertionError: RoPE portion not exact: 0.0009765625
==== short test summary info=
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_matches_reference[64-2048]
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[16-1-2048]
- AssertionError: RoPE portion not exact: 0.0009765625
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[16-5-2048]
AILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[64-1-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[64-5-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[16-8-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[16-64-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[64-8-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[64-64-2048] - AssertionError: Tensor-likes are not equal!
sys:1: DeprecationWarning: builtin type swigvarlink has no

10 failed, 38 passed, 16 warnings in 16.72s

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jun 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @maeehart.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models needs-rebase rocm Related to AMD ROCm v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

4 participants