diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 673f7fe1dcc0..c377645ced49 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -440,56 +440,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: "Rank %s, get_finished: %s requests done sending " "and %s requests done recving", self.rank, len(done_sending), len(done_recving)) - - if self.world_size == 1: - return done_sending, done_recving - - # In TP>1 setup, each rank exchanges KVs with its counterpart - # ranks independently. get_finished() runs in a worker creates - # the done_sending and done_recving sets that are sent to the - # scheduler via ModelRunnerOutput by Rank 0. To avoid race - # ensure trnxs are done before adding to finished, Ranks 1 to - # N-1 communicate to Rank 0 once their transaction is done. - # Rank 0 only returns finished once all ranks are complete. - if self.rank == 0: - for req_id in done_sending: - self._done_sending_count[req_id] += 1 - for req_id in done_recving: - self._done_recving_count[req_id] += 1 - - # Update the counts of how many ranks have finished. - # Get notifies from other ranks that txns are done. - other_ranks_finished_ids: list[str] = [] - for i in range(1, self.world_size): - other_ranks_finished_ids.extend( - self.tp_group.recv_object(src=i)) - for req_id in other_ranks_finished_ids: - if (req_id in self._done_recving_count - or req_id in self._recving_transfers): - self._done_recving_count[req_id] += 1 - else: - self._done_sending_count[req_id] += 1 - - # Return ids that have finished on all ranks to the scheduler. - all_done_sending: set[str] = set() - all_done_recving: set[str] = set() - for req_id in list(self._done_recving_count.keys()): - if self._done_recving_count[req_id] == self.world_size: - self._done_recving_count.pop(req_id) - all_done_recving.add(req_id) - for req_id in list(self._done_sending_count.keys()): - if self._done_sending_count[req_id] == self.world_size: - self._done_sending_count.pop(req_id) - all_done_sending.add(req_id) - - return all_done_sending, all_done_recving - - else: - finished_req_ids = list(done_recving.union(done_sending)) - self.tp_group.send_object(finished_req_ids, dst=0) - - # NOTE(rob): unused as only Rank 0 sends to sched. - return done_sending, done_recving + return done_sending, done_recving def _get_new_notifs(self) -> set[str]: """Get req_ids which got a remote xfer message.""" diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index cb125bf4bf17..17c27b941bd3 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -8,6 +8,7 @@ import time import traceback import weakref +from collections import defaultdict from concurrent.futures import Future from dataclasses import dataclass from enum import Enum, auto @@ -61,6 +62,12 @@ def _init_executor(self) -> None: # Set multiprocessing envs that are common to V0 and V1 set_multiprocessing_worker_envs(self.parallel_config) + # Track requests finishing in each worker + self._done_sending_count: dict[str, int] = defaultdict( + lambda: tensor_parallel_size) + self._done_recving_count: dict[str, int] = defaultdict( + lambda: tensor_parallel_size) + # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. @@ -139,7 +146,7 @@ def register_failure_callback(self, callback: FailureCallback): else: self.failure_callback = callback - def execute_model( + def execute_model_non_pdisagg( self, scheduler_output, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: @@ -149,6 +156,36 @@ def execute_model( timeout=EXECUTE_MODEL_TIMEOUT_S) return output + def execute_model( + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: + outputs: list[ModelRunnerOutput] = self.collective_rpc( + "execute_model", + args=(scheduler_output, ), + timeout=EXECUTE_MODEL_TIMEOUT_S) + rank0_outputs = outputs[0] + if len(outputs) == 1: + return rank0_outputs + + done_sending, done_recving = set(), set() + for output in outputs: + for req_id in output.finished_sending or (): + if self._done_sending_count[req_id] == 1: + del self._done_sending_count[req_id] + done_sending.add(req_id) + else: + self._done_sending_count[req_id] -= 1 + for req_id in output.finished_recving or (): + if self._done_recving_count[req_id] == 1: + del self._done_recving_count[req_id] + done_recving.add(req_id) + else: + self._done_recving_count[req_id] -= 1 + rank0_outputs.finished_recving = done_recving + rank0_outputs.finished_sending = done_sending + return rank0_outputs + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 68c4e94fcd73..44e95cd19d99 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """A GPU worker class.""" +import copy import gc import os from typing import TYPE_CHECKING, Optional @@ -7,6 +8,7 @@ import torch import torch.distributed import torch.nn as nn +from v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT import vllm.envs as envs from vllm.config import VllmConfig @@ -266,7 +268,13 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) - return output if self.is_driver_worker else None + if not self.is_driver_worker: + # No need to transfer entire output for non-zero ranks. + new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + new_output.finished_recving = output.finished_recving + new_output.finished_sending = output.finished_recving + output = new_output + return output def profile(self, is_start: bool = True): if self.profiler is None: