From f34498a58f7ba72ce2a77aa8ddee4e77f249c68e Mon Sep 17 00:00:00 2001 From: MrZ20 <2609716663@qq.com> Date: Mon, 20 Apr 2026 10:26:16 +0800 Subject: [PATCH 1/2] fix Signed-off-by: MrZ20 <2609716663@qq.com> --- pyproject.toml | 43 +- .../kv_connector/test_mooncake_connector.py | 882 ++++++++---------- .../test_mooncake_layerwise_connector.py | 480 ++++------ .../test_remote_decode_lifecycle.py | 47 +- .../test_remote_prefill_lifecycle.py | 83 +- tests/ut/kv_connector/utils.py | 8 +- .../model_loader/netloader/test_netloader.py | 65 +- .../netloader/test_netloader_elastic.py | 175 ++-- .../netloader/test_netloader_load.py | 16 +- .../netloader/test_netloader_utils.py | 17 +- 10 files changed, 809 insertions(+), 1007 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f44307cd644..1d8c9ea8a85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,16 +50,13 @@ plugins.md029.enabled = false # ol-prefix line-length = 120 # Folder to be modified exclude = [ - # Batch (1) - # "tests/e2e/__init__.py", - # "tests/e2e/310p/", - # "tests/e2e/conftest.py", - # "tests/e2e/doctests/", - # "tests/e2e/model_utils.py", - # "tests/e2e/models/", - # "tests/e2e/multicard/2-cards/", - - # Batch (2) + "tests/e2e/__init__.py", + "tests/e2e/310p/", + "tests/e2e/conftest.py", + "tests/e2e/doctests/", + "tests/e2e/model_utils.py", + "tests/e2e/models/", + "tests/e2e/multicard/2-cards/", "tests/e2e/multicard/4-cards/", "tests/e2e/nightly/multi_node/", "tests/e2e/singlecard/pooling/", @@ -67,11 +64,31 @@ exclude = [ "tests/e2e/utils.py", "tests/e2e/vllm_interface/", "tests/e2e/weekly/", - - # Batch (3) "tests/e2e/nightly/single_node/", - "tests/ut/", + "tests/ut/_310p/", + "tests/ut/attention/", + "tests/ut/base.py", + + "tests/ut/batch_invariant/", + "tests/ut/compilation/", + "tests/ut/conftest.py", + "tests/ut/core/", + "tests/ut/device_allocator/", + "tests/ut/distributed/", + "tests/ut/eplb/", + + "tests/ut/ops/", + "tests/ut/patch/", + "tests/ut/quantization/", + "tests/ut/sample/", + + "tests/ut/spec_decode/", + "tests/ut/test_ascend_config.py", + "tests/ut/test_envs.py", + "tests/ut/test_platform.py", + "tests/ut/test_utils.py", + "tests/ut/worker/", ] [tool.ruff.lint] diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 45918c4a389..d2d95f85816 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -6,8 +6,8 @@ import time import types import unittest -from collections import defaultdict, deque -from typing import Any, Dict, OrderedDict, Optional +from collections import OrderedDict, defaultdict, deque +from typing import Any from unittest.mock import MagicMock, patch import msgspec @@ -23,36 +23,42 @@ _mock_tp_group = MagicMock(rank_in_group=0, world_size=4) _mock_pcp_group = MagicMock(rank_in_group=0, world_size=1) _mock_dcp_group = MagicMock(rank_in_group=0, world_size=1) +patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pp_group", return_value=_mock_pp_group).start() +patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tp_group", return_value=_mock_tp_group).start() patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pp_group', - return_value=_mock_pp_group).start() + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_world_size", return_value=4 +).start() patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tp_group', - return_value=_mock_tp_group).start() + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_rank", return_value=0 +).start() patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_world_size', - return_value=4).start() -patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_rank', - return_value=0).start() -patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pcp_group', - return_value=_mock_pcp_group).start() -patch('vllm.distributed.parallel_state._DCP', _mock_dcp_group).start() + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pcp_group", return_value=_mock_pcp_group +).start() +patch("vllm.distributed.parallel_state._DCP", _mock_dcp_group).start() from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import ( # noqa: E402 - KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker, - KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector, - MooncakeConnectorMetadata, MooncakeConnectorScheduler, - MooncakeConnectorWorker, ReqMeta, ensure_zmq_recv, ensure_zmq_send, - group_concurrent_contiguous, string_to_int64_hash, zmq_ctx) + KVCacheRecvingThread, + KVCacheSendingThread, + KVCacheTaskTracker, + KVConnectorRole, + MooncakeAgentMetadata, + MooncakeConnector, + MooncakeConnectorMetadata, + MooncakeConnectorScheduler, + MooncakeConnectorWorker, + ReqMeta, + ensure_zmq_recv, + ensure_zmq_send, + group_concurrent_contiguous, + string_to_int64_hash, + zmq_ctx, +) GET_META_MSG = b"get_meta_msg" DONE_RECVING_MSG = b"done_recving_msg" class TestKVCacheTaskTrackerInit(unittest.TestCase): - def test_init_basic_properties(self): tracker = KVCacheTaskTracker() self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) @@ -61,7 +67,6 @@ def test_init_basic_properties(self): class TestGetAndClearFinishedSingleRequests(unittest.TestCase): - def setUp(self): self.tracker = KVCacheTaskTracker() self.tracker.finished_requests = set() @@ -84,45 +89,40 @@ def test_multiple_requests(self): self.assertSetEqual(result, {"req_1", "req_2", "req_3"}) self.assertEqual(len(self.tracker.finished_requests), 0) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") def test_concurrent_access(self, mock_logger): from concurrent.futures import ThreadPoolExecutor + self.tracker.finished_requests = {"req_1", "req_2"} with ThreadPoolExecutor(max_workers=3) as executor: - futures = [ - executor.submit(self.tracker.get_and_clear_finished_requests) - for _ in range(3) - ] + futures = [executor.submit(self.tracker.get_and_clear_finished_requests) for _ in range(3)] results = [f.result() for f in futures] self.assertEqual(sum(1 for r in results if r), 1) self.assertEqual(len(self.tracker.finished_requests), 0) class TestKVCacheSendingThreadInit(unittest.TestCase): - def setUp(self): - kv_caches: Dict[str, Any] = {} + kv_caches: dict[str, Any] = {} self.common_args = { - 'tp_rank': 1, - 'prefill_tp_size': 4, - 'local_engine_id': 'engine_1', - 'side_channel_host': 'localhost', - 'side_channel_port': 5555, - 'metadata': MagicMock(), - 'vllm_config': MockVllmConfig(), - 'ready_event': threading.Event(), - 'kv_caches': kv_caches, - 'pcp_rank': 0 + "tp_rank": 1, + "prefill_tp_size": 4, + "local_engine_id": "engine_1", + "side_channel_host": "localhost", + "side_channel_port": 5555, + "metadata": MagicMock(), + "vllm_config": MockVllmConfig(), + "ready_event": threading.Event(), + "kv_caches": kv_caches, + "pcp_rank": 0, } self.threads = [] def tearDown(self): for thread in self.threads: - if hasattr(thread, 'task_tracker') and hasattr( - thread.task_tracker, 'socket'): + if hasattr(thread, "task_tracker") and hasattr(thread.task_tracker, "socket"): thread.task_tracker.socket.close() - if hasattr(thread, 'is_alive') and thread.is_alive(): + if hasattr(thread, "is_alive") and thread.is_alive(): thread.join(timeout=0.1) def test_thread_daemon_property(self): @@ -138,35 +138,32 @@ def test_thread_name_format(self): def test_ready_event_reference(self): custom_event = threading.Event() args = self.common_args.copy() - args['ready_event'] = custom_event + args["ready_event"] = custom_event thread = KVCacheSendingThread(**args) self.threads.append(thread) self.assertIs(thread.ready_event, custom_event) class TestGetAndClearFinishedRequests(unittest.TestCase): - def setUp(self): - kv_caches: Dict[str, Any] = {} + kv_caches: dict[str, Any] = {} self.common_args = { - 'tp_rank': 1, - 'prefill_tp_size': 4, - 'local_engine_id': 'engine_1', - 'side_channel_host': 'localhost', - 'vllm_config': MockVllmConfig(), - 'side_channel_port': 5555, - 'metadata': { - "test": "metadata" - }, - 'ready_event': threading.Event(), - 'kv_caches': kv_caches, - 'pcp_rank': 0 + "tp_rank": 1, + "prefill_tp_size": 4, + "local_engine_id": "engine_1", + "side_channel_host": "localhost", + "vllm_config": MockVllmConfig(), + "side_channel_port": 5555, + "metadata": {"test": "metadata"}, + "ready_event": threading.Event(), + "kv_caches": kv_caches, + "pcp_rank": 0, } self.thread = KVCacheSendingThread(**self.common_args) - @patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests') + @patch.object(KVCacheTaskTracker, "get_and_clear_finished_requests") def test_get_and_clear_finished_requests(self, mock_get_clear): - expected_requests = {'req1', 'req2'} + expected_requests = {"req1", "req2"} mock_get_clear.return_value = expected_requests result = self.thread.get_and_clear_finished_requests() mock_get_clear.assert_called_once() @@ -174,7 +171,6 @@ def test_get_and_clear_finished_requests(self, mock_get_clear): class TestKVCacheSendingThread(unittest.TestCase): - def test_run_handles_get_meta_and_done_recv_msgs(self): ready_event = threading.Event() metadata = MooncakeAgentMetadata( @@ -186,25 +182,26 @@ def test_run_handles_get_meta_and_done_recv_msgs(self): vllm_config = MockVllmConfig() host = "127.0.0.1" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) base_port = s.getsockname()[1] - thread = KVCacheSendingThread(tp_rank=0, - prefill_tp_size=1, - local_engine_id="engine1", - side_channel_host=host, - side_channel_port=base_port, - metadata=metadata, - vllm_config=vllm_config, - ready_event=ready_event, - kv_caches={}, - pcp_rank=0) + thread = KVCacheSendingThread( + tp_rank=0, + prefill_tp_size=1, + local_engine_id="engine1", + side_channel_host=host, + side_channel_port=base_port, + metadata=metadata, + vllm_config=vllm_config, + ready_event=ready_event, + kv_caches={}, + pcp_rank=0, + ) thread.start() - actual_port = base_port + (thread.pp_rank * thread.tp_size + - thread.tp_rank + - thread.pcp_rank * thread.prefill_tp_size) - self.assertTrue(ready_event.wait(timeout=3), - "Server thread startup timeout") + actual_port = base_port + ( + thread.pp_rank * thread.tp_size + thread.tp_rank + thread.pcp_rank * thread.prefill_tp_size + ) + self.assertTrue(ready_event.wait(timeout=3), "Server thread startup timeout") context = zmq.Context() # type: ignore sock = context.socket(zmq.DEALER) # type: ignore @@ -212,7 +209,7 @@ def test_run_handles_get_meta_and_done_recv_msgs(self): encoder = msgspec.msgpack.Encoder() decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata) - sock.send_multipart([b"", encoder.encode((GET_META_MSG, ))]) + sock.send_multipart([b"", encoder.encode((GET_META_MSG,))]) frames = sock.recv_multipart() self.assertEqual(frames[0], b"") meta = decoder.decode(frames[1]) @@ -221,8 +218,7 @@ def test_run_handles_get_meta_and_done_recv_msgs(self): self.assertEqual(meta.num_blocks, 2) req_id = "request_42" - sock.send_multipart( - [b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))]) + sock.send_multipart([b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))]) frames = sock.recv_multipart() self.assertEqual(frames[0], b"") self.assertEqual(frames[1], b"ACK") @@ -233,12 +229,11 @@ def test_run_handles_get_meta_and_done_recv_msgs(self): class TestKVCacheRecvingThreadBasic(unittest.TestCase): - def setUp(self): self.engine = MagicMock() self.ready_event = threading.Event() self.vllm_config = MockVllmConfig() - self.kv_caches: Dict[str, Any] = {} + self.kv_caches: dict[str, Any] = {} self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, @@ -252,7 +247,8 @@ def setUp(self): ready_event=self.ready_event, vllm_config=self.vllm_config, kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + prefill_pp_layer_partition=None, + ) def test_add_request(self): test_req = { @@ -264,7 +260,7 @@ def test_add_request(self): "remote_handshake_port": 6666, "offset": 0, "tp_num_need_pulls": 2, - "all_task_done": False + "all_task_done": False, } self.thread.add_request( request_id=test_req["request_id"], @@ -275,12 +271,13 @@ def test_add_request(self): remote_handshake_port=test_req["remote_handshake_port"], offset=test_req["offset"], tp_num_need_pulls=test_req["tp_num_need_pulls"], - all_task_done=test_req["all_task_done"]) + all_task_done=test_req["all_task_done"], + ) queued = self.thread.request_queue.get_nowait() self.assertEqual(queued["request_id"], "req1") self.assertEqual(queued["remote_host"], "localhost") - @patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests') + @patch.object(KVCacheTaskTracker, "get_and_clear_finished_requests") def test_get_finished_requests(self, mock_tracker): mock_tracker.return_value = {"req1", "req2"} result = self.thread.get_and_clear_finished_requests() @@ -288,12 +285,11 @@ def test_get_finished_requests(self, mock_tracker): class TestSocketManagement(unittest.TestCase): - def setUp(self): self.engine = MagicMock() self.ready_event = threading.Event() self.vllm_config = MockVllmConfig() - self.kv_caches: Dict[str, Any] = {} + self.kv_caches: dict[str, Any] = {} self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, @@ -307,16 +303,13 @@ def setUp(self): ready_event=self.ready_event, vllm_config=self.vllm_config, kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + prefill_pp_layer_partition=None, + ) self.thread.remote_sockets = defaultdict(deque) self.thread.remote_poller = MagicMock() - @patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.zmq.Context' - ) - @patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.make_zmq_socket' - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.zmq.Context") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.make_zmq_socket") def test_get_remote_socket(self, mock_make_socket, mock_context): mock_sock = MagicMock() mock_make_socket.return_value = mock_sock @@ -328,11 +321,10 @@ def test_get_remote_socket(self, mock_make_socket, mock_context): self.assertEqual(sock, mock_sock) mock_make_socket.assert_called_once() args, kwargs = mock_make_socket.call_args - self.assertEqual(kwargs.get('path'), 'tcp://test_host:12345') - self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore - self.assertFalse(kwargs.get('bind', True)) - self.thread.remote_poller.register.assert_called_with( - mock_sock, zmq.POLLIN) # type: ignore + self.assertEqual(kwargs.get("path"), "tcp://test_host:12345") + self.assertEqual(kwargs.get("socket_type"), zmq.REQ) # type: ignore + self.assertFalse(kwargs.get("bind", True)) + self.thread.remote_poller.register.assert_called_with(mock_sock, zmq.POLLIN) # type: ignore def test_return_socket_to_pool(self): mock_sock = MagicMock() @@ -348,15 +340,12 @@ def test_return_socket_to_pool(self): class TestCoreFunctionality(unittest.TestCase): - def setUp(self): self.engine = MagicMock() self.ready_event = threading.Event() self.mock_queue = MagicMock() self.vllm_config = MockVllmConfig() - self.kv_caches: Dict[str, Any] = { - "layer_0": (MagicMock(), MagicMock()) - } + self.kv_caches: dict[str, Any] = {"layer_0": (MagicMock(), MagicMock())} self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, @@ -370,7 +359,8 @@ def setUp(self): ready_event=self.ready_event, vllm_config=self.vllm_config, kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + prefill_pp_layer_partition=None, + ) self.thread.request_queue = self.mock_queue self.test_req = { "request_id": "req1", @@ -382,17 +372,15 @@ def setUp(self): "remote_transfer_port": 7777, "offset": 0, "tp_num_need_pulls": 2, - "remote_port_send_num": { - 6666: 1 - }, - "all_task_done": False + "remote_port_send_num": {6666: 1}, + "all_task_done": False, } self.thread.task_tracker = MagicMock() self.engine.batch_transfer_sync_read.return_value = 0 self.thread.remote_te_port = {"remote_engine": {6666: 7777}} - @patch.object(KVCacheRecvingThread, '_transfer_kv_cache') - @patch.object(KVCacheRecvingThread, '_send_done_recv_signal') + @patch.object(KVCacheRecvingThread, "_transfer_kv_cache") + @patch.object(KVCacheRecvingThread, "_send_done_recv_signal") def test_handle_request(self, mock_send, mock_transfer): mock_transfer.return_value = None mock_send.return_value = None @@ -403,19 +391,14 @@ def test_handle_request(self, mock_send, mock_transfer): mock_send.assert_called_once_with("req1", "localhost", 6666, {6666: 1}) if not self.thread.task_tracker.update_done_task_count.called: self.thread.task_tracker.update_done_task_count("req1") - self.thread.task_tracker.update_done_task_count.assert_called_once_with( - "req1") + self.thread.task_tracker.update_done_task_count.assert_called_once_with("req1") self.mock_queue.task_done.assert_called_once() - @patch.object(KVCacheRecvingThread, '_get_remote_metadata') + @patch.object(KVCacheRecvingThread, "_get_remote_metadata") def test_transfer_kv_cache(self, mock_get_meta): - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config' - ) as mock_config: + with patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config") as mock_config: mock_config.return_value.enable_kv_nz = False - self.thread.kv_caches_base_addr["remote_engine"] = { - 6666: [0x3000, 0x4000] - } + self.thread.kv_caches_base_addr["remote_engine"] = {6666: [0x3000, 0x4000]} self.thread._transfer_kv_cache(self.test_req) self.engine.batch_transfer_sync_read.assert_called_once() call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args @@ -429,21 +412,18 @@ def test_transfer_kv_cache(self, mock_get_meta): def test_transfer_kv_cache_failure(self): self.engine.batch_transfer_sync_read.return_value = -1 - self.thread.kv_caches_base_addr["remote_engine"] = { - 6666: [0x3000, 0x4000] - } + self.thread.kv_caches_base_addr["remote_engine"] = {6666: [0x3000, 0x4000]} with self.assertRaises(RuntimeError): self.thread._transfer_kv_cache(self.test_req) class TestMetadataHandling(unittest.TestCase): - def setUp(self): self.engine = MagicMock() self.ready_event = threading.Event() self.vllm_config = MockVllmConfig() - self.kv_caches: Dict[str, Any] = {} + self.kv_caches: dict[str, Any] = {} self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, @@ -457,49 +437,42 @@ def setUp(self): ready_event=self.ready_event, vllm_config=self.vllm_config, kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + prefill_pp_layer_partition=None, + ) self.test_metadata = MooncakeAgentMetadata( - engine_id="remote_engine", - te_rpc_port=9090, - kv_caches_base_addr=[0x3000, 0x4000], - num_blocks=2) + engine_id="remote_engine", te_rpc_port=9090, kv_caches_base_addr=[0x3000, 0x4000], num_blocks=2 + ) - @patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_send' - ) - @patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_recv' - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_send") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_recv") def test_get_remote_metadata_success(self, mock_recv, mock_send): mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata) - with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \ - patch.object(self.thread, '_return_remote_socket') as mock_return_socket: + with ( + patch.object(self.thread, "_get_remote_socket") as mock_get_socket, + patch.object(self.thread, "_return_remote_socket") as mock_return_socket, + ): mock_socket = MagicMock() mock_get_socket.return_value = mock_socket self.thread._get_remote_metadata("host1", 5555) mock_get_socket.assert_called_once_with("host1", 5555) - mock_return_socket.assert_called_once_with(mock_socket, "host1", - 5555) - mock_send.assert_called_once_with( - mock_socket, self.thread.encoder.encode((GET_META_MSG, ""))) - mock_recv.assert_called_once_with(mock_socket, - self.thread.remote_poller) - self.assertEqual( - self.thread.kv_caches_base_addr["remote_engine"][5555], - [0x3000, 0x4000]) + mock_return_socket.assert_called_once_with(mock_socket, "host1", 5555) + mock_send.assert_called_once_with(mock_socket, self.thread.encoder.encode((GET_META_MSG, ""))) + mock_recv.assert_called_once_with(mock_socket, self.thread.remote_poller) + self.assertEqual(self.thread.kv_caches_base_addr["remote_engine"][5555], [0x3000, 0x4000]) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_send") @patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_send' + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_recv", + side_effect=Exception("Network error"), ) - @patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_recv', - side_effect=Exception("Network error")) def test_get_remote_metadata_failure(self, mock_recv, mock_send): - with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \ - patch.object(self.thread, '_return_remote_socket') as mock_return_socket: + with ( + patch.object(self.thread, "_get_remote_socket") as mock_get_socket, + patch.object(self.thread, "_return_remote_socket") as mock_return_socket, + ): mock_socket = MagicMock() mock_get_socket.return_value = mock_socket @@ -511,12 +484,11 @@ def test_get_remote_metadata_failure(self, mock_recv, mock_send): class TestMainThreadLoop(unittest.TestCase): - def setUp(self): self.engine = MagicMock() self.ready_event = threading.Event() self.vllm_config = MockVllmConfig() - self.kv_caches: Dict[str, Any] = {} + self.kv_caches: dict[str, Any] = {} self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, @@ -530,10 +502,11 @@ def setUp(self): ready_event=self.ready_event, vllm_config=self.vllm_config, kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + prefill_pp_layer_partition=None, + ) self.thread.request_queue = queue.Queue() - @patch.object(KVCacheRecvingThread, '_handle_request') + @patch.object(KVCacheRecvingThread, "_handle_request") def test_run_loop_normal(self, mock_handle): test_request = { "request_id": "req1", @@ -545,7 +518,7 @@ def test_run_loop_normal(self, mock_handle): "remote_transfer_port": 7777, "offset": 0, "tp_num_need_pulls": 2, - "all_task_done": False + "all_task_done": False, } self.thread.request_queue.put(test_request) @@ -561,7 +534,6 @@ def test_run_loop_normal(self, mock_handle): class MockVllmConfig: - def __init__(self): self.model_config = MagicMock() self.parallel_config = MagicMock() @@ -574,34 +546,20 @@ def __init__(self): self.parallel_config.data_parallel_size_local = 1 self.parallel_config.pipeline_parallel_size = 1 self.parallel_config.data_parallel_rank_local = 0 - self.model_config.get_num_layers_by_block_type = MagicMock( - return_value=32) + self.model_config.get_num_layers_by_block_type = MagicMock(return_value=32) self.cache_config.block_size = 16 self.kv_transfer_config.kv_port = 5000 - self.kv_transfer_config.kv_role = 'kv_producer' + self.kv_transfer_config.kv_role = "kv_producer" self.kv_transfer_config.get_from_extra_config = MagicMock() self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: { - "prefill": { - "tp_size": 2, - "dp_size": 1, - "pp_size": 1 - }, - "decode": { - "tp_size": 2, - "dp_size": 1, - "pp_size": 1 - } + "prefill": {"tp_size": 2, "dp_size": 1, "pp_size": 1}, + "decode": {"tp_size": 2, "dp_size": 1, "pp_size": 1}, }.get(k, d) self.additional_config = {} class MockRequest: - - def __init__(self, - request_id, - prompt_token_ids=None, - kv_transfer_params=None, - status=None): + def __init__(self, request_id, prompt_token_ids=None, kv_transfer_params=None, status=None): self.request_id = request_id self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4] self.kv_transfer_params = kv_transfer_params or {} @@ -610,7 +568,6 @@ def __init__(self, class TestKVCacheTaskTracker(unittest.TestCase): - def setUp(self): self.tracker = KVCacheTaskTracker() @@ -659,9 +616,12 @@ def test_retrieve_expired_requests(self): self.tracker.add_delayed_request("req_1", current_time - 600) self.tracker.add_delayed_request("req_2", current_time) result = self.tracker._retrieve_expired_requests() - self.assertEqual(result, { - "req_1", - }) + self.assertEqual( + result, + { + "req_1", + }, + ) result_delay = self.tracker.delayed_free_requests self.assertEqual(len(result_delay), 1) self.assertIn("req_2", result_delay) @@ -676,24 +636,25 @@ def test_duplicate_task_update(self): class TestMooncakeConnectorMetadata(unittest.TestCase): - def test_add_new_req(self): meta = MooncakeConnectorMetadata() self.assertEqual(len(meta.requests), 0) self.assertEqual(len(meta.requests_to_send), 0) - meta.add_new_req(request_id="req1", - local_block_ids=[1, 2, 3], - num_external_tokens=48, - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": "remote_engine", - "remote_host": "localhost", - "remote_port": 5000, - "remote_pcp_size": 1, - "remote_dcp_size": 1, - "remote_ptp_size": 2 - }) + meta.add_new_req( + request_id="req1", + local_block_ids=[1, 2, 3], + num_external_tokens=48, + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": "remote_engine", + "remote_host": "localhost", + "remote_port": 5000, + "remote_pcp_size": 1, + "remote_dcp_size": 1, + "remote_ptp_size": 2, + }, + ) self.assertEqual(len(meta.requests), 1) req_meta = meta.requests["req1"] @@ -707,15 +668,15 @@ def test_add_new_req(self): class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): - def setUp(self): config = MockVllmConfig() self.p1 = patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config', - new=MagicMock()) + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config", new=MagicMock() + ) self.p2 = patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - new=MagicMock(return_value=MagicMock())) + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + new=MagicMock(return_value=MagicMock()), + ) self.p1.start() self.p2.start() self.addCleanup(self.p1.stop) @@ -724,14 +685,12 @@ def setUp(self): def test_get_num_new_matched_tokens(self): request = MockRequest("req1") - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens(request, 0) self.assertEqual(tokens, 0) self.assertFalse(async_flag) request.kv_transfer_params = {"do_remote_prefill": True} - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens(request, 0) self.assertEqual(tokens, 4) self.assertTrue(async_flag) @@ -746,7 +705,7 @@ def test_build_connector_meta(self): "remote_host": "localhost", "remote_port": 5000, "remote_pcp_size": 1, - "remote_dcp_size": 1 + "remote_dcp_size": 1, } meta = self.scheduler.build_connector_meta(MagicMock()) @@ -758,7 +717,6 @@ def test_build_connector_meta(self): class TestHelperFunctions(unittest.TestCase): - def test_group_concurrent_contiguous(self): src: list[int] = [1, 2, 3, 5, 6] dst: list[int] = [10, 11, 12, 14, 15] @@ -788,14 +746,15 @@ def test_string_to_int64_hash(self): class TestMooncakeConnectorForScheduler(unittest.TestCase): - def test_scheduler_role(self): config = MockVllmConfig() - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) self.assertIsNotNone(connector.connector_scheduler) self.assertIsNone(connector.connector_worker) @@ -803,11 +762,13 @@ def test_scheduler_role(self): @patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens") def test_scheduler_methods(self, mock_method): config = MockVllmConfig() - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") connector.get_num_new_matched_tokens(request, 0) @@ -815,7 +776,6 @@ def test_scheduler_methods(self, mock_method): class MockKVCacheBlocks: - def get_unhashed_block_ids(self): return [4, 5, 6] @@ -829,44 +789,46 @@ class MockForwardContext: class TestMooncakeConnector(unittest.TestCase): - def setUp(self): self.config = MockVllmConfig() os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" def test_scheduler_initialization(self): - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): - connector = MooncakeConnector(self.config, - KVConnectorRole.SCHEDULER) + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) self.assertIsNotNone(connector.connector_scheduler) self.assertIsNone(connector.connector_worker) @patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens") def test_get_num_new_matched_tokens(self, mock_method): - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): - connector = MooncakeConnector(self.config, - KVConnectorRole.SCHEDULER) + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") connector.get_num_new_matched_tokens(request, 0) mock_method.assert_called_once_with(request, 0) @patch.object(MooncakeConnectorScheduler, "update_state_after_alloc") def test_update_state_after_alloc(self, mock_method): - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): - connector = MooncakeConnector(self.config, - KVConnectorRole.SCHEDULER) + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") blocks = MockKVCacheBlocks() connector.update_state_after_alloc(request, blocks, 3) @@ -874,55 +836,54 @@ def test_update_state_after_alloc(self, mock_method): @patch.object(MooncakeConnectorScheduler, "build_connector_meta") def test_build_connector_meta(self, mock_method): - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): - connector = MooncakeConnector(self.config, - KVConnectorRole.SCHEDULER) + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) scheduler_output = MockSchedulerOutput() connector.build_connector_meta(scheduler_output) mock_method.assert_called_once_with(scheduler_output) @patch.object(MooncakeConnectorScheduler, "request_finished") def test_request_finished(self, mock_method): - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): - connector = MooncakeConnector(self.config, - KVConnectorRole.SCHEDULER) + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") connector.request_finished(request, [1, 2, 3]) mock_method.assert_called_once_with(request, [1, 2, 3]) class TestMooncakeConnectorScheduler(unittest.TestCase): - def setUp(self): self.config = MockVllmConfig() - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config' - ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()): - self.scheduler = MooncakeConnectorScheduler( - self.config, "test_engine") + with ( + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config"), + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + ): + self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine") def test_get_num_new_matched_tokens_no_remote_prefill(self): request = MockRequest("req1") - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens(request, 0) self.assertEqual(tokens, 0) self.assertFalse(async_flag) def test_get_num_new_matched_tokens_with_remote_prefill(self): - request = MockRequest("req1", - kv_transfer_params={"do_remote_prefill": True}) - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) + request = MockRequest("req1", kv_transfer_params={"do_remote_prefill": True}) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens(request, 0) self.assertEqual(tokens, 4) self.assertTrue(async_flag) @@ -933,14 +894,16 @@ def test_update_state_after_alloc_no_remote_prefill(self): self.assertEqual(len(self.scheduler._reqs_need_recv), 0) def test_update_state_after_alloc_with_remote_prefill(self): - request = MockRequest("req1", - kv_transfer_params={ - "do_remote_prefill": True, - "remote_block_ids": [1, 2, 3], - "remote_engine_id": "remote", - "remote_host": "localhost", - "remote_port": 5000 - }) + request = MockRequest( + "req1", + kv_transfer_params={ + "do_remote_prefill": True, + "remote_block_ids": [1, 2, 3], + "remote_engine_id": "remote", + "remote_host": "localhost", + "remote_port": 5000, + }, + ) blocks = MockKVCacheBlocks() self.scheduler.update_state_after_alloc(request, blocks, 3) self.assertEqual(len(self.scheduler._reqs_need_recv), 1) @@ -949,14 +912,12 @@ def test_update_state_after_alloc_with_remote_prefill(self): def test_request_finished_no_remote_decode(self): request = MockRequest("req1") - delay_free, params = self.scheduler.request_finished( - request, [1, 2, 3]) + delay_free, params = self.scheduler.request_finished(request, [1, 2, 3]) self.assertFalse(delay_free) self.assertIsNone(params) class TestUtils(unittest.TestCase): - def test_string_to_int64_hash(self): h1 = string_to_int64_hash("hello") h2 = string_to_int64_hash("hello") @@ -978,38 +939,33 @@ def test_group_empty(self): self.assertEqual(dst_g, []) def test_zmq_ctx_invalid_type(self): - with self.assertRaises(ValueError): - with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): - pass + with self.assertRaises(ValueError), zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): + pass - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.make_zmq_socket" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.make_zmq_socket") def test_zmq_ctx_ok(self, mock_make_socket): mock_socket = MagicMock() mock_make_socket.return_value = mock_socket with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore self.assertEqual(s, mock_socket) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") def test_ensure_zmq_send_success(self, mock_logger): mock_socket = MagicMock() ensure_zmq_send(mock_socket, b"hello") mock_socket.send.assert_called_once_with(b"hello") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") def test_ensure_zmq_send_retry_and_fail(self, mock_logger): mock_socket = MagicMock() mock_socket.send.side_effect = zmq.ZMQError( # type: ignore - "send failed") + "send failed" + ) with self.assertRaises(RuntimeError): ensure_zmq_send(mock_socket, b"hello", max_retries=2) self.assertEqual(mock_socket.send.call_count, 2) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") def test_ensure_zmq_recv_success(self, mock_logger): mock_socket = MagicMock() mock_socket.recv.return_value = b"response" @@ -1020,33 +976,26 @@ def test_ensure_zmq_recv_success(self, mock_logger): data = ensure_zmq_recv(mock_socket, mock_poller) self.assertEqual(data, b"response") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger") def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger): mock_socket = MagicMock() mock_poller = MagicMock() mock_poller.poll.return_value = [] with self.assertRaises(RuntimeError): - ensure_zmq_recv(mock_socket, - mock_poller, - timeout=0.01, - max_retries=2) + ensure_zmq_recv(mock_socket, mock_poller, timeout=0.01, max_retries=2) class MockMooncakeAgentMetadata: - def __init__(self, **kwargs): pass class MockMooncakeConnectorMetadata: - def __init__(self): self.requests = {} class MockKVCacheSendingThread(threading.Thread): - def __init__(self, *args, **kwargs): super().__init__() self.daemon = True @@ -1060,7 +1009,6 @@ def start(self): class MockKVCacheRecvingThread(threading.Thread): - def __init__(self, *args, **kwargs): super().__init__() self.daemon = True @@ -1075,7 +1023,6 @@ def start(self): class MockTensor: - def __init__(self, *args, **kwargs): self.size = MagicMock(return_value=(10, 16, 8, 16)) self.element_size = MagicMock(return_value=4) @@ -1087,7 +1034,6 @@ def __init__(self, *args, **kwargs): class MockTransferEngine: - def initialize(self, *args, **kwargs): return 0 @@ -1116,7 +1062,6 @@ def mock_string_to_int64_hash(s): class TestMooncakeConnectorWorker(unittest.TestCase): - def setUp(self): self.mock_transfer_engine = MagicMock() self.mock_transfer_engine.get_rpc_port.return_value = 9090 @@ -1124,46 +1069,40 @@ def setUp(self): self.mock_transfer_engine.register_memory.return_value = 0 self.patches = [ - patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), - patch('torch.Tensor.element_size', return_value=4), - patch('torch.Tensor.data_ptr', return_value=0x1000), - patch('math.prod', return_value=128), - patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_rank', - mock_get_tensor_model_parallel_rank), + patch("torch.Tensor.size", return_value=(10, 16, 8, 16)), + patch("torch.Tensor.element_size", return_value=4), + patch("torch.Tensor.data_ptr", return_value=0x1000), + patch("math.prod", return_value=128), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tp_group', - mock_get_tp_group), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_rank", + mock_get_tensor_model_parallel_rank, + ), + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tp_group", mock_get_tp_group), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pp_group', - return_value=_mock_pp_group), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pp_group", + return_value=_mock_pp_group, + ), + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ip", mock_get_ip), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ip', - mock_get_ip), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.string_to_int64_hash", + mock_string_to_int64_hash, + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.string_to_int64_hash', - mock_string_to_int64_hash), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.global_te.get_transfer_engine", + return_value=self.mock_transfer_engine, + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.global_te.get_transfer_engine', - return_value=self.mock_transfer_engine), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.global_te.register_buffer", + return_value=None, + ), + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.KVCacheSendingThread", MagicMock()), + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.KVCacheRecvingThread", MagicMock()), + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger", MagicMock()), + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.threading.Event", MagicMock()), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.global_te.register_buffer', - return_value=None), - patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.KVCacheSendingThread', - MagicMock()), - patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.KVCacheRecvingThread', - MagicMock()), - patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger', - MagicMock()), - patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.threading.Event', - MagicMock()), - patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), ] for p in self.patches: @@ -1185,7 +1124,7 @@ def test_register_kv_caches_producer(self): self.assertIsNone(worker.kv_recv_thread) def test_register_kv_caches_consumer(self): - self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer' + self.vllm_config.kv_transfer_config.kv_role = "kv_consumer" worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) self.assertIsNone(worker.kv_send_thread) @@ -1210,113 +1149,102 @@ def test_device_id_selection_with_physical_devices(self): self.assertIsNotNone(worker.engine) def test_get_remote_tp_rank(self): - - def get_tp_rank(prefill_tp_size: int, - prefill_pp_size: int, - decode_tp_size: int, - num_kv_heads: int, - tp_num_need_pulls: int, - is_deepseek_mla: bool, - remote_ptp_size: Optional[int] = None): - with patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config', - return_value=MagicMock()), \ - patch.object(self.vllm_config.kv_transfer_config, 'get_from_extra_config', - side_effect=lambda k, d=None: { - "prefill": {"tp_size": prefill_tp_size, "dp_size": 1, "pp_size": prefill_pp_size}, - "decode": {"tp_size": decode_tp_size, "dp_size": 1, "pp_size": 1} - }.get(k, d)): + def get_tp_rank( + prefill_tp_size: int, + prefill_pp_size: int, + decode_tp_size: int, + num_kv_heads: int, + tp_num_need_pulls: int, + is_deepseek_mla: bool, + remote_ptp_size: int | None = None, + ): + with ( + patch( + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config", + return_value=MagicMock(), + ), + patch.object( + self.vllm_config.kv_transfer_config, + "get_from_extra_config", + side_effect=lambda k, d=None: { + "prefill": {"tp_size": prefill_tp_size, "dp_size": 1, "pp_size": prefill_pp_size}, + "decode": {"tp_size": decode_tp_size, "dp_size": 1, "pp_size": 1}, + }.get(k, d), + ), + ): self.vllm_config.model_config.hf_text_config.num_key_value_heads = num_kv_heads self.vllm_config.model_config.is_deepseek_mla = is_deepseek_mla - worker = MooncakeConnectorWorker(self.vllm_config, - self.engine_id) + worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker.tp_num_need_pulls = tp_num_need_pulls worker.use_sparse = 0 - return worker._get_remote_ranks_for_req( - 'test', remote_ptp_size) + return worker._get_remote_ranks_for_req("test", remote_ptp_size) self.assertIn( - get_tp_rank(16, 1, 1, 4, 4, False)[0], - [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) - self.assertIn( - get_tp_rank(8, 1, 1, 4, 4, False)[0], [[0, 2, 4, 6], [1, 3, 5, 7]]) + get_tp_rank(16, 1, 1, 4, 4, False)[0], [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]] + ) + self.assertIn(get_tp_rank(8, 1, 1, 4, 4, False)[0], [[0, 2, 4, 6], [1, 3, 5, 7]]) self.assertIn(get_tp_rank(4, 1, 1, 4, 4, False)[0], [[0, 1, 2, 3]]) - self.assertIn(get_tp_rank(16, 1, 4, 4, 1, False), - [[[0], [4], [8], [12]], [[1], [5], [9], [13]], - [[2], [6], [10], [14]], [[3], [7], [11], [15]]]) - self.assertIn(get_tp_rank(8, 1, 4, 4, 1, False), - [[[0], [2], [4], [6]], [[1], [3], [5], [7]]]) - self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False), - [[[0, 1, 4, 5], [2, 3, 6, 7]]]) - self.assertIn(get_tp_rank(4, 1, 4, 4, 1, False), - [[[0], [1], [2], [3]]]) self.assertIn( - get_tp_rank(8, 2, 1, 4, 4, False)[0], - [[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]]) - self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False), - [[[0, 1, 4, 5], [2, 3, 6, 7]]]) + get_tp_rank(16, 1, 4, 4, 1, False), + [[[0], [4], [8], [12]], [[1], [5], [9], [13]], [[2], [6], [10], [14]], [[3], [7], [11], [15]]], + ) + self.assertIn(get_tp_rank(8, 1, 4, 4, 1, False), [[[0], [2], [4], [6]], [[1], [3], [5], [7]]]) + self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False), [[[0, 1, 4, 5], [2, 3, 6, 7]]]) + self.assertIn(get_tp_rank(4, 1, 4, 4, 1, False), [[[0], [1], [2], [3]]]) + self.assertIn(get_tp_rank(8, 2, 1, 4, 4, False)[0], [[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]]) + self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False), [[[0, 1, 4, 5], [2, 3, 6, 7]]]) self.assertIn(get_tp_rank(2, 2, 1, 4, 2, False), [[[0, 1, 2, 3]]]) + self.assertIn(get_tp_rank(4, 4, 2, 8, 2, False), [[[0, 1, 4, 5, 8, 9, 12, 13], [2, 3, 6, 7, 10, 11, 14, 15]]]) + self.assertIn(get_tp_rank(4, 2, 1, 4, 4, False)[0], [[0, 1, 2, 3, 4, 5, 6, 7]]) + self.assertIn(get_tp_rank(4, 4, 1, 4, 4, False)[0], [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]) self.assertIn( - get_tp_rank(4, 4, 2, 8, 2, False), - [[[0, 1, 4, 5, 8, 9, 12, 13], [2, 3, 6, 7, 10, 11, 14, 15]]]) - self.assertIn( - get_tp_rank(4, 2, 1, 4, 4, False)[0], [[0, 1, 2, 3, 4, 5, 6, 7]]) - self.assertIn( - get_tp_rank(4, 4, 1, 4, 4, False)[0], - [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]) - self.assertIn(get_tp_rank(8, 2, 4, 4, 1, False), - [[[0, 8], [2, 10], [4, 12], [6, 14]], - [[1, 9], [3, 11], [5, 13], [7, 15]]]) - self.assertIn(get_tp_rank(4, 2, 4, 4, 4, False), - [[[0, 4], [1, 5], [2, 6], [3, 7]]]) + get_tp_rank(8, 2, 4, 4, 1, False), + [[[0, 8], [2, 10], [4, 12], [6, 14]], [[1, 9], [3, 11], [5, 13], [7, 15]]], + ) + self.assertIn(get_tp_rank(4, 2, 4, 4, 4, False), [[[0, 4], [1, 5], [2, 6], [3, 7]]]) self.assertIn( - get_tp_rank(4, 4, 4, 4, 1, False), - [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]]) + get_tp_rank(4, 4, 4, 4, 1, False), [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]] + ) self.assertIn( - get_tp_rank(16, 1, 1, 1, 1, - True)[0], [[0], [1], [2], [3], [4], [5], [6], [7], [8], - [9], [10], [11], [12], [13], [14], [15]]) + get_tp_rank(16, 1, 1, 1, 1, True)[0], + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15]], + ) self.assertIn(get_tp_rank(4, 1, 4, 1, 1, True), [[[0], [1], [2], [3]]]) self.assertIn( - get_tp_rank(8, 2, 1, 1, 1, True)[0], - [[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13], - [7, 15]]) + get_tp_rank(8, 2, 1, 1, 1, True)[0], [[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13], [7, 15]] + ) self.assertIn( - get_tp_rank(4, 4, 1, 1, 1, True)[0], - [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) + get_tp_rank(4, 4, 1, 1, 1, True)[0], [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]] + ) self.assertIn( - get_tp_rank(8, 2, 4, 1, 1, True)[0], - [[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13], - [7, 15]]) + get_tp_rank(8, 2, 4, 1, 1, True)[0], [[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13], [7, 15]] + ) self.assertIn( - get_tp_rank(4, 4, 4, 1, 1, True), - [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]]) + get_tp_rank(4, 4, 4, 1, 1, True), [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]] + ) # check remote ptp size - self.assertListEqual(get_tp_rank(16, 1, 2, 4, 2, False, 8), - get_tp_rank(8, 1, 2, 4, 2, False)) - self.assertListEqual(get_tp_rank(8, 1, 2, 4, 2, False, 4), - get_tp_rank(4, 1, 2, 4, 2, False)) - self.assertListEqual(get_tp_rank(4, 1, 2, 4, 1, False, 2), - get_tp_rank(2, 1, 2, 4, 1, False)) + self.assertListEqual(get_tp_rank(16, 1, 2, 4, 2, False, 8), get_tp_rank(8, 1, 2, 4, 2, False)) + self.assertListEqual(get_tp_rank(8, 1, 2, 4, 2, False, 4), get_tp_rank(4, 1, 2, 4, 2, False)) + self.assertListEqual(get_tp_rank(4, 1, 2, 4, 1, False, 2), get_tp_rank(2, 1, 2, 4, 1, False)) def test_get_kv_split_metadata(self): - - def get_kv_split_metadata(use_mla, - pcp_size, - dcp_size, - tp_size, - tp_rank, - pcp_rank, - _prefill_tp_size, - remote_pcp_size, - remote_dcp_size, - remote_port, - remote_block_ids, - local_block_ids, - remote_engine_id, - remote_ptp_size=None): - + def get_kv_split_metadata( + use_mla, + pcp_size, + dcp_size, + tp_size, + tp_rank, + pcp_rank, + _prefill_tp_size, + remote_pcp_size, + remote_dcp_size, + remote_port, + remote_block_ids, + local_block_ids, + remote_engine_id, + remote_ptp_size=None, + ): worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker.use_mla = use_mla @@ -1338,78 +1266,88 @@ def get_kv_split_metadata(use_mla, meta.remote_port = remote_port meta.remote_block_ids = remote_block_ids meta.local_block_ids = local_block_ids - meta.num_external_tokens = pcp_size * dcp_size * len( - local_block_ids) * worker.block_size + meta.num_external_tokens = pcp_size * dcp_size * len(local_block_ids) * worker.block_size meta.num_prompt_blocks = pcp_size * dcp_size * len(local_block_ids) meta.remote_engine_id = remote_engine_id remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata( - '0', meta) + "0", meta + ) return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list self.assertEqual( - get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], - [1], 0), - ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], - [30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [], - [], [], [1]])) + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], [1], 0), + ( + [[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30000]], + [[], [], [], [], [], [], [], [1]], + [[], [], [], [], [], [], [], [1]], + ), + ) self.assertEqual( - get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], - [1], 0), - ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], - [30008], [30009], [30010], [30011], [30012], [30013], [30014], - [30015], [30000] - ], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], - [1]], [[], [], [], [], [], [], [], [], [], [], [], [], [], - [], [], [1]])) + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], [1], 0), + ( + [ + [30001], + [30002], + [30003], + [30004], + [30005], + [30006], + [30007], + [30008], + [30009], + [30010], + [30011], + [30012], + [30013], + [30014], + [30015], + [30000], + ], + [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]], + [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]], + ), + ) self.assertEqual( - get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], - [1], 0), - ([[30001], [30008], [30009], [30000]], [[], [], [], [1] - ], [[], [], [], [1]])) + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], [1], 0), + ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]]), + ) self.assertEqual( - get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], - [1], 0), - ([[30001], [30008], [30009], [30000]], [[], [], [], [1] - ], [[], [], [], [1]])) + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], [1], 0), + ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]]), + ) self.assertEqual( - get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], - [1], 0), - ([[30000], [30008]], [[1], []], [[1], []])) + get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], [1], 0), + ([[30000], [30008]], [[1], []], [[1], []]), + ) self.assertEqual( - get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], - [1], 0), - ([[30000], [30008]], [[1], []], [[1], []])) + get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], [1], 0), + ([[30000], [30008]], [[1], []], [[1], []]), + ) self.assertEqual( - get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000, - [1, 2, 3], [1, 2, 3, 4, 5], 0), - ([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]])) + get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000, [1, 2, 3], [1, 2, 3, 4, 5], 0), + ([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]), + ) # check remote ptp size self.assertEqual( - get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], - [1], 0, 16), - get_kv_split_metadata(True, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1], - [1], 0) + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], [1], 0, 16), + get_kv_split_metadata(True, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1], [1], 0), ) self.assertEqual( - get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], - [1], 0, 16), - get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1], - [1], 0) + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], [1], 0, 16), + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 1, 8, 30000, [1], [1], 0), ) self.assertEqual( - get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], - [1], 0, 16), - get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 2, 8, 30000, [1], - [1], 0) + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], [1], 0, 16), + get_kv_split_metadata(False, 1, 1, 8, 1, 0, 16, 2, 8, 30000, [1], [1], 0), ) def test_get_tp_num_need_pulls(self): @@ -1427,5 +1365,5 @@ def test_get_tp_num_need_pulls(self): self.assertEqual(tp_num_need_pulls, 1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index bdb1b02f465..bb69ca683d8 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -16,19 +16,29 @@ sys.modules["mooncake.engine"] = fake_engine from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector import ( # noqa: E402 - KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole, - MooncakeAgentMetadata, MooncakeLayerwiseConnector, - MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler, - MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, SendTask, - ensure_zmq_recv, ensure_zmq_send, group_concurrent_contiguous, - string_to_int64_hash, zmq_ctx) + KVCacheRecvingLayerThread, + KVCacheSendingLayerThread, + KVConnectorRole, + MooncakeAgentMetadata, + MooncakeLayerwiseConnector, + MooncakeLayerwiseConnectorMetadata, + MooncakeLayerwiseConnectorScheduler, + MooncakeLayerwiseConnectorWorker, + ReqMeta, + SendReqInfo, + SendTask, + ensure_zmq_recv, + ensure_zmq_send, + group_concurrent_contiguous, + string_to_int64_hash, + zmq_ctx, +) GET_META_MSG = b"get_meta_msg" DONE_SENDING_MSG = b"done_sending_msg" class TestKVCacheSendingLayerThread(unittest.TestCase): - def setUp(self): self.engine = MagicMock() self.engine.register_memory.return_value = 0 @@ -36,9 +46,7 @@ def setUp(self): fake_stream = MagicMock(name="FakeStream") fake_stream.synchronize = MagicMock() - self.first_kv_cache = torch.zeros((2, 2, 2, 8), - dtype=torch.float32, - device="cpu") + self.first_kv_cache = torch.zeros((2, 2, 2, 8), dtype=torch.float32, device="cpu") self.ready_event = threading.Event() @@ -57,20 +65,20 @@ def setUp(self): tp_rank=0, pd_head_ratio=1, num_head_replica=1, - kv_cache_base_addr=[1000, 2000, 3000, 4000, 5000, - 6000], # 2 * total_layers + kv_cache_base_addr=[1000, 2000, 3000, 4000, 5000, 6000], # 2 * total_layers use_mla=True, block_len=[1024, 2048], k_buffer=self.fake_k_buffer, v_buffer=self.fake_v_buffer, resharding_stream=fake_resharding_stream, - callback_func=MagicMock()) + callback_func=MagicMock(), + ) self.req_meta_base = ReqMeta( local_block_ids=[5, 8], - remote_tp_size = 8, - remote_pcp_size = 1, - remote_dcp_size = 1, + remote_tp_size=8, + remote_pcp_size=1, + remote_dcp_size=1, token_ids=[1, 2, 3], remote_block_ids=[10, 20], remote_engine_id="remote_engine", @@ -79,28 +87,27 @@ def setUp(self): remote_te_rpc_port=6000, remote_kv_caches_base_addr=[4000, 8000, 14000, 18000], metaserver="http://dummy", - chunk_finish=False) + chunk_finish=False, + ) @patch( "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.npu_stream_switch", - side_effect=lambda *_args, **_kwargs: contextlib.nullcontext()) + side_effect=lambda *_args, **_kwargs: contextlib.nullcontext(), + ) @patch( "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.Tensor.data_ptr", autospec=True, - return_value=0x200000) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.align_memory", - side_effect=lambda x, _align: x) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.npu.synchronize" + return_value=0x200000, ) @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous" + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.align_memory", + side_effect=lambda x, _align: x, ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.npu.synchronize") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous") def test_transfer_pd_gt1_uses_buffers_and_calls_engine( - self, mock_group, _mock_sync, _mock_align, _mock_dataptr, - mock_stream_switch): - + self, mock_group, _mock_sync, _mock_align, _mock_dataptr, mock_stream_switch + ): fake_resharding_stream = MagicMock() thread = KVCacheSendingLayerThread( @@ -116,7 +123,8 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine( k_buffer=self.fake_k_buffer, v_buffer=self.fake_v_buffer, resharding_stream=fake_resharding_stream, - callback_func=MagicMock()) + callback_func=MagicMock(), + ) req_meta = self.req_meta_base req_meta.remote_kv_caches_base_addr = [4000, 8000] @@ -137,8 +145,7 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine( thread._transfer_kv_cache(send_task) self.engine.batch_transfer_sync_write.assert_called_once() - session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ - 0] + session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[0] self.assertEqual(session_id, "127.0.0.1:6000") self.assertEqual(len(src_list), 4) @@ -172,20 +179,16 @@ def test_transfer_skips_when_no_local_blocks(self): @patch( "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous", - side_effect=group_concurrent_contiguous) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.npu.synchronize" + side_effect=group_concurrent_contiguous, ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.npu.synchronize") def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): - req_meta = self.req_meta_base req_meta.chunk_finish = True req_meta.local_block_ids = [5, 6] req_meta.remote_block_ids = [10, 11] - req_meta.remote_kv_caches_base_addr = [ - 7000, 8000, 9000, 10000, 11000, 12000 - ] + req_meta.remote_kv_caches_base_addr = [7000, 8000, 9000, 10000, 11000, 12000] key = torch.zeros((1, 8), dtype=torch.float32) value = torch.zeros((1, 8), dtype=torch.float32) @@ -204,21 +207,20 @@ def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): class TestKVCacheRecvingLayerThread(unittest.TestCase): - def setUp(self): - - self.meta = MooncakeAgentMetadata(te_rpc_port=6000, - kv_caches_base_addr=[0x1, 0x2]) + self.meta = MooncakeAgentMetadata(te_rpc_port=6000, kv_caches_base_addr=[0x1, 0x2]) self.ready_event = threading.Event() def test_get_and_clear_finished_requests(self): - th = KVCacheRecvingLayerThread(tp_rank=0, - side_channel_port=5555, - tp_size=2, - pd_head_ratio=1, - local_engine_id="engineA", - metadata=self.meta, - ready_event=self.ready_event) + th = KVCacheRecvingLayerThread( + tp_rank=0, + side_channel_port=5555, + tp_size=2, + pd_head_ratio=1, + local_engine_id="engineA", + metadata=self.meta, + ready_event=self.ready_event, + ) with th.lock: th.done_requests.update({"r1", "r2"}) @@ -229,13 +231,15 @@ def test_get_and_clear_finished_requests(self): self.assertEqual(got2, set()) def test_update_task_aggregates_by_pd_head_ratio(self): - th = KVCacheRecvingLayerThread(tp_rank=0, - side_channel_port=5555, - tp_size=2, - pd_head_ratio=2, - local_engine_id="engineA", - metadata=self.meta, - ready_event=self.ready_event) + th = KVCacheRecvingLayerThread( + tp_rank=0, + side_channel_port=5555, + tp_size=2, + pd_head_ratio=2, + local_engine_id="engineA", + metadata=self.meta, + ready_event=self.ready_event, + ) with th.lock: th.task_tracker["reqX"] = 0 @@ -251,40 +255,28 @@ def test_update_task_aggregates_by_pd_head_ratio(self): self.assertNotIn("reqX", th.task_tracker) self.assertIn("reqX", th.done_requests) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", - return_value="127.0.0.1") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_socket" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", return_value="127.0.0.1") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_socket") @patch( "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_path", - side_effect=lambda proto, host, port: f"{proto}://{host}:{port}") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder" - ) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx" + side_effect=lambda proto, host, port: f"{proto}://{host}:{port}", ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder") + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx") def test_run_loop_handles_meta_done_invalid_unexpected_and_ack( - self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_make_path, - _mock_make_sock, _mock_get_ip, mock_logger): - + self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_make_path, _mock_make_sock, _mock_get_ip, mock_logger + ): enc_inst = MagicMock() enc_inst.encode.return_value = b"ENCODED_META" mock_Encoder.return_value = enc_inst dec_inst = MagicMock() dec_inst.decode.side_effect = [ - (GET_META_MSG, ), + (GET_META_MSG,), (DONE_SENDING_MSG, "reqA", 1), - (b"weird_msg", ), + (b"weird_msg",), ] mock_Decoder.return_value = dec_inst @@ -303,13 +295,15 @@ def test_run_loop_handles_meta_done_invalid_unexpected_and_ack( mock_zmq_ctx.return_value = cm ready_event = threading.Event() - th = KVCacheRecvingLayerThread(tp_rank=1, - side_channel_port=6000, - tp_size=2, - pd_head_ratio=1, - local_engine_id="engineZ", - metadata=self.meta, - ready_event=ready_event) + th = KVCacheRecvingLayerThread( + tp_rank=1, + side_channel_port=6000, + tp_size=2, + pd_head_ratio=1, + local_engine_id="engineZ", + metadata=self.meta, + ready_event=ready_event, + ) with th.lock: th.task_tracker["reqA"] = 0 @@ -344,9 +338,8 @@ def test_run_loop_handles_meta_done_invalid_unexpected_and_ack( @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder") @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx") def test_run_loop_pd_head_ratio_gt1_requires_multiple_done( - self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip, - _mock_logger): - + self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip, _mock_logger + ): enc_inst = MagicMock() enc_inst.encode.return_value = b"ENC" mock_Encoder.return_value = enc_inst @@ -362,7 +355,7 @@ def test_run_loop_pd_head_ratio_gt1_requires_multiple_done( sock.recv_multipart.side_effect = [ [b"ID", b"PAY1"], [b"ID", b"PAY2"], - SystemExit, # 退出循环 + SystemExit, # 退出循环 ] cm = MagicMock() cm.__enter__.return_value = sock @@ -375,7 +368,7 @@ def test_run_loop_pd_head_ratio_gt1_requires_multiple_done( pd_head_ratio=2, local_engine_id="engineY", metadata=self.meta, - ready_event=self.ready_event + ready_event=self.ready_event, ) with th.lock: th.task_tracker["reqB"] = 0 @@ -387,7 +380,6 @@ def test_run_loop_pd_head_ratio_gt1_requires_multiple_done( class MockVllmConfig: - def __init__(self): self.model_config = MagicMock() self.parallel_config = MagicMock() @@ -408,24 +400,13 @@ def __init__(self): self.kv_transfer_config.is_kv_consumer = False self.kv_transfer_config.get_from_extra_config = MagicMock() self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: { - "prefill": { - "tp_size": 2, - "dp_size": 1 - }, - "decode": { - "tp_size": 2, - "dp_size": 1 - }, + "prefill": {"tp_size": 2, "dp_size": 1}, + "decode": {"tp_size": 2, "dp_size": 1}, }.get(k, d) class MockRequest: - - def __init__(self, - request_id, - prompt_token_ids=None, - kv_transfer_params=None, - status=None): + def __init__(self, request_id, prompt_token_ids=None, kv_transfer_params=None, status=None): self.request_id = request_id self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4] self.kv_transfer_params = kv_transfer_params or {} @@ -437,19 +418,20 @@ def __init__(self, class TestMooncakeLayerwiseConnectorMetadata(unittest.TestCase): - def test_add_new_req(self): meta = MooncakeLayerwiseConnectorMetadata() self.assertEqual(len(meta.requests), 0) - meta.add_new_req(request_id="req1", - local_block_ids=[1, 2, 3], - kv_transfer_params={ - "remote_block_ids": [4, 5, 6], - "remote_engine_id": "remote_engine", - "remote_host": "localhost", - "remote_port": 5000 - }) + meta.add_new_req( + request_id="req1", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": "remote_engine", + "remote_host": "localhost", + "remote_port": 5000, + }, + ) self.assertEqual(len(meta.requests), 1) req_meta = meta.requests["req1"] @@ -462,22 +444,18 @@ def test_add_new_req(self): class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase): - def setUp(self): config = MockVllmConfig() - self.scheduler = MooncakeLayerwiseConnectorScheduler( - config, "test_engine") + self.scheduler = MooncakeLayerwiseConnectorScheduler(config, "test_engine") def test_get_num_new_matched_tokens(self): request = MockRequest("req1") - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens(request, 0) self.assertEqual(tokens, 0) self.assertFalse(async_flag) request.kv_transfer_params = {"do_remote_prefill": True} - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - request, 0) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens(request, 0) self.assertEqual(tokens, 4) self.assertTrue(async_flag) @@ -502,29 +480,27 @@ def test_build_connector_meta(self): class _MockBlocks: - def __init__(self, unhashed, block_ids_tuple=None): self._unhashed = list(unhashed) - self._block_ids_tuple = block_ids_tuple if block_ids_tuple is not None else ( - [1, 2], ) + self._block_ids_tuple = block_ids_tuple if block_ids_tuple is not None else ([1, 2],) def get_unhashed_block_ids(self): return list(self._unhashed) def get_block_ids(self): - return self._block_ids_tuple class _MockSchedulerOutput: - - def __init__(self, - cached_req_ids=None, - cached_new_block_ids=None, - cached_num_computed=None, - new_reqs=None, - num_sched=None, - scheduled_spec_decode_tokens=None): + def __init__( + self, + cached_req_ids=None, + cached_new_block_ids=None, + cached_num_computed=None, + new_reqs=None, + num_sched=None, + scheduled_spec_decode_tokens=None, + ): self.scheduled_cached_reqs = SimpleNamespace( req_ids=cached_req_ids or [], new_block_ids=cached_new_block_ids or [], @@ -536,32 +512,24 @@ def __init__(self, class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): - def setUp(self): self.config = MockVllmConfig() - self.scheduler = MooncakeLayerwiseConnectorScheduler( - self.config, "test_engine") + self.scheduler = MooncakeLayerwiseConnectorScheduler(self.config, "test_engine") def test_get_num_new_matched_tokens_with_prefill_block_aligned(self): - - req = MockRequest("req_prefill", - prompt_token_ids=list(range(32)), - kv_transfer_params={"do_remote_prefill": True}) - tokens, async_flag = self.scheduler.get_num_new_matched_tokens( - req, num_computed_tokens=16) + req = MockRequest( + "req_prefill", prompt_token_ids=list(range(32)), kv_transfer_params={"do_remote_prefill": True} + ) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens(req, num_computed_tokens=16) self.assertEqual(tokens, 16) self.assertTrue(async_flag) def test_update_state_after_alloc_prefill_records_and_resets_flag(self): - req = MockRequest("req_u1", - prompt_token_ids=list(range(24)), - kv_transfer_params={"do_remote_prefill": True}) + req = MockRequest("req_u1", prompt_token_ids=list(range(24)), kv_transfer_params={"do_remote_prefill": True}) req.num_computed_tokens = 0 - blocks = _MockBlocks(unhashed=[4, 5, 6], block_ids_tuple=([4, 5, 6], )) + blocks = _MockBlocks(unhashed=[4, 5, 6], block_ids_tuple=([4, 5, 6],)) - self.scheduler.update_state_after_alloc(req, - blocks, - num_external_tokens=8) + self.scheduler.update_state_after_alloc(req, blocks, num_external_tokens=8) self.assertIn("req_u1", self.scheduler._reqs_need_recv) record = self.scheduler._reqs_need_recv["req_u1"] self.assertIs(record[0], req) @@ -570,16 +538,13 @@ def test_update_state_after_alloc_prefill_records_and_resets_flag(self): self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True)) def test_update_state_after_alloc_decode_records_send_layerwise(self): - req = MockRequest("req_u2", - prompt_token_ids=list(range(10)), - kv_transfer_params={ - "do_remote_decode": True, - "remote_block_ids": [] - }) - blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) - self.scheduler.update_state_after_alloc(req, - blocks, - num_external_tokens=0) + req = MockRequest( + "req_u2", + prompt_token_ids=list(range(10)), + kv_transfer_params={"do_remote_decode": True, "remote_block_ids": []}, + ) + blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9],)) + self.scheduler.update_state_after_alloc(req, blocks, num_external_tokens=0) self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise) info = self.scheduler._reqs_need_send_layerwise["req_u2"] self.assertEqual(info.local_block_ids, [7, 8, 9]) @@ -587,15 +552,17 @@ def test_update_state_after_alloc_decode_records_send_layerwise(self): def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True - req = MockRequest("req_b1", - kv_transfer_params={ - "remote_block_ids": [1, 2], - "remote_engine_id": "E", - "remote_host": "H", - "remote_port": 5555, - "remote_te_rpc_port": 6000, - "remote_kv_caches_base_addr": [10, 11], - }) + req = MockRequest( + "req_b1", + kv_transfer_params={ + "remote_block_ids": [1, 2], + "remote_engine_id": "E", + "remote_host": "H", + "remote_port": 5555, + "remote_te_rpc_port": 6000, + "remote_kv_caches_base_addr": [10, 11], + }, + ) self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101]) meta = self.scheduler.build_connector_meta(_MockSchedulerOutput()) self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata) @@ -620,7 +587,7 @@ def test_build_connector_meta_accumulates_cached_blocks(self): out = _MockSchedulerOutput( cached_req_ids=["req_b2"], - cached_new_block_ids=[([3, 4], )], + cached_new_block_ids=[([3, 4],)], cached_num_computed=[4], new_reqs=[], num_sched={}, @@ -628,11 +595,8 @@ def test_build_connector_meta_accumulates_cached_blocks(self): meta = self.scheduler.build_connector_meta(out) self.assertEqual(len(meta.requests), 0) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous" - ) - def test_build_connector_meta_emits_when_tokens_reach_total( - self, mock_group_concurrent_contiguous): + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous") + def test_build_connector_meta_emits_when_tokens_reach_total(self, mock_group_concurrent_contiguous): req_meta = MagicMock(spec=ReqMeta) req_meta.local_block_ids = [1, 2, 3] req_meta.remote_block_ids = [4, 5] @@ -654,18 +618,18 @@ def test_build_connector_meta_emits_when_tokens_reach_total( send_req_info.update_computed_tokens = MagicMock() send_req_info.update_transferred_tokens = MagicMock() send_req_info.unpack = MagicMock( - return_value=( - send_req_info.local_block_ids, - send_req_info.local_transferred_tokens, - send_req_info.local_computed_tokens, - send_req_info.request - ) - ) + return_value=( + send_req_info.local_block_ids, + send_req_info.local_transferred_tokens, + send_req_info.local_computed_tokens, + send_req_info.request, + ) + ) self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info out = _MockSchedulerOutput( cached_req_ids=["req_b3"], - cached_new_block_ids=[([50], )], + cached_new_block_ids=[([50],)], cached_num_computed=[8], new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)], num_sched={"req_b3": 4}, @@ -675,14 +639,12 @@ def test_build_connector_meta_emits_when_tokens_reach_total( self.assertIn("req_b3", meta.requests) def test_request_finished_returns_false_none(self): - ok, params = self.scheduler.request_finished(MockRequest("req_fin"), - [1, 2]) + ok, params = self.scheduler.request_finished(MockRequest("req_fin"), [1, 2]) self.assertFalse(ok) self.assertIsNone(params) class TestHelperFunctions(unittest.TestCase): - def test_group_concurrent_contiguous(self): src: list[int] = [1, 2, 3, 5, 6] dst: list[int] = [10, 11, 12, 14, 15] @@ -709,43 +671,35 @@ def test_string_to_int64_hash(self): self.assertNotEqual(hash1, hash3) def test_zmq_ctx_invalid_type(self): - with self.assertRaises(ValueError): - with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): - pass + with self.assertRaises(ValueError), zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): + pass - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_socket" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_socket") def test_zmq_ctx_ok(self, mock_make_socket): mock_socket = MagicMock() mock_make_socket.return_value = mock_socket with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore self.assertEqual(s, mock_socket) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger") def test_ensure_zmq_send_success(self, _): mock_socket = MagicMock() path = "127.0.0.1:12345" ensure_zmq_send(mock_socket, b"hello", path) mock_socket.send.assert_called_once_with(b"hello") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger") def test_ensure_zmq_send_retry_and_fail(self, _): mock_socket = MagicMock() path = "127.0.0.1:12345" mock_socket.send.side_effect = zmq.ZMQError( # type: ignore - "send failed") + "send failed" + ) with self.assertRaises(RuntimeError): ensure_zmq_send(mock_socket, b"hello", path, max_retries=2) self.assertEqual(mock_socket.send.call_count, 2) - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger") def test_ensure_zmq_recv_success(self, _): mock_socket = MagicMock() mock_socket.recv.return_value = b"response" @@ -757,44 +711,33 @@ def test_ensure_zmq_recv_success(self, _): data = ensure_zmq_recv(mock_socket, mock_poller, path) self.assertEqual(data, b"response") - @patch( - "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger" - ) + @patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger") def test_ensure_zmq_recv_timeout_and_fail(self, _): mock_socket = MagicMock() mock_poller = MagicMock() mock_poller.poll.return_value = [] path = "127.0.0.1:12345" with self.assertRaises(RuntimeError): - ensure_zmq_recv(mock_socket, - mock_poller, - path, - timeout=0.01, - max_retries=2) + ensure_zmq_recv(mock_socket, mock_poller, path, timeout=0.01, max_retries=2) class TestMooncakeLayerwiseConnectorForScheduler(unittest.TestCase): - def test_scheduler_role(self): config = MockVllmConfig() - connector = MooncakeLayerwiseConnector(config, - KVConnectorRole.SCHEDULER) + connector = MooncakeLayerwiseConnector(config, KVConnectorRole.SCHEDULER) self.assertIsNotNone(connector.connector_scheduler) self.assertIsNone(connector.connector_worker) - @patch.object(MooncakeLayerwiseConnectorScheduler, - "get_num_new_matched_tokens") + @patch.object(MooncakeLayerwiseConnectorScheduler, "get_num_new_matched_tokens") def test_scheduler_methods(self, mock_method): config = MockVllmConfig() - connector = MooncakeLayerwiseConnector(config, - KVConnectorRole.SCHEDULER) + connector = MooncakeLayerwiseConnector(config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") connector.get_num_new_matched_tokens(request, 0) mock_method.assert_called_once_with(request, 0) class MockKVCacheBlocks: - def get_unhashed_block_ids(self): return [4, 5, 6] @@ -808,31 +751,25 @@ class MockForwardContext: class TestMooncakeLayerwiseConnector(unittest.TestCase): - def setUp(self): self.config = MockVllmConfig() os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" def test_scheduler_initialization(self): - connector = MooncakeLayerwiseConnector(self.config, - KVConnectorRole.SCHEDULER) + connector = MooncakeLayerwiseConnector(self.config, KVConnectorRole.SCHEDULER) self.assertIsNotNone(connector.connector_scheduler) self.assertIsNone(connector.connector_worker) - @patch.object(MooncakeLayerwiseConnectorScheduler, - "get_num_new_matched_tokens") + @patch.object(MooncakeLayerwiseConnectorScheduler, "get_num_new_matched_tokens") def test_get_num_new_matched_tokens(self, mock_method): - connector = MooncakeLayerwiseConnector(self.config, - KVConnectorRole.SCHEDULER) + connector = MooncakeLayerwiseConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") connector.get_num_new_matched_tokens(request, 0) mock_method.assert_called_once_with(request, 0) - @patch.object(MooncakeLayerwiseConnectorScheduler, - "update_state_after_alloc") + @patch.object(MooncakeLayerwiseConnectorScheduler, "update_state_after_alloc") def test_update_state_after_alloc(self, mock_method): - connector = MooncakeLayerwiseConnector(self.config, - KVConnectorRole.SCHEDULER) + connector = MooncakeLayerwiseConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") blocks = MockKVCacheBlocks() connector.update_state_after_alloc(request, blocks, 3) @@ -840,23 +777,20 @@ def test_update_state_after_alloc(self, mock_method): @patch.object(MooncakeLayerwiseConnectorScheduler, "build_connector_meta") def test_build_connector_meta(self, mock_method): - connector = MooncakeLayerwiseConnector(self.config, - KVConnectorRole.SCHEDULER) + connector = MooncakeLayerwiseConnector(self.config, KVConnectorRole.SCHEDULER) scheduler_output = MockSchedulerOutput() connector.build_connector_meta(scheduler_output) mock_method.assert_called_once_with(scheduler_output) @patch.object(MooncakeLayerwiseConnectorScheduler, "request_finished") def test_request_finished(self, mock_method): - connector = MooncakeLayerwiseConnector(self.config, - KVConnectorRole.SCHEDULER) + connector = MooncakeLayerwiseConnector(self.config, KVConnectorRole.SCHEDULER) request = MockRequest("req1") connector.request_finished(request, [1, 2, 3]) mock_method.assert_called_once_with(request, [1, 2, 3]) class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase): - def setUp(self): self.mock_transfer_engine = MagicMock() self.mock_transfer_engine.get_rpc_port.return_value = 9090 @@ -864,46 +798,51 @@ def setUp(self): self.mock_transfer_engine.register_memory.return_value = 0 self.patches = [ - patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), - patch('torch.Tensor.element_size', return_value=4), - patch('torch.Tensor.data_ptr', return_value=0x1000), - patch('math.prod', return_value=128), - patch('random.Random'), + patch("torch.Tensor.size", return_value=(10, 16, 8, 16)), + patch("torch.Tensor.element_size", return_value=4), + patch("torch.Tensor.data_ptr", return_value=0x1000), + patch("math.prod", return_value=128), + patch("random.Random"), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_tensor_model_parallel_rank', - return_value=0), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_tensor_model_parallel_rank", + return_value=0, + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_tp_group', - return_value=None), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_tp_group", + return_value=None, + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip', - return_value="127.0.0.1"), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip", + return_value="127.0.0.1", + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.string_to_int64_hash', - side_effect=lambda s: hash(s)), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.string_to_int64_hash", + side_effect=lambda s: hash(s), + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.global_te.get_transfer_engine', - return_value=self.mock_transfer_engine), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.global_te.get_transfer_engine", + return_value=self.mock_transfer_engine, + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.global_te.register_buffer', - return_value=None), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.global_te.register_buffer", + return_value=None, + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.KVCacheSendingLayerThread', - MagicMock()), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.KVCacheSendingLayerThread", + MagicMock(), + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.KVCacheRecvingLayerThread', - MagicMock()), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.KVCacheRecvingLayerThread", + MagicMock(), + ), + patch("vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger", MagicMock()), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger', - MagicMock()), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.threading.Event", MagicMock() + ), patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.threading.Event', - MagicMock()), - patch( - 'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ascend_config', - return_value=SimpleNamespace(pd_tp_ratio=1, - num_head_replica=1, - pd_head_ratio=1)), + "vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ascend_config", + return_value=SimpleNamespace(pd_tp_ratio=1, num_head_replica=1, pd_head_ratio=1), + ), ] for p in self.patches: @@ -923,22 +862,18 @@ def tearDown(self): p.stop() # type: ignore def test_register_kv_caches_producer(self): - self.vllm_config.kv_transfer_config.is_kv_producer = True self.vllm_config.kv_transfer_config.is_kv_consumer = False - worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, - self.engine_id) + worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) self.assertEqual(len(worker.kv_caches), 1) self.assertIsNotNone(worker.kv_send_layer_thread) self.assertIsNone(worker.kv_recv_layer_thread) def test_register_kv_caches_consumer(self): - self.vllm_config.kv_transfer_config.is_kv_producer = False self.vllm_config.kv_transfer_config.is_kv_consumer = True - worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, - self.engine_id) + worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) self.assertIsNone(worker.kv_send_layer_thread) self.assertIsNotNone(worker.kv_recv_layer_thread) @@ -949,8 +884,7 @@ def test_register_kv_caches_mla_case(self): mla_cache2 = MagicMock() mla_cache2.size.return_value = (10, 16, 1, 8) mla_caches = {"layer1": (mla_cache1, mla_cache2)} - worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, - self.engine_id) + worker = MooncakeLayerwiseConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(mla_caches) self.assertTrue(worker.use_mla) - self.assertEqual(len(worker.block_len), 2) \ No newline at end of file + self.assertEqual(len(worker.block_len), 2) diff --git a/tests/ut/kv_connector/test_remote_decode_lifecycle.py b/tests/ut/kv_connector/test_remote_decode_lifecycle.py index bf44c0fdc85..75de02d7094 100644 --- a/tests/ut/kv_connector/test_remote_decode_lifecycle.py +++ b/tests/ut/kv_connector/test_remote_decode_lifecycle.py @@ -21,10 +21,13 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.request import FinishReason, RequestStatus -from tests.ut.kv_connector.utils import (assert_scheduler_empty, - create_model_runner_output, - create_request, create_scheduler, - create_vllm_config) +from tests.ut.kv_connector.utils import ( + assert_scheduler_empty, + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) def test_basic_lifecycle(): @@ -38,10 +41,7 @@ def test_basic_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request = create_request(request_id=1, - max_tokens=1, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request = create_request(request_id=1, max_tokens=1, num_tokens=NUM_TOKENS, do_remote_decode=True) scheduler.add_request(request) request_id = request.request_id @@ -56,8 +56,7 @@ def test_basic_lifecycle(): model_runner_output = create_model_runner_output(reqs=[request]) # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) # Ensure the request is finished after 1 tokens. assert request.is_finished() @@ -72,8 +71,7 @@ def test_basic_lifecycle(): assert len(scheduler.waiting) == 0 # ... but blocks should not be freed. - blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 @@ -102,10 +100,9 @@ def test_basic_lifecycle(): # (3b): execute_model() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa + + model_runner_output.kv_connector_output = KVConnectorOutput(finished_sending=[request_id]) # (3c): update_from_output() scheduler.update_from_output(scheduler_output, model_runner_output) @@ -129,8 +126,7 @@ def test_prefix_cache_lifecycle(): scheduler.add_request(request_remote_a) scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_remote_a], - use_eos=True) + model_runner_output = create_model_runner_output(reqs=[request_remote_a], use_eos=True) scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.schedule() scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) @@ -142,9 +138,7 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS -= 1 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_remote = create_request(request_id=1, - num_tokens=NUM_TOKENS, - do_remote_decode=True) + request_remote = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_decode=True) scheduler.add_request(request_remote) scheduler_output = scheduler.schedule() @@ -152,18 +146,15 @@ def test_prefix_cache_lifecycle(): eco = scheduler.update_from_output(scheduler_output, model_runner_output) kv_transfer_params = eco[0].outputs[0].kv_transfer_params # Ensure we send all block ids, even if there is a cache hit. - assert (len( - kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + - 1)) + assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1) # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_sending=[request_remote.request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # noqa + + model_runner_output.kv_connector_output = KVConnectorOutput(finished_sending=[request_remote.request_id]) scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) diff --git a/tests/ut/kv_connector/test_remote_prefill_lifecycle.py b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py index c9b88915595..3a33fe128d9 100644 --- a/tests/ut/kv_connector/test_remote_prefill_lifecycle.py +++ b/tests/ut/kv_connector/test_remote_prefill_lifecycle.py @@ -21,10 +21,13 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.request import RequestStatus -from tests.ut.kv_connector.utils import (assert_scheduler_empty, - create_model_runner_output, - create_request, create_scheduler, - create_vllm_config) +from tests.ut.kv_connector.utils import ( + assert_scheduler_empty, + create_model_runner_output, + create_request, + create_scheduler, + create_vllm_config, +) def test_basic_lifecycle(): @@ -37,13 +40,9 @@ def test_basic_lifecycle(): BLOCK_SIZE = vllm_config.cache_config.block_size NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - START_FREE_BLOCK_QUEUE_SIZE = ( - scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) + START_FREE_BLOCK_QUEUE_SIZE = scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks - request = create_request(request_id=1, - num_tokens=NUM_TOKENS, - do_remote_prefill=True, - block_size=BLOCK_SIZE) + request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True, block_size=BLOCK_SIZE) scheduler.add_request(request) request_id = request.request_id @@ -62,16 +61,14 @@ def test_basic_lifecycle(): # Req waiting for KVs with no computed/scheduled toks ... assert len(scheduler.waiting) == 1 assert request in scheduler.waiting - assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS) - assert (request.num_computed_tokens == 0) + assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS + assert request.num_computed_tokens == 0 # ... but should have (uncached) blocks allocated to it. block_pool = scheduler.kv_cache_manager.block_pool - assert (block_pool.free_block_queue.num_free_blocks - < START_FREE_BLOCK_QUEUE_SIZE) + assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE assert len(block_pool.cached_block_hash_to_block) == 0 - blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id] for block in blocks: assert block._block_hash is None @@ -79,8 +76,7 @@ def test_basic_lifecycle(): model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT # (1c): update_from_output() - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) assert not engine_core_outputs or not engine_core_outputs[0].outputs # STEP (2): @@ -91,16 +87,14 @@ def test_basic_lifecycle(): # (2b): forward(): request finishes recv. model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa + + model_runner_output.kv_connector_output = KVConnectorOutput(finished_recving=[request_id]) # (2c): update_from_output(): - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.waiting) == 1 - assert (request_id in scheduler.finished_recving_kv_req_ids) + assert request_id in scheduler.finished_recving_kv_req_ids # STEP (3): # (3a): schedule(): this should actually schedule. @@ -109,11 +103,10 @@ def test_basic_lifecycle(): # Confirm the block are actually allocated. num_hashed_blocks = 0 - blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_id] + blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id] for block in blocks: assert block.ref_cnt == 1 - num_hashed_blocks += (1 if block._block_hash is not None else 0) + num_hashed_blocks += 1 if block._block_hash is not None else 0 assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS # Confirm the rest of the prompt is scheduled in this step. @@ -121,7 +114,7 @@ def test_basic_lifecycle(): num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id] num_computed_tokens = scheduled_req.num_computed_tokens total_prompt_tokens = len(scheduled_req.prompt_token_ids) - assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens) + assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens # (3b): execute_model() model_runner_output = create_model_runner_output([request]) @@ -131,8 +124,7 @@ def test_basic_lifecycle(): # Step (4): Hit EOS. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output([request], use_eos=True) - engine_core_outputs = scheduler.update_from_output(scheduler_output, - model_runner_output) + engine_core_outputs = scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.schedule() assert_scheduler_empty(scheduler) @@ -169,8 +161,9 @@ def test_no_spurious_prefix_caching(): scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) assert len(scheduler.waiting) == 1 - remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ - 0].req_to_blocks[request_remote.request_id] + remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[ + request_remote.request_id + ] # Remote blocks should not be cached. for block in remote_blocks: @@ -189,9 +182,7 @@ def test_full_block_prompt(): NUM_EXTERNAL_FULL_BLOCKS = 2 NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) - request = create_request(request_id=1, - num_tokens=NUM_TOKENS, - do_remote_prefill=True) + request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True) scheduler.add_request(request) request_id = request.request_id @@ -199,8 +190,7 @@ def test_full_block_prompt(): # STEP (1): Initialize a recv. scheduler_output = scheduler.schedule() # All blocks should be allocated. - num_blocks = len(scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]) + num_blocks = len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id]) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT scheduler.update_from_output(scheduler_output, model_runner_output) @@ -208,25 +198,22 @@ def test_full_block_prompt(): # # STEP (2): Recv. scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - from vllm.v1.worker.kv_connector_model_runner_mixin import \ - KVConnectorOutput # type: ignore # noqa - model_runner_output.kv_connector_output = KVConnectorOutput( - finished_recving=[request_id]) + from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput # type: ignore # noqa + + model_runner_output.kv_connector_output = KVConnectorOutput(finished_recving=[request_id]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.waiting) == 1 - assert (request_id in scheduler.finished_recving_kv_req_ids) + assert request_id in scheduler.finished_recving_kv_req_ids # # STEP (3): Run as usual. scheduler_output = scheduler.schedule() # We need to recompute the final token of the prompt to generate # the first new token, so we should not have a new block. - num_blocks = len(scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[request_id]) + num_blocks = len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[request_id]) assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS - assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == - NUM_TOKENS - 1) - assert (scheduler_output.num_scheduled_tokens[request_id] == 1) + assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1 + assert scheduler_output.num_scheduled_tokens[request_id] == 1 model_runner_output = create_model_runner_output([request]) scheduler.update_from_output(scheduler_output, model_runner_output) diff --git a/tests/ut/kv_connector/utils.py b/tests/ut/kv_connector/utils.py index 6a560e80669..2b1950bc7b8 100644 --- a/tests/ut/kv_connector/utils.py +++ b/tests/ut/kv_connector/utils.py @@ -4,7 +4,7 @@ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. import os -from typing import Any, Optional +from typing import Any import torch from vllm import SamplingParams @@ -122,7 +122,7 @@ def create_request( block_hasher = get_request_block_hasher(block_size, sha256) - kv_transfer_params: Optional[dict[str, Any]] = None + kv_transfer_params: dict[str, Any] | None = None if do_remote_decode: assert not do_remote_prefill @@ -162,8 +162,8 @@ def create_request( def create_model_runner_output( reqs: list[Request], - finished_sending: Optional[list[str]] = None, - finished_recving: Optional[list[str]] = None, + finished_sending: list[str] | None = None, + finished_recving: list[str] | None = None, use_eos: bool = False, ) -> ModelRunnerOutput: """Make dummy model runner output for testing.""" diff --git a/tests/ut/model_loader/netloader/test_netloader.py b/tests/ut/model_loader/netloader/test_netloader.py index 64d95efe575..f8f4149fd19 100644 --- a/tests/ut/model_loader/netloader/test_netloader.py +++ b/tests/ut/model_loader/netloader/test_netloader.py @@ -25,8 +25,8 @@ class DummyDeviceConfig: - device = 'cuda' - device_type = 'cuda' + device = "cuda" + device_type = "cuda" class DummyParallelConfig: @@ -41,13 +41,12 @@ class DummyVllmConfig: class DummyModelConfig: - model = 'dummy-model' + model = "dummy-model" dtype = torch.float32 @pytest.fixture def default_load_config(): - class DummyLoadConfig: model_loader_extra_config = None load_format = "default" @@ -56,7 +55,6 @@ class DummyLoadConfig: def make_loader_with_config(extra): - class DummyLoadConfig: model_loader_extra_config = extra load_format = "default" @@ -67,9 +65,7 @@ class DummyLoadConfig: def test_init_with_extra_config_file(tmp_path, monkeypatch): # Generate test JSON file config_content = { - "SOURCE": [{ - "device_id": 0 - }], + "SOURCE": [{"device_id": 0}], "MODEL": "foo-model", "LISTEN_PORT": 5001, "INT8_CACHE": "hbm", @@ -80,9 +76,7 @@ def test_init_with_extra_config_file(tmp_path, monkeypatch): dummy_logger = MagicMock() monkeypatch.setattr("vllm.logger.logger", dummy_logger) - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", - lambda x: True) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", lambda x: True) extra = {"CONFIG_FILE": str(config_file)} loader = make_loader_with_config(extra) @@ -96,18 +90,14 @@ def test_init_with_extra_config_file(tmp_path, monkeypatch): def test_init_with_extra_config(monkeypatch): dummy_logger = MagicMock() monkeypatch.setattr("vllm.logger.logger", dummy_logger) - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", - lambda x: True) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", lambda x: True) extra = { - "SOURCE": [{ - "device_id": 0 - }], + "SOURCE": [{"device_id": 0}], "MODEL": "foo", "LISTEN_PORT": "4000", "INT8_CACHE": "dram", - "OUTPUT_PREFIX": "/tmp/" + "OUTPUT_PREFIX": "/tmp/", } loader = make_loader_with_config(extra) assert loader.model_path == "foo" @@ -120,9 +110,7 @@ def test_init_with_extra_config(monkeypatch): def test_init_with_invalid_config(monkeypatch): dummy_logger = MagicMock() monkeypatch.setattr("vllm.logger.logger", dummy_logger) - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", - lambda x: False) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", lambda x: False) # c extra = { "SOURCE": None, @@ -143,7 +131,6 @@ def test_load_model_elastic_success(mock_logger, monkeypatch, tmp_path): monkeypatch.setattr("torch.distributed.get_rank", lambda: 0) class FakeContext: - def __enter__(self): pass @@ -152,54 +139,42 @@ def __exit__(self, a, b, c): monkeypatch.setattr("torch.device", lambda d: FakeContext()) # patch deep copy - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.netloader.deepcopy", lambda x: x) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.netloader.deepcopy", lambda x: x) # patch set_default_torch_dtype monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.netloader.set_default_torch_dtype", - lambda dtype: FakeContext()) + "vllm_ascend.model_loader.netloader.netloader.set_default_torch_dtype", lambda dtype: FakeContext() + ) # patch initialize_model dummy_model = MagicMock(spec=nn.Module) dummy_model.eval.return_value = dummy_model - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.netloader.initialize_model", - lambda **kwargs: dummy_model) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.netloader.initialize_model", lambda **kwargs: dummy_model) # patch elastic_load - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.netloader.elastic_load", - lambda **kwargs: dummy_model) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.netloader.elastic_load", lambda **kwargs: dummy_model) # patch process_weights_after_loading monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.netloader.process_weights_after_loading", - lambda *a, **k: None) + "vllm_ascend.model_loader.netloader.netloader.process_weights_after_loading", lambda *a, **k: None + ) # patch get_ip monkeypatch.setattr("vllm.utils.network_utils.get_ip", lambda: "127.0.0.1") # patch find_free_port - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.netloader.find_free_port", - lambda: 8888) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.netloader.find_free_port", lambda: 8888) # patch ElasticServer class DummyElasticServer: - def __init__(*a, **k): pass def start(self): pass - monkeypatch.setattr( - "vllm_ascend.model_loader.netloader.netloader.ElasticServer", - DummyElasticServer) + monkeypatch.setattr("vllm_ascend.model_loader.netloader.netloader.ElasticServer", DummyElasticServer) # write output_prefix to the temporary directory extra = { - "SOURCE": [{ - "device_id": 0 - }], + "SOURCE": [{"device_id": 0}], "MODEL": "foo", "LISTEN_PORT": 5555, "OUTPUT_PREFIX": str(tmp_path) + "/output_", - "INT8_CACHE": "no" + "INT8_CACHE": "no", } loader = make_loader_with_config(extra) vllm_config = DummyVllmConfig() diff --git a/tests/ut/model_loader/netloader/test_netloader_elastic.py b/tests/ut/model_loader/netloader/test_netloader_elastic.py index 127f1dd6c54..7998540351f 100644 --- a/tests/ut/model_loader/netloader/test_netloader_elastic.py +++ b/tests/ut/model_loader/netloader/test_netloader_elastic.py @@ -24,18 +24,12 @@ import torch import vllm.logger -from vllm_ascend.model_loader.netloader.interaction.elastic import ( - ElasticClient, ElasticServer) +from vllm_ascend.model_loader.netloader.interaction.elastic import ElasticClient, ElasticServer # Simulate server's normal response def mock_server_response(data): - return json.dumps({ - "label": "JOIN_ACK", - "content": { - "name": "mocked_name" - } - }).encode("utf-8") + return json.dumps({"label": "JOIN_ACK", "content": {"name": "mocked_name"}}).encode("utf-8") # Simulate server's error response @@ -56,12 +50,12 @@ def test_elastic_client_init(): tp = 1 pp = 1 - with patch('socket.socket') as mock_socket: + with patch("socket.socket") as mock_socket: mock_socket_instance = MagicMock() mock_socket.return_value = mock_socket_instance mock_socket_instance.recv.return_value = mock_server_response(None) - mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346) + mock_socket_instance.getsockname.return_value = ("127.0.0.1", 12346) mock_socket_instance.__enter__.return_value = mock_socket_instance with ElasticClient(sources, device_id, model_path, tp, pp) as client: @@ -79,18 +73,17 @@ def test_elastic_client_register(): tp = 1 pp = 1 - with patch('socket.socket') as mock_socket: + with patch("socket.socket") as mock_socket: mock_socket_instance = MagicMock() mock_socket.return_value = mock_socket_instance mock_socket_instance.connect.return_value = None mock_socket_instance.recv.return_value = mock_server_response(None) - mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346) + mock_socket_instance.getsockname.return_value = ("127.0.0.1", 12346) mock_socket_instance.__enter__.return_value = mock_socket_instance client = ElasticClient(sources, device_id, model_path, tp, pp) - assert client.register(device_id, model_path, tp, - pp) == ("mocked_name", 12346) + assert client.register(device_id, model_path, tp, pp) == ("mocked_name", 12346) # Test the behavior of the `register` method of ElasticClient when the server returns an error response. @@ -101,16 +94,14 @@ def test_elastic_client_register_error_response(): tp = 1 pp = 1 - with patch('socket.socket') as mock_socket: + with patch("socket.socket") as mock_socket: mock_socket_instance = MagicMock() mock_socket.return_value = mock_socket_instance mock_socket_instance.connect.return_value = None - mock_socket_instance.recv.return_value = mock_server_error_response( - None) + mock_socket_instance.recv.return_value = mock_server_error_response(None) - with ElasticClient(sources, device_id, model_path, tp, pp) as client: - with pytest.raises(RuntimeError): - client.register(device_id, model_path, tp, pp) + with ElasticClient(sources, device_id, model_path, tp, pp) as client, pytest.raises(RuntimeError): + client.register(device_id, model_path, tp, pp) mock_socket_instance.close.assert_called_once() @@ -122,7 +113,7 @@ def test_elastic_client_register_exception(): tp = 1 pp = 1 - with patch('socket.socket') as mock_socket: + with patch("socket.socket") as mock_socket: mock_socket_instance = MagicMock() mock_socket.return_value = mock_socket_instance mock_socket_instance.connect.return_value = None @@ -130,14 +121,12 @@ def test_elastic_client_register_exception(): mock_socket_instance.__enter__.return_value = mock_socket_instance mock_socket_instance.__exit__.return_value = None - with ElasticClient(sources, device_id, model_path, tp, pp) as client: - with pytest.raises(RuntimeError): - client.register(device_id, model_path, tp, pp) + with ElasticClient(sources, device_id, model_path, tp, pp) as client, pytest.raises(RuntimeError): + client.register(device_id, model_path, tp, pp) mock_socket_instance.close.assert_called_once() class FakeInt8Param: - def __init__(self, name="param", device="npu", dtype=torch.int8): self.dtype = dtype self.device = torch.device(device) @@ -158,7 +147,6 @@ def cpu(self): class FakeModel: - def __init__(self): self.params = { "param1": MagicMock(dtype=torch.float32), # This will be ignored @@ -185,7 +173,7 @@ def server_config(): "tp": 1, "pp": 1, "int8_cache": "dram", - 'int8_cache_name': None + "int8_cache_name": None, } @@ -202,22 +190,20 @@ def test_server_initialization(server_config, mock_model): # Check the socket configuration mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_STREAM) - mock_socket.return_value.setsockopt.assert_called_with( - socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + mock_socket.return_value.setsockopt.assert_called_with(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) mock_socket.return_value.bind.assert_called_with(("127.0.0.1", 8080)) mock_socket.return_value.listen.assert_called_with(256) # Check int8 cache assert "param2" in server.original_int8 - assert server.original_int8[ - "param2"].device.type == "cpu" # Verifying DRAM Cache + assert server.original_int8["param2"].device.type == "cpu" # Verifying DRAM Cache - assert server.addr == server_config['addr'] - assert server.port == server_config['port'] - assert server.device_id == server_config['device_id'] - assert server.model_path == server_config['model_path'] - assert server.tp == server_config['tp'] - assert server.pp == server_config['pp'] + assert server.addr == server_config["addr"] + assert server.port == server_config["port"] + assert server.device_id == server_config["device_id"] + assert server.model_path == server_config["model_path"] + assert server.tp == server_config["tp"] + assert server.pp == server_config["pp"] # Get captured logs log_output = log_capture_string.getvalue() @@ -229,11 +215,8 @@ def test_server_initialization(server_config, mock_model): # Test the int8 cache option -@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"), - ("no", None), - ("invalid", None)]) -def test_int8_cache_handling(server_config, mock_model, cache_option, - expected_device, caplog): +@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"), ("no", None), ("invalid", None)]) +def test_int8_cache_handling(server_config, mock_model, cache_option, expected_device, caplog): server_config["int8_cache"] = cache_option server_config["model"] = mock_model @@ -255,16 +238,13 @@ def test_int8_cache_handling(server_config, mock_model, cache_option, if expected_device is None: assert len(server.original_int8) == 0 else: - assert server.original_int8[ - "param2"].device.type == expected_device + assert server.original_int8["param2"].device.type == expected_device # Test client processing def test_client_handler_valid_join(server_config, mock_model): server_config["model"] = mock_model - with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend" - ) as mock_p2p_send: - + with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend") as mock_p2p_send: # Create a simulated connection mock_conn = MagicMock() mock_addr = ("192.168.1.1", 12345) @@ -272,13 +252,7 @@ def test_client_handler_valid_join(server_config, mock_model): # Configuring Client Data valid_data = { "label": "JOIN", - "content": { - "device_id": 0, - "model_path": "/test/model", - "tp": 1, - "pp": 1, - "port": 9090 - } + "content": {"device_id": 0, "model_path": "/test/model", "tp": 1, "pp": 1, "port": 9090}, } mock_conn.recv.return_value = json.dumps(valid_data).encode("utf-8") @@ -287,16 +261,9 @@ def test_client_handler_valid_join(server_config, mock_model): server.register_handler(mock_conn, mock_addr) # Verify response - expected_ack = { - "label": "JOIN_ACK", - "content": { - "name": "192.168.1.1:12345" - } - } - mock_conn.send.assert_called_once_with( - json.dumps(expected_ack).encode("utf-8")) - mock_p2p_send.assert_called_once_with("127.0.0.1", 9090, - "192.168.1.1:12345") + expected_ack = {"label": "JOIN_ACK", "content": {"name": "192.168.1.1:12345"}} + mock_conn.send.assert_called_once_with(json.dumps(expected_ack).encode("utf-8")) + mock_p2p_send.assert_called_once_with("127.0.0.1", 9090, "192.168.1.1:12345") mock_conn.close.assert_called_once() @@ -315,8 +282,8 @@ def test_client_handler_mismatch(server_config): "model_path": "/wrong/model", "tp": 2, "pp": 2, - "port": 9090 - } + "port": 9090, + }, } mock_conn.recv.return_value = json.dumps(mismatch_data).encode("utf-8") @@ -326,13 +293,29 @@ def test_client_handler_mismatch(server_config): # Verify response expected_ack = { - "label": - "JOIN_NACK", - "content": - f"Received data {(mismatch_data['content']['device_id'], mismatch_data['content']['model_path'], mismatch_data['content']['tp'], mismatch_data['content']['pp'])} does not consist with this server {(server_config['device_id'], server_config['model_path'], server_config['tp'], server_config['pp'])}" + "label": "JOIN_NACK", + "content": ( + f"Received data " + f"{ + ( + mismatch_data['content']['device_id'], + mismatch_data['content']['model_path'], + mismatch_data['content']['tp'], + mismatch_data['content']['pp'], + ) + } " + f"does not consist with this server " + f"{ + ( + server_config['device_id'], + server_config['model_path'], + server_config['tp'], + server_config['pp'], + ) + }" + ), } - mock_conn.send.assert_called_once_with( - json.dumps(expected_ack).encode("utf-8")) + mock_conn.send.assert_called_once_with(json.dumps(expected_ack).encode("utf-8")) mock_conn.close.assert_called_once() @@ -340,23 +323,16 @@ def test_client_handler_mismatch(server_config): @pytest.mark.parametrize( "invalid_data,should_send", [ + ({"label": "WRONG_LABEL"}, True), # Incorrect label, can be decoded as JSON, but the content is invalid. ( - { - "label": "WRONG_LABEL" - }, True - ), # Incorrect label, can be decoded as JSON, but the content is invalid. - ( - { - "content": { - "missing_fields": True - } - }, True - ), # Missing field, can be decoded as JSON, but the content is invalid. - ("plain text", False), # Non-JSON data, json.loads failed - (b"invalid_bytes", False) # Invalid byte, decode or json.loads failed - ]) -def test_client_handler_invalid_requests(server_config, invalid_data, - should_send): + {"content": {"missing_fields": True}}, + True, + ), # Missing field, can be decoded as JSON, but the content is invalid. + ("plain text", False), # Non-JSON data, json.loads failed + (b"invalid_bytes", False), # Invalid byte, decode or json.loads failed + ], +) +def test_client_handler_invalid_requests(server_config, invalid_data, should_send): with patch("socket.socket"): log_capture_string = io.StringIO() ch = logging.StreamHandler(log_capture_string) @@ -369,23 +345,18 @@ def test_client_handler_invalid_requests(server_config, invalid_data, mock_addr = ("192.168.1.1", 12345) if isinstance(invalid_data, (str, bytes)): - mock_conn.recv.return_value = invalid_data if isinstance( - invalid_data, bytes) else invalid_data.encode() + mock_conn.recv.return_value = invalid_data if isinstance(invalid_data, bytes) else invalid_data.encode() else: - mock_conn.recv.return_value = json.dumps(invalid_data).encode( - "utf-8") + mock_conn.recv.return_value = json.dumps(invalid_data).encode("utf-8") server.register_handler(mock_conn, mock_addr) if should_send: expected_ack = { - "label": - "JOIN_NACK", - "content": - f"Received data does not contain required fields: {invalid_data}" + "label": "JOIN_NACK", + "content": f"Received data does not contain required fields: {invalid_data}", } - mock_conn.send.assert_called_once_with( - json.dumps(expected_ack).encode("utf-8")) + mock_conn.send.assert_called_once_with(json.dumps(expected_ack).encode("utf-8")) else: mock_conn.send.assert_not_called() @@ -400,9 +371,7 @@ def test_client_handler_invalid_requests(server_config, invalid_data, # Test the thread startup. def test_server_start(server_config): - with patch("socket.socket"), \ - patch("threading.Thread") as mock_thread: - + with patch("socket.socket"), patch("threading.Thread") as mock_thread: handler_thread_instance = mock_thread.return_value server = ElasticServer(**server_config) @@ -411,9 +380,9 @@ def test_server_start(server_config): # Assert that the correct target parameter was passed when instantiating the Thread instance. mock_thread.assert_called_once() args, kwargs = mock_thread.call_args - assert kwargs['target'] == server.elastic_client_handler + assert kwargs["target"] == server.elastic_client_handler - # Check that the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment). + # Verify the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment). assert handler_thread_instance.daemon is True # Check if the start() method is called. diff --git a/tests/ut/model_loader/netloader/test_netloader_load.py b/tests/ut/model_loader/netloader/test_netloader_load.py index 77c3486f61e..bc0ba8952c5 100644 --- a/tests/ut/model_loader/netloader/test_netloader_load.py +++ b/tests/ut/model_loader/netloader/test_netloader_load.py @@ -24,14 +24,8 @@ @pytest.fixture def mock_sources(): return [ - { - "device_id": 0, - "sources": ["a", "b"] - }, - { - "device_id": 1, - "sources": ["c"] - }, + {"device_id": 0, "sources": ["a", "b"]}, + {"device_id": 1, "sources": ["c"]}, ] @@ -76,8 +70,7 @@ def test_model_load_fail(mock_logger, mock_p2p): mock_client.ack = ["foo", "bar"] mock_client.server_addr = "addr" - with patch("vllm_ascend.model_loader.netloader.load.ElasticClient", - return_value=mock_client): + with patch("vllm_ascend.model_loader.netloader.load.ElasticClient", return_value=mock_client): # P2PLoad.load returns None mock_p2p_instance = MagicMock() mock_p2p_instance.load.return_value = None @@ -97,8 +90,7 @@ def test_model_load_success(mock_logger, mock_p2p): mock_client.ack = ["foo", "bar"] mock_client.server_addr = "addr" - with patch("vllm_ascend.model_loader.netloader.load.ElasticClient", - return_value=mock_client): + with patch("vllm_ascend.model_loader.netloader.load.ElasticClient", return_value=mock_client): expected_model = object() mock_p2p_instance = MagicMock() mock_p2p_instance.load.return_value = expected_model diff --git a/tests/ut/model_loader/netloader/test_netloader_utils.py b/tests/ut/model_loader/netloader/test_netloader_utils.py index 66198449737..4969cc37a24 100644 --- a/tests/ut/model_loader/netloader/test_netloader_utils.py +++ b/tests/ut/model_loader/netloader/test_netloader_utils.py @@ -20,8 +20,7 @@ import pytest -from vllm_ascend.model_loader.netloader.utils import (find_free_port, - is_valid_path_prefix) +from vllm_ascend.model_loader.netloader.utils import find_free_port, is_valid_path_prefix def test_find_free_port(): @@ -31,7 +30,7 @@ def test_find_free_port(): def test_is_valid_path_prefix_empty(): - assert not is_valid_path_prefix('') + assert not is_valid_path_prefix("") def test_is_valid_path_prefixIllegal_characters(): @@ -39,22 +38,22 @@ def test_is_valid_path_prefixIllegal_characters(): def test_is_valid_path_prefixRelative_path(): - assert is_valid_path_prefix('test') + assert is_valid_path_prefix("test") def test_is_valid_path_prefixAbsolute_path(): with tempfile.TemporaryDirectory() as tmpdir: - assert is_valid_path_prefix(os.path.join(tmpdir, 'test')) + assert is_valid_path_prefix(os.path.join(tmpdir, "test")) -@patch('os.path.exists', return_value=False) +@patch("os.path.exists", return_value=False) def test_is_valid_path_prefix_no_directory(mock_exists): - assert not is_valid_path_prefix('/nonexistent_dir/test') + assert not is_valid_path_prefix("/nonexistent_dir/test") -@patch('os.path.exists', return_value=True) +@patch("os.path.exists", return_value=True) def test_is_valid_path_prefix_directory_exists(mock_exists): - assert is_valid_path_prefix('/existing_dir/test') + assert is_valid_path_prefix("/existing_dir/test") if __name__ == "__main__": From 8d7b8d334a369cd3117eecaad568c88f5a7c5d04 Mon Sep 17 00:00:00 2001 From: MrZ20 <2609716663@qq.com> Date: Mon, 20 Apr 2026 17:31:07 +0800 Subject: [PATCH 2/2] fix Signed-off-by: MrZ20 <2609716663@qq.com> --- .../netloader/test_netloader_elastic.py | 35 ++++++++----------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/tests/ut/model_loader/netloader/test_netloader_elastic.py b/tests/ut/model_loader/netloader/test_netloader_elastic.py index 7998540351f..60652bb168d 100644 --- a/tests/ut/model_loader/netloader/test_netloader_elastic.py +++ b/tests/ut/model_loader/netloader/test_netloader_elastic.py @@ -292,28 +292,23 @@ def test_client_handler_mismatch(server_config): assert isinstance(mismatch_data["content"], dict) # Verify response + mismatch_tuple = ( + mismatch_data["content"]["device_id"], + mismatch_data["content"]["model_path"], + mismatch_data["content"]["tp"], + mismatch_data["content"]["pp"], + ) + + server_tuple = ( + server_config["device_id"], + server_config["model_path"], + server_config["tp"], + server_config["pp"], + ) + expected_ack = { "label": "JOIN_NACK", - "content": ( - f"Received data " - f"{ - ( - mismatch_data['content']['device_id'], - mismatch_data['content']['model_path'], - mismatch_data['content']['tp'], - mismatch_data['content']['pp'], - ) - } " - f"does not consist with this server " - f"{ - ( - server_config['device_id'], - server_config['model_path'], - server_config['tp'], - server_config['pp'], - ) - }" - ), + "content": (f"Received data {mismatch_tuple} does not consist with this server {server_tuple}"), } mock_conn.send.assert_called_once_with(json.dumps(expected_ack).encode("utf-8")) mock_conn.close.assert_called_once()