[ROCm][Bugfix] Plumb rotary_dim through fused QK-norm+RoPE+KV-cache kernel (enables GLM-4.7-FP8 on top of #42749)#43676
Conversation
…sion (ROCM_AITER_FA + UNIFIED_ATTN) Patch from vllm-project#42749
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. Agent GuidelinesIMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a new fusion pass, QkNormRopeKvCacheFusionPass, which fuses QK-norm, RoPE, and KV cache updates into a single AITER HIP kernel on ROCm. This change includes updates to configuration, compilation passes, and attention backends (rocm_aiter_fa, rocm_aiter_unified_attn, and rocm_attn) to support the new fusion, along with comprehensive unit tests. The review feedback suggests replacing assert statements with explicit ValueError raises in the attention backends to ensure robust validation even when Python is run with optimization flags.
| assert ( | ||
| cos_sin_cache.ndim == 2 | ||
| and 0 < rotary_dim <= head_dim | ||
| and rotary_dim % 2 == 0 | ||
| ), ( | ||
| f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache layout " | ||
| f"{tuple(cos_sin_cache.shape)} for head_dim={head_dim}; " | ||
| f"expected shape [max_pos, rotary_dim] with rotary_dim<=head_dim " | ||
| f"and even" | ||
| ) |
There was a problem hiding this comment.
Using assert statements for production validation checks is risky because they are stripped out when Python is run with optimization flags (e.g., python -O). To ensure that layout mismatches fail loudly in all environments, replace the assert with an explicit if check that raises a ValueError.
| assert ( | |
| cos_sin_cache.ndim == 2 | |
| and 0 < rotary_dim <= head_dim | |
| and rotary_dim % 2 == 0 | |
| ), ( | |
| f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache layout " | |
| f"{tuple(cos_sin_cache.shape)} for head_dim={head_dim}; " | |
| f"expected shape [max_pos, rotary_dim] with rotary_dim<=head_dim " | |
| f"and even" | |
| ) | |
| if not ( | |
| cos_sin_cache.ndim == 2 | |
| and 0 < rotary_dim <= head_dim | |
| and rotary_dim % 2 == 0 | |
| ): | |
| raise ValueError( | |
| f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache layout " | |
| f"{tuple(cos_sin_cache.shape)} for head_dim={head_dim}; " | |
| f"expected shape [max_pos, rotary_dim] with rotary_dim<=head_dim " | |
| f"and even" | |
| ) |
There was a problem hiding this comment.
Addressed in commit 093e48feb — replaced both assert checks with if not (...): raise ValueError(...) and added a comment noting the explicit raise is intentional so the check survives python -O. Thanks for the catch!
| assert ( | ||
| cos_sin_cache.ndim == 2 | ||
| and 0 < rotary_dim <= head_dim | ||
| and rotary_dim % 2 == 0 | ||
| ), ( | ||
| f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache layout " | ||
| f"{tuple(cos_sin_cache.shape)} for head_dim={head_dim}; " | ||
| f"expected shape [max_pos, rotary_dim] with rotary_dim<=head_dim " | ||
| f"and even" | ||
| ) |
There was a problem hiding this comment.
Using assert statements for production validation checks is risky because they are stripped out when Python is run with optimization flags (e.g., python -O). To ensure that layout mismatches fail loudly in all environments, replace the assert with an explicit if check that raises a ValueError.
| assert ( | |
| cos_sin_cache.ndim == 2 | |
| and 0 < rotary_dim <= head_dim | |
| and rotary_dim % 2 == 0 | |
| ), ( | |
| f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache layout " | |
| f"{tuple(cos_sin_cache.shape)} for head_dim={head_dim}; " | |
| f"expected shape [max_pos, rotary_dim] with rotary_dim<=head_dim " | |
| f"and even" | |
| ) | |
| if not ( | |
| cos_sin_cache.ndim == 2 | |
| and 0 < rotary_dim <= head_dim | |
| and rotary_dim % 2 == 0 | |
| ): | |
| raise ValueError( | |
| f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache layout " | |
| f"{tuple(cos_sin_cache.shape)} for head_dim={head_dim}; " | |
| f"expected shape [max_pos, rotary_dim] with rotary_dim<=head_dim " | |
| f"and even" | |
| ) |
There was a problem hiding this comment.
Addressed in commit 093e48feb — replaced both assert checks with if not (...): raise ValueError(...) and added a comment noting the explicit raise is intentional so the check survives python -O. Thanks for the catch!
The aiter fused kernel fused_qk_norm_rope_cache_pts_quant_shuffle accepts a rotary_dim parameter (default 0 -> falls back to HEAD_SIZE). vLLM was never threading it through, so models with partial_rotary_factor < 1 (e.g. GLM-4.7 with factor=0.5 -> rotary_dim=64) silently had full RoPE applied to all 128 head dims, plus out-of-bounds reads on the half-length cos_sin_cache. Result on GLM-4.7-FP8: gsm8k 0.92 -> 0.79 with fusion enabled, restored to 0.92 with this fix. The fix derives rotary_dim from the cos_sin_cache tensor itself: RotaryEmbedding._compute_cos_sin_cache builds cache = cat(cos, sin, -1) where cos/sin each have last dim = rotary_dim/2, so cos_sin_cache.shape[-1] == rotary_dim holds for every RoPE variant. Adds a defensive layout check (if not (...): raise ValueError, not an assert, so the check survives python -O) to fail loud rather than silently produce wrong outputs on future RoPE variants. Touches the two callers of fused_qk_norm_rope_and_cache: - vllm/v1/attention/backends/rocm_aiter_fa.py (ROCM_AITER_FA) - vllm/v1/attention/backends/rocm_attn.py (UNIFIED_ATTN) and extends the helper signature in vllm/_aiter_ops.py.
40e8e6d to
093e48f
Compare
|
Hey @omirosh, I read through your fix on top of my changes, recenlty I split my original PR into two parts; part 1 pertains to the fusion infrastructure changes, which your fix touches upon. Since your cherry pick, I've added more fixes on top of the existing infrastructure, and was curious to know whether you'd like me to absorb your fixes into my existing part 1, or if you'd like to rebase and land the fix once mine goes in? |
|
This pull request has merge conflicts that must be resolved before it can be |
Hi @jhu960213, thanks a lot for taking the time to read through my fixes! Absorbing it into your part 1 sounds great to me. That's the cleaner outcome for both of us (one PR instead of a stacked pair, and your part 1 lands GLM-4.7-FP8 working out of the box). I'll keep this PR open until your part 1 lands with the fix included, then close it. Thanks again for the quick turnaround! |
|
Hey @omirosh, so I've managed to fold your fix into my part 1 PR: #42749, but this requires my PR that literally just got merged into aiter today so tip of tree aiter: https://github.com/ROCm/aiter should give you what you need. Give my branch a try and see if the error is fixed? |
Hi @jhu960213! Thanks! I gave it a try, but unfortunately output is corrupted. Tested on main aiter + your branch, ROCM_AITER_FA, GLM-4.7-FP8, fp8 KV cache, EP on, fusion confirmed replacing all 92 attention patterns. Also reproduces on Qwen3-32B + bf16 KV (so not GLM/partial-rotary/FP8 specific). Both VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT=0/1 result in garbage output. I am assuming your #3429 aiter PR is addressing the layout flip from #42095 (get_kv_cache_shape from [2, num_blocks, B, H, D] to [num_blocks, 2, B, H, D], and rocm_aiter_fa.py switched from kv_cache.unbind(0) to unbind(1))? Could you share the exact aiter SHA + your full server invocation (env + --compilation-config + attn backend) so I can match it bit-for-bit, and whether you've run gsm8k or similar with fusion on? Happy to share repro logs. |
Summary
Follow-up to #42749 (e2e QK-norm + RoPE + KV-cache fusion on ROCm) to make
the fused kernel work for partial-RoPE models — concretely it enables
GLM-4.7-FP8 (
partial_rotary_factor=0.5,head_dim=128→rotary_dim=64),which is currently broken end-to-end when #42749's fusion is activated.
Stacked on top of #42749; the actual change in this PR is the
fix(rocm): plumb rotary_dim ...commit, a 3-file +55-line addition.The aiter kernel (
mrope_utils::fused_mrope_rms_kv_kernel) accepts arotary_dimparameter that defaults toHEAD_SIZEwhen not passed. vLLM wasnever threading it through, so partial-RoPE models silently:
the Q/K vectors of the second half of every head, and
cos_sin_cache[..., 64:128], which is out of bounds on a tensorwhose last dim is 64 — pulling arbitrary GPU memory into the rotation.
The result on GLM-4.7-FP8 is essentially-random Q/K vectors → garbage tokens
from the attention layer → catastrophic gsm8k regression (strict-match
0 of 1319 problems correct) the moment
fuse_qk_norm_rope_kvcache=trueis set.
Fix
Derive
rotary_dimfrom thecos_sin_cachetensor itself:RotaryEmbedding._compute_cos_sin_cachebuildscache = cat(cos, sin, dim=-1)wherecosandsineach have last dimrotary_dim/2, socos_sin_cache.shape[-1] == rotary_dimholds for everyRoPE variant currently in the tree.
Plumb it through to the aiter kernel call in both attention backends that
#42749 introduces as callers of
fused_qk_norm_rope_and_cache:vllm/_aiter_ops.py—fused_qk_norm_rope_and_cachegains arotary_dim: int = 0parameter forwarded to the aiter binding.vllm/v1/attention/backends/rocm_aiter_fa.py— derivesrotary_dimfromcos_sin_cache.shape[-1], passes it to the helper.vllm/v1/attention/backends/rocm_attn.py— same change.Also adds a defensive layout check (
if not (... ): raise ValueError(...),not an
assert— survivespython -O) oncos_sin_cache.ndim == 2 and 0 < rotary_dim <= head_dim and rotary_dim % 2 == 0so a future RoPE-cache layout change fails loud instead of silently
producing wrong outputs the way this bug did.
Total diff: +55 lines across 3 files, no behaviour change for full-RoPE
models (where
cos_sin_cache.shape[-1] == head_dimmakes the derivedrotary_dimequal to the kernel's previous default).Validation — A vs B2 vs D1
Three runs of the same workload (
random_input_len=1000,output_len=100,max_concurrency=4) on the same hardware (TP=4, MI355 / gfx950), withVLLM_ROCM_USE_AITER=1and FP8 KV cache:--compilation-config)fuse_qk_norm_rope_kvcache=true)¹ Wall-clock GPU activity per decode step, from the rank-0 torch.profiler
trace filtered on
cat=gpu_user_annotation, eventexecute_context_0(0)_generation_4(4). Stdev ≤ 0.65 ms across all threeruns.
What this shows
0.0000 → 0.9477(+0.9477). D1 vs A: strict-match within stderr (0.9477 vs 0.9484, ∆ ≈ 1×
σ). The model is producing correct outputs again on the fusion path.
(−4.7 %), throughput +3.6 %. The fusion is doing exactly what [Model][Hardware][AMD]: Part 1/2 -> Enable e2e QK Norm + RoPE + KV Cache runtime fusion for Qwen3-30B-A3B on ROCM_AITER_FA, and ROCM_AITER_UNIFIED_ATTN #42749
intended — just on correct data now.
This is one
cos_sin_cache.shape[-1]read plus the defensive assert perfused call. An order of magnitude smaller than the fusion's own +4–5 %
win.
Kernel-firing pattern: fix does not change what runs
Kernel-event counts from the rank-0 profiler traces, confirming the fix is
purely a kwarg correction:
mrope_utils::fused_mrope_rms_kv_kernel(aiter fused)vllm::reshape_and_cache_flash_kerneltriton_red_fused_2(QK RMSNorm)triton_poi_fused_3(RoPE rotation)vllm::rotary_embedding_kernel(unfused fallback)B2 and D1 have identical fusion firing patterns — only the
rotary_dimargument value to
fused_mrope_rms_kv_kerneldiffers. The fix changes thedata path inside one kernel, not the graph structure.
Root cause (one-liner for future grep)
mrope_utils::fused_mrope_rms_kv_kernelin aiter containsrotary_dim_ = rotary_dim > 0 ? rotary_dim : HEAD_SIZE. Without this PR,vLLM never passes
rotary_dim, so partial-RoPE models (e.g. GLM-4.7'srotary_dim=64,head_dim=128) get full-RoPE applied to all head dimsplus OOB reads on the half-length
cos_sin_cache. With this PR,rotary_dimis derived fromcos_sin_cache.shape[-1]and threadedthrough every caller of
fused_qk_norm_rope_and_cache.Tested configurations
GLM-4.7-FP8, TP=4, MI355 / gfx950,
VLLM_ROCM_USE_AITER=1, FP8 KV cache,--compilation-config '{"splitting_ops": [], "pass_config": {"fuse_qk_norm_rope_kvcache": true}}':ROCM_AITER_FATest plan
ROCM_AITER_FA, fusion on, fixapplied — strict-match 0.9477 (baseline parity vs A's 0.9484, within
stderr).
identically (9108–9200 invocations vs 0 at baseline) — fix is a kwarg
correction, not a graph change.
cos_sin_cache.shape[-1] == head_dimand the new assert doesn'tfire. Expected to be a no-op since the derived
rotary_dimequalsthe kernel's previous default.
Notes for reviewers
cos_sin_cache.shape[-1] == head_dim(any full-RoPE model), the derivedrotary_dimequals the kernel's previous default, so behaviour isbyte-identical to pre-PR for those models.
if … raise ValueError(not an
assert, which would be stripped bypython -O) — silentlyusing the wrong
rotary_dimis exactly how this bug went undetectedend-to-end. Failing loud on a future layout change is preferable to
another silent 0.92 → 0.00 regression.
landed together). The branch contains [Model][Hardware][AMD]: Part 1/2 -> Enable e2e QK Norm + RoPE + KV Cache runtime fusion for Qwen3-30B-A3B on ROCM_AITER_FA, and ROCM_AITER_UNIFIED_ATTN #42749's cherry-pick to make CI
green; only the
fix(rocm): plumb rotary_dim ...commit is new.Happy to fold the fix directly into [Model][Hardware][AMD]: Part 1/2 -> Enable e2e QK Norm + RoPE + KV Cache runtime fusion for Qwen3-30B-A3B on ROCM_AITER_FA, and ROCM_AITER_UNIFIED_ATTN #42749 if the reviewers there
prefer one combined commit.