diff --git a/csrc/ops.h b/csrc/ops.h index f8bdc61aaa8e..1401a6756659 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,11 +92,13 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +#ifndef USE_ROCM void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q, int64_t num_heads_k, int64_t num_heads_v, int64_t head_dim, double eps, torch::Tensor& q_weight, torch::Tensor& k_weight, torch::Tensor& cos_sin_cache, bool is_neox, torch::Tensor& position_ids); +#endif void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 126ad35e527a..ca468724b674 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -133,7 +133,10 @@ def __call__(self, graph: torch.fx.Graph): ), ) # Defunctionalize fused_qk_norm_rope to remove higher-order wrapper. - elif at_target == torch.ops._C.fused_qk_norm_rope.default: + elif ( + current_platform.is_cuda() + and at_target == torch.ops._C.fused_qk_norm_rope.default + ): mutated_args = {1: "qkv"} args = ( "qkv", diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 0c2210d72ce0..8edf6c9572a9 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -17,9 +17,9 @@ from .activation_quant_fusion import ActivationQuantFusionPass from .fusion import RMSNormQuantFusionPass from .fusion_attn import AttnFusionPass - from .qk_norm_rope_fusion import QKNormRoPEFusionPass if current_platform.is_cuda(): + from .qk_norm_rope_fusion import QKNormRoPEFusionPass from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .fix_functionalization import FixFunctionalizationPass