diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 50e83aa2ef20..fb4b641e1376 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2479,3 +2479,122 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) remote_tp_size=1, expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, ) + + @patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + FakeNixlWrapper, + ) + def test_mla_broadcast_notif_uses_remote_request_id( + self, default_vllm_config, dist_init + ): + """MLA + remote TP > local TP: the broadcast notification sent to + non-read prefill ranks must be keyed by the prefill-side request + id (``meta.remote.request_id``), not the local decode request id. + + Prefill ranks key ``_reqs_to_send`` by their own request id, so a + broadcast keyed by the decode id is rejected in + ``_get_new_notifs`` with "Potentially invalid KV blocks for + unrecognized request" and the blocks only release via the abort + timeout. See ``_read_blocks_for_req`` in + ``vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py``. + """ + decode_tp_size = 1 + prefill_tp_size = 4 + + vllm_config = create_vllm_config() + vllm_config.parallel_config.tensor_parallel_size = decode_tp_size + + connector = NixlConnector( + vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16) + ) + connector.connector_worker = FakeNixlConnectorWorker( + vllm_config, connector.engine_id, hand_shake_latency=0 + ) + worker = connector.connector_worker + + # Force the MLA path; only `self.use_mla` gates the branches we + # exercise inside `_read_blocks_for_req`. + worker.use_mla = True + + # Manually register the remote (P) engine and pre-populate the + # per-rank state the handshake would normally fill in. The real + # `_nixl_handshake` is unnecessary here — we only need + # `transfer_topo` to know `remote_tp_size`, and `_remote_agents` + # / `dst_xfer_side_handles` to be keyed by remote rank. + remote_engine_id = "remote_engine" + worker.transfer_topo.register_remote_engine( + remote_engine_id=remote_engine_id, + remote_tp_size=prefill_tp_size, + remote_block_size=worker.block_size, + remote_block_len=worker.block_size * 4096, + remote_physical_blocks_per_logical=1, + local_block_len=worker.block_size * 4096, + ) + worker._remote_agents[remote_engine_id] = { + rank: f"agent_p{rank}" for rank in range(prefill_tp_size) + } + worker.dst_xfer_side_handles = { + remote_engine_id: {rank: 100 + rank for rank in range(prefill_tp_size)} + } + # Sanity: D TP=1, P TP=4 => tp_ratio = -4 (P > D). + assert worker.transfer_topo.tp_ratio(prefill_tp_size) == -prefill_tp_size + + # Distinct ids on each side — that's the whole point of the bug. + decode_req_id = "decode-req-AAAA" + prefill_req_id = "prefill-req-BBBB" + assert decode_req_id != prefill_req_id + + metadata = NixlConnectorMetadata() + metadata.add_new_req_to_recv( + request_id=decode_req_id, + local_block_ids=([0, 1, 2],), + kv_transfer_params={ + "remote_block_ids": ([10, 11, 12],), + "remote_engine_id": remote_engine_id, + "remote_request_id": prefill_req_id, + "remote_host": "localhost", + "remote_port": 1234, + "remote_tp_size": prefill_tp_size, + }, + ) + meta = metadata.reqs_to_recv[decode_req_id] + + # Capture broadcast send_notif calls; stub `_read_blocks` so we + # don't need a working xfer path. Real `_read_blocks` emits its + # auto-notif via `make_prepped_xfer`, not via `send_notif`, so + # any captured `send_notif` here is a broadcast. + send_notif_calls: list[tuple[str, bytes]] = [] + worker.nixl_wrapper.send_notif = ( # type: ignore[method-assign] + lambda agent_name, notif_msg: send_notif_calls.append( + (agent_name, notif_msg) + ) + ) + worker._read_blocks = MagicMock() # type: ignore[method-assign] + + worker._read_blocks_for_req(decode_req_id, meta) + + # MLA: read once from rank 0 and broadcast to the other ranks. + worker._read_blocks.assert_called_once() + assert worker._read_blocks.call_args.kwargs["remote_rank"] == 0 + assert ( + worker._read_blocks.call_args.kwargs["remote_request_id"] == prefill_req_id + ) + + # Broadcast goes to ranks {1, 2, 3} only, never to the read target. + expected_recipients = { + worker._remote_agents[remote_engine_id][r] + for r in range(1, prefill_tp_size) + } + assert {agent for agent, _ in send_notif_calls} == expected_recipients + + # Every broadcast notif must be keyed by the prefill request id. + # Pre-fix this used the *decode* request id, which prefill ranks + # didn't recognize. + expected_notif = f"{prefill_req_id}:{decode_tp_size}".encode() + bad_notif = f"{decode_req_id}:{decode_tp_size}".encode() + for agent, notif in send_notif_calls: + assert notif == expected_notif, ( + f"Broadcast notif to {agent!r} must use prefill_req_id; " + f"got {notif!r} (expected {expected_notif!r}, " + f"buggy form would be {bad_notif!r})" + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index bd7ef5973f62..607bf4b988ff 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1971,7 +1971,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): if self.use_mla and tp_ratio < 0: # ..but we still need to notify the other remote ranks that we # have the blocks we need so they can update the request state. - notif_id = f"{req_id}:{self.world_size}".encode() + notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() remote_agents = self._remote_agents[meta.remote.engine_id] for rank_to_notify, agent in remote_agents.items(): if rank_to_notify != remote_rank: