[ROCm] Avoid full KV cache dequant in MLA decode fallback#42248
[ROCm] Avoid full KV cache dequant in MLA decode fallback#42248Bortlesboat wants to merge 1 commit into
Conversation
Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces an optimization to the ROCm Aiter MLA sparse attention fallback by ensuring that only the cache blocks actually referenced are dequantized. This is implemented through a new helper function, _gather_referenced_cache_blocks, which identifies used blocks, creates a compact cache, and remaps the indices. A new unit test has also been added to verify this behavior. I have no feedback to provide.
…ache
For DSv4's C4A topk-driven prefill path the dense ``dequantize_and_gather_
k_cache`` materializes the full compressed prefix per request into a BF16
``kv`` buffer, but ``flash_mla_sparse_fwd`` then reads only the topk-
selected positions out of it — the entries outside the topk union are
dequantized and discarded.
This commit adds a sparse variant that dequantizes exactly the global
slot ids the caller asks for, behind a ``has_cutedsl()`` dispatcher
matching the dense ``dequantize_and_gather_k_cache``:
dequantize_and_gather_k_cache_sparse(out, k_cache, slot_indices,
block_size)
* ``slot_indices: [N] int32`` — global slot ids produced by
``compute_global_topk_indices_and_lens``; ``-1`` entries are padding
(no load, no store).
* ``out: [N, head_dim] bf16`` — caller flattens any
``(num_queries × top_k)`` layout to N rows.
Two implementations:
* Triton fallback (``_sparse_triton``): 1-program-per-output-row, single
512-element FP8 tile load + BF16 tail, ``num_warps=1`` to keep the
vectorized load shape.
* CuteDSL fast path (``_sparse_cutedsl``): reuses
``DequantGatherKCacheKernel``'s 4-stage cp.async G2S pipeline +
bit-manip FP8→BF16 + BF16 UE8M0 scale multiply. The driver loops over
``slot_indices`` instead of ``(seq_lens, gather_lens, block_table)``.
Benchmark (B200, sm_100, block_size=64, iters=200). Full data in
``benchmarks/attention_benchmarks/deepseek_v4_kernel/results_sparse/
pr_tables.md``.
Table A — sparse Triton vs sparse CuteDSL on PR vllm-project#42236 Table 1 k_len
shapes:
k_len | triton_us | cutedsl_us | triton_GB/s | cutedsl_GB/s | speedup
16384 | 23.17 | 16.19 | 1137.1 | 1627.1 | 1.43x
32000 | 36.58 | 20.96 | 1406.8 | 2455.0 | 1.75x
262144 | 227.36 | 110.46 | 1854.0 | 3816.0 | 2.06x
Table B — sparse CuteDSL vs dense CuteDSL at N_dense=32768 (one
realistic C4A prefill chunk at 16K seqlen × cr=4 × 4 reqs):
fraction | sparse_us | dense_us | speedup vs dense
3% | 11.65 | 20.35 | 1.75x
12% | 11.97 | 20.29 | 1.70x
25% | 12.80 | 20.35 | 1.59x
100% | 20.80 | 20.35 | 0.98x
Tests
-----
``.venv/bin/python -m pytest tests/kernels/test_compressor_kv_cache.py``
— 55 passed, including 23 new sparse tests covering:
- contiguous-equiv vs legacy Triton on both impls
- -1 padding skip on both impls
- N=0 no-op
- scattered-permutation invariance
- CuteDSL-vs-Triton bit-identical on a random scattered batch with
-1 padding interleaved.
Non-duplicate
-------------
No open PR targets this kernel. Closest neighbours are vllm-project#42248 (ROCm MLA
decode fallback — different code path) and vllm-project#40909 (ROCm AITER buffer
sharing — different optimisation). vllm-project#42236 added the dense CuteDSL kernel
that the sparse path complements; this PR does not touch the dense path.
Follow-up
---------
Wiring this into the C4A prefill loop in ``DeepseekV4MLAAttention.
_forward_prefill`` requires per-request ``topk_indices`` deduplication,
a compact ``kv`` buffer layout, and a remap step in
``combine_topk_swa_indices``. That integration + end-to-end TTFT data
will land in a follow-up PR.
AI-assisted (Claude). The submitting human reviewed every changed line
and ran the listed test command.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: ly860 <1475865623@qq.com>
|
@Bortlesboat heads-up: I opened #42576 as an alternative approach to the Reason for a separate PR rather than a follow-up commit on this branch: #42576 takes a materially different algorithm (per-token gather + Happy to close #42576 in favor of you adopting the same approach here if |
|
Follow-up on my earlier ping — actually checked PR #41812 ("[ROCm][DSv4] implement flash sparse mla with triton kernels") So the root cause from #41962 is gone on I just closed #42576 as obsolete. You may want to close this one too, Sorry for the noise — should have caught the |
|
Closing — #41812 (merged May 11) swapped out the Python fallback path I was targeting. The new |
Fixes #41962.
Summary
rocm_dequantize_blocked_k_cacherocm_ref_sparse_attn_decodeDuplicate checks
gh issue view 41962 --repo vllm-project/vllm --comments: issue has only the ROCm auto-CC commentgh pr list --repo vllm-project/vllm --state open --search "41962 in:body": no open PRsgh pr list --repo vllm-project/vllm --state open --search "rocm_dequantize_blocked_k_cache OOM": no open PRsgh pr list --repo vllm-project/vllm --state open --search "DeepSeek-V4-Flash decode fallback KV cache OOM": found [ROCm] DeepSeekV4-Flash-Base model enablement on ROCm with triton & torchfallback #41136 and [Feature] TRITON_MLA_SPARSE backend for SM8x/11x/12x DSA Sparse MLA Support #38476, but neither is a duplicate of this current-main fallback compaction fix. [ROCm] DeepSeekV4-Flash-Base model enablement on ROCm with triton & torchfallback #41136 is broader DeepSeekV4 ROCm enablement/fallback work, and [Feature] TRITON_MLA_SPARSE backend for SM8x/11x/12x DSA Sparse MLA Support #38476 is a separate sparse MLA backend effort.Tests
uv run --isolated --with pytest python -m pytest --noconftest tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py -q->1 passeduvx ruff format --check vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py-> passeduvx ruff check vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py-> passeduvx pre-commit run mypy-local --files vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py-> passedSKIP=update-dockerfile-graph uvx pre-commit run --files vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py-> passed; Dockerfile graph hook skipped because no Docker files changed and the local Windows shell does not provide/bin/bashAI assistance
This draft was prepared with AI assistance from OpenAI Codex. The human submitter should review the changed lines and test outputs before marking it ready for review.