[ROCm][DSv4] Functional fixes for DeepSeek V4 on MI300X (gfx942)#42893
[ROCm][DSv4] Functional fixes for DeepSeek V4 on MI300X (gfx942)#42893maeehart wants to merge 9 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements support for FP8 FNUZ (Finite, No NaN/Inf, Unsigned Zero) encoding on ROCm hardware, specifically targeting the gfx942 (MI300X) architecture. Key changes include updating CUDA and Triton kernels to use FNUZ-specific constants and types (tl.float8e4b8), zero-initializing workspace memory to ensure determinism in sparse attention, and adding specialized handling for UE8M0 scales in quantization utilities. Feedback from the reviewer points out a potential issue where the FP8_E4B15 flag might be incorrectly evaluated for MI300X hardware due to a capability check, which could render the FNUZ fixes ineffective. It is recommended to use current_platform.is_fp8_fnuz() to correctly gate these code paths.
| if FP8_E4B15: | ||
| k_float = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) | ||
| k_float = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32) |
There was a problem hiding this comment.
The change to use tl.float8e4b8 is correct for FNUZ hardware. However, this code is gated by if FP8_E4B15:. The FP8_E4B15 flag is determined by _use_fp8_e4b15, which checks if the device capability is less than (8, 9).
For the target hardware of this PR, MI300X (gfx942), the capability is (9, 42), so this condition is false. This means this corrected code path will not be executed on MI300X, and the fix might be ineffective.
To ensure this fix applies correctly on ROCm FNUZ hardware, the condition should be based on current_platform.is_fp8_fnuz().
There was a problem hiding this comment.
Good catch — you're right that FP8_E4B15=0 on MI300X (gfx942 reports capability (9, x) >= (8, 9)), so this branch is unreachable on the target hardware of the PR.
The right answer turned out to be the opposite of "re-gate on current_platform.is_fp8_fnuz()", though: FP8_E4B15=1 is NVIDIA Ampere/Ada (sm < 8.9, software FP8 emulation), where tl.float8e4b15 (E4M3 with bias 15) IS the correct Triton type. tl.float8e4b8 (E4M3 with bias 8) is the AMD-FNUZ-specific type — Triton on Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix elsewhere.
The original commit (2bef91e) conflated two unrelated gating constexprs:
IS_FNUZinrocm_aiter_mla_sparse.py, correctly gated oncurrent_platform.fp8_dtype() == torch.float8_e4m3fnuz/current_platform.is_fp8_fnuz()— this is the one that actually fixes the MI300X sparse-MLA decode failure cited in the commit message; it stays.FP8_E4B15in these turboquant kernels, which is1only on NVIDIA Ampere/Ada — drive-by, dead on AMD, and wrong on the NVIDIA cards it actually fires on.
Reverted the 3 turboquant lines back to tl.float8e4b15 in 6f7e155 so the Ampere/Ada FP8 path stays correct. The MI300X fix in _sparse_attn_decode_ragged_kernel is unchanged.
| if FP8_E4B15: | ||
| k_recon = k_raw.to(tl.float8e4b15, bitcast=True).to(tl.float32) | ||
| k_recon = k_raw.to(tl.float8e4b8, bitcast=True).to(tl.float32) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Good catch — you're right that FP8_E4B15=0 on MI300X (gfx942 reports capability (9, x) >= (8, 9)), so this branch is unreachable on the target hardware of the PR.
The right answer turned out to be the opposite of "re-gate on current_platform.is_fp8_fnuz()", though: FP8_E4B15=1 is NVIDIA Ampere/Ada (sm < 8.9, software FP8 emulation), where tl.float8e4b15 (E4M3 with bias 15) IS the correct Triton type. tl.float8e4b8 (E4M3 with bias 8) is the AMD-FNUZ-specific type — Triton on Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix elsewhere.
The original commit (2bef91e) conflated two unrelated gating constexprs:
IS_FNUZinrocm_aiter_mla_sparse.py, correctly gated oncurrent_platform.fp8_dtype() == torch.float8_e4m3fnuz/current_platform.is_fp8_fnuz()— this is the one that actually fixes the MI300X sparse-MLA decode failure cited in the commit message; it stays.FP8_E4B15in these turboquant kernels, which is1only on NVIDIA Ampere/Ada — drive-by, dead on AMD, and wrong on the NVIDIA cards it actually fires on.
Reverted the 3 turboquant lines back to tl.float8e4b15 in 6f7e155 so the Ampere/Ada FP8 path stays correct. The MI300X fix in _sparse_attn_decode_ragged_kernel is unchanged.
| d_mask = d_offs < D | ||
| k_vals = tl.load(Key_ptr + base + d_offs, mask=d_mask, other=0.0) | ||
| k_fp8 = k_vals.to(tl.float8e4b15) if FP8_E4B15 else k_vals.to(tl.float8e4nv) | ||
| k_fp8 = k_vals.to(tl.float8e4b8) if FP8_E4B15 else k_vals.to(tl.float8e4nv) |
There was a problem hiding this comment.
The change to use tl.float8e4b8 is correct, but it is conditional on FP8_E4B15. This flag seems to be incorrectly evaluated as false for the target hardware of this PR (MI300X, gfx942), which would make this fix ineffective. The condition should likely be based on current_platform.is_fp8_fnuz() to ensure the correct FP8 format is used on FNUZ-supporting hardware.
There was a problem hiding this comment.
Good catch — you're right that FP8_E4B15=0 on MI300X (gfx942 reports capability (9, x) >= (8, 9)), so this branch is unreachable on the target hardware of the PR.
The right answer turned out to be the opposite of "re-gate on current_platform.is_fp8_fnuz()", though: FP8_E4B15=1 is NVIDIA Ampere/Ada (sm < 8.9, software FP8 emulation), where tl.float8e4b15 (E4M3 with bias 15) IS the correct Triton type. tl.float8e4b8 (E4M3 with bias 8) is the AMD-FNUZ-specific type — Triton on Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix elsewhere.
The original commit (2bef91e) conflated two unrelated gating constexprs:
IS_FNUZinrocm_aiter_mla_sparse.py, correctly gated oncurrent_platform.fp8_dtype() == torch.float8_e4m3fnuz/current_platform.is_fp8_fnuz()— this is the one that actually fixes the MI300X sparse-MLA decode failure cited in the commit message; it stays.FP8_E4B15in these turboquant kernels, which is1only on NVIDIA Ampere/Ada — drive-by, dead on AMD, and wrong on the NVIDIA cards it actually fires on.
Reverted the 3 turboquant lines back to tl.float8e4b15 in 6f7e155 so the Ampere/Ada FP8 path stays correct. The MI300X fix in _sparse_attn_decode_ragged_kernel is unchanged.
PR vllm-project#42893 fixed the C++ SWA-K-cache encoder so it writes FNUZ E4M3 bytes on gfx942 (and OCP on gfx950) and updated the *generic* ``DeepseekV4MLAAttention._forward_prefill`` to call ``dequantize_and_gather_k_cache(..., use_fnuz=is_fp8_fnuz())`` for SWA and ``use_fnuz=False`` for the Triton-OCP-encoded compressed K cache. Two FP8-format mismatches remained on the actual ROCm DSv4 path (``DeepseekV4ROCMAiterMLASparseImpl``): 1. The public ``dequantize_and_gather_k_cache`` wrapper in ``vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py`` did not accept ``use_fnuz`` -- it silently dropped the kwarg when forwarding to ``dequantize_and_gather_k_cache_triton`` (which defaults to False). The ROCm prefill called the wrapper without ``use_fnuz``, so the SWA K cache (FNUZ on gfx942) was being read as OCP, scaling every K vector by ~448/240 in prefill attention. 2. ``_sparse_attn_decode_ragged_kernel`` in ``vllm/v1/attention/ops/rocm_aiter_mla_sparse.py`` decoded both the SWA (FNUZ on gfx942) and the compressed (always OCP) K caches with a single ``IS_FNUZ`` constexpr, so on MI300X the compressed-side branch reinterpreted OCP bytes as FNUZ -- the same encoder/decoder mismatch as (1) in the opposite direction (~240/448) on the decode side. Together these scrambled K vectors going into both prefill and decode attention, producing the GSM8K=0.005 gibberish PR vllm-project#42893 documented but could not explain with eager-vs-graphs. This commit: * Adds ``use_fnuz`` to the wrapper and forwards it to the Triton implementation (the cuteDSL path is dead on ROCm anyway). * Splits ``_sparse_attn_decode_ragged_kernel``'s ``IS_FNUZ`` into per-cache flags ``IS_FNUZ_MAIN`` (SWA) and ``IS_FNUZ_EXTRA`` (compressed) so each cache is decoded with its own encoder's format. * Wires ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill`` to pass ``use_fnuz=False`` for the compressed call (Triton-OCP encoder) and ``use_fnuz=current_platform.is_fp8_fnuz()`` for the SWA call (C++ FNUZ-on-gfx942 encoder), matching the asymmetry that PR vllm-project#42893's "[ROCm][DSv4] Fix compressed K cache dequant to match Triton OCP encoder" introduced for the generic path. Validated on 1 node x 4 x MI300X (gfx942), TP=4, VLLM_ROCM_USE_AITER=1, ``deepseek-ai/DeepSeek-V4-Flash``, both eager and CUDA-graphs ``FULL_AND_PIECEWISE`` configs from PR vllm-project#42810. GSM8K 5-shot, n=200, num_concurrent=32 against /v1/completions: | Mode | exact_match | Stderr | | ----- | ----------- | -------- | | Eager | 0.955 | +/-0.0147 | | Graph | 0.955 | +/-0.0147 | vs. the pre-fix 0.005 PR vllm-project#42893 reported on the same configuration. The two modes match each other to all three reported digits on both strict-match and flexible-extract filters. Co-authored-by: Cursor <cursoragent@cursor.com>
|
@maeehart The gibberish (GSM8K=0.005) on this PR is caused by two extra FNUZ-vs-OCP gating gaps that aren't covered by f59e042 ("Fix compressed K cache dequant to match Triton OCP encoder"). |
…ly path) PR review on vllm-project#42893 (gemini-code-assist) flagged that the three turboquant changes in commit 2bef91e ("[ROCm][DSv4] Use tl.float8e4b8 for FNUZ on MI300X sparse MLA kernels") are dead code on MI300X: they sit inside ``if FP8_E4B15:`` branches, and FP8_E4B15 is the constexpr returned by ``_use_fp8_e4b15(device)`` -- which is 1 only when ``torch.cuda.get_device_capability() < (8, 9)``. MI300X (gfx942) reports cap >= (9, x), so FP8_E4B15 = 0 on every AMD platform and the patched FNUZ branch is never executed. More importantly, the changes are *wrong* on the hardware where FP8_E4B15 = 1 -- NVIDIA Ampere/Ada (sm < 8.9). On those cards ``tl.float8e4b15`` (E4M3 with bias 15) is the correct Triton FP8 type for software emulation; ``tl.float8e4b8`` (E4M3 with bias 8) is the AMD-FNUZ-specific type and Triton on NVIDIA Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix. The original commit message conflated two unrelated gating constexprs (``IS_FNUZ`` in rocm_aiter_mla_sparse.py vs ``FP8_E4B15`` in the turboquant kernels). Only the rocm_aiter_mla_sparse.py hunks of 2bef91e are actually correct -- those are gated on ``current_platform.fp8_dtype() == torch.float8_e4m3fnuz`` / ``current_platform.is_fp8_fnuz()`` and are the ones that actually fix the MI300X sparse-MLA decode failure. Revert just the three turboquant lines back to ``tl.float8e4b15`` so the NVIDIA Ampere/Ada FP8 path is preserved. The MI300X fix in ``_sparse_attn_decode_ragged_kernel`` (the dequant/gather kernel cited in the original commit message) is unchanged. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Thanks @jin-amd for validating. I am about to comment this. Right now the mi300x sparse indexer has issue. It is causing accuracy to be 0. And furthermore, @maeehart always run the full gsm8k dataset to reduce the variance in the accuracy score. For DeepSeek Sparse Attention model (DSV4, DSV3.2 and GLM5.1) we have to validate with large concurrency e.g. 128 or 256 to ensure the sparse indexer logic is working correctly. |
On it. |
|
At the same time, I will validate this PR on MI355x and do more in depth review. |
| constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 | ||
| constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) | ||
| constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 | ||
| // Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942 |
There was a problem hiding this comment.
@maeehart This is incorrect for fnuz. the max is 224.0
vllm/vllm/model_executor/layers/quantization/utils/fp8_utils.py
Lines 412 to 417 in b12745e
This is a critical bug that was fixed last year, the range values of fp8 were incorrect last time.
| constexpr int kNumQuantBlocks = kNopeDim / kQuantBlock; // 7 | ||
| constexpr int kScaleBytesPerToken = kNumQuantBlocks + 1; // 8 (7 real + 1 pad) | ||
| constexpr int kTokenDataBytes = kNopeDim + kRopeDim * 2; // 448 + 128 = 576 | ||
| // Match the encoding chosen in rocm_cvt_float_to_fp8_e4m3: FNUZ on gfx942 |
There was a problem hiding this comment.
@maeehart please also make sure all the unit tests cases in tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py passed.
| # Zero once per call so the holes are deterministic (and harmless if | ||
| # ever indexed). The cost is one bf16 zero per chunk worth of | ||
| # workspace, which is dwarfed by the gather + attention themselves. | ||
| kv.zero_() |
There was a problem hiding this comment.
This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling kv.zero_()
| # Zero once per call so the holes are deterministic (zero attention | ||
| # contribution). The cost is one bf16 fill of the workspace tile, | ||
| # which is dwarfed by the FP8 dequant + sparse attention themselves. | ||
| kv.zero_() |
There was a problem hiding this comment.
This is rather weird, there is no need to do this on Mi355x. I would like to avoid doing this as it incurs overhead. Please try to look into the sparse indexer logic of gfx942. I believe fixing the logic there can avoid calling kv.zero_()
|
High-concurrency GSM8K validation on MI300X (per @tjtanaa's review request) Hardware: 1 node × 8 × MI300X (gfx942), HIP 6.x, container Tests run on Results
Both DSv3.2 and GLM-5.1-FP8 are stable at conc=256 in both eager and graph mode, and reproduce the expected GSM8K accuracy → sparse indexer logic survives high concurrency on those models. Eager vs graph agree to within stderr on both filters, confirming there's no graph-specific accuracy regression. DSv4-Flash: hard failure in
|
@maeehart @tjtanaa With this Aiter fix (PR already submitted): ROCm/aiter#3257, everything regarding dsv4 works fine: Setup
Accuracy (full 1319 samples)
|
|
Thanks @jin-amd a lot for filing the aiter PR. Is there a way that we get a stable triton implementation into vLLM for now as workaround for gfx942. We still need to wait for your PR to be merged and for us to upgrade aiter to a version that has your fix. |
MI300X test pass: ready to merge without the AITER bumpTested this PR on top of nightly I pushed one extra commit to this branch What the vendor commit does
The entire vendor patch is meant to be reverted once vLLM bumps to an GSM8K (1319 problems, 5-shot CoT, num_concurrent=128)That matches the published DSv4-Flash GSM8K baseline (~95.5–95.6%); Server cold-start to first token: ~170 s (model load 24 s, torch.compile One thing reviewers will hit testing this against today's
|
…llm-project#42893) Reviewer flagged that ``kv.zero_()`` papers over an algorithmic bug in the indexer/sparse-attention pipeline rather than fixing it. They are right -- the proper fix is for the indexer or the sparse-attention kernel to mask invalid M rows so a topK can never pick a hole. Both are kernel-level changes that also need to land on the NVIDIA path, out of scope for this PR. Update the comments in both ``DeepseekV4MLAAttention._forward_prefill`` and ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill_attn_impl`` to record: * The bug is arch-independent (the workspace allocator is ``torch.empty()`` on every platform; the Triton dequant kernel + rocm_sparse_attn_prefill share the same hole pattern on gfx942 and gfx950). * The symptom was first observed on MI300X with FNUZ FP8 + cudagraphs; it is not specific to gfx942. * The defensible fix lives in the indexer (score = -inf at invalid rows) or attention kernel (treat indices >= valid_len as zero contribution), not in the caller. * ``kv.zero_()`` is the cheap stop-gap: one bf16 fill of the workspace tile per layer per chunk, dwarfed by FP8 dequant + sparse attention themselves, with no correctness regression on either arch. No behavior change. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
|
@tjtanaa thanks for the careful review. I've pushed two follow-up commits addressing all three points. 1. The FNUZ branch now uses 2. Unit tests (same commit) The Triton reference encoder Verified post-fix: 36/36 unit tests pass on MI300X (gfx942) and MI355X (gfx950). 3. I checked whether the call is gfx942-specific:
So the underlying issue is arch-independent — workspace allocator + indexer + sparse-attention. The reason the symptom was loud on MI300X (10 distinct first tokens for the same temperature-0 prompt) and not noticed on MI355X is that FNUZ vs OCP score the garbage rows differently; it's not "gfx950 doesn't have this bug." I agree the proper fix lives in the indexer (set scores at invalid rows to Happy to revisit the gating (e.g. only zero on FNUZ archs) if you'd prefer — just let me know. |
…lm-project#42893) Switch the FNUZ branch of `kFp8Max` to 224.0 (was 240.0, the FNUZ dtype's raw representable max). 224.0 is what the rest of vLLM's FNUZ pipeline uses -- see `vllm/model_executor/layers/quantization/utils/fp8_utils.py:412-417`, which notes that 240.0 hurts dynamic-quant accuracy. The OCP branch (gfx950 + NVIDIA) keeps 448.0. Make the unit test honor the same split: add an optional `use_fnuz` constexpr to `quantize_and_insert_k_kernel` (default False, no production caller affected) and pick the encoding from `current_platform.is_fp8_fnuz()`. Byte-exact comparison now succeeds on both gfx942 and gfx950. Verified: 36/36 unit tests pass on MI300X (gfx942) and MI355X (gfx950). Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…review vllm-project#42893) Replace the long rationale around ``kv.zero_()`` (in both prefill paths) with a brief TODO that names the proper fix: mask invalid rows in the indexer (score = -inf) or in the sparse-attention kernel (skip indices >= valid_len). The current zero is the minimal interim workaround; the underlying bug is arch-independent (uninitialized workspace + indexer that scores the entire M dim) so the call stays unchanged on every platform until the indexer/kernel fix lands. No behavior change. Also condense the duplicate FNUZ-vs-OCP comments at the dequant call sites and in ``_sparse_attn_decode_ragged_kernel``: the wrapper docstring already explains the asymmetry, so per-call-site repetition was just noise. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
40a6ad3 to
7470489
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
@maeehart Please resolve merge conflict and rebase before we further review this PR. |
Rebase status: substantive port required, not a mechanical rebaseApologies for the delay on the rebase. The mechanical Why a mechanical rebase does not applyPR #43385 ([ROCm] [DSv4] [Perf] Support DeepSeek v4 MTP) landed on
Of this branch's 11 commits, 4 ( Why the dead-file commits still matterI checked
So the FNUZ-aware paths and the workspace-zeroing logic from this branch are still missing on
Plus the prefill-attn KV workspace and ROCm sparse-MLA prefill workspace need to be zero-initialised before gather, which is what commits Plan
Will get this done over the next session. If you have any preference on whether the port lands as a single squashed-style commit on top or as separate commits per fix, please let me know. I am leaning toward the single commit because it documents the architectural move from Also, if there are other DSv4 cleanups you have queued post-#43385 that conflict with the FNUZ-vs-OCP gating direction, please flag them so I can shape the port accordingly. cc @Ganyi |
DSv4 on AMD MI300X (gfx942) hits several FP8-related issues that this
commit addresses:
1. **fp8_utils.py**: ``process_fp8_weight_block_strategy`` calls
``normalize_e4m3fn_to_e4m3fnuz`` which doubles ``weight_scale`` by
``weight_scale * 2.0``. On models with UE8M0 scales
(``torch.float8_e8m0fnu``), that ``mul`` is not implemented on
CUDA/HIP and load aborts with::
NotImplementedError: "mul_cuda" not implemented for 'Float8_e8m0fnu'
UE8M0 stores power-of-2 exponent values (2^(exp-127)) with no
mantissa, so doubling the scale is equivalent to incrementing the
exponent byte by 1. Handle the UE8M0 case explicitly and fall back
to the float path otherwise.
2. **fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu**: gate ``kFp8Max``
to match the FNUZ/OCP path actually taken on each ROCm arch
(240 on gfx942 FNUZ, 448 on gfx950 OCP).
3. **deepseek_v4_attention.py** + **cache_utils.py**: small MI300X path
fixes that go with the FNUZ scale handling above.
Co-authored-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
``fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu`` selected its FP8 type and ``kFp8Max`` based purely on the HIP build macro ``HIP_FP8_TYPE_OCP``. That macro is set by the HIP runtime version, not by the target GPU arch -- on a HIP build that defines ``HIP_FP8_TYPE_OCP``, the kernel was using OCP E4M3 / ``448.0`` even on gfx942 (MI300X), whose MFMA instructions only accept FNUZ E4M3. The rest of vLLM's gfx942 path (Triton sparse-MLA, indexer Q quant, ``current_platform.fp8_dtype()``) all use FNUZ on this arch, so the C++ writer was producing K-cache entries the FNUZ readers misinterpret. Gate the OCP branch on ``defined(__gfx950__)`` so: * gfx942 (MI300X) -> ``__hip_fp8_e4m3_fnuz`` + ``kFp8Max = 240.0f`` * gfx950 (MI355X) -> ``__hip_fp8_e4m3`` + ``kFp8Max = 448.0f`` This matches the encoding chosen elsewhere on each arch. Signed-off-by: ganyi <ygan@amd.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
The DSv4 sparse MLA Triton kernels added in vllm-project#41812 (and the matching turboquant store/decode kernels) bitcast uint8 to ``tl.float8e4b15`` when ``IS_FNUZ`` is true. ``float8e4b15`` is not a real Triton type; on AMD gfx942 (MI300X) Triton only supports the FP8 dtypes listed in the error from triton/compiler: ('fp8e4b8', 'fp8e4nv', 'fp8e5', 'fp8e5b16') The correct FNUZ E4M3 type is ``tl.float8e4b8`` (bias 8, matches the PyTorch ``torch.float8_e4m3fnuz`` used elsewhere on the MI300 path). The non-FNUZ branch already correctly uses ``tl.float8e4nv``. Without this fix, the very first profile run on MI300X with sparse MLA fails inside the dequant/gather kernel: type fp8e4b15 not supported in this architecture. This swaps all FNUZ branches to ``tl.float8e4b8``. Verified that ``IS_FNUZ`` is gated on ``current_platform.fp8_dtype() == torch.float8_e4m3fnuz`` so it never fires on OCP hardware. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill_attn_impl`` in ``vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse_dsv4.py`` is the actual ROCm path reached from ``DeepseekV4MLAAttention.forward`` at ``deepseek_v4_attention.py:762`` (``current_platform.is_rocm()``). ``DeepseekV4MLAAttention._forward_prefill`` in the same file is dead code on ROCm, so the previous ``kv.zero_()`` patch (commit 36a7037) fixed only the generic path. This ROCm-only forward also gets ``kv`` via ``current_workspace_manager().get_simultaneous(...)`` -- uninitialized shared memory reused across requests and layers -- writes only the compressed-K prefix and the SWA window for each chunk row, then reads the entire ``kv.view(-1, 1, head_dim)`` through ragged indices that can land on the holes for very short sequences. The result is exactly the symptom we observe on MI300X DSv4-Flash: 10 identical temperature=0 ``/v1/completions`` calls produce 10 distinct first tokens. Apply the same zero-init here. Cost is one bf16 fill of the workspace tile, dwarfed by the FP8 dequant + sparse attention. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
PR vllm-project#42893 fixed the C++ SWA-K-cache encoder so it writes FNUZ E4M3 bytes on gfx942 (and OCP on gfx950) and updated the *generic* ``DeepseekV4MLAAttention._forward_prefill`` to call ``dequantize_and_gather_k_cache(..., use_fnuz=is_fp8_fnuz())`` for SWA and ``use_fnuz=False`` for the Triton-OCP-encoded compressed K cache. Two FP8-format mismatches remained on the actual ROCm DSv4 path (``DeepseekV4ROCMAiterMLASparseImpl``): 1. The public ``dequantize_and_gather_k_cache`` wrapper in ``vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py`` did not accept ``use_fnuz`` -- it silently dropped the kwarg when forwarding to ``dequantize_and_gather_k_cache_triton`` (which defaults to False). The ROCm prefill called the wrapper without ``use_fnuz``, so the SWA K cache (FNUZ on gfx942) was being read as OCP, scaling every K vector by ~448/240 in prefill attention. 2. ``_sparse_attn_decode_ragged_kernel`` in ``vllm/v1/attention/ops/rocm_aiter_mla_sparse.py`` decoded both the SWA (FNUZ on gfx942) and the compressed (always OCP) K caches with a single ``IS_FNUZ`` constexpr, so on MI300X the compressed-side branch reinterpreted OCP bytes as FNUZ -- the same encoder/decoder mismatch as (1) in the opposite direction (~240/448) on the decode side. Together these scrambled K vectors going into both prefill and decode attention, producing the GSM8K=0.005 gibberish PR vllm-project#42893 documented but could not explain with eager-vs-graphs. This commit: * Adds ``use_fnuz`` to the wrapper and forwards it to the Triton implementation (the cuteDSL path is dead on ROCm anyway). * Splits ``_sparse_attn_decode_ragged_kernel``'s ``IS_FNUZ`` into per-cache flags ``IS_FNUZ_MAIN`` (SWA) and ``IS_FNUZ_EXTRA`` (compressed) so each cache is decoded with its own encoder's format. * Wires ``DeepseekV4ROCMAiterMLASparseImpl._forward_prefill`` to pass ``use_fnuz=False`` for the compressed call (Triton-OCP encoder) and ``use_fnuz=current_platform.is_fp8_fnuz()`` for the SWA call (C++ FNUZ-on-gfx942 encoder), matching the asymmetry that PR vllm-project#42893's "[ROCm][DSv4] Fix compressed K cache dequant to match Triton OCP encoder" introduced for the generic path. Validated on 1 node x 4 x MI300X (gfx942), TP=4, VLLM_ROCM_USE_AITER=1, ``deepseek-ai/DeepSeek-V4-Flash``, both eager and CUDA-graphs ``FULL_AND_PIECEWISE`` configs from PR vllm-project#42810. GSM8K 5-shot, n=200, num_concurrent=32 against /v1/completions: | Mode | exact_match | Stderr | | ----- | ----------- | -------- | | Eager | 0.955 | +/-0.0147 | | Graph | 0.955 | +/-0.0147 | vs. the pre-fix 0.005 PR vllm-project#42893 reported on the same configuration. The two modes match each other to all three reported digits on both strict-match and flexible-extract filters. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
…ly path) PR review on vllm-project#42893 (gemini-code-assist) flagged that the three turboquant changes in commit 2bef91e ("[ROCm][DSv4] Use tl.float8e4b8 for FNUZ on MI300X sparse MLA kernels") are dead code on MI300X: they sit inside ``if FP8_E4B15:`` branches, and FP8_E4B15 is the constexpr returned by ``_use_fp8_e4b15(device)`` -- which is 1 only when ``torch.cuda.get_device_capability() < (8, 9)``. MI300X (gfx942) reports cap >= (9, x), so FP8_E4B15 = 0 on every AMD platform and the patched FNUZ branch is never executed. More importantly, the changes are *wrong* on the hardware where FP8_E4B15 = 1 -- NVIDIA Ampere/Ada (sm < 8.9). On those cards ``tl.float8e4b15`` (E4M3 with bias 15) is the correct Triton FP8 type for software emulation; ``tl.float8e4b8`` (E4M3 with bias 8) is the AMD-FNUZ-specific type and Triton on NVIDIA Ampere/Ada will reject it with the same "type not supported in this architecture" error the original commit was trying to fix. The original commit message conflated two unrelated gating constexprs (``IS_FNUZ`` in rocm_aiter_mla_sparse.py vs ``FP8_E4B15`` in the turboquant kernels). Only the rocm_aiter_mla_sparse.py hunks of 2bef91e are actually correct -- those are gated on ``current_platform.fp8_dtype() == torch.float8_e4m3fnuz`` / ``current_platform.is_fp8_fnuz()`` and are the ones that actually fix the MI300X sparse-MLA decode failure. Revert just the three turboquant lines back to ``tl.float8e4b15`` so the NVIDIA Ampere/Ada FP8 path is preserved. The MI300X fix in ``_sparse_attn_decode_ragged_kernel`` (the dequant/gather kernel cited in the original commit message) is unchanged. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
The AITER wrapper bundled in the currently-pinned aiter wheel launches fp8_mqa_logits with (BLOCK_KV=128, num_stages=2) on gfx942. For the DSv4 sparse indexer shape (NUM_HEADS=64, HEAD_SIZE=128) this double-buffered KV tile + fp32 scores accumulator + Q tile pushes Triton's LDS request to 96 KiB, which exceeds MI300X's 64 KiB per CU. The launch JIT-aborts with OutOfResources on the first inference. The fix is upstreamed as ROCm/aiter#3257 but until vLLM bumps to an AITER version that contains it, this patch ships the same kernel + tile-size logic vendored into vllm/. - Add vllm/v1/attention/ops/triton_fp8_mqa_logits.py with a byte-for-byte copy of AITER's @triton.jit kernel and a Python wrapper that selects (BLOCK_KV=64, num_stages=1) (~33 KiB) when the default tile would not fit on gfx942 (see module docstring for the LDS budget calculation). - Route rocm_fp8_mqa_logits to the vendored kernel on gfx942 when AITER ops are enabled. gfx950+ and CUDA still use the upstream AITER wrapper (which has dedicated Gluon kernels this vendor copy does not include). - Fix a latent broadcasting bug in the torch reference fallback: the per-KV-token scale arrives as [N, 1] (a [N, 4] uint8 buffer view-cast to fp32) and was being multiplied against an [H, M, N] score tensor where PyTorch right-aligns [N, 1] against the M dim. Flatten to [N] so the multiply lines up with the last axis. Also drop a hard-coded 'cuda' device on the index tensors so the fallback works on ROCm with HIP devices. This entire patch is intended to be reverted once vLLM picks up an AITER version that includes ROCm/aiter#3257. Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
…lm-project#42893) Switch the FNUZ branch of `kFp8Max` to 224.0 (was 240.0, the FNUZ dtype's raw representable max). 224.0 is what the rest of vLLM's FNUZ pipeline uses -- see `vllm/model_executor/layers/quantization/utils/fp8_utils.py:412-417`, which notes that 240.0 hurts dynamic-quant accuracy. The OCP branch (gfx950 + NVIDIA) keeps 448.0. Make the unit test honor the same split: add an optional `use_fnuz` constexpr to `quantize_and_insert_k_kernel` (default False, no production caller affected) and pick the encoding from `current_platform.is_fp8_fnuz()`. Byte-exact comparison now succeeds on both gfx942 and gfx950. Verified: 36/36 unit tests pass on MI300X (gfx942) and MI355X (gfx950). Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
…review vllm-project#42893) Replace the long rationale around ``kv.zero_()`` (in both prefill paths) with a brief TODO that names the proper fix: mask invalid rows in the indexer (score = -inf) or in the sparse-attention kernel (skip indices >= valid_len). The current zero is the minimal interim workaround; the underlying bug is arch-independent (uninitialized workspace + indexer that scores the entire M dim) so the call stays unchanged on every platform until the indexer/kernel fix lands. No behavior change. Also condense the duplicate FNUZ-vs-OCP comments at the dequant call sites and in ``_sparse_attn_decode_ragged_kernel``: the wrapper docstring already explains the asymmetry, so per-call-site repetition was just noise. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
7470489 to
9ae8703
Compare
Rebase doneForce-pushed What ended up on the branch
How the conflict resolvedOf the original 11 commits on the pre-rebase branch, 2 (
The other 9 commits applied with one trivial The Validation statusI am following up with a re-run of the GSM8K accuracy test on MI300X (gfx942) with the rebased branch and the public @tjtanaa @Ganyi this is ready for re-review when you have a moment. |
|
Before the latest rebase , on Launch command with cudagraph fails. I encounter this error on both Flash and Pro model on mi300x. |
Post-rebase MI300X validationBuilt Server config:
Sanity checks before the eval also pass cleanly: "What is 17 plus 25?" -> 42, This matches and slightly improves on the 0.955 number quoted in the original Note: the public nightly's prebuilt |
| dequantize_and_gather_k_cache, | ||
| quantize_and_insert_k_cache, | ||
| ) | ||
| from vllm.platforms import current_platform |
There was a problem hiding this comment.
This test is failing
-Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_matches_reference[16-2048] - AssertionError: RoPE portion not exact: 0.0009765625
==== short test summary info=
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_matches_reference[64-2048]
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[16-1-2048]
- AssertionError: RoPE portion not exact: 0.0009765625
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[16-5-2048]
AILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[64-1-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_kv_path_with_dp_padding[64-5-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[16-8-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[16-64-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[64-8-2048] - AssertionError: Tensor-likes are not equal!
FAILED tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py :: test_combined_q_and_kv[64-64-2048] - AssertionError: Tensor-likes are not equal!
sys:1: DeprecationWarning: builtin type swigvarlink has no
10 failed, 38 passed, 16 warnings in 16.72s
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Builds on #42810 to bring DeepSeek V4 (Pro and Flash) to a working state on AMD MI300X (gfx942). With main + #42810 alone, the model fails to load on MI300X — both in eager and in cudagraphs (
FULL_AND_PIECEWISE) modes — with:and once that is past, the first sparse-MLA forward errors with:
These come from a mix of an MX-format scale path that assumes float arithmetic, a HIP build macro that doesn't actually correspond to the GPU arch, a non-existent Triton FP8 dtype, an asymmetric K-cache encoder/decoder pair, and prefill workspaces that aren't zeroed before reuse. After the load-time fixes were in, two more FP8-format mismatches were caught on the actual ROCm DSv4 path (
DeepseekV4ROCMAiterMLASparseImpl): the public dequant wrapper silently droppeduse_fnuz, and the decode kernel used a singleIS_FNUZfor both the FNUZ SWA cache and the always-OCP compressed cache, scaling K vectors by ~448/240 in prefill and ~240/448 in decode. Each issue is fixed in its own commit so they can be reviewed independently.#42810 fixed the
ffn_normregression and the AITER MHC accuracy issue at high concurrency; this PR is the MI300X-specific complement, and with all 7 commits in place GSM8K accuracy matches #42810's reference number to all three reported digits (see test results below).Changes (8 commits)
[ROCm][DSv4] MI300X (gfx942) support for DeepSeek V4(authored by ganyi)fp8_utils.py: handletorch.float8_e8m0fnuweight scales inprocess_fp8_weight_block_strategyby incrementing the UE8M0 exponent byte instead of* 2.0(which is unimplemented on CUDA/HIP for that dtype). This is the fix for the load-time crash above.fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu: keepkFp8Maxconsistent with the FP8 dtype actually emitted on each arch.deepseek_v4_attention.pyandcache_utils.pyso the MI300X path stays internally consistent.[ROCm][DSv4] Use FNUZ FP8 on gfx942 in fused KV insert kernel(authored by ganyi)__hip_fp8_e4m3and__hip_fp8_e4m3_fnuzbased purely onHIP_FP8_TYPE_OCP. That macro is a property of the HIP runtime version, not the target arch — on a HIP build that defines it, the kernel writes OCP bytes even on gfx942 hardware that only supports FNUZ MFMA. Gate the OCP path additionally on__gfx950__so the encoding matches the arch (FNUZ + 240on gfx942,OCP + 448on gfx950).[ROCm][DSv4] Use tl.float8e4b8 for FNUZ on MI300X sparse MLA kernels_sparse_attn_decode_ragged_kernel) bitcastuint8 -> tl.float8e4b15whenIS_FNUZis true.float8e4b15is not a real Triton type on gfx942 — Triton on MI300X only supportsfp8e4b8,fp8e4nv,fp8e5,fp8e5b16. The correct FNUZ E4M3 type istl.float8e4b8(bias 8, matches the PyTorchtorch.float8_e4m3fnuzused elsewhere on the MI300 path).IS_FNUZhere is correctly gated oncurrent_platform.fp8_dtype() == torch.float8_e4m3fnuz/current_platform.is_fp8_fnuz()so it never fires on OCP hardware.[ROCm][DSv4] Fix compressed K cache dequant to match Triton OCP encodercompressed_k_cacheis written by Triton (_fused_kv_compress_norm_rope_insert_sparse_attn) using OCP-style E4M3 encoding, but the C++ dequant indeepseek_v4_attention.pywas reading it as FNUZ on gfx942, giving 2x scale mismatch on the compressed K side. This commit makes the dequant match the encoder.[ROCm][DSv4] Zero prefill-attn KV workspace before gather[ROCm][DSv4] Zero ROCm sparse-MLA prefill KV workspacetorch.empty-backed workspace view fromcurrent_workspace_manager().get_simultaneousand then scatter into a subset of its rows. The remaining rows are read by the attention kernel and contain uninitialized memory from earlier layers/requests. Zero the workspace before the scatter so the unused rows are deterministic.[ROCm][DSv4] Propagate FNUZ vs OCP gating to ROCm prefill+decode paths(authored by jin-amd, validated the GSM8K result reported below)multi_query_kv_attention#4 fixed for the generic prefill path remained on the actual ROCm DSv4 path (DeepseekV4ROCMAiterMLASparseImpl):dequantize_and_gather_k_cachewrapper incache_utils.pydid not acceptuse_fnuz— it silently dropped the kwarg when forwarding todequantize_and_gather_k_cache_triton(which defaults toFalse). The ROCm prefill called the wrapper withoutuse_fnuz, so the SWA K cache (FNUZ on gfx942) was being read as OCP, scaling every K vector by ~448/240 in prefill._sparse_attn_decode_ragged_kernelinrocm_aiter_mla_sparse.pydecoded both the SWA (FNUZ on gfx942) and the compressed (always OCP) K caches with a singleIS_FNUZconstexpr, so on MI300X the compressed-side branch reinterpreted OCP bytes as FNUZ — the same encoder/decoder mismatch in the opposite direction (~240/448) on the decode side.use_fnuzto the wrapper and forwards it to the Triton implementation; splitsIS_FNUZintoIS_FNUZ_MAIN(SWA) andIS_FNUZ_EXTRA(compressed) so each cache is decoded with its own encoder's format; wiresDeepseekV4ROCMAiterMLASparseImpl._forward_prefillto passuse_fnuz=Falsefor the compressed (Triton-OCP) call anduse_fnuz=current_platform.is_fp8_fnuz()for the SWA (C++ FNUZ-on-gfx942) call. This is the commit that closes the GSM8K=0.005 accuracy gap originally documented below.[ROCm][DSv4] Revert turboquant fp8e4b15 -> fp8e4b8 changes (NVIDIA-only path)triton_turboquant_decode.py/triton_turboquant_store.py. Per review (gemini-code-assist), those sit insideif FP8_E4B15:branches whose constexpr is1only whentorch.cuda.get_device_capability() < (8, 9)— i.e. NVIDIA Ampere/Ada, wheretl.float8e4b15IS the correct Triton FP8 type andtl.float8e4b8would be rejected.FP8_E4B15is always0on AMD (gfx942 reports cap(9, x)), so the change was both unreachable on MI300X and incorrect on its actual NVIDIA target. The real MI300X sparse-MLA fix inrocm_aiter_mla_sparse.pyis gated on the right constexpr (IS_FNUZfromcurrent_platform.is_fp8_fnuz()) and is unaffected.Diff stats
Total: +126 / -13, ROCm-only paths (gated by
current_platform.is_rocm()/current_platform.is_fp8_fnuz()/IS_FNUZ/__gfx950__). No CUDA / sm90 path changes.Test plan
Hardware: 1 node x 4 x MI300X (gfx942), HIP 6.x, vLLM container built from this branch on top of
599e75f43.Model:
deepseek-ai/DeepSeek-V4-Flash, TP=4,VLLM_ROCM_USE_AITER=1.Both eager and the cudagraphs (
FULL_AND_PIECEWISE) config from #42810 are exercised, so reviewers don't have to infer non-eager status from the eager test alone.Config A: eager mode
vllm serve deepseek-ai/DeepSeek-V4-Flash \ --tensor-parallel-size 4 \ --kv-cache-dtype fp8 \ --max-model-len 32768 \ --max-num-batched-tokens 8192 \ --gpu-memory-utilization 0.85 \ --distributed-executor-backend mp \ --trust-remote-code \ --enforce-eager \ --host 0.0.0.0 --port 8000Config B: cudagraphs
FULL_AND_PIECEWISE(matches #42810's reported config, TP scaled to 4)vllm serve deepseek-ai/DeepSeek-V4-Flash \ --tensor-parallel-size 4 \ --kv-cache-dtype fp8_e4m3 \ --block-size 256 \ --max-model-len 32768 \ --max-num-seqs 256 \ --distributed-executor-backend mp \ --trust-remote-code \ --gpu-memory-utilization 0.6 \ --moe-backend triton_unfused \ --tokenizer-mode deepseek_v4 \ --async-scheduling \ --enable-prefix-caching \ --compilation-config '{"mode":3,"cudagraph_mode":"FULL_AND_PIECEWISE"}' \ --host 0.0.0.0 --port 8000lm_eval(GSM8K 5-shot, limit=200, run against/v1/completionswithnum_concurrent=32).Test results
Before this PR (pure
main@599e75f43, post-#42810)Server fails to start during weight loading, identically in both modes:
After this PR
fp8e4b15Triton compile)exact_matchBoth modes go from "crashes at load" to "loads, serves, is deterministic per-mode, and matches #42810's TP=8 V4-Flash reference number (0.95) to all three reported digits". Config A and Config B also match each other to all three digits on both
strict-matchandflexible-extractfilters, confirming there is no eager-vs-cudagraphs accuracy gap on this path.Related
ffn_normregression [ROCm] [Bugfix] Fix DeepSeek V4 Functionality and Accuracy #42810 fixes)Co-authored-by: ganyi ygan@amd.com
Co-authored-by: Jin Tao jintao12@amd.com