Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
520 changes: 520 additions & 0 deletions tests/compile/passes/test_qk_norm_rope_kvcache_fusion.py

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,72 @@ def triton_fp4_gemm_dynamic_quant(
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y

@staticmethod
def fused_qk_norm_rope_and_cache(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
cos_sin_cache: torch.Tensor,
positions: torch.Tensor,
num_heads_q: int,
num_heads_k: int,
num_heads_v: int,
head_dim: int,
is_neox: bool,
rms_norm_eps: float,
q_out: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
k_out: torch.Tensor | None,
v_out: torch.Tensor | None,
return_kv: bool,
use_shuffle_layout: bool,
block_size: int,
x: int,
rotary_dim: int = 0,
):
# Partial-RoPE support: when rotary_dim < head_dim, the fused kernel
# rotates only the first `rotary_dim` elements of each head and
# leaves the remainder pass-through (e.g. GLM-4.7 has
# partial_rotary_factor=0.5, so head_dim=128 and rotary_dim=64).
# The aiter kernel treats rotary_dim==0 as "full head", so callers
# that don't pass it correctly silently apply full RoPE -> garbage
# outputs for partial-RoPE models.
from aiter.ops.fused_qk_norm_rope_cache_quant import (
fused_qk_norm_rope_cache_pts_quant_shuffle,
)

fused_qk_norm_rope_cache_pts_quant_shuffle(
qkv,
q_weight,
k_weight,
cos_sin_cache,
positions,
qkv.size(0),
num_heads_q,
num_heads_k,
num_heads_v,
head_dim,
is_neox,
rms_norm_eps,
q_out,
k_cache,
v_cache,
slot_mapping,
k_scale,
v_scale,
k_out,
v_out,
return_kv,
use_shuffle_layout,
block_size,
x,
rotary_dim,
)

@staticmethod
def triton_rope_and_cache(
query: torch.Tensor,
Expand Down
4 changes: 3 additions & 1 deletion vllm/compilation/passes/fusion/act_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
if silu_and_mul_nvfp4_quant_supported:
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501

if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() and hasattr(
torch.ops._C, "silu_and_mul_per_block_quant"
):
FUSED_OPS[kFp8Dynamic128Sym] = torch.ops._C.silu_and_mul_per_block_quant.default
FUSED_OPS[kFp8Dynamic64Sym] = torch.ops._C.silu_and_mul_per_block_quant.default

Expand Down
47 changes: 46 additions & 1 deletion vllm/compilation/passes/fusion/matcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload

from vllm import ir
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm.model_executor.layers.layernorm import RMSNorm, RMSNormGated
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
Expand All @@ -29,6 +30,8 @@
)
from vllm.platforms import current_platform

RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
ROTARY_OP = torch.ops._C.rotary_embedding.default
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default

Expand Down Expand Up @@ -161,6 +164,48 @@ def forward_native(
return result


class MatcherRMSNorm(MatcherCustomOp):
"""Matcher for plain RMS norm (no residual add).

Dispatches through ``vllm.ir.ops.rms_norm`` so the traced pattern
follows the same IR lowering path as the model's ``RMSNorm`` layer
(native / vllm_c / aiter / oink / ...), whichever one the current
``IrOpPriorityConfig`` selects. This keeps the pattern aligned with
whatever impl actually appears in the target graph at runtime; callers
therefore do not need to register per-backend variants.
"""

def __init__(
self,
epsilon: float,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RMSNorm.enabled()

super().__init__(enabled)
self.epsilon = epsilon

def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]

def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return ir.ops.rms_norm(input, weight, self.epsilon)

def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return ir.ops.rms_norm(input, weight, self.epsilon)


class MatcherRMSNormGated(MatcherCustomOp):
"""Matches RMSNormGated with norm_before_gate=True and group_size=None."""

Expand Down
67 changes: 51 additions & 16 deletions vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@
from torch._inductor.pattern_matcher import PatternMatcherPass

import vllm.ir.ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding

from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
from .matcher_utils import MatcherRotaryEmbedding
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64

logger = init_logger(__name__)

FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default

# Head dimensions supported by csrc/fused_qknorm_rope_kernel.cu's
# launchFusedQKNormRope and launchFusedQKNormRopeNTokenHeads dispatchers.
# Keep in sync with the switch statements in that file.
SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS: tuple[int, ...] = (64, 128, 256)

P = ParamSpec("P")


Expand Down Expand Up @@ -58,13 +64,15 @@ def __init__(
eps: float,
is_neox: bool,
rope_flashinfer: bool = False,
match_rocm_aiter_rope: bool = False,
) -> None:
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.eps = eps
self.rmsnorm_matcher = MatcherRMSNorm(eps)
self.is_neox = is_neox
self.rope_flashinfer = rope_flashinfer
self.rope_matcher = MatcherRotaryEmbedding(
Expand All @@ -73,6 +81,7 @@ def __init__(
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
use_flashinfer=self.rope_flashinfer,
match_rocm_aiter=match_rocm_aiter_rope if match_rocm_aiter_rope else None,
)

def get_inputs(self) -> list[torch.Tensor]:
Expand Down Expand Up @@ -186,7 +195,12 @@ def replacement(


class QKNormRoPEFusionPass(VllmPatternMatcherPass):
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists.

Registers patterns for both standard vLLM ops and ROCm AITER ops
(when AITER is enabled), so the fusion fires regardless of which
RMSNorm/RoPE implementation the graph uses.
"""

@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
Expand All @@ -202,7 +216,6 @@ def __init__(self, config: VllmConfig) -> None:
)
return

# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
config, Attention
)
Expand All @@ -213,26 +226,48 @@ def __init__(self, config: VllmConfig) -> None:
return
layer = next(iter(attn_layers.values()))

for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
for rope_flashinfer in [False, True]:
if layer.head_size not in SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS:
logger.warning_once(
"QK Norm+RoPE fusion not enabled: layer head_size=%d is not "
"supported by fused_qk_norm_rope kernel (supported: %s). "
"Falling back to unfused QK norm + RoPE path.",
layer.head_size,
SUPPORTED_FUSED_QK_NORM_ROPE_HEAD_DIMS,
)
return

# RMS norm variants are no longer iterated: after the vLLM IR
# migration (#33825), `MatcherRMSNorm` dispatches via
# `ir.ops.rms_norm`, which resolves to the same backend (native /
# vllm_c / aiter / oink / ...) that the model's RMSNorm layer
# picks. The pattern graph tracks the target graph automatically.
aiter_rope_variants = [False]
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
aiter_rope_variants.append(True)

for aiter_rope in aiter_rope_variants:
for epsilon in [1e-5, 1e-6]:
for neox in [True, False]:
if RotaryEmbedding.enabled():
for rope_flashinfer in [False, True]:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
rope_flashinfer=rope_flashinfer,
match_rocm_aiter_rope=aiter_rope,
).register(self.patterns)
else:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
rope_flashinfer=rope_flashinfer,
match_rocm_aiter_rope=aiter_rope,
).register(self.patterns)
else:
QkNormRopePattern(
head_dim=layer.head_size,
num_heads=layer.num_heads,
num_kv_heads=layer.num_kv_heads,
eps=epsilon,
is_neox=neox,
).register(self.patterns)

self.dump_patterns(config, self.patterns)

Expand Down
Loading
Loading