Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 1 addition & 259 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata)
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata


class TestAscendMLABackend(TestBase):
Expand Down Expand Up @@ -188,8 +187,6 @@ def test_ascend_mla_metadata_builder_default(self):
mock_device = 'cpu'

ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
Expand All @@ -199,44 +196,9 @@ def test_ascend_mla_metadata_builder_default(self):
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
self.assertEqual(builder.torchair_graph_enabled, True)

@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True

builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)

input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]

scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
scheduler_output.scheduled_spec_decode_tokens = {
0: [1],
1: [],
2: [1, 1],
3: []
}

input_batch.swap_states = MagicMock()

modified = builder.reorder_batch(input_batch, scheduler_output)

self.assertFalse(modified)
input_batch.swap_states.assert_not_called()

def test_reorder_batch_without_torchair_graph(self):
def test_reorder_batch(self):
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = False

mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
Expand Down Expand Up @@ -268,128 +230,6 @@ def test_reorder_batch_without_torchair_graph(self):
self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2)

@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))

@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 4)
self.assertTrue(torch.equal(result, block_tables[:, :4]))

@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_get_graph_runner_block_tables_from_numpy(self,
mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)

block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

result = builder._get_graph_runner_block_tables(3, block_tables)

self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))

@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_dummy(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False

mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config,
mock_device,
metadata_cls=AscendMLAMetadata)
builder.rope_dim = 64

with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=3,
num_actual_tokens=3,
decode_token_per_req=1,
actual_seq_lengths_q=[0, 1, 2],
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
)
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)

sin_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)

self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_input_tokens, 3)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 1)
self.assertEqual(metadata.num_decode_tokens, 1)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLADecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 64)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 3)
assert torch.equal(sin_golden, metadata.decode.sin)
assert torch.equal(cos_golden, metadata.decode.cos)


class TestAscendMLAImpl(TestBase):

Expand All @@ -401,8 +241,6 @@ class TestAscendMLAImpl(TestBase):
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
mock_tp.world_size = 2
ascend_config.torchair_graph_config.enabled = True
ascend_config.torchair_graph_config.enable_kv_nz = False
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
Expand Down Expand Up @@ -464,7 +302,6 @@ def test_init(self):
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
self.assertTrue(self.impl.torchair_graph_enabled)

def test_v_up_proj_and_o_proj(self):
batch_size = 4
Expand Down Expand Up @@ -580,102 +417,10 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)

@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv(self, mock_kv_cache):
batch_size = 2
hidden = torch.randn(batch_size, 128)
cos = torch.randn(batch_size, 32)
sin = torch.randn(batch_size, 32)
kv_cache = (torch.randn(
4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(
4, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
slots = torch.arange(batch_size, dtype=torch.long)

proj_out = torch.randn(
batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )

mock_kv_cache.return_value = (torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim),
torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank),
None, None)

k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots)

self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden)
mock_kv_cache.assert_called_once()
self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank))
self.assertEqual(kv.shape,
(batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))

@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_prefill(self, mock_kv):
B, N, S, H = 2, self.impl.num_kv_heads, 1, 128
hidden_states = torch.randn(B, N, S, H)
cos = torch.randn(B, S, 32)
sin = torch.randn(B, S, 32)
kv_cache = (
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
)

slots = torch.arange(B * S, dtype=torch.long)

proj_out = torch.randn(
B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )

mock_kv.return_value = (None, None,
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.qk_rope_head_dim),
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.kv_lora_rank))

k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin,
kv_cache, slots)

self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states)
mock_kv.assert_called_once()

self.assertEqual(
k_pe.shape,
(B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank))

@patch("torch_npu.npu_interleave_rope")
def test_rope_single(self, mock_rope):
B, N, D = 2, 16, 1024
x = torch.randn(B, N, D)
cos = torch.randn(B, N, 1, D)
sin = torch.randn(B, N, 1, D)
mock_rope.return_value = x.view(B, N, 1, D)
result = self.impl.rope_single(x, cos, sin)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], D)
mock_rope.assert_called_once()

@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj_and_o_proj")
@patch("torch_npu._npu_paged_attention_mla")
def test_forward_decode_without_graph(self, mock_page_attention_mla,
mock_up_proj):
self.impl.running_in_graph = False
self.impl.running_chunkprefilll_with_torchair = False
num_tokens = 100
num_blocks = 256
block_size = 4
Expand Down Expand Up @@ -706,9 +451,6 @@ def test_forward_decode_without_graph(self, mock_page_attention_mla,
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._forward_prefill")
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
self.impl.running_in_graph = False
self.impl.torchair_graph_enabled = False

num_tokens = 100
num_blocks = 256
block_size = 4
Expand Down
21 changes: 21 additions & 0 deletions tests/ut/test_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,27 @@ def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
self.assertEqual(result,
"vllm_ascend.attention.mla_v1.AscendMLABackend")

@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_mla_and_torchair(
self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True

mock_get_ascend_config.return_value = mock_config

result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=True,
)
self.assertEqual(
result,
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend")

@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_torchair(self,
mock_get_ascend_config):
Expand Down
Loading
Loading