From 951096e2908a3437a1a18e47d6134a475bf185d6 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 11:35:21 -0700 Subject: [PATCH 01/13] wip Signed-off-by: Kourosh Hakhamaneshi --- vllm/executor/ray_utils.py | 19 +++ vllm/v1/executor/ray_distributed_executor.py | 118 +++++++++++++++++-- 2 files changed, 129 insertions(+), 8 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index c222f1609096..6524e98c52c5 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -17,6 +17,12 @@ from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase + +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +import copy + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput @@ -134,6 +140,19 @@ def execute_model_ray( scheduler_output, intermediate_tensors = scheduler_output, None output = self.worker.model_runner.execute_model( scheduler_output, intermediate_tensors) + + logger.info(f"in the ray_utils.py ...") + if has_kv_transfer_group(): + finished_sending, finished_recving = ( + get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids)) + if finished_sending or finished_recving: + if not output or output is EMPTY_MODEL_RUNNER_OUTPUT: + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + logger.info(f"Have succesfully set finished_sending: {finished_sending}, finished_recving: {finished_recving}") + if isinstance(output, IntermediateTensors): output = scheduler_output, output return output diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index daca7c0faf66..f910c6993be7 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -1,14 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from concurrent.futures import Future -from typing import Union +from collections import defaultdict +from concurrent.futures import CancelledError, Future +from typing import Optional, Sequence, Union, cast from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput +from vllm.logger import init_logger, logger +init_logger("vllm") class FutureWrapper(Future): """A wrapper around a Ray output reference to meet the interface @@ -28,6 +31,19 @@ def result(self, timeout=None): class RayDistributedExecutor(RayDistributedExecutorV0, Executor): """Ray distributed executor using Ray Compiled Graphs.""" + def _init_executor(self) -> None: + super()._init_executor() + + # KV connector setup + self.has_connector = self.vllm_config.kv_transfer_config is not None + + # Complete transfer tracker. Used by to track finished requests + # [req_id -> n_finished_workers] + self._recv_remaining_count = defaultdict[str, + int](lambda: self.parallel_config.world_size) + self._send_remaining_count = defaultdict[str, + int](lambda: self.parallel_config.world_size) + @property def max_concurrent_batches(self) -> int: """Ray distributed executor supports pipeline parallelism, @@ -37,6 +53,80 @@ def max_concurrent_batches(self) -> int: return 2 return self.parallel_config.pipeline_parallel_size + def _aggregate_workers_output( + self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: + # aggregate finished_sending, finished_recving from all workers + + finished_sending = set[str]() + finished_recving = set[str]() + for output in outputs: + # update finished_sending + for req_id in output.finished_sending or []: + new_count = self._send_remaining_count[req_id] - 1 + if new_count == 0: + # got response from all workers, report back to scheduler + finished_sending.add(req_id) + del self._send_remaining_count[req_id] + else: + self._send_remaining_count[req_id] = new_count + + # update finished_recving + for req_id in output.finished_recving or []: + new_count = self._recv_remaining_count[req_id] - 1 + if new_count == 0: + # got response from all workers, report back to scheduler + finished_recving.add(req_id) + del self._recv_remaining_count[req_id] + else: + self._recv_remaining_count[req_id] = new_count + + # select output of the worker specified by output_rank + output = outputs[0] + + # set the aggregated finished_sending / finished_recving + if finished_sending: + output.finished_sending = finished_sending + if finished_recving: + output.finished_recving = finished_recving + + return output + + def _async_aggregate_workers_output( + self, output_futures: Sequence[Union[Future[ModelRunnerOutput], FutureWrapper]] + ) -> Future[ModelRunnerOutput]: + """Takes a list of futures and returns a single future which resolves + to the respective list of outputs.""" + result_future: Future[ModelRunnerOutput] = Future() + + outputs: list[Optional[ModelRunnerOutput]] = [None + ] * len(output_futures) + + def make_callback(idx): + + def callback(fut): + if result_future.done(): + return + + try: + outputs[idx] = fut.result() + except CancelledError: + result_future.cancel() + except Exception as e: + result_future.set_exception(e) + + # Check if all outputs are ready + if all(outputs): + result_future.set_result( + self._aggregate_workers_output( + cast(list[ModelRunnerOutput], outputs))) + + return callback + + for i, output_future in enumerate(output_futures): + output_future.add_done_callback(make_callback(i)) + + return result_future + def execute_model( self, scheduler_output, @@ -55,10 +145,22 @@ def execute_model( refs = self.forward_dag.execute(scheduler_output) # type: ignore - # When PP is not used, we block here until the result is available. - if self.max_concurrent_batches == 1: - return refs[0].get() + if not self.has_connector: + # get output only from a single worker (output_rank) + # When PP is not used, we block here until the result is available. + if self.max_concurrent_batches == 1: + return refs[0].get() - # When PP is used, we return a FutureWrapper immediately so that - # the scheduler can yield to the next batch. - return FutureWrapper(refs[0]) + # When PP is used, we return a FutureWrapper immediately so that + # the scheduler can yield to the next batch. + return FutureWrapper(refs[0]) + + # get output from all workers when connector is present + if self.max_concurrent_batches == 1: + # Block and get results from all workers + outputs = [ref.get() for ref in refs] + return self._aggregate_workers_output(outputs) + else: + # Return a future that will aggregate outputs from all workers + output_futures = [FutureWrapper(ref) for ref in refs] + return self._async_aggregate_workers_output(output_futures) From b629b86b68d66d90d3cc0fee3ff122bcd51be8bd Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 12:00:34 -0700 Subject: [PATCH 02/13] wip Signed-off-by: Kourosh Hakhamaneshi --- vllm/executor/ray_utils.py | 18 ------------ vllm/v1/worker/gpu_model_runner.py | 44 ++++++++++++++++++++++++++---- vllm/v1/worker/gpu_worker.py | 16 ----------- 3 files changed, 38 insertions(+), 40 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 6524e98c52c5..433af4cd4e52 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -18,11 +18,6 @@ from vllm.worker.worker_base import WorkerWrapperBase -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT -import copy - if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput @@ -140,19 +135,6 @@ def execute_model_ray( scheduler_output, intermediate_tensors = scheduler_output, None output = self.worker.model_runner.execute_model( scheduler_output, intermediate_tensors) - - logger.info(f"in the ray_utils.py ...") - if has_kv_transfer_group(): - finished_sending, finished_recving = ( - get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids)) - if finished_sending or finished_recving: - if not output or output is EMPTY_MODEL_RUNNER_OUTPUT: - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving - logger.info(f"Have succesfully set finished_sending: {finished_sending}, finished_recving: {finished_recving}") - if isinstance(output, IntermediateTensors): output = scheduler_output, output return output diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c900..047ce6b5a1fd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import gc import time import weakref @@ -1233,6 +1234,8 @@ def _pool( hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]], + finished_recving: Optional[set[str]], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1267,6 +1270,8 @@ def _pool( logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, + finished_sending=finished_sending, + finished_recving=finished_recving, ) @torch.inference_mode() @@ -1277,12 +1282,12 @@ def execute_model( ) -> Union[ModelRunnerOutput, IntermediateTensors]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - if has_kv_transfer_group(): - with set_forward_context(None, self.vllm_config): - self.maybe_setup_kv_connector(scheduler_output) + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output) - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, @@ -1375,6 +1380,8 @@ def execute_model( ) self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1400,7 +1407,7 @@ def execute_model( else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + num_scheduled_tokens_np, finished_sending, finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1694,6 +1701,31 @@ def maybe_wait_for_kv_save() -> None: if has_kv_transfer_group(): get_kv_transfer_group().wait_for_save() + @staticmethod + def get_finished_kv_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None + + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output + def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6458b55777a4..5fab2a7382dc 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -332,22 +332,6 @@ def execute_model( output = EMPTY_MODEL_RUNNER_OUTPUT assert isinstance(output, ModelRunnerOutput) - if has_kv_transfer_group(): - finished_sending, finished_recving = ( - get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids)) - if finished_sending or finished_recving: - if output is EMPTY_MODEL_RUNNER_OUTPUT: - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving - - # Clear KVConnector state for this step. - get_kv_transfer_group().clear_connector_metadata() - - # with a connector, the scheduler expects output from all workers - return output - # return output only from the driver worker return output if self.is_driver_worker else None From c0f9c927a2eec840472b1bc33216ecc945ce9e86 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 14:27:55 -0700 Subject: [PATCH 03/13] wip Signed-off-by: Kourosh Hakhamaneshi --- .../kv_transfer/kv_connector/utils.py | 91 +++++++++++++++ .../kv_transfer/kv_connector/v1/base.py | 2 +- vllm/v1/executor/multiproc_executor.py | 86 +------------- vllm/v1/executor/ray_distributed_executor.py | 110 ++++-------------- 4 files changed, 119 insertions(+), 170 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 5cbc8ca31752..9fb5da0ffc1e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,12 +3,16 @@ """ KV cache helper for store. """ +from concurrent.futures import CancelledError, Future +from collections import defaultdict +from typing import Sequence, Optional, cast import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger +from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -107,3 +111,90 @@ def get_kv_connector_cache_layout(): "layout to HND for better xfer performance.") return "HND" return "NHD" + + + + +class KVOutputAggregator: + """Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.""" + + def __init__(self, world_size: int): + self.world_size = world_size + # Complete transfer tracker. Used by to track finished requests + # [req_id -> n_finished_workers] + self._recv_remaining_count = defaultdict[str, + int](lambda: world_size) + self._send_remaining_count = defaultdict[str, + int](lambda: world_size) + + def aggregate(self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: + # aggregate finished_sending, finished_recving from all workers + + finished_sending = set[str]() + finished_recving = set[str]() + for output in outputs: + # update finished_sending + for req_id in output.finished_sending or []: + new_count = self._send_remaining_count[req_id] - 1 + if new_count == 0: + # got response from all workers, report back to scheduler + finished_sending.add(req_id) + del self._send_remaining_count[req_id] + else: + self._send_remaining_count[req_id] = new_count + + # update finished_recving + for req_id in output.finished_recving or []: + new_count = self._recv_remaining_count[req_id] - 1 + if new_count == 0: + # got response from all workers, report back to scheduler + finished_recving.add(req_id) + del self._recv_remaining_count[req_id] + else: + self._recv_remaining_count[req_id] = new_count + + # select output of the worker specified by output_rank + output = outputs[0] + + # set the aggregated finished_sending / finished_recving + if finished_sending: + output.finished_sending = finished_sending + if finished_recving: + output.finished_recving = finished_recving + + return output + + + def async_aggregate(self, output_futures: Sequence[Future[ModelRunnerOutput]]) -> Future[ModelRunnerOutput]: + """Takes a list of futures and returns a single future which resolves + to the respective list of outputs.""" + result_future: Future[ModelRunnerOutput] = Future() + + outputs: list[Optional[ModelRunnerOutput]] = [None + ] * len(output_futures) + + def make_callback(idx): + + def callback(fut): + if result_future.done(): + return + + try: + outputs[idx] = fut.result() + except CancelledError: + result_future.cancel() + except Exception as e: + result_future.set_exception(e) + + # Check if all outputs are ready + if all(outputs): + result_future.set_result( + self.aggregate( + cast(list[ModelRunnerOutput], outputs))) + + return callback + + for i, output_future in enumerate(output_futures): + output_future.add_done_callback(make_callback(i)) + + return result_future \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 9459ab27aba3..e1245775bea3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -194,7 +194,7 @@ def get_finished( """ Notifies worker-side connector ids of requests that have finished generating tokens on the worker. - The scheduler process (via the MultiprocExecutor) will use this output + The scheduler process (via the Executors) will use this output to track which workers are done. Returns: diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index d29da55ce885..c90af3daccd8 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -9,8 +9,7 @@ import time import traceback import weakref -from collections import defaultdict -from concurrent.futures import CancelledError, Future, ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto from functools import partial @@ -35,6 +34,7 @@ from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator logger = init_logger(__name__) @@ -118,13 +118,7 @@ def _init_executor(self) -> None: self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None - - # Complete transfer tracker. Used by to track finished requests - # [req_id -> n_finished_workers] - self._recv_remaining_count = defaultdict[str, - int](lambda: self.world_size) - self._send_remaining_count = defaultdict[str, - int](lambda: self.world_size) + self.kv_output_aggregator = KVOutputAggregator(self.parallel_config.world_size) def start_worker_monitor(self): workers = self.workers @@ -186,8 +180,8 @@ def execute_model( # aggregate all workers output to a single output if non_block: - return self._async_aggregate_workers_output(outputs) - return self._aggregate_workers_output(outputs) + return self.kv_output_aggregator.async_aggregate(outputs) + return self.kv_output_aggregator.aggregate(outputs) def collective_rpc(self, method: Union[str, Callable], @@ -246,76 +240,6 @@ def get_response(w: WorkerProcHandle, except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e - def _aggregate_workers_output( - self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: - # aggregate finished_sending, finished_recving from all workers - - def update_finished_set(req_ids: Optional[set[str]], - remaining_count_dict: dict[str, int], - finished_set: set[str]) -> None: - for req_id in req_ids or (): - new_count = remaining_count_dict[req_id] - 1 - if new_count == 0: - finished_set.add(req_id) - del remaining_count_dict[req_id] - else: - remaining_count_dict[req_id] = new_count - - finished_sending = set[str]() - finished_recving = set[str]() - for output in outputs: - update_finished_set(output.finished_sending, - self._send_remaining_count, finished_sending) - update_finished_set(output.finished_recving, - self._recv_remaining_count, finished_recving) - - # select output of the worker specified by output_rank - output = outputs[self.output_rank] - - # set the aggregated finished_sending / finished_recving - if finished_sending: - output.finished_sending = finished_sending - if finished_recving: - output.finished_recving = finished_recving - - return output - - def _async_aggregate_workers_output( - self, output_futures: list[Future[ModelRunnerOutput]] - ) -> (Future[ModelRunnerOutput]): - """Takes a list of futures and returns a single future which resolves - to the respective list of outputs.""" - result_future: Future[ModelRunnerOutput] = Future() - - outputs: list[Optional[ModelRunnerOutput]] = [None - ] * len(output_futures) - - def make_callback(idx): - - def callback(fut): - if result_future.done(): - return - - try: - outputs[idx] = fut.result() - except CancelledError: - result_future.cancel() - except Exception as e: - result_future.set_exception(e) - - # this check assumes io_thread_pool uses a single thread - if all(outputs): - result_future.set_result( - self._aggregate_workers_output( - cast(list[ModelRunnerOutput], outputs))) - - return callback - - for i, output_future in enumerate(output_futures): - output_future.add_done_callback(make_callback(i)) - - return result_future - @staticmethod def _ensure_worker_termination(worker_procs: list[BaseProcess]): """Ensure that all worker processes are terminated. Assumes workers have diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index f910c6993be7..42716c1af0c0 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -1,17 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict -from concurrent.futures import CancelledError, Future -from typing import Optional, Sequence, Union, cast +from concurrent.futures import Future +from typing import Union from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput -from vllm.logger import init_logger, logger +from vllm.logger import init_logger +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -init_logger("vllm") +logger = init_logger(__name__) class FutureWrapper(Future): """A wrapper around a Ray output reference to meet the interface @@ -27,6 +27,20 @@ def result(self, timeout=None): raise NotImplementedError("timeout is not supported") return self.ref.get() +class FutureWrapperWithAggregation(Future): + def __init__(self, refs, aggregator: KVOutputAggregator): + super().__init__() + self.refs = refs + self.aggregator = aggregator + + def result(self, timeout=None): + if timeout is not None: + raise NotImplementedError("timeout is not supported") + + # get all refs and aggregate the outputs and return the first one + outputs = [ref.get() for ref in self.refs] + return self.aggregator.aggregate(outputs) + class RayDistributedExecutor(RayDistributedExecutorV0, Executor): """Ray distributed executor using Ray Compiled Graphs.""" @@ -36,13 +50,7 @@ def _init_executor(self) -> None: # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None - - # Complete transfer tracker. Used by to track finished requests - # [req_id -> n_finished_workers] - self._recv_remaining_count = defaultdict[str, - int](lambda: self.parallel_config.world_size) - self._send_remaining_count = defaultdict[str, - int](lambda: self.parallel_config.world_size) + self.kv_output_aggregator = KVOutputAggregator(self.parallel_config.world_size) @property def max_concurrent_batches(self) -> int: @@ -53,79 +61,6 @@ def max_concurrent_batches(self) -> int: return 2 return self.parallel_config.pipeline_parallel_size - def _aggregate_workers_output( - self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: - # aggregate finished_sending, finished_recving from all workers - - finished_sending = set[str]() - finished_recving = set[str]() - for output in outputs: - # update finished_sending - for req_id in output.finished_sending or []: - new_count = self._send_remaining_count[req_id] - 1 - if new_count == 0: - # got response from all workers, report back to scheduler - finished_sending.add(req_id) - del self._send_remaining_count[req_id] - else: - self._send_remaining_count[req_id] = new_count - - # update finished_recving - for req_id in output.finished_recving or []: - new_count = self._recv_remaining_count[req_id] - 1 - if new_count == 0: - # got response from all workers, report back to scheduler - finished_recving.add(req_id) - del self._recv_remaining_count[req_id] - else: - self._recv_remaining_count[req_id] = new_count - - # select output of the worker specified by output_rank - output = outputs[0] - - # set the aggregated finished_sending / finished_recving - if finished_sending: - output.finished_sending = finished_sending - if finished_recving: - output.finished_recving = finished_recving - - return output - - def _async_aggregate_workers_output( - self, output_futures: Sequence[Union[Future[ModelRunnerOutput], FutureWrapper]] - ) -> Future[ModelRunnerOutput]: - """Takes a list of futures and returns a single future which resolves - to the respective list of outputs.""" - result_future: Future[ModelRunnerOutput] = Future() - - outputs: list[Optional[ModelRunnerOutput]] = [None - ] * len(output_futures) - - def make_callback(idx): - - def callback(fut): - if result_future.done(): - return - - try: - outputs[idx] = fut.result() - except CancelledError: - result_future.cancel() - except Exception as e: - result_future.set_exception(e) - - # Check if all outputs are ready - if all(outputs): - result_future.set_result( - self._aggregate_workers_output( - cast(list[ModelRunnerOutput], outputs))) - - return callback - - for i, output_future in enumerate(output_futures): - output_future.add_done_callback(make_callback(i)) - - return result_future def execute_model( self, @@ -159,8 +94,7 @@ def execute_model( if self.max_concurrent_batches == 1: # Block and get results from all workers outputs = [ref.get() for ref in refs] - return self._aggregate_workers_output(outputs) + return self.kv_output_aggregator.aggregate(outputs) else: # Return a future that will aggregate outputs from all workers - output_futures = [FutureWrapper(ref) for ref in refs] - return self._async_aggregate_workers_output(output_futures) + return FutureWrapperWithAggregation(refs, self.kv_output_aggregator) From 80d861e8a78d176e009e0449c6ca5c4a9d88d642 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 14:38:11 -0700 Subject: [PATCH 04/13] wip Signed-off-by: Kourosh Hakhamaneshi --- vllm/v1/executor/ray_distributed_executor.py | 51 +++++++++----------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 42716c1af0c0..32a8aa573319 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -1,45 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from concurrent.futures import Future -from typing import Union +from concurrent.futures import Future +from typing import Optional, Union +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.logger import init_logger from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput -from vllm.logger import init_logger -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator logger = init_logger(__name__) + class FutureWrapper(Future): """A wrapper around a Ray output reference to meet the interface - of .execute_model(). + of .execute_model(): The top level (core busy loop) expects .result() api + to block and return a single output. + + If aggregator is provided, the outputs from all workers are aggregated upon + the result() call. If not only the first worker's output is returned. """ - def __init__(self, ref): - super().__init__() - self.ref = ref - - def result(self, timeout=None): - if timeout is not None: - raise NotImplementedError("timeout is not supported") - return self.ref.get() - -class FutureWrapperWithAggregation(Future): - def __init__(self, refs, aggregator: KVOutputAggregator): + def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None): super().__init__() self.refs = refs self.aggregator = aggregator - + def result(self, timeout=None): if timeout is not None: raise NotImplementedError("timeout is not supported") - # get all refs and aggregate the outputs and return the first one - outputs = [ref.get() for ref in self.refs] - return self.aggregator.aggregate(outputs) + if self.aggregator is None: + return self.refs[0].get() + else: + outputs = [ref.get() for ref in self.refs] + return self.aggregator.aggregate(outputs) class RayDistributedExecutor(RayDistributedExecutorV0, Executor): @@ -47,10 +44,11 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): def _init_executor(self) -> None: super()._init_executor() - + # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator(self.parallel_config.world_size) + self.kv_output_aggregator = KVOutputAggregator( + self.parallel_config.world_size) @property def max_concurrent_batches(self) -> int: @@ -61,7 +59,6 @@ def max_concurrent_batches(self) -> int: return 2 return self.parallel_config.pipeline_parallel_size - def execute_model( self, scheduler_output, @@ -81,20 +78,20 @@ def execute_model( refs = self.forward_dag.execute(scheduler_output) # type: ignore if not self.has_connector: - # get output only from a single worker (output_rank) + # Get output only from a single worker (output_rank) # When PP is not used, we block here until the result is available. if self.max_concurrent_batches == 1: return refs[0].get() # When PP is used, we return a FutureWrapper immediately so that # the scheduler can yield to the next batch. - return FutureWrapper(refs[0]) + return FutureWrapper(refs) - # get output from all workers when connector is present + # Get output from all workers when connector is present if self.max_concurrent_batches == 1: # Block and get results from all workers outputs = [ref.get() for ref in refs] return self.kv_output_aggregator.aggregate(outputs) else: # Return a future that will aggregate outputs from all workers - return FutureWrapperWithAggregation(refs, self.kv_output_aggregator) + return FutureWrapper(refs, self.kv_output_aggregator) From 1c63f8ea1c4de68583807d4c72dcbffe94df3751 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 14:52:53 -0700 Subject: [PATCH 05/13] wip Signed-off-by: Kourosh Hakhamaneshi --- .../kv_transfer/kv_connector/utils.py | 84 +++++++++---------- vllm/v1/executor/multiproc_executor.py | 10 ++- vllm/v1/executor/ray_distributed_executor.py | 8 +- 3 files changed, 50 insertions(+), 52 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 9fb5da0ffc1e..a06e3e94b58a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,16 +3,18 @@ """ KV cache helper for store. """ -from concurrent.futures import CancelledError, Future from collections import defaultdict -from typing import Sequence, Optional, cast +from collections.abc import Sequence +from concurrent.futures import CancelledError, Future +from typing import Optional, cast + import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -113,59 +115,53 @@ def get_kv_connector_cache_layout(): return "NHD" - - class KVOutputAggregator: - """Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.""" - + """Utility class to aggregate the output of all workers into a single + output corresponding to Rank 0 for scheduler.""" + def __init__(self, world_size: int): self.world_size = world_size # Complete transfer tracker. Used by to track finished requests # [req_id -> n_finished_workers] - self._recv_remaining_count = defaultdict[str, - int](lambda: world_size) - self._send_remaining_count = defaultdict[str, - int](lambda: world_size) - - def aggregate(self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: + self._recv_remaining_count = defaultdict[str, int](lambda: world_size) + self._send_remaining_count = defaultdict[str, int](lambda: world_size) + + def aggregate(self, + outputs: list[ModelRunnerOutput], + output_rank: int = 0) -> ModelRunnerOutput: # aggregate finished_sending, finished_recving from all workers - finished_sending = set[str]() - finished_recving = set[str]() - for output in outputs: - # update finished_sending - for req_id in output.finished_sending or []: - new_count = self._send_remaining_count[req_id] - 1 + def update_finished_set(req_ids: Optional[set[str]], + remaining_count_dict: dict[str, int], + finished_set: set[str]) -> None: + for req_id in req_ids or (): + new_count = remaining_count_dict[req_id] - 1 if new_count == 0: - # got response from all workers, report back to scheduler - finished_sending.add(req_id) - del self._send_remaining_count[req_id] + finished_set.add(req_id) + del remaining_count_dict[req_id] else: - self._send_remaining_count[req_id] = new_count + remaining_count_dict[req_id] = new_count - # update finished_recving - for req_id in output.finished_recving or []: - new_count = self._recv_remaining_count[req_id] - 1 - if new_count == 0: - # got response from all workers, report back to scheduler - finished_recving.add(req_id) - del self._recv_remaining_count[req_id] - else: - self._recv_remaining_count[req_id] = new_count + finished_sending = set[str]() + finished_recving = set[str]() + for output in outputs: + update_finished_set(output.finished_sending, + self._send_remaining_count, finished_sending) + update_finished_set(output.finished_recving, + self._recv_remaining_count, finished_recving) # select output of the worker specified by output_rank - output = outputs[0] + output = outputs[output_rank] # set the aggregated finished_sending / finished_recving - if finished_sending: - output.finished_sending = finished_sending - if finished_recving: - output.finished_recving = finished_recving + output.finished_sending = finished_sending if finished_sending else None + output.finished_recving = finished_recving if finished_recving else None return output - - - def async_aggregate(self, output_futures: Sequence[Future[ModelRunnerOutput]]) -> Future[ModelRunnerOutput]: + + def async_aggregate(self, + output_futures: Sequence[Future[ModelRunnerOutput]], + output_rank: int = 0) -> Future[ModelRunnerOutput]: """Takes a list of futures and returns a single future which resolves to the respective list of outputs.""" result_future: Future[ModelRunnerOutput] = Future() @@ -186,15 +182,15 @@ def callback(fut): except Exception as e: result_future.set_exception(e) - # Check if all outputs are ready + # this check assumes io_thread_pool uses a single thread if all(outputs): result_future.set_result( - self.aggregate( - cast(list[ModelRunnerOutput], outputs))) + self.aggregate(cast(list[ModelRunnerOutput], outputs), + output_rank)) return callback for i, output_future in enumerate(output_futures): output_future.add_done_callback(make_callback(i)) - return result_future \ No newline at end of file + return result_future diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index c90af3daccd8..87a607a341b5 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -26,6 +26,7 @@ destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger @@ -34,7 +35,6 @@ from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator logger = init_logger(__name__) @@ -118,7 +118,8 @@ def _init_executor(self) -> None: self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator(self.parallel_config.world_size) + self.kv_output_aggregator = KVOutputAggregator( + self.parallel_config.world_size) def start_worker_monitor(self): workers = self.workers @@ -180,8 +181,9 @@ def execute_model( # aggregate all workers output to a single output if non_block: - return self.kv_output_aggregator.async_aggregate(outputs) - return self.kv_output_aggregator.aggregate(outputs) + return self.kv_output_aggregator.async_aggregate( + outputs, self.output_rank) + return self.kv_output_aggregator.aggregate(outputs, self.output_rank) def collective_rpc(self, method: Union[str, Callable], diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 32a8aa573319..79d4a757e2e3 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -36,7 +36,7 @@ def result(self, timeout=None): return self.refs[0].get() else: outputs = [ref.get() for ref in self.refs] - return self.aggregator.aggregate(outputs) + return self.aggregator.aggregate(outputs, output_rank=0) class RayDistributedExecutor(RayDistributedExecutorV0, Executor): @@ -92,6 +92,6 @@ def execute_model( # Block and get results from all workers outputs = [ref.get() for ref in refs] return self.kv_output_aggregator.aggregate(outputs) - else: - # Return a future that will aggregate outputs from all workers - return FutureWrapper(refs, self.kv_output_aggregator) + + # Return a future that will aggregate outputs from all workers + return FutureWrapper(refs, self.kv_output_aggregator) From 913cd5229c6af5aefb9e010f59f1c03078ab157c Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 15:11:54 -0700 Subject: [PATCH 06/13] wip Signed-off-by: Kourosh Hakhamaneshi --- .../unit/test_output_aggreagator.py | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/v1/kv_connector/unit/test_output_aggreagator.py diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py new file mode 100644 index 000000000000..cad73f68e9f1 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from concurrent.futures import Future +from typing import Optional + +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.v1.outputs import ModelRunnerOutput + + +class DummyModelRunnerOutput(ModelRunnerOutput): + + def __init__(self, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None): + self.finished_sending = finished_sending + self.finished_recving = finished_recving + + +def test_aggregate_workers_output(): + aggregator = KVOutputAggregator(world_size=2) + + output1 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + output2 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving is None + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving=None) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + assert aggregated.finished_sending == {'req1'} + assert aggregated.finished_recving is None + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving == {'req2'} + + +def test_async_aggregate_workers_output(): + aggregator = KVOutputAggregator(world_size=2) + + future1: Future[DummyModelRunnerOutput] = Future() + future2: Future[DummyModelRunnerOutput] = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + output2 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving is None + + future1 = Future() + future2 = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving=None) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + assert aggregated.finished_sending == {'req1'} + assert aggregated.finished_recving is None + + future1 = Future() + future2 = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving == {'req2'} From 9d4c583307f77db64fae21ef6534f08462eb5a7d Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 15:15:06 -0700 Subject: [PATCH 07/13] wip Signed-off-by: Kourosh Hakhamaneshi --- vllm/executor/ray_utils.py | 1 - vllm/v1/worker/gpu_model_runner.py | 4 ++-- vllm/v1/worker/gpu_worker.py | 5 +---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 433af4cd4e52..c222f1609096 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -17,7 +17,6 @@ from vllm.utils import get_ip from vllm.worker.worker_base import WorkerWrapperBase - if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 047ce6b5a1fd..7317552bf6af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1288,7 +1288,6 @@ def execute_model( return self.kv_connector_no_forward(scheduler_output) - # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, @@ -1407,7 +1406,8 @@ def execute_model( else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, finished_recving) + num_scheduled_tokens_np, finished_sending, + finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 5fab2a7382dc..c375962ce5bb 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" -import copy import gc import os from typing import TYPE_CHECKING, Any, Optional @@ -15,9 +14,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest From 2013ef6b77d6fb6aaa2df0264ba8e8e0d8925239 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Wed, 16 Jul 2025 22:23:53 -0700 Subject: [PATCH 08/13] wip Signed-off-by: Kourosh Hakhamaneshi --- vllm/v1/executor/ray_distributed_executor.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 79d4a757e2e3..5249a41ed6c0 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -15,7 +15,7 @@ class FutureWrapper(Future): - """A wrapper around a Ray output reference to meet the interface + """A wrapper around Ray output reference to meet the interface of .execute_model(): The top level (core busy loop) expects .result() api to block and return a single output. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7317552bf6af..07ce01d5e2d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1556,6 +1556,8 @@ def execute_model( logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], + finished_sending=finished_sending, + finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, ) From 7e4bf720d2f11d145785d73110f70071941558c3 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Thu, 17 Jul 2025 09:36:27 -0700 Subject: [PATCH 09/13] wip Signed-off-by: Kourosh Hakhamaneshi --- vllm/v1/worker/gpu_worker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c375962ce5bb..4de1a5e59b22 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" +import copy import gc import os from typing import TYPE_CHECKING, Any, Optional @@ -326,7 +327,10 @@ def execute_model( assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - output = EMPTY_MODEL_RUNNER_OUTPUT + empty_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + empty_output.finished_sending = output.finished_sending + empty_output.finished_recving = output.finished_recving + output = empty_output assert isinstance(output, ModelRunnerOutput) # return output only from the driver worker From 5e97ce6bab25bfde6704bcdfd307c7033e1fe29c Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Thu, 17 Jul 2025 09:42:48 -0700 Subject: [PATCH 10/13] wip Signed-off-by: Kourosh Hakhamaneshi --- tests/v1/executor/test_multiproc_executor.py | 127 ------------------ .../kv_transfer/kv_connector/utils.py | 4 + 2 files changed, 4 insertions(+), 127 deletions(-) delete mode 100644 tests/v1/executor/test_multiproc_executor.py diff --git a/tests/v1/executor/test_multiproc_executor.py b/tests/v1/executor/test_multiproc_executor.py deleted file mode 100644 index c1425d82becf..000000000000 --- a/tests/v1/executor/test_multiproc_executor.py +++ /dev/null @@ -1,127 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import threading -from collections import defaultdict -from concurrent.futures import Future -from typing import Optional - -from vllm.v1.executor.multiproc_executor import MultiprocExecutor -from vllm.v1.outputs import ModelRunnerOutput - - -class DummyMultiprocExecutor(MultiprocExecutor): - - def __init__(self, output_rank, world_size): - # Manually initialize minimal required fields - self.output_rank = output_rank - self.world_size = world_size - self._send_remaining_count = defaultdict[str, - int](lambda: self.world_size) - self._recv_remaining_count = defaultdict[str, - int](lambda: self.world_size) - self.io_thread_pool = None - self.shutdown_event = threading.Event() - - -class DummyModelRunnerOutput(ModelRunnerOutput): - - def __init__(self, - finished_sending: Optional[set[str]] = None, - finished_recving: Optional[set[str]] = None): - self.finished_sending = finished_sending - self.finished_recving = finished_recving - - -def test_aggregate_workers_output(): - executor = DummyMultiprocExecutor(output_rank=0, world_size=2) - - output1 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) - output2 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - - aggregated = executor._aggregate_workers_output([output1, output2]) - - assert aggregated is output1 - assert aggregated.finished_sending is None - assert aggregated.finished_recving is None - - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving=None) - - aggregated = executor._aggregate_workers_output([output1, output2]) - - assert aggregated is output1 - assert aggregated.finished_sending == {'req1'} - assert aggregated.finished_recving is None - - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) - - aggregated = executor._aggregate_workers_output([output1, output2]) - - assert aggregated is output1 - assert aggregated.finished_sending is None - assert aggregated.finished_recving == {'req2'} - - -def test_async_aggregate_workers_output(): - executor = DummyMultiprocExecutor(output_rank=0, world_size=2) - - future1: Future[DummyModelRunnerOutput] = Future() - future2: Future[DummyModelRunnerOutput] = Future() - result_future = executor._async_aggregate_workers_output( - [future1, future2]) - - output1 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) - output2 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - future1.set_result(output1) - future2.set_result(output2) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - assert aggregated.finished_sending is None - assert aggregated.finished_recving is None - - future1 = Future() - future2 = Future() - result_future = executor._async_aggregate_workers_output( - [future1, future2]) - - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving=None) - future1.set_result(output1) - future2.set_result(output2) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - assert aggregated.finished_sending == {'req1'} - assert aggregated.finished_recving is None - - future1 = Future() - future2 = Future() - result_future = executor._async_aggregate_workers_output( - [future1, future2]) - - output1 = DummyModelRunnerOutput(finished_sending=None, - finished_recving=None) - output2 = DummyModelRunnerOutput(finished_sending={'req1'}, - finished_recving={'req2'}) - future1.set_result(output1) - future2.set_result(output2) - - assert result_future.done() - aggregated = result_future.result() - assert aggregated is output1 - assert aggregated.finished_sending is None - assert aggregated.finished_recving == {'req2'} diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index a06e3e94b58a..252cf99f9c02 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -154,6 +154,10 @@ def update_finished_set(req_ids: Optional[set[str]], output = outputs[output_rank] # set the aggregated finished_sending / finished_recving + # if output.finished_sending/recving is not empty, but the other ranks + # still have unfinished send/recv, we want to set the aggregated + # finished_sending/recving to None until all ranks have finished + # send/recv output.finished_sending = finished_sending if finished_sending else None output.finished_recving = finished_recving if finished_recving else None From f04be9f84d89e9761a323092b898c647202b9e16 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Thu, 17 Jul 2025 10:30:33 -0700 Subject: [PATCH 11/13] wip Signed-off-by: Kourosh Hakhamaneshi --- tests/v1/kv_connector/unit/test_nixl_connector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index c4f558b7acdb..0fc618105580 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -378,10 +378,14 @@ def test_concurrent_load_kv( raise TimeoutError("Took too long to complete async handshake.") +# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which +# we put here is important. First run ray, it will clean up the resources, then +# run mp tests. +@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper) -def test_abort_timeout_on_prefiller(monkeypatch): +def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): """ Test lifecycle of an aborted Remote Prefill request hitting the timeout. -----> P @@ -404,6 +408,7 @@ def test_abort_timeout_on_prefiller(monkeypatch): enforce_eager=True, gpu_memory_utilization=0.5, kv_transfer_config=kv_transfer_config, + distributed_executor_backend=distributed_executor_backend, ) remote_prefill_opts = { "do_remote_decode": True, From 23409ae349b5509aa0b454122e6cebe6aefd1d7e Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Fri, 18 Jul 2025 10:20:10 -0700 Subject: [PATCH 12/13] fixed ray tests Signed-off-by: Kourosh Hakhamaneshi --- .../kv_connector/unit/test_nixl_connector.py | 110 ++++++------------ vllm/mocks/__init__.py | 0 vllm/mocks/mock_nixl_connector.py | 76 ++++++++++++ 3 files changed, 112 insertions(+), 74 deletions(-) create mode 100644 vllm/mocks/__init__.py create mode 100644 vllm/mocks/mock_nixl_connector.py diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 0fc618105580..da31e6d551ba 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import tempfile +import textwrap import time -import uuid -from collections import defaultdict -from typing import Optional from unittest.mock import patch import pytest +import ray from vllm import LLM from vllm.config import KVTransferConfig @@ -15,11 +16,32 @@ KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) from vllm.forward_context import ForwardContext +from vllm.mocks.mock_nixl_connector import FakeNixlWrapper from vllm.sampling_params import SamplingParams from .utils import create_request, create_scheduler, create_vllm_config +def _make_stub_pkg() -> str: + """Return a directory that makes + `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.""" + td = tempfile.mkdtemp() + pkg_root = os.path.join(td, "nixl", "_api") + os.makedirs(pkg_root, exist_ok=True) + + stub = textwrap.dedent("""\ + # Forward the real FakeNixlWrapper that the driver already defined. + print("In fake package") + from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent + """) + with open(os.path.join(pkg_root, "__init__.py"), "w") as f: + f.write(stub) + + # touch parent package + open(os.path.join(td, "nixl", "__init__.py"), "w").close() + return td + + def test_basic_interface(): """Unit test for basic NixlConnector interface functionality.""" @@ -87,77 +109,6 @@ def test_prompt_less_than_block_size(): assert len(scheduler_output.scheduled_new_reqs) == 1 -class FakeNixlWrapper: - """Mock implementation of NixlWrapper for testing. - - We don't inherit from nixl._api.nixl_agent because nixl may not be - installed. - """ - - AGENT_METADATA = b"fake_agent_metadata" - REMOTE_AGENT_NAME = "remote_agent" - - def __init__(self, agent_name: str, *args, **kwargs): - self._cycles_before_xfer_done = 0 - self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( - lambda: 0) - - def get_reg_descs(self, caches_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in caches_data] - - def register_memory(self, descs) -> None: - pass - - def get_xfer_descs(self, blocks_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in blocks_data] - - def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: - return uuid.uuid4().int - - def get_agent_metadata(self) -> bytes: - return self.AGENT_METADATA - - def add_remote_agent(self, agent_metadata: bytes) -> str: - return self.REMOTE_AGENT_NAME - - def get_new_notifs(self) -> dict[str, list[bytes]]: - # Used to collect done_sending, which we don't test yet. - return {} - - def check_xfer_state(self, handle: int) -> str: - if self._check_xfer_state_cycles[ - handle] >= self._cycles_before_xfer_done: - return "DONE" - self._check_xfer_state_cycles[handle] += 1 - return "PROC" - - def release_xfer_handle(self, handle: int) -> None: - pass - - def send_notif(self, agent_name: str, notif_msg: bytes) -> None: - pass - - def make_prepped_xfer(self, - xfer_type: str, - local_xfer_side_handle: int, - local_block_descs_ids: list[int], - remote_xfer_side_handle: int, - remote_block_descs_ids: list[int], - notif_msg: Optional[bytes] = None) -> int: - return uuid.uuid4().int - - def transfer(self, handle: int) -> str: - return "PROC" - - ############################################################ - # Follow are for changing the behavior during testing. - ############################################################ - - def set_cycles_before_xfer_done(self, cycles: int): - """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles - - class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" @@ -403,6 +354,17 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): timeout = 6 monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + + # Build runtime_env only if we’re using Ray + if distributed_executor_backend == "ray": + runtime_env = { + "working_dir": _make_stub_pkg(), # ship stub package + "env_vars": { + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + }, + } + ray.init(runtime_env=runtime_env) + llm = LLM( model=model_name, enforce_eager=True, diff --git a/vllm/mocks/__init__.py b/vllm/mocks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/mocks/mock_nixl_connector.py b/vllm/mocks/mock_nixl_connector.py new file mode 100644 index 000000000000..54e2c5ee3b0a --- /dev/null +++ b/vllm/mocks/mock_nixl_connector.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import uuid +from collections import defaultdict +from typing import Optional + + +class FakeNixlWrapper: + """Mock implementation of NixlWrapper for testing. + + We don't inherit from nixl._api.nixl_agent because nixl may not be + installed. + """ + + AGENT_METADATA = b"fake_agent_metadata" + REMOTE_AGENT_NAME = "remote_agent" + + def __init__(self, agent_name: str, *args, **kwargs): + self._cycles_before_xfer_done = 0 + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( + lambda: 0) + + def get_reg_descs(self, caches_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in caches_data] + + def register_memory(self, descs) -> None: + pass + + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in blocks_data] + + def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: + return uuid.uuid4().int + + def get_agent_metadata(self) -> bytes: + return self.AGENT_METADATA + + def add_remote_agent(self, agent_metadata: bytes) -> str: + return self.REMOTE_AGENT_NAME + + def get_new_notifs(self) -> dict[str, list[bytes]]: + # Used to collect done_sending, which we don't test yet. + return {} + + def check_xfer_state(self, handle: int) -> str: + if self._check_xfer_state_cycles[ + handle] >= self._cycles_before_xfer_done: + return "DONE" + self._check_xfer_state_cycles[handle] += 1 + return "PROC" + + def release_xfer_handle(self, handle: int) -> None: + pass + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + pass + + def make_prepped_xfer(self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: Optional[bytes] = None) -> int: + return uuid.uuid4().int + + def transfer(self, handle: int) -> str: + return "PROC" + + ############################################################ + # Follow are for changing the behavior during testing. + ############################################################ + + def set_cycles_before_xfer_done(self, cycles: int): + """Set the number of cycles before a transfer is considered done.""" + self._cycles_before_xfer_done = cycles From c70c5c1dcccad0564a43e817c84827f5f8acf6a7 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Fri, 18 Jul 2025 18:11:24 -0700 Subject: [PATCH 13/13] addressing ci Signed-off-by: Kourosh Hakhamaneshi --- tests/v1/kv_connector/unit/test_nixl_connector.py | 4 ++-- vllm/distributed/kv_transfer/kv_connector/utils.py | 1 - vllm/sequence.py | 6 ++++++ vllm/v1/executor/ray_distributed_executor.py | 6 +++--- vllm/v1/worker/gpu_model_runner.py | 3 +++ vllm/v1/worker/gpu_worker.py | 11 ++++++++--- 6 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index da31e6d551ba..a0dfd54fb825 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -331,8 +331,8 @@ def test_concurrent_load_kv( # NOTE: resource cleanup in mp backend is a bit finicky, so the order in which # we put here is important. First run ray, it will clean up the resources, then -# run mp tests. -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) +# the rest of the tests. +@pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) @patch( "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", FakeNixlWrapper) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 252cf99f9c02..c179d6cc29b7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -120,7 +120,6 @@ class KVOutputAggregator: output corresponding to Rank 0 for scheduler.""" def __init__(self, world_size: int): - self.world_size = world_size # Complete transfer tracker. Used by to track finished requests # [req_id -> n_finished_workers] self._recv_remaining_count = defaultdict[str, int](lambda: world_size) diff --git a/vllm/sequence.py b/vllm/sequence.py index ffe890eb2dab..8786448646a9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1198,9 +1198,15 @@ class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. + + Each stage also needs to handle its own finished_sending and + finished_recving in case of kv transfer. """ tensors: dict[str, torch.Tensor] + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None def __init__(self, tensors): # manually define this function, so that diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 5249a41ed6c0..7716e0d0f998 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -34,9 +34,9 @@ def result(self, timeout=None): if self.aggregator is None: return self.refs[0].get() - else: - outputs = [ref.get() for ref in self.refs] - return self.aggregator.aggregate(outputs, output_rank=0) + + outputs = [ref.get() for ref in self.refs] + return self.aggregator.aggregate(outputs, output_rank=0) class RayDistributedExecutor(RayDistributedExecutorV0, Executor): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 40a41be5c1c9..3628443a408f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1393,6 +1393,9 @@ def execute_model( if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: + if finished_sending or finished_recving: + hidden_states.finished_sending = finished_sending + hidden_states.finished_recving = finished_recving return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict(hidden_states.tensors, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4de1a5e59b22..c5f5f259365c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -327,9 +327,14 @@ def execute_model( assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - empty_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - empty_output.finished_sending = output.finished_sending - empty_output.finished_recving = output.finished_recving + + # In case of PP with kv transfer, we need to pass through the + # finished_sending and finished_recving buffers. + empty_output = EMPTY_MODEL_RUNNER_OUTPUT + if output.finished_sending or output.finished_recving: + empty_output = copy.copy(empty_output) + empty_output.finished_sending = output.finished_sending + empty_output.finished_recving = output.finished_recving output = empty_output assert isinstance(output, ModelRunnerOutput)