From 3631fa0c2423b91d0d6dd2e59d365576b0718c5b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 10 Jul 2025 14:07:40 +0100 Subject: [PATCH 1/2] [KVConnector] Always call connector clear_metadata() at end of step Signed-off-by: Nick Hill Co-authored-by: David Ben-David Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/base.py | 8 ++--- vllm/v1/executor/multiproc_executor.py | 34 ++++++++----------- vllm/v1/worker/gpu_model_runner.py | 4 --- vllm/v1/worker/gpu_worker.py | 4 +++ 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index b5199d85d5ae..66ab9077cb94 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum): WORKER = 1 -class KVConnectorMetadata: +class KVConnectorMetadata(ABC): # noqa: B024 """ Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. @@ -71,7 +71,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design.") - self._connector_metadata = KVConnectorMetadata() + self._connector_metadata: Optional[KVConnectorMetadata] = None self._vllm_config = vllm_config self._role = role @@ -102,9 +102,9 @@ def clear_connector_metadata(self) -> None: This function should be called by the model runner every time after the model execution. """ - self._connector_metadata = KVConnectorMetadata() + self._connector_metadata = None - def _get_connector_metadata(self) -> KVConnectorMetadata: + def _get_connector_metadata(self) -> Optional[KVConnectorMetadata]: """Get the connector metadata. This function should only be called inside the connector. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 52812c5859fa..95ba45147fd8 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -250,28 +250,24 @@ def _aggregate_workers_output( self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: # aggregate finished_sending, finished_recving from all workers - finished_sending = set[str]() - finished_recving = set[str]() - for output in outputs: - # update finished_sending - for req_id in output.finished_sending or []: - new_count = self._send_remaining_count[req_id] - 1 + def update_finished_set(req_ids: Optional[set[str]], + remaining_count_dict: dict[str, int], + finished_set: set[str]) -> None: + for req_id in req_ids or (): + new_count = remaining_count_dict[req_id] - 1 if new_count == 0: - # got response from all workers, report back to scheduler - finished_sending.add(req_id) - del self._send_remaining_count[req_id] + finished_set.add(req_id) + del remaining_count_dict[req_id] else: - self._send_remaining_count[req_id] = new_count + remaining_count_dict[req_id] = new_count - # update finished_recving - for req_id in output.finished_recving or []: - new_count = self._recv_remaining_count[req_id] - 1 - if new_count == 0: - # got response from all workers, report back to scheduler - finished_recving.add(req_id) - del self._recv_remaining_count[req_id] - else: - self._recv_remaining_count[req_id] = new_count + finished_sending = set[str]() + finished_recving = set[str]() + for output in outputs: + update_finished_set(output.finished_sending, + self._send_remaining_count, finished_sending) + update_finished_set(output.finished_recving, + self._recv_remaining_count, finished_recving) # select output of the worker specified by output_rank output = outputs[self.output_rank] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9cda4dbb9615..e264285859e4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1539,10 +1539,6 @@ def execute_model( attn_metadata, ) - # Clear KVConnector state after all KVs are generated. - if has_kv_transfer_group(): - get_kv_transfer_group().clear_connector_metadata() - self.eplb_step() return ModelRunnerOutput( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6b30acee1d90..3c764bcdcb21 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -338,6 +338,10 @@ def execute_model( output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.finished_sending = finished_sending output.finished_recving = finished_recving + + # Clear KVConnector state for this step. + get_kv_transfer_group().clear_connector_metadata() + # with a connector, the scheduler expects output from all workers return output From 8b171ccddda3973c16ea551bb0591065e8bbb6ab Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 10 Jul 2025 14:24:49 +0100 Subject: [PATCH 2/2] fix typing Signed-off-by: Nick Hill Co-authored-by: David Ben-David Signed-off-by: Nick Hill --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 66ab9077cb94..9459ab27aba3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -104,7 +104,7 @@ def clear_connector_metadata(self) -> None: """ self._connector_metadata = None - def _get_connector_metadata(self) -> Optional[KVConnectorMetadata]: + def _get_connector_metadata(self) -> KVConnectorMetadata: """Get the connector metadata. This function should only be called inside the connector. @@ -112,6 +112,9 @@ def _get_connector_metadata(self) -> Optional[KVConnectorMetadata]: Returns: ConnectorMetadata: the connector metadata. """ + + # Should only be called while set to valid metadata. + assert self._connector_metadata is not None return self._connector_metadata def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):