Skip to content

[ROCm] Avoid full KV cache dequant in MLA decode fallback#42248

Closed
Bortlesboat wants to merge 1 commit into
vllm-project:mainfrom
Bortlesboat:codex/vllm-rocm-swa-dequant-slice-41962
Closed

[ROCm] Avoid full KV cache dequant in MLA decode fallback#42248
Bortlesboat wants to merge 1 commit into
vllm-project:mainfrom
Bortlesboat:codex/vllm-rocm-swa-dequant-slice-41962

Conversation

@Bortlesboat

Copy link
Copy Markdown
Contributor

Fixes #41962.

Summary

  • compact the SWA and extra KV caches to only the physical blocks referenced by the current decode indices before rocm_dequantize_blocked_k_cache
  • remap flattened token indices into the compact cache layout before calling rocm_ref_sparse_attn_decode
  • add a focused regression test proving the fallback no longer sends the full cache pool into dequantization

Duplicate checks

Tests

  • uv run --isolated --with pytest python -m pytest --noconftest tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py -q -> 1 passed
  • uvx ruff format --check vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py -> passed
  • uvx ruff check vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py -> passed
  • uvx pre-commit run mypy-local --files vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py -> passed
  • SKIP=update-dockerfile-graph uvx pre-commit run --files vllm/v1/attention/ops/rocm_aiter_mla_sparse.py tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py -> passed; Dockerfile graph hook skipped because no Docker files changed and the local Windows shell does not provide /bin/bash

AI assistance

This draft was prepared with AI assistance from OpenAI Codex. The human submitter should review the changed lines and test outputs before marking it ready for review.

Co-authored-by: OpenAI Codex <codex@openai.com>

Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
@mergify mergify Bot added rocm Related to AMD ROCm v1 labels May 10, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 10, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization to the ROCm Aiter MLA sparse attention fallback by ensuring that only the cache blocks actually referenced are dequantized. This is implemented through a new helper function, _gather_referenced_cache_blocks, which identifies used blocks, creates a compact cache, and remaps the indices. A new unit test has also been added to verify this behavior. I have no feedback to provide.

ly865623 added a commit to ly865623/vllm that referenced this pull request May 13, 2026
…ache

For DSv4's C4A topk-driven prefill path the dense ``dequantize_and_gather_
k_cache`` materializes the full compressed prefix per request into a BF16
``kv`` buffer, but ``flash_mla_sparse_fwd`` then reads only the topk-
selected positions out of it — the entries outside the topk union are
dequantized and discarded.

This commit adds a sparse variant that dequantizes exactly the global
slot ids the caller asks for, behind a ``has_cutedsl()`` dispatcher
matching the dense ``dequantize_and_gather_k_cache``:

  dequantize_and_gather_k_cache_sparse(out, k_cache, slot_indices,
                                       block_size)

* ``slot_indices: [N] int32`` — global slot ids produced by
  ``compute_global_topk_indices_and_lens``; ``-1`` entries are padding
  (no load, no store).
* ``out: [N, head_dim] bf16`` — caller flattens any
  ``(num_queries × top_k)`` layout to N rows.

Two implementations:

* Triton fallback (``_sparse_triton``): 1-program-per-output-row, single
  512-element FP8 tile load + BF16 tail, ``num_warps=1`` to keep the
  vectorized load shape.
* CuteDSL fast path (``_sparse_cutedsl``): reuses
  ``DequantGatherKCacheKernel``'s 4-stage cp.async G2S pipeline +
  bit-manip FP8→BF16 + BF16 UE8M0 scale multiply. The driver loops over
  ``slot_indices`` instead of ``(seq_lens, gather_lens, block_table)``.

Benchmark (B200, sm_100, block_size=64, iters=200). Full data in
``benchmarks/attention_benchmarks/deepseek_v4_kernel/results_sparse/
pr_tables.md``.

Table A — sparse Triton vs sparse CuteDSL on PR vllm-project#42236 Table 1 k_len
shapes:

  k_len  | triton_us | cutedsl_us | triton_GB/s | cutedsl_GB/s | speedup
   16384 |     23.17 |      16.19 |      1137.1 |       1627.1 |   1.43x
   32000 |     36.58 |      20.96 |      1406.8 |       2455.0 |   1.75x
  262144 |    227.36 |     110.46 |      1854.0 |       3816.0 |   2.06x

