Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import math
from typing import Optional, Tuple

import einops
import torch
import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
Expand All @@ -28,7 +27,8 @@

from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
get_ascend_device_type, has_rope, is_vl_model)
get_ascend_device_type, has_rope, is_vl_model,
vllm_version_is)

# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
Expand Down Expand Up @@ -580,37 +580,35 @@ def forward_oot(
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
head_dim = x.shape[-1]

origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
if vllm_version_is('0.13.0'):
origin_shape = x.shape
origin_dtype = x.dtype
if len(origin_shape) == 3:
x = x.unsqueeze(0)
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
else:
x, cos, sin, origin_shape, origin_dtype = self._pre_process(
x, cos, sin)

head_dim = x.shape[-1]
# cos, sin: [seq_len, head_dim // 2]
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
# cos, sin: [1, seq_len, 1, head_dim]
cos = cos.reshape(1, -1, 1, head_dim)
sin = sin.reshape(1, -1, 1, head_dim)

if len(x.shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
# x: [1, seq_len, num_heads, head_size]
output = torch_npu.npu_rotary_mul(x, cos, sin).squeeze(0)
output = torch_npu.npu_rotary_mul(x, cos, sin)

if vllm_version_is('0.13.0'):
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
else:
assert len(x.shape) == 4
# x: [2 * b, s, head, head_dim]
qk = einops.rearrange(
x, "(two b) s head head_dim -> b s two head head_dim", two=2)
# q, k: [b, s, head, head_dim]
q, k = qk[:, :, 0], qk[:, :, 1]
q = torch_npu.npu_rotary_mul(q, cos, sin)
k = torch_npu.npu_rotary_mul(k, cos, sin)
output = torch.cat([q, k], dim=0)

This comment was marked as resolved.


if self.enable_fp32_compute:
output = output.to(origin_dtype)
output = self._post_process(output, origin_shape, origin_dtype)

return output