Skip to content
Merged
10 changes: 9 additions & 1 deletion tests/compile/passes/test_rope_kvcache_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ def forward(
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rope_custom_op:
ops.append(ROTARY_OP)
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
else:
ops.append(ROTARY_OP)
else:
ops.append(INDEX_SELECT_OP)
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
Expand All @@ -196,6 +199,7 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
],
)
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False])
@pytest.mark.parametrize("num_heads", [64])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("head_size", [64])
Expand All @@ -210,6 +214,7 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
def test_rope_kvcache_fusion(
attn_backend: AttentionBackendEnum,
enable_rope_custom_op: bool,
enable_aiter_triton_rope: bool,
num_heads: int,
num_kv_heads: int,
head_size: int,
Expand Down Expand Up @@ -245,6 +250,9 @@ def test_rope_kvcache_fusion(

with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1")
m.setenv(
"VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0"
)
rocm_aiter_ops.refresh_env_variables()

model = QKRoPEKVCacheTestModel(
Expand Down
101 changes: 65 additions & 36 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,59 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake(
return out, residual_out


def _triton_rotary_embedding_impl(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
offsets: torch.Tensor | None = None,
) -> None:
# Modifies query and key in-place
from aiter.ops.triton.rope.rope import (
rope_cached_thd_positions_offsets_2c_fwd_inplace,
)

num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox else 1
rotary_dim = head_size

query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
rope_cached_thd_positions_offsets_2c_fwd_inplace(
query_,
key_,
cos,
sin,
positions,
offsets,
rotate_style,
reuse_freqs_front_part=True,
nope_first=False,
)
query = query.view(query_shape)
key = key.view(key_shape)


def _triton_rotary_embedding_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox_style: bool,
offsets: torch.Tensor | None = None,
) -> None:
return


# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False

Expand Down Expand Up @@ -1178,6 +1231,14 @@ def register_ops_once() -> None:
dispatch_key=current_platform.dispatch_key,
)

# Register rocm aiter rotary embedding custom op
direct_register_custom_op(
op_name="rocm_aiter_triton_rotary_embedding",
op_func=_triton_rotary_embedding_impl,
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_triton_rotary_embedding_fake,
)

_OPS_REGISTERED = True

@staticmethod
Expand Down Expand Up @@ -1220,6 +1281,10 @@ def get_act_mul_fused_fp8_group_quant_op() -> OpOverload:
def get_triton_add_rmsnorm_pad_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default

@staticmethod
def get_triton_rotary_embedding_op() -> OpOverload:
return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default

@staticmethod
def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
Expand Down Expand Up @@ -1482,42 +1547,6 @@ def triton_fp4_gemm_dynamic_qaunt(
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y

@staticmethod
def triton_rotary_embed(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
cos_sin_cache: torch.Tensor,
head_size: int,
rotary_dim: int,
is_neox_style: bool,
):
from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace

num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1

query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
rope_cached_thd_positions_2c_fwd_inplace(
query_,
key_,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=False,
)
query = query.view(query_shape)
key = key.view(key_shape)

@staticmethod
def triton_rope_and_cache(
query: torch.Tensor,
Expand Down
5 changes: 5 additions & 0 deletions vllm/compilation/passes/fusion/matcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,13 @@ def __init__(
num_heads: int,
num_kv_heads: int,
use_flashinfer: bool = False,
match_rocm_aiter: bool | None = None,
enabled: bool | None = None,
) -> None:
if enabled is None:
enabled = RotaryEmbedding.enabled()
if match_rocm_aiter is None:
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()

super().__init__(enabled)
self.is_neox = is_neox
Expand All @@ -104,6 +107,8 @@ def __init__(
self.rotary_dim = head_size
if use_flashinfer:
self.rotary_op = FLASHINFER_ROTARY_OP
elif match_rocm_aiter:
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
else:
self.rotary_op = ROTARY_OP

Expand Down
6 changes: 5 additions & 1 deletion vllm/compilation/passes/utility/scatter_split_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,18 @@ class ScatterSplitReplacementPass(VllmInductorPass):
def __call__(self, graph: fx.Graph) -> None:
count = 0

target_ops = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)

for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue

kwargs = node.kwargs
at_target = node.args[0]

if at_target == torch.ops._C.rotary_embedding.default:
if at_target in target_ops:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = {}
Expand Down
29 changes: 13 additions & 16 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class PassConfig:
"""Enable async TP."""
fuse_allreduce_rms: bool = Field(default=None)
"""Enable flashinfer allreduce fusion."""
enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass."""

# ROCm/AITER specific fusions
fuse_act_padding: bool = Field(default=None)
Expand Down Expand Up @@ -153,8 +155,6 @@ class PassConfig:
8: 1, # 1MB
},
}, where key is the device capability"""
enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass."""

# TODO(luka) better pass enabling system.

Expand Down Expand Up @@ -834,23 +834,20 @@ def __post_init__(self) -> None:
func if isinstance(func, InductorPass) else CallableInductorPass(func)
)

if self.pass_config.enable_qk_norm_rope_fusion:
if (
self.pass_config.enable_qk_norm_rope_fusion
and "+rotary_embedding" not in self.custom_ops
):
# TODO(zhuhaoran): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if self.pass_config.fuse_rope_kvcache:
from vllm._aiter_ops import rocm_aiter_ops

if rocm_aiter_ops.is_triton_rotary_embed_enabled():
logger.warning(
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
)
self.pass_config.fuse_rope_kvcache = False
else:
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")
if (
self.pass_config.fuse_rope_kvcache
and "+rotary_embedding" not in self.custom_ops
):
# TODO(Rohan138): support rope native forward match and remove this.
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
self.custom_ops.append("+rotary_embedding")

if (
is_torch_equal_or_newer("2.9.0.dev")
Expand Down
23 changes: 20 additions & 3 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
)


def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
"""Enable if rotary embedding custom op is active and
use_inductor_graph_partition is enabled.
"""
from vllm._aiter_ops import rocm_aiter_ops

return (
rocm_aiter_ops.is_enabled()
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
and cfg.compilation_config.use_inductor_graph_partition
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you guys using inductor graph partition on rocm by default? Otherwise we should also return true here I'd dynamo partition and kv cache op not in splitting ops

Copy link
Copy Markdown
Contributor Author

@Rohan138 Rohan138 Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Somehow GH ate my original PR comment that explained this)

This PR is necessary but not sufficient to actually enable this fusion by default. We also need:

return true if dynamo partition and kv cache op not in splitting ops

https://github.com/vllm-project/vllm/blob/main/vllm/config/compilation.py#1001 is called in https://github.com/vllm-project/vllm/blob/main/vllm/config/vllm.py#L961 after the defaults are set in https://github.com/vllm-project/vllm/blob/main/vllm/config/vllm.py#L807. So if inductor partition is not enabled, we would return true for this, then append kv cache to splitting ops, which would silently break the fusion.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Links are broken but I know what you mean - but if splitting_ops=[] is passed kvcache won't be added so it should still work. So this check should be if inductor_partition or len(splitting_ops)==0

)


def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
from vllm._aiter_ops import rocm_aiter_ops

return (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_RMSNORM
and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
rocm_aiter_ops.is_rmsnorm_enabled()
and not rocm_aiter_ops.is_triton_gemm_enabled()
and cfg.model_config is not None
and cfg.model_config.get_hidden_size() == 2880
)
Expand All @@ -149,6 +162,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": False,
"fuse_rope_kvcache": False,
},
"cudagraph_mode": CUDAGraphMode.NONE,
"use_inductor_graph_partition": False,
Expand All @@ -167,6 +181,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"enable_sp": False,
"fuse_gemm_comms": False,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False,
Expand All @@ -185,6 +200,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
Expand All @@ -203,6 +219,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
"enable_sp": IS_DENSE,
"fuse_gemm_comms": IS_DENSE,
"fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
},
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
"use_inductor_graph_partition": False,
Expand Down
6 changes: 3 additions & 3 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
Expand Down Expand Up @@ -937,9 +937,9 @@ def _get_or_set_default() -> str:
os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1")
),
# Whether to use aiter rope.
# By default is disabled.
# By default is enabled.
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1")
),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
Expand Down
18 changes: 11 additions & 7 deletions vllm/model_executor/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,20 @@ def __init__(
if not hasattr(self, "use_flashinfer"):
self.use_flashinfer = False

self.use_aiter = (
self.enabled() and rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
if self.use_aiter:
self.rocm_aiter_triton_rotary_embedding = (
rocm_aiter_ops.get_triton_rotary_embedding_op()
)

if init_cache:
cache = self._compute_cos_sin_cache()
if not self.use_flashinfer:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_triton_rotary_embed_enabled = (
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)

self.apply_rotary_emb = ApplyRotaryEmb(
is_neox_style=self.is_neox_style,
Expand Down Expand Up @@ -231,15 +236,14 @@ def forward_hip(
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.is_rocm_triton_rotary_embed_enabled:
if self.use_aiter:
cos_sin_cache = self._match_cos_sin_cache_dtype(query)
rocm_aiter_ops.triton_rotary_embed(
self.rocm_aiter_triton_rotary_embedding(
positions,
query,
key,
cos_sin_cache,
self.head_size,
self.rotary_dim,
cos_sin_cache,
self.is_neox_style,
)
return query, key
Expand Down
Loading