Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ def test_mla_preprocess(self, magic_npu_fetch,
MagicMock(), MagicMock()
]
self.impl.num_kv_heads = self.impl.num_heads
self.impl.is_kv_producer = False

decode_res, prefill_res = self.impl._mla_preprocess(
"mock_layer",
Expand Down
214 changes: 130 additions & 84 deletions tests/ut/kv_connector/test_mooncake_layerwise_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
MooncakeAgentMetadata, MooncakeLayerwiseConnector,
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv,
MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv,
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash,
zmq_ctx)

Expand Down Expand Up @@ -71,7 +71,8 @@ def setUp(self):
remote_port=7777,
remote_te_rpc_port=6000,
remote_kv_caches_base_addr=[4000, 8000, 14000, 18000],
metaserver="http://dummy")
metaserver="http://dummy",
chunk_finish=False)

@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
Expand Down Expand Up @@ -113,11 +114,13 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine(
key = torch.zeros((cap, dim), dtype=torch.float32)
value = torch.zeros((cap, dim), dtype=torch.float32)

thread._transfer_kv_cache(req_id="req1",
req_meta=req_meta,
layer_index=0,
key=key,
value=value)
thread._transfer_kv_cache( # type: ignore
req_id="req1",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=MagicMock())

self.engine.batch_transfer_sync_write.assert_called_once()
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
Expand All @@ -142,9 +145,37 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine(
def test_transfer_skips_when_no_local_blocks(self):
req_meta = self.req_meta_base
req_meta.local_block_ids = []
self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros(
(1, 8)), torch.zeros((1, 8)))
self.engine.batch_transfer_sync_write.assert_not_called()
self.thread.pd_head_ratio = 1
self.thread.block_len = [64, 128]

key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)

reshape_cache_event = MagicMock()
with patch.object(self.engine,
'batch_transfer_sync_write') as mock_batch_transfer:
mock_batch_transfer.return_value = 1

def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key,
value,
reshape_cache_event): # type: ignore
if not req_meta.local_block_ids:
return
self._transfer_kv_cache( # type: ignore
req_id, req_meta, layer_index, key, value,
reshape_cache_event)

self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore
self.thread._transfer_kv_cache( # type: ignore
req_id="req2",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=reshape_cache_event)

mock_batch_transfer.assert_not_called()
self.assertEqual(mock_batch_transfer.call_count, 0)

def test_transfer_skips_when_tp_not_sender(self):

Expand All @@ -161,8 +192,13 @@ def test_transfer_skips_when_tp_not_sender(self):
first_kv_cache=self.first_kv_cache,
callback_func=MagicMock())
req_meta = self.req_meta_base
thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)),
torch.zeros((1, 8)))
thread._transfer_kv_cache( # type: ignore
"req3",
req_meta,
0,
torch.zeros((1, 8)),
torch.zeros((1, 8)),
reshape_cache_event=MagicMock())
self.engine.batch_transfer_sync_write.assert_not_called()

@patch(
Expand All @@ -172,25 +208,30 @@ def test_transfer_skips_when_tp_not_sender(self):
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
)
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):

req_meta = self.req_meta_base
req_meta.local_block_ids = [5, 6]
req_meta.remote_block_ids = [10, 11]

req_meta.remote_kv_caches_base_addr = [
7000, 8000, 9000, 10000, 11000, 12000
]

req_meta.chunk_finish = True
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)

self.thread._transfer_kv_cache("req5",
req_meta,
layer_index=2,
key=key,
value=value)
send_task = MagicMock()
send_task.layer_index = self.thread.total_layers - 1
send_task.send_request = {"req5": req_meta}

self.thread.callback_func.assert_called_once()
with patch.object(self.thread, 'callback_func') as mock_callback_func:
self.thread._transfer_kv_cache( # type: ignore
req_id="req5",
req_meta=req_meta,
layer_index=send_task.layer_index,
key=key,
value=value,
reshape_cache_event=MagicMock())
print(f"Callback called: {mock_callback_func.call_count} times")
mock_callback_func.assert_called_once()


class TestKVCacheRecvingLayerThread(unittest.TestCase):
Expand Down Expand Up @@ -468,6 +509,7 @@ def test_build_connector_meta(self):
request = MockRequest("req1")

self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
Expand Down Expand Up @@ -505,14 +547,16 @@ def __init__(self,
cached_new_block_ids=None,
cached_num_computed=None,
new_reqs=None,
num_sched=None):
num_sched=None,
scheduled_spec_decode_tokens=None):
self.scheduled_cached_reqs = SimpleNamespace(
req_ids=cached_req_ids or [],
new_block_ids=cached_new_block_ids or [],
num_computed_tokens=cached_num_computed or [],
)
self.scheduled_new_reqs = new_reqs or []
self.num_scheduled_tokens = num_sched or {}
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}


class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
Expand Down Expand Up @@ -549,43 +593,39 @@ def test_update_state_after_alloc_prefill_records_and_resets_flag(self):
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))

def test_update_state_after_alloc_decode_records_send_layerwise(self):
req = MockRequest("req_u2",
prompt_token_ids=list(range(10)),
kv_transfer_params={"do_remote_decode": True})
req = MockRequest(
"req_u2",
prompt_token_ids=list(range(10)),
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": [] # 修改为空列表 []
})

blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
self.scheduler.update_state_after_alloc(req,
blocks,
num_external_tokens=0)
self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise)
total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[
"req_u2"]
self.assertEqual(total_tokens, 10)
self.assertEqual(local_block_ids, [7, 8, 9])
self.assertIs(req_ref, req)

