Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/v1/executor/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncio
import os
from collections.abc import Callable
from concurrent.futures import Future
from typing import Any

import pytest
Expand All @@ -28,7 +27,7 @@ def collective_rpc(
kwargs: dict | None = None,
non_block: bool = False,
unique_reply_rank: int | None = None,
) -> Any | list[Any] | Future[Any | list[Any]]:
) -> list[Any]:
# Drop marker to show that this was run
with open(".marker", "w"):
...
Expand Down
32 changes: 20 additions & 12 deletions tests/v1/kv_connector/unit/test_output_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,14 @@ def test_aggregate_workers_output():
def test_async_aggregate_workers_output():
aggregator = KVOutputAggregator(expected_finished_count=2)

future: Future[list[DummyModelRunnerOutput]] = Future()
result_future = aggregator.async_aggregate(future)
future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future()
result_future = aggregator.async_aggregate([future1, future2])

output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
future.set_result([output1, output2])
future1.set_result(output1)
future2.set_result(output2)

assert result_future.done()
aggregated = result_future.result()
Expand All @@ -104,14 +106,16 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids

future = Future()
result_future = aggregator.async_aggregate(future)
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])

output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, finished_recving={"req2"}
)
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
future.set_result([output1, output2])
future1.set_result(output1)
future2.set_result(output2)

assert result_future.done()
aggregated = result_future.result()
Expand All @@ -121,12 +125,14 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {1}

future = Future()
result_future = aggregator.async_aggregate(future)
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])

output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
future.set_result([output1, output2])
future1.set_result(output1)
future2.set_result(output2)

assert result_future.done()
aggregated = result_future.result()
Expand All @@ -136,14 +142,16 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}

future = Future()
result_future = aggregator.async_aggregate(future)
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])

output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, invalid_block_ids={4, 5}
)
future.set_result([output1, output2])
future1.set_result(output1)
future2.set_result(output2)

