diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 24306a01075..114ae438868 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -524,7 +524,7 @@ def indexer_select( dim=-1) # [b,s,64,64+64] q_pe = q_pe.unsqueeze(2) - q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) q_pe = q_pe.squeeze(2) q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] @@ -534,7 +534,7 @@ def indexer_select( dim=-1) # [b,s,64+64] k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin) k_pe = k_pe.squeeze(2) k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]