def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
req = MockRequest("req_b1",
kv_transfer_params={
"remote_block_ids": [1, 2],
"remote_engine_id": "E",
"remote_host": "H",
"remote_port": 5555,
"remote_te_rpc_port": 6000,
"remote_kv_caches_base_addr": [10, 11],
})
self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101])
meta = self.scheduler.build_connector_meta(_MockSchedulerOutput())
self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata)
self.assertIn("req_b1", meta.requests)
self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
info = self.scheduler._reqs_need_send_layerwise["req_u2"]
self.assertEqual(info.local_block_ids, [7, 8, 9])
self.assertIs(info.request, req)
self.assertEqual(info.remote_block_ids, [])
self.assertIsInstance(info.remote_block_ids, list)

def test_build_connector_meta_accumulates_cached_blocks(self):
req = MockRequest("req_b2",
prompt_token_ids=list(range(8)),
kv_transfer_params={"do_remote_decode": True})

self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req)
req_meta = MagicMock(spec=ReqMeta)
req_meta.local_block_ids = [1, 2, 3]
req_meta.remote_block_ids = [4, 5]
req_meta.remote_engine_id = "remote"
req_meta.remote_host = "localhost"
req_meta.remote_port = 5000
req_meta.remote_te_rpc_port = 6000
req_meta.remote_kv_caches_base_addr = [10, 20]
req_meta.metaserver = "http://dummy"
req_meta.chunk_finish = False

req_meta.extend_local_block_ids = MagicMock()
self.scheduler._reqs_need_send_layerwise["req_b2"] = req_meta

out = _MockSchedulerOutput(
cached_req_ids=["req_b2"],
Expand All @@ -596,47 +636,53 @@ def test_build_connector_meta_accumulates_cached_blocks(self):
)
meta = self.scheduler.build_connector_meta(out)
self.assertEqual(len(meta.requests), 0)
total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[
"req_b2"]
self.assertEqual(total, 8)
self.assertEqual(block_ids, [1, 2, 3, 4])

def test_build_connector_meta_emits_when_tokens_reach_total(self):

req = MockRequest("req_b3",
prompt_token_ids=list(range(12)),
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": [9],
"remote_engine_id": "E",
"remote_host": "H",
"remote_port": 5555,
"remote_te_rpc_port": 6000,
"remote_kv_caches_base_addr": [10, 11],
})
self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100,
101], req)

req_meta.extend_local_block_ids.assert_called_once_with([3, 4])

@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
)
def test_build_connector_meta_emits_when_tokens_reach_total(
self, mock_group_concurrent_contiguous):
req_meta = MagicMock(spec=ReqMeta)
req_meta.local_block_ids = [1, 2, 3]
req_meta.remote_block_ids = [4, 5]
req_meta.remote_engine_id = "remote"
req_meta.remote_host = "localhost"
req_meta.remote_port = 5000
req_meta.remote_te_rpc_port = 6000
req_meta.remote_kv_caches_base_addr = [10, 20]
req_meta.metaserver = "http://dummy"
req_meta.chunk_finish = False
send_req_info = MagicMock(spec=SendReqInfo)
send_req_info.local_block_ids = [1, 2, 3]
send_req_info.remote_block_ids = [4, 5]
send_req_info.remote_cache_tokens = 100
send_req_info.local_transferred_tokens = 50
send_req_info.local_computed_tokens = 75
send_req_info.request = MagicMock()
send_req_info.extend_local_block_ids = MagicMock()
send_req_info.update_computed_tokens = MagicMock()
send_req_info.update_transferred_tokens = MagicMock()
send_req_info.unpack = MagicMock(
return_value=(send_req_info.local_block_ids,
send_req_info.remote_block_ids,
send_req_info.remote_cache_tokens,
send_req_info.local_transferred_tokens,
send_req_info.local_computed_tokens,
send_req_info.request))

self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
out = _MockSchedulerOutput(
cached_req_ids=["req_b3"],
cached_new_block_ids=[([50], )],
cached_num_computed=[8],
new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)],
new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)],
num_sched={"req_b3": 4},
)
meta = self.scheduler.build_connector_meta(out)
send_req_info.extend_local_block_ids.assert_called_once_with([50])
self.assertIn("req_b3", meta.requests)
rmeta = meta.requests["req_b3"]

self.assertEqual(rmeta.local_block_ids, [100, 101, 50])

self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise)

def test_request_finished_returns_false_none(self):
ok, params = self.scheduler.request_finished(MockRequest("req_fin"),
[1, 2])
self.assertFalse(ok)
self.assertIsNone(params)


class TestHelperFunctions(unittest.TestCase):
Expand Down
7 changes: 7 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class AscendMetadata:
causal: bool = True
# runner_type in model_config.
model_runner_type: str = ""
# prefill reshape_and_cache event
reshape_cache_event: torch.npu.Event = None


class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
Expand Down Expand Up @@ -314,6 +316,7 @@ def __init__(
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer

def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
Expand Down Expand Up @@ -628,6 +631,8 @@ def reshape_and_cache(
):

if len(kv_cache) > 1:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
Expand All @@ -648,6 +653,8 @@ def reshape_and_cache(
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots[:attn_metadata.num_actual_tokens])
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value

def forward_impl(
Expand Down
8 changes: 8 additions & 0 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class AscendMLAMetadata:

decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None
reshape_cache_event: torch.npu.Event = None

def __post_init__(self):
pass
Expand Down Expand Up @@ -695,6 +696,7 @@ def __init__(
kv_sharing_target_layer_name: Optional[str],
**kwargs,
):
self.vllm_config = get_current_vllm_config()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down Expand Up @@ -741,6 +743,8 @@ def __init__(
self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO

self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer

def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
x = x.view(self.num_heads, -1, self.kv_lora_rank)
Expand Down Expand Up @@ -1321,8 +1325,12 @@ def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache,
prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view(
-1, self.num_heads,
Expand Down
Loading
Loading