assert result_future.done()
aggregated = result_future.result()
Expand Down
43 changes: 29 additions & 14 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,24 +221,39 @@ def update_finished_set(

def async_aggregate(
self,
output_future: Future[Sequence[ModelRunnerOutput | None]],
output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a future that resolves to a list of outputs and returns a future
which resolves to a single aggregated output."""
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput | None] = Future()

def callback(fut):
if result_future.done():
return
try:
result_future.set_result(self.aggregate(fut.result(), output_rank))
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)

output_future.add_done_callback(callback)
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)

def make_callback(idx):
def callback(fut):
if result_future.done():
return

try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)

# this check assumes io_thread_pool uses a single thread
nonlocal remaining
remaining -= 1
if not remaining:
result_future.set_result(self.aggregate(outputs, output_rank))

return callback

for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
Comment on lines +231 to +255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of async_aggregate is not thread-safe. The remaining counter is accessed and modified without a lock. While the comment on line 246 correctly points out the assumption of a single-threaded I/O pool, this design is fragile. If the ThreadPoolExecutor in MultiprocExecutor is ever configured with more than one worker, this will introduce a race condition, leading to incorrect behavior. To make this implementation robust and thread-safe, a lock should be used to protect the shared remaining counter.

        from threading import Lock
        outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
        remaining = len(output_futures)
        lock = Lock()

        def make_callback(idx):
            def callback(fut):
                if result_future.done():
                    return

                try:
                    outputs[idx] = fut.result()
                except CancelledError:
                    result_future.cancel()
                except Exception as e:
                    result_future.set_exception(e)

                with lock:
                    # This check is now thread-safe.
                    nonlocal remaining
                    remaining -= 1
                    if not remaining:
                        if not result_future.done():
                            result_future.set_result(self.aggregate(outputs, output_rank))

            return callback

        for i, output_future in enumerate(output_futures):
            output_future.add_done_callback(make_callback(i))


return result_future


Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def collective_rpc(
args: tuple = (),
kwargs: dict | None = None,
non_block: Literal[True] = True,
) -> Future[list[_R]]:
) -> list[Future[_R]]:
pass

@abstractmethod
Expand Down Expand Up @@ -219,7 +219,7 @@ def sample_tokens(

def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)
Expand Down
128 changes: 61 additions & 67 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
import time
import traceback
import weakref
from collections import deque
from collections.abc import Callable
from concurrent.futures import Future, InvalidStateError
from contextlib import suppress
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import cached_property, partial
Expand Down Expand Up @@ -56,30 +54,6 @@
logger = init_logger(__name__)


class FutureWrapper(Future):
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
self.futures_queue = futures_queue
super().__init__()

def result(self, timeout=None):
if timeout is not None:
raise RuntimeError("timeout not implemented")
# Drain any futures ahead of us in the queue.
while not self.done():
future, get_response = self.futures_queue.pop()
future.wait_for_response(get_response)
return super().result()

def wait_for_response(self, get_response: Callable):
try:
response = get_response()
with suppress(InvalidStateError):
self.set_result(response)
except Exception as e:
with suppress(InvalidStateError):
self.set_exception(e)


class MultiprocExecutor(Executor):
supports_pp: bool = True

Expand All @@ -90,6 +64,7 @@ def _init_executor(self) -> None:
self.is_failed = False
self.shutdown_event = threading.Event()
self.failure_callback: FailureCallback | None = None
self.io_thread_pool: ThreadPoolExecutor | None = None

self.world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
Expand Down Expand Up @@ -157,7 +132,12 @@ def _init_executor(self) -> None:
uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers])

self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)

self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
Expand Down Expand Up @@ -215,13 +195,14 @@ def _execute_with_aggregation(
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
return self.collective_rpc(
(output,) = self.collective_rpc(
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)
return output

# get output from all workers
outputs = self.collective_rpc(
Expand All @@ -242,21 +223,20 @@ def execute_dummy_batch(self) -> None:

def take_draft_token_ids(self) -> DraftTokenIds | None:
# OPTIMIZATION: Get output only from a single worker (output_rank)
return self.collective_rpc(
outputs = self.collective_rpc(
"take_draft_token_ids", unique_reply_rank=self.output_rank
)
return outputs[0]

def collective_rpc( # type: ignore[override]
def collective_rpc(
self,
method: str | Callable,
timeout: float | None = None,
args: tuple = (),
kwargs: dict | None = None,
non_block: bool = False,
unique_reply_rank: int | None = None,
) -> Any | list[Any] | Future[Any | list[Any]]:
"""Returns single result if unique_reply_rank is provided, otherwise list."""

) -> list[Any]:
if self.is_failed:
raise RuntimeError("Executor failed.")

Expand All @@ -266,52 +246,63 @@ def collective_rpc( # type: ignore[override]
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try:
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL
)
self.rpc_broadcast_mq.enqueue(
(send_method, args, kwargs, unique_reply_rank)
)

if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
workers = (
(self.workers[unique_reply_rank],)
if unique_reply_rank is not None
else self.workers
)
responses = []

workers = (
(self.workers[unique_reply_rank],)
if unique_reply_rank is not None
else self.workers
)
def get_response(
w: WorkerProcHandle,
dequeue_timeout: float | None = None,
cancel_event: threading.Event | None = None,
):
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=cancel_event
)

shutdown_event = self.shutdown_event
if status != WorkerProc.ResponseStatus.SUCCESS:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause"
)
return result

def get_response():
responses = []
for w in workers:
dequeue_timeout = (
None if deadline is None else (deadline - time.monotonic())
)
try:
status, result = w.worker_response_mq.dequeue(
timeout=dequeue_timeout, cancel=shutdown_event

if self.io_thread_pool is not None:
# We must consume worker_response_mq from a single thread.
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event
)
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
if status != WorkerProc.ResponseStatus.SUCCESS:
if not non_block:
result = result.result()
elif not non_block:
result = get_response(w, dequeue_timeout, self.shutdown_event)
else:
raise RuntimeError(
f"Worker failed with error '{result}', please check the"
" stack trace above for the root cause"
"non_block can only be used when max_concurrent_batches > 1"
)
responses.append(result)
return responses[0] if unique_reply_rank is not None else responses

if non_block:
future = FutureWrapper(self.futures_queue)
self.futures_queue.appendleft((future, get_response))
return future

# First drain any pending futures in the queue.
while self.futures_queue:
future, get_fut_response = self.futures_queue.pop()
future.wait_for_response(get_fut_response)

return get_response()
return responses
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e

@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
Expand Down Expand Up @@ -357,6 +348,9 @@ def shutdown(self):
self._ensure_worker_termination([w.proc for w in workers])

self.shutdown_event.set()
if self.io_thread_pool is not None:
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
del self.io_thread_pool

self.rpc_broadcast_mq = None

Expand Down
Loading