Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,16 +694,18 @@ def test_async_load_kv(
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, default_vllm_config, dist_init
self, local_tp_size: int, default_vllm_config, dist_init, monkeypatch
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
monkeypatch.setattr(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size",
lambda: local_tp_size,
)

vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size

connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
Expand Down Expand Up @@ -738,10 +740,10 @@ def check_handshake(remote_tp_size: int):
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test actually doesn't work with local_tp_size == remote_tp_size

remote_tp_size=4,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
check_handshake(4)

# NOTE flexibility: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
Expand All @@ -759,9 +761,8 @@ def check_handshake(remote_tp_size: int):
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, default_vllm_config, dist_init
self, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1318,12 +1318,12 @@ def _nixl_handshake(
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
setup_agent_time = time.perf_counter()

# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
Expand Down
Loading