Skip to content

[ROCm] Make sparse-MLA decode fallback gather CUDA-graph safe (alt to #42248)#42576

Closed
ChuanLi1101 wants to merge 2 commits into
vllm-project:mainfrom
ChuanLi1101:chuali/42248-graph-safe-gather
Closed

[ROCm] Make sparse-MLA decode fallback gather CUDA-graph safe (alt to #42248)#42576
ChuanLi1101 wants to merge 2 commits into
vllm-project:mainfrom
ChuanLi1101:chuali/42248-graph-safe-gather

Conversation

@ChuanLi1101

Copy link
Copy Markdown
Collaborator

Purpose

Fixes #41962. Alternative to #42248 with a CUDA/HIP-graph-safe approach.

#42248 addresses the same OOM in rocm_forward_decode_fallback by
compacting the SWA / KV cache to only referenced blocks before
rocm_dequantize_blocked_k_cache. Internal MI325 verification confirms
the OOM is fixed in eager mode, but the compaction helper
_gather_referenced_cache_blocks uses boolean masking
(indices[indices >= 0]) and torch.unique — both produce
data-dependent output shapes that HIP/CUDA stream capture rejects with
hipErrorStreamCaptureUnsupported. The PR is unmergeable in graph mode
without --enforce-eager.

This PR resolves #41962 with a per-token gather + dequantization that is
also CUDA/HIP-graph-safe:

  • New _dequantize_referenced_tokens(quant_k_cache, indices, ...) helper
    that gathers and dequantizes per token slot using only operations whose
    output shape is statically derivable from indices.shape (advanced
    indexing on (num_blocks, block_size, ...) views; torch.where over
    the >=0 mask; no boolean indexing, no torch.unique, no Python
    branches on tensor data).
  • Output buffer is bounded by indices.numel() * head_dim * 2 bytes per
    cache, decoupled from num_total_blocks. Preserves the OOM win
    motivating [ROCm] Avoid full KV cache dequant in MLA decode fallback #42248: for typical DSv4-Flash decode shapes this is
    ~10–60× smaller than the original full-pool dequant.
  • rocm_forward_decode_fallback is rewired to use the new helper.
    rocm_ref_sparse_attn_decode's public API is unchanged; its inner
    index_select becomes an identity gather on already-gathered data
    (negligible cost vs. the dequant savings).

@Bortlesboat is credited via Co-authored-by: on the commit since this
PR builds on his identification of the OOM and his test scaffolding from
#42248.

Duplicate-work check

Test Plan

.venv/bin/python -m pytest tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py -v

Test Result

tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py::test_forward_decode_fallback_uses_graph_safe_per_token_gather PASSED
tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py::test_dequantize_referenced_tokens_is_graph_safe PASSED
tests/v1/attention/test_rocm_aiter_mla_sparse_fallback.py::test_dequantize_referenced_tokens_matches_block_level_dequant PASSED

============================== 3 passed in 4.14s ==============================

Three regression tests:

  • End-to-end fallback: locks in that the original (uncompacted)
    cache and indices are passed straight through to the new helper, the
    gathered tensor row-count tracks indices.numel(), and identity-style
    remap with -1 sentinels is preserved for downstream masking.
  • Graph-safe shape contract: _dequantize_referenced_tokens's
    output shape is a function of indices.shape only, not cache size —
    the invariant that prevents the hipErrorStreamCaptureUnsupported
    regression from being reintroduced.
  • Numerical equivalence: per-token dequant matches
    rocm_dequantize_blocked_k_cache bit-for-bit on referenced tokens
    (skipped when torch.float8_e8m0fnu is unavailable).

End-to-end MI325 / MI300 verification on the DSv4-Flash repro from
#41962 is pending — the same vllm serve command from issue #41962
that previously required --enforce-eager should now succeed in graph
mode as well.

pre-commit cannot bootstrap a Windows venv on the local dev host
(known: virtualenv cache layout missing on Windows app-distribution
Python); ruff check and ruff format are clean on the changed files.

AI assistance

This PR was prepared with AI assistance. Per vLLM AI-assisted
contribution rules, the human submitter has reviewed every changed line
and runs the listed tests before requesting review.

Bortlesboat and others added 2 commits May 10, 2026 15:38
Co-authored-by: OpenAI Codex <codex@openai.com>

Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
The previous compaction in `_gather_referenced_cache_blocks` used boolean
masking and `torch.unique` to dedup referenced cache blocks, which produce
data-dependent output shapes. HIP/CUDA stream capture rejects these with
`hipErrorStreamCaptureUnsupported` (vllm-project#41962 verification on MI325 — only
`--enforce-eager` worked, graph mode crashed during CUDAGraph capture).

Replace the host-side block compaction with a per-token gather + dequant
helper `_dequantize_referenced_tokens` that uses only operations whose
output shape is statically derivable from `indices.shape`:

* Advanced indexing on `(num_blocks, block_size, ...)` views with
  `(block_idx, slot_idx)` tensors derived via `clamp_min` + `div` + `%`.
* `torch.where` over the `>=0` mask for index remapping (no boolean
  indexing, no `unique`).
* Output buffer `torch.empty((indices.numel(), head_dim), bf16)` is
  bounded by indices alone, so the OOM win that motivated vllm-project#42248 is
  preserved (vs. the original full-pool dequant of
  `num_total_blocks * block_size * head_dim`).

Tests:
- Updated regression locks in: original (uncompacted) cache and indices
  passed straight through, gathered tensor row-count tracks
  `indices.numel()`, identity-style remap with -1 sentinels preserved.
- Added direct test of `_dequantize_referenced_tokens` on multiple
  index shapes confirming the static-shape contract.
- Added numerical equivalence test against `rocm_dequantize_blocked_k_cache`
  to guard the per-token dequant math (skipped when `torch.float8_e8m0fnu`
  is unavailable).

Co-authored-by: Andrew Barnes <bortstheboat@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@ChuanLi1101 ChuanLi1101 requested a review from tjtanaa as a code owner May 13, 2026 21:54

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added nvidia rocm Related to AMD ROCm v1 labels May 13, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 13, 2026
@mergify

mergify Bot commented May 13, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ChuanLi1101.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 13, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a graph-safe per-token dequantization path for the ROCm MLA sparse fallback. By replacing the previous block-level compaction with a method using static shapes and advanced indexing, the implementation now supports HIP graph capture and avoids OOM issues by only dequantizing referenced tokens. A new test suite has been added to ensure numerical correctness and verify that the operations remain graph-safe. I have no feedback to provide.

@ChuanLi1101

Copy link
Copy Markdown
Collaborator Author

Closing as obsolete — and apologies for the noise.

After opening this PR I rebased onto current main and hit a wide
conflict, which led me to look closer at vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
on main. PR #41812 ("[ROCm][DSv4] implement flash sparse mla with
triton kernels") was merged on 2026-05-11, two days after #42248 was
opened, and it removes the entire rocm_dequantize_blocked_k_cache-based
fallback path that #41962 / #42248 / this PR were addressing:

  • rocm_dequantize_blocked_k_cache, rocm_ref_sparse_attn_decode, and
    rocm_forward_decode_fallback no longer exist in
    rocm_aiter_mla_sparse.py.
  • The new rocm_sparse_attn_decode reads the uint8 quant cache directly
    via _rocm_sparse_attn_decode_triton, with no full-pool bf16
    materialization.
  • For the prefill path, dequantize_and_gather_k_cache(out=kv[:chunk_size], ...)
    writes into a chunked, caller-sized output buffer rather than the full
    cache pool.

So the OOM root cause from #41962 (rocm_dequantize_blocked_k_cache
allocating (num_total_blocks, block_size, ...) bf16) is gone in main
by virtue of architectural removal, and any PR rebuilding that function
is fixing dead code. The internal MI325 verification that motivated this
PR (hipErrorStreamCaptureUnsupported from the compaction helper) was
on a base that included #42248's commit but pre-dated #41812.

This is on me — my AGENTS.md duplicate-work check covered open PRs and
the issue thread but not whether main had already obsoleted the
problem. I'll be more careful to verify the target code still exists on
HEAD before opening the next PR.

Filing a follow-up note on #41962 asking the reporter to retest against
a build that includes #41812. Apologies for the churn,
@Bortlesboat / @tjtanaa / @hongxiayang.

@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 13, 2026
@github-project-automation github-project-automation Bot moved this to Done in NVIDIA May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

[ROCm] DeepSeek-V4-Flash: rocm_dequantize_blocked_k_cache materializes entire KV cache pool causing OOM during decode

2 participants