From 3fd785ecef008408b51ddc32da4a7a865ee8c434 Mon Sep 17 00:00:00 2001 From: Integration Build Date: Thu, 21 May 2026 05:48:11 +0000 Subject: [PATCH 1/2] cherry-pick(vllm#42749): AMD Kernel: e2e QK Norm + RoPE + KV Cache fusion (ROCM_AITER_FA + UNIFIED_ATTN) Patch from https://github.com/vllm-project/vllm/pull/42749 --- .../test_qk_norm_rope_kvcache_fusion.py | 520 ++++++++++++++++++ vllm/_aiter_ops.py | 57 ++ .../passes/fusion/act_quant_fusion.py | 4 +- .../passes/fusion/matcher_utils.py | 47 +- .../passes/fusion/qk_norm_rope_fusion.py | 67 ++- .../fusion/qk_norm_rope_kvcache_fusion.py | 398 ++++++++++++++ vllm/compilation/passes/pass_manager.py | 9 + vllm/config/compilation.py | 38 +- vllm/config/vllm.py | 104 ++++ vllm/v1/attention/backend.py | 37 ++ vllm/v1/attention/backends/rocm_aiter_fa.py | 89 +++ .../backends/rocm_aiter_unified_attn.py | 17 + vllm/v1/attention/backends/rocm_attn.py | 98 ++++ 13 files changed, 1465 insertions(+), 20 deletions(-) create mode 100644 tests/compile/passes/test_qk_norm_rope_kvcache_fusion.py create mode 100644 vllm/compilation/passes/fusion/qk_norm_rope_kvcache_fusion.py diff --git a/tests/compile/passes/test_qk_norm_rope_kvcache_fusion.py b/tests/compile/passes/test_qk_norm_rope_kvcache_fusion.py new file mode 100644 index 000000000000..06ec9c4000f1 --- /dev/null +++ b/tests/compile/passes/test_qk_norm_rope_kvcache_fusion.py @@ -0,0 +1,520 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +import pytest +import torch + +import vllm.config +from tests.compile.backend import TestBackend +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata +from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops +from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP +from vllm.compilation.passes.fusion.qk_norm_rope_kvcache_fusion import ( + QkNormRopeKvCacheFusionPass, +) +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass +from vllm.compilation.passes.utility.scatter_split_replace import ( + ScatterSplitReplacementPass, +) +from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass +from vllm.config import ( + CacheConfig, + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, +) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform +from vllm.v1.attention.backend import ( + AttentionBackend, + CommonAttentionMetadata, +) +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.kv_cache_interface import AttentionSpec + +INDEX_SELECT_OP = torch.ops.aten.index.Tensor +FP8_DTYPE = current_platform.fp8_dtype() + + +class QKNormRoPEKVCacheTestModel(torch.nn.Module): + """Minimal model that reproduces the QK-norm + RoPE + KV cache update + pattern matched by QkNormRopeKvCacheFusionPass: + + q, k, v = split(qkv) + q = rms_norm(q.view(heads, dim), q_weight).view(flat) + k = rms_norm(k.view(heads, dim), k_weight).view(flat) + q, k = rotary_emb(positions, q, k) + q = q.view(num_heads, head_dim) + k = k.view(num_kv_heads, head_dim) + v = v.view(num_kv_heads, head_dim) + dummy = unified_kv_cache_update(k, v, layer_name) + """ + + def __init__( + self, + vllm_config: VllmConfig, + attn_backend: AttentionBackendEnum, + num_heads: int, + num_kv_heads: int, + head_size: int, + is_neox: bool, + rms_norm_eps: float, + dtype: torch.dtype, + device: torch.device, + prefix: str = "model.layers.0.self_attn.attn", + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.block_size = vllm_config.cache_config.block_size + self.q_size = num_heads * head_size + self.kv_size = num_kv_heads * head_size + self.is_neox = is_neox + self.dtype = dtype + self.device = device + self.layer_name = prefix + + self.q_norm = RMSNorm(head_size, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_size, eps=rms_norm_eps) + + self.rotary_emb = RotaryEmbedding( + head_size, + rotary_dim=head_size, + max_position_embeddings=4096, + base=10000, + is_neox_style=is_neox, + dtype=self.dtype, + ) + + self.enable_rope_custom_op = self.rotary_emb.enabled() + + self.attn = Attention( + num_heads=num_heads, + head_size=head_size, + scale=1.0 / head_size**0.5, + num_kv_heads=num_kv_heads, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=prefix, + attn_backend=attn_backend.get_class(), + ) + self.attn_backend: type[AttentionBackend] = self.attn.get_attn_backend() + assert not self.attn_backend.forward_includes_kv_cache_update, ( + f"Attention backend {self.attn_backend} does not support " + "fuse_qk_norm_rope_kvcache." + ) + kv_cache_dtype_str = vllm_config.cache_config.cache_dtype + self.kv_cache_dtype = ( + FP8_DTYPE if kv_cache_dtype_str.startswith("fp8") else self.dtype + ) + + if self.kv_cache_dtype != self.dtype: + self.attn._k_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + self.attn._v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + self.attn._k_scale_float = 1.0 + self.attn._v_scale_float = 1.0 + else: + self.attn._k_scale = self.attn._k_scale.to(device) + self.attn._v_scale = self.attn._v_scale.to(device) + + self.builder = self.attn.attn_backend.get_builder_cls()( + kv_cache_spec=AttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=head_size, + dtype=self.kv_cache_dtype, + ), + layer_names=[self.attn.layer_name], + vllm_config=vllm_config, + device=device, + ) + + def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata: + batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, self.block_size, self.device, arange_block_indices=True + ) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size + num_blocks = batch_size * max_blocks + + attn_backend = self.attn.attn_backend + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size + ) + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + + raw_tensor = torch.zeros( + 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + raw_tensor = raw_tensor.view(kv_cache_shape) + kv_cache = raw_tensor.permute(*inv_order) + + # Store as a bare tensor (not wrapped in a list) to match production + # `bind_kv_cache` behavior. `get_attention_context` returns this + # attribute directly to the fused/unfused `do_kv_cache_update` impls, + # which call `kv_cache.unbind(0)` and therefore require a tensor. + self.attn.kv_cache = kv_cache + + attn_metadata = self.builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + + return attn_metadata + + def forward( + self, qkv: torch.Tensor, positions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + qkv = qkv.clone() + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # QK-norm: RMSNorm on per-head Q and K + q = q.view(-1, self.num_heads, self.head_size) + q = self.q_norm(q) + q = q.view(-1, self.q_size) + + k = k.view(-1, self.num_kv_heads, self.head_size) + k = self.k_norm(k) + k = k.view(-1, self.kv_size) + + # RoPE + q, k = self.rotary_emb(positions, q, k) + + # Final views + KV cache update + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size) + kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( + k, v, self.layer_name + ) + return q, k, v, kv_cache_dummy_dep + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + # Note: RMSNorm is no longer asserted here. After the vLLM IR + # migration (#33825), `RMSNorm` dispatches through `ir.ops.rms_norm` + # which resolves via `IrOpPriorityConfig`. The op that actually + # appears in the pre-pass graph depends on the platform's priority + # (native / vllm_c / aiter / oink / ...) and is outside the scope of + # this fusion test. + ops: list[torch._ops.OpOverload] = [] + # RoPE is not yet IR-migrated, so its custom op still surfaces + # directly in the graph based on `enable_rope_custom_op`. + if self.enable_rope_custom_op: + if rocm_aiter_ops.is_triton_rotary_embed_enabled(): + ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default) + else: + ops.append(ROTARY_OP) + else: + ops.append(INDEX_SELECT_OP) + ops.append(torch.ops.vllm.unified_kv_cache_update.default) + return ops + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + return [torch.ops.vllm.fused_qk_norm_rope_and_unified_kv_cache_update.default] + + +@pytest.mark.parametrize( + "attn_backend", + [ + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + ], +) +@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False]) +@pytest.mark.parametrize("num_heads", [64]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("is_neox", [True, False]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("rms_norm_eps", [1e-5, 1e-6]) +@pytest.mark.skipif( + not is_aiter_found_and_supported(), + reason="Only test on ROCm with AITER installed and supported", +) +def test_qk_norm_rope_kvcache_fusion( + attn_backend: AttentionBackendEnum, + enable_aiter_triton_rope: bool, + num_heads: int, + num_kv_heads: int, + head_size: int, + block_size: int, + is_neox: bool, + dtype: torch.dtype, + kv_cache_dtype: str, + rms_norm_eps: float, + monkeypatch: pytest.MonkeyPatch, +): + device = os.environ.get("VLLM_TEST_CUDA_DEVICE", "cuda") + torch.set_default_device(device) + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + # Note: `+rms_norm` toggles between RMSNorm.forward_custom and + # forward_native, but both paths now dispatch through `ir.ops.rms_norm` + # (post #33825), so the graph is identical either way. We always enable + # it here to exercise the "custom op on" flavor. + custom_ops: list[str] = ["+rotary_embedding", "+rms_norm"] + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + cache_config=CacheConfig( + block_size=block_size, + cache_dtype=kv_cache_dtype, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig( + fuse_qk_norm_rope_kvcache=True, + eliminate_noops=True, + ), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", "1") + m.setenv( + "VLLM_ROCM_USE_AITER_TRITON_ROPE", + "1" if enable_aiter_triton_rope else "0", + ) + rocm_aiter_ops.refresh_env_variables() + + model = QKNormRoPEKVCacheTestModel( + vllm_config=vllm_config, + attn_backend=attn_backend, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + is_neox=is_neox, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=torch.get_default_device(), + ) + + fusion_pass = QkNormRopeKvCacheFusionPass(vllm_config) + passes = [ + NoOpEliminationPass(vllm_config), + SplitCoalescingPass(vllm_config), + ScatterSplitReplacementPass(vllm_config), + fusion_pass, + PostCleanupPass(vllm_config), + ] + backend = TestBackend(*passes) + + T = 5 + + qkv = torch.randn( + T, + num_heads * head_size + 2 * num_kv_heads * head_size, + dtype=dtype, + ) + pos = torch.arange(T, dtype=torch.long) + + qkv_unfused = qkv.clone() + pos_unfused = pos.clone() + + # Run unfused (eager) forward + with set_forward_context(None, vllm_config): + forward_context = get_forward_context() + attn_metadata = model.build_attn_metadata(T) + forward_context.slot_mapping = { + model.layer_name: attn_metadata.slot_mapping + } + q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused) + attn_layer = forward_context.no_compile_layers[model.layer_name] + kv_cache_unfused = attn_layer.kv_cache + del dummy + + # Run fused (compiled) forward + torch._dynamo.mark_dynamic(qkv, 0) + torch._dynamo.mark_dynamic(pos, 0) + with set_forward_context(None, vllm_config): + model_fused = torch.compile(model, backend=backend) + forward_context = get_forward_context() + attn_metadata = model_fused.build_attn_metadata(T) + forward_context.slot_mapping = { + model.layer_name: attn_metadata.slot_mapping + } + q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos) + attn_layer = forward_context.no_compile_layers[model.layer_name] + kv_cache_fused = attn_layer.kv_cache + del dummy + + assert fusion_pass.matched_count == 1 + + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) + + ATOL, RTOL = (1e-2, 1e-2) + is_fp8_cache = model.kv_cache_dtype != dtype + + torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL) + + if not is_fp8_cache: + # The AITER PTS kernel populates k_out only for non-FP8 caches. + # With FP8, the kernel writes quantized K directly to the cache + # and may leave k_out uninitialised. In production this is fine + # because downstream attention reads K from the cache. + torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL) + + torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL) + + uses_interleaved_v = getattr(model.attn.impl, "_use_interleaved_v_cache", False) + cache_atol = 5e-2 if is_fp8_cache else ATOL + cache_rtol = 1.0 if is_fp8_cache else RTOL + + # K-cache: same layout for both paths, always compare directly. + torch.testing.assert_close( + kv_cache_unfused[0].view(dtype), + kv_cache_fused[0].view(dtype), + atol=cache_atol, + rtol=cache_rtol, + ) + + if uses_interleaved_v: + # The fused AITER kernel writes V-cache in interleaved layout + # [blocks, heads, block_size/x, head_dim, x] while the unfused + # write_to_paged_cache uses standard [blocks, heads, head_dim, + # block_size]. Transform interleaved → standard before comparing. + # + # split_kv_cache views the raw [n, BS, H, D] as [n, H, D, BS]. + # In that view the interleaved data is laid out as + # [BS//x, D, x] per (block, head), so: + # reshape → [n, H, BS//x, D, x] + # permute → [n, H, D, BS//x, x] + # reshape → [n, H, D, BS] (standard layout) + x_il = 16 // kv_cache_fused.element_size() + n_blk = kv_cache_fused.shape[1] + + v_unfused_view = kv_cache_unfused[1].view( + n_blk, num_kv_heads, head_size, block_size + ) + v_fused_view = kv_cache_fused[1].view( + n_blk, num_kv_heads, head_size, block_size + ) + v_fused_std = ( + v_fused_view.reshape( + n_blk, num_kv_heads, block_size // x_il, head_size, x_il + ) + .permute(0, 1, 3, 2, 4) + .contiguous() + .reshape(n_blk, num_kv_heads, head_size, block_size) + ) + torch.testing.assert_close( + v_unfused_view.view(dtype), + v_fused_std.view(dtype), + atol=cache_atol, + rtol=cache_rtol, + ) + else: + torch.testing.assert_close( + kv_cache_unfused[1].view(dtype), + kv_cache_fused[1].view(dtype), + atol=cache_atol, + rtol=cache_rtol, + ) + + +@pytest.mark.skipif( + not is_aiter_found_and_supported(), + reason="Only test on ROCm with AITER installed and supported", +) +def test_qk_norm_rope_kvcache_pattern_match_smoke( + monkeypatch: pytest.MonkeyPatch, +): + """Minimal smoke test for the QK-norm+RoPE+KVCache pattern matcher. + + Verifies that the fusion pass finds and replaces the unfused pattern + exactly once. Skips the full accuracy + KV cache comparison done by + ``test_qk_norm_rope_kvcache_fusion`` so it runs in a few seconds and + is suitable for iterating on the matcher itself. + """ + device = os.environ.get("VLLM_TEST_CUDA_DEVICE", "cuda") + dtype = torch.bfloat16 + torch.set_default_device(device) + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + cache_config=CacheConfig(block_size=16, cache_dtype="auto"), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rotary_embedding", "+rms_norm"], + pass_config=PassConfig( + fuse_qk_norm_rope_kvcache=True, + eliminate_noops=True, + ), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", "1") + m.setenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "0") + rocm_aiter_ops.refresh_env_variables() + + model = QKNormRoPEKVCacheTestModel( + vllm_config=vllm_config, + attn_backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, + num_heads=64, + num_kv_heads=8, + head_size=64, + is_neox=True, + rms_norm_eps=1e-5, + dtype=dtype, + device=torch.get_default_device(), + ) + + fusion_pass = QkNormRopeKvCacheFusionPass(vllm_config) + backend = TestBackend( + NoOpEliminationPass(vllm_config), + SplitCoalescingPass(vllm_config), + ScatterSplitReplacementPass(vllm_config), + fusion_pass, + PostCleanupPass(vllm_config), + ) + + T = 5 + qkv = torch.randn(T, 64 * 64 + 2 * 8 * 64, dtype=dtype) + pos = torch.arange(T, dtype=torch.long) + torch._dynamo.mark_dynamic(qkv, 0) + torch._dynamo.mark_dynamic(pos, 0) + + with set_forward_context(None, vllm_config): + forward_context = get_forward_context() + attn_metadata = model.build_attn_metadata(T) + forward_context.slot_mapping = { + model.layer_name: attn_metadata.slot_mapping + } + model_fused = torch.compile(model, backend=backend) + model_fused(qkv, pos) + + assert fusion_pass.matched_count == 1, ( + f"Expected matched_count == 1, got {fusion_pass.matched_count}" + ) + # Verify the fused op ended up in the post-pass graph. We skip + # `check_before_ops` here because the pre-pass RMS-norm impl depends + # on `IrOpPriorityConfig` (native / vllm_c / aiter / ...), which is + # orthogonal to what this smoke test is validating. + backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index ce4fc3cfbadb..939d6ed9e0a1 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2128,6 +2128,63 @@ def triton_fp4_gemm_dynamic_quant( gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) return y + @staticmethod + def fused_qk_norm_rope_and_cache( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + positions: torch.Tensor, + num_heads_q: int, + num_heads_k: int, + num_heads_v: int, + head_dim: int, + is_neox: bool, + rms_norm_eps: float, + q_out: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + slot_mapping: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + k_out: torch.Tensor | None, + v_out: torch.Tensor | None, + return_kv: bool, + use_shuffle_layout: bool, + block_size: int, + x: int, + ): + from aiter.ops.fused_qk_norm_rope_cache_quant import ( + fused_qk_norm_rope_cache_pts_quant_shuffle, + ) + + fused_qk_norm_rope_cache_pts_quant_shuffle( + qkv, + q_weight, + k_weight, + cos_sin_cache, + positions, + qkv.size(0), + num_heads_q, + num_heads_k, + num_heads_v, + head_dim, + is_neox, + rms_norm_eps, + q_out, + k_cache, + v_cache, + slot_mapping, + k_scale, + v_scale, + k_out, + v_out, + return_kv, + use_shuffle_layout, + block_size, + x, + ) + @staticmethod def triton_rope_and_cache( query: torch.Tensor, diff --git a/vllm/compilation/passes/fusion/act_quant_fusion.py b/vllm/compilation/passes/fusion/act_quant_fusion.py index e35fc5cd4084..7612112596b4 100644 --- a/vllm/compilation/passes/fusion/act_quant_fusion.py +++ b/vllm/compilation/passes/fusion/act_quant_fusion.py @@ -39,7 +39,9 @@ if silu_and_mul_nvfp4_quant_supported: FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 -if current_platform.is_cuda_alike(): +if current_platform.is_cuda_alike() and hasattr( + torch.ops._C, "silu_and_mul_per_block_quant" +): FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index 9f25a6805e93..ed7e2b87b003 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -7,10 +7,11 @@ from torch._higher_order_ops import auto_functionalized from torch._ops import OpOverload +from vllm import ir from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNormGated +from vllm.model_executor.layers.layernorm import RMSNorm, RMSNormGated from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -29,6 +30,8 @@ ) from vllm.platforms import current_platform +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default ROTARY_OP = torch.ops._C.rotary_embedding.default FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default @@ -161,6 +164,48 @@ def forward_native( return result +class MatcherRMSNorm(MatcherCustomOp): + """Matcher for plain RMS norm (no residual add). + + Dispatches through ``vllm.ir.ops.rms_norm`` so the traced pattern + follows the same IR lowering path as the model's ``RMSNorm`` layer + (native / vllm_c / aiter / oink / ...), whichever one the current + ``IrOpPriorityConfig`` selects. This keeps the pattern aligned with + whatever impl actually appears in the target graph at runtime; callers + therefore do not need to register per-backend variants. + """ + + def __init__( + self, + epsilon: float, + enabled: bool | None = None, + ) -> None: + if enabled is None: + enabled = RMSNorm.enabled() + + super().__init__(enabled) + self.epsilon = epsilon + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16) + weight = self.empty(16) + return [input, weight] + + def forward_custom( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return ir.ops.rms_norm(input, weight, self.epsilon) + + def forward_native( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return ir.ops.rms_norm(input, weight, self.epsilon) + + class MatcherRMSNormGated(MatcherCustomOp): """Matches RMSNormGated with norm_before_gate=True and group_size=None.""" diff --git a/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py index b7e747a784eb..28cb89f96f71 100644 --- a/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py +++ b/vllm/compilation/passes/fusion/qk_norm_rope_fusion.py @@ -11,6 +11,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass import vllm.ir.ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -18,13 +19,18 @@ from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass -from .matcher_utils import MatcherRotaryEmbedding +from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64 logger = init_logger(__name__) FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default +# Head dimensions supported by csrc/fused_qknorm_rope_kernel.cu's +# launchFusedQKNormRope and launchFusedQKNormRopeNTokenHeads dispatchers. +# Keep in sync with the switch statements in that file. +SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS: tuple[int, ...] = (64, 128, 256) + P = ParamSpec("P") @@ -58,6 +64,7 @@ def __init__( eps: float, is_neox: bool, rope_flashinfer: bool = False, + match_rocm_aiter_rope: bool = False, ) -> None: self.num_heads = num_heads self.num_kv_heads = num_kv_heads @@ -65,6 +72,7 @@ def __init__( self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.eps = eps + self.rmsnorm_matcher = MatcherRMSNorm(eps) self.is_neox = is_neox self.rope_flashinfer = rope_flashinfer self.rope_matcher = MatcherRotaryEmbedding( @@ -73,6 +81,7 @@ def __init__( num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, use_flashinfer=self.rope_flashinfer, + match_rocm_aiter=match_rocm_aiter_rope if match_rocm_aiter_rope else None, ) def get_inputs(self) -> list[torch.Tensor]: @@ -186,7 +195,12 @@ def replacement( class QKNormRoPEFusionPass(VllmPatternMatcherPass): - """Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.""" + """Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists. + + Registers patterns for both standard vLLM ops and ROCm AITER ops + (when AITER is enabled), so the fusion fires regardless of which + RMSNorm/RoPE implementation the graph uses. + """ @enable_fake_mode def __init__(self, config: VllmConfig) -> None: @@ -202,7 +216,6 @@ def __init__(self, config: VllmConfig) -> None: ) return - # use one attn layer to get meta (such as head_dim) for QkNormRopePattern attn_layers: dict[str, Attention] = get_layers_from_vllm_config( config, Attention ) @@ -213,26 +226,48 @@ def __init__(self, config: VllmConfig) -> None: return layer = next(iter(attn_layers.values())) - for epsilon in [1e-5, 1e-6]: - for neox in [True, False]: - if RotaryEmbedding.enabled(): - for rope_flashinfer in [False, True]: + if layer.head_size not in SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS: + logger.warning_once( + "QK Norm+RoPE fusion not enabled: layer head_size=%d is not " + "supported by fused_qk_norm_rope kernel (supported: %s). " + "Falling back to unfused QK norm + RoPE path.", + layer.head_size, + SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS, + ) + return + + # RMS norm variants are no longer iterated: after the vLLM IR + # migration (#33825), `MatcherRMSNorm` dispatches via + # `ir.ops.rms_norm`, which resolves to the same backend (native / + # vllm_c / aiter / oink / ...) that the model's RMSNorm layer + # picks. The pattern graph tracks the target graph automatically. + aiter_rope_variants = [False] + if rocm_aiter_ops.is_triton_rotary_embed_enabled(): + aiter_rope_variants.append(True) + + for aiter_rope in aiter_rope_variants: + for epsilon in [1e-5, 1e-6]: + for neox in [True, False]: + if RotaryEmbedding.enabled(): + for rope_flashinfer in [False, True]: + QkNormRopePattern( + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon, + is_neox=neox, + rope_flashinfer=rope_flashinfer, + match_rocm_aiter_rope=aiter_rope, + ).register(self.patterns) + else: QkNormRopePattern( head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, eps=epsilon, is_neox=neox, - rope_flashinfer=rope_flashinfer, + match_rocm_aiter_rope=aiter_rope, ).register(self.patterns) - else: - QkNormRopePattern( - head_dim=layer.head_size, - num_heads=layer.num_heads, - num_kv_heads=layer.num_kv_heads, - eps=epsilon, - is_neox=neox, - ).register(self.patterns) self.dump_patterns(config, self.patterns) diff --git a/vllm/compilation/passes/fusion/qk_norm_rope_kvcache_fusion.py b/vllm/compilation/passes/fusion/qk_norm_rope_kvcache_fusion.py new file mode 100644 index 000000000000..ac564fc39882 --- /dev/null +++ b/vllm/compilation/passes/fusion/qk_norm_rope_kvcache_fusion.py @@ -0,0 +1,398 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.fx_passes.post_grad import view_to_reshape +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.utils import Range +from vllm.logger import init_logger +from vllm.model_executor.layers.attention.attention import ( + Attention, + get_attention_context, +) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.utils.torch_utils import direct_register_custom_op + +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding +from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64 + +logger = init_logger(__name__) + + +# --------------------------------------------------------------------------- +# Custom op: fused QK-norm + RoPE + KV cache update +# --------------------------------------------------------------------------- + + +def fused_qk_norm_rope_and_unified_kv_cache_update_impl( + q_out: torch.Tensor, + k_out: torch.Tensor, + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + rms_norm_eps: float, + cos_sin_cache: torch.Tensor, + is_neox: bool, + layer_name: str = "", +) -> torch.Tensor: + _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) + if layer_slot_mapping is not None: + attn_layer.impl.do_qk_norm_rope_kvcache_update( + attn_layer, + qkv, + q_out, + k_out, + positions, + q_weight, + k_weight, + rms_norm_eps, + cos_sin_cache, + is_neox, + kv_cache, + layer_slot_mapping, + ) + + return torch.empty(0, device=qkv.device, dtype=qkv.dtype) + + +def fused_qk_norm_rope_and_unified_kv_cache_update_fake( + q_out: torch.Tensor, + k_out: torch.Tensor, + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + rms_norm_eps: float, + cos_sin_cache: torch.Tensor, + is_neox: bool, + layer_name: str = "", +) -> torch.Tensor: + return torch.empty(0, device=qkv.device, dtype=qkv.dtype) + + +direct_register_custom_op( + op_name="fused_qk_norm_rope_and_unified_kv_cache_update", + op_func=fused_qk_norm_rope_and_unified_kv_cache_update_impl, + mutates_args=["q_out", "k_out"], + fake_impl=fused_qk_norm_rope_and_unified_kv_cache_update_fake, +) + + +# --------------------------------------------------------------------------- +# Pattern: QK-norm + RoPE + unified_kv_cache_update +# --------------------------------------------------------------------------- + + +class QkNormRopeKvCachePattern: + """ + Match the unfused sequence: + q, k, v = split(qkv, ...) + q = rms_norm(q.view(heads), q_weight).view(flat) + k = rms_norm(k.view(heads), k_weight).view(flat) + q, k = rotary_embedding(positions, q, k, cos_sin_cache, is_neox) + q = q.view(num_heads, head_dim) + k = k.view(num_kv_heads, head_dim) + v = v.view(num_kv_heads, head_dim) + dummy = unified_kv_cache_update(k, v, layer_name) + + Replace with: + q_out = empty(...) + k_out = empty(...) + dummy = fused_qk_norm_rope_and_unified_kv_cache_update( + q_out, k_out, qkv, positions, q_weight, k_weight, + eps, cos_sin_cache, is_neox, layer_name) + v = split(qkv, ...)[2].view(num_kv_heads, head_dim) + """ + + FUSED_OP = torch.ops.vllm.fused_qk_norm_rope_and_unified_kv_cache_update.default + + def __init__( + self, + layer: Attention, + eps: float, + is_neox: bool, + rope_flashinfer: bool = False, + match_rocm_aiter_rope: bool = False, + ) -> None: + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.num_kv_heads = layer.num_kv_heads + self.head_size = layer.head_size + self.head_size_v = layer.head_size_v + self.eps = eps + self.is_neox = is_neox + self.rope_flashinfer = rope_flashinfer + + self.q_size = self.num_heads * self.head_size + self.k_size = self.num_kv_heads * self.head_size + self.v_size = self.num_kv_heads * self.head_size_v + + self.rmsnorm_matcher = MatcherRMSNorm(eps) + self.rope_matcher = MatcherRotaryEmbedding( + is_neox=is_neox, + head_size=self.head_size, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + use_flashinfer=rope_flashinfer, + match_rocm_aiter=match_rocm_aiter_rope if match_rocm_aiter_rope else None, + ) + + def get_inputs(self) -> list[torch.Tensor]: + T = 5 + L = 4096 + qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size) + positions = empty_i64(T) + q_weight = empty_bf16(1, self.head_size) + k_weight = empty_bf16(1, self.head_size) + if self.rope_flashinfer: + cos_sin_cache = empty_fp32(L, self.head_size) + else: + cos_sin_cache = empty_bf16(L, self.head_size) + return [qkv, positions, q_weight, k_weight, cos_sin_cache] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + + q_by_head = q.view(-1, self.q_size // self.head_size, self.head_size) + q_normed = self.rmsnorm_matcher(q_by_head, q_weight) + q_flat = q_normed.view(-1, self.q_size) + + k_by_head = k.view(-1, self.k_size // self.head_size, self.head_size) + k_normed = self.rmsnorm_matcher(k_by_head, k_weight) + k_flat = k_normed.view(-1, self.k_size) + + q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache) + + q_rope = q_rope.view(-1, self.num_heads, self.head_size) + k_rope = k_rope.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size_v) + dummy = torch.ops.vllm.unified_kv_cache_update(k_rope, v, self.layer_name) + return dummy, q_rope, k_rope, v + + def replacement( + qkv: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q_out = torch.empty( + qkv.shape[0], + self.num_heads, + self.head_size, + device=qkv.device, + dtype=qkv.dtype, + ) + k_out = torch.empty( + qkv.shape[0], + self.num_kv_heads, + self.head_size, + device=qkv.device, + dtype=qkv.dtype, + ) + _, _, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + v = v.view(qkv.shape[0], self.num_kv_heads, self.head_size_v) + + results = auto_functionalized( + self.FUSED_OP, + q_out=q_out, + k_out=k_out, + qkv=qkv, + positions=positions, + q_weight=q_weight, + k_weight=k_weight, + rms_norm_eps=self.eps, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + layer_name=self.layer_name, + ) + + # results[0] = dummy, results[1] = q_out, results[2] = k_out + return results[0], results[1], results[2], v + + def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule: + gm = pm.fwd_only(*args, **kwargs) + view_to_reshape(gm) + return gm + + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + fwd_and_view_to_reshape, + pm_pass, + ) + + +# --------------------------------------------------------------------------- +# Pass class +# --------------------------------------------------------------------------- + + +class QkNormRopeKvCacheFusionPass(VllmPatternMatcherPass): + """ + Fuse QK-norm + RoPE + KV cache update into a single AITER HIP kernel. + + Supersedes both QKNormRoPEFusionPass and RopeKVCacheFusionPass for + attention layers that support the combined operation, eliminating two + separate kernel launches and the intermediate memory traffic. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig) -> None: + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="qk_norm_rope_kvcache_fusion_pass" + ) + + cc = config.compilation_config + self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num + + dtype = config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logger.warning_once( + "QK Norm+RoPE+KVCache fusion not enabled: unsupported dtype %s", dtype + ) + return + + attn_layers = get_layers_from_vllm_config(config, Attention) + + rope_custom_enabled = cc.is_custom_op_enabled("rotary_embedding") + rms_custom_enabled = cc.is_custom_op_enabled("rms_norm") + logger.debug( + "QkNormRopeKvCacheFusionPass init: " + "RotaryEmbedding.enabled()=%s, rope_custom_enabled=%s, " + "RMSNorm custom_op_enabled=%s", + RotaryEmbedding.enabled(), + rope_custom_enabled, + rms_custom_enabled, + ) + + # RMS norm variants are no longer iterated: after the vLLM IR + # migration (#33825), `MatcherRMSNorm` dispatches via + # `ir.ops.rms_norm`, which resolves to the same backend (native / + # vllm_c / aiter / oink / ...) that the model's RMSNorm layer + # picks. The pattern graph tracks the target graph automatically. + aiter_rope_variants = [False] + if rocm_aiter_ops.is_triton_rotary_embed_enabled(): + aiter_rope_variants.append(True) + + for _, layer in attn_layers.items(): + if not layer.impl.fused_qk_norm_rope_kvcache_supported(): + continue + layer.impl.set_fused_kv_cache_layout() + for aiter_rope in aiter_rope_variants: + for epsilon in [1e-5, 1e-6]: + for neox in [True, False]: + if RotaryEmbedding.enabled(): + for rope_flashinfer in [False, True]: + try: + QkNormRopeKvCachePattern( + layer=layer, + eps=epsilon, + is_neox=neox, + rope_flashinfer=rope_flashinfer, + match_rocm_aiter_rope=aiter_rope, + ).register(self.patterns) + except RuntimeError as e: + if "Duplicate pattern" in str(e): + logger.debug( + "Skipping duplicate pattern: " + "aiter_rope=%s eps=%s neox=%s fi=%s", + aiter_rope, + epsilon, + neox, + rope_flashinfer, + ) + else: + raise + else: + try: + QkNormRopeKvCachePattern( + layer=layer, + eps=epsilon, + is_neox=neox, + match_rocm_aiter_rope=aiter_rope, + ).register(self.patterns) + except RuntimeError as e: + if "Duplicate pattern" in str(e): + logger.debug( + "Skipping duplicate pattern: " + "aiter_rope=%s eps=%s neox=%s fi=N/A", + aiter_rope, + epsilon, + neox, + ) + else: + raise + + # Backends that set _use_interleaved_v_cache (e.g. ROCM_ATTN) + # require a consistent V-cache layout across ALL compile ranges. + # If max_token_num is too small, unfused ranges would write + # standard-layout V while the attention kernel reads interleaved, + # corrupting long-sequence generation. Force fusion to cover all + # ranges so both write and read paths agree on the layout. + max_batched = config.scheduler_config.max_num_batched_tokens + needs_full_coverage = any( + getattr(layer.impl, "_use_interleaved_v_cache", False) + for _, layer in attn_layers.items() + if layer.impl.fused_qk_norm_rope_kvcache_supported() + ) + if ( + needs_full_coverage + and max_batched is not None + and self.max_token_num < max_batched + ): + logger.info( + "Raising rope_kvcache_fusion_max_token_num from %d to %d " + "to maintain consistent interleaved V-cache layout across " + "all compile ranges (required by attention backend).", + self.max_token_num, + max_batched, + ) + self.max_token_num = max_batched + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + _orig_fx_to_pat = pm.fx_to_pattern + + def _relaxed_fx_to_pattern(*a, **kw): + kw["ignore_types"] = (int, torch.SymInt) + return _orig_fx_to_pat(*a, **kw) + + pm.fx_to_pattern = _relaxed_fx_to_pattern + try: + self.matched_count = self.patterns.apply(graph) + finally: + pm.fx_to_pattern = _orig_fx_to_pat + + logger.info( + "QK-Norm+RoPE+KVCache fusion: replaced %s pattern(s) " + "with AITER fused_qk_norm_rope_cache_pts_quant_shuffle", + self.matched_count, + ) + + def is_applicable_for_range(self, compile_range: Range) -> bool: + return compile_range.end <= self.max_token_num + + def uuid(self) -> str: + return VllmInductorPass.hash_source(self, QkNormRopeKvCachePattern) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 9c86518a946e..693a111e5894 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -35,6 +35,7 @@ from .fusion.mla_attn_quant_fusion import MLAAttnQuantFusionPass from .fusion.mla_rope_kvcache_cat_fusion import MLARoPEKVCacheCatFusionPass from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass + from .fusion.qk_norm_rope_kvcache_fusion import QkNormRopeKvCacheFusionPass from .fusion.rms_quant_fusion import RMSNormQuantFusionPass from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass from .fusion.sequence_parallelism import SequenceParallelismPass @@ -169,6 +170,14 @@ def configure(self, config: VllmConfig) -> None: if rocm_aiter_ops.is_enabled(): self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] + if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)] + + if self.pass_config.fuse_qk_norm_rope_kvcache: + self.passes += [SplitCoalescingPass(config)] + self.passes += [ScatterSplitReplacementPass(config)] + self.passes += [QkNormRopeKvCacheFusionPass(config)] + if ( self.pass_config.fuse_mla_dual_rms_norm and rocm_aiter_ops.is_enabled() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 7b9478035ece..91e44c4bf987 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -148,10 +148,16 @@ class PassConfig: """Fuse paired q/kv RMS norms in MLA attention.""" fuse_rope_kvcache: bool = None # type: ignore[assignment] """Fuse the QK rope + KV cache ops.""" + fuse_qk_norm_rope_kvcache: bool = Field(default=None) # type: ignore[assignment] + """Fuse QK RMSNorm + RoPE + KV cache update into a single AITER HIP + kernel. Supersedes both enable_qk_norm_rope_fusion and fuse_rope_kvcache + for layers that support it. Auto-enabled at O1+ on ROCm for models + with QK-norm (e.g. Qwen3-MoE).""" rope_kvcache_fusion_max_token_num: int = 256 """The threshold for ROCm AITER RoPE+KVCache fusion e.g. for small batch decode. Larger batch sizes e.g. during prefill will use the unfused kernels. + Also applies to the fused QK-Norm+RoPE+KVCache pass. """ fi_allreduce_fusion_max_size_mb: float | None = None @@ -230,6 +236,8 @@ def compute_hash(self) -> str: "fuse_act_padding", "fuse_mla_dual_rms_norm", "fuse_rope_kvcache", + "fuse_qk_norm_rope_kvcache", + "enable_qk_norm_rope_fusion", "fuse_rope_kvcache_cat_mla", mode="wrap", ) @@ -288,6 +296,12 @@ def __post_init__(self) -> None: "The fusion will be disabled." ) self.fuse_rope_kvcache = False + if self.fuse_qk_norm_rope_kvcache and not current_platform.is_rocm(): + logger.warning_once( + "QK-Norm+RoPE+KVCache fusion requires ROCm with AITER. " + "The fusion will be disabled." + ) + self.fuse_qk_norm_rope_kvcache = False if self.fuse_rope_kvcache_cat_mla and not current_platform.is_cuda_alike(): logger.warning_once( "MLA KV cache update with RoPE fusion enabled but the " @@ -302,10 +316,13 @@ def log_enabled_passes(self) -> None: after all defaults are finalized. TODO also log the compile ranges for which this is enabled. """ + fusion_prefixes = ("fuse_", "enable_") enabled_fusions = [ - f.name[len("fuse_") :] + f.name[len(prefix) :] for f in fields(self) # type: ignore[arg-type] - if getattr(self, f.name) and f.name.startswith("fuse_") + if getattr(self, f.name) + for prefix in fusion_prefixes + if f.name.startswith(prefix) ] if enabled_fusions: @@ -946,6 +963,7 @@ def __post_init__(self) -> None: # TODO(zhuhaoran): support rope native forward match and remove this. # Linked issue: https://github.com/vllm-project/vllm/issues/28042 self.custom_ops.append("+rotary_embedding") + if ( self.pass_config.fuse_rope_kvcache and "+rotary_embedding" not in self.custom_ops @@ -954,6 +972,12 @@ def __post_init__(self) -> None: # Linked issue: https://github.com/vllm-project/vllm/issues/28042 self.custom_ops.append("+rotary_embedding") + if ( + self.pass_config.fuse_qk_norm_rope_kvcache + and "+rotary_embedding" not in self.custom_ops + ): + self.custom_ops.append("+rotary_embedding") + if ( is_torch_equal_or_newer("2.9.0.dev") and "combo_kernels" not in self.inductor_compile_config @@ -1136,6 +1160,16 @@ def set_splitting_ops_for_v1( "to enable RoPE+KV cache fusion." ) self.pass_config.fuse_rope_kvcache = False + if self.pass_config.fuse_qk_norm_rope_kvcache: + logger.warning_once( + "fuse_qk_norm_rope_kvcache is enabled, but " + "splitting_ops is None and Inductor graph partition " + "is not enabled. Disabling fuse_qk_norm_rope_kvcache. " + "Please either set splitting_ops to an empty list [] " + "or set use_inductor_graph_partition to True " + "to enable QK-Norm+RoPE+KV cache fusion." + ) + self.pass_config.fuse_qk_norm_rope_kvcache = False self.splitting_ops.append("vllm::unified_kv_cache_update") self.splitting_ops.append("vllm::unified_mla_kv_cache_update") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index f009dd6f154e..11fc8af23f64 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -185,6 +185,71 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: return rocm_aiter_ops.is_enabled() and check_aiter_fused_qk_rmsnorm() +# Architectures where QK-norm is hardcoded (not a config option). +# The fused_qk_norm_rope kernel matches the pattern: +# split(QKV) -> view -> RMSNorm(Q) -> view -> RMSNorm(K) -> view -> RoPE +# Add model_type values here when new architectures hardcode QK-norm +# with vllm's standard RMSNorm. +_QK_NORM_MODEL_TYPES = frozenset( + { + "qwen3", + "qwen3_moe", + } +) + + +def enable_qk_norm_rope_kvcache(cfg: "VllmConfig") -> bool: + """Enable fused QK-norm + RoPE + KV cache update for models with + QK-norm on ROCm with AITER. Requires rotary embedding custom op. + + Note: this callable does not check use_inductor_graph_partition; + if the user enables fuse_qk_norm_rope_kvcache without it, + CompilationConfig disables the fusion at compile-time with a + warning (mirroring the fuse_rope_kvcache precedent).""" + from vllm._aiter_ops import rocm_aiter_ops + from vllm.platforms import current_platform + + if not current_platform.is_rocm(): + return False + if not rocm_aiter_ops.is_enabled(): + return False + if cfg.model_config is None: + return False + hf_config = cfg.model_config.hf_text_config + has_qk_norm = getattr(hf_config, "qk_norm", False) + model_type = getattr(hf_config, "model_type", "") + if not has_qk_norm and model_type not in _QK_NORM_MODEL_TYPES: + return False + return cfg.compilation_config.is_custom_op_enabled("rotary_embedding") + + +def enable_qk_norm_rope(cfg: "VllmConfig") -> bool: + """Enable QK-norm + RoPE fusion for models with QK-norm layers. + + Detection uses two strategies: + 1. Explicit config: checks ``hf_config.qk_norm`` for models that + expose QK-norm as a config option (e.g. BAGEL uses the Qwen2 + architecture with ``qk_norm=True`` injected at load time). + 2. Architecture: checks ``hf_config.model_type`` against + ``_QK_NORM_MODEL_TYPES`` for models that hardcode QK-norm + in their architecture without a config.json field + (e.g. Qwen3 and Qwen3-MoE always use QK-norm). + """ + from vllm.platforms import current_platform + + if not current_platform.is_cuda_alike(): + return False + if cfg.model_config is None: + return False + hf_config = cfg.model_config.hf_text_config + # Some models inject qk_norm at load time (e.g. BAGEL on Qwen2 arch) + if getattr(hf_config, "qk_norm", False): + return True + # Qwen3/Qwen3-MoE hardcode QK-norm - no config.json field exists + model_type = getattr(hf_config, "model_type", "") + return model_type in _QK_NORM_MODEL_TYPES + + OPTIMIZATION_LEVEL_00 = { "compilation_config": { "pass_config": { @@ -197,6 +262,8 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": False, "fuse_mla_dual_rms_norm": False, "fuse_rope_kvcache": False, + "fuse_qk_norm_rope_kvcache": False, + "enable_qk_norm_rope_fusion": False, "fuse_rope_kvcache_cat_mla": False, }, "cudagraph_mode": CUDAGraphMode.NONE, @@ -218,6 +285,8 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": False, + "fuse_qk_norm_rope_kvcache": enable_qk_norm_rope_kvcache, + "enable_qk_norm_rope_fusion": enable_qk_norm_rope, "fuse_rope_kvcache_cat_mla": False, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, @@ -239,6 +308,8 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion, + "fuse_qk_norm_rope_kvcache": enable_qk_norm_rope_kvcache, + "enable_qk_norm_rope_fusion": enable_qk_norm_rope, "fuse_rope_kvcache_cat_mla": enable_rope_kvcache_mla_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, @@ -260,6 +331,8 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "fuse_act_padding": enable_norm_pad_fusion, "fuse_mla_dual_rms_norm": enable_mla_dual_rms_norm_fusion, "fuse_rope_kvcache": enable_rope_kvcache_fusion, + "fuse_qk_norm_rope_kvcache": enable_qk_norm_rope_kvcache, + "enable_qk_norm_rope_fusion": enable_qk_norm_rope, "fuse_rope_kvcache_cat_mla": enable_rope_kvcache_mla_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, @@ -1115,6 +1188,22 @@ def has_blocked_weights(): "optimization level defaults." ) + # Fusion flags may have been auto-enabled by optimization-level + # callables above. CompilationConfig.__post_init__ already ran + # (during construction) when these flags were still None, so we + # must apply the dependent settings here. + pass_config = self.compilation_config.pass_config + if ( + pass_config.enable_qk_norm_rope_fusion + and "+rotary_embedding" not in self.compilation_config.custom_ops + ): + self.compilation_config.custom_ops.append("+rotary_embedding") + if ( + pass_config.fuse_qk_norm_rope_kvcache + and "+rotary_embedding" not in self.compilation_config.custom_ops + ): + self.compilation_config.custom_ops.append("+rotary_embedding") + if ( self.compilation_config.cudagraph_mode.requires_piecewise_compilation() and self.compilation_config.mode != CompilationMode.VLLM_COMPILE @@ -1846,6 +1935,21 @@ def _set_compile_ranges(self): compile_range_end, ) + if compilation_config.pass_config.fuse_qk_norm_rope_kvcache: + max_token_num = ( + compilation_config.pass_config.rope_kvcache_fusion_max_token_num + ) + if max_token_num is not None: + if compile_range_end is not None and max_token_num < compile_range_end: + computed_compile_ranges_endpoints.append(max_token_num) + else: + logger.debug( + "Max num batched tokens below qk_norm+rope+kvcache " + "fusion threshold, fusion enabled for " + "num_tokens <= %d.", + compile_range_end, + ) + if compilation_config.pass_config.fuse_minimax_qk_norm: from vllm.compilation.passes.fusion.minimax_qk_norm_fusion import ( MAX_TOKEN_NUM, diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index d83489238d33..60f8d6cd5e9f 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -812,6 +812,20 @@ def fused_output_quant_supported(self, quant_key: "QuantKey"): """ return False + def fused_qk_norm_rope_kvcache_supported(self): + """ + Does this attention implementation support fused QKNorm+RoPE+KVCache fusion. + This is used by the QkNormRopeKvCachePattern to only fuse the QKNorm ops + with the RoPE ops and the KV cache update for implementations that support it. + """ + return False + + def set_fused_kv_cache_layout(self): + """Called by the fusion pass after confirming this layer will use + the fused kernel. Backends that need to adjust their KV cache read + path (e.g. permute strides) should override this.""" + pass + def fused_rope_kvcache_supported(self): """ Does this attention implementation support RoPE+KVCache fusion. @@ -820,6 +834,29 @@ def fused_rope_kvcache_supported(self): """ return False + def do_qk_norm_rope_kvcache_update( + self, + layer: AttentionLayer, + qkv: torch.Tensor, + q_out: torch.Tensor, + k_out: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + rms_norm_eps: float, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + """ + If `fused_qk_norm_rope_kvcache_supported` returns True, this method + will be called by the fused custom op. Applies QK-norm + RoPE and + writes K/V to the KV cache. Results are written to the pre-allocated + q_out and k_out tensors; V is split from QKV at the graph level. + """ + raise NotImplementedError + def do_rope_and_kv_cache_update( self, layer: AttentionLayer, diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 5dbedc86bc02..67e8bdce2244 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" +import math from dataclasses import dataclass from typing import ClassVar @@ -816,6 +817,11 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self._cached_k_scale_val: float | None = None + self._cached_k_scale_cpu: torch.Tensor | None = None + self._cached_v_scale_val: float | None = None + self._cached_v_scale_cpu: torch.Tensor | None = None + if attn_type != AttentionType.DECODER: raise NotImplementedError( "Only decoder self-attention is supported for " @@ -1434,6 +1440,89 @@ def fused_rope_kvcache_supported(self): and not rocm_aiter_ops.is_shuffle_kv_cache_enabled() ) + def fused_qk_norm_rope_kvcache_supported(self): + # Fusion is supported in both shuffle and non-shuffle KV cache layouts. + return rocm_aiter_ops.is_enabled() + + def set_fused_kv_cache_layout(self): + # No-op: this backend uses the AITER flash attention kernel for + # decode, which reads V in the same layout the AITER fused write + # produces, so no layout adjustment is needed. + pass + + def do_qk_norm_rope_kvcache_update( + self, + layer: AttentionLayer, + qkv: torch.Tensor, + q_out: torch.Tensor, + k_out: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + rms_norm_eps: float, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + key_cache, value_cache = kv_cache.unbind(0) + + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) + + num_heads_q = self.num_heads + num_heads_k = self.num_kv_heads + num_heads_v = self.num_kv_heads + head_dim = self.head_size + use_shuffle_layout = rocm_aiter_ops.is_shuffle_kv_cache_enabled() + x = 16 // key_cache.element_size() + block_size = key_cache.shape[1] + + # Cache CPU scalar tensors for scales so the C++ kernel's .item() + # call doesn't trigger a device-to-host sync during CUDA graph capture. + k_scale_val = layer._k_scale_float + v_scale_val = layer._v_scale_float + if self._cached_k_scale_val is None or ( + self._cached_k_scale_val != k_scale_val + and not (math.isnan(self._cached_k_scale_val) and math.isnan(k_scale_val)) + ): + self._cached_k_scale_cpu = torch.tensor(k_scale_val, dtype=torch.float32) + self._cached_k_scale_val = k_scale_val + if self._cached_v_scale_val is None or ( + self._cached_v_scale_val != v_scale_val + and not (math.isnan(self._cached_v_scale_val) and math.isnan(v_scale_val)) + ): + self._cached_v_scale_cpu = torch.tensor(v_scale_val, dtype=torch.float32) + self._cached_v_scale_val = v_scale_val + + rocm_aiter_ops.fused_qk_norm_rope_and_cache( + qkv=qkv, + q_weight=q_weight, + k_weight=k_weight, + cos_sin_cache=cos_sin_cache, + positions=positions, + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + is_neox=is_neox, + rms_norm_eps=rms_norm_eps, + q_out=q_out, + k_cache=key_cache, + v_cache=value_cache, + slot_mapping=layer_slot_mapping, + k_scale=self._cached_k_scale_cpu, + v_scale=self._cached_v_scale_cpu, + k_out=k_out, + v_out=None, + return_kv=True, + use_shuffle_layout=use_shuffle_layout, + block_size=block_size, + x=x, + ) + def do_rope_and_kv_cache_update( self, layer: AttentionLayer, diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index f56b58c43e7f..8a44cb3c1e9f 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -263,6 +263,23 @@ def do_kv_cache_update( def fused_rope_kvcache_supported(self): return rocm_aiter_ops.is_enabled() + def fused_qk_norm_rope_kvcache_supported(self): + # Opt in even though the parent RocmAttentionImpl returns False: + # this backend uses the AITER triton unified attention kernel for + # decode (rather than the custom HIP ASM paged attention kernel), + # so it does not require the interleaved V-cache read path that + # is being added separately for ROCM_ATTN. The inherited + # do_qk_norm_rope_kvcache_update body writes K/V in standard + # layout (use_shuffle_layout=False unless globally enabled), which + # matches what this backend reads. + return rocm_aiter_ops.is_enabled() + + def set_fused_kv_cache_layout(self): + # No-op: this backend uses AITER flash/unified attention for decode, + # not the C++ HIP ASM paged attention kernel, so it does not need + # the interleaved V-cache read path that the parent would enable. + pass + def do_rope_and_kv_cache_update( self, layer: AttentionLayer, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index d533268e2176..f18767bfec33 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" +import math from dataclasses import dataclass from typing import ClassVar @@ -303,6 +304,11 @@ def __init__( f"num_heads: {num_heads}." ) + self._cached_k_scale_val: float | None = None + self._cached_k_scale_cpu: torch.Tensor | None = None + self._cached_v_scale_val: float | None = None + self._cached_v_scale_cpu: torch.Tensor | None = None + def _forward_encoder_attention( self, query: torch.Tensor, @@ -499,6 +505,98 @@ def do_kv_cache_update( layer._v_scale, ) + def fused_qk_norm_rope_kvcache_supported(self): + # Gated off for ROCM_ATTN itself in this PR. ROCM_ATTN engages + # the fused QK-norm + RoPE + KV-cache kernel only when the V-cache + # is also written in interleaved layout, which requires kernel + # changes to the custom HIP paged-attention decode kernel that are + # split out into a follow-up PR. The do_qk_norm_rope_kvcache_update + # body below is shipped here so RocmAiterUnifiedAttentionImpl + # (which uses the AITER triton unified attention kernel for decode + # and therefore does NOT need an interleaved V layout) can inherit + # it and opt in via its own supported() override. + return False + + def set_fused_kv_cache_layout(self): + # No-op until ROCM_ATTN itself opts in to the fused path + # (see the follow-up PR that adds USE_INTERLEAVED_V_CACHE in + # csrc/rocm/attention.cu and the matching INTERLEAVED_V_KX path + # in prefix_prefill). + pass + + def do_qk_norm_rope_kvcache_update( + self, + layer: AttentionLayer, + qkv: torch.Tensor, + q_out: torch.Tensor, + k_out: torch.Tensor, + positions: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + rms_norm_eps: float, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + key_cache, value_cache = kv_cache.unbind(0) + + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + + num_heads_q = self.num_heads + num_heads_k = self.num_kv_heads + num_heads_v = self.num_kv_heads + head_dim = self.head_size + use_shuffle_layout = rocm_aiter_ops.is_shuffle_kv_cache_enabled() + block_size = key_cache.shape[1] + x = 16 // key_cache.element_size() + + # Use CPU scalar tensors for scales so the C++ kernel's .item() + # call doesn't trigger a device-to-host sync during CUDA graph capture. + k_scale_val = layer._k_scale_float + v_scale_val = layer._v_scale_float + if self._cached_k_scale_val is None or ( + self._cached_k_scale_val != k_scale_val + and not (math.isnan(self._cached_k_scale_val) and math.isnan(k_scale_val)) + ): + self._cached_k_scale_cpu = torch.tensor(k_scale_val, dtype=torch.float32) + self._cached_k_scale_val = k_scale_val + if self._cached_v_scale_val is None or ( + self._cached_v_scale_val != v_scale_val + and not (math.isnan(self._cached_v_scale_val) and math.isnan(v_scale_val)) + ): + self._cached_v_scale_cpu = torch.tensor(v_scale_val, dtype=torch.float32) + self._cached_v_scale_val = v_scale_val + + rocm_aiter_ops.fused_qk_norm_rope_and_cache( + qkv=qkv, + q_weight=q_weight, + k_weight=k_weight, + cos_sin_cache=cos_sin_cache, + positions=positions, + num_heads_q=num_heads_q, + num_heads_k=num_heads_k, + num_heads_v=num_heads_v, + head_dim=head_dim, + is_neox=is_neox, + rms_norm_eps=rms_norm_eps, + q_out=q_out, + k_cache=key_cache, + v_cache=value_cache, + slot_mapping=layer_slot_mapping, + k_scale=self._cached_k_scale_cpu, + v_scale=self._cached_v_scale_cpu, + k_out=k_out, + v_out=None, + return_kv=True, + use_shuffle_layout=use_shuffle_layout, + block_size=block_size, + x=x, + ) + def fused_rope_kvcache_supported(self): return rocm_aiter_ops.is_enabled() From 093e48feb32c3f0745f4b1857c89d0a6e4e844e2 Mon Sep 17 00:00:00 2001 From: Olga Miroshnichenko Date: Tue, 26 May 2026 08:41:11 -0500 Subject: [PATCH 2/2] fix(rocm): plumb rotary_dim through fused QK-norm+RoPE+KV-cache kernel The aiter fused kernel fused_qk_norm_rope_cache_pts_quant_shuffle accepts a rotary_dim parameter (default 0 -> falls back to HEAD_SIZE). vLLM was never threading it through, so models with partial_rotary_factor < 1 (e.g. GLM-4.7 with factor=0.5 -> rotary_dim=64) silently had full RoPE applied to all 128 head dims, plus out-of-bounds reads on the half-length cos_sin_cache. Result on GLM-4.7-FP8: gsm8k 0.92 -> 0.79 with fusion enabled, restored to 0.92 with this fix. The fix derives rotary_dim from the cos_sin_cache tensor itself: RotaryEmbedding._compute_cos_sin_cache builds cache = cat(cos, sin, -1) where cos/sin each have last dim = rotary_dim/2, so cos_sin_cache.shape[-1] == rotary_dim holds for every RoPE variant. Adds a defensive layout check (if not (...): raise ValueError, not an assert, so the check survives python -O) to fail loud rather than silently produce wrong outputs on future RoPE variants. Touches the two callers of fused_qk_norm_rope_and_cache: - vllm/v1/attention/backends/rocm_aiter_fa.py (ROCM_AITER_FA) - vllm/v1/attention/backends/rocm_attn.py (UNIFIED_ATTN) and extends the helper signature in vllm/_aiter_ops.py. --- vllm/_aiter_ops.py | 9 ++++++++ vllm/v1/attention/backends/rocm_aiter_fa.py | 25 +++++++++++++++++++++ vllm/v1/attention/backends/rocm_attn.py | 25 +++++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 939d6ed9e0a1..52f709df9978 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -2153,7 +2153,15 @@ def fused_qk_norm_rope_and_cache( use_shuffle_layout: bool, block_size: int, x: int, + rotary_dim: int = 0, ): + # Partial-RoPE support: when rotary_dim < head_dim, the fused kernel + # rotates only the first `rotary_dim` elements of each head and + # leaves the remainder pass-through (e.g. GLM-4.7 has + # partial_rotary_factor=0.5, so head_dim=128 and rotary_dim=64). + # The aiter kernel treats rotary_dim==0 as "full head", so callers + # that don't pass it correctly silently apply full RoPE -> garbage + # outputs for partial-RoPE models. from aiter.ops.fused_qk_norm_rope_cache_quant import ( fused_qk_norm_rope_cache_pts_quant_shuffle, ) @@ -2183,6 +2191,7 @@ def fused_qk_norm_rope_and_cache( use_shuffle_layout, block_size, x, + rotary_dim, ) @staticmethod diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 67e8bdce2244..3cafe2abf45f 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1480,6 +1480,30 @@ def do_qk_norm_rope_kvcache_update( x = 16 // key_cache.element_size() block_size = key_cache.shape[1] + # Partial-RoPE: derive rotary_dim from the cos_sin_cache layout. + # RotaryEmbedding builds cache = cat(cos, sin, dim=-1) where + # cos/sin each have last dim = rotary_dim/2, so + # cos_sin_cache.shape[-1] == rotary_dim exactly. Without this the + # aiter fused kernel defaults rotary_dim to head_dim and applies + # full RoPE to partial-RoPE models (e.g. GLM-4.7 has + # partial_rotary_factor=0.5). + rotary_dim = cos_sin_cache.shape[-1] + # Fail loud if a future model arrives with an unexpected layout: + # silently using the wrong rotary_dim produces correct-looking but + # numerically wrong outputs (gsm8k 0.92 -> 0.79 regression cause). + # Use an explicit raise (not assert) so the check survives `python -O`. + if not ( + cos_sin_cache.ndim == 2 + and 0 < rotary_dim <= head_dim + and rotary_dim % 2 == 0 + ): + raise ValueError( + f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache " + f"layout {tuple(cos_sin_cache.shape)} for " + f"head_dim={head_dim}; expected shape [max_pos, rotary_dim] " + f"with rotary_dim<=head_dim and even" + ) + # Cache CPU scalar tensors for scales so the C++ kernel's .item() # call doesn't trigger a device-to-host sync during CUDA graph capture. k_scale_val = layer._k_scale_float @@ -1521,6 +1545,7 @@ def do_qk_norm_rope_kvcache_update( use_shuffle_layout=use_shuffle_layout, block_size=block_size, x=x, + rotary_dim=rotary_dim, ) def do_rope_and_kv_cache_update( diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index f18767bfec33..6557b2182aa8 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -554,6 +554,30 @@ def do_qk_norm_rope_kvcache_update( block_size = key_cache.shape[1] x = 16 // key_cache.element_size() + # Partial-RoPE: derive rotary_dim from the cos_sin_cache layout. + # RotaryEmbedding builds cache = cat(cos, sin, dim=-1) where + # cos/sin each have last dim = rotary_dim/2, so + # cos_sin_cache.shape[-1] == rotary_dim exactly. Without this the + # aiter fused kernel defaults rotary_dim to head_dim and applies + # full RoPE to partial-RoPE models (e.g. GLM-4.7 has + # partial_rotary_factor=0.5). + rotary_dim = cos_sin_cache.shape[-1] + # Fail loud if a future model arrives with an unexpected layout: + # silently using the wrong rotary_dim produces correct-looking but + # numerically wrong outputs (gsm8k 0.92 -> 0.79 regression cause). + # Use an explicit raise (not assert) so the check survives `python -O`. + if not ( + cos_sin_cache.ndim == 2 + and 0 < rotary_dim <= head_dim + and rotary_dim % 2 == 0 + ): + raise ValueError( + f"fused_qk_norm_rope_and_cache: unexpected cos_sin_cache " + f"layout {tuple(cos_sin_cache.shape)} for " + f"head_dim={head_dim}; expected shape [max_pos, rotary_dim] " + f"with rotary_dim<=head_dim and even" + ) + # Use CPU scalar tensors for scales so the C++ kernel's .item() # call doesn't trigger a device-to-host sync during CUDA graph capture. k_scale_val = layer._k_scale_float @@ -595,6 +619,7 @@ def do_qk_norm_rope_kvcache_update( use_shuffle_layout=use_shuffle_layout, block_size=block_size, x=x, + rotary_dim=rotary_dim, ) def fused_rope_kvcache_supported(self):