Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions vllm_ascend/ops/register_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream

from typing import Optional, Tuple
from vllm_ascend.ops.triton.rope import rope_forward_triton

def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -302,7 +303,15 @@ def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor,
input_offset: torch.Tensor) -> torch.Tensor:
return torch_npu.npu_quantize(in_tensor, input_scale_reciprocal,
input_offset, torch.qint8, -1, False)

def _rope_forward_triton_fake(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_dim: int = -1,
is_neox_style: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(q), torch.empty_like(k)

direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
Expand Down Expand Up @@ -369,3 +378,8 @@ def _quantize_impl_fake(in_tensor: torch.Tensor, input_scale: torch.Tensor,
fake_impl=_quantize_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="rope_forward_triton",
op_func=rope_forward_triton,
fake_impl=_rope_forward_triton_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
Comment on lines +381 to +385
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function rope_forward_triton is used here as the op_func for the custom operator, but it is not defined or imported within this file. This will likely result in a NameError when this module is imported. To fix this, you should import it from vllm_ascend.ops.triton.rope at the top of the file.

61 changes: 38 additions & 23 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def _rope_forward_oot(
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
cos, sin = get_cos_and_sin_slice()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be problems to put this index_select here. Please contact @Angazenn to make sure of this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_cos_and_sin_slice does not take index_select, I think its ok to move it here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine

# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(
query, is_neox_style, self.head_size) and get_ascend_device_type(
Expand All @@ -204,7 +205,6 @@ def _rope_forward_oot(
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
cos, sin = get_cos_and_sin_slice()
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128 and cos is not None and sin is not None:
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
Expand All @@ -217,28 +217,43 @@ def _rope_forward_oot(
query, key = torch_npu.npu_apply_rotary_pos_emb(
query, key, cos, sin)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., :self.rotary_dim]
q_pass = query[..., self.rotary_dim:]
k_rot = key[..., :self.rotary_dim]
k_pass = key[..., self.rotary_dim:]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
torch_npu._npu_rotary_embedding(
positions,
q_rot,
k_rot,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
if HAS_TRITON:

cos = cos.view(-1, self.rotary_dim)
sin = sin.view(-1, self.rotary_dim)
q = query.contiguous().view(query.shape[0], -1,
self.head_size)
k = key.contiguous().view(key.shape[0], -1, self.head_size)
query, key = torch.ops.vllm.rope_forward_triton(q,
k,
cos,
sin,
rope_dim=self.rotary_dim,
is_neox_style=True)
return query.view(query_shape), key.view(key_shape)
else:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., :self.rotary_dim]
q_pass = query[..., self.rotary_dim:]
k_rot = key[..., :self.rotary_dim]
k_pass = key[..., self.rotary_dim:]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
torch_npu._npu_rotary_embedding(
positions,
q_rot,
k_rot,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
Expand Down
15 changes: 9 additions & 6 deletions vllm_ascend/ops/triton/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# This file is a part of the vllm-ascend project.
#
from vllm.triton_utils import tl, triton

import torch
from typing import Tuple
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num


Expand Down Expand Up @@ -157,12 +158,14 @@ def _triton_rope(
mask=second_k_mask)


def rope_forward_triton(q,
k,
cos,
sin,
def rope_forward_triton(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
rope_dim: int = -1,
is_neox_style: bool = True):
is_neox_style: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
Expand Down
Loading