Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope

from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
Expand Down Expand Up @@ -525,6 +526,9 @@ def forward(self,
return q_pe, k_pe


_QWEN3_VL_MROPE_SECTION = [24, 20, 20]


class AscendMRotaryEmbedding(MRotaryEmbedding):

def forward_oot(
Expand All @@ -533,6 +537,10 @@ def forward_oot(
query: torch.Tensor,
key: torch.Tensor,
):
# use triton mrope for Qwen3-VL
if self.mrope_section == _QWEN3_VL_MROPE_SECTION:
return self.forward_triton(positions, query, key)

if self.mrope_section != [16, 24, 24] or \
get_ascend_device_type() == AscendDeviceType.A5:
return super().forward_oot(positions, query, key)
Expand All @@ -559,6 +567,35 @@ def forward_oot(

return query, key

def forward_triton(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert positions.ndim == 2
assert key is not None

self._match_cos_sin_cache_dtype(query)
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape

assert self.mrope_section
q, k = triton_mrope(
query,
key,
cos,
sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
return q.reshape(query_shape), k.reshape(key_shape)


class AscendApplyRotaryEmb(ApplyRotaryEmb):

Expand Down
Loading