[ROCm] Make sparse-MLA decode fallback gather CUDA-graph safe (alt to #42248)#42576
[ROCm] Make sparse-MLA decode fallback gather CUDA-graph safe (alt to #42248)#42576ChuanLi1101 wants to merge 2 commits into
Conversation
Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: Bortlesboat <bortstheboat@gmail.com>
The previous compaction in `_gather_referenced_cache_blocks` used boolean masking and `torch.unique` to dedup referenced cache blocks, which produce data-dependent output shapes. HIP/CUDA stream capture rejects these with `hipErrorStreamCaptureUnsupported` (vllm-project#41962 verification on MI325 — only `--enforce-eager` worked, graph mode crashed during CUDAGraph capture). Replace the host-side block compaction with a per-token gather + dequant helper `_dequantize_referenced_tokens` that uses only operations whose output shape is statically derivable from `indices.shape`: * Advanced indexing on `(num_blocks, block_size, ...)` views with `(block_idx, slot_idx)` tensors derived via `clamp_min` + `div` + `%`. * `torch.where` over the `>=0` mask for index remapping (no boolean indexing, no `unique`). * Output buffer `torch.empty((indices.numel(), head_dim), bf16)` is bounded by indices alone, so the OOM win that motivated vllm-project#42248 is preserved (vs. the original full-pool dequant of `num_total_blocks * block_size * head_dim`). Tests: - Updated regression locks in: original (uncompacted) cache and indices passed straight through, gathered tensor row-count tracks `indices.numel()`, identity-style remap with -1 sentinels preserved. - Added direct test of `_dequantize_referenced_tokens` on multiple index shapes confirming the static-shape contract. - Added numerical equivalence test against `rocm_dequantize_blocked_k_cache` to guard the per-token dequant math (skipped when `torch.float8_e8m0fnu` is unavailable). Co-authored-by: Andrew Barnes <bortstheboat@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Chuan Li <chuali@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces a graph-safe per-token dequantization path for the ROCm MLA sparse fallback. By replacing the previous block-level compaction with a method using static shapes and advanced indexing, the implementation now supports HIP graph capture and avoids OOM issues by only dequantizing referenced tokens. A new test suite has been added to ensure numerical correctness and verify that the operations remain graph-safe. I have no feedback to provide.
|
Closing as obsolete — and apologies for the noise. After opening this PR I rebased onto current
So the OOM root cause from #41962 ( This is on me — my AGENTS.md duplicate-work check covered open PRs and Filing a follow-up note on #41962 asking the reporter to retest against |
Purpose
Fixes #41962. Alternative to #42248 with a CUDA/HIP-graph-safe approach.
#42248 addresses the same OOM in
rocm_forward_decode_fallbackbycompacting the SWA / KV cache to only referenced blocks before
rocm_dequantize_blocked_k_cache. Internal MI325 verification confirmsthe OOM is fixed in eager mode, but the compaction helper
_gather_referenced_cache_blocksuses boolean masking(
indices[indices >= 0]) andtorch.unique— both producedata-dependent output shapes that HIP/CUDA stream capture rejects with
hipErrorStreamCaptureUnsupported. The PR is unmergeable in graph modewithout
--enforce-eager.This PR resolves #41962 with a per-token gather + dequantization that is
also CUDA/HIP-graph-safe:
_dequantize_referenced_tokens(quant_k_cache, indices, ...)helperthat gathers and dequantizes per token slot using only operations whose
output shape is statically derivable from
indices.shape(advancedindexing on
(num_blocks, block_size, ...)views;torch.whereoverthe
>=0mask; no boolean indexing, notorch.unique, no Pythonbranches on tensor data).
indices.numel() * head_dim * 2bytes percache, decoupled from
num_total_blocks. Preserves the OOM winmotivating [ROCm] Avoid full KV cache dequant in MLA decode fallback #42248: for typical DSv4-Flash decode shapes this is
~10–60× smaller than the original full-pool dequant.
rocm_forward_decode_fallbackis rewired to use the new helper.rocm_ref_sparse_attn_decode's public API is unchanged; its innerindex_selectbecomes an identity gather on already-gathered data(negligible cost vs. the dequant savings).
@Bortlesboat is credited via
Co-authored-by:on the commit since thisPR builds on his identification of the OOM and his test scaffolding from
#42248.
Duplicate-work check
gh issue view 41962 --repo vllm-project/vllm --comments— only thegithub-actionsauto-CC and Bortlesboat's [ROCm] Avoid full KV cache dequant in MLA decode fallback #42248 announcement.gh pr list --repo vllm-project/vllm --state open --search '41962 in:body'— only [ROCm] Avoid full KV cache dequant in MLA decode fallback #42248 (the PR this one is a graph-safe alternative to).
gh pr list --repo vllm-project/vllm --state open --search 'rocm sparse mla decode fallback'— also [ROCm] Avoid full KV cache dequant in MLA decode fallback #42248. [Kernel][DSv4] Add sparse-gather variant of dequantize_and_gather_k_cache #42504 ("[Kernel][DSv4] Add sparse-gather variant of
dequantize_and_gather_k_cache") is for the DSv4 prefill path on
NVIDIA B200 / CUDA, a different code path ([Kernel][DSv4] Add sparse-gather variant of dequantize_and_gather_k_cache #42504 body explicitly
marks [ROCm] Avoid full KV cache dequant in MLA decode fallback #42248 as a different code path).
different graph-mode behavior). Per AGENTS.md: "If your approach is
materially different, explain the difference in the issue."
Test Plan
Test Result
Three regression tests:
cache and indices are passed straight through to the new helper, the
gathered tensor row-count tracks
indices.numel(), and identity-styleremap with
-1sentinels is preserved for downstream masking._dequantize_referenced_tokens'soutput shape is a function of
indices.shapeonly, not cache size —the invariant that prevents the
hipErrorStreamCaptureUnsupportedregression from being reintroduced.
rocm_dequantize_blocked_k_cachebit-for-bit on referenced tokens(skipped when
torch.float8_e8m0fnuis unavailable).End-to-end MI325 / MI300 verification on the DSv4-Flash repro from
#41962 is pending — the same
vllm servecommand from issue #41962that previously required
--enforce-eagershould now succeed in graphmode as well.
pre-commitcannot bootstrap a Windows venv on the local dev host(known:
virtualenvcache layout missing on Windows app-distributionPython);
ruff checkandruff formatare clean on the changed files.AI assistance
This PR was prepared with AI assistance. Per vLLM AI-assisted
contribution rules, the human submitter has reviewed every changed line
and runs the listed tests before requesting review.