[ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X#40889
[ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X#40889ChuanLi1101 wants to merge 14 commits intovllm-project:mainfrom
Conversation
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: yasong.wang <yasong.wang@inferact.ai> Signed-off-by: Zhewen Li <zhewenli@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…oject#225) Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: ganyi <ygan@amd.com> Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.
Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
- AiterSparseScratch: lazy-init persistent-mode metadata buffers,
keyed by (batch, nhead, topk, dtype) so 61 layers share one
allocation per decode step
- aiter_sparse_attn_decode: drop-in replacement handling dual-scope
attention (SWA + extra), LSE-based merging, and attn_sink correction
- Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
requires FP8)
- Fixed-stride kv_indices layout with -1 sentinels (required by
AITER persistent-mode kernels)
- deepseek_v4_attention.py:
- Add _aiter_scratch / _aiter_extra_scratch fields to __init__
- Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
routes to _forward_decode_aiter, otherwise falls back to the
existing PyTorch reference
- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py
Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the DeepSeek-V4 model architecture, including the Multi-Head Latent Attention (MLA) with horizontal fusion, a Multi-Token Predictor (MTP) draft model, and specialized kernels for quantization and rotary embeddings. Key additions include a horizontally-fused Q-norm/RoPE/KV-insert kernel, a softplus-sqrt Top-K gating kernel for MoE, and support for MXFP4 quantization. Feedback highlights critical performance concerns on the ROCm path, specifically the inefficient dequantization of the entire physical KV cache and high CPU overhead from multiple kernel launches in the dequantization logic.
| blocked_swa = self._dequantize_blocked_k_cache( | ||
| self.swa_cache_layer.kv_cache) | ||
| blocked_extra = ( | ||
| None if swa_only | ||
| else self._dequantize_blocked_k_cache(kv_cache) | ||
| ) |
There was a problem hiding this comment.
The ROCm decode path dequantizes the entire physical KV cache (both SWA and compressed portions) into bfloat16 on every generation step for every layer. In vLLM V1, the physical cache tensor contains blocks for all active requests in the system. For large context windows or high max_num_seqs, this will lead to massive memory bandwidth waste and likely cause Out-of-Memory (OOM) errors, as the dequantized cache is ~1.75x larger than the quantized one and is stored in temporary buffers. The AITER kernel should ideally consume the quantized cache directly, or dequantization should be limited to the active blocks identified in the block table for the current batch.
| for tile_idx in range(num_tiles): | ||
| cur_nope = input_nope[ | ||
| ..., tile_idx * tile_size : (tile_idx + 1) * tile_size | ||
| ].to(torch.bfloat16) | ||
| cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) | ||
| result[ | ||
| ..., tile_idx * tile_size : (tile_idx + 1) * tile_size | ||
| ] = (cur_nope * cur_scales).unsqueeze(2) | ||
| return result |
There was a problem hiding this comment.
The _dequantize_blocked_k_cache method uses a Python loop to process quantization tiles. This results in 7 separate kernel launches per layer per decode step. With 61 layers, this adds over 400 kernel launches per token generation step, significantly increasing CPU overhead and potentially bottlenecking the generation process. This dequantization logic should be implemented in a single fused Triton kernel to minimize launch overhead.
|
|
||
| def is_aiter_dsv4_decode_enabled() -> bool: | ||
| return os.environ.get( | ||
| "VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE", "0" |
There was a problem hiding this comment.
@ChuanLi1101 is there a need for a flag?
We should always enable the op if aiter can be used. Moreover, on ROCm right now, we could only support DeepSeek Sparse Attention through AITER. So, there shouldn't be a need a flag.
There was a problem hiding this comment.
| q_scale = torch.ones(1, dtype=torch.float32, device=device) | ||
| kv_scale = torch.ones(1, dtype=torch.float32, device=device) | ||
|
|
||
| _, lse = aiter.mla.mla_decode_fwd( |
There was a problem hiding this comment.
@ChuanLi1101 which version of AITER must we use?
There was a problem hiding this comment.
Validated against AITER 0.1.6.post5 on MI355X (the version pinned by the ROCm container chuali_glm51 we used for benchmarking). Concretely we rely on:
aiter.mla.mla_decode_fwd(..., return_lse=True, q_scale=, kv_scale=, work_meta_data=, work_indptr=, work_info_set=, reduce_indptr=, reduce_final_map=, reduce_partial_map=)— the FP8/FP8 persistent-mode signature ongfx950.aiter.get_mla_metadata_info_v1(...)andaiter.get_mla_metadata_v1(...)for the persistent-mode scratch layout.
Happy to add an explicit minimum-version check in _aiter_decode_one_scope (e.g. via aiter.__version__) if you'd like — let me know.
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
Update on review feedback + merge-conflict status
Re: I'd rather not reapply #40871's rebase here — that would likely diverge from what @zyongye / @whx-sjtu intend. Plan:
Happy to proceed either way — let me know if maintainers prefer I rebase the full stack now. |
| """ | ||
| AITER-accelerated sparse MLA decode for DeepSeek V4 on ROCm (MI355X / gfx950). | ||
|
|
||
| Drop-in replacement for `DeepseekV4MLAAttention._ref_sparse_attn_decode`. |
There was a problem hiding this comment.
We should ship this feature in the enablement PR.
The benefit is two-fold:
- We ship a PR that has better performance
- This will greatly cut down the torch code. The torch code in the enablement PR is actually making the PR large, and it introduces many unnecessary torch code.
|
Moving to draft for now. A small Following up with @hattie / @HexWang to validate baseline generation quality first; will flip back to ready once we have a known-good config and re-run lm_eval. |
Root cause of MI355X gibberish (vllm-project#40889 / vllm-project#40892 baseline): DeepseekCompressor._forward_old (the ROCm path) rotates the compressed K with the LAST token's position of each compressed chunk (e.g. positions [3, 7, 11, ...] for compress_ratio=4), while the NV fused kernels (_fused_kv_compress_norm_rope_insert_indexer_attn and _fused_kv_compress_norm_rope_insert_sparse_attn) rotate with the FIRST token's position of the chunk: compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO This produces a (compress_ratio - 1)-token RoPE phase offset on every cached K versus the model's training-time convention. Q is rotated at its true token position (matching NV), so each q.k inner product is rotated by ~(CR-1) extra positions, which breaks the relative-position semantics of RoPE, corrupts both rocm_fp8_mqa_logits (top-k indexer) and the sparse-MLA decode, and produces incoherent first-token output on plain prompts. Fix: align the K-cache RoPE position with the NV reference: rope_positions = (compressed_positions // CR) * CR Symptom on MI355X (TP=8, bf16, /data/DeepSeek-V4-Pro): Before: LLM().generate("What is 2+2? Answer:") -> " 1 1 1 1 1 1 ..." / " (2 2-2 2 2 2 ..." Expected after fix: LLM().generate("What is 2+2? Answer:") -> "4" / "Four" / etc. Affected paths: - head_dim == 512 (sparse-attn K cache, used by AITER/ref decode) - head_dim == 128 (indexer K cache, used by rocm_fp8_mqa_logits) Both were rotated with the wrong position; the single-line change covers both because the bug is in the shared apply_gptj_rope_ref call before the head_dim split. Test plan: 1. wsl -> ssh smci355-ccs-aus-m13-05.cs-aus.dcgpu, container chuali_glm51 2. Apply this commit to /workspace/vllm_dsv4 3. Run bench_remote/_simple_smoke.py (TP=8, eager, bf16) 4. Expect coherent text on prompts: "What is 2+2? Answer:" "The capital of France is" "Q: A robe takes 2 bolts of blue fiber and half that much white..." "Once upon a time," 5. If coherent, follow up with lm_eval gsm8k to confirm. Signed-off-by: Chuan Li <chuali@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
select_mxfp4_moe_backend in mxfp4.py references RoutingMethodType.DeepseekV4 but the symbol is not imported on the hexwang/dsv4_adapt_upstream base, raising NameError: name 'RoutingMethodType' is not defined during model init. Hexiang's recipe happens to skip this code path via --moe-backend triton_unfused CLI flag, but the LLM offline API takes the same path and trips on it. Minimal one-line fix: pull RoutingMethodType into the same multi-import block that already imports FusedMoEQuantConfig / FusedMoEQuantDesc / mxfp4_*_moe_quant_config from fused_moe.config. This fix was originally in vllm-project#40889 but got auto-merged-out during the cherry-pick onto tj/dsv4prrebase (which already had it); reintroducing when rebasing onto hexwang/dsv4_adapt_upstream which does not. Signed-off-by: Chuan Li <chuanli@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
Summary
This PR adds an AITER-accelerated sparse MLA decode path for DeepSeek V4 on AMD MI355X (gfx950), building on top of PR #40871 (ROCm DeepSeek V4 support).
The existing ROCm decode path uses a PyTorch reference implementation. This PR replaces it with AITER's persistent-mode ASM kernel (
aiter.mla.mla_decode_fwd), achieving ~2-3x decode speedup at high batch sizes while maintaining numerical correctness.Changes
New file:
vllm/v1/attention/ops/rocm_aiter_dsv4_decode.pyAiterSparseScratch: Lazy-initialized persistent-mode metadata buffers. Keyed by(batch_size, nhead, topk, dtype, kvtype)so all 61 DSv4 attention layers share one allocation per decode step, eliminating per-layer metadata rebuild overhead.aiter_sparse_attn_decode(): Drop-in replacement for_ref_sparse_attn_decode, handling:return_lse=True)kv_indiceslayout with-1sentinels (required by AITER persistent-mode kernels)Modified:
vllm/model_executor/layers/deepseek_v4_attention.py_aiter_scratch/_aiter_extra_scratchfields to__init___forward_decodenow unconditionally dispatches to the new_forward_decode_aiter()(DeepSeek sparse attention is AITER-only on ROCm; no env-var flag — see review thread).Bugfix:
vllm/model_executor/layers/fused_moe/oracle/mxfp4.pyRoutingMethodTypeimport (was causingNameErrorat runtime)Validation
aiter==0.1.6.post5Usage
(No flag required — the AITER decode path is the only ROCm sparse-attention decode path.)
Test plan
Note on merge conflicts
This PR is stacked on top of #40871; mergify-reported conflicts are inherited from that PR's commits, not from any file changed here. Will rebase once #40871 lands or is rebased.
AI assistance was used in developing this change (Claude). All code has been reviewed and validated by a human.