Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/v1/e2e/test_async_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def test_without_spec_decoding(
(False, "mp", True, None, True),
(True, "mp", True, None, True),
(True, "uni", True, None, True),
(False, "ray", False, None, False),
(True, "ray", False, None, True),
(False, "ray", True, None, False),
(True, "ray", True, None, False),
(False, "ray", True, None, True),
(True, "ray", True, None, True),
]

if current_platform.is_rocm():
Expand Down Expand Up @@ -145,6 +151,13 @@ def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.Monke
(True, "uni", True, spec_config_short, False),
(True, "mp", True, spec_config, True),
(True, "uni", True, spec_config_short, True),
(False, "ray", False, None, False),
(False, "ray", False, spec_config, False),
(True, "ray", False, spec_config, True),
(False, "ray", True, spec_config, False),
(True, "ray", True, spec_config, False),
(False, "ray", True, spec_config_short, True),
(True, "ray", True, spec_config, True),
]

run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
Expand Down
22 changes: 0 additions & 22 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,13 +676,6 @@ def __post_init__(self):
self.model_config, self.load_config
)

executor_backend = self.parallel_config.distributed_executor_backend
executor_supports_async_sched = executor_backend in (
"mp",
"uni",
"external_launcher",
)

if self.scheduler_config.async_scheduling:
# Async scheduling explicitly enabled, hard fail any incompatibilities.
# Currently, async scheduling only support eagle speculative
Expand All @@ -703,12 +696,6 @@ def __post_init__(self):
"Async scheduling is not compatible with "
"disable_padded_drafter_batch=True."
)
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "
"`external_launcher` distributed executor backend, but you chose "
f"`{executor_backend}`."
)
elif self.scheduler_config.async_scheduling is None:
# Enable async scheduling unless there is an incompatible option.
if (
Expand All @@ -733,15 +720,6 @@ def __post_init__(self):
scope="local",
)
self.scheduler_config.async_scheduling = False
elif not executor_supports_async_sched:
logger.warning_once(
"Async scheduling will be disabled because it is not supported "
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
"`external_launcher` are supported).",
executor_backend,
scope="local",
)
self.scheduler_config.async_scheduling = False
else:
self.scheduler_config.async_scheduling = True

Expand Down
19 changes: 19 additions & 0 deletions vllm/v1/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Callable
from concurrent.futures import Future
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any

import cloudpickle
Expand Down Expand Up @@ -456,6 +457,11 @@ def sample_tokens( # type: ignore[override]

return self._execute_dag(scheduler_output, grammar_output, non_block)

@staticmethod
def _get_async_refs(refs, worker, timeout=None):
ray.get(refs, timeout=timeout)
return worker.execute_method.remote("get_execute_model_output")

def _execute_dag(
self,
scheduler_output: SchedulerOutput,
Expand All @@ -468,6 +474,19 @@ def _execute_dag(

refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore

if self.scheduler_config.async_scheduling:
assert non_block
assert self.parallel_config.pipeline_parallel_size == 1, (
"Async scheduling is not supported with pipeline parallelism."
)

# Delay getting the model runner output until next step execute_model
# returns.
refs = [
partial(RayDistributedExecutor._get_async_refs, ref, worker)
for ref, worker in zip(refs, self.workers)
]

if not self.has_connector:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.
Expand Down
52 changes: 47 additions & 5 deletions vllm/v1/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os
import time
from collections import defaultdict
from collections import defaultdict, deque
from concurrent.futures import Future
from typing import TYPE_CHECKING, Union

Expand Down Expand Up @@ -50,6 +50,13 @@ def __init__(self, *args, **kwargs) -> None:
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
self._execute_model_outputs = deque[
"ModelRunnerOutput"
| "AsyncModelRunnerOutput"
| tuple[
"SchedulerOutput", "GrammarOutput", "IntermediateTensors" | None
]
]()

def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
"""
Expand Down Expand Up @@ -110,7 +117,8 @@ def execute_model_ray(
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
) -> Union[
"ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors" | None],
None,
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
Expand Down Expand Up @@ -151,11 +159,33 @@ def execute_model_ray(
output = scheduler_output, grammar_output, None
elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output)

assert self.vllm_config is not None
if self.vllm_config.scheduler_config.async_scheduling:
self._execute_model_outputs.append(output)
return None

assert not isinstance(output, AsyncModelRunnerOutput)
return output

def get_execute_model_output(
self,
) -> Union[
"ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors" | None],
]:
assert (
self.vllm_config and self.vllm_config.scheduler_config.async_scheduling
)
assert self._execute_model_outputs, "No execute_model output available"
output = self._execute_model_outputs.popleft()

if isinstance(output, AsyncModelRunnerOutput):
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
output = output.get_output()

return output

def override_env_vars(self, vars: dict[str, str]):
Expand Down Expand Up @@ -191,8 +221,20 @@ def __init__(self, ref_or_refs, aggregator: KVOutputAggregator | None = None):
self.ref_or_refs = ref_or_refs
self.aggregator = aggregator

def is_callable(self, ref_or_refs):
if isinstance(ref_or_refs, list):
return callable(ref_or_refs[0])
return callable(ref_or_refs)

def get_refs(self, timeout=None):
if self.is_callable(self.ref_or_refs):
if isinstance(self.ref_or_refs, list):
return [ref(timeout) for ref in self.ref_or_refs]
return self.ref_or_refs(timeout)
return self.ref_or_refs

def result(self, timeout=None):
outputs = ray.get(self.ref_or_refs, timeout=timeout)
outputs = ray.get(self.get_refs(timeout), timeout=timeout)
if self.aggregator is None:
return outputs

Expand Down
Loading