diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py index d074d2a9e319..09679fb41779 100644 --- a/tests/compile/passes/test_rope_kvcache_fusion.py +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -177,7 +177,10 @@ def forward( def ops_in_model_before(self) -> list[torch._ops.OpOverload]: ops = [] if self.enable_rope_custom_op: - ops.append(ROTARY_OP) + if rocm_aiter_ops.is_triton_rotary_embed_enabled(): + ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default) + else: + ops.append(ROTARY_OP) else: ops.append(INDEX_SELECT_OP) ops.append(torch.ops.vllm.unified_kv_cache_update.default) @@ -196,6 +199,7 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]: ], ) @pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False]) +@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False]) @pytest.mark.parametrize("num_heads", [64]) @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("head_size", [64]) @@ -210,6 +214,7 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]: def test_rope_kvcache_fusion( attn_backend: AttentionBackendEnum, enable_rope_custom_op: bool, + enable_aiter_triton_rope: bool, num_heads: int, num_kv_heads: int, head_size: int, @@ -245,6 +250,9 @@ def test_rope_kvcache_fusion( with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: m.setenv("VLLM_ROCM_USE_AITER", "1") + m.setenv( + "VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0" + ) rocm_aiter_ops.refresh_env_variables() model = QKRoPEKVCacheTestModel( diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 012a3f36798e..3414443e52cb 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -831,6 +831,59 @@ def _rocm_aiter_triton_add_rmsnorm_pad_fake( return out, residual_out +def _triton_rotary_embedding_impl( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, + offsets: torch.Tensor | None = None, +) -> None: + # Modifies query and key in-place + from aiter.ops.triton.rope.rope import ( + rope_cached_thd_positions_offsets_2c_fwd_inplace, + ) + + num_tokens = positions.numel() + cos, sin = cos_sin_cache.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + rotate_style = 0 if is_neox else 1 + rotary_dim = head_size + + query = query.view(num_tokens, -1, head_size) + key = key.view(num_tokens, -1, head_size) + query_ = query[..., :rotary_dim] + key_ = key[..., :rotary_dim] + positions = positions.view(*query.shape[:1]) + rope_cached_thd_positions_offsets_2c_fwd_inplace( + query_, + key_, + cos, + sin, + positions, + offsets, + rotate_style, + reuse_freqs_front_part=True, + nope_first=False, + ) + query = query.view(query_shape) + key = key.view(key_shape) + + +def _triton_rotary_embedding_fake( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox_style: bool, + offsets: torch.Tensor | None = None, +) -> None: + return + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -1178,6 +1231,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + # Register rocm aiter rotary embedding custom op + direct_register_custom_op( + op_name="rocm_aiter_triton_rotary_embedding", + op_func=_triton_rotary_embedding_impl, + mutates_args=["query", "key"], # These tensors are modified in-place + fake_impl=_triton_rotary_embedding_fake, + ) + _OPS_REGISTERED = True @staticmethod @@ -1220,6 +1281,10 @@ def get_act_mul_fused_fp8_group_quant_op() -> OpOverload: def get_triton_add_rmsnorm_pad_op() -> OpOverload: return torch.ops.vllm.rocm_aiter_triton_add_rmsnorm_pad.default + @staticmethod + def get_triton_rotary_embedding_op() -> OpOverload: + return torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default + @staticmethod def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float @@ -1482,42 +1547,6 @@ def triton_fp4_gemm_dynamic_qaunt( gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y) return y - @staticmethod - def triton_rotary_embed( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - cos_sin_cache: torch.Tensor, - head_size: int, - rotary_dim: int, - is_neox_style: bool, - ): - from aiter.ops.triton.rope import rope_cached_thd_positions_2c_fwd_inplace - - num_tokens = positions.numel() - cos, sin = cos_sin_cache.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - rotate_style = 0 if is_neox_style else 1 - - query = query.view(num_tokens, -1, head_size) - key = key.view(num_tokens, -1, head_size) - query_ = query[..., :rotary_dim] - key_ = key[..., :rotary_dim] - positions = positions.view(*query.shape[:1]) - rope_cached_thd_positions_2c_fwd_inplace( - query_, - key_, - cos, - sin, - positions, - rotate_style, - reuse_freqs_front_part=True, - nope_first=False, - ) - query = query.view(query_shape) - key = key.view(key_shape) - @staticmethod def triton_rope_and_cache( query: torch.Tensor, diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index 6b1b9a73baee..03f680552c58 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -89,10 +89,13 @@ def __init__( num_heads: int, num_kv_heads: int, use_flashinfer: bool = False, + match_rocm_aiter: bool | None = None, enabled: bool | None = None, ) -> None: if enabled is None: enabled = RotaryEmbedding.enabled() + if match_rocm_aiter is None: + match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled() super().__init__(enabled) self.is_neox = is_neox @@ -104,6 +107,8 @@ def __init__( self.rotary_dim = head_size if use_flashinfer: self.rotary_op = FLASHINFER_ROTARY_OP + elif match_rocm_aiter: + self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op() else: self.rotary_op = ROTARY_OP diff --git a/vllm/compilation/passes/utility/scatter_split_replace.py b/vllm/compilation/passes/utility/scatter_split_replace.py index 1826c07f869c..a17a7b336d2d 100644 --- a/vllm/compilation/passes/utility/scatter_split_replace.py +++ b/vllm/compilation/passes/utility/scatter_split_replace.py @@ -60,6 +60,10 @@ class ScatterSplitReplacementPass(VllmInductorPass): def __call__(self, graph: fx.Graph) -> None: count = 0 + target_ops = [torch.ops._C.rotary_embedding.default] + if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"): + target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default) + for node in graph.nodes: if not is_func(node, auto_functionalized): continue @@ -67,7 +71,7 @@ def __call__(self, graph: fx.Graph) -> None: kwargs = node.kwargs at_target = node.args[0] - if at_target == torch.ops._C.rotary_embedding.default: + if at_target in target_ops: query = kwargs["query"] key = kwargs["key"] getitem_nodes = {} diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index b1f0779c7f9c..ab6f3da06cdf 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -123,6 +123,8 @@ class PassConfig: """Enable async TP.""" fuse_allreduce_rms: bool = Field(default=None) """Enable flashinfer allreduce fusion.""" + enable_qk_norm_rope_fusion: bool = False + """Enable fused Q/K RMSNorm + RoPE pass.""" # ROCm/AITER specific fusions fuse_act_padding: bool = Field(default=None) @@ -153,8 +155,6 @@ class PassConfig: 8: 1, # 1MB }, }, where key is the device capability""" - enable_qk_norm_rope_fusion: bool = False - """Enable fused Q/K RMSNorm + RoPE pass.""" # TODO(luka) better pass enabling system. @@ -834,23 +834,20 @@ def __post_init__(self) -> None: func if isinstance(func, InductorPass) else CallableInductorPass(func) ) - if self.pass_config.enable_qk_norm_rope_fusion: + if ( + self.pass_config.enable_qk_norm_rope_fusion + and "+rotary_embedding" not in self.custom_ops + ): # TODO(zhuhaoran): support rope native forward match and remove this. # Linked issue: https://github.com/vllm-project/vllm/issues/28042 self.custom_ops.append("+rotary_embedding") - if self.pass_config.fuse_rope_kvcache: - from vllm._aiter_ops import rocm_aiter_ops - - if rocm_aiter_ops.is_triton_rotary_embed_enabled(): - logger.warning( - "Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with " - "fuse_rope_kvcache. Disabling fuse_rope_kvcache." - ) - self.pass_config.fuse_rope_kvcache = False - else: - # TODO(Rohan138): support rope native forward match and remove this. - # Linked issue: https://github.com/vllm-project/vllm/issues/28042 - self.custom_ops.append("+rotary_embedding") + if ( + self.pass_config.fuse_rope_kvcache + and "+rotary_embedding" not in self.custom_ops + ): + # TODO(Rohan138): support rope native forward match and remove this. + # Linked issue: https://github.com/vllm-project/vllm/issues/28042 + self.custom_ops.append("+rotary_embedding") if ( is_torch_equal_or_newer("2.9.0.dev") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 2a0c0679f95d..d7deadd501e8 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -126,14 +126,27 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: ) +def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: + """Enable if rotary embedding custom op is active and + use_inductor_graph_partition is enabled. + """ + from vllm._aiter_ops import rocm_aiter_ops + + return ( + rocm_aiter_ops.is_enabled() + and cfg.compilation_config.is_custom_op_enabled("rotary_embedding") + and cfg.compilation_config.use_inductor_graph_partition + ) + + def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: """Enable if using AITER RMSNorm and AITER Triton GEMMs and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion.""" + from vllm._aiter_ops import rocm_aiter_ops return ( - envs.VLLM_ROCM_USE_AITER - and envs.VLLM_ROCM_USE_AITER_RMSNORM - and envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + rocm_aiter_ops.is_rmsnorm_enabled() + and not rocm_aiter_ops.is_triton_gemm_enabled() and cfg.model_config is not None and cfg.model_config.get_hidden_size() == 2880 ) @@ -149,6 +162,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: "enable_sp": False, "fuse_gemm_comms": False, "fuse_act_padding": False, + "fuse_rope_kvcache": False, }, "cudagraph_mode": CUDAGraphMode.NONE, "use_inductor_graph_partition": False, @@ -167,6 +181,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: "enable_sp": False, "fuse_gemm_comms": False, "fuse_act_padding": enable_norm_pad_fusion, + "fuse_rope_kvcache": enable_rope_kvcache_fusion, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, "use_inductor_graph_partition": False, @@ -185,6 +200,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: "enable_sp": IS_DENSE, "fuse_gemm_comms": IS_DENSE, "fuse_act_padding": enable_norm_pad_fusion, + "fuse_rope_kvcache": enable_rope_kvcache_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, @@ -203,6 +219,7 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool: "enable_sp": IS_DENSE, "fuse_gemm_comms": IS_DENSE, "fuse_act_padding": enable_norm_pad_fusion, + "fuse_rope_kvcache": enable_rope_kvcache_fusion, }, "cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE, "use_inductor_graph_partition": False, diff --git a/vllm/envs.py b/vllm/envs.py index e6b824c56f0b..175481cddc65 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -105,7 +105,7 @@ VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False - VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False + VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False @@ -937,9 +937,9 @@ def _get_or_set_default() -> str: os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1") ), # Whether to use aiter rope. - # By default is disabled. + # By default is enabled. "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 1e3063392499..1374334b2cad 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -47,15 +47,20 @@ def __init__( if not hasattr(self, "use_flashinfer"): self.use_flashinfer = False + self.use_aiter = ( + self.enabled() and rocm_aiter_ops.is_triton_rotary_embed_enabled() + ) + if self.use_aiter: + self.rocm_aiter_triton_rotary_embedding = ( + rocm_aiter_ops.get_triton_rotary_embedding_op() + ) + if init_cache: cache = self._compute_cos_sin_cache() if not self.use_flashinfer: cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.is_rocm_triton_rotary_embed_enabled = ( - rocm_aiter_ops.is_triton_rotary_embed_enabled() - ) self.apply_rotary_emb = ApplyRotaryEmb( is_neox_style=self.is_neox_style, @@ -231,15 +236,14 @@ def forward_hip( query: torch.Tensor, key: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if self.is_rocm_triton_rotary_embed_enabled: + if self.use_aiter: cos_sin_cache = self._match_cos_sin_cache_dtype(query) - rocm_aiter_ops.triton_rotary_embed( + self.rocm_aiter_triton_rotary_embedding( positions, query, key, - cos_sin_cache, self.head_size, - self.rotary_dim, + cos_sin_cache, self.is_neox_style, ) return query, key diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a8a1d59f1bf0..c20c5717f740 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -494,6 +494,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enabled() use_aiter_fused_se = rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + use_aiter_triton_rope = rocm_aiter_ops.is_triton_rotary_embed_enabled() if compilation_config.cudagraph_mode.has_full_cudagraphs(): # decode context parallel does not support full cudagraphs @@ -558,6 +559,13 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: and "-grouped_topk" not in compilation_config.custom_ops ): compilation_config.custom_ops.append("+grouped_topk") + # Enable rotary embedding when using AITER if its not disabled by user + if ( + use_aiter_triton_rope + and "+rotary_embedding" not in compilation_config.custom_ops + and "-rotary_embedding" not in compilation_config.custom_ops + ): + compilation_config.custom_ops.append("+rotary_embedding") # Default dispatch to rocm's sparse_attn_indexer implementation compilation_config.custom_ops.append("+sparse_attn_indexer")