Table B — sparse CuteDSL vs dense CuteDSL at N_dense=32768 (one
realistic C4A prefill chunk at 16K seqlen × cr=4 × 4 reqs):

  fraction | sparse_us | dense_us | speedup vs dense
        3% |     11.65 |    20.35 |             1.75x
       12% |     11.97 |    20.29 |             1.70x
       25% |     12.80 |    20.35 |             1.59x
      100% |     20.80 |    20.35 |             0.98x

Tests
-----
``.venv/bin/python -m pytest tests/kernels/test_compressor_kv_cache.py``
— 55 passed, including 23 new sparse tests covering:
  - contiguous-equiv vs legacy Triton on both impls
  - -1 padding skip on both impls
  - N=0 no-op
  - scattered-permutation invariance
  - CuteDSL-vs-Triton bit-identical on a random scattered batch with
    -1 padding interleaved.

Non-duplicate
-------------
No open PR targets this kernel. Closest neighbours are vllm-project#42248 (ROCm MLA
decode fallback — different code path) and vllm-project#40909 (ROCm AITER buffer
sharing — different optimisation). vllm-project#42236 added the dense CuteDSL kernel
that the sparse path complements; this PR does not touch the dense path.

Follow-up
---------
Wiring this into the C4A prefill loop in ``DeepseekV4MLAAttention.
_forward_prefill`` requires per-request ``topk_indices`` deduplication,
a compact ``kv`` buffer layout, and a remap step in
``combine_topk_swa_indices``. That integration + end-to-end TTFT data
will land in a follow-up PR.

AI-assisted (Claude). The submitting human reviewed every changed line
and ran the listed test command.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: ly860 <1475865623@qq.com>
@ChuanLi1101

Copy link
Copy Markdown
Collaborator

@Bortlesboat heads-up: I opened #42576 as an alternative approach to the
same #41962 OOM, with attribution to you on the commit (Co-authored-by:).

Reason for a separate PR rather than a follow-up commit on this branch:
internal MI325 verification of this PR confirmed the OOM fix in eager
mode, but exposed a regression under HIP graph capture —
_gather_referenced_cache_blocks uses boolean masking + torch.unique,
which produce data-dependent output shapes that stream capture rejects
(hipErrorStreamCaptureUnsupported). So this PR is currently only
mergeable behind --enforce-eager.

#42576 takes a materially different algorithm (per-token gather +
dequant instead of block-level compaction), which keeps the OOM win and
is also graph-safe (only ops with statically derivable output shapes:
advanced indexing, torch.where over the >=0 mask). Since the
algorithm is different end-to-end I felt a separate PR was clearer than
force-pushing your branch.

Happy to close #42576 in favor of you adopting the same approach here if
you prefer to keep ownership of the fix — just let me know which way you
want to go. Either path resolves #41962.

@ChuanLi1101

Copy link
Copy Markdown
Collaborator

Follow-up on my earlier ping — actually checked main more carefully and
it looks like both this PR and #42576 are obsolete:

PR #41812 ("[ROCm][DSv4] implement flash sparse mla with triton kernels")
was merged on 2026-05-11, the day after this PR was opened. It removes
rocm_dequantize_blocked_k_cache, rocm_ref_sparse_attn_decode, and
rocm_forward_decode_fallback from rocm_aiter_mla_sparse.py. The new
rocm_sparse_attn_decode reads the uint8 quant cache directly through
_rocm_sparse_attn_decode_triton with no full-pool bf16 materialization,
and the prefill path uses dequantize_and_gather_k_cache(out=kv[:chunk_size], ...)
into a chunked, caller-sized buffer.

So the root cause from #41962 is gone on main by way of architectural
removal — neither this PR nor #42576 has live code to merge into, and a
rebase produces structural conflicts (the helpers being modified no
longer exist).

I just closed #42576 as obsolete. You may want to close this one too,
or re-baseline against current main and check whether the OOM still
reproduces on the new triton path before continuing.

Sorry for the noise — should have caught the main refactor on my
duplicate-work check before opening #42576.

@Bortlesboat

Copy link
Copy Markdown
Contributor Author

Closing — #41812 (merged May 11) swapped out the Python fallback path I was targeting. The new _rocm_sparse_attn_decode_triton walks referenced cache blocks via ragged indices inside the kernel itself, so the slice-then-dequant approach here doesn't have a place anymore.

@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[ROCm] DeepSeek-V4-Flash: rocm_dequantize_blocked_k_cache materializes entire KV cache pool causing OOM during decode

2 participants