Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
e3ce8aa
Add fused QK norm + RoPE + KV cache fusion infrastructure
jhu960213 May 13, 2026
034e4de
Add fused QK norm + RoPE + KV cache support for ROCM_AITER_FA and ROC…
jhu960213 May 13, 2026
cc9e330
Add unit test for QK norm + RoPE + KV cache fusion pass
jhu960213 May 13, 2026
803b694
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 May 15, 2026
5697b3a
Fix matcher_utils.py merge conflict resolution
jhu960213 May 15, 2026
f61f6a2
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 May 15, 2026
d7f4c57
Dropped hip_ convention for the fused kernel and removed auto enable …
jhu960213 May 20, 2026
439e265
Revamped fused qk norm rope kvcacheactivation from compilation and pa…
jhu960213 May 20, 2026
b96c04b
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 May 20, 2026
9eba884
Moved the caching of CPU scalar copies of k/v scales into the Attenti…
jhu960213 May 21, 2026
14b30bb
Fixed ruff auto staged
jhu960213 May 25, 2026
74c9c6a
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 May 25, 2026
039055f
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 May 26, 2026
72801be
Restaged ruff formatted file
jhu960213 May 26, 2026
5b18c74
Removed MatcherRMSNorm from matcher utils since we use vllm ir forwar…
jhu960213 May 26, 2026
f95b157
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 May 26, 2026
a96f91b
Ruff strip trailing whitespace from merg confict resolution
jhu960213 May 26, 2026
65d599f
Clean up code
jhu960213 May 27, 2026
ae4a653
Parametrized custom ops list in UT for qk norm rope kvcache
jhu960213 May 28, 2026
322b9e9
Restaging ruff formatted
jhu960213 May 28, 2026
8180635
Cleaned up more code
jhu960213 May 28, 2026
32406f2
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 May 28, 2026
4da1ddf
Removed the no ops from ROCM_ATTN for part1 and added to unified attn…
jhu960213 May 28, 2026
2f45e60
Added partial rotary embedding support
jhu960213 May 29, 2026
9ed3810
Added partial rotary embedding triage in qk norm rope cache unit test
jhu960213 Jun 3, 2026
e3eeaa6
Removed comments etc
jhu960213 Jun 4, 2026
b102dfc
Renamed unit test to align with rocm aiter convention
jhu960213 Jun 4, 2026
ea7dc41
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 4, 2026
45309aa
Removed some more comments
jhu960213 Jun 4, 2026
e4dc639
Some more cleanup
jhu960213 Jun 4, 2026
93ac152
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 9, 2026
98b313a
Parmetrized token num and shuffle write for this fusion and also lose…
jhu960213 Jun 10, 2026
9b3d9ce
Turned shuffle kv cache write off when qk norm + rope cache pts quant…
jhu960213 Jun 10, 2026
b00eb44
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 10, 2026
d67effd
Added another pattern and replacement logic for fp8 KV + ROCM_AITER_U…
jhu960213 Jun 11, 2026
79492b7
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 15, 2026
4f813fd
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 16, 2026
d0f260d
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 17, 2026
0014ca6
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 22, 2026
f8fc70c
Used correct per tensor quant call for both pattern and replacement f…
jhu960213 Jun 22, 2026
4288bd4
Fixed fp8 quant query pattern matching logic for unified attn
jhu960213 Jun 23, 2026
0f89797
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 25, 2026
b30d0cf
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 25, 2026
dcfc7f8
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 26, 2026
0d709e0
Merge branch 'main' into jhu96/optimize-qwen30b-part1
dllehr-amd Jun 29, 2026
66b4875
Added fix for kv scales device to host synching when model is loaded …
jhu960213 Jun 29, 2026
5dc9a1a
Added guard for this fusion to only fire when head dims are in the su…
jhu960213 Jun 29, 2026
322d6ed
Zero out q and k outs when fusion is skipped during profiling runs
jhu960213 Jun 29, 2026
f9407b7
Added a guard to ensure that layers only fire with this fusion if the…
jhu960213 Jun 29, 2026
c5ec548
Refactored this fusion to use a common helper func in _aiter_ops.py f…
jhu960213 Jun 29, 2026
cfa1208
Fixed fp8 quant query pattern match issue in unified attention
jhu960213 Jun 30, 2026
49e3d15
Merge branch 'main' into jhu96/optimize-qwen30b-part1
jhu960213 Jun 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
431 changes: 431 additions & 0 deletions tests/compile/passes/test_qk_norm_rope_kvcache_fusion.py

