diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index c7df5f92e874..1b656d0c890e 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -37,6 +37,14 @@ def __call__(self, graph: torch.fx.Graph) -> None: self.nodes_to_remove: list[torch.fx.Node] = [] count = 0 + + rope_targets = [torch.ops._C.rotary_embedding.default] + + if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"): + rope_targets.append( + torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default + ) + for node in graph.nodes: if not is_func(node, auto_functionalized): continue # Avoid deep if-elif nesting @@ -44,7 +52,7 @@ def __call__(self, graph: torch.fx.Graph) -> None: kwargs = node.kwargs at_target = node.args[0] - if at_target == torch.ops._C.rotary_embedding.default: + if at_target in rope_targets: query = kwargs["query"] key = kwargs["key"] getitem_nodes = self.getitem_users(node)