[ROCm][DSv4] Make AITER sparse MLA decode cudagraph-clean (follow-up to #40889)#40892
[ROCm][DSv4] Make AITER sparse MLA decode cudagraph-clean (follow-up to #40889)#40892ChuanLi1101 wants to merge 16 commits intovllm-project:mainfrom
Conversation
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: yasong.wang <yasong.wang@inferact.ai> Signed-off-by: Zhewen Li <zhewenli@inferact.ai> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
…oject#225) Signed-off-by: Yifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: ganyi <ygan@amd.com> Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.
Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
- AiterSparseScratch: lazy-init persistent-mode metadata buffers,
keyed by (batch, nhead, topk, dtype) so 61 layers share one
allocation per decode step
- aiter_sparse_attn_decode: drop-in replacement handling dual-scope
attention (SWA + extra), LSE-based merging, and attn_sink correction
- Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
requires FP8)
- Fixed-stride kv_indices layout with -1 sentinels (required by
AITER persistent-mode kernels)
- deepseek_v4_attention.py:
- Add _aiter_scratch / _aiter_extra_scratch fields to __init__
- Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
routes to _forward_decode_aiter, otherwise falls back to the
existing PyTorch reference
- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py
Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm, DeepSeek sparse attention can only run through AITER, so gating the op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value. - Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py. - Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode to dispatch unconditionally to _forward_decode_aiter. - Drop the now-unused os import and the env-var mention in the _forward_decode_aiter docstring. Signed-off-by: Chuan Li <chuan.li@amd.com> Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.
Changes
-------
* `AiterSparseScratch` now caches:
* Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
`col_arange`, `q_scale`, `kv_scale`.
* Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
`valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
`aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
`torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
`torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
`DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.
Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.
Test plan
---------
* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
cudagraph-mode (non-eager) startup log + decode parity vs eager.
AI assistance: drafted with Cursor agent; human-reviewed before
submission.
Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.
Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.
* `AiterSparseScratch.rebuild()` now only allocates buffers and stores
the static gqa/topk/dtype parameters; it no longer requires a
`kv_indptr_seed` and no longer runs the metadata builder itself.
* New `AiterSparseScratch.refresh_metadata()` reruns
`get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
* `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
`kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.
Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
- Call 1 (lens=[3,2]): success, scratch key set.
- Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
bit-identical to call 1.
- Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
expected (max abs diff = 2.39 vs identical-input call), no fault.
- Parity check vs the original non-cudagraph implementation:
max abs diff = 0.000000.
Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request implements support for the DeepseekV4 model family, introducing Multi-Head Latent Attention (MLA) and Multi-Token Predictor (MTP) capabilities. The changes include several optimized CUDA and Triton kernels for fused operations such as RMSNorm, RoPE, and quantization, alongside hypercompressed (mHC) blocks implemented via TileLang. Review feedback identifies a critical performance bottleneck where the entire SWA cache is dequantized every decode step, which should be optimized to handle dequantization on-the-fly. Additionally, the implementation of _dequantize_blocked_k_cache is noted to be incompatible with CUDA graphs due to per-call tensor allocations. A potential runtime error due to a shape mismatch in the attention output copy operation was also flagged, with a recommendation to squeeze the sequence dimension.
| blocked_swa = self._dequantize_blocked_k_cache( | ||
| self.swa_cache_layer.kv_cache) |
There was a problem hiding this comment.
Dequantizing the entire SWA cache every decode step is extremely inefficient and will cause significant performance degradation as the KV cache grows. In vLLM V1, kv_cache typically represents the entire physical block pool. Dequantizing thousands of blocks (e.g., 64MB per layer for 1000 blocks) on every step will likely exceed the latency budget for real-time inference. The AITER kernel should ideally handle dequantization on-the-fly using the provided indices, or at least only dequantize the active sliding window region.
| scratch=self._aiter_scratch, | ||
| extra_scratch=self._aiter_extra_scratch, | ||
| ) | ||
| output.copy_(attn_out.to(output.dtype)) |
There was a problem hiding this comment.
Potential shape mismatch during copy_. aiter_sparse_attn_decode is called with a 4D query q.unsqueeze(1) (shape [num_tokens, 1, num_heads, head_dim]), so it likely returns a 4D tensor. However, output is a 3D tensor (shape [num_tokens, num_heads, head_dim]). This will cause a RuntimeError during the copy_ operation. You should explicitly squeeze the sequence dimension from the attention output.
| output.copy_(attn_out.to(output.dtype)) | |
| output.copy_(attn_out.squeeze(1).to(output.dtype)) |
| result = torch.empty( | ||
| (num_blocks, block_size, 1, d), | ||
| dtype=torch.bfloat16, | ||
| device=quant_k_cache.device, | ||
| ) |
There was a problem hiding this comment.
This allocation of the result tensor on every call is not cudagraph-clean and contradicts the PR's objective of making the decode path allocation-free. Since the number of blocks in the KV cache is fixed after initialization in vLLM V1, this buffer should be pre-allocated (e.g., as part of AiterSparseScratch or as a persistent buffer in the layer) to ensure pointer stability and avoid runtime overhead during graph replay.
|
Fail on 8 x MI325x |
Root cause of MI355X gibberish (vllm-project#40889 / vllm-project#40892 baseline): DeepseekCompressor._forward_old (the ROCm path) rotates the compressed K with the LAST token's position of each compressed chunk (e.g. positions [3, 7, 11, ...] for compress_ratio=4), while the NV fused kernels (_fused_kv_compress_norm_rope_insert_indexer_attn and _fused_kv_compress_norm_rope_insert_sparse_attn) rotate with the FIRST token's position of the chunk: compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO This produces a (compress_ratio - 1)-token RoPE phase offset on every cached K versus the model's training-time convention. Q is rotated at its true token position (matching NV), so each q.k inner product is rotated by ~(CR-1) extra positions, which breaks the relative-position semantics of RoPE, corrupts both rocm_fp8_mqa_logits (top-k indexer) and the sparse-MLA decode, and produces incoherent first-token output on plain prompts. Fix: align the K-cache RoPE position with the NV reference: rope_positions = (compressed_positions // CR) * CR Symptom on MI355X (TP=8, bf16, /data/DeepSeek-V4-Pro): Before: LLM().generate("What is 2+2? Answer:") -> " 1 1 1 1 1 1 ..." / " (2 2-2 2 2 2 ..." Expected after fix: LLM().generate("What is 2+2? Answer:") -> "4" / "Four" / etc. Affected paths: - head_dim == 512 (sparse-attn K cache, used by AITER/ref decode) - head_dim == 128 (indexer K cache, used by rocm_fp8_mqa_logits) Both were rotated with the wrong position; the single-line change covers both because the bug is in the shared apply_gptj_rope_ref call before the head_dim split. Test plan: 1. wsl -> ssh smci355-ccs-aus-m13-05.cs-aus.dcgpu, container chuali_glm51 2. Apply this commit to /workspace/vllm_dsv4 3. Run bench_remote/_simple_smoke.py (TP=8, eager, bf16) 4. Expect coherent text on prompts: "What is 2+2? Answer:" "The capital of France is" "Q: A robe takes 2 bolts of blue fiber and half that much white..." "Once upon a time," 5. If coherent, follow up with lm_eval gsm8k to confirm. Signed-off-by: Chuan Li <chuali@amd.com> Made-with: Cursor
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.
Concretely:
* `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
* `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
bf16->fp8 cast into them in place.
* `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
`current_workspace_manager().get_simultaneous(...)` so they share a
single workspace allocation, mirroring how prefill already does it.
Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.
Validated on MI355X with a standalone microbench:
* bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
* `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
* pointer stable across varying per-step `lens` patterns
* shape / dtype mismatch raises as expected
Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.
Concretely:
* `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
* `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
bf16->fp8 cast into them in place.
* `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
`current_workspace_manager().get_simultaneous(...)` so they share a
single workspace allocation, mirroring how prefill already does it.
Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.
Validated on MI355X with a standalone microbench:
* bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
* `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
* pointer stable across varying per-step `lens` patterns
* shape / dtype mismatch raises as expected
Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.
Concretely:
* `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
* `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
bf16->fp8 cast into them in place.
* `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
`current_workspace_manager().get_simultaneous(...)` so they share a
single workspace allocation, mirroring how prefill already does it.
Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.
Validated on MI355X with a standalone microbench:
* bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
* `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
* pointer stable across varying per-step `lens` patterns
* shape / dtype mismatch raises as expected
Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.
Concretely:
* `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
* `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
bf16->fp8 cast into them in place.
* `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
`current_workspace_manager().get_simultaneous(...)` so they share a
single workspace allocation, mirroring how prefill already does it.
Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.
Validated on MI355X with a standalone microbench:
* bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
* `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
* pointer stable across varying per-step `lens` patterns
* shape / dtype mismatch raises as expected
Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.
Concretely:
* `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
* `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
bf16->fp8 cast into them in place.
* `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
`current_workspace_manager().get_simultaneous(...)` so they share a
single workspace allocation, mirroring how prefill already does it.
Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.
Validated on MI355X with a standalone microbench:
* bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
* `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
* pointer stable across varying per-step `lens` patterns
* shape / dtype mismatch raises as expected
Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).
Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Purpose
Cudagraph-clean follow-up to #40889. The persistent-mode AITER sparse MLA decode kernel (
mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps) is the only LSE-returning sparse attention path on MI355X / gfx950, but the original integration allocates fresh per-step indexing tensors (qo_indptr,kv_indptr,kv_indices,kv_last_page_lens,q_scale,kv_scale) and a fresh FP8 query / output buffer on every layer call. That is incompatible with HIP/CUDA-graph capture, so DSv4 currently has to run in eager mode on ROCm.This PR makes the per-step decode path completely allocation-free and pointer-stable so the decode loop can be wrapped in a HIP graph once the model wires it up.
Changes
AiterSparseScratchnow owns every per-step buffer in addition to the AITER work-plan / reduce buffers:qo_indptr,kv_indptr,kv_indices_2d,kv_last_page_lensvalid_mask,valid_lens,col_arangeq_fp8(FP8 query buffer),out_buf(BF16 output)q_scale,kv_scale(constant 1.0 tensors)_aiter_decode_one_scoperewrites all of those in-place every step (torch.lt(out=),copy_,masked_fill_,torch.cumsum(out=)).AiterSparseScratch.refresh_metadata()re-runsaiter.get_mla_metadata_v1against the currentkv_indptrevery step and writes the new work plan into the samework_*/reduce_*slots. The persistent ASM kernel encodes per-batch lengths into that plan, so leaving it stale across steps with different kv lengths causes a GPU memory-access fault.rebuild()now only allocates buffers and stores static gqa/topk/dtype parameters; it no longer re-runs the metadata builder itself.Net effect: across decode steps with the same
(total_q, nhead, topk, d_qk, d_v, dtype, kvtype)key, everydata_ptr()is stable, so a graph captured on step N can be replayed for any step N+k.Testing
bench_remote/_unit_test_cudagraph.py(run insidechuali_glm51on MI355X, GPU 5):lens=[3, 2]→ success, scratch key set.lens→rebuild()skipped, all 9 trackeddata_ptr()s identical to call 1, output bit-identical.lens=[4, 1]→ all 9data_ptr()s still identical, no fault, output differs by 2.39 max abs as expected because the valid kv set is different.0.000000.AI-assisted contribution
This PR was prepared with AI assistance (Cursor agent). The submitter reviewed every changed line and ran the unit-test harness above on MI355X.
Stacking
Depends on #40889 (AITER-accelerated MLA decode for DSv4). All four commits in this PR's range will reduce to two cudagraph-only commits once #40889 lands; happy to rebase on request.