From 3817016ba9e5f86dfc70ef209aafe9fd51fe26d1 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Wed, 10 Dec 2025 10:42:01 -0800 Subject: [PATCH 1/4] Fix a bug when creating meta at prefill side for save to host Signed-off-by: Chendi Xue --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 514b8534aaa6..149c7b9273b0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -238,7 +238,7 @@ def add_new_req( local_physical_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], - remote_request_id=kv_transfer_params["remote_request_id"], + remote_request_id=kv_transfer_params.get("remote_request_id", ""), remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], # P workers don't need to receive tp_size from proxy here. @@ -602,7 +602,6 @@ def update_state_after_alloc( num_external_tokens, params, ) - if not params: return From 74aba668f0527a586bff5046e6e575a8256d6155 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 11 Dec 2025 07:28:51 -0800 Subject: [PATCH 2/4] update add_new_req by setting remote with default value Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 149c7b9273b0..7da1f426ab77 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -236,11 +236,11 @@ def add_new_req( _req = ReqMeta( local_block_ids=local_block_ids, local_physical_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params["remote_block_ids"], - remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_block_ids=kv_transfer_params.get("remote_block_ids", []), + remote_engine_id=kv_transfer_params.get("remote_engine_id", ""), remote_request_id=kv_transfer_params.get("remote_request_id", ""), - remote_host=kv_transfer_params["remote_host"], - remote_port=kv_transfer_params["remote_port"], + remote_host=kv_transfer_params.get("remote_host", ""), + remote_port=kv_transfer_params.get("remote_port", 0), # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), ) @@ -602,6 +602,7 @@ def update_state_after_alloc( num_external_tokens, params, ) + if not params: return From 1916c0ba4f4fae4a949561ba231d4372edc4e608 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 11 Dec 2025 07:57:54 -0800 Subject: [PATCH 3/4] Use None as default and ignore mypy Signed-off-by: Chendi Xue --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7da1f426ab77..186884fffad8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -236,11 +236,11 @@ def add_new_req( _req = ReqMeta( local_block_ids=local_block_ids, local_physical_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params.get("remote_block_ids", []), - remote_engine_id=kv_transfer_params.get("remote_engine_id", ""), - remote_request_id=kv_transfer_params.get("remote_request_id", ""), - remote_host=kv_transfer_params.get("remote_host", ""), - remote_port=kv_transfer_params.get("remote_port", 0), + remote_block_ids=kv_transfer_params.get("remote_block_ids"), # type: ignore + remote_engine_id=kv_transfer_params.get("remote_engine_id"), # type: ignore + remote_request_id=kv_transfer_params.get("remote_request_id"), # type: ignore + remote_host=kv_transfer_params.get("remote_host"), # type: ignore + remote_port=kv_transfer_params.get("remote_port"), # type: ignore # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), ) From 831d017c0079d4ef71f47244569dacde8bca597b Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Fri, 12 Dec 2025 11:16:50 -0500 Subject: [PATCH 4/4] [NIXL] Refactor prefill-only ReqMeta into RemoteMeta On the decode side, in reqs_to_save, we do not expect or need any of these remote_ fields - they are for the prefill side only. Make this more clear by putting these fields in their own dataclass which is only present on requests in reqs_to_recv. Signed-off-by: Mark McLoughlin --- .../kv_connector/unit/test_nixl_connector.py | 12 +-- .../kv_connector/v1/nixl_connector.py | 95 +++++++++++-------- 2 files changed, 62 insertions(+), 45 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 53da09cfbc21..66804fa671c7 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -461,7 +461,7 @@ def test_multi_xfer_one_engine( metadata = NixlConnectorMetadata() if num_xfers > 0: num_xfers -= 1 - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[num_xfers + 1, num_xfers + 2, num_xfers + 3], kv_transfer_params={ @@ -532,7 +532,7 @@ def test_async_load_kv( vllm_config, connector.engine_id ) metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id="id", local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -588,7 +588,7 @@ def test_concurrent_load_kv( metadata = NixlConnectorMetadata() total_reqs = 5 for i in range(total_reqs): - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=f"id_{i}", local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -752,7 +752,7 @@ def test_kv_connector_stats(dist_init): # Create transfer metadata request_id = "test_req_for_stats" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -1515,7 +1515,7 @@ def test_handshake_failure_returns_finished(dist_init): request_id = "test_handshake_fail" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -1565,7 +1565,7 @@ def test_transfer_setup_failure_returns_finished(dist_init): request_id = "test_transfer_fail" metadata = NixlConnectorMetadata() - metadata.add_new_req( + metadata.add_new_req_to_recv( request_id=request_id, local_block_ids=[7, 8, 9], kv_transfer_params={ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 186884fffad8..fb4b8ac391af 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash( return compat_hash +@dataclass +class RemoteMeta: + block_ids: list[int] + host: str + port: int + engine_id: str + request_id: str + + @dataclass class ReqMeta: local_block_ids: list[int] # To be used when logical block size does not match the kernel block size local_physical_block_ids: list[int] - remote_block_ids: list[int] - remote_host: str - remote_port: int - remote_engine_id: str - remote_request_id: str tp_size: int + remote: RemoteMeta | None = None class NixlConnectorMetadata(KVConnectorMetadata): @@ -223,31 +228,43 @@ def __init__(self): self.reqs_in_batch: set[ReqId] = set() self.reqs_not_processed: set[ReqId] = set() - def add_new_req( + def _add_new_req( self, - request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - load_remote_cache: bool = True, - save_to_host: bool = False, - ): - # save and load are mutually exclusive - assert load_remote_cache ^ save_to_host - _req = ReqMeta( + ) -> ReqMeta: + return ReqMeta( local_block_ids=local_block_ids, local_physical_block_ids=local_block_ids, - remote_block_ids=kv_transfer_params.get("remote_block_ids"), # type: ignore - remote_engine_id=kv_transfer_params.get("remote_engine_id"), # type: ignore - remote_request_id=kv_transfer_params.get("remote_request_id"), # type: ignore - remote_host=kv_transfer_params.get("remote_host"), # type: ignore - remote_port=kv_transfer_params.get("remote_port"), # type: ignore # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), ) - if save_to_host: - self.reqs_to_save[request_id] = _req - if load_remote_cache: - self.reqs_to_recv[request_id] = _req + + def add_new_req_to_save( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + self.reqs_to_save[request_id] = self._add_new_req( + local_block_ids, kv_transfer_params + ) + + def add_new_req_to_recv( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + req = self._add_new_req(local_block_ids, kv_transfer_params) + req.remote = RemoteMeta( + block_ids=kv_transfer_params["remote_block_ids"], + engine_id=kv_transfer_params["remote_engine_id"], + request_id=kv_transfer_params["remote_request_id"], + host=kv_transfer_params["remote_host"], + port=kv_transfer_params["remote_port"], + ) + self.reqs_to_recv[request_id] = req class NixlConnector(KVConnectorBase_V1): @@ -666,22 +683,18 @@ def build_connector_meta( # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None - meta.add_new_req( + meta.add_new_req_to_recv( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - load_remote_cache=True, - save_to_host=False, ) for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - meta.add_new_req( + meta.add_new_req_to_save( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - load_remote_cache=False, - save_to_host=True, ) meta.reqs_to_send = self._reqs_need_send @@ -1124,10 +1137,11 @@ def _background_nixl_handshake( # Do NIXL handshake in background and add to _ready_requests when done. fut = self._handshake_futures.get(remote_engine_id) if fut is None: + assert meta.remote is not None fut = self._handshake_initiation_executor.submit( self._nixl_handshake, - meta.remote_host, - meta.remote_port, + meta.remote.host, + meta.remote.port, meta.tp_size, remote_engine_id, ) @@ -1774,6 +1788,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" + assert meta.remote is not None if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) if self.enable_permute_local_kv: @@ -1781,7 +1796,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: # post processing for heteroblocksize block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( - meta.remote_engine_id + meta.remote.engine_id ) if ( not self.use_mla @@ -1916,17 +1931,18 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): meta.local_physical_block_ids = self._logical_to_kernel_block_ids( meta.local_block_ids ) - meta.remote_block_ids = self._logical_to_kernel_block_ids( - meta.remote_block_ids + assert meta.remote is not None + meta.remote.block_ids = self._logical_to_kernel_block_ids( + meta.remote.block_ids ) - remote_engine_id = meta.remote_engine_id + remote_engine_id = meta.remote.engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_physical_block_ids), - len(meta.remote_block_ids), + len(meta.remote.block_ids), ) # always store metadata for failure recovery self._recving_metadata[req_id] = meta @@ -1965,17 +1981,18 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): + assert meta.remote is not None logger.debug( "Remote agent %s available, calling _read_blocks for req %s", - meta.remote_engine_id, + meta.remote.engine_id, req_id, ) self._read_blocks( request_id=req_id, - dst_engine_id=meta.remote_engine_id, - remote_request_id=meta.remote_request_id, + dst_engine_id=meta.remote.engine_id, + remote_request_id=meta.remote.request_id, local_block_ids=meta.local_physical_block_ids, - remote_block_ids=meta.remote_block_ids, + remote_block_ids=meta.remote.block_ids, ) def _read_blocks(