diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 095bd4c3dd98..d0af96f5d3ed 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -694,16 +694,18 @@ def test_async_load_kv( ) @pytest.mark.parametrize("local_tp_size", [1, 2]) def test_prefill_tp_size_greater_than_decode_tp_size( - self, local_tp_size: int, default_vllm_config, dist_init + self, local_tp_size: int, default_vllm_config, dist_init, monkeypatch ): """ Verify remote TP > local TP handshake succeeds with different remote configurations. """ + monkeypatch.setattr( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", + lambda: local_tp_size, + ) vllm_config = create_vllm_config() - local_tp_size = 1 - vllm_config.parallel_config.tensor_parallel_size = local_tp_size connector = NixlConnector( vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) @@ -738,10 +740,10 @@ def check_handshake(remote_tp_size: int): remote_agents = worker._nixl_handshake( host="localhost", port=1234, - remote_tp_size=2, + remote_tp_size=4, expected_engine_id=worker.REMOTE_ENGINE_ID, ) - check_handshake(2) + check_handshake(4) # NOTE flexibility: a second remote with higher number of ranks is # discovered. This is not a scenario we actively support right now, but @@ -759,9 +761,8 @@ def check_handshake(remote_tp_size: int): "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper, ) - @pytest.mark.parametrize("local_tp_size", [1, 2]) def test_prefill_tp_size_greater_than_decode_tp_size_mla( - self, local_tp_size: int, default_vllm_config, dist_init + self, default_vllm_config, dist_init ): """ Verify remote TP > local TP handshake succeeds with different diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 9001e31810ff..79a04bcb95e0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1318,12 +1318,12 @@ def _nixl_handshake( f"Expected {expected_engine_id}," f"received {metadata.engine_id}." ) - setup_agent_time = time.perf_counter() # Register Remote agent. remote_agent_name = self.add_remote_agent( metadata, remote_rank, remote_tp_size ) + setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time,