Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions vllm/v1/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
Comment thread
njhill marked this conversation as resolved.
from collections.abc import Callable
from concurrent.futures import Future, ThreadPoolExecutor
from concurrent.futures import Future
from functools import cached_property
from multiprocessing import Lock
from typing import Any
Expand All @@ -23,6 +23,25 @@
logger = init_logger(__name__)


class AsyncOutputFuture(Future):
def __init__(self, async_output: AsyncModelRunnerOutput, single_value: bool):
self.async_output = async_output
self.single_value = single_value
super().__init__()

def result(self, timeout=None):
if timeout is not None:
raise RuntimeError("timeout not implemented")

if not super().done():
try:
output = self.async_output.get_output()
self.set_result(output if self.single_value else [output])
except Exception as e:
self.set_exception(e)
return super().result()
Comment thread
njhill marked this conversation as resolved.


class UniProcExecutor(Executor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
Expand All @@ -37,12 +56,6 @@ def _init_executor(self) -> None:
shared_worker_lock=Lock(),
)

self.async_output_thread: ThreadPoolExecutor | None = None
if self.max_concurrent_batches > 1:
self.async_output_thread = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="WorkerAsyncOutput"
)

self.driver_worker.init_worker(all_kwargs=[kwargs])
self.driver_worker.init_device()

Expand Down Expand Up @@ -83,15 +96,7 @@ def collective_rpc( # type: ignore[override]
try:
result = run_method(self.driver_worker, method, args, kwargs)
if isinstance(result, AsyncModelRunnerOutput):
if (async_thread := self.async_output_thread) is not None:
if single_value:
return async_thread.submit(result.get_output)

def get_output_list() -> list[Any]:
return [result.get_output()]

return async_thread.submit(get_output_list)
result = result.get_output()
return AsyncOutputFuture(result, single_value)
future = Future[Any]()
future.set_result(result if single_value else [result])
except Exception as e:
Expand Down
Loading