Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,6 @@ def create_engine_config(
self.distributed_executor_backend = "mp"
logger.info("Using mp-based distributed executor backend "
"for async scheduling.")
if self.distributed_executor_backend == "uni":
raise ValueError("Async scheduling is not supported with "
"uni-process backend.")
if self.pipeline_parallel_size > 1:
raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.")
Expand Down
59 changes: 59 additions & 0 deletions vllm/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from queue import Queue
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.distributed.parallel_state import get_world_group
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)
Expand Down Expand Up @@ -44,9 +49,14 @@ def _init_executor(self) -> None:
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.init_execute_model_thread(kwargs)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
if self.vllm_config.scheduler_config.async_scheduling:
# start the execute_model thread after initialize distributed
# environment.
self._execute_model_thread.start()

def collective_rpc(self,
method: Union[str, Callable],
Expand All @@ -71,6 +81,50 @@ def reinitialize_distributed(
self.shutdown()
return

def _execute_model_loop(self):
# we need to set the device for the new thread,
# or it just use the default device 0.
torch.cuda.set_device(get_world_group().local_rank)
while True:
try:
sheduler_output = self._input_queue.get()
if sheduler_output is None:
break
super().execute_model(sheduler_output)
except Exception as e:
self._output_queue.put(e)

def execute_model(self, scheduler_output: SchedulerOutput):
if self.vllm_config.scheduler_config.async_scheduling:
self._input_queue.put(scheduler_output)
output = self._output_queue.get()
if isinstance(output, ModelRunnerOutput):
# execute_model thread just finished a step, and get a new
# sheduler_output immediately. so we need to block here until
# the d2h copy flag is ready.
self._output_queue.get()
elif isinstance(output, Exception):
raise output
return output
return super().execute_model(scheduler_output)

def init_execute_model_thread(self, kwargs):
if self.vllm_config.scheduler_config.async_scheduling:
self._input_queue: Queue = Queue()
self._output_queue: Queue = Queue()
self._execute_model_thread = Thread(
target=self._execute_model_loop,
daemon=True,
name="execute_model_loop",
)
kwargs["output_queue"] = self._output_queue

@property
def max_concurrent_batches(self) -> int:
if self.vllm_config.scheduler_config.async_scheduling:
return 2
return 1


UniProcExecutorAsync = UniProcExecutor

Expand Down Expand Up @@ -124,9 +178,14 @@ def _init_executor(self) -> None:
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.init_execute_model_thread(kwargs)
self.collective_rpc("init_worker", args=([kwargs], ))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
if self.vllm_config.scheduler_config.async_scheduling:
# start the execute_model thread after initialize distributed
# environment.
self._execute_model_thread.start()

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
Expand Down
27 changes: 27 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,33 @@ def step_with_batch_queue(

return engine_core_outputs, model_executed

def step_async_in_process(self):
"""Make asynchronous schedule in single process."""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
batch_queue = self.batch_queue
assert batch_queue is not None
if not self.scheduler.has_requests():
return {}, False
engine_core_outputs = {}
model_output = None
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
batch_queue.appendleft(scheduler_output) # type: ignore
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output)
# in single process mode, the model_output may be a bool value to
# notify it's time to make scheduleing of next step.
# so in this situation, we don't need to call update_from_output.
if isinstance(model_output, ModelRunnerOutput):
pre_scheduler_output = batch_queue.pop()
engine_core_outputs = self.scheduler.update_from_output(
pre_scheduler_output, model_output) # type: ignore

return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)

def shutdown(self):
self.structured_output_manager.clear_backend()
if self.model_executor:
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,13 @@ class InprocClient(EngineCoreClient):

def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs)
self.async_scehduling = self.engine_core.vllm_config.scheduler_config.async_scheduling # noqa

def get_output(self) -> EngineCoreOutputs:
outputs, _ = self.engine_core.step()
if self.async_scehduling:
outputs, _ = self.engine_core.step_async_in_process()
else:
outputs, _ = self.engine_core.step()
return outputs.get(0) or EngineCoreOutputs()

def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
Expand Down
21 changes: 19 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from queue import Queue
from typing import TYPE_CHECKING, Any, Optional, Union, cast

import numpy as np
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
output_queue: Optional[Queue] = None,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
Expand Down Expand Up @@ -319,6 +321,10 @@ def __init__(
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
# when enbale aysnc_scheduling in single process mode,
# we use a queue to transfer the model outputs and
# notify the main thread to make scheduling of next step.
self.output_queue = output_queue

def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
Expand Down Expand Up @@ -1739,8 +1745,7 @@ def execute_model(
)

self.eplb_step()

return ModelRunnerOutput(
output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
Expand All @@ -1750,6 +1755,14 @@ def execute_model(
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
)
if self.output_queue is None:
return output
# When using output_queue, deepcopy the req_ids and req_id_to_index
# incase the input_batch is modified before the output is consumed.
output.req_ids = deepcopy(self.input_batch.req_ids)
output.req_id_to_index = deepcopy(self.input_batch.req_id_to_index)
self.output_queue.put(output)
return output

def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
Expand Down Expand Up @@ -3268,5 +3281,9 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]]
pinned.copy_(sampled_token_ids, non_blocking=True)
self.transfer_event.record()
if self.output_queue is not None:
# just send a flag to notify the main thread to make
# scheduling of next step during the synchronizing time.
self.output_queue.put(True)
self.transfer_event.synchronize()
return pinned.tolist()
5 changes: 4 additions & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from queue import Queue
from typing import TYPE_CHECKING, Any, Optional

import torch
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
output_queue: Optional[Queue] = None,
):

super().__init__(vllm_config=vllm_config,
Expand Down Expand Up @@ -93,6 +95,7 @@ def __init__(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
self.output_queue = output_queue

def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator
Expand Down Expand Up @@ -199,7 +202,7 @@ def init_device(self):

# Construct the model runner
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
self.vllm_config, self.device, self.output_queue)

if self.rank == 0:
# If usage stat is enabled, collect relevant info.
Expand Down