diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 28bde0747d7..3b38aa431e5 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -25,6 +25,7 @@ DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb +from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, @@ -525,6 +526,9 @@ def forward(self, return q_pe, k_pe +_QWEN3_VL_MROPE_SECTION = [24, 20, 20] + + class AscendMRotaryEmbedding(MRotaryEmbedding): def forward_oot( @@ -533,6 +537,10 @@ def forward_oot( query: torch.Tensor, key: torch.Tensor, ): + # use triton mrope for Qwen3-VL + if self.mrope_section == _QWEN3_VL_MROPE_SECTION: + return self.forward_triton(positions, query, key) + if self.mrope_section != [16, 24, 24] or \ get_ascend_device_type() == AscendDeviceType.A5: return super().forward_oot(positions, query, key) @@ -559,6 +567,35 @@ def forward_oot( return query, key + def forward_triton( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + assert positions.ndim == 2 + assert key is not None + + self._match_cos_sin_cache_dtype(query) + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + + assert self.mrope_section + q, k = triton_mrope( + query, + key, + cos, + sin, + self.mrope_section, + self.head_size, + self.rotary_dim, + self.mrope_interleaved, + ) + return q.reshape(query_shape), k.reshape(key_shape) + class AscendApplyRotaryEmb(ApplyRotaryEmb):