Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum):
WORKER = 1


class KVConnectorMetadata:
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
Expand All @@ -71,7 +71,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.")
self._connector_metadata = KVConnectorMetadata()
self._connector_metadata: Optional[KVConnectorMetadata] = None
self._vllm_config = vllm_config
self._role = role

Expand Down Expand Up @@ -102,7 +102,7 @@ def clear_connector_metadata(self) -> None:
This function should be called by the model runner every time
after the model execution.
"""
self._connector_metadata = KVConnectorMetadata()
self._connector_metadata = None

def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.
Expand All @@ -112,6 +112,9 @@ def _get_connector_metadata(self) -> KVConnectorMetadata:
Returns:
ConnectorMetadata: the connector metadata.
"""

# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
Expand Down
34 changes: 15 additions & 19 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,28 +250,24 @@ def _aggregate_workers_output(
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers

finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
self._send_remaining_count[req_id] = new_count
remaining_count_dict[req_id] = new_count

# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)

# select output of the worker specified by output_rank
output = outputs[self.output_rank]
Expand Down
4 changes: 0 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,10 +1539,6 @@ def execute_model(
attn_metadata,
)

# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()

self.eplb_step()

return ModelRunnerOutput(
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def execute_model(
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving

# Clear KVConnector state for this step.
get_kv_transfer_group().clear_connector_metadata()

# with a connector, the scheduler expects output from all workers
return output

Expand Down