Skip to content

[ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X#40889

Draft
ChuanLi1101 wants to merge 14 commits intovllm-project:mainfrom
ChuanLi1101:rocm/aiter-mla-dsv4-decode
Draft

[ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X#40889
ChuanLi1101 wants to merge 14 commits intovllm-project:mainfrom
ChuanLi1101:rocm/aiter-mla-dsv4-decode

Conversation

@ChuanLi1101
Copy link
Copy Markdown
Collaborator

@ChuanLi1101 ChuanLi1101 commented Apr 25, 2026

Summary

This PR adds an AITER-accelerated sparse MLA decode path for DeepSeek V4 on AMD MI355X (gfx950), building on top of PR #40871 (ROCm DeepSeek V4 support).

The existing ROCm decode path uses a PyTorch reference implementation. This PR replaces it with AITER's persistent-mode ASM kernel (aiter.mla.mla_decode_fwd), achieving ~2-3x decode speedup at high batch sizes while maintaining numerical correctness.

Changes

New file: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py

  • AiterSparseScratch: Lazy-initialized persistent-mode metadata buffers. Keyed by (batch_size, nhead, topk, dtype, kvtype) so all 61 DSv4 attention layers share one allocation per decode step, eliminating per-layer metadata rebuild overhead.
  • aiter_sparse_attn_decode(): Drop-in replacement for _ref_sparse_attn_decode, handling:
    • Dual-scope attention (SWA + extra blocked K) with LSE-based output merging
    • Attention sink correction using LSE values from the kernel
    • FP8/FP8 input casting (required by gfx950 persistent-mode + return_lse=True)
    • Fixed-stride kv_indices layout with -1 sentinels (required by AITER persistent-mode kernels)

Modified: vllm/model_executor/layers/deepseek_v4_attention.py

  • Added _aiter_scratch / _aiter_extra_scratch fields to __init__
  • On ROCm, _forward_decode now unconditionally dispatches to the new _forward_decode_aiter() (DeepSeek sparse attention is AITER-only on ROCm; no env-var flag — see review thread).

Bugfix: vllm/model_executor/layers/fused_moe/oracle/mxfp4.py

  • Added missing RoutingMethodType import (was causing NameError at runtime)

Validation

  • Numerical correctness: Cosine similarity > 0.999 across all TP configurations (TP2/h_q=64, TP4/h_q=32, TP8/h_q=16) for both SWA-only and SWA+extra dual-scope scenarios
  • Micro-benchmark: 2.4x speedup at batch_size=128 (dual-scope, FP8) on MI355X
  • E2E smoke test: Server starts and serves requests successfully on MI355X with TP4
  • AITER version: Validated against aiter==0.1.6.post5

Usage

vllm serve /path/to/DeepSeek-V4 --tensor-parallel-size 4 --trust-remote-code

(No flag required — the AITER decode path is the only ROCm sparse-attention decode path.)

Test plan

  • Numerical validation against PyTorch reference (cosine > 0.999)
  • Micro-benchmark showing speedup
  • E2E server smoke test (model loads, generates text)
  • Full benchmark suite with concurrent requests
  • CI tests on ROCm

Note on merge conflicts

This PR is stacked on top of #40871; mergify-reported conflicts are inherited from that PR's commits, not from any file changed here. Will rebase once #40871 lands or is rebased.

AI assistance was used in developing this change (Claude). All code has been reviewed and validated by a human.

zyongye and others added 13 commits April 24, 2026 02:58
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: yasong.wang <yasong.wang@inferact.ai>
Signed-off-by: Zhewen Li <zhewenli@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: ganyi <ygan@amd.com>
Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.

Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
  - AiterSparseScratch: lazy-init persistent-mode metadata buffers,
    keyed by (batch, nhead, topk, dtype) so 61 layers share one
    allocation per decode step
  - aiter_sparse_attn_decode: drop-in replacement handling dual-scope
    attention (SWA + extra), LSE-based merging, and attn_sink correction
  - Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
    requires FP8)
  - Fixed-stride kv_indices layout with -1 sentinels (required by
    AITER persistent-mode kernels)

- deepseek_v4_attention.py:
  - Add _aiter_scratch / _aiter_extra_scratch fields to __init__
  - Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
    routes to _forward_decode_aiter, otherwise falls back to the
    existing PyTorch reference

- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py

Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ChuanLi1101.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the DeepSeek-V4 model architecture, including the Multi-Head Latent Attention (MLA) with horizontal fusion, a Multi-Token Predictor (MTP) draft model, and specialized kernels for quantization and rotary embeddings. Key additions include a horizontally-fused Q-norm/RoPE/KV-insert kernel, a softplus-sqrt Top-K gating kernel for MoE, and support for MXFP4 quantization. Feedback highlights critical performance concerns on the ROCm path, specifically the inefficient dequantization of the entire physical KV cache and high CPU overhead from multiple kernel launches in the dequantization logic.

Comment on lines +1166 to +1171
blocked_swa = self._dequantize_blocked_k_cache(
self.swa_cache_layer.kv_cache)
blocked_extra = (
None if swa_only
else self._dequantize_blocked_k_cache(kv_cache)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The ROCm decode path dequantizes the entire physical KV cache (both SWA and compressed portions) into bfloat16 on every generation step for every layer. In vLLM V1, the physical cache tensor contains blocks for all active requests in the system. For large context windows or high max_num_seqs, this will lead to massive memory bandwidth waste and likely cause Out-of-Memory (OOM) errors, as the dequantized cache is ~1.75x larger than the quantized one and is stored in temporary buffers. The AITER kernel should ideally consume the quantized cache directly, or dequantization should be limited to the active blocks identified in the block table for the current batch.

Comment on lines +1242 to +1250
for tile_idx in range(num_tiles):
cur_nope = input_nope[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].to(torch.bfloat16)
cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1)
result[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
] = (cur_nope * cur_scales).unsqueeze(2)
return result
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _dequantize_blocked_k_cache method uses a Python loop to process quantization tiles. This results in 7 separate kernel launches per layer per decode step. With 61 layers, this adds over 400 kernel launches per token generation step, significantly increasing CPU overhead and potentially bottlenecking the generation process. This dequantization logic should be implemented in a single fused Triton kernel to minimize launch overhead.


def is_aiter_dsv4_decode_enabled() -> bool:
return os.environ.get(
"VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE", "0"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChuanLi1101 is there a need for a flag?

We should always enable the op if aiter can be used. Moreover, on ROCm right now, we could only support DeepSeek Sparse Attention through AITER. So, there shouldn't be a need a flag.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, thanks @tjtanaa. Removed the flag in b3a4a44is_aiter_dsv4_decode_enabled() is gone from rocm_aiter_dsv4_decode.py, and DeepseekV4MLAAttention._forward_decode now always dispatches to _forward_decode_aiter on ROCm. The PR description has been updated accordingly.

q_scale = torch.ones(1, dtype=torch.float32, device=device)
kv_scale = torch.ones(1, dtype=torch.float32, device=device)

_, lse = aiter.mla.mla_decode_fwd(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChuanLi1101 which version of AITER must we use?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validated against AITER 0.1.6.post5 on MI355X (the version pinned by the ROCm container chuali_glm51 we used for benchmarking). Concretely we rely on:

  • aiter.mla.mla_decode_fwd(..., return_lse=True, q_scale=, kv_scale=, work_meta_data=, work_indptr=, work_info_set=, reduce_indptr=, reduce_final_map=, reduce_partial_map=) — the FP8/FP8 persistent-mode signature on gfx950.
  • aiter.get_mla_metadata_info_v1(...) and aiter.get_mla_metadata_v1(...) for the persistent-mode scratch layout.

Happy to add an explicit minimum-version check in _aiter_decode_one_scope (e.g. via aiter.__version__) if you'd like — let me know.

Address review feedback (tjtanaa, vllm-project#40889): on ROCm,
DeepSeek sparse attention can only run through AITER, so gating the
op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value.

- Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from
  vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py.
- Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode
  to dispatch unconditionally to _forward_decode_aiter.
- Drop the now-unused os import and the env-var mention in the
  _forward_decode_aiter docstring.

Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
@ChuanLi1101
Copy link
Copy Markdown
Collaborator Author

Update on review feedback + merge-conflict status

  • @tjtanaa flag comment — addressed in b3a4a44. The VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE flag and is_aiter_dsv4_decode_enabled() helper are removed; on ROCm _forward_decode now unconditionally dispatches to _forward_decode_aiter.
  • @tjtanaa AITER version question — replied inline (validated against aiter==0.1.6.post5).

Re: needs-rebase/merge conflicts — this PR is currently stacked on top of #40871 (hexwang/dsv4_adapt_upstream @ 88986f79c), which is itself in MERGEABLE: CONFLICTING / mergeStateStatus: DIRTY. All 23 conflicting paths reported by mergify (e.g. vllm/model_executor/layers/fused_moe/runner/*, vllm/v1/kv_cache_interface.py, vllm/v1/spec_decode/eagle.py, requirements/cuda.txt, …) are introduced by #40871's commits and have nothing to do with this PR's three changed files (rocm_aiter_dsv4_decode.py, deepseek_v4_attention.py, fused_moe/oracle/mxfp4.py).

I'd rather not reapply #40871's rebase here — that would likely diverge from what @zyongye / @whx-sjtu intend. Plan:

  1. Once [New Model][ROCm] Add AMD support for DeepSeek V4 #40871 is rebased onto current main, I'll rebase this branch onto the new hexwang/dsv4_adapt_upstream head and force-push; remaining conflicts (if any) should be limited to my files.
  2. Alternatively, after [New Model][ROCm] Add AMD support for DeepSeek V4 #40871 lands on main this PR becomes a trivial 3-file diff against main.

Happy to proceed either way — let me know if maintainers prefer I rebase the full stack now.

"""
AITER-accelerated sparse MLA decode for DeepSeek V4 on ROCm (MI355X / gfx950).

Drop-in replacement for `DeepseekV4MLAAttention._ref_sparse_attn_decode`.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should ship this feature in the enablement PR.

The benefit is two-fold:

  • We ship a PR that has better performance
  • This will greatly cut down the torch code. The torch code in the enablement PR is actually making the PR large, and it introduces many unnecessary torch code.

@ChuanLi1101
Copy link
Copy Markdown
Collaborator Author

Moving to draft for now.

A small lm_eval + plain LLM().generate(...) smoke on MI355X (TP=8, bf16) shows that both the AITER decode path in this PR and the baseline _ref_sparse_attn_decode path on hexwang/dsv4_adapt produce incoherent text on real prompts (gsm8k limit=5 → exact_match=0.0 on both). So the correctness gap is not specific to this PR — it appears to be a pre-existing baseline issue on MI355X.

Following up with @hattie / @HexWang to validate baseline generation quality first; will flip back to ready once we have a known-good config and re-run lm_eval.

ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 26, 2026
Root cause of MI355X gibberish (vllm-project#40889 / vllm-project#40892 baseline):
DeepseekCompressor._forward_old (the ROCm path) rotates the
compressed K with the LAST token's position of each compressed
chunk (e.g. positions [3, 7, 11, ...] for compress_ratio=4),
while the NV fused kernels
(_fused_kv_compress_norm_rope_insert_indexer_attn and
_fused_kv_compress_norm_rope_insert_sparse_attn) rotate with the
FIRST token's position of the chunk:
    compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO

This produces a (compress_ratio - 1)-token RoPE phase offset on
every cached K versus the model's training-time convention.
Q is rotated at its true token position (matching NV), so each
q.k inner product is rotated by ~(CR-1) extra positions, which
breaks the relative-position semantics of RoPE, corrupts both
rocm_fp8_mqa_logits (top-k indexer) and the sparse-MLA decode,
and produces incoherent first-token output on plain prompts.

Fix: align the K-cache RoPE position with the NV reference:
    rope_positions = (compressed_positions // CR) * CR

Symptom on MI355X (TP=8, bf16, /data/DeepSeek-V4-Pro):
  Before:
    LLM().generate("What is 2+2? Answer:")
       -> " 1 1 1 1 1 1 ..."  /  " (2 2-2 2 2 2 ..."
  Expected after fix:
    LLM().generate("What is 2+2? Answer:")
       -> "4" / "Four" / etc.

Affected paths:
  - head_dim == 512 (sparse-attn K cache, used by AITER/ref decode)
  - head_dim == 128 (indexer K cache, used by rocm_fp8_mqa_logits)
Both were rotated with the wrong position; the single-line change
covers both because the bug is in the shared apply_gptj_rope_ref
call before the head_dim split.

Test plan:
  1. wsl -> ssh smci355-ccs-aus-m13-05.cs-aus.dcgpu, container chuali_glm51
  2. Apply this commit to /workspace/vllm_dsv4
  3. Run bench_remote/_simple_smoke.py (TP=8, eager, bf16)
  4. Expect coherent text on prompts:
       "What is 2+2? Answer:"
       "The capital of France is"
       "Q: A robe takes 2 bolts of blue fiber and half that much white..."
       "Once upon a time,"
  5. If coherent, follow up with lm_eval gsm8k to confirm.

Signed-off-by: Chuan Li <chuali@amd.com>
Made-with: Cursor
ChuanLi1101 pushed a commit to ChuanLi1101/vllm that referenced this pull request Apr 27, 2026
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Address review feedback (tjtanaa, vllm-project#40889): on ROCm,
DeepSeek sparse attention can only run through AITER, so gating the
op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value.

- Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from
  vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py.
- Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode
  to dispatch unconditionally to _forward_decode_aiter.
- Drop the now-unused os import and the env-var mention in the
  _forward_decode_aiter docstring.

Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
ChuanLi1101 pushed a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Address review feedback (tjtanaa, vllm-project#40889): on ROCm,
DeepSeek sparse attention can only run through AITER, so gating the
op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value.

- Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from
  vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py.
- Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode
  to dispatch unconditionally to _forward_decode_aiter.
- Drop the now-unused os import and the env-var mention in the
  _forward_decode_aiter docstring.

Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
ChuanLi1101 pushed a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
ChuanLi1101 pushed a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
select_mxfp4_moe_backend in mxfp4.py references
RoutingMethodType.DeepseekV4 but the symbol is not imported on the
hexwang/dsv4_adapt_upstream base, raising NameError: name
'RoutingMethodType' is not defined during model init. Hexiang's recipe
happens to skip this code path via --moe-backend triton_unfused CLI
flag, but the LLM offline API takes the same path and trips on it.

Minimal one-line fix: pull RoutingMethodType into the same multi-import
block that already imports FusedMoEQuantConfig / FusedMoEQuantDesc /
mxfp4_*_moe_quant_config from fused_moe.config.

This fix was originally in vllm-project#40889 but got auto-merged-out during the
cherry-pick onto tj/dsv4prrebase (which already had it); reintroducing
when rebasing onto hexwang/dsv4_adapt_upstream which does not.

Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
ChuanLi1101 pushed a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation gpt-oss Related to GPT-OSS models kv-connector needs-rebase new-model Requests to new models nvidia performance Performance-related issues rocm Related to AMD ROCm speculative-decoding tool-calling v1

Projects

Status: Todo
Status: No status
Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

7 participants