Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import gc
from typing import List

import pytest
import torch
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope

from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton

MROPE_SECTION = [[32, 32, 32]]
DTYPES = [torch.bfloat16, torch.float16]
HEAD_SIZES = [128]
ROTARY_DIMS = [128]
NUM_Q_HEADS = [64]
NUM_K_HEADS = [1]
NUM_TOKENS = [1, 4, 8, 16]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3


def pytorch_forward_native(q, k, cos, sin, mrope_section, head_size,
rotary_dim, mrope_interleaved):
"""PyTorch-native implementation equivalent to forward().
"""

num_tokens = q.shape[0]
n_q_head = q.shape[1] // head_size
n_kv_head = k.shape[1] // head_size

q_reshaped = q.view(num_tokens, n_q_head, head_size)
k_reshaped = k.view(num_tokens, n_kv_head, head_size)

cos_reshaped = cos.permute(1, 2, 0)
sin_reshaped = sin.permute(1, 2, 0)

half_rd = rotary_dim // 2

for token_idx in range(num_tokens):
token_cos = cos_reshaped[token_idx]
token_sin = sin_reshaped[token_idx]

cos_row = torch.zeros(head_size // 2, device=q.device, dtype=q.dtype)
sin_row = torch.zeros(head_size // 2, device=q.device, dtype=q.dtype)

if mrope_interleaved:
cos_offsets = torch.arange(0, head_size // 2, device=q.device)
h_mask = (
(cos_offsets % 3) == 1) & (cos_offsets <= 3 * mrope_section[1])
w_mask = (
(cos_offsets % 3) == 2) & (cos_offsets <= 3 * mrope_section[2])
t_mask = ~(h_mask | w_mask)

cos_row[t_mask] = token_cos[t_mask, 0]
cos_row[h_mask] = token_cos[h_mask, 1]
cos_row[w_mask] = token_cos[w_mask, 2]

sin_row[t_mask] = token_sin[t_mask, 0]
sin_row[h_mask] = token_sin[h_mask, 1]
sin_row[w_mask] = token_sin[w_mask, 2]
else:
t_end = mrope_section[0]
h_end = t_end + mrope_section[1]

if t_end > 0:
cos_row[:t_end] = token_cos[:t_end, 0]
sin_row[:t_end] = token_sin[:t_end, 0]

if mrope_section[1] > 0:
cos_row[t_end:h_end] = token_cos[t_end:h_end, 1]
sin_row[t_end:h_end] = token_sin[t_end:h_end, 1]

if mrope_section[2] > 0:
w_start = h_end
cos_row[w_start:half_rd] = token_cos[w_start:half_rd, 2]
sin_row[w_start:half_rd] = token_sin[w_start:half_rd, 2]

q_token = q_reshaped[token_idx]
k_token = k_reshaped[token_idx]

q1 = q_token[:, :half_rd]
q2 = q_token[:, half_rd:]
k1 = k_token[:, :half_rd]
k2 = k_token[:, half_rd:]

cos_half = cos_row.unsqueeze(0)
sin_half = sin_row.unsqueeze(0)

new_q1 = q1 * cos_half - q2 * sin_half
new_q2 = q2 * cos_half + q1 * sin_half

new_k1 = k1 * cos_half - k2 * sin_half
new_k2 = k2 * cos_half + k1 * sin_half

q_reshaped[token_idx] = torch.cat([new_q1, new_q2], dim=1)
k_reshaped[token_idx] = torch.cat([new_k1, new_k2], dim=1)

q_result = q_reshaped.view(num_tokens, -1)
k_result = k_reshaped.view(num_tokens, -1)

return q_result, k_result


def create_test_data(num_tokens, n_q_head, n_kv_head, rotary_dim, head_size,
device, dtype):
q = torch.randn(num_tokens,
n_q_head * head_size,
dtype=dtype,
device=device)
k = torch.randn(num_tokens,
n_kv_head * head_size,
dtype=dtype,
device=device)

sin = torch.randn(3,
num_tokens,
rotary_dim // 2,
dtype=dtype,
device=device)
cos = torch.randn(3,
num_tokens,
rotary_dim // 2,
dtype=dtype,
device=device)

norm = torch.sqrt(cos**2 + sin**2)
cos = cos / norm
sin = sin / norm

return q, k, cos, sin


@pytest.mark.parametrize("mrope_section", MROPE_SECTION)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
@pytest.mark.parametrize("num_k_heads", NUM_K_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_mrotary_embedding_triton_kernel(
mrope_section: List[int],
num_tokens: int,
num_q_heads: int,
num_k_heads: int,
head_size: int,
rotary_dim: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
if rotary_dim == -1:
rotary_dim = head_size

q_trt, k_trt, cos, sin = create_test_data(num_tokens=num_tokens,
n_q_head=num_q_heads,
n_kv_head=num_k_heads,
head_size=head_size,
rotary_dim=rotary_dim,
device=device,
dtype=dtype)

q_gold, k_gold = q_trt.clone(), k_trt.clone()

q_trt, k_trt = triton_mrope(q_trt, k_trt, cos, sin, mrope_section,
head_size, rotary_dim, True)

q_gold, k_gold = pytorch_forward_native(q_gold, k_gold, cos, sin,
mrope_section, head_size,
rotary_dim, True)
atol = DEFAULT_ATOL
rtol = DEFAULT_RTOL
if dtype == torch.bfloat16:
atol = 1e-02
rtol = 1e-02
# Compare the results.
torch.testing.assert_close(q_trt.view(q_gold.size()),
q_gold,
atol=atol,
rtol=rtol)
torch.testing.assert_close(k_trt.view(k_gold.size()),
k_gold,
atol=atol,
rtol=rtol)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
42 changes: 42 additions & 0 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope

from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
Expand Down Expand Up @@ -527,12 +531,50 @@ def forward(self,

class AscendMRotaryEmbedding(MRotaryEmbedding):

def forward_triton(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None):
assert positions.ndim == 2
assert key is not None

self._match_cos_sin_cache_dtype(query)
self.cos = None
self.sin = None
if self.cos is None and self.sin is None:
cos_sin = self.cos_sin_cache[positions] # type: ignore
cos, sin = cos_sin.chunk(2, dim=-1)
self.cos = cos.contiguous()
self.sin = sin.contiguous()
query_shape = query.shape
key_shape = key.shape

assert self.mrope_section

q, k = triton_mrope(
query,
key,
self.cos,
self.sin,
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
Comment on lines +545 to +564
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The caching logic for cos and sin is incorrect. These tensors depend on positions, which can change on each call to forward_triton. Caching them as instance attributes (self.cos, self.sin) will cause subsequent calls with different positions to use stale values, leading to incorrect results. The cos and sin tensors should be computed on every call and not stored on self.

        cos_sin = self.cos_sin_cache[positions]  # type: ignore
        cos, sin = cos_sin.chunk(2, dim=-1)
        query_shape = query.shape
        key_shape = key.shape

        assert self.mrope_section

        q, k = triton_mrope(
            query,
            key,
            cos.contiguous(),
            sin.contiguous(),
            self.mrope_section,
            self.head_size,
            self.rotary_dim,
            self.mrope_interleaved,
        )


return q.reshape(query_shape), k.reshape(key_shape)

def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
):
if HAS_TRITON and positions.ndim == 2:
# todo: need cann update in 8.5.0
return self.forward_triton(positions, query, key)

if self.mrope_section != [16, 24, 24] or \
get_ascend_device_type() == AscendDeviceType.A5:
return super().forward_oot(positions, query, key)
Expand Down