diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 3b38aa431e5..b865e15847b 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -25,7 +25,10 @@ DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb -from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, @@ -526,10 +529,38 @@ def forward(self, return q_pe, k_pe -_QWEN3_VL_MROPE_SECTION = [24, 20, 20] +class AscendMRotaryEmbedding(MRotaryEmbedding): + def forward_triton(self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None): + assert positions.ndim == 2 + assert key is not None -class AscendMRotaryEmbedding(MRotaryEmbedding): + self._match_cos_sin_cache_dtype(query) + cos_sin = self.cos_sin_cache[positions] # type: ignore + cos, sin = cos_sin.chunk(2, dim=-1) + self.cos = cos.contiguous() + self.sin = sin.contiguous() + query_shape = query.shape + key_shape = key.shape + + assert self.mrope_section + + q, k = triton_mrope( + query, + key, + self.cos, + self.sin, + self.mrope_section, + self.head_size, + self.rotary_dim, + self.mrope_interleaved, + ) + + return q.reshape(query_shape), k.reshape(key_shape) def forward_oot( self, @@ -537,8 +568,8 @@ def forward_oot( query: torch.Tensor, key: torch.Tensor, ): - # use triton mrope for Qwen3-VL - if self.mrope_section == _QWEN3_VL_MROPE_SECTION: + if HAS_TRITON and positions.ndim == 2: + # todo: need cann update in 8.5.0 return self.forward_triton(positions, query, key) if self.mrope_section != [16, 24, 24] or \ @@ -567,35 +598,6 @@ def forward_oot( return query, key - def forward_triton( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None = None, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - assert positions.ndim == 2 - assert key is not None - - self._match_cos_sin_cache_dtype(query) - cos_sin = self.cos_sin_cache[positions] - cos, sin = cos_sin.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - - assert self.mrope_section - q, k = triton_mrope( - query, - key, - cos, - sin, - self.mrope_section, - self.head_size, - self.rotary_dim, - self.mrope_interleaved, - ) - return q.reshape(query_shape), k.reshape(key_shape) - class AscendApplyRotaryEmb(ApplyRotaryEmb):