Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions tests/ut/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,28 @@ def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])

@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_rotary_dim_less_than_head_size(
self, mock_npu_rotary, mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False

# test case when rotary_dim < head_size
org_rotary_dim = self.layer.rotary_dim
self.layer.rotary_dim = self.layer.head_size // 2

result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)

mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)

# restore rotary_dim
self.layer.rotary_dim = org_rotary_dim


class MockRopeModule:

Expand Down Expand Up @@ -207,28 +229,6 @@ def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape

@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_cache_handling(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
self.layer.max_seq_len = 1024
# Test cache situation is true
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
mock_rope_forward_oot.return_value = (self.query, self.key)

q_pe, k_pe = self.layer.forward(self.positions,
self.query,
self.key,
max_seq_len=2048)
mock_set_cache.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape

@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
Expand Down
51 changes: 25 additions & 26 deletions tests/ut/torchair/ops/test_torchair_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

from tests.ut.base import TestBase
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
custom_rotary_embedding_enabled, native_rope_deepseek_forward,
rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale)
_set_cos_sin_cache, custom_rotary_embedding_enabled,
native_rope_deepseek_forward, rope_forward_oot, rotate_half,
yarn_find_correction_dim, yarn_get_mscale)


class TestCustomRotaryEmbeddingEnabled(TestBase):
Expand Down Expand Up @@ -200,46 +201,44 @@ def __init__(self, max_seq_len=2048, is_neox_style=True):
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
self.beta_fast = 32
self.beta_slow = 1
self.max_position_embeddings = 4096
self.mscale = 1.0
self.scaling_factor = 40

def register_buffer(self):
pass

class TestNativeRopeDeepseekForward(TestBase):

@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
class TestSetSinCosCache(TestBase):

def test_set_cos_sin_cache(self):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)

mock_rope_forward_oot.return_value = (query, key)
with patch.object(module, "register_buffer") as mock_register_buffer:
_set_cos_sin_cache(module,
1024,
device="cpu",
dtype=torch.bfloat16)

q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
mock_register_buffer.assert_called()

assert q_pe.shape == query.shape
assert k_pe.shape == key.shape

@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache'
)
class TestNativeRopeDeepseekForward(TestBase):

@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_cache_handling(
self, mock_rope_forward_oot, mock_set_cache):
# Test cache situation is true
module = MockRopeModule(max_seq_len=1024)
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)

mock_rope_forward_oot.return_value = (query, key)

q_pe, k_pe = native_rope_deepseek_forward(module,
positions,
query,
key,
max_seq_len=2048)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)

assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
Expand Down
18 changes: 7 additions & 11 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ def __init__(
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
self._set_cos_sin_cache(seq_len=max_position_embeddings,

# NOTE: For ascend friendly computing, reorder sin and cos cache
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
self._set_cos_sin_cache(self.max_seq_len,
device=NPUPlatform.device_type,
dtype=dtype)

Expand Down Expand Up @@ -275,8 +277,7 @@ def _apply_rotary_pos_emb(self,

return q_embed, k_embed

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
dim = self.rotary_dim

freq_extra = 1.0 / (self.base**(
Expand All @@ -297,9 +298,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)

freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
Expand All @@ -317,10 +316,7 @@ def forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
offsets: Optional[torch.Tensor] = None):
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
Expand Down
21 changes: 7 additions & 14 deletions vllm_ascend/torchair/ops/torchair_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,7 @@ def native_rope_deepseek_forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
_set_cos_sin_cache(self, max_seq_len, query.device, query.dtype)
offsets: Optional[torch.Tensor] = None):
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
Expand Down Expand Up @@ -211,8 +208,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
return q_embed, k_embed


def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
dim = self.rotary_dim

freq_extra = 1.0 / (self.base**(
Expand All @@ -232,9 +228,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)

t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)

freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
Expand Down Expand Up @@ -365,8 +359,7 @@ def deepseek_rope_init_func(
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
_set_cos_sin_cache(self,
max_position_embeddings,
dtype=dtype,
device="npu")

# NOTE: For ascend friendly computing, reorder sin and cos cache
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
_set_cos_sin_cache(self, self.max_seq_len, dtype=dtype, device="npu")
8 changes: 2 additions & 6 deletions vllm_ascend/torchair/torchair_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,9 +1198,7 @@ def forward(
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions,
decode_q_pe.contiguous(),
decode_k_pe,
max_seq_len=attn_metadata.decode.max_seq_lens)
decode_q_pe.contiguous(), decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
Expand All @@ -1225,9 +1223,7 @@ def forward(
else:
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)
prefill_q_pe.contiguous(), prefill_k_pe)

assert len(
kv_cache
Expand Down
Loading