Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5376308
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
d2bff3f
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
c689c63
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
cf50e6c
Merge remote-tracking branch 'origin/main'
wjy9595 Dec 24, 2025
5dc9c2c
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
589b477
Merge remote-tracking branch 'origin/main'
wjy9595 Dec 24, 2025
d5062ab
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
ebc67de
Merge branch 'main' into main
wujinyuan1 Dec 24, 2025
c5cd3a9
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
eb98e1b
Merge remote-tracking branch 'origin/main'
wjy9595 Dec 24, 2025
f14901d
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
0bfc6a7
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 24, 2025
ae4dcb4
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 25, 2025
41b0565
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 25, 2025
08ed935
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 25, 2025
4a9effe
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 25, 2025
45b58f8
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 25, 2025
152a48e
Merge branch 'main' into main
wujinyuan1 Dec 25, 2025
5b3de36
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 25, 2025
25e452e
Merge remote-tracking branch 'origin/main'
wjy9595 Dec 25, 2025
2a9e99b
Merge branch 'main' into main
wujinyuan1 Dec 26, 2025
583c0af
[Refactor]6/N Extract common code of class AscendMLAImpl
wjy9595 Dec 26, 2025
a5ea83c
Merge remote-tracking branch 'origin/main'
wjy9595 Dec 26, 2025
954f63b
Merge branch 'main' into main
weijinqian0 Dec 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def test_init(self):

@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess_dcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_dcp_group):
Expand Down Expand Up @@ -339,7 +339,7 @@ def mock_all_gather_func(tensor, dim):
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_cp.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess_pcp(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad,
mock_get_pcp_group,
Expand Down Expand Up @@ -543,8 +543,8 @@ def make_all_gather(ws):
self.impl._v_up_proj.return_value = torch.randn(
B, self.impl.v_head_dim)

result = self.impl._forward_decode_pcp_dcp(q_nope, q_pe, k_nope, k_pe,
BS, attn_metadata)
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)

self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], self.impl.v_head_dim)
Expand Down Expand Up @@ -578,14 +578,14 @@ def mock_kv_b_proj(kv_c_normed):

def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
padded_local_chunk_seq_lens_lst: list[int],
local_context_lens_allranks: list[list[int]],
sum_seq_len: int, max_seq_len: int,
chunk_size: int, chunk_idx: int, toks: int):
return torch.randn(sum_seq_len, allgatered_kv_c_normed.shape[1],
allgatered_kv_c_normed.shape[2]), torch.randn(
sum_seq_len, allgatered_k_pe.shape[1],
allgatered_k_pe.shape[2])
chunked_context: CPChunkedContextMetadata,
chunk_idx: int, toks: int):
return torch.randn(
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
allgatered_kv_c_normed.shape[1],
allgatered_kv_c_normed.shape[2]), torch.randn(
chunked_context.cu_seq_lens_lst[chunk_idx][-1],
allgatered_k_pe.shape[1], allgatered_k_pe.shape[2])

# mock proj
self.impl.kv_b_proj.side_effect = mock_kv_b_proj
Expand Down Expand Up @@ -679,10 +679,6 @@ def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor,
iters * (1 if dcp_size * pcp_size > 1 else 0))
self.assertEqual(mock_load.call_count, iters)
self.assertEqual(mock_ring.call_count, iters)
self.assertEqual(mock_dcp.all_gather.call_count,
(1 if dcp_size > 1 else 0))
self.assertEqual(mock_pcp.all_gather.call_count,
iters * (1 if pcp_size > 1 else 0))
mock_reorg.reset_mock()
mock_load.reset_mock()
mock_ring.reset_mock()
Expand All @@ -691,7 +687,18 @@ def mock_reorg_kvcache(allgatered_kv_c_normed: torch.Tensor,
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)

def test_reorg_kvcache_with_dcp_pcp(self):
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group,
mock_pcp, mock_get_pcp_group):

def mock_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)

BLOCK_SIZE = 128 # fixed
max_model_len = 4096
max_num_seqs = 25
Expand All @@ -706,6 +713,12 @@ def test_reorg_kvcache_with_dcp_pcp(self):
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
if pcp_size * dcp_size == 1:
continue
self.impl.dcp_size = dcp_size
self.impl.pcp_size = pcp_size
mock_dcp.all_gather = MagicMock(
side_effect=mock_all_gather(dcp_size))
mock_pcp.all_gather = MagicMock(
side_effect=mock_all_gather(pcp_size))
chunked_prefill_workspace_size = min(
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE),
128 * 1024)
Expand All @@ -723,27 +736,21 @@ def test_reorg_kvcache_with_dcp_pcp(self):

for i in range(len(chunked_context.seq_tot)):
allgatered_kv_c_normed = torch.randn(
chunked_context.seq_tot[i] * pcp_size * dcp_size,
self.impl.num_heads, self.impl.v_head_dim)
allgatered_k_pe = torch.randn(
chunked_context.seq_tot[i] * pcp_size * dcp_size,
self.impl.num_heads, self.impl.qk_rope_head_dim)
chunked_context.seq_tot[i], self.impl.num_heads,
self.impl.kv_lora_rank)
allgatered_k_pe = torch.randn(chunked_context.seq_tot[i],
self.impl.num_heads,
self.impl.qk_rope_head_dim)
result_kv, result_k_pe = self.impl._reorg_kvcache(
allgatered_kv_c_normed,
allgatered_k_pe,
padded_local_chunk_seq_lens_lst=chunked_context.
padded_local_chunk_seq_lens[i],
local_context_lens_allranks=chunked_context.
local_context_lens_allranks,
sum_seq_len=chunked_context.cu_seq_lens_lst[i][-1],
max_seq_len=chunked_context.max_seq_lens[i],
chunk_size=chunked_context.chunk_size,
chunked_context,
chunk_idx=i,
toks=chunked_context.seq_tot[i],
)
self.assertEqual(result_kv.shape,
(chunked_context.cu_seq_lens_lst[i][-1],
self.impl.num_heads, self.impl.v_head_dim))
self.impl.num_heads, self.impl.kv_lora_rank))
self.assertEqual(
result_k_pe.shape,
(chunked_context.cu_seq_lens_lst[i][-1],
Expand All @@ -754,6 +761,11 @@ def test_reorg_kvcache_with_dcp_pcp(self):
self.assertEqual(result_k_pe.shape[0],
chunked_context.cu_seq_lens_lst[i][-1])

self.assertEqual(mock_dcp.all_gather.call_count,
(1 if dcp_size > 1 else 0))
self.assertEqual(mock_pcp.all_gather.call_count,
(1 if pcp_size > 1 else 0))

def test_out_lse_reshape(self):
test_cases = [10, 1, 128, 512]
for test_case in test_cases:
Expand Down Expand Up @@ -1052,10 +1064,9 @@ def mock_npu_ring_mla_effect(q_nope, q_rope, k_nope, k_rope, value,
attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu(
torch.ones(10, 10, dtype=torch.float16), 1)

output = self.impl._forward_prefill_cp(q_nope, q_pe, k_nope,
k_pe, value,
kv_c_and_k_pe_cache,
attn_metadata)
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
value, kv_c_and_k_pe_cache,
attn_metadata)
self.assertEqual(
output.shape,
(seq_len_q, self.impl.num_heads * self.impl.v_head_dim))
Loading
Loading