Large diffs are not rendered by default.

59 changes: 59 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,65 @@ def triton_fp4_gemm_dynamic_quant(
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y

@staticmethod
def fused_qk_norm_rope_and_cache(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
is_neox: bool,
rms_norm_eps: float,
q_out: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
k_out: torch.Tensor | None,
v_out: torch.Tensor | None,
return_kv: bool,
use_shuffle_layout: bool,
block_size: int,
x: int,
rotary_dim: int = 0,
):
from aiter.ops.fused_qk_norm_rope_cache_quant import (
fused_qk_norm_rope_cache_pts_quant_shuffle,
)

fused_qk_norm_rope_cache_pts_quant_shuffle(
qkv,
q_weight,
k_weight,
cos_sin_cache,
positions,
qkv.size(0),
num_heads_q,
num_heads_k,
num_heads_v,
head_dim,
is_neox,
rms_norm_eps,
q_out,
k_cache,
v_cache,
slot_mapping,
k_scale,
v_scale,
k_out,
v_out,
return_kv,
use_shuffle_layout,
block_size,
x,
rotary_dim,
)

@staticmethod
def triton_rope_and_cache(
query: torch.Tensor,
Expand Down
27 changes: 25 additions & 2 deletions vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@

FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default

# Head dimensions supported by csrc/fused_qknorm_rope_kernel.cu's
# launchFusedQKNormRope and launchFusedQKNormRopeNTokenHeads dispatchers.
# Keep in sync with the switch statements in that file.
SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS: tuple[int, ...] = (64, 128, 256)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be in qk_norm_rope_kvcache_fusion.py? Or is the fused kernel not have the head dim issue? Also will this be restrictive to the existing qk_norm_rope usecase? I figured we'd hit an error by now if we were restricted to just the 3 head dim sizes?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, for the QK Norm + Rope fusion, this fusion pass uses the

void launchFusedQKNormRope(void* qkv, int const num_tokens,
custom kernel here, which is also registered at
def fused_qk_norm_rope(
. Since I saw that the dispatch mechanism currently only supported the 3 head dims, I thought I'd be defensive and add a guard at the pass level. We don't match and replace into a kernel, which we'd TORCH_CHECK: here anyways at runtime? I think its cause the replacement kernel only supports 3 head dims, and the author didn't add this guard before bc the matcher would match and replace kernels (for diff models with unsupported head dims) only to be caught at runtime? Instead of catching it at the pass level?

I added SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS so the pass skips pattern registration when head_size isn’t supported, instead of matching/replacing and failing at runtime. Previously, there was no pass-level guard.... we only had the C++ switch. This doesn’t apply to fuse_qk_norm_rope_kvcache (AITER KV-cache fusion uses a different kernel).


P = ParamSpec("P")


Expand Down Expand Up @@ -186,7 +191,12 @@ def replacement(


class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.

Registers patterns for both standard vLLM ops and ROCm AITER ops
(when AITER is enabled), so the fusion fires regardless of which
RMSNorm/RoPE implementation the graph uses.
"""

@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
Expand All @@ -202,7 +212,6 @@ def __init__(self, config: VllmConfig) -> None:
)
return

# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
config, Attention
)
Expand All @@ -213,6 +222,20 @@ def __init__(self, config: VllmConfig) -> None:
return
layer = next(iter(attn_layers.values()))

if layer.head_size not in SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS:
logger.warning_once(
"QK Norm+RoPE fusion not enabled: layer head_size=%d is not "
"supported by fused_qk_norm_rope kernel (supported: %s). "
"Falling back to unfused QK norm + RoPE path.",
layer.head_size,
SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS,
)
return

# RMS norm variants are no longer iterated: after the vLLM IR migration (#33825)
Comment thread
jhu960213 marked this conversation as resolved.
Outdated
# AITER rope variants are also not iterated: `MatcherRotaryEmbedding`
# auto-detects via `rocm_aiter_ops.is_triton_rotary_embed_enabled()`
# and selects the right rotary op.
for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
Expand Down
Loading
Loading