Skip to content

[ROCm][DSv4] Make AITER sparse MLA decode cudagraph-clean (follow-up to #40889)#40892

Draft
ChuanLi1101 wants to merge 16 commits intovllm-project:mainfrom
ChuanLi1101:rocm/aiter-mla-dsv4-decode-cudagraph
Draft

[ROCm][DSv4] Make AITER sparse MLA decode cudagraph-clean (follow-up to #40889)#40892
ChuanLi1101 wants to merge 16 commits intovllm-project:mainfrom
ChuanLi1101:rocm/aiter-mla-dsv4-decode-cudagraph

Conversation

@ChuanLi1101
Copy link
Copy Markdown
Collaborator

Purpose

Cudagraph-clean follow-up to #40889. The persistent-mode AITER sparse MLA decode kernel (mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps) is the only LSE-returning sparse attention path on MI355X / gfx950, but the original integration allocates fresh per-step indexing tensors (qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, q_scale, kv_scale) and a fresh FP8 query / output buffer on every layer call. That is incompatible with HIP/CUDA-graph capture, so DSv4 currently has to run in eager mode on ROCm.

This PR makes the per-step decode path completely allocation-free and pointer-stable so the decode loop can be wrapped in a HIP graph once the model wires it up.

Changes

  • AiterSparseScratch now owns every per-step buffer in addition to the AITER work-plan / reduce buffers:
  • qo_indptr, kv_indptr, kv_indices_2d, kv_last_page_lens
  • valid_mask, valid_lens, col_arange
  • q_fp8 (FP8 query buffer), out_buf (BF16 output)
  • q_scale, kv_scale (constant 1.0 tensors)
  • _aiter_decode_one_scope rewrites all of those in-place every step (torch.lt(out=), copy_, masked_fill_, torch.cumsum(out=)).
  • AiterSparseScratch.refresh_metadata() re-runs aiter.get_mla_metadata_v1 against the current kv_indptr every step and writes the new work plan into the same work_* / reduce_* slots. The persistent ASM kernel encodes per-batch lengths into that plan, so leaving it stale across steps with different kv lengths causes a GPU memory-access fault.
  • rebuild() now only allocates buffers and stores static gqa/topk/dtype parameters; it no longer re-runs the metadata builder itself.

Net effect: across decode steps with the same (total_q, nhead, topk, d_qk, d_v, dtype, kvtype) key, every data_ptr() is stable, so a graph captured on step N can be replayed for any step N+k.

Testing

bench_remote/_unit_test_cudagraph.py (run inside chuali_glm51 on MI355X, GPU 5):

  • Call 1 with lens=[3, 2] → success, scratch key set.
  • Call 2 with the same lensrebuild() skipped, all 9 tracked data_ptr()s identical to call 1, output bit-identical.
  • Call 3 with lens=[4, 1] → all 9 data_ptr()s still identical, no fault, output differs by 2.39 max abs as expected because the valid kv set is different.
  • Parity vs. the original ([ROCm] Add AITER-accelerated MLA decode for DeepSeek V4 on MI355X #40889) implementation on the same inputs: max abs diff = 0.000000.

AI-assisted contribution

This PR was prepared with AI assistance (Cursor agent). The submitter reviewed every changed line and ran the unit-test harness above on MI355X.

Stacking

Depends on #40889 (AITER-accelerated MLA decode for DSv4). All four commits in this PR's range will reduce to two cudagraph-only commits once #40889 lands; happy to rebase on request.

zyongye and others added 16 commits April 24, 2026 02:58
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: yasong.wang <yasong.wang@inferact.ai>
Signed-off-by: Zhewen Li <zhewenli@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: ganyi <ygan@amd.com>
Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Replace the PyTorch reference sparse MLA decode with AITER's
persistent-mode ASM kernel (aiter.mla.mla_decode_fwd) on gfx950.
This gives ~2-3x decode speedup at high batch sizes.

Key changes:
- New module: vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py
  - AiterSparseScratch: lazy-init persistent-mode metadata buffers,
    keyed by (batch, nhead, topk, dtype) so 61 layers share one
    allocation per decode step
  - aiter_sparse_attn_decode: drop-in replacement handling dual-scope
    attention (SWA + extra), LSE-based merging, and attn_sink correction
  - Uses FP8/FP8 path only (gfx950 persistent-mode + return_lse
    requires FP8)
  - Fixed-stride kv_indices layout with -1 sentinels (required by
    AITER persistent-mode kernels)

- deepseek_v4_attention.py:
  - Add _aiter_scratch / _aiter_extra_scratch fields to __init__
  - Gate ROCm decode path: VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE=1
    routes to _forward_decode_aiter, otherwise falls back to the
    existing PyTorch reference

- Fix missing RoutingMethodType import in fused_moe/oracle/mxfp4.py

Validated numerically (cosine > 0.999) across TP2/TP4/TP8 configs
on MI355X. Micro-benchmarked at 2.4x speedup (b=128, dual-scope).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Address review feedback (tjtanaa, vllm-project#40889): on ROCm,
DeepSeek sparse attention can only run through AITER, so gating the
op with VLLM_ROCM_USE_AITER_MLA_DSV4_DECODE adds no value.

- Remove is_aiter_dsv4_decode_enabled() and the env-var lookup from
  vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py.
- Simplify the ROCm branch in DeepseekV4MLAAttention._forward_decode
  to dispatch unconditionally to _forward_decode_aiter.
- Drop the now-unused os import and the env-var mention in the
  _forward_decode_aiter docstring.

Signed-off-by: Chuan Li <chuan.li@amd.com>
Made-with: Cursor
Hoist all per-step allocations on the AITER sparse-MLA decode path into
the existing `AiterSparseScratch` cache so cudagraph capture sees stable
memory layouts. Previously each layer's call site freshly allocated
`qo_indptr`, `kv_indptr`, `kv_indices`, `kv_last_page_lens`, `q_scale`,
`kv_scale`, `q_fp8`, the bf16 output buffer, and intermediate boolean
masks every step, which was incompatible with HIP-graph capture and
generated unnecessary allocator pressure with 61 DSv4 attention layers.

Changes
-------

* `AiterSparseScratch` now caches:
  * Static buffers: `qo_indptr` (arange), `kv_last_page_lens` (ones),
    `col_arange`, `q_scale`, `kv_scale`.
  * Per-step write buffers: `kv_indptr`, `kv_indices_2d`, `valid_mask`,
    `valid_lens`, `q_fp8`, `out_buf`.
* `rebuild()` allocates and (where applicable) initialises every buffer
  once per `(total_q, h_q, topk, d_qk, d_v, dtype, kvtype)` key and runs
  `aiter.get_mla_metadata_v1` against the persistent qo/kv/last-page
  tensors.
* `_aiter_decode_one_scope` rewrites the per-step buffers in-place via
  `torch.lt(out=)`, `tensor.copy_`, `masked_fill_`, and
  `torch.cumsum(out=)` instead of fresh allocations.
* The public `aiter_sparse_attn_decode` signature is unchanged, so
  `DeepseekV4MLAAttention._forward_decode_aiter` keeps working as-is.

Follow-up to PR vllm-project#40889; remaining cudagraph blocker is the per-step
`blocked_k.to(fp8_e4m3fn)` cast on the dequantised KV cache, which is
tracked separately together with the dequantise-into-FP8 fast path
suggested in code review.

Test plan
---------

* `python -c "import ast; ast.parse(open('vllm/v1/attention/ops/rocm_aiter_dsv4_decode.py').read())"` passes.
* On MI355X with `chuali_glm51` container (rocm/vllm-dev:nightly +
  torch 2.10/HIP 7.2): manual smoke test pending; will follow up with
  cudagraph-mode (non-eager) startup log + decode parity vs eager.

AI assistance: drafted with Cursor agent; human-reviewed before
submission.

Signed-off-by: ChuanLi1101 <chuanli1101@gmail.com>
`aiter.get_mla_metadata_v1` produces a `work_*`/`reduce_*` plan that is
keyed on the *actual* per-batch kv lengths, not just on shapes. The
persistent ASM `mla_a8w8_qh16_qseqlen1_gqaratio16_lse_ps` kernel reads
out of bounds (causing a GPU memory access fault) if those buffers are
left stale across steps with different kv lengths.

Fix the cudagraph-clean refactor so the metadata is rewritten in-place
on every per-step call against the current `kv_indptr`. The buffer
sizes returned by `get_mla_metadata_info_v1` are determined by shapes
+ `max_split_per_batch` only, so they remain large enough for any kv
length distribution and the data pointers stay stable for graph capture.

  * `AiterSparseScratch.rebuild()` now only allocates buffers and stores
    the static gqa/topk/dtype parameters; it no longer requires a
    `kv_indptr_seed` and no longer runs the metadata builder itself.
  * New `AiterSparseScratch.refresh_metadata()` reruns
    `get_mla_metadata_v1` writing into the same `work_*`/`reduce_*` slots.
  * `_aiter_decode_one_scope` writes `valid_mask`/`valid_lens`/
    `kv_indptr`/`kv_indices_2d`/`q_fp8` directly into scratch every
    step, then calls `refresh_metadata()` and `mla.mla_decode_fwd`.

Validated with the standalone `bench_remote/_unit_test_cudagraph.py`
harness on MI355X:
  - Call 1 (lens=[3,2]): success, scratch key set.
  - Call 2 (same lens): rebuild skipped, all data_ptrs stable, output
    bit-identical to call 1.
  - Call 3 (lens=[4,1]): all data_ptrs still stable, output differs as
    expected (max abs diff = 2.39 vs identical-input call), no fault.
  - Parity check vs the original non-cudagraph implementation:
    max abs diff = 0.000000.

Signed-off-by: Chuan Li <chuanli1101@gmail.com>
Co-authored-by: Cursor
Signed-off-by: Li <chuali@amd.com>
@mergify mergify Bot added performance Performance-related issues gpt-oss Related to GPT-OSS models nvidia rocm Related to AMD ROCm speculative-decoding labels Apr 26, 2026
@mergify mergify Bot added the v1 label Apr 26, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 26, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ChuanLi1101.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for the DeepseekV4 model family, introducing Multi-Head Latent Attention (MLA) and Multi-Token Predictor (MTP) capabilities. The changes include several optimized CUDA and Triton kernels for fused operations such as RMSNorm, RoPE, and quantization, alongside hypercompressed (mHC) blocks implemented via TileLang. Review feedback identifies a critical performance bottleneck where the entire SWA cache is dequantized every decode step, which should be optimized to handle dequantization on-the-fly. Additionally, the implementation of _dequantize_blocked_k_cache is noted to be incompatible with CUDA graphs due to per-call tensor allocations. A potential runtime error due to a shape mismatch in the attention output copy operation was also flagged, with a recommendation to squeeze the sequence dimension.

Comment on lines +1152 to +1153
blocked_swa = self._dequantize_blocked_k_cache(
self.swa_cache_layer.kv_cache)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Dequantizing the entire SWA cache every decode step is extremely inefficient and will cause significant performance degradation as the KV cache grows. In vLLM V1, kv_cache typically represents the entire physical block pool. Dequantizing thousands of blocks (e.g., 64MB per layer for 1000 blocks) on every step will likely exceed the latency budget for real-time inference. The AITER kernel should ideally handle dequantization on-the-fly using the provided indices, or at least only dequantize the active sliding window region.

scratch=self._aiter_scratch,
extra_scratch=self._aiter_extra_scratch,
)
output.copy_(attn_out.to(output.dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Potential shape mismatch during copy_. aiter_sparse_attn_decode is called with a 4D query q.unsqueeze(1) (shape [num_tokens, 1, num_heads, head_dim]), so it likely returns a 4D tensor. However, output is a 3D tensor (shape [num_tokens, num_heads, head_dim]). This will cause a RuntimeError during the copy_ operation. You should explicitly squeeze the sequence dimension from the attention output.

Suggested change
output.copy_(attn_out.to(output.dtype))
output.copy_(attn_out.squeeze(1).to(output.dtype))

Comment on lines +1222 to +1226
result = torch.empty(
(num_blocks, block_size, 1, d),
dtype=torch.bfloat16,
device=quant_k_cache.device,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This allocation of the result tensor on every call is not cudagraph-clean and contradicts the PR's objective of making the decode path allocation-free. Since the number of blocks in the KV cache is fixed after initialization in vLLM V1, this buffer should be pre-allocated (e.g., as part of AiterSparseScratch or as a persistent buffer in the layer) to ensure pointer stability and avoid runtime overhead during graph replay.

@Morxi
Copy link
Copy Markdown

Morxi commented Apr 26, 2026

Fail on 8 x MI325x

DOCKER_BUILDKIT=1 docker build \
    -f docker/Dockerfile.rocm \
    -t vllm-rocm-custom:latest \
    --build-arg ROCM_VERSION=7.2.2 \
    --build-arg PYTHON_VERSION=3.10 \
    .
docker run -it --rm \
    --name vllm-deepseek-v4-1m \
    --group-add=video \
    --cap-add=SYS_PTRACE \
    --security-opt seccomp=unconfined \
    --device /dev/kfd \
    --device /dev/dri \
    --ipc=host \
    --shm-size 32gb \
    -p 8422:8000 \
    -v ~/.cache/huggingface:/root/.cache/huggingface \
    vllm-rocm-custom:latest \
    --model deepseek-ai/DeepSeek-V4-Pro \
        --host 0.0.0.0 \
        --port 8000 \
        --dtype auto \
        --tensor-parallel-size 8\
        --trust-remote-code
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 647, in <ge
nexpr>
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     layer_fn(prefix=f"{prefix}.{idx}") for idx in range(start_layer, end_layer)
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py", line 638, 
in <lambda>
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     lambda prefix: DeepseekV4DecoderLayer(
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]                    ^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py", line 489, 
in __init__
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py", line 235, 
in __init__
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     self.experts = SharedFusedMoE(
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]                    ^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/layer.py", line 5
23, in __init__
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     self.quant_method: FusedMoEMethodBase = _get_quant_method()
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]                                             ^^^^^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/layer.py", line 5
15, in _get_quant_method
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     quant_method = self.quant_config.get_quant_method(self, prefix)
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py", line 158, 
in get_quant_method
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     return Mxfp4MoEMethod(layer.moe_config)
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/mxfp4.py", lin
e 471, in __init__
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py",
 line 503, in select_mxfp4_moe_backend
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879]     raise NotImplementedError(
(Worker_TP7 pid=820) ERROR 04-26 02:28:10 [multiproc_executor.py:879] NotImplementedError: No MXFP4 MoE backend supports the deployment configuration.

@ChuanLi1101 ChuanLi1101 marked this pull request as draft April 26, 2026 07:34
@ChuanLi1101
Copy link
Copy Markdown
Collaborator Author

Moving to draft, blocked on #40889 (and on validating MI355X baseline generation quality). See #40889 for context. Will flip back to ready once #40889 is ready and lm_eval is re-run.

ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 26, 2026
Root cause of MI355X gibberish (vllm-project#40889 / vllm-project#40892 baseline):
DeepseekCompressor._forward_old (the ROCm path) rotates the
compressed K with the LAST token's position of each compressed
chunk (e.g. positions [3, 7, 11, ...] for compress_ratio=4),
while the NV fused kernels
(_fused_kv_compress_norm_rope_insert_indexer_attn and
_fused_kv_compress_norm_rope_insert_sparse_attn) rotate with the
FIRST token's position of the chunk:
    compressed_pos = (position // COMPRESS_RATIO) * COMPRESS_RATIO

This produces a (compress_ratio - 1)-token RoPE phase offset on
every cached K versus the model's training-time convention.
Q is rotated at its true token position (matching NV), so each
q.k inner product is rotated by ~(CR-1) extra positions, which
breaks the relative-position semantics of RoPE, corrupts both
rocm_fp8_mqa_logits (top-k indexer) and the sparse-MLA decode,
and produces incoherent first-token output on plain prompts.

Fix: align the K-cache RoPE position with the NV reference:
    rope_positions = (compressed_positions // CR) * CR

Symptom on MI355X (TP=8, bf16, /data/DeepSeek-V4-Pro):
  Before:
    LLM().generate("What is 2+2? Answer:")
       -> " 1 1 1 1 1 1 ..."  /  " (2 2-2 2 2 2 ..."
  Expected after fix:
    LLM().generate("What is 2+2? Answer:")
       -> "4" / "Four" / etc.

Affected paths:
  - head_dim == 512 (sparse-attn K cache, used by AITER/ref decode)
  - head_dim == 128 (indexer K cache, used by rocm_fp8_mqa_logits)
Both were rotated with the wrong position; the single-line change
covers both because the bug is in the shared apply_gptj_rope_ref
call before the head_dim split.

Test plan:
  1. wsl -> ssh smci355-ccs-aus-m13-05.cs-aus.dcgpu, container chuali_glm51
  2. Apply this commit to /workspace/vllm_dsv4
  3. Run bench_remote/_simple_smoke.py (TP=8, eager, bf16)
  4. Expect coherent text on prompts:
       "What is 2+2? Answer:"
       "The capital of France is"
       "Q: A robe takes 2 bolts of blue fiber and half that much white..."
       "Once upon a time,"
  5. If coherent, follow up with lm_eval gsm8k to confirm.

Signed-off-by: Chuan Li <chuali@amd.com>
Made-with: Cursor
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 27, 2026
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.

Concretely:
  * `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
  * `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
    optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
    bf16->fp8 cast into them in place.
  * `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
    `current_workspace_manager().get_simultaneous(...)` so they share a
    single workspace allocation, mirroring how prefill already does it.

Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.

Validated on MI355X with a standalone microbench:
  * bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
  * `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
  * pointer stable across varying per-step `lens` patterns
  * shape / dtype mismatch raises as expected

Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.

Concretely:
  * `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
  * `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
    optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
    bf16->fp8 cast into them in place.
  * `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
    `current_workspace_manager().get_simultaneous(...)` so they share a
    single workspace allocation, mirroring how prefill already does it.

Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.

Validated on MI355X with a standalone microbench:
  * bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
  * `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
  * pointer stable across varying per-step `lens` patterns
  * shape / dtype mismatch raises as expected

Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.

Concretely:
  * `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
  * `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
    optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
    bf16->fp8 cast into them in place.
  * `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
    `current_workspace_manager().get_simultaneous(...)` so they share a
    single workspace allocation, mirroring how prefill already does it.

Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.

Validated on MI355X with a standalone microbench:
  * bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
  * `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
  * pointer stable across varying per-step `lens` patterns
  * shape / dtype mismatch raises as expected

Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.

Concretely:
  * `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
  * `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
    optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
    bf16->fp8 cast into them in place.
  * `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
    `current_workspace_manager().get_simultaneous(...)` so they share a
    single workspace allocation, mirroring how prefill already does it.

Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.

Validated on MI355X with a standalone microbench:
  * bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
  * `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
  * pointer stable across varying per-step `lens` patterns
  * shape / dtype mismatch raises as expected

Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
ChuanLi1101 added a commit to ChuanLi1101/vllm that referenced this pull request Apr 28, 2026
Routes the per-layer-sized intermediates inside the AITER sparse MLA decode
path through the existing `current_workspace_manager()` so all 61 DSv4
attention layers reuse the same bf16 + fp8 buffers per step instead of
each layer allocating two fresh ~kv-cache-sized tensors.

Concretely:
  * `_dequantize_blocked_k_cache` accepts an optional `out=` bf16 buffer.
  * `aiter_sparse_attn_decode` and `_aiter_decode_one_scope` accept
    optional `kv_fp8_buf` / `extra_kv_fp8_buf` fp8 buffers and copy the
    bf16->fp8 cast into them in place.
  * `_forward_decode_aiter` (ROCm path) pulls the 2-or-4 buffers from
    `current_workspace_manager().get_simultaneous(...)` so they share a
    single workspace allocation, mirroring how prefill already does it.

Without this, every layer per step allocates two fresh per-kv-cache-sized
tensors that go into the cudagraph memory pool, multiplying that pool by
~60x worth of redundant slots on a 61-layer DSv4 model. The buffer sizes
depend only on static kv-cache shape (num_blocks, block_size, head_dim),
so the workspace reaches its max during warmup and stays stable through
capture and `lock_workspace()`.

Validated on MI355X with a standalone microbench:
  * bit-exact parity with the un-buffered path (`max abs diff = 0.0`)
  * `kv_fp8_buf.data_ptr()` stable across 61 simulated "layer" calls
  * pointer stable across varying per-step `lens` patterns
  * shape / dtype mismatch raises as expected

Stacks on top of vllm-project#40892 (cudagraph-clean AITER decode).

Co-authored-by: Claude
Signed-off-by: Chuan Li <chuanli@amd.com>
Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation gpt-oss Related to GPT-OSS models kv-connector needs-rebase new-model Requests to new models nvidia performance Performance-related issues rocm Related to AMD ROCm speculative-decoding tool-calling v1

Projects

Status: Todo
Status: No status
Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

7 participants