From 3a3a250aa102b5b216cdb49225a4ae7580dae050 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 10 Mar 2026 13:54:48 -0700 Subject: [PATCH 01/29] Implement RayExecutorV2 & tested on a single-node Signed-off-by: Jeffrey Wang --- .buildkite/test_areas/distributed.yaml | 18 ++ tests/distributed/test_ray_v2_executor.py | 282 ++++++++++++++++++ vllm/envs.py | 7 + vllm/v1/executor/abstract.py | 10 +- vllm/v1/executor/ray_executor_v2.py | 344 ++++++++++++++++++++++ 5 files changed, 659 insertions(+), 2 deletions(-) create mode 100644 tests/distributed/test_ray_v2_executor.py create mode 100644 vllm/v1/executor/ray_executor_v2.py diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 06a0b5212eeb..5b13cb8fe6b7 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -239,3 +239,21 @@ steps: commands: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py + +- label: RayExecutorV2 (4 GPUs) + timeout_in_minutes: 60 + working_dir: "/vllm-workspace/tests" + num_devices: 4 + source_file_dependencies: + - vllm/v1/executor/ray_executor_v2.py + - vllm/v1/executor/abstract.py + - vllm/v1/executor/multiproc_executor.py + - tests/distributed/test_ray_v2_executor.py + - tests/distributed/test_pipeline_parallel.py + - tests/basic_correctness/test_basic_correctness.py + commands: + - export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 + - export NCCL_CUMEM_HOST_ENABLE=0 + - pytest -v -s distributed/test_ray_v2_executor.py + - pytest -v -s distributed/test_pipeline_parallel.py -k "ray" + - TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray" diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py new file mode 100644 index 000000000000..929b0c098297 --- /dev/null +++ b/tests/distributed/test_ray_v2_executor.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Integration tests for RayExecutorV2 at the executor level. +Validates executor initialization, placement group support, RPC calls, +and distributed execution with various TP/PP configurations. +""" + +import os +import threading + +import pytest +import ray +import torch +from ray.util.placement_group import PlacementGroup +from ray.util.state import list_actors + +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 + +MODEL = "facebook/opt-125m" +NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0 + + +@pytest.fixture(autouse=True) +def enable_ray_v2_backend(): + """Enable the RayExecutorV2 backend via feature flag for all tests.""" + saved = { + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" + ), + "RAY_RUNTIME_ENV_HOOK": os.environ.get("RAY_RUNTIME_ENV_HOOK"), + } + os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + # TODO (jeffreywang): Figure out vLLM CI + # -- this is only necessary for Anyscale ray cluster + os.environ.pop("RAY_RUNTIME_ENV_HOOK", None) + try: + yield + finally: + _cleanup_ray_resources() + os.environ.update({k: v for k, v in saved.items() if v is not None}) + for key in (k for k, v in saved.items() if v is None): + os.environ.pop(key, None) + + +def _cleanup_ray_resources(): + dangling_actors = [ + actor + for actor in list_actors(filters=[("state", "=", "ALIVE")]) + if actor.class_name == "RayWorkerProc" + ] + assert not dangling_actors + + for pg_id, pg_info in ray.util.placement_group_table().items(): + if pg_info["state"] == "CREATED": + pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) + ray.util.remove_placement_group(pg) + + +def create_vllm_config( + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = 256, + gpu_memory_utilization: float = 0.3, + placement_group=None, +) -> VllmConfig: + engine_args = EngineArgs( + model=MODEL, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + distributed_executor_backend="ray", + enforce_eager=True, + ) + vllm_config = engine_args.create_engine_config() + + if placement_group is not None: + vllm_config.parallel_config.placement_group = placement_group + + return vllm_config + + +def ensure_ray_initialized(): + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + + +@pytest.fixture +def create_placement_group(request): + ensure_ray_initialized() + num_gpus = request.param + bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] + pg = ray.util.placement_group(bundles, strategy="PACK") + ray.get(pg.ready()) + yield pg + ray.util.remove_placement_group(pg) + + +@pytest.fixture +def executor(request): + """Create a RayExecutorV2 and shut it down after the test.""" + executor = RayExecutorV2(vllm_config=request.param) + yield executor + executor.shutdown() + + +def assert_executor(executor, tp_size, pp_size): + """Common assertions for executor initialization tests.""" + world_size = tp_size * pp_size + expected_output_rank = (pp_size - 1) * tp_size + + assert executor.world_size == world_size + assert len(executor.ray_worker_handles) == world_size + assert len(executor.response_mqs) == world_size + assert executor._get_output_rank() == expected_output_rank + + if pp_size > 1: + assert executor.max_concurrent_batches == pp_size + + executor.check_health() + assert not executor.is_failed + + ranks = sorted(h.rank for h in executor.ray_worker_handles) + assert ranks == list(range(world_size)) + + for handle in executor.ray_worker_handles: + assert handle.node_id is not None + + +@pytest.mark.parametrize("tp_size, pp_size", [(1, 1), (2, 1), (4, 1), (2, 2)]) +def test_ray_v2_executor(tp_size, pp_size): + """Validate RayExecutorV2 with various TP/PP configs.""" + world_size = tp_size * pp_size + if world_size > NUM_GPUS: + pytest.skip(f"Need at least {world_size} GPUs") + + vllm_config = create_vllm_config( + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + ) + executor = RayExecutorV2(vllm_config=vllm_config) + try: + assert_executor(executor, tp_size, pp_size) + finally: + executor.shutdown() + + +@pytest.mark.parametrize( + "tp_size, pp_size, create_placement_group", + [(2, 1, 2), (4, 1, 4), (2, 2, 4)], + indirect=["create_placement_group"], +) +def test_ray_v2_executor_pg(tp_size, pp_size, create_placement_group): + """Validate RayExecutorV2 with various TP/PP configs using external PG.""" + world_size = tp_size * pp_size + if world_size > NUM_GPUS: + pytest.skip(f"Need at least {world_size} GPUs") + + vllm_config = create_vllm_config( + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + placement_group=create_placement_group, + ) + executor = RayExecutorV2(vllm_config=vllm_config) + try: + assert_executor(executor, tp_size, pp_size) + finally: + executor.shutdown() + + +@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") +@pytest.mark.parametrize( + "executor", + [create_vllm_config(tensor_parallel_size=2)], + indirect=True, +) +def test_ray_v2_executor_failure_callback(executor): + """Validate failure callback registration.""" + callback_invoked = False + + def test_callback(): + nonlocal callback_invoked + callback_invoked = True + + executor.register_failure_callback(test_callback) + assert not callback_invoked + + executor.is_failed = True + executor.register_failure_callback(test_callback) + assert callback_invoked + + +@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") +@pytest.mark.parametrize( + "executor", + [create_vllm_config(tensor_parallel_size=2)], + indirect=True, +) +def test_ray_v2_executor_collective_rpc(executor): + """Validate collective RPC calls through MessageQueue.""" + executor.check_health() + assert not executor.is_failed + assert executor.rpc_broadcast_mq is not None + + +@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") +@pytest.mark.parametrize( + "executor", + [create_vllm_config(tensor_parallel_size=2)], + indirect=True, +) +def test_ray_v2_executor_driver_node_rank_0(executor): + """Validate that driver node workers get the lowest ranks.""" + driver_node = ray.get_runtime_context().get_node_id() + + for handle in executor.ray_worker_handles: + assert handle.node_id == driver_node + + rank0_handle = next(h for h in executor.ray_worker_handles if h.rank == 0) + assert rank0_handle.node_id == driver_node + + +@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") +@pytest.mark.parametrize( + "executor", + [create_vllm_config(tensor_parallel_size=2)], + indirect=True, +) +def test_ray_v2_executor_worker_death(executor): + """Validate executor detects worker death via ray.wait().""" + callback_event = threading.Event() + + def on_failure(): + callback_event.set() + + executor.register_failure_callback(on_failure) + assert not executor.is_failed + + # Kill one worker actor externally + victim = executor.ray_worker_handles[1].actor + ray.kill(victim, no_restart=True) + + # Monitor thread should detect the death and invoke callback + assert callback_event.wait(timeout=30) + assert executor.is_failed + assert executor.shutting_down + + +@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") +def test_ray_v2_executor_shutdown(): + """Validate graceful shutdown: ray.kill() terminates all worker actors.""" + executor = RayExecutorV2(vllm_config=create_vllm_config(tensor_parallel_size=2)) + assert executor.rpc_broadcast_mq is not None + assert len(executor.response_mqs) == executor.world_size + + actors = [h.actor for h in executor.ray_worker_handles] + executor.shutdown() + + for actor in actors: + with pytest.raises(ray.exceptions.RayActorError): + ray.get(actor.wait_for_init.remote(), timeout=5) + + assert executor.rpc_broadcast_mq is None + assert len(executor.response_mqs) == 0 + + +@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") +@pytest.mark.parametrize( + "executor", + [create_vllm_config(tensor_parallel_size=2)], + indirect=True, +) +def test_ray_v2_run_refs_stored_for_monitoring(executor): + """Validate worker handles store run_ref for monitoring.""" + for handle in executor.ray_worker_handles: + assert handle.run_ref is not None + ready, _ = ray.wait([handle.run_ref], timeout=0) + assert len(ready) == 0, "run_ref should be pending" diff --git a/vllm/envs.py b/vllm/envs.py index 716810da1c27..eeb2419294bc 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -58,6 +58,7 @@ VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True + VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False VLLM_XLA_USE_SPMD: bool = False VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -730,6 +731,12 @@ def _get_or_set_default() -> str: "VLLM_USE_RAY_WRAPPED_PP_COMM": lambda: bool( int(os.getenv("VLLM_USE_RAY_WRAPPED_PP_COMM", "1")) ), + # When True and distributed_executor_backend="ray", use RayExecutorV2 + # (MQ-based) instead of RayDistributedExecutor (compiled-graph backend). + # TODO (jeffreywang): Enabled by default in vLLM 0.20.0. + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": lambda: bool( + int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "0")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 8e7c48054554..952c37e6749c 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -7,6 +7,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Literal, TypeVar, overload +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -57,9 +58,14 @@ def get_class(vllm_config: VllmConfig) -> type["Executor"]: ) executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": - from vllm.v1.executor.ray_executor import RayDistributedExecutor + if envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND: + from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 - executor_class = RayDistributedExecutor + executor_class = RayExecutorV2 + else: + from vllm.v1.executor.ray_executor import RayDistributedExecutor + + executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py new file mode 100644 index 000000000000..50cc4ad287b9 --- /dev/null +++ b/vllm/v1/executor/ray_executor_v2.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import threading +import weakref +from collections import defaultdict, deque +from dataclasses import dataclass +from typing import Any + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.distributed.device_communicators.shm_broadcast import ( + Handle, + MessageQueue, +) +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.network_utils import ( + get_distributed_init_method, + get_ip, + get_loopback_ip, + get_open_port, +) +from vllm.v1.executor.multiproc_executor import ( + FutureWrapper, + MultiprocExecutor, + WorkerProc, +) +from vllm.v1.executor.ray_utils import ( + initialize_ray_cluster, + ray, +) + +if ray is not None: + from ray.actor import ActorHandle + from ray.types import ObjectRef + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +else: + ActorHandle = None + +logger = init_logger(__name__) + + +@dataclass +class RayWorkerHandle: + """Handle for a Ray worker actor, compatible with MultiprocExecutor.""" + + actor: ActorHandle + """Ray worker actor""" + + rank: int + """Rank of the worker""" + + local_rank: int + """Local rank of the worker""" + + node_id: str + """Node ID of the worker""" + + bundle_index: int + """Placement group bundle index to schedule the actor on""" + + run_ref: ObjectRef = None + """run() ObjectRef used as a sentinel for health monitoring""" + + +class RayWorkerProc(WorkerProc): + """Worker process that runs inside a Ray actor.""" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle: Handle, + is_driver_worker: bool, + ): + super().__init__( + vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + input_shm_handle=input_shm_handle, + shared_worker_lock=None, + is_driver_worker=is_driver_worker, + ) + self.local_rank = local_rank + + def wait_for_init(self) -> dict: + """Respond to the driver's wait_until_ready() barrier.""" + assert self.worker_response_mq is not None + return { + "status": self.READY_STR, + "handle": self.worker_response_mq.export_handle(), + } + + def run(self) -> None: + """Main entry point called via actor.run.remote().""" + try: + assert self.rpc_broadcast_mq is not None + self.rpc_broadcast_mq.wait_until_ready() + assert self.worker_response_mq is not None + self.worker_response_mq.wait_until_ready() + + self.worker_busy_loop() + finally: + self.shutdown() + + +class RayExecutorV2(MultiprocExecutor): + """Ray-based distributed executor using MessageQueue communication. + + Inherits from MultiprocExecutor to reuse the MQ-based control plane + and NCCL data plane. Workers are Ray actors. + """ + + uses_ray: bool = True + supports_pp: bool = True + + def __init__(self, vllm_config: VllmConfig): + # Skip MultiprocExecutor.__init__; we monitor via ray.wait() + self.monitor_workers = False + super(MultiprocExecutor, self).__init__(vllm_config) + + def _init_executor(self) -> None: + """Initialize the RayExecutorV2 executor.""" + self._finalizer = weakref.finalize(self, self.shutdown) + self.is_failed = False + self.failure_callback = None + self.shutting_down = False + + # Step 1: Initialize Ray cluster and retrieve placement group + if ray is None: + raise ImportError("Ray is required for RayExecutorV2") + initialize_ray_cluster(self.parallel_config) + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + tp_size, pp_size, pcp_size = self._get_parallel_sizes() + assert self.world_size == tp_size * pp_size * pcp_size, ( + f"world_size ({self.world_size}) must be equal to the " + f"tensor_parallel_size ({tp_size}) x pipeline" + f"_parallel_size ({pp_size}) x prefill_context" + f"_parallel_size ({pcp_size}). " + ) + + # Step 2: Query PG table, sort bundles, assign ranks + pg_table = ray.util.placement_group_table(placement_group) + bundle_to_node = pg_table["bundles_to_node_id"] + + # Prefer driver node; group by node for TP locality + bundle_to_node_id = [] + for i, bundle in enumerate(placement_group.bundle_specs): + ray_device_key = current_platform.ray_device_key + if not ray_device_key: + raise ValueError( + f"current platform {current_platform.device_name}" + " does not support ray." + ) + + if bundle.get(ray_device_key): + node_id = bundle_to_node.get(i) or bundle_to_node.get(str(i)) + bundle_to_node_id.append((i, node_id)) + + bundle_to_node_id = bundle_to_node_id[: self.world_size] + driver_node = ray.get_runtime_context().get_node_id() + + def _sort_key(item): + _, node_id = item + return (0 if node_id == driver_node else 1, node_id) + + bundle_to_node_id.sort(key=_sort_key) + + # Assign each worker a local rank + node_rank_counter: dict[str, int] = defaultdict(int) + bundle_assignments: list[dict[str, Any]] = [] + for rank, (bundle_id, node_id) in enumerate(bundle_to_node_id): + local_rank = node_rank_counter[node_id] + node_rank_counter[node_id] += 1 + bundle_assignments.append( + { + "rank": rank, + "local_rank": local_rank, + "bundle_id": bundle_id, + "node_id": node_id, + } + ) + + # Determine node topology + node_ids = list(dict.fromkeys(a["node_id"] for a in bundle_assignments)) + is_single_node = len(node_ids) == 1 + + # Step 3: Create broadcast MessageQueue + distributed_init_method = get_distributed_init_method( + get_loopback_ip() if is_single_node else get_ip(), get_open_port() + ) + + max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 + mq_connect_ip = get_ip() + self.rpc_broadcast_mq = MessageQueue( + self.world_size, + self.local_world_size, + max_chunk_bytes=max_chunk_bytes, + connect_ip=mq_connect_ip, + ) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() + + # Step 4: Spawn RayWorkerProc actors into PG bundles + self.ray_worker_handles: list[RayWorkerHandle] = [] + self._ray_actors: list[Any] = [] + + # Create the remote actor + for assignment in bundle_assignments: + is_driver_worker = self._is_driver_worker(assignment["rank"]) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=assignment["bundle_id"], + ) + + # Prevent Ray from setting CUDA_VISIBLE_DEVICES + runtime_env = { + "env_vars": { + env_var: "1" + for env_var in current_platform.ray_noset_device_env_vars + }, + } + + actor = ( + ray.remote(RayWorkerProc) + .options( + num_cpus=0, + num_gpus=envs.VLLM_RAY_PER_WORKER_GPUS, + scheduling_strategy=scheduling_strategy, + runtime_env=runtime_env, + ) + .remote( + vllm_config=self.vllm_config, + local_rank=assignment["local_rank"], + rank=assignment["rank"], + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + is_driver_worker=is_driver_worker, + ) + ) + + handle = RayWorkerHandle( + actor=actor, + rank=assignment["rank"], + local_rank=assignment["local_rank"], + node_id=assignment["node_id"], + bundle_index=assignment["bundle_id"], + ) + self.ray_worker_handles.append(handle) + self._ray_actors.append(actor) + + # Step 5: Collect response MQ handles + init_refs = [h.actor.wait_for_init.remote() for h in self.ray_worker_handles] + init_results = ray.get(init_refs) + + self.response_mqs: list[MessageQueue] = [] + for i, result in enumerate(init_results): + if result["status"] != RayWorkerProc.READY_STR: + raise RuntimeError(f"Worker {i} failed to initialize: {result}") + self.response_mqs.append( + MessageQueue.create_from_handle(result["handle"], 0) + ) + + # Step 6: Start run() before wait_until_ready() to avoid + # deadlock — workers send subscriptions inside run(). + for handle in self.ray_worker_handles: + handle.run_ref = handle.actor.run.remote() + + # Step 7: wait_until_ready() barrier + self.rpc_broadcast_mq.wait_until_ready() + for response_mq in self.response_mqs: + response_mq.wait_until_ready() + + self.futures_queue = deque[tuple[FutureWrapper, Any]]() + self._post_init_executor() + + self.start_worker_monitor() + self.output_rank = self._get_output_rank() + + def start_worker_monitor(self, inline=False) -> None: + """Monitor worker liveness via ray.wait() on run() ObjectRefs.""" + run_refs = [h.run_ref for h in self.ray_worker_handles if h.run_ref is not None] + if not run_refs: + return + + self_ref = weakref.ref(self) + + ref_to_rank = { + h.run_ref: h.rank for h in self.ray_worker_handles if h.run_ref is not None + } + + def monitor_workers(): + done, _ = ray.wait(run_refs, num_returns=1) + executor = self_ref() + if not executor or executor.shutting_down: + return + + dead_ranks = [ref_to_rank[ref] for ref in done if ref in ref_to_rank] + executor.is_failed = True + logger.error( + "RayWorkerProc rank=%s died unexpectedly, shutting down executor.", + dead_ranks, + ) + executor.shutdown() + + if executor.failure_callback is not None: + callback = executor.failure_callback + executor.failure_callback = None + callback() + + threading.Thread( + target=monitor_workers, daemon=True, name="RayWorkerMonitor" + ).start() + + def shutdown(self) -> None: + """Properly shut down the executor and its workers""" + if getattr(self, "shutting_down", False): + return + self.shutting_down = True + + for handle in getattr(self, "ray_worker_handles", []): + try: + ray.kill(handle.actor) + except Exception: + logger.exception("Failed to kill actor rank=%d", handle.rank) + + if rpc_broadcast_mq := getattr(self, "rpc_broadcast_mq", None): + rpc_broadcast_mq.shutdown() + self.rpc_broadcast_mq = None + + for mq in getattr(self, "response_mqs", []): + mq.shutdown() + self.response_mqs = [] From df756649c8f52d0f9744a3ac75969a69d3848095 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Wed, 11 Mar 2026 18:18:49 -0700 Subject: [PATCH 02/29] Enable multinode Signed-off-by: Jeffrey Wang --- .buildkite/test_areas/distributed.yaml | 29 ++ tests/distributed/test_mq_tcp_multinode.py | 120 +++++++++ tests/distributed/test_ray_v2_executor.py | 114 ++++++-- .../test_ray_v2_executor_multinode.py | 252 ++++++++++++++++++ vllm/v1/executor/ray_executor_v2.py | 116 +++++--- vllm/v1/executor/ray_utils.py | 41 ++- 6 files changed, 605 insertions(+), 67 deletions(-) create mode 100644 tests/distributed/test_mq_tcp_multinode.py create mode 100644 tests/distributed/test_ray_v2_executor_multinode.py diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 5b13cb8fe6b7..8fee8dacc266 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -180,6 +180,20 @@ steps: commands: - ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 2 $IMAGE_TAG "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=0 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py && VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py" "VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' && NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' && python3 ../examples/offline_inference/data_parallel.py -dp=2 -tp=1 --dp-num-nodes=2 --dp-node-rank=1 --dp-master-addr=192.168.10.10 --dp-master-port=12345 --enforce-eager --trust-remote-code" +- label: MessageQueue TCP Multi-Node (2 GPUs) + timeout_in_minutes: 10 + working_dir: "/vllm-workspace/tests" + num_devices: 1 + num_nodes: 2 + no_plugin: true + optional: true + source_file_dependencies: + - vllm/distributed/device_communicators/shm_broadcast.py + - vllm/distributed/parallel_state.py + - tests/distributed/test_mq_tcp_multinode.py + commands: + - ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 1 $IMAGE_TAG "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" "torchrun --nnodes 2 --nproc-per-node=1 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_mq_tcp_multinode.py" + - label: Distributed NixlConnector PD accuracy (4 GPUs) timeout_in_minutes: 30 working_dir: "/vllm-workspace/tests" @@ -257,3 +271,18 @@ steps: - pytest -v -s distributed/test_ray_v2_executor.py - pytest -v -s distributed/test_pipeline_parallel.py -k "ray" - TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray" + +- label: RayExecutorV2 Multi-Node (8 GPUs) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/tests" + num_devices: 4 + num_nodes: 2 + no_plugin: true + optional: true + source_file_dependencies: + - vllm/v1/executor/ray_executor_v2.py + - vllm/v1/executor/abstract.py + - vllm/v1/executor/multiproc_executor.py + - tests/distributed/test_ray_v2_executor_multinode.py + commands: + - ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 4 $IMAGE_TAG "VLLM_MULTI_NODE=1 VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 NCCL_CUMEM_HOST_ENABLE=0 pytest -v -s distributed/test_ray_v2_executor_multinode.py" "echo 'Worker node ready'" diff --git a/tests/distributed/test_mq_tcp_multinode.py b/tests/distributed/test_mq_tcp_multinode.py new file mode 100644 index 000000000000..61f0b239b675 --- /dev/null +++ b/tests/distributed/test_mq_tcp_multinode.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Multi-node integration test for MessageQueue TCP fallback. + +Verifies that when writer and readers span separate nodes (Docker containers +with isolated /dev/shm), `create_from_process_group` correctly detects +cross-node ranks via `in_the_same_node_as()` and falls back to ZMQ TCP +transport — and that data actually arrives. +""" + +import numpy as np +import torch.distributed as dist + +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.distributed.parallel_state import in_the_same_node_as + + +def main(): + dist.init_process_group(backend="gloo") + + rank = dist.get_rank() + world_size = dist.get_world_size() + assert world_size >= 2, ( + f"Need at least 2 ranks across nodes, got world_size={world_size}" + ) + + # Verify that in_the_same_node_as detects cross-node correctly + status = in_the_same_node_as(dist.group.WORLD, source_rank=0) + local_count = sum(status) + print( + f"[Rank {rank}] in_the_same_node_as(source=0): {status} " + f"(local={local_count}/{world_size})" + ) + # With 2 Docker containers (1 proc each), rank 0 and rank 1 should be on different nodes. + assert local_count < world_size, ( + f"Expected cross-node ranks but all {world_size} ranks appear local." + ) + + # Create MessageQueue + writer_rank = 0 + mq = MessageQueue.create_from_process_group( + dist.group.WORLD, + max_chunk_bytes=1024 * 1024, # 1 MiB + max_chunks=10, + writer_rank=writer_rank, + ) + + # Verify the transport path selection + if rank == writer_rank: + print( + f"[Rank {rank}] Writer: n_local_reader={mq.n_local_reader}, " + f"n_remote_reader={mq.n_remote_reader}" + ) + assert mq.n_remote_reader > 0, ( + "Writer should have at least 1 remote (TCP) reader in a " + "multi-node setup." + ) + else: + if status[rank]: + assert mq._is_local_reader, ( + f"Rank {rank} is on the same node as writer but is not a " + "local reader." + ) + print(f"[Rank {rank}] Reader: local (shared memory)") + else: + assert mq._is_remote_reader, ( + f"Rank {rank} is on a different node but is not a remote " + "(TCP) reader." + ) + print(f"[Rank {rank}] Reader: remote (TCP)") + + # Test data transfer: simple objects + dist.barrier() + if rank == writer_rank: + mq.enqueue("hello_from_node0") + else: + msg = mq.dequeue(timeout=10) + assert msg == "hello_from_node0" + dist.barrier() + print(f"[Rank {rank}] Simple object test passed") + + # Test data transfer: numpy arrays + np.random.seed(42) + arrays = [np.random.randint(0, 100, size=np.random.randint(100, 5000)) + for _ in range(100)] + + dist.barrier() + if rank == writer_rank: + for arr in arrays: + mq.enqueue(arr) + else: + for i, expected in enumerate(arrays): + received = mq.dequeue(timeout=10) + assert np.array_equal(expected, received), ( + f"Array mismatch at index {i}: " + f"expected shape {expected.shape}, got shape {received.shape}" + ) + dist.barrier() + print(f"[Rank {rank}] Numpy array test passed") + + # Test data transfer: large payload (> max_chunk_bytes) + dist.barrier() + big_array = np.zeros(200_000, dtype=np.int64) # ~1.6 MiB > 1 MiB chunk + if rank == writer_rank: + mq.enqueue(big_array) + else: + received = mq.dequeue(timeout=10) + assert np.array_equal(big_array, received) + dist.barrier() + print(f"[Rank {rank}] Large payload test passed") + + # Done -- cleanup + dist.barrier() + print(f"[Rank {rank}] All MessageQueue TCP multi-node tests passed!") + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index 929b0c098297..4c799d69ecf9 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -7,21 +7,23 @@ and distributed execution with various TP/PP configurations. """ +import gc import os import threading +import time +from unittest.mock import patch import pytest import ray -import torch from ray.util.placement_group import PlacementGroup from ray.util.state import list_actors +from vllm import LLM from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 MODEL = "facebook/opt-125m" -NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0 @pytest.fixture(autouse=True) @@ -32,10 +34,15 @@ def enable_ray_v2_backend(): "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" ), "RAY_RUNTIME_ENV_HOOK": os.environ.get("RAY_RUNTIME_ENV_HOOK"), + "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( + "VLLM_ENABLE_V1_MULTIPROCESSING" + ), } os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + # TODO (jeffreywang): Is this necessary? + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" # TODO (jeffreywang): Figure out vLLM CI - # -- this is only necessary for Anyscale ray cluster + # This is only necessary for Anyscale ray cluster. os.environ.pop("RAY_RUNTIME_ENV_HOOK", None) try: yield @@ -47,11 +54,19 @@ def enable_ray_v2_backend(): def _cleanup_ray_resources(): - dangling_actors = [ - actor - for actor in list_actors(filters=[("state", "=", "ALIVE")]) - if actor.class_name == "RayWorkerProc" - ] + if not ray.is_initialized(): + return + + # Ray actor shutdown is async -- wait until all actors are dead + for _ in range(10): + dangling_actors = [ + actor + for actor in list_actors(filters=[("state", "=", "ALIVE")]) + if actor.class_name == "RayWorkerProc" + ] + if not dangling_actors: + break + time.sleep(1) assert not dangling_actors for pg_id, pg_info in ray.util.placement_group_table().items(): @@ -59,6 +74,8 @@ def _cleanup_ray_resources(): pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) ray.util.remove_placement_group(pg) + ray.shutdown() + def create_vllm_config( tensor_parallel_size: int = 1, @@ -134,10 +151,6 @@ def assert_executor(executor, tp_size, pp_size): @pytest.mark.parametrize("tp_size, pp_size", [(1, 1), (2, 1), (4, 1), (2, 2)]) def test_ray_v2_executor(tp_size, pp_size): """Validate RayExecutorV2 with various TP/PP configs.""" - world_size = tp_size * pp_size - if world_size > NUM_GPUS: - pytest.skip(f"Need at least {world_size} GPUs") - vllm_config = create_vllm_config( tensor_parallel_size=tp_size, pipeline_parallel_size=pp_size, @@ -156,10 +169,6 @@ def test_ray_v2_executor(tp_size, pp_size): ) def test_ray_v2_executor_pg(tp_size, pp_size, create_placement_group): """Validate RayExecutorV2 with various TP/PP configs using external PG.""" - world_size = tp_size * pp_size - if world_size > NUM_GPUS: - pytest.skip(f"Need at least {world_size} GPUs") - vllm_config = create_vllm_config( tensor_parallel_size=tp_size, pipeline_parallel_size=pp_size, @@ -172,7 +181,6 @@ def test_ray_v2_executor_pg(tp_size, pp_size, create_placement_group): executor.shutdown() -@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") @pytest.mark.parametrize( "executor", [create_vllm_config(tensor_parallel_size=2)], @@ -194,7 +202,6 @@ def test_callback(): assert callback_invoked -@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") @pytest.mark.parametrize( "executor", [create_vllm_config(tensor_parallel_size=2)], @@ -207,7 +214,6 @@ def test_ray_v2_executor_collective_rpc(executor): assert executor.rpc_broadcast_mq is not None -@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") @pytest.mark.parametrize( "executor", [create_vllm_config(tensor_parallel_size=2)], @@ -224,7 +230,6 @@ def test_ray_v2_executor_driver_node_rank_0(executor): assert rank0_handle.node_id == driver_node -@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") @pytest.mark.parametrize( "executor", [create_vllm_config(tensor_parallel_size=2)], @@ -250,7 +255,6 @@ def on_failure(): assert executor.shutting_down -@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") def test_ray_v2_executor_shutdown(): """Validate graceful shutdown: ray.kill() terminates all worker actors.""" executor = RayExecutorV2(vllm_config=create_vllm_config(tensor_parallel_size=2)) @@ -268,7 +272,6 @@ def test_ray_v2_executor_shutdown(): assert len(executor.response_mqs) == 0 -@pytest.mark.skipif(NUM_GPUS < 2, reason="Need at least 2 GPUs") @pytest.mark.parametrize( "executor", [create_vllm_config(tensor_parallel_size=2)], @@ -280,3 +283,70 @@ def test_ray_v2_run_refs_stored_for_monitoring(executor): assert handle.run_ref is not None ready, _ = ray.wait([handle.run_ref], timeout=0) assert len(ready) == 0, "run_ref should be pending" + + +@pytest.mark.parametrize("tp_size, pp_size", [(2, 1), (2, 2)]) +def test_ray_v2_single_node_generation(tp_size, pp_size): + """End-to-end LLM generation with RayExecutorV2.""" + + llm = LLM( + model=MODEL, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + distributed_executor_backend="ray", + enforce_eager=True, + max_model_len=256, + gpu_memory_utilization=0.3, + ) + try: + prompts = [ + "Hello, my name is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts) + + assert len(outputs) == len(prompts) + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text) > 0 + finally: + llm.llm_engine.model_executor.shutdown() + del llm + gc.collect() + + +@pytest.mark.parametrize("tp_size, pp_size", [(2, 1), (2, 2)]) +def test_ray_v2_single_node_generation_with_pg(tp_size, pp_size): + """E2E LLM generation with a user-provided placement group.""" + ensure_ray_initialized() + bundles = [{"GPU": 1, "CPU": 1} for _ in range(tp_size * pp_size)] + pg = ray.util.placement_group(bundles, strategy="PACK") + ray.get(pg.ready()) + + try: + with patch.object(ray.util, "get_current_placement_group", return_value=pg): + llm = LLM( + model=MODEL, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + distributed_executor_backend="ray", + enforce_eager=True, + max_model_len=256, + gpu_memory_utilization=0.3, + ) + prompts = [ + "Hello, my name is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts) + + assert len(outputs) == len(prompts) + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text) > 0 + finally: + llm.llm_engine.model_executor.shutdown() + del llm + gc.collect() diff --git a/tests/distributed/test_ray_v2_executor_multinode.py b/tests/distributed/test_ray_v2_executor_multinode.py new file mode 100644 index 000000000000..09538b0f8557 --- /dev/null +++ b/tests/distributed/test_ray_v2_executor_multinode.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Multi-node integration tests for RayExecutorV2. +Validates executor initialization, worker placement, and end-to-end +generation across multiple nodes with various TP/PP configurations. + +Requires VLLM_MULTI_NODE=1 env var and a multi-node Ray cluster. + +Run: +```sh +VLLM_MULTI_NODE=1 VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 \ + pytest -v -s distributed/test_ray_v2_executor_multinode.py +``` +""" + +import gc +import os +import time +from unittest.mock import patch + +import pytest +import ray +from ray.util.placement_group import PlacementGroup +from ray.util.state import list_actors + +from vllm import LLM +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 + +MODEL = "facebook/opt-125m" + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + +pytestmark = pytest.mark.skipif( + not VLLM_MULTI_NODE, reason="Need VLLM_MULTI_NODE=1 and a multi-node cluster." +) + + +@pytest.fixture(autouse=True) +def enable_ray_v2_backend(): + """Enable the RayExecutorV2 backend via feature flag for all tests.""" + saved = { + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" + ), + "RAY_RUNTIME_ENV_HOOK": os.environ.get("RAY_RUNTIME_ENV_HOOK"), + "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( + "VLLM_ENABLE_V1_MULTIPROCESSING" + ), + } + os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + os.environ.pop("RAY_RUNTIME_ENV_HOOK", None) + # Disable multiprocessing to avoid fork-after-Ray-init issues. + # The RayExecutorV2 already handles distribution via Ray actors, + # so the EngineCore can safely run in-process. + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + try: + yield + finally: + _cleanup_ray_resources() + os.environ.update({k: v for k, v in saved.items() if v is not None}) + for key in (k for k, v in saved.items() if v is None): + os.environ.pop(key, None) + + +def _cleanup_ray_resources(): + if not ray.is_initialized(): + return + + # Wait briefly for async cleanup (del llm triggers GC-based shutdown) + for _ in range(10): + dangling_actors = [ + actor + for actor in list_actors(filters=[("state", "=", "ALIVE")]) + if actor.class_name == "RayWorkerProc" + ] + if not dangling_actors: + break + time.sleep(1) + + for pg_id, pg_info in ray.util.placement_group_table().items(): + if pg_info["state"] == "CREATED": + pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) + ray.util.remove_placement_group(pg) + + # Disconnect from Ray so forked subprocesses (EngineCore) don't inherit + # a stale driver connection that can't create placement groups. + ray.shutdown() + + +def create_vllm_config( + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = 256, + gpu_memory_utilization: float = 0.3, +) -> VllmConfig: + engine_args = EngineArgs( + model=MODEL, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + distributed_executor_backend="ray", + enforce_eager=True, + ) + return engine_args.create_engine_config() + + +def assert_executor(executor, tp_size, pp_size): + """Common assertions for executor initialization tests.""" + world_size = tp_size * pp_size + expected_output_rank = (pp_size - 1) * tp_size + + assert executor.world_size == world_size + assert len(executor.ray_worker_handles) == world_size + assert len(executor.response_mqs) == world_size + assert executor._get_output_rank() == expected_output_rank + + if pp_size > 1: + assert executor.max_concurrent_batches == pp_size + + executor.check_health() + assert not executor.is_failed + + ranks = sorted(h.rank for h in executor.ray_worker_handles) + assert ranks == list(range(world_size)) + + for handle in executor.ray_worker_handles: + assert handle.node_id is not None + + +def test_ray_v2_multinode_executor_init(): + """Validate RayExecutorV2 initializes correctly across multiple nodes + with TP=4, PP=2 (8 GPUs).""" + vllm_config = create_vllm_config( + tensor_parallel_size=4, + pipeline_parallel_size=2, + ) + executor = RayExecutorV2(vllm_config=vllm_config) + try: + assert_executor(executor, tp_size=4, pp_size=2) + + # Verify workers span multiple nodes + node_ids = {h.node_id for h in executor.ray_worker_handles} + assert len(node_ids) > 1 + + # Verify rank 0 exists and has a valid node_id. + # On clusters where the driver node has GPUs, rank 0 will be there. + # On GPU-less head nodes, rank 0 is on the first GPU node instead. + rank0_handle = next(h for h in executor.ray_worker_handles if h.rank == 0) + assert rank0_handle.node_id is not None + finally: + executor.shutdown() + + +def test_ray_v2_multinode_worker_placement(): + """Verify TP locality: workers in the same TP group share a node.""" + tp_size = 4 + pp_size = 2 + + vllm_config = create_vllm_config( + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + ) + executor = RayExecutorV2(vllm_config=vllm_config) + try: + # Workers are sorted by rank; consecutive tp_size ranks form a TP group + for pp_rank in range(pp_size): + start_rank = pp_rank * tp_size + tp_group_handles = [ + h + for h in executor.ray_worker_handles + if start_rank <= h.rank < start_rank + tp_size + ] + tp_group_nodes = {h.node_id for h in tp_group_handles} + assert len(tp_group_nodes) == 1 + + # Workers should be distributed across > 1 node + all_nodes = {h.node_id for h in executor.ray_worker_handles} + assert len(all_nodes) > 1 + finally: + executor.shutdown() + + +def test_ray_v2_multinode_generation(): + """End-to-end LLM generation with TP=4, PP=2 across multiple nodes.""" + llm = LLM( + model=MODEL, + tensor_parallel_size=4, + pipeline_parallel_size=2, + distributed_executor_backend="ray", + enforce_eager=True, + max_model_len=256, + gpu_memory_utilization=0.3, + ) + try: + prompts = [ + "Hello, my name is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts) + + assert len(outputs) == len(prompts) + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text) > 0 + finally: + llm.llm_engine.model_executor.shutdown() + del llm + gc.collect() + + +@pytest.mark.parametrize("tp_size, pp_size", [(4, 2), (2, 4)]) +def test_ray_v2_multinode_generation_with_pg(tp_size, pp_size): + """E2E LLM generation with a user-provided placement group across nodes.""" + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + + bundles = [{"GPU": 1, "CPU": 1} for _ in range(tp_size * pp_size)] + pg = ray.util.placement_group(bundles, strategy="PACK") + ray.get(pg.ready()) + + try: + with patch.object(ray.util, "get_current_placement_group", return_value=pg): + llm = LLM( + model=MODEL, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + distributed_executor_backend="ray", + enforce_eager=True, + max_model_len=256, + gpu_memory_utilization=0.3, + ) + prompts = [ + "Hello, my name is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts) + + assert len(outputs) == len(prompts) + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text) > 0 + finally: + llm.llm_engine.model_executor.shutdown() + del llm + gc.collect() diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 50cc4ad287b9..a9909fb8bada 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -75,7 +75,9 @@ def __init__( distributed_init_method: str, input_shm_handle: Handle, is_driver_worker: bool, + same_node_as_executor: bool = False, ): + self._same_node_as_executor = same_node_as_executor super().__init__( vllm_config=vllm_config, local_rank=local_rank, @@ -87,6 +89,24 @@ def __init__( ) self.local_rank = local_rank + def _init_message_queues( + self, input_shm_handle: Handle, vllm_config: VllmConfig + ) -> None: + """ + Workers on the same node as the executor use shared memory for + both the broadcast (input) MQ and the response MQ. Workers on + different nodes use TCP (n_local_reader=0). + """ + self.rpc_broadcast_mq = MessageQueue.create_from_handle( + input_shm_handle, self.worker.rank + ) + + n_local = 1 if self._same_node_as_executor else 0 + self.worker_response_mq = MessageQueue( + n_reader=1, n_local_reader=n_local, connect_ip=get_ip() + ) + self.peer_response_handles = [] + def wait_for_init(self) -> dict: """Respond to the driver's wait_until_ready() barrier.""" assert self.worker_response_mq is not None @@ -133,7 +153,7 @@ def _init_executor(self) -> None: # Step 1: Initialize Ray cluster and retrieve placement group if ray is None: raise ImportError("Ray is required for RayExecutorV2") - initialize_ray_cluster(self.parallel_config) + initialize_ray_cluster(self.parallel_config, require_gpu_on_driver=False) placement_group = self.parallel_config.placement_group # Disable Ray usage stats collection @@ -195,22 +215,34 @@ def _sort_key(item): node_ids = list(dict.fromkeys(a["node_id"] for a in bundle_assignments)) is_single_node = len(node_ids) == 1 - # Step 3: Create broadcast MessageQueue - distributed_init_method = get_distributed_init_method( - get_loopback_ip() if is_single_node else get_ip(), get_open_port() - ) + # Step 3: Resolve the IP for torch.distributed TCPStore. + # The TCPStore server runs on rank 0's node, so all workers + # must be able to reach this address. + rank0_node_id = bundle_assignments[0]["node_id"] + if is_single_node: + dist_ip = get_loopback_ip() + elif rank0_node_id == driver_node: + dist_ip = get_ip() + else: + node_id_to_ip = { + n["NodeID"]: n["NodeManagerAddress"] for n in ray.nodes() if n["Alive"] + } + dist_ip = node_id_to_ip[rank0_node_id] + distributed_init_method = get_distributed_init_method(dist_ip, get_open_port()) + # Step 4: Create broadcast MessageQueue. + # Workers on the driver node use shared memory; the rest use TCP. max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 - mq_connect_ip = get_ip() + n_local = sum(1 for a in bundle_assignments if a["node_id"] == driver_node) self.rpc_broadcast_mq = MessageQueue( self.world_size, - self.local_world_size, + n_local, max_chunk_bytes=max_chunk_bytes, - connect_ip=mq_connect_ip, + connect_ip=get_ip(), ) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() - # Step 4: Spawn RayWorkerProc actors into PG bundles + # Step 5: Spawn RayWorkerProc actors into PG bundles self.ray_worker_handles: list[RayWorkerHandle] = [] self._ray_actors: list[Any] = [] @@ -247,6 +279,7 @@ def _sort_key(item): distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, is_driver_worker=is_driver_worker, + same_node_as_executor=(assignment["node_id"] == driver_node), ) ) @@ -260,7 +293,7 @@ def _sort_key(item): self.ray_worker_handles.append(handle) self._ray_actors.append(actor) - # Step 5: Collect response MQ handles + # Step 6: Collect response MQ handles init_refs = [h.actor.wait_for_init.remote() for h in self.ray_worker_handles] init_results = ray.get(init_refs) @@ -272,12 +305,12 @@ def _sort_key(item): MessageQueue.create_from_handle(result["handle"], 0) ) - # Step 6: Start run() before wait_until_ready() to avoid + # Step 7: Start run() before wait_until_ready() to avoid # deadlock — workers send subscriptions inside run(). for handle in self.ray_worker_handles: handle.run_ref = handle.actor.run.remote() - # Step 7: wait_until_ready() barrier + # Step 8: wait_until_ready() barrier self.rpc_broadcast_mq.wait_until_ready() for response_mq in self.response_mqs: response_mq.wait_until_ready() @@ -295,33 +328,47 @@ def start_worker_monitor(self, inline=False) -> None: return self_ref = weakref.ref(self) - ref_to_rank = { h.run_ref: h.rank for h in self.ray_worker_handles if h.run_ref is not None } - def monitor_workers(): - done, _ = ray.wait(run_refs, num_returns=1) + def _should_stop() -> bool: executor = self_ref() - if not executor or executor.shutting_down: - return - - dead_ranks = [ref_to_rank[ref] for ref in done if ref in ref_to_rank] - executor.is_failed = True - logger.error( - "RayWorkerProc rank=%s died unexpectedly, shutting down executor.", - dead_ranks, - ) - executor.shutdown() + return not executor or executor.shutting_down - if executor.failure_callback is not None: - callback = executor.failure_callback - executor.failure_callback = None - callback() + def monitor_workers(): + # TODO (jeffreywang): Is there a better way? + # Poll with timeout; a blocking ray.wait() would segfault + # if Ray is torn down while this thread is waiting. + while not _should_stop() and ray.is_initialized(): + try: + done, _ = ray.wait(run_refs, num_returns=1, timeout=5.0) + except Exception: + return + if not done or _should_stop(): + continue + + dead_ranks = [ref_to_rank[r] for r in done if r in ref_to_rank] + executor = self_ref() + if not executor: + return + executor.is_failed = True + logger.error( + "RayWorkerProc rank=%s died unexpectedly, shutting down executor.", + dead_ranks, + ) + executor.shutdown() + if executor.failure_callback is not None: + callback = executor.failure_callback + executor.failure_callback = None + callback() + return - threading.Thread( + t = threading.Thread( target=monitor_workers, daemon=True, name="RayWorkerMonitor" - ).start() + ) + t.start() + self._monitor_thread = t def shutdown(self) -> None: """Properly shut down the executor and its workers""" @@ -329,6 +376,13 @@ def shutdown(self) -> None: return self.shutting_down = True + # Wait for the monitor thread to exit before tearing down Ray + # resources — it may be inside ray.wait() which would segfault + # if Ray is shut down underneath it. + monitor = getattr(self, "_monitor_thread", None) + if monitor is not None and monitor.is_alive(): + monitor.join(timeout=10) + for handle in getattr(self, "ray_worker_handles", []): try: ray.kill(handle.actor) diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index dd82cfb99aac..f86db5cfb13a 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -214,13 +214,17 @@ def assert_ray_available(): def _verify_bundles( - placement_group: "PlacementGroup", parallel_config: ParallelConfig, device_str: str + placement_group: "PlacementGroup", + parallel_config: ParallelConfig, + device_str: str, + require_gpu_on_driver: bool = True, ): """Verify a given placement group has bundles located in the right place. There are 2 rules. - Warn if all tensor parallel workers cannot fit in a single node. - - Fail if driver node is not included in a placement group. + - Fail if driver node is not included in a placement group + (only when require_gpu_on_driver is True). """ assert ray.is_initialized(), ( "Ray is not initialized although distributed-executor-backend is ray." @@ -237,7 +241,7 @@ def _verify_bundles( node_id_to_bundle[node_id].append(bundles[bundle_idx]) driver_node_id = ray.get_runtime_context().get_node_id() - if driver_node_id not in node_id_to_bundle: + if require_gpu_on_driver and driver_node_id not in node_id_to_bundle: raise RuntimeError( f"driver node id {driver_node_id} is not included in a placement " f"group {placement_group.id}. Node id -> bundles " @@ -352,6 +356,7 @@ def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): def initialize_ray_cluster( parallel_config: ParallelConfig, ray_address: str | None = None, + require_gpu_on_driver: bool = True, ): """Initialize the distributed cluster with Ray. @@ -363,6 +368,10 @@ def initialize_ray_cluster( parallel_config: The configurations for parallel execution. ray_address: The address of the Ray cluster. If None, uses the default Ray cluster address. + require_gpu_on_driver: If True (default), require at least one GPU + on the current (driver) node and pin the first PG bundle to it. + Set to False for executors like RayExecutorV2 where all GPU work + is delegated to remote Ray actors. """ assert_ray_available() from vllm.platforms import current_platform @@ -461,16 +470,18 @@ def initialize_ray_cluster( current_ip = get_ip() current_node_id = ray.get_runtime_context().get_node_id() current_node_resource = available_resources_per_node()[current_node_id] - if current_node_resource.get(device_str, 0) < 1: - raise ValueError( - f"Current node has no {device_str} available. " - f"{current_node_resource=}. vLLM engine cannot start without " - f"{device_str}. Make sure you have at least 1 {device_str} " - f"available in a node {current_node_id=} {current_ip=}." - ) - # This way, at least bundle is required to be created in a current - # node. - placement_group_specs[0][f"node:{current_ip}"] = 0.001 + if require_gpu_on_driver: + if current_node_resource.get(device_str, 0) < 1: + raise ValueError( + f"Current node has no {device_str} available. " + f"{current_node_resource=}. vLLM engine cannot start " + f"without {device_str}. Make sure you have at least 1 " + f"{device_str} available in a node " + f"{current_node_id=} {current_ip=}." + ) + # This way, at least bundle is required to be created in a + # current node. + placement_group_specs[0][f"node:{current_ip}"] = 0.001 # By default, Ray packs resources as much as possible. current_placement_group = ray.util.placement_group( @@ -479,7 +490,9 @@ def initialize_ray_cluster( _wait_until_pg_ready(current_placement_group) assert current_placement_group is not None - _verify_bundles(current_placement_group, parallel_config, device_str) + _verify_bundles( + current_placement_group, parallel_config, device_str, require_gpu_on_driver + ) # Set the placement group in the parallel config parallel_config.placement_group = current_placement_group From bbaa21b791d94535cfc744b74fa28b69ca189263 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 16 Mar 2026 11:46:49 -0700 Subject: [PATCH 03/29] Fix pre-commit Signed-off-by: Jeffrey Wang --- tests/distributed/test_mq_tcp_multinode.py | 17 ++++++++--------- vllm/v1/executor/ray_executor_v2.py | 6 ++++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/distributed/test_mq_tcp_multinode.py b/tests/distributed/test_mq_tcp_multinode.py index 61f0b239b675..135ef11d7fa3 100644 --- a/tests/distributed/test_mq_tcp_multinode.py +++ b/tests/distributed/test_mq_tcp_multinode.py @@ -32,7 +32,8 @@ def main(): f"[Rank {rank}] in_the_same_node_as(source=0): {status} " f"(local={local_count}/{world_size})" ) - # With 2 Docker containers (1 proc each), rank 0 and rank 1 should be on different nodes. + # With 2 Docker containers (1 proc each), rank 0 and rank 1 + # should be on different nodes. assert local_count < world_size, ( f"Expected cross-node ranks but all {world_size} ranks appear local." ) @@ -53,20 +54,17 @@ def main(): f"n_remote_reader={mq.n_remote_reader}" ) assert mq.n_remote_reader > 0, ( - "Writer should have at least 1 remote (TCP) reader in a " - "multi-node setup." + "Writer should have at least 1 remote (TCP) reader in a multi-node setup." ) else: if status[rank]: assert mq._is_local_reader, ( - f"Rank {rank} is on the same node as writer but is not a " - "local reader." + f"Rank {rank} is on the same node as writer but is not a local reader." ) print(f"[Rank {rank}] Reader: local (shared memory)") else: assert mq._is_remote_reader, ( - f"Rank {rank} is on a different node but is not a remote " - "(TCP) reader." + f"Rank {rank} is on a different node but is not a remote (TCP) reader." ) print(f"[Rank {rank}] Reader: remote (TCP)") @@ -82,8 +80,9 @@ def main(): # Test data transfer: numpy arrays np.random.seed(42) - arrays = [np.random.randint(0, 100, size=np.random.randint(100, 5000)) - for _ in range(100)] + arrays = [ + np.random.randint(0, 100, size=np.random.randint(100, 5000)) for _ in range(100) + ] dist.barrier() if rank == writer_rank: diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index a9909fb8bada..e342d3c21193 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -105,7 +105,7 @@ def _init_message_queues( self.worker_response_mq = MessageQueue( n_reader=1, n_local_reader=n_local, connect_ip=get_ip() ) - self.peer_response_handles = [] + self.peer_response_handles: list[dict] = [] def wait_for_init(self) -> dict: """Respond to the driver's wait_until_ready() barrier.""" @@ -175,7 +175,9 @@ def _init_executor(self) -> None: # Prefer driver node; group by node for TP locality bundle_to_node_id = [] - for i, bundle in enumerate(placement_group.bundle_specs): + bundle_specs = placement_group.bundle_specs + assert bundle_specs is not None + for i, bundle in enumerate(bundle_specs): ray_device_key = current_platform.ray_device_key if not ray_device_key: raise ValueError( From 2541f2d5471ec049c49e72ce74b32bbb247f79c7 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 16 Mar 2026 13:13:24 -0700 Subject: [PATCH 04/29] Fix RayExecutorV2 monitor thread self-join Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor.py | 17 +++++++++------ vllm/v1/executor/ray_executor_v2.py | 25 +++++++++++++++++------ 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index 4c799d69ecf9..31d7b331028d 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -58,6 +58,7 @@ def _cleanup_ray_resources(): return # Ray actor shutdown is async -- wait until all actors are dead + dangling_actors = [] for _ in range(10): dangling_actors = [ actor @@ -67,14 +68,18 @@ def _cleanup_ray_resources(): if not dangling_actors: break time.sleep(1) - assert not dangling_actors - for pg_id, pg_info in ray.util.placement_group_table().items(): - if pg_info["state"] == "CREATED": - pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) - ray.util.remove_placement_group(pg) + # Always clean up PGs and shut down Ray, even if actors are dangling, + # to avoid leaking GPU resources and blocking subsequent tests. + try: + for pg_id, pg_info in ray.util.placement_group_table().items(): + if pg_info["state"] == "CREATED": + pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) + ray.util.remove_placement_group(pg) + finally: + ray.shutdown() - ray.shutdown() + assert not dangling_actors def create_vllm_config( diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index e342d3c21193..68b8752d0c26 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -372,18 +372,31 @@ def monitor_workers(): t.start() self._monitor_thread = t + def _join_monitor_thread(self) -> None: + """Wait for the monitor thread to exit. + + Must be called before tearing down Ray resources — the monitor + may be inside ray.wait() which would segfault if Ray is shut + down underneath it. When the monitor itself calls shutdown() + (on worker death), we skip the join because the thread is + about to return anyway. + """ + monitor = getattr(self, "_monitor_thread", None) + if ( + monitor is not None + and monitor.is_alive() + and threading.current_thread() is not monitor + ): + monitor.join(timeout=10) + def shutdown(self) -> None: """Properly shut down the executor and its workers""" if getattr(self, "shutting_down", False): + self._join_monitor_thread() return self.shutting_down = True - # Wait for the monitor thread to exit before tearing down Ray - # resources — it may be inside ray.wait() which would segfault - # if Ray is shut down underneath it. - monitor = getattr(self, "_monitor_thread", None) - if monitor is not None and monitor.is_alive(): - monitor.join(timeout=10) + self._join_monitor_thread() for handle in getattr(self, "ray_worker_handles", []): try: From c3ad8e55ee20353f5bbcf0e963769b2eb6c65f88 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 17 Mar 2026 05:38:06 +0000 Subject: [PATCH 05/29] Remove unnecessary changes Signed-off-by: Jeffrey Wang --- .buildkite/test_areas/distributed.yaml | 15 -- tests/distributed/test_ray_v2_executor.py | 34 +-- .../test_ray_v2_executor_multinode.py | 252 ------------------ vllm/v1/executor/ray_executor_v2.py | 1 + 4 files changed, 20 insertions(+), 282 deletions(-) delete mode 100644 tests/distributed/test_ray_v2_executor_multinode.py diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 8fee8dacc266..145f4cc4418f 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -271,18 +271,3 @@ steps: - pytest -v -s distributed/test_ray_v2_executor.py - pytest -v -s distributed/test_pipeline_parallel.py -k "ray" - TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray" - -- label: RayExecutorV2 Multi-Node (8 GPUs) - timeout_in_minutes: 30 - working_dir: "/vllm-workspace/tests" - num_devices: 4 - num_nodes: 2 - no_plugin: true - optional: true - source_file_dependencies: - - vllm/v1/executor/ray_executor_v2.py - - vllm/v1/executor/abstract.py - - vllm/v1/executor/multiproc_executor.py - - tests/distributed/test_ray_v2_executor_multinode.py - commands: - - ./.buildkite/scripts/run-multi-node-test.sh /vllm-workspace/tests 2 4 $IMAGE_TAG "VLLM_MULTI_NODE=1 VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 NCCL_CUMEM_HOST_ENABLE=0 pytest -v -s distributed/test_ray_v2_executor_multinode.py" "echo 'Worker node ready'" diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index 31d7b331028d..634865b36334 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -33,17 +33,15 @@ def enable_ray_v2_backend(): "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" ), - "RAY_RUNTIME_ENV_HOOK": os.environ.get("RAY_RUNTIME_ENV_HOOK"), "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( "VLLM_ENABLE_V1_MULTIPROCESSING" ), } os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" - # TODO (jeffreywang): Is this necessary? + # The multiprocess engine forks a subprocess that inherits the Ray + # driver connection, causing hangs. RayExecutorV2 already distributes + # work via Ray actors, so the EngineCore can run safely in-process. os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - # TODO (jeffreywang): Figure out vLLM CI - # This is only necessary for Anyscale ray cluster. - os.environ.pop("RAY_RUNTIME_ENV_HOOK", None) try: yield finally: @@ -57,17 +55,21 @@ def _cleanup_ray_resources(): if not ray.is_initialized(): return - # Ray actor shutdown is async -- wait until all actors are dead + # Ray actor shutdown is async -- wait until all actors are dead. dangling_actors = [] - for _ in range(10): - dangling_actors = [ - actor - for actor in list_actors(filters=[("state", "=", "ALIVE")]) - if actor.class_name == "RayWorkerProc" - ] - if not dangling_actors: - break - time.sleep(1) + try: + for _ in range(10): + dangling_actors = [ + actor + for actor in list_actors(filters=[("state", "=", "ALIVE")]) + if actor.class_name == "RayWorkerProc" + ] + if not dangling_actors: + break + time.sleep(1) + except Exception: + # Tolerate connection errors to the Ray dashboard + pass # Always clean up PGs and shut down Ray, even if actors are dangling, # to avoid leaking GPU resources and blocking subsequent tests. @@ -76,6 +78,8 @@ def _cleanup_ray_resources(): if pg_info["state"] == "CREATED": pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) ray.util.remove_placement_group(pg) + except Exception: + pass finally: ray.shutdown() diff --git a/tests/distributed/test_ray_v2_executor_multinode.py b/tests/distributed/test_ray_v2_executor_multinode.py deleted file mode 100644 index 09538b0f8557..000000000000 --- a/tests/distributed/test_ray_v2_executor_multinode.py +++ /dev/null @@ -1,252 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -Multi-node integration tests for RayExecutorV2. -Validates executor initialization, worker placement, and end-to-end -generation across multiple nodes with various TP/PP configurations. - -Requires VLLM_MULTI_NODE=1 env var and a multi-node Ray cluster. - -Run: -```sh -VLLM_MULTI_NODE=1 VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 \ - pytest -v -s distributed/test_ray_v2_executor_multinode.py -``` -""" - -import gc -import os -import time -from unittest.mock import patch - -import pytest -import ray -from ray.util.placement_group import PlacementGroup -from ray.util.state import list_actors - -from vllm import LLM -from vllm.config import VllmConfig -from vllm.engine.arg_utils import EngineArgs -from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 - -MODEL = "facebook/opt-125m" - -VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" - -pytestmark = pytest.mark.skipif( - not VLLM_MULTI_NODE, reason="Need VLLM_MULTI_NODE=1 and a multi-node cluster." -) - - -@pytest.fixture(autouse=True) -def enable_ray_v2_backend(): - """Enable the RayExecutorV2 backend via feature flag for all tests.""" - saved = { - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" - ), - "RAY_RUNTIME_ENV_HOOK": os.environ.get("RAY_RUNTIME_ENV_HOOK"), - "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( - "VLLM_ENABLE_V1_MULTIPROCESSING" - ), - } - os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" - os.environ.pop("RAY_RUNTIME_ENV_HOOK", None) - # Disable multiprocessing to avoid fork-after-Ray-init issues. - # The RayExecutorV2 already handles distribution via Ray actors, - # so the EngineCore can safely run in-process. - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - try: - yield - finally: - _cleanup_ray_resources() - os.environ.update({k: v for k, v in saved.items() if v is not None}) - for key in (k for k, v in saved.items() if v is None): - os.environ.pop(key, None) - - -def _cleanup_ray_resources(): - if not ray.is_initialized(): - return - - # Wait briefly for async cleanup (del llm triggers GC-based shutdown) - for _ in range(10): - dangling_actors = [ - actor - for actor in list_actors(filters=[("state", "=", "ALIVE")]) - if actor.class_name == "RayWorkerProc" - ] - if not dangling_actors: - break - time.sleep(1) - - for pg_id, pg_info in ray.util.placement_group_table().items(): - if pg_info["state"] == "CREATED": - pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) - ray.util.remove_placement_group(pg) - - # Disconnect from Ray so forked subprocesses (EngineCore) don't inherit - # a stale driver connection that can't create placement groups. - ray.shutdown() - - -def create_vllm_config( - tensor_parallel_size: int = 1, - pipeline_parallel_size: int = 1, - max_model_len: int = 256, - gpu_memory_utilization: float = 0.3, -) -> VllmConfig: - engine_args = EngineArgs( - model=MODEL, - tensor_parallel_size=tensor_parallel_size, - pipeline_parallel_size=pipeline_parallel_size, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - distributed_executor_backend="ray", - enforce_eager=True, - ) - return engine_args.create_engine_config() - - -def assert_executor(executor, tp_size, pp_size): - """Common assertions for executor initialization tests.""" - world_size = tp_size * pp_size - expected_output_rank = (pp_size - 1) * tp_size - - assert executor.world_size == world_size - assert len(executor.ray_worker_handles) == world_size - assert len(executor.response_mqs) == world_size - assert executor._get_output_rank() == expected_output_rank - - if pp_size > 1: - assert executor.max_concurrent_batches == pp_size - - executor.check_health() - assert not executor.is_failed - - ranks = sorted(h.rank for h in executor.ray_worker_handles) - assert ranks == list(range(world_size)) - - for handle in executor.ray_worker_handles: - assert handle.node_id is not None - - -def test_ray_v2_multinode_executor_init(): - """Validate RayExecutorV2 initializes correctly across multiple nodes - with TP=4, PP=2 (8 GPUs).""" - vllm_config = create_vllm_config( - tensor_parallel_size=4, - pipeline_parallel_size=2, - ) - executor = RayExecutorV2(vllm_config=vllm_config) - try: - assert_executor(executor, tp_size=4, pp_size=2) - - # Verify workers span multiple nodes - node_ids = {h.node_id for h in executor.ray_worker_handles} - assert len(node_ids) > 1 - - # Verify rank 0 exists and has a valid node_id. - # On clusters where the driver node has GPUs, rank 0 will be there. - # On GPU-less head nodes, rank 0 is on the first GPU node instead. - rank0_handle = next(h for h in executor.ray_worker_handles if h.rank == 0) - assert rank0_handle.node_id is not None - finally: - executor.shutdown() - - -def test_ray_v2_multinode_worker_placement(): - """Verify TP locality: workers in the same TP group share a node.""" - tp_size = 4 - pp_size = 2 - - vllm_config = create_vllm_config( - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - ) - executor = RayExecutorV2(vllm_config=vllm_config) - try: - # Workers are sorted by rank; consecutive tp_size ranks form a TP group - for pp_rank in range(pp_size): - start_rank = pp_rank * tp_size - tp_group_handles = [ - h - for h in executor.ray_worker_handles - if start_rank <= h.rank < start_rank + tp_size - ] - tp_group_nodes = {h.node_id for h in tp_group_handles} - assert len(tp_group_nodes) == 1 - - # Workers should be distributed across > 1 node - all_nodes = {h.node_id for h in executor.ray_worker_handles} - assert len(all_nodes) > 1 - finally: - executor.shutdown() - - -def test_ray_v2_multinode_generation(): - """End-to-end LLM generation with TP=4, PP=2 across multiple nodes.""" - llm = LLM( - model=MODEL, - tensor_parallel_size=4, - pipeline_parallel_size=2, - distributed_executor_backend="ray", - enforce_eager=True, - max_model_len=256, - gpu_memory_utilization=0.3, - ) - try: - prompts = [ - "Hello, my name is", - "The capital of France is", - "The future of AI is", - ] - outputs = llm.generate(prompts) - - assert len(outputs) == len(prompts) - for output in outputs: - assert len(output.outputs) > 0 - assert len(output.outputs[0].text) > 0 - finally: - llm.llm_engine.model_executor.shutdown() - del llm - gc.collect() - - -@pytest.mark.parametrize("tp_size, pp_size", [(4, 2), (2, 4)]) -def test_ray_v2_multinode_generation_with_pg(tp_size, pp_size): - """E2E LLM generation with a user-provided placement group across nodes.""" - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - - bundles = [{"GPU": 1, "CPU": 1} for _ in range(tp_size * pp_size)] - pg = ray.util.placement_group(bundles, strategy="PACK") - ray.get(pg.ready()) - - try: - with patch.object(ray.util, "get_current_placement_group", return_value=pg): - llm = LLM( - model=MODEL, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - distributed_executor_backend="ray", - enforce_eager=True, - max_model_len=256, - gpu_memory_utilization=0.3, - ) - prompts = [ - "Hello, my name is", - "The capital of France is", - "The future of AI is", - ] - outputs = llm.generate(prompts) - - assert len(outputs) == len(prompts) - for output in outputs: - assert len(output.outputs) > 0 - assert len(output.outputs[0].text) > 0 - finally: - llm.llm_engine.model_executor.shutdown() - del llm - gc.collect() diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 68b8752d0c26..7ebe35ba05de 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -175,6 +175,7 @@ def _init_executor(self) -> None: # Prefer driver node; group by node for TP locality bundle_to_node_id = [] + assert placement_group is not None bundle_specs = placement_group.bundle_specs assert bundle_specs is not None for i, bundle in enumerate(bundle_specs): From 300d0ae3ecc18a465579ce36eeeca37603d9e3d7 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 17 Mar 2026 06:51:27 +0000 Subject: [PATCH 06/29] Extract bundle sorting to a utility Signed-off-by: Jeffrey Wang --- tests/utils_/test_ray_utils.py | 94 +++++++++++++++++++++++++++++ vllm/v1/executor/ray_executor_v2.py | 43 +++---------- vllm/v1/executor/ray_utils.py | 38 ++++++++++++ 3 files changed, 142 insertions(+), 33 deletions(-) create mode 100644 tests/utils_/test_ray_utils.py diff --git a/tests/utils_/test_ray_utils.py b/tests/utils_/test_ray_utils.py new file mode 100644 index 000000000000..f52ec9af07b4 --- /dev/null +++ b/tests/utils_/test_ray_utils.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import MagicMock, patch + +import pytest + +NODE_A = "node_a" +NODE_B = "node_b" +NODE_C = "node_c" + + +@pytest.mark.parametrize( + "bundles_to_node_id,bundle_specs,world_size,expected", + [ + pytest.param( + {0: NODE_C, 1: NODE_A, 2: NODE_B, 3: NODE_C, 4: NODE_A, + 5: NODE_B}, + [{"GPU": 1}] * 6, + 6, + [(1, NODE_A), (4, NODE_A), (2, NODE_B), (5, NODE_B), + (0, NODE_C), (3, NODE_C)], + ), + pytest.param( + {0: NODE_B, 1: NODE_B, 2: NODE_A, 3: NODE_A}, + [{"GPU": 1}] * 4, + 4, + [(2, NODE_A), (3, NODE_A), (0, NODE_B), (1, NODE_B)], + ), + pytest.param( + {0: NODE_C, 1: NODE_B, 2: NODE_C, 3: NODE_B}, + [{"GPU": 1}] * 4, + 4, + [(1, NODE_B), (3, NODE_B), (0, NODE_C), (2, NODE_C)], + ), + pytest.param( + {0: NODE_A, 1: NODE_A, 2: NODE_A}, + [{"GPU": 1}] * 3, + 3, + [(0, NODE_A), (1, NODE_A), (2, NODE_A)], + ), + pytest.param( + {0: NODE_A}, + [{"GPU": 1}], + 1, + [(0, NODE_A)], + ), + pytest.param( + {}, + [], + 0, + [], + ), + pytest.param( + {0: NODE_A, 1: NODE_B, 2: NODE_A, 3: NODE_B}, + [{"GPU": 1}] * 4, + 2, + [(0, NODE_A), (1, NODE_B)], + ), + pytest.param( + {0: NODE_A, 1: NODE_B, 2: NODE_A}, + [{"CPU": 1}, {"GPU": 1}, {"GPU": 1}], + 2, + [(2, NODE_A), (1, NODE_B)], + ), + ], +) +def test_get_bundles_sorted_by_node( + bundles_to_node_id, bundle_specs, world_size, expected +): + mock_pg = MagicMock() + mock_pg.bundle_specs = bundle_specs + + mock_ctx = MagicMock() + mock_ctx.get_node_id.return_value = NODE_A + + with ( + patch( + "vllm.v1.executor.ray_utils.placement_group_table", + return_value={"bundles_to_node_id": bundles_to_node_id}, + ), + patch("vllm.v1.executor.ray_utils.ray") as mock_ray, + patch( + "vllm.v1.executor.ray_utils.current_platform" + ) as mock_platform, + ): + mock_ray.get_runtime_context.return_value = mock_ctx + mock_platform.ray_device_key = "GPU" + + from vllm.v1.executor.ray_utils import get_bundles_sorted_by_node + + result = get_bundles_sorted_by_node(mock_pg, world_size) + + assert result == expected diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 7ebe35ba05de..0fe90cb8fa30 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -27,6 +27,7 @@ WorkerProc, ) from vllm.v1.executor.ray_utils import ( + get_bundles_sorted_by_node, initialize_ray_cluster, ray, ) @@ -170,35 +171,11 @@ def _init_executor(self) -> None: ) # Step 2: Query PG table, sort bundles, assign ranks - pg_table = ray.util.placement_group_table(placement_group) - bundle_to_node = pg_table["bundles_to_node_id"] - - # Prefer driver node; group by node for TP locality - bundle_to_node_id = [] - assert placement_group is not None - bundle_specs = placement_group.bundle_specs - assert bundle_specs is not None - for i, bundle in enumerate(bundle_specs): - ray_device_key = current_platform.ray_device_key - if not ray_device_key: - raise ValueError( - f"current platform {current_platform.device_name}" - " does not support ray." - ) - - if bundle.get(ray_device_key): - node_id = bundle_to_node.get(i) or bundle_to_node.get(str(i)) - bundle_to_node_id.append((i, node_id)) - - bundle_to_node_id = bundle_to_node_id[: self.world_size] + bundle_to_node_id = get_bundles_sorted_by_node( + placement_group, self.world_size + ) driver_node = ray.get_runtime_context().get_node_id() - def _sort_key(item): - _, node_id = item - return (0 if node_id == driver_node else 1, node_id) - - bundle_to_node_id.sort(key=_sort_key) - # Assign each worker a local rank node_rank_counter: dict[str, int] = defaultdict(int) bundle_assignments: list[dict[str, Any]] = [] @@ -340,9 +317,9 @@ def _should_stop() -> bool: return not executor or executor.shutting_down def monitor_workers(): - # TODO (jeffreywang): Is there a better way? - # Poll with timeout; a blocking ray.wait() would segfault - # if Ray is torn down while this thread is waiting. + # Poll with a timeout rather than blocking on ray.wait() + # because a blocking call would segfault if Ray is torn down + # while this thread is inside it. while not _should_stop() and ray.is_initialized(): try: done, _ = ray.wait(run_refs, num_returns=1, timeout=5.0) @@ -378,9 +355,9 @@ def _join_monitor_thread(self) -> None: Must be called before tearing down Ray resources — the monitor may be inside ray.wait() which would segfault if Ray is shut - down underneath it. When the monitor itself calls shutdown() - (on worker death), we skip the join because the thread is - about to return anyway. + down underneath it. When the monitor itself calls shutdown() + on worker death, we skip the join because the thread is about + to return anyway. """ monitor = getattr(self, "_monitor_thread", None) if ( diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index f86db5cfb13a..783c97a76c7e 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -270,6 +270,44 @@ def _verify_bundles( ) +def get_bundles_sorted_by_node( + placement_group: "PlacementGroup", + world_size: int, +) -> list[tuple[int, str]]: + """ + Return GPU bundle indices paired with node IDs, sorted driver-first. + + This utility has to be invoked from the driver node. + """ + pg_data = placement_group_table(placement_group) + bundle_to_node = pg_data["bundles_to_node_id"] + + ray_device_key = current_platform.ray_device_key + if not ray_device_key: + raise ValueError( + f"current platform {current_platform.device_name}" + " does not support ray." + ) + + bundle_specs = placement_group.bundle_specs + assert bundle_specs is not None + bundle_to_node_id: list[tuple[int, str]] = [] + for i, bundle in enumerate(bundle_specs): + if bundle.get(ray_device_key): + node_id = bundle_to_node.get(i) or bundle_to_node.get(str(i)) + bundle_to_node_id.append((i, node_id)) + + bundle_to_node_id = bundle_to_node_id[:world_size] + driver_node = ray.get_runtime_context().get_node_id() + + def _sort_key(item): + _, node_id = item + return (0 if node_id == driver_node else 1, node_id) + + bundle_to_node_id.sort(key=_sort_key) + return bundle_to_node_id + + def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): """Wait until a placement group is ready. From 11d32eb00743b0665629493674345f6506e1ada4 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 17 Mar 2026 07:28:06 +0000 Subject: [PATCH 07/29] Fix linter Signed-off-by: Jeffrey Wang --- tests/utils_/test_ray_utils.py | 17 ++++++++++------- vllm/v1/executor/ray_executor_v2.py | 4 +--- vllm/v1/executor/ray_utils.py | 3 +-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/utils_/test_ray_utils.py b/tests/utils_/test_ray_utils.py index f52ec9af07b4..bc462576e823 100644 --- a/tests/utils_/test_ray_utils.py +++ b/tests/utils_/test_ray_utils.py @@ -14,12 +14,17 @@ "bundles_to_node_id,bundle_specs,world_size,expected", [ pytest.param( - {0: NODE_C, 1: NODE_A, 2: NODE_B, 3: NODE_C, 4: NODE_A, - 5: NODE_B}, + {0: NODE_C, 1: NODE_A, 2: NODE_B, 3: NODE_C, 4: NODE_A, 5: NODE_B}, [{"GPU": 1}] * 6, 6, - [(1, NODE_A), (4, NODE_A), (2, NODE_B), (5, NODE_B), - (0, NODE_C), (3, NODE_C)], + [ + (1, NODE_A), + (4, NODE_A), + (2, NODE_B), + (5, NODE_B), + (0, NODE_C), + (3, NODE_C), + ], ), pytest.param( {0: NODE_B, 1: NODE_B, 2: NODE_A, 3: NODE_A}, @@ -80,9 +85,7 @@ def test_get_bundles_sorted_by_node( return_value={"bundles_to_node_id": bundles_to_node_id}, ), patch("vllm.v1.executor.ray_utils.ray") as mock_ray, - patch( - "vllm.v1.executor.ray_utils.current_platform" - ) as mock_platform, + patch("vllm.v1.executor.ray_utils.current_platform") as mock_platform, ): mock_ray.get_runtime_context.return_value = mock_ctx mock_platform.ray_device_key = "GPU" diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 0fe90cb8fa30..a1bd1948e616 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -171,9 +171,7 @@ def _init_executor(self) -> None: ) # Step 2: Query PG table, sort bundles, assign ranks - bundle_to_node_id = get_bundles_sorted_by_node( - placement_group, self.world_size - ) + bundle_to_node_id = get_bundles_sorted_by_node(placement_group, self.world_size) driver_node = ray.get_runtime_context().get_node_id() # Assign each worker a local rank diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 783c97a76c7e..e2765b2b8267 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -285,8 +285,7 @@ def get_bundles_sorted_by_node( ray_device_key = current_platform.ray_device_key if not ray_device_key: raise ValueError( - f"current platform {current_platform.device_name}" - " does not support ray." + f"current platform {current_platform.device_name} does not support ray." ) bundle_specs = placement_group.bundle_specs From 5795f1d4d19d91b0e36ecf05dc8af215fb612778 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Wed, 18 Mar 2026 01:45:58 +0000 Subject: [PATCH 08/29] Enable async scheduling Signed-off-by: Jeffrey Wang --- vllm/config/vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index dc776fac1469..96f4c69b0d07 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -687,7 +687,7 @@ def __post_init__(self): "mp", "uni", "external_launcher", - ) + ) or (executor_backend == "ray" and envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND) if self.scheduler_config.async_scheduling: # Async scheduling explicitly enabled, hard fail any incompatibilities. From 712807466a29cc1d469867588955bd8d920adfe7 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 19 Mar 2026 03:24:51 +0000 Subject: [PATCH 09/29] Address CR feedback Signed-off-by: Jeffrey Wang --- tests/utils_/test_ray_utils.py | 48 ++++++++++++----- vllm/v1/executor/ray_executor_v2.py | 82 ++++++++++++----------------- vllm/v1/executor/ray_utils.py | 42 ++++++++++++--- 3 files changed, 103 insertions(+), 69 deletions(-) diff --git a/tests/utils_/test_ray_utils.py b/tests/utils_/test_ray_utils.py index bc462576e823..57a4797f0c3d 100644 --- a/tests/utils_/test_ray_utils.py +++ b/tests/utils_/test_ray_utils.py @@ -9,6 +9,18 @@ NODE_B = "node_b" NODE_C = "node_c" +IP_A = "10.0.0.1" +IP_B = "10.0.0.2" +IP_C = "10.0.0.3" + +NODE_ID_TO_IP = {NODE_A: IP_A, NODE_B: IP_B, NODE_C: IP_C} + +MOCK_RAY_NODES = [ + {"NodeID": NODE_A, "NodeManagerAddress": IP_A, "Alive": True}, + {"NodeID": NODE_B, "NodeManagerAddress": IP_B, "Alive": True}, + {"NodeID": NODE_C, "NodeManagerAddress": IP_C, "Alive": True}, +] + @pytest.mark.parametrize( "bundles_to_node_id,bundle_specs,world_size,expected", @@ -18,37 +30,47 @@ [{"GPU": 1}] * 6, 6, [ - (1, NODE_A), - (4, NODE_A), - (2, NODE_B), - (5, NODE_B), - (0, NODE_C), - (3, NODE_C), + (1, NODE_A, IP_A), + (4, NODE_A, IP_A), + (2, NODE_B, IP_B), + (5, NODE_B, IP_B), + (0, NODE_C, IP_C), + (3, NODE_C, IP_C), ], ), pytest.param( {0: NODE_B, 1: NODE_B, 2: NODE_A, 3: NODE_A}, [{"GPU": 1}] * 4, 4, - [(2, NODE_A), (3, NODE_A), (0, NODE_B), (1, NODE_B)], + [ + (2, NODE_A, IP_A), + (3, NODE_A, IP_A), + (0, NODE_B, IP_B), + (1, NODE_B, IP_B), + ], ), pytest.param( {0: NODE_C, 1: NODE_B, 2: NODE_C, 3: NODE_B}, [{"GPU": 1}] * 4, 4, - [(1, NODE_B), (3, NODE_B), (0, NODE_C), (2, NODE_C)], + [ + (1, NODE_B, IP_B), + (3, NODE_B, IP_B), + (0, NODE_C, IP_C), + (2, NODE_C, IP_C), + ], ), pytest.param( {0: NODE_A, 1: NODE_A, 2: NODE_A}, [{"GPU": 1}] * 3, 3, - [(0, NODE_A), (1, NODE_A), (2, NODE_A)], + [(0, NODE_A, IP_A), (1, NODE_A, IP_A), (2, NODE_A, IP_A)], ), pytest.param( {0: NODE_A}, [{"GPU": 1}], 1, - [(0, NODE_A)], + [(0, NODE_A, IP_A)], ), pytest.param( {}, @@ -60,13 +82,14 @@ {0: NODE_A, 1: NODE_B, 2: NODE_A, 3: NODE_B}, [{"GPU": 1}] * 4, 2, - [(0, NODE_A), (1, NODE_B)], + # After sort-then-clip, driver node (NODE_A) bundles are prioritized + [(0, NODE_A, IP_A), (2, NODE_A, IP_A)], ), pytest.param( {0: NODE_A, 1: NODE_B, 2: NODE_A}, [{"CPU": 1}, {"GPU": 1}, {"GPU": 1}], 2, - [(2, NODE_A), (1, NODE_B)], + [(2, NODE_A, IP_A), (1, NODE_B, IP_B)], ), ], ) @@ -88,6 +111,7 @@ def test_get_bundles_sorted_by_node( patch("vllm.v1.executor.ray_utils.current_platform") as mock_platform, ): mock_ray.get_runtime_context.return_value = mock_ctx + mock_ray.nodes.return_value = MOCK_RAY_NODES mock_platform.ray_device_key = "GPU" from vllm.v1.executor.ray_utils import get_bundles_sorted_by_node diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index a1bd1948e616..9e76ef11a320 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import threading import weakref from collections import defaultdict, deque @@ -17,8 +16,6 @@ from vllm.platforms import current_platform from vllm.utils.network_utils import ( get_distributed_init_method, - get_ip, - get_loopback_ip, get_open_port, ) from vllm.v1.executor.multiproc_executor import ( @@ -27,6 +24,7 @@ WorkerProc, ) from vllm.v1.executor.ray_utils import ( + build_actor_name, get_bundles_sorted_by_node, initialize_ray_cluster, ray, @@ -58,9 +56,6 @@ class RayWorkerHandle: node_id: str """Node ID of the worker""" - bundle_index: int - """Placement group bundle index to schedule the actor on""" - run_ref: ObjectRef = None """run() ObjectRef used as a sentinel for health monitoring""" @@ -76,9 +71,10 @@ def __init__( distributed_init_method: str, input_shm_handle: Handle, is_driver_worker: bool, - same_node_as_executor: bool = False, + is_driver_node: bool = False, ): - self._same_node_as_executor = same_node_as_executor + self._is_driver_node = is_driver_node + self.local_rank = local_rank super().__init__( vllm_config=vllm_config, local_rank=local_rank, @@ -88,7 +84,6 @@ def __init__( shared_worker_lock=None, is_driver_worker=is_driver_worker, ) - self.local_rank = local_rank def _init_message_queues( self, input_shm_handle: Handle, vllm_config: VllmConfig @@ -102,9 +97,14 @@ def _init_message_queues( input_shm_handle, self.worker.rank ) - n_local = 1 if self._same_node_as_executor else 0 + n_local = 1 if self._is_driver_node else 0 + # Use ray.util.get_node_ip_address() to get Ray's internal IP. + # get_ip() returns host's external IP which is typically not + # routable between nodes within the cluster. self.worker_response_mq = MessageQueue( - n_reader=1, n_local_reader=n_local, connect_ip=get_ip() + n_reader=1, + n_local_reader=n_local, + connect_ip=ray.util.get_node_ip_address(), ) self.peer_response_handles: list[dict] = [] @@ -140,8 +140,6 @@ class RayExecutorV2(MultiprocExecutor): supports_pp: bool = True def __init__(self, vllm_config: VllmConfig): - # Skip MultiprocExecutor.__init__; we monitor via ray.wait() - self.monitor_workers = False super(MultiprocExecutor, self).__init__(vllm_config) def _init_executor(self) -> None: @@ -157,11 +155,6 @@ def _init_executor(self) -> None: initialize_ray_cluster(self.parallel_config, require_gpu_on_driver=False) placement_group = self.parallel_config.placement_group - # Disable Ray usage stats collection - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - tp_size, pp_size, pcp_size = self._get_parallel_sizes() assert self.world_size == tp_size * pp_size * pcp_size, ( f"world_size ({self.world_size}) must be equal to the " @@ -177,35 +170,23 @@ def _init_executor(self) -> None: # Assign each worker a local rank node_rank_counter: dict[str, int] = defaultdict(int) bundle_assignments: list[dict[str, Any]] = [] - for rank, (bundle_id, node_id) in enumerate(bundle_to_node_id): + for rank, (bundle_id_idx, node_id, node_ip) in enumerate(bundle_to_node_id): local_rank = node_rank_counter[node_id] node_rank_counter[node_id] += 1 bundle_assignments.append( { "rank": rank, "local_rank": local_rank, - "bundle_id": bundle_id, + "bundle_id_idx": bundle_id_idx, "node_id": node_id, + "node_ip": node_ip, } ) - # Determine node topology - node_ids = list(dict.fromkeys(a["node_id"] for a in bundle_assignments)) - is_single_node = len(node_ids) == 1 - # Step 3: Resolve the IP for torch.distributed TCPStore. # The TCPStore server runs on rank 0's node, so all workers # must be able to reach this address. - rank0_node_id = bundle_assignments[0]["node_id"] - if is_single_node: - dist_ip = get_loopback_ip() - elif rank0_node_id == driver_node: - dist_ip = get_ip() - else: - node_id_to_ip = { - n["NodeID"]: n["NodeManagerAddress"] for n in ray.nodes() if n["Alive"] - } - dist_ip = node_id_to_ip[rank0_node_id] + dist_ip = bundle_assignments[0]["node_ip"] distributed_init_method = get_distributed_init_method(dist_ip, get_open_port()) # Step 4: Create broadcast MessageQueue. @@ -216,22 +197,22 @@ def _init_executor(self) -> None: self.world_size, n_local, max_chunk_bytes=max_chunk_bytes, - connect_ip=get_ip(), + connect_ip=ray.util.get_node_ip_address(), ) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Step 5: Spawn RayWorkerProc actors into PG bundles self.ray_worker_handles: list[RayWorkerHandle] = [] - self._ray_actors: list[Any] = [] + instance_id = self.vllm_config.instance_id # Create the remote actor - for assignment in bundle_assignments: - is_driver_worker = self._is_driver_worker(assignment["rank"]) + for bundle in bundle_assignments: + is_driver_worker = self._is_driver_worker(bundle["rank"]) + is_driver_node = bundle["node_id"] == driver_node scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=assignment["bundle_id"], + placement_group_bundle_index=bundle["bundle_id_idx"], ) # Prevent Ray from setting CUDA_VISIBLE_DEVICES @@ -242,9 +223,14 @@ def _init_executor(self) -> None: }, } + actor_name = build_actor_name( + instance_id, bundle["rank"], tp_size, pp_size, pcp_size + ) + actor = ( ray.remote(RayWorkerProc) .options( + name=actor_name, num_cpus=0, num_gpus=envs.VLLM_RAY_PER_WORKER_GPUS, scheduling_strategy=scheduling_strategy, @@ -252,24 +238,22 @@ def _init_executor(self) -> None: ) .remote( vllm_config=self.vllm_config, - local_rank=assignment["local_rank"], - rank=assignment["rank"], + local_rank=bundle["local_rank"], + rank=bundle["rank"], distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, is_driver_worker=is_driver_worker, - same_node_as_executor=(assignment["node_id"] == driver_node), + is_driver_node=is_driver_node, ) ) handle = RayWorkerHandle( actor=actor, - rank=assignment["rank"], - local_rank=assignment["local_rank"], - node_id=assignment["node_id"], - bundle_index=assignment["bundle_id"], + rank=bundle["rank"], + local_rank=bundle["local_rank"], + node_id=bundle["node_id"], ) self.ray_worker_handles.append(handle) - self._ray_actors.append(actor) # Step 6: Collect response MQ handles init_refs = [h.actor.wait_for_init.remote() for h in self.ray_worker_handles] @@ -303,7 +287,7 @@ def start_worker_monitor(self, inline=False) -> None: """Monitor worker liveness via ray.wait() on run() ObjectRefs.""" run_refs = [h.run_ref for h in self.ray_worker_handles if h.run_ref is not None] if not run_refs: - return + raise RuntimeError("Ray workers have not started successfully.") self_ref = weakref.ref(self) ref_to_rank = { diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index e2765b2b8267..9b1dc2939ad7 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -270,12 +270,31 @@ def _verify_bundles( ) +def build_actor_name( + instance_id: str, + rank: int, + tp_size: int, + pp_size: int, + pcp_size: int, +) -> str: + """Build a descriptive Ray actor name for dashboard visibility.""" + name = f"vllm_Worker_{instance_id}" + if tp_size > 1: + name += f"_TP{rank % tp_size}" + if pp_size > 1: + name += f"_PP{(rank // tp_size) % pp_size}" + if pcp_size > 1: + name += f"_PCP{rank // (tp_size * pp_size)}" + return name + + def get_bundles_sorted_by_node( placement_group: "PlacementGroup", world_size: int, -) -> list[tuple[int, str]]: +) -> list[tuple[int, str, str]]: """ - Return GPU bundle indices paired with node IDs, sorted driver-first. + Return GPU bundle indices paired with node IDs and node IPs, + sorted driver-first. This utility has to be invoked from the driver node. """ @@ -288,23 +307,26 @@ def get_bundles_sorted_by_node( f"current platform {current_platform.device_name} does not support ray." ) + node_id_to_ip = { + n["NodeID"]: n["NodeManagerAddress"] for n in ray.nodes() if n["Alive"] + } + bundle_specs = placement_group.bundle_specs assert bundle_specs is not None - bundle_to_node_id: list[tuple[int, str]] = [] + bundle_to_node_id: list[tuple[int, str, str]] = [] for i, bundle in enumerate(bundle_specs): if bundle.get(ray_device_key): - node_id = bundle_to_node.get(i) or bundle_to_node.get(str(i)) - bundle_to_node_id.append((i, node_id)) + node_id = bundle_to_node.get(i) + bundle_to_node_id.append((i, node_id, node_id_to_ip[node_id])) - bundle_to_node_id = bundle_to_node_id[:world_size] driver_node = ray.get_runtime_context().get_node_id() def _sort_key(item): - _, node_id = item + _, node_id, _ = item return (0 if node_id == driver_node else 1, node_id) bundle_to_node_id.sort(key=_sort_key) - return bundle_to_node_id + return bundle_to_node_id[:world_size] def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): @@ -413,6 +435,10 @@ def initialize_ray_cluster( assert_ray_available() from vllm.platforms import current_platform + # Disable Ray usage stats collection + if os.environ.get("RAY_USAGE_STATS_ENABLED", "0") != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + # Prevalidate GPU requirements before Ray processing if current_platform.is_cuda() and parallel_config.world_size > 1: from vllm.utils.torch_utils import cuda_device_count_stateless From e7a3c1f20a543364409023cb9790c9d6fd80c29b Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 19 Mar 2026 03:40:54 +0000 Subject: [PATCH 10/29] Address test feedback Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index 634865b36334..17709e372270 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -71,19 +71,8 @@ def _cleanup_ray_resources(): # Tolerate connection errors to the Ray dashboard pass - # Always clean up PGs and shut down Ray, even if actors are dangling, - # to avoid leaking GPU resources and blocking subsequent tests. - try: - for pg_id, pg_info in ray.util.placement_group_table().items(): - if pg_info["state"] == "CREATED": - pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) - ray.util.remove_placement_group(pg) - except Exception: - pass - finally: - ray.shutdown() - assert not dangling_actors + ray.shutdown() def create_vllm_config( From ec2730d74e93ee45f426aeab9099335c713544c5 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 19 Mar 2026 18:06:23 +0000 Subject: [PATCH 11/29] Iterate over world_size Signed-off-by: Jeffrey Wang --- vllm/v1/executor/ray_executor_v2.py | 9 ++++++--- vllm/v1/executor/ray_utils.py | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 9e76ef11a320..d4e3c7066128 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -164,7 +164,7 @@ def _init_executor(self) -> None: ) # Step 2: Query PG table, sort bundles, assign ranks - bundle_to_node_id = get_bundles_sorted_by_node(placement_group, self.world_size) + bundle_to_node_id = get_bundles_sorted_by_node(placement_group) driver_node = ray.get_runtime_context().get_node_id() # Assign each worker a local rank @@ -205,8 +205,11 @@ def _init_executor(self) -> None: self.ray_worker_handles: list[RayWorkerHandle] = [] instance_id = self.vllm_config.instance_id - # Create the remote actor - for bundle in bundle_assignments: + # Create exactly world_size remote actors despite the number of bundles + # in the placement group. + for bundle_idx in range(self.world_size): + # Fail fast if the placement group has less than world_size bundles. + bundle = bundle_assignments[bundle_idx] is_driver_worker = self._is_driver_worker(bundle["rank"]) is_driver_node = bundle["node_id"] == driver_node diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 9b1dc2939ad7..45a86fdd5b62 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -290,7 +290,6 @@ def build_actor_name( def get_bundles_sorted_by_node( placement_group: "PlacementGroup", - world_size: int, ) -> list[tuple[int, str, str]]: """ Return GPU bundle indices paired with node IDs and node IPs, @@ -326,7 +325,7 @@ def _sort_key(item): return (0 if node_id == driver_node else 1, node_id) bundle_to_node_id.sort(key=_sort_key) - return bundle_to_node_id[:world_size] + return bundle_to_node_id def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): From ca9590065d0298862bbe3ab3f8ffff88625e3ff5 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 19 Mar 2026 22:13:09 +0000 Subject: [PATCH 12/29] Fix tests and linters Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor.py | 1 - tests/utils_/test_ray_utils.py | 31 ++++------------------- vllm/v1/executor/ray_utils.py | 1 + vllm/v1/worker/worker_base.py | 4 +-- 4 files changed, 8 insertions(+), 29 deletions(-) diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index 17709e372270..485c8b4fcb9d 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -15,7 +15,6 @@ import pytest import ray -from ray.util.placement_group import PlacementGroup from ray.util.state import list_actors from vllm import LLM diff --git a/tests/utils_/test_ray_utils.py b/tests/utils_/test_ray_utils.py index 57a4797f0c3d..0872ae9413f7 100644 --- a/tests/utils_/test_ray_utils.py +++ b/tests/utils_/test_ray_utils.py @@ -5,6 +5,8 @@ import pytest +from vllm.v1.executor.ray_utils import get_bundles_sorted_by_node + NODE_A = "node_a" NODE_B = "node_b" NODE_C = "node_c" @@ -23,12 +25,11 @@ @pytest.mark.parametrize( - "bundles_to_node_id,bundle_specs,world_size,expected", + "bundles_to_node_id,bundle_specs,expected", [ pytest.param( {0: NODE_C, 1: NODE_A, 2: NODE_B, 3: NODE_C, 4: NODE_A, 5: NODE_B}, [{"GPU": 1}] * 6, - 6, [ (1, NODE_A, IP_A), (4, NODE_A, IP_A), @@ -41,7 +42,6 @@ pytest.param( {0: NODE_B, 1: NODE_B, 2: NODE_A, 3: NODE_A}, [{"GPU": 1}] * 4, - 4, [ (2, NODE_A, IP_A), (3, NODE_A, IP_A), @@ -52,7 +52,6 @@ pytest.param( {0: NODE_C, 1: NODE_B, 2: NODE_C, 3: NODE_B}, [{"GPU": 1}] * 4, - 4, [ (1, NODE_B, IP_B), (3, NODE_B, IP_B), @@ -63,39 +62,21 @@ pytest.param( {0: NODE_A, 1: NODE_A, 2: NODE_A}, [{"GPU": 1}] * 3, - 3, [(0, NODE_A, IP_A), (1, NODE_A, IP_A), (2, NODE_A, IP_A)], ), - pytest.param( - {0: NODE_A}, - [{"GPU": 1}], - 1, - [(0, NODE_A, IP_A)], - ), pytest.param( {}, [], - 0, [], ), - pytest.param( - {0: NODE_A, 1: NODE_B, 2: NODE_A, 3: NODE_B}, - [{"GPU": 1}] * 4, - 2, - # After sort-then-clip, driver node (NODE_A) bundles are prioritized - [(0, NODE_A, IP_A), (2, NODE_A, IP_A)], - ), pytest.param( {0: NODE_A, 1: NODE_B, 2: NODE_A}, [{"CPU": 1}, {"GPU": 1}, {"GPU": 1}], - 2, [(2, NODE_A, IP_A), (1, NODE_B, IP_B)], ), ], ) -def test_get_bundles_sorted_by_node( - bundles_to_node_id, bundle_specs, world_size, expected -): +def test_get_bundles_sorted_by_node(bundles_to_node_id, bundle_specs, expected): mock_pg = MagicMock() mock_pg.bundle_specs = bundle_specs @@ -114,8 +95,6 @@ def test_get_bundles_sorted_by_node( mock_ray.nodes.return_value = MOCK_RAY_NODES mock_platform.ray_device_key = "GPU" - from vllm.v1.executor.ray_utils import get_bundles_sorted_by_node - - result = get_bundles_sorted_by_node(mock_pg, world_size) + result = get_bundles_sorted_by_node(mock_pg) assert result == expected diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 45a86fdd5b62..10ff0ae3a576 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -325,6 +325,7 @@ def _sort_key(item): return (0 if node_id == driver_node else 1, node_id) bundle_to_node_id.sort(key=_sort_key) + return bundle_to_node_id diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index b6ba8adf8336..f733d63da176 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -195,8 +195,8 @@ def __init__( All workers have rpc_rank=0, but they have different ranks in the TP group. """ - self.rpc_rank = rpc_rank - self.global_rank = self.rpc_rank if global_rank is None else global_rank + self.rpc_rank: int = rpc_rank + self.global_rank: int = self.rpc_rank if global_rank is None else global_rank # Initialized after init_worker is called self.worker: WorkerBase From 139c02a28df7219b5b4a9d27cb9b5fc968895f73 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Sun, 22 Mar 2026 21:50:34 +0000 Subject: [PATCH 13/29] Respect VLLM_RAY_BUNDLE_INDICES Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor.py | 46 +++++++++++++++++++++++ vllm/v1/executor/ray_executor_v2.py | 17 ++++++++- vllm/v1/executor/ray_utils.py | 32 ++++++++++++++++ 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index 485c8b4fcb9d..1e5456b4f058 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -313,6 +313,52 @@ def test_ray_v2_single_node_generation(tp_size, pp_size): gc.collect() +@pytest.mark.parametrize( + "bundle_indices, expected_bundle_ids, create_placement_group", + [("2,3", [2, 3], 4), ("3,2", [3, 2], 4)], + indirect=["create_placement_group"], +) +def test_ray_v2_bundle_indices_env( + bundle_indices, expected_bundle_ids, create_placement_group, monkeypatch +): + """Validate explicit VLLM_RAY_BUNDLE_INDICES bundle placement.""" + monkeypatch.setenv("VLLM_RAY_BUNDLE_INDICES", bundle_indices) + vllm_config = create_vllm_config( + tensor_parallel_size=2, + placement_group=create_placement_group, + ) + executor = RayExecutorV2(vllm_config=vllm_config) + try: + actual = [ + h.bundle_id_idx + for h in sorted(executor.ray_worker_handles, key=lambda h: h.rank) + ] + assert actual == expected_bundle_ids + assert_executor(executor, tp_size=2, pp_size=1) + finally: + executor.shutdown() + + +@pytest.mark.parametrize( + "bundle_indices, expected_error, create_placement_group", + [ + ("1,1", "cannot have duplicate values,", 4), + ("0,1,2", "must have the same size", 4), + ], + indirect=["create_placement_group"], +) +def test_ray_v2_invalid_bundle_indices( + bundle_indices, expected_error, create_placement_group, monkeypatch +): + """Validate invalid bundle indices are rejected.""" + monkeypatch.setenv("VLLM_RAY_BUNDLE_INDICES", bundle_indices) + vllm_config = create_vllm_config( + tensor_parallel_size=2, placement_group=create_placement_group + ) + with pytest.raises(AssertionError, match=expected_error): + RayExecutorV2(vllm_config=vllm_config) + + @pytest.mark.parametrize("tp_size, pp_size", [(2, 1), (2, 2)]) def test_ray_v2_single_node_generation_with_pg(tp_size, pp_size): """E2E LLM generation with a user-provided placement group.""" diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index d4e3c7066128..8f3e57846202 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -25,6 +25,7 @@ ) from vllm.v1.executor.ray_utils import ( build_actor_name, + get_bundles_for_indices, get_bundles_sorted_by_node, initialize_ray_cluster, ray, @@ -56,6 +57,9 @@ class RayWorkerHandle: node_id: str """Node ID of the worker""" + bundle_id_idx: int = -1 + """Placement group bundle index for the worker""" + run_ref: ObjectRef = None """run() ObjectRef used as a sentinel for health monitoring""" @@ -163,8 +167,16 @@ def _init_executor(self) -> None: f"_parallel_size ({pcp_size}). " ) - # Step 2: Query PG table, sort bundles, assign ranks - bundle_to_node_id = get_bundles_sorted_by_node(placement_group) + # Step 2: Build bundle assignments for worker rank placement + # while respecting VLLM_RAY_BUNDLE_INDICES. + if envs.VLLM_RAY_BUNDLE_INDICES: + bundle_to_node_id = get_bundles_for_indices( + placement_group, + list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(","))), + self.world_size, + ) + else: + bundle_to_node_id = get_bundles_sorted_by_node(placement_group) driver_node = ray.get_runtime_context().get_node_id() # Assign each worker a local rank @@ -255,6 +267,7 @@ def _init_executor(self) -> None: rank=bundle["rank"], local_rank=bundle["local_rank"], node_id=bundle["node_id"], + bundle_id_idx=bundle["bundle_id_idx"], ) self.ray_worker_handles.append(handle) diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 10ff0ae3a576..6940b6b53933 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -51,6 +51,8 @@ def __init__(self, *args, **kwargs) -> None: # that thread. self.compiled_dag_cuda_device_set = False + rpc_rank: int + def adjust_rank(self, rank_mapping: dict[int, int]) -> None: """ Adjust the rpc_rank based on the given mapping. @@ -288,6 +290,36 @@ def build_actor_name( return name +def get_bundles_for_indices( + placement_group: "PlacementGroup", + bundle_indices: list[int], + world_size: int, +) -> list[tuple[int, str, str]]: + """ + Return GPU bundle indices paired with node IDs and node IPs for + explicit bundle indices specified via VLLM_RAY_BUNDLE_INDICES. + """ + assert len(bundle_indices) == world_size, ( + "VLLM_RAY_BUNDLE_INDICES must have the same size" + f" as the world size, but got {bundle_indices=} " + f"and {world_size=}" + ) + assert len(set(bundle_indices)) == len(bundle_indices), ( + "VLLM_RAY_BUNDLE_INDICES cannot have duplicate values," + f" but got {bundle_indices=}" + ) + + pg_data = placement_group_table(placement_group) + pg_bundle_to_node = pg_data["bundles_to_node_id"] + node_id_to_ip = { + n["NodeID"]: n["NodeManagerAddress"] for n in ray.nodes() if n["Alive"] + } + return [ + (bid, pg_bundle_to_node[bid], node_id_to_ip[pg_bundle_to_node[bid]]) + for bid in bundle_indices + ] + + def get_bundles_sorted_by_node( placement_group: "PlacementGroup", ) -> list[tuple[int, str, str]]: From 7657031a42a985902033828136a47e28f44712d9 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 23 Mar 2026 04:04:32 +0000 Subject: [PATCH 14/29] Adjust DP rank for ray executor backend Signed-off-by: Jeffrey Wang --- vllm/v1/worker/gpu_worker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d101edc18100..c26d0053651d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -222,9 +222,7 @@ def init_device(self): os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) parallel_config = self.parallel_config if ( - parallel_config.distributed_executor_backend - not in ("ray", "external_launcher") - and parallel_config.data_parallel_backend != "ray" + parallel_config.distributed_executor_backend != "external_launcher" and parallel_config.nnodes_within_dp == 1 ): # Use local DP rank if available, otherwise use global DP rank. From 6c1ea7e34f84d63939ebce2ee39baa9dbed97534 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 23 Mar 2026 06:21:22 +0000 Subject: [PATCH 15/29] Apply DP local-rank device offset for RayExecutorV2 workers Signed-off-by: Jeffrey Wang --- vllm/v1/worker/gpu_worker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c26d0053651d..099b8dbeca41 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -222,7 +222,14 @@ def init_device(self): os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) parallel_config = self.parallel_config if ( - parallel_config.distributed_executor_backend != "external_launcher" + parallel_config.distributed_executor_backend + not in ("ray", "external_launcher") + and parallel_config.data_parallel_backend != "ray" + and parallel_config.nnodes_within_dp == 1 + ) or ( + envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND + and parallel_config.distributed_executor_backend == "ray" + and parallel_config.data_parallel_size > 1 and parallel_config.nnodes_within_dp == 1 ): # Use local DP rank if available, otherwise use global DP rank. From d04031790dc180d361054715dc484e3ce149d388 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 23 Mar 2026 18:45:25 -0700 Subject: [PATCH 16/29] Support DP Signed-off-by: Jeffrey Wang --- vllm/config/parallel.py | 27 +++++++++++++++++++++++---- vllm/v1/engine/core.py | 11 ++++++++++- vllm/v1/executor/ray_executor_v2.py | 16 +++++++++++----- vllm/v1/worker/gpu_worker.py | 6 +----- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index add011ca40a9..35581f52e37f 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -671,10 +671,29 @@ def __post_init__(self) -> None: ) if not self.enable_elastic_ep: if not self._data_parallel_master_port_list: - self._data_parallel_master_port_list = get_open_ports_list(5) - self.data_parallel_master_port = ( - self._data_parallel_master_port_list.pop() - ) + if ( + envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND + and self.distributed_executor_backend == "ray" + ): + # Under RayExecutorV2 each DP replica lives in a + # separate process with its own ParallelConfig. + # Random ports would differ across replicas, but + # all workers must join the same torch.distributed + # world. Derive the starting port deterministically + # from data_parallel_rpc_port (coordinated across + # all DP replicas by the launcher). +1 to avoid + # the rpc_port itself (bound by ZMQ over TCP). + self.data_parallel_master_port = ( + self.data_parallel_rpc_port + 1 + ) + else: + self._data_parallel_master_port_list = ( + get_open_ports_list(5) + ) + if self._data_parallel_master_port_list: + self.data_parallel_master_port = ( + self._data_parallel_master_port_list.pop() + ) if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 421b25c0d0d4..ae0b9ae9dde5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1039,7 +1039,16 @@ def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): parallel_config: ParallelConfig = vllm_config.parallel_config data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0 if data_parallel: - parallel_config.data_parallel_rank_local = local_dp_rank + # Prefer VLLM_DP_RANK_LOCAL (set by external launchers + # like DPServer that know the true within-node DP + # position) over local_dp_rank from CoreEngineProcManager + # (which is always 0 when each launcher manages 1 engine). + if envs.VLLM_DP_RANK_LOCAL >= 0: + parallel_config.data_parallel_rank_local = ( + envs.VLLM_DP_RANK_LOCAL + ) + else: + parallel_config.data_parallel_rank_local = local_dp_rank process_title = f"EngineCore_DP{dp_rank}" else: process_title = "EngineCore" diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 8f3e57846202..ec16c812cc55 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -231,12 +231,18 @@ def _init_executor(self) -> None: ) # Prevent Ray from setting CUDA_VISIBLE_DEVICES - runtime_env = { - "env_vars": { - env_var: "1" - for env_var in current_platform.ray_noset_device_env_vars - }, + env_vars = { + env_var: "1" + for env_var in current_platform.ray_noset_device_env_vars } + # Propagate V2 executor flag and DP local rank to workers + if envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND: + env_vars["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + if envs.VLLM_DP_RANK_LOCAL >= 0: + env_vars["VLLM_DP_RANK_LOCAL"] = str( + envs.VLLM_DP_RANK_LOCAL + ) + runtime_env = {"env_vars": env_vars} actor_name = build_actor_name( instance_id, bundle["rank"], tp_size, pp_size, pcp_size diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 099b8dbeca41..be1cfa10ed52 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -223,12 +223,8 @@ def init_device(self): parallel_config = self.parallel_config if ( parallel_config.distributed_executor_backend - not in ("ray", "external_launcher") + not in ("external_launcher",) and parallel_config.data_parallel_backend != "ray" - and parallel_config.nnodes_within_dp == 1 - ) or ( - envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND - and parallel_config.distributed_executor_backend == "ray" and parallel_config.data_parallel_size > 1 and parallel_config.nnodes_within_dp == 1 ): From a76acc987f3e4e9c79d5439eb302f88a9a3a0bde Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 24 Mar 2026 03:46:37 +0000 Subject: [PATCH 17/29] Fix linter Signed-off-by: Jeffrey Wang --- vllm/config/parallel.py | 8 ++------ vllm/v1/engine/core.py | 4 +--- vllm/v1/executor/ray_executor_v2.py | 7 ++----- vllm/v1/worker/gpu_worker.py | 9 ++++++++- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 35581f52e37f..7986d412d679 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -683,13 +683,9 @@ def __post_init__(self) -> None: # from data_parallel_rpc_port (coordinated across # all DP replicas by the launcher). +1 to avoid # the rpc_port itself (bound by ZMQ over TCP). - self.data_parallel_master_port = ( - self.data_parallel_rpc_port + 1 - ) + self.data_parallel_master_port = self.data_parallel_rpc_port + 1 else: - self._data_parallel_master_port_list = ( - get_open_ports_list(5) - ) + self._data_parallel_master_port_list = get_open_ports_list(5) if self._data_parallel_master_port_list: self.data_parallel_master_port = ( self._data_parallel_master_port_list.pop() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ae0b9ae9dde5..ac09aad4eacd 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1044,9 +1044,7 @@ def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): # position) over local_dp_rank from CoreEngineProcManager # (which is always 0 when each launcher manages 1 engine). if envs.VLLM_DP_RANK_LOCAL >= 0: - parallel_config.data_parallel_rank_local = ( - envs.VLLM_DP_RANK_LOCAL - ) + parallel_config.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL else: parallel_config.data_parallel_rank_local = local_dp_rank process_title = f"EngineCore_DP{dp_rank}" diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index ec16c812cc55..f5423a2a9862 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -232,16 +232,13 @@ def _init_executor(self) -> None: # Prevent Ray from setting CUDA_VISIBLE_DEVICES env_vars = { - env_var: "1" - for env_var in current_platform.ray_noset_device_env_vars + env_var: "1" for env_var in current_platform.ray_noset_device_env_vars } # Propagate V2 executor flag and DP local rank to workers if envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND: env_vars["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" if envs.VLLM_DP_RANK_LOCAL >= 0: - env_vars["VLLM_DP_RANK_LOCAL"] = str( - envs.VLLM_DP_RANK_LOCAL - ) + env_vars["VLLM_DP_RANK_LOCAL"] = str(envs.VLLM_DP_RANK_LOCAL) runtime_env = {"env_vars": env_vars} actor_name = build_actor_name( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index be1cfa10ed52..910d6c6145b0 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -223,8 +223,15 @@ def init_device(self): parallel_config = self.parallel_config if ( parallel_config.distributed_executor_backend - not in ("external_launcher",) + not in ("ray", "external_launcher") and parallel_config.data_parallel_backend != "ray" + and parallel_config.nnodes_within_dp == 1 + ) or ( + # RayExecutorV2 workers see all GPUs via + # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 so they + # need the DP local-rank offset applied here. + envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND + and parallel_config.distributed_executor_backend == "ray" and parallel_config.data_parallel_size > 1 and parallel_config.nnodes_within_dp == 1 ): From c9f0a39617a3e6ab5a1515ce2c4e5a3e58a48af1 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 24 Mar 2026 18:40:58 -0700 Subject: [PATCH 18/29] Lazily initialize RayWorkerProc Signed-off-by: Jeffrey Wang --- vllm/config/parallel.py | 23 ++---- vllm/v1/engine/core.py | 9 +-- vllm/v1/executor/ray_executor_v2.py | 104 +++++++++++++++++++++------- vllm/v1/worker/gpu_worker.py | 8 --- 4 files changed, 84 insertions(+), 60 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 7986d412d679..add011ca40a9 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -671,25 +671,10 @@ def __post_init__(self) -> None: ) if not self.enable_elastic_ep: if not self._data_parallel_master_port_list: - if ( - envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND - and self.distributed_executor_backend == "ray" - ): - # Under RayExecutorV2 each DP replica lives in a - # separate process with its own ParallelConfig. - # Random ports would differ across replicas, but - # all workers must join the same torch.distributed - # world. Derive the starting port deterministically - # from data_parallel_rpc_port (coordinated across - # all DP replicas by the launcher). +1 to avoid - # the rpc_port itself (bound by ZMQ over TCP). - self.data_parallel_master_port = self.data_parallel_rpc_port + 1 - else: - self._data_parallel_master_port_list = get_open_ports_list(5) - if self._data_parallel_master_port_list: - self.data_parallel_master_port = ( - self._data_parallel_master_port_list.pop() - ) + self._data_parallel_master_port_list = get_open_ports_list(5) + self.data_parallel_master_port = ( + self._data_parallel_master_port_list.pop() + ) if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ac09aad4eacd..421b25c0d0d4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1039,14 +1039,7 @@ def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): parallel_config: ParallelConfig = vllm_config.parallel_config data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0 if data_parallel: - # Prefer VLLM_DP_RANK_LOCAL (set by external launchers - # like DPServer that know the true within-node DP - # position) over local_dp_rank from CoreEngineProcManager - # (which is always 0 when each launcher manages 1 engine). - if envs.VLLM_DP_RANK_LOCAL >= 0: - parallel_config.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL - else: - parallel_config.data_parallel_rank_local = local_dp_rank + parallel_config.data_parallel_rank_local = local_dp_rank process_title = f"EngineCore_DP{dp_rank}" else: process_title = "EngineCore" diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index f5423a2a9862..4e0cf081c2aa 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import threading import weakref from collections import defaultdict, deque @@ -65,23 +66,28 @@ class RayWorkerHandle: class RayWorkerProc(WorkerProc): - """Worker process that runs inside a Ray actor.""" + """Worker process that runs inside a Ray actor. + + Initialization is split into two phases: + 1. __init__: lightweight setup, stores init args (no device/model init) + 2. initialize_worker: called after GPU IDs are discovered, completes + the full WorkerProc initialization with the correct local_rank and + CUDA_VISIBLE_DEVICES. + """ def __init__( self, vllm_config: VllmConfig, - local_rank: int, rank: int, distributed_init_method: str, input_shm_handle: Handle, is_driver_worker: bool, is_driver_node: bool = False, ): + # Defer WorkerProc.__init__ until GPU IDs are known. self._is_driver_node = is_driver_node - self.local_rank = local_rank - super().__init__( + self._init_kwargs = dict( vllm_config=vllm_config, - local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method, input_shm_handle=input_shm_handle, @@ -89,6 +95,33 @@ def __init__( is_driver_worker=is_driver_worker, ) + def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: + """Return (node_id, gpu_ids) assigned to this actor by Ray.""" + node_id = ray.get_runtime_context().get_node_id() + device_key = current_platform.ray_device_key + if not device_key: + raise RuntimeError( + "current platform %s does not support ray.", + current_platform.device_name, + ) + gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] + return node_id, [int(x) for x in gpu_ids] + + def initialize_worker(self, local_rank: int, env_vars: dict[str, str]) -> None: + """Complete initialization after GPU assignment is known. + + Sets CUDA_VISIBLE_DEVICES and initializes the underlying WorkerProc + with the correct local_rank. + """ + for key, value in env_vars.items(): + os.environ[key] = value + + self.local_rank = local_rank + super().__init__( + local_rank=local_rank, + **self._init_kwargs, + ) + def _init_message_queues( self, input_shm_handle: Handle, vllm_config: VllmConfig ) -> None: @@ -179,16 +212,11 @@ def _init_executor(self) -> None: bundle_to_node_id = get_bundles_sorted_by_node(placement_group) driver_node = ray.get_runtime_context().get_node_id() - # Assign each worker a local rank - node_rank_counter: dict[str, int] = defaultdict(int) bundle_assignments: list[dict[str, Any]] = [] for rank, (bundle_id_idx, node_id, node_ip) in enumerate(bundle_to_node_id): - local_rank = node_rank_counter[node_id] - node_rank_counter[node_id] += 1 bundle_assignments.append( { "rank": rank, - "local_rank": local_rank, "bundle_id_idx": bundle_id_idx, "node_id": node_id, "node_ip": node_ip, @@ -213,14 +241,13 @@ def _init_executor(self) -> None: ) scheduler_output_handle = self.rpc_broadcast_mq.export_handle() - # Step 5: Spawn RayWorkerProc actors into PG bundles + # Step 5: Spawn RayWorkerProc actors into PG bundles (deferred init). + # Workers are created lightweight here; full initialization happens + # in Step 7 after GPU IDs are discovered. self.ray_worker_handles: list[RayWorkerHandle] = [] instance_id = self.vllm_config.instance_id - # Create exactly world_size remote actors despite the number of bundles - # in the placement group. for bundle_idx in range(self.world_size): - # Fail fast if the placement group has less than world_size bundles. bundle = bundle_assignments[bundle_idx] is_driver_worker = self._is_driver_worker(bundle["rank"]) is_driver_node = bundle["node_id"] == driver_node @@ -230,15 +257,11 @@ def _init_executor(self) -> None: placement_group_bundle_index=bundle["bundle_id_idx"], ) - # Prevent Ray from setting CUDA_VISIBLE_DEVICES + # Prevent Ray from setting CUDA_VISIBLE_DEVICES; we set it + # in initialize_worker after discovering GPU IDs. env_vars = { env_var: "1" for env_var in current_platform.ray_noset_device_env_vars } - # Propagate V2 executor flag and DP local rank to workers - if envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND: - env_vars["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" - if envs.VLLM_DP_RANK_LOCAL >= 0: - env_vars["VLLM_DP_RANK_LOCAL"] = str(envs.VLLM_DP_RANK_LOCAL) runtime_env = {"env_vars": env_vars} actor_name = build_actor_name( @@ -256,7 +279,6 @@ def _init_executor(self) -> None: ) .remote( vllm_config=self.vllm_config, - local_rank=bundle["local_rank"], rank=bundle["rank"], distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, @@ -268,13 +290,45 @@ def _init_executor(self) -> None: handle = RayWorkerHandle( actor=actor, rank=bundle["rank"], - local_rank=bundle["local_rank"], + local_rank=-1, # Set in Step 7 after GPU ID discovery node_id=bundle["node_id"], bundle_id_idx=bundle["bundle_id_idx"], ) self.ray_worker_handles.append(handle) - # Step 6: Collect response MQ handles + # Step 6: Discover GPU IDs assigned to each worker via Ray runtime context. + worker_node_and_gpu_ids = ray.get( + [h.actor.get_node_and_gpu_ids.remote() for h in self.ray_worker_handles] + ) + + node_workers: dict[str, list[int]] = defaultdict(list) + node_gpus: dict[str, list[int]] = defaultdict(list) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + # Step 7: Initialize workers with correct local_rank and + # CUDA_VISIBLE_DEVICES. Each worker sees all GPUs assigned to + # this executor on its node; local_rank indexes into that set. + init_worker_refs = [] + for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): + local_rank = node_workers[node_id].index(i) + worker_env_vars = { + current_platform.device_control_env_var: ",".join( + map(str, node_gpus[node_id]) + ), + } + self.ray_worker_handles[i].local_rank = local_rank + init_worker_refs.append( + self.ray_worker_handles[i].actor.initialize_worker.remote( + local_rank, worker_env_vars + ) + ) + ray.get(init_worker_refs) + + # Step 8: Collect response MQ handles init_refs = [h.actor.wait_for_init.remote() for h in self.ray_worker_handles] init_results = ray.get(init_refs) @@ -286,12 +340,12 @@ def _init_executor(self) -> None: MessageQueue.create_from_handle(result["handle"], 0) ) - # Step 7: Start run() before wait_until_ready() to avoid + # Step 9: Start run() before wait_until_ready() to avoid # deadlock — workers send subscriptions inside run(). for handle in self.ray_worker_handles: handle.run_ref = handle.actor.run.remote() - # Step 8: wait_until_ready() barrier + # Step 10: wait_until_ready() barrier self.rpc_broadcast_mq.wait_until_ready() for response_mq in self.response_mqs: response_mq.wait_until_ready() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 910d6c6145b0..d101edc18100 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -226,14 +226,6 @@ def init_device(self): not in ("ray", "external_launcher") and parallel_config.data_parallel_backend != "ray" and parallel_config.nnodes_within_dp == 1 - ) or ( - # RayExecutorV2 workers see all GPUs via - # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 so they - # need the DP local-rank offset applied here. - envs.VLLM_USE_RAY_V2_EXECUTOR_BACKEND - and parallel_config.distributed_executor_backend == "ray" - and parallel_config.data_parallel_size > 1 - and parallel_config.nnodes_within_dp == 1 ): # Use local DP rank if available, otherwise use global DP rank. dp_local_rank = self.parallel_config.data_parallel_rank_local From 29c74260638a6fbce100d0828e54ee499859fd2b Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Wed, 25 Mar 2026 23:15:39 +0000 Subject: [PATCH 19/29] Propagate env var; add tests Signed-off-by: Jeffrey Wang --- .buildkite/test_areas/distributed.yaml | 2 + tests/distributed/ray_v2_utils.py | 32 +++ tests/distributed/test_ray_v2_executor.py | 53 +--- tests/distributed/test_ray_v2_executor_e2e.py | 174 ++++++++++++ .../test_ray_v2_executor_multinode.py | 257 ++++++++++++++++++ vllm/v1/executor/ray_executor.py | 14 +- vllm/v1/executor/ray_executor_v2.py | 37 ++- vllm/v1/executor/ray_utils.py | 11 + 8 files changed, 510 insertions(+), 70 deletions(-) create mode 100644 tests/distributed/ray_v2_utils.py create mode 100644 tests/distributed/test_ray_v2_executor_e2e.py create mode 100644 tests/distributed/test_ray_v2_executor_multinode.py diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 044da189c58e..4cc9bc0f0280 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -308,11 +308,13 @@ steps: - vllm/v1/executor/abstract.py - vllm/v1/executor/multiproc_executor.py - tests/distributed/test_ray_v2_executor.py + - tests/distributed/test_ray_v2_executor_e2e.py - tests/distributed/test_pipeline_parallel.py - tests/basic_correctness/test_basic_correctness.py commands: - export VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 - export NCCL_CUMEM_HOST_ENABLE=0 - pytest -v -s distributed/test_ray_v2_executor.py + - pytest -v -s distributed/test_ray_v2_executor_e2e.py - pytest -v -s distributed/test_pipeline_parallel.py -k "ray" - TARGET_TEST_SUITE=L4 pytest -v -s basic_correctness/test_basic_correctness.py -k "ray" diff --git a/tests/distributed/ray_v2_utils.py b/tests/distributed/ray_v2_utils.py new file mode 100644 index 000000000000..a490f7675b07 --- /dev/null +++ b/tests/distributed/ray_v2_utils.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Shared test utilities for RayExecutorV2 tests.""" + +import os + +import pytest +import ray + + +@pytest.fixture(autouse=True) +def enable_ray_v2_backend(): + saved = { + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" + ), + "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( + "VLLM_ENABLE_V1_MULTIPROCESSING" + ), + } + os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + if ray.is_initialized(): + ray.shutdown() + try: + yield + finally: + if ray.is_initialized(): + ray.shutdown() + os.environ.update({k: v for k, v in saved.items() if v is not None}) + for key in (k for k, v in saved.items() if v is None): + os.environ.pop(key, None) diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index 1e5456b4f058..eaf80da05261 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -8,15 +8,13 @@ """ import gc -import os import threading -import time from unittest.mock import patch import pytest import ray -from ray.util.state import list_actors +from tests.distributed.ray_v2_utils import enable_ray_v2_backend # noqa: F401 from vllm import LLM from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -25,55 +23,6 @@ MODEL = "facebook/opt-125m" -@pytest.fixture(autouse=True) -def enable_ray_v2_backend(): - """Enable the RayExecutorV2 backend via feature flag for all tests.""" - saved = { - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" - ), - "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( - "VLLM_ENABLE_V1_MULTIPROCESSING" - ), - } - os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" - # The multiprocess engine forks a subprocess that inherits the Ray - # driver connection, causing hangs. RayExecutorV2 already distributes - # work via Ray actors, so the EngineCore can run safely in-process. - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - try: - yield - finally: - _cleanup_ray_resources() - os.environ.update({k: v for k, v in saved.items() if v is not None}) - for key in (k for k, v in saved.items() if v is None): - os.environ.pop(key, None) - - -def _cleanup_ray_resources(): - if not ray.is_initialized(): - return - - # Ray actor shutdown is async -- wait until all actors are dead. - dangling_actors = [] - try: - for _ in range(10): - dangling_actors = [ - actor - for actor in list_actors(filters=[("state", "=", "ALIVE")]) - if actor.class_name == "RayWorkerProc" - ] - if not dangling_actors: - break - time.sleep(1) - except Exception: - # Tolerate connection errors to the Ray dashboard - pass - - assert not dangling_actors - ray.shutdown() - - def create_vllm_config( tensor_parallel_size: int = 1, pipeline_parallel_size: int = 1, diff --git a/tests/distributed/test_ray_v2_executor_e2e.py b/tests/distributed/test_ray_v2_executor_e2e.py new file mode 100644 index 000000000000..7a857e88cd06 --- /dev/null +++ b/tests/distributed/test_ray_v2_executor_e2e.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Orchestration-level integration tests for RayExecutorV2. +""" + +import gc +import os + +import ray + +from tests.distributed.ray_v2_utils import enable_ray_v2_backend # noqa: F401 +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.sampling_params import SamplingParams +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.executor.abstract import Executor + +MODEL = "facebook/opt-125m" + + +def _get_env_var(worker, name): + """Called on RayWorkerProc workers via collective_rpc.""" + return os.environ.get(name) + + +@ray.remote(num_cpus=0) +class AsyncLLMActor: + async def start(self, pg, bundle_indices=None, ray_runtime_env=None): + os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + # VLLM_ALLOW_INSECURE_SERIALIZATION is needed so collective_rpc can + # pickle _get_env_var over the AsyncLLM -> EngineCore ZMQ boundary. + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + if bundle_indices is not None: + os.environ["VLLM_RAY_BUNDLE_INDICES"] = bundle_indices + else: + os.environ.pop("VLLM_RAY_BUNDLE_INDICES", None) + + engine_args = AsyncEngineArgs( + model=MODEL, + tensor_parallel_size=2, + distributed_executor_backend="ray", + enforce_eager=True, + max_model_len=256, + gpu_memory_utilization=0.8, + ) + vllm_config = engine_args.create_engine_config() + vllm_config.parallel_config.placement_group = pg + if ray_runtime_env is not None: + vllm_config.parallel_config.ray_runtime_env = ray_runtime_env + + executor_class = Executor.get_class(vllm_config) + self.engine = AsyncLLM( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + log_requests=False, + ) + + async def generate(self, prompt): + params = SamplingParams(max_tokens=16) + result = None + async for output in self.engine.generate( + prompt, params, request_id="test_request_id" + ): + result = output + assert result is not None + return result.outputs[0].text + + async def get_worker_env(self, name): + results = await self.engine.collective_rpc( + _get_env_var, + timeout=10, + args=(name,), + ) + return results + + async def shutdown(self): + if engine := getattr(self, "engine", None): + engine.shutdown() + del self.engine + gc.collect() + + +def test_multi_replicas(): + """Two actors each run AsyncLLM with TP=2 via RayExecutorV2. + + Actor 1 starts first and claims 80% of GPU memory. Without lazy + RayWorkerProc init, actor 2 lands on the *same* two GPUs and fails + because there is not enough free memory. + """ + ray.init(ignore_reinit_error=True) + + pg1 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") + pg2 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") + ray.get([pg1.ready(), pg2.ready()]) + + actor1 = AsyncLLMActor.remote() # type: ignore[attr-defined] + actor2 = AsyncLLMActor.remote() # type: ignore[attr-defined] + + ray.get(actor1.start.remote(pg1)) + ray.get(actor2.start.remote(pg2)) + + out1, out2 = ray.get( + [ + actor1.generate.remote("Hello world"), + actor2.generate.remote("Hello world"), + ] + ) + assert len(out1) > 0 + assert len(out2) > 0 + + +def test_multi_replicas_with_bundle_indices(): + """Two actors share one 4-GPU placement group with out-of-order + bundle indices: actor 1 gets bundles [2,1], actor 2 gets [0,3]. + """ + ray.init(ignore_reinit_error=True) + + pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 4, strategy="PACK") + ray.get(pg.ready()) + + actor1 = AsyncLLMActor.remote() # type: ignore[attr-defined] + actor2 = AsyncLLMActor.remote() # type: ignore[attr-defined] + + ray.get(actor1.start.remote(pg, bundle_indices="2,1")) + ray.get(actor2.start.remote(pg, bundle_indices="0,3")) + + out1, out2 = ray.get( + [ + actor1.generate.remote("Hello world"), + actor2.generate.remote("Hello world"), + ] + ) + assert len(out1) > 0 + assert len(out2) > 0 + + +def test_env_var_and_runtime_env_propagation(): + """ + Verify env vars (NCCL_, HF_) and parallel_config.ray_runtime_env + propagate to RayWorkerProc actors. + """ + sentinel_vars = { + "NCCL_DEBUG": "INFO", + "HF_TOKEN": "test_sentinel_token", + } + for k, v in sentinel_vars.items(): + os.environ[k] = v + + try: + ray.init(ignore_reinit_error=True) + + pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") + ray.get(pg.ready()) + + ray_runtime_env = { + "env_vars": {"RAY_RUNTIME_ENV_MARKER": "ray_runtime_env"}, + } + + actor = AsyncLLMActor.remote() # type: ignore[attr-defined] + ray.get(actor.start.remote(pg, ray_runtime_env=ray_runtime_env)) + + for name, expected in sentinel_vars.items(): + results = ray.get(actor.get_worker_env.remote(name)) + for val in results: + assert val == expected + + results = ray.get(actor.get_worker_env.remote("RAY_RUNTIME_ENV_MARKER")) + for val in results: + assert val == "ray_runtime_env" + + finally: + for k in sentinel_vars: + os.environ.pop(k, None) diff --git a/tests/distributed/test_ray_v2_executor_multinode.py b/tests/distributed/test_ray_v2_executor_multinode.py new file mode 100644 index 000000000000..ed7abd9bc43b --- /dev/null +++ b/tests/distributed/test_ray_v2_executor_multinode.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Multi-node integration tests for RayExecutorV2. +Validates executor initialization, worker placement, and end-to-end +generation across multiple nodes with various TP/PP configurations. + +Requires VLLM_MULTI_NODE=1 env var and a multi-node Ray cluster. + +Run: +```sh +VLLM_MULTI_NODE=1 VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 \ + pytest -v -s distributed/test_ray_v2_executor_multinode.py +``` +""" + +import gc +import os +import time +from unittest.mock import patch + +import pytest +import ray +from ray.util.placement_group import PlacementGroup +from ray.util.state import list_actors + +from vllm import LLM +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 + +MODEL = "facebook/opt-125m" + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + +pytestmark = pytest.mark.skipif( + not VLLM_MULTI_NODE, reason="Need VLLM_MULTI_NODE=1 and a multi-node cluster." +) + + +@pytest.fixture(autouse=True) +def enable_ray_v2_backend(): + """Enable the RayExecutorV2 backend via feature flag for all tests.""" + saved = { + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" + ), + "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( + "VLLM_ENABLE_V1_MULTIPROCESSING" + ), + } + os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + # When multiprocessing is enabled (default), the LLM engine runs in a + # forked subprocess and `llm.llm_engine.model_executor` is not directly + # accessible for test cleanup (`executor.shutdown()`). Disabling it + # keeps the engine in-process so tests can call shutdown explicitly. + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + try: + yield + finally: + _cleanup_ray_resources() + os.environ.update({k: v for k, v in saved.items() if v is not None}) + for key in (k for k, v in saved.items() if v is None): + os.environ.pop(key, None) + + +def _cleanup_ray_resources(): + if not ray.is_initialized(): + return + + # Ray actor shutdown is async -- wait until all actors are dead. + dangling_actors = [] + try: + for _ in range(10): + dangling_actors = [ + actor + for actor in list_actors(filters=[("state", "=", "ALIVE")]) + if actor.class_name == "RayWorkerProc" + ] + if not dangling_actors: + break + time.sleep(1) + except Exception: + # Tolerate connection errors to the Ray dashboard + pass + + try: + for pg_id, pg_info in ray.util.placement_group_table().items(): + if pg_info["state"] == "CREATED": + pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) + ray.util.remove_placement_group(pg) + except Exception: + pass + finally: + ray.shutdown() + + +def create_vllm_config( + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = 256, + gpu_memory_utilization: float = 0.3, +) -> VllmConfig: + engine_args = EngineArgs( + model=MODEL, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + distributed_executor_backend="ray", + enforce_eager=True, + ) + return engine_args.create_engine_config() + + +def assert_executor(executor, tp_size, pp_size): + """Common assertions for executor initialization tests.""" + world_size = tp_size * pp_size + expected_output_rank = (pp_size - 1) * tp_size + + assert executor.world_size == world_size + assert len(executor.ray_worker_handles) == world_size + assert len(executor.response_mqs) == world_size + assert executor._get_output_rank() == expected_output_rank + + if pp_size > 1: + assert executor.max_concurrent_batches == pp_size + + executor.check_health() + assert not executor.is_failed + + ranks = sorted(h.rank for h in executor.ray_worker_handles) + assert ranks == list(range(world_size)) + + for handle in executor.ray_worker_handles: + assert handle.node_id is not None + + +def test_ray_v2_multinode_executor_init(): + """Validate RayExecutorV2 initializes correctly across multiple nodes + with TP=4, PP=2 (8 GPUs).""" + vllm_config = create_vllm_config( + tensor_parallel_size=4, + pipeline_parallel_size=2, + ) + executor = RayExecutorV2(vllm_config=vllm_config) + try: + assert_executor(executor, tp_size=4, pp_size=2) + + # Verify workers span multiple nodes + node_ids = {h.node_id for h in executor.ray_worker_handles} + assert len(node_ids) > 1 + + # Verify rank 0 exists and has a valid node_id. + # On clusters where the driver node has GPUs, rank 0 will be there. + # On GPU-less head nodes, rank 0 is on the first GPU node instead. + rank0_handle = next(h for h in executor.ray_worker_handles if h.rank == 0) + assert rank0_handle.node_id is not None + finally: + executor.shutdown() + + +def test_ray_v2_multinode_worker_placement(): + """Verify TP locality: workers in the same TP group share a node.""" + tp_size = 4 + pp_size = 2 + + vllm_config = create_vllm_config( + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + ) + executor = RayExecutorV2(vllm_config=vllm_config) + try: + # Workers are sorted by rank; consecutive tp_size ranks form a TP group + for pp_rank in range(pp_size): + start_rank = pp_rank * tp_size + tp_group_handles = [ + h + for h in executor.ray_worker_handles + if start_rank <= h.rank < start_rank + tp_size + ] + tp_group_nodes = {h.node_id for h in tp_group_handles} + assert len(tp_group_nodes) == 1 + + # Workers should be distributed across > 1 node + all_nodes = {h.node_id for h in executor.ray_worker_handles} + assert len(all_nodes) > 1 + finally: + executor.shutdown() + + +def test_ray_v2_multinode_generation(): + """End-to-end LLM generation with TP=4, PP=2 across multiple nodes.""" + llm = LLM( + model=MODEL, + tensor_parallel_size=4, + pipeline_parallel_size=2, + distributed_executor_backend="ray", + enforce_eager=True, + max_model_len=256, + gpu_memory_utilization=0.3, + ) + try: + prompts = [ + "Hello, my name is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts) + + assert len(outputs) == len(prompts) + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text) > 0 + finally: + llm.llm_engine.model_executor.shutdown() + del llm + gc.collect() + + +@pytest.mark.parametrize("tp_size, pp_size", [(4, 2), (2, 4)]) +def test_ray_v2_multinode_generation_with_pg(tp_size, pp_size): + """E2E LLM generation with a user-provided placement group across nodes.""" + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + + bundles = [{"GPU": 1, "CPU": 1} for _ in range(tp_size * pp_size)] + pg = ray.util.placement_group(bundles, strategy="PACK") + ray.get(pg.ready()) + + try: + with patch.object(ray.util, "get_current_placement_group", return_value=pg): + llm = LLM( + model=MODEL, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + distributed_executor_backend="ray", + enforce_eager=True, + max_model_len=256, + gpu_memory_utilization=0.3, + ) + prompts = [ + "Hello, my name is", + "The capital of France is", + "The future of AI is", + ] + outputs = llm.generate(prompts) + + assert len(outputs) == len(prompts) + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text) > 0 + finally: + llm.llm_engine.model_executor.shutdown() + del llm + gc.collect() diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 1cbc11990e08..3115d0ce65ee 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -23,6 +23,7 @@ from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.executor.ray_utils import ( + WORKER_SPECIFIC_ENV_VARS, FutureWrapper, RayWorkerWrapper, initialize_ray_cluster, @@ -62,17 +63,6 @@ class RayWorkerMetaData: class RayDistributedExecutor(Executor): """Ray-based distributed executor""" - # These env vars are worker-specific, therefore are NOT copied - # from the driver to the workers - WORKER_SPECIFIC_ENV_VARS = { - "VLLM_HOST_IP", - "VLLM_HOST_PORT", - "LOCAL_RANK", - "CUDA_VISIBLE_DEVICES", - "HIP_VISIBLE_DEVICES", - "ROCR_VISIBLE_DEVICES", - } - uses_ray: bool = True supports_pp: bool = True @@ -335,7 +325,7 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # Environment variables to copy from driver to workers env_vars_to_copy = get_env_vars_to_copy( - exclude_vars=self.WORKER_SPECIFIC_ENV_VARS, + exclude_vars=WORKER_SPECIFIC_ENV_VARS, additional_vars=set(current_platform.additional_env_vars), destination="workers", ) diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 4e0cf081c2aa..b5592df3c200 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import os import threading import weakref @@ -15,6 +16,7 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.ray.ray_env import get_env_vars_to_copy from vllm.utils.network_utils import ( get_distributed_init_method, get_open_port, @@ -25,6 +27,7 @@ WorkerProc, ) from vllm.v1.executor.ray_utils import ( + WORKER_SPECIFIC_ENV_VARS, build_actor_name, get_bundles_for_indices, get_bundles_sorted_by_node, @@ -247,6 +250,17 @@ def _init_executor(self) -> None: self.ray_worker_handles: list[RayWorkerHandle] = [] instance_id = self.vllm_config.instance_id + # Collect env vars to propagate from driver to workers (NCCL, + # HF, vLLM flags, etc.) — same mechanism as RayDistributedExecutor. + env_vars_to_copy = get_env_vars_to_copy( + exclude_vars=WORKER_SPECIFIC_ENV_VARS, + additional_vars=set(current_platform.additional_env_vars), + destination="RayWorkerProc actors", + ) + self._worker_env_vars = { + name: os.environ[name] for name in env_vars_to_copy if name in os.environ + } + for bundle_idx in range(self.world_size): bundle = bundle_assignments[bundle_idx] is_driver_worker = self._is_driver_worker(bundle["rank"]) @@ -257,12 +271,23 @@ def _init_executor(self) -> None: placement_group_bundle_index=bundle["bundle_id_idx"], ) - # Prevent Ray from setting CUDA_VISIBLE_DEVICES; we set it - # in initialize_worker after discovering GPU IDs. - env_vars = { - env_var: "1" for env_var in current_platform.ray_noset_device_env_vars - } - runtime_env = {"env_vars": env_vars} + # Build runtime_env for the worker actor: + # 1. Start from parallel_config.ray_runtime_env (pip, working_dir, + # etc.) so that packages installed at the job level are + # available inside RayWorkerProc actors. + # 2. Merge in driver env vars (NCCL, HF, vLLM flags) collected + # by get_env_vars_to_copy. + # 3. Prevent Ray from setting CUDA_VISIBLE_DEVICES; we set it + # ourselves in initialize_worker after discovering GPU IDs. + base_runtime_env = self.parallel_config.ray_runtime_env + runtime_env: dict = ( + copy.deepcopy(dict(base_runtime_env)) if base_runtime_env else {} + ) + env_vars = runtime_env.setdefault("env_vars", {}) + env_vars.update(self._worker_env_vars) + env_vars.update( + {v: "1" for v in current_platform.ray_noset_device_env_vars} + ) actor_name = build_actor_name( instance_id, bundle["rank"], tp_size, pp_size, pcp_size diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 6940b6b53933..6a997c7f50f0 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -26,6 +26,17 @@ logger = init_logger(__name__) PG_WAIT_TIMEOUT = 1800 +# Env vars that are worker-specific and must NOT be copied from the +# driver to Ray workers — they are set per-worker after GPU discovery. +WORKER_SPECIFIC_ENV_VARS: set[str] = { + "VLLM_HOST_IP", + "VLLM_HOST_PORT", + "LOCAL_RANK", + "CUDA_VISIBLE_DEVICES", + "HIP_VISIBLE_DEVICES", + "ROCR_VISIBLE_DEVICES", +} + try: import ray from ray.util import placement_group_table From aae593883c93366470ceb5359fbcc3393f26b12d Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Wed, 25 Mar 2026 23:28:28 +0000 Subject: [PATCH 20/29] Add nsight profiling and non-GPU device support to RayExecutorV2 - Extract _build_runtime_env() and _get_actor_resource_kwargs() helpers mirroring RayDistributedExecutor's pattern. - Inject nsight config into runtime_env when ray_workers_use_nsight is set. - Use resources={device_key: num} for non-GPU platforms (e.g. TPU) instead of hardcoding num_gpus. - Remove test_ray_v2_executor_multinode.py (moved elsewhere). Signed-off-by: Jeffrey Wang --- .../test_ray_v2_executor_multinode.py | 257 ------------------ vllm/v1/executor/ray_executor_v2.py | 51 ++-- 2 files changed, 33 insertions(+), 275 deletions(-) delete mode 100644 tests/distributed/test_ray_v2_executor_multinode.py diff --git a/tests/distributed/test_ray_v2_executor_multinode.py b/tests/distributed/test_ray_v2_executor_multinode.py deleted file mode 100644 index ed7abd9bc43b..000000000000 --- a/tests/distributed/test_ray_v2_executor_multinode.py +++ /dev/null @@ -1,257 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -Multi-node integration tests for RayExecutorV2. -Validates executor initialization, worker placement, and end-to-end -generation across multiple nodes with various TP/PP configurations. - -Requires VLLM_MULTI_NODE=1 env var and a multi-node Ray cluster. - -Run: -```sh -VLLM_MULTI_NODE=1 VLLM_USE_RAY_V2_EXECUTOR_BACKEND=1 \ - pytest -v -s distributed/test_ray_v2_executor_multinode.py -``` -""" - -import gc -import os -import time -from unittest.mock import patch - -import pytest -import ray -from ray.util.placement_group import PlacementGroup -from ray.util.state import list_actors - -from vllm import LLM -from vllm.config import VllmConfig -from vllm.engine.arg_utils import EngineArgs -from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 - -MODEL = "facebook/opt-125m" - -VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" - -pytestmark = pytest.mark.skipif( - not VLLM_MULTI_NODE, reason="Need VLLM_MULTI_NODE=1 and a multi-node cluster." -) - - -@pytest.fixture(autouse=True) -def enable_ray_v2_backend(): - """Enable the RayExecutorV2 backend via feature flag for all tests.""" - saved = { - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" - ), - "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( - "VLLM_ENABLE_V1_MULTIPROCESSING" - ), - } - os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" - # When multiprocessing is enabled (default), the LLM engine runs in a - # forked subprocess and `llm.llm_engine.model_executor` is not directly - # accessible for test cleanup (`executor.shutdown()`). Disabling it - # keeps the engine in-process so tests can call shutdown explicitly. - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - try: - yield - finally: - _cleanup_ray_resources() - os.environ.update({k: v for k, v in saved.items() if v is not None}) - for key in (k for k, v in saved.items() if v is None): - os.environ.pop(key, None) - - -def _cleanup_ray_resources(): - if not ray.is_initialized(): - return - - # Ray actor shutdown is async -- wait until all actors are dead. - dangling_actors = [] - try: - for _ in range(10): - dangling_actors = [ - actor - for actor in list_actors(filters=[("state", "=", "ALIVE")]) - if actor.class_name == "RayWorkerProc" - ] - if not dangling_actors: - break - time.sleep(1) - except Exception: - # Tolerate connection errors to the Ray dashboard - pass - - try: - for pg_id, pg_info in ray.util.placement_group_table().items(): - if pg_info["state"] == "CREATED": - pg = PlacementGroup(ray.PlacementGroupID(bytes.fromhex(pg_id))) - ray.util.remove_placement_group(pg) - except Exception: - pass - finally: - ray.shutdown() - - -def create_vllm_config( - tensor_parallel_size: int = 1, - pipeline_parallel_size: int = 1, - max_model_len: int = 256, - gpu_memory_utilization: float = 0.3, -) -> VllmConfig: - engine_args = EngineArgs( - model=MODEL, - tensor_parallel_size=tensor_parallel_size, - pipeline_parallel_size=pipeline_parallel_size, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - distributed_executor_backend="ray", - enforce_eager=True, - ) - return engine_args.create_engine_config() - - -def assert_executor(executor, tp_size, pp_size): - """Common assertions for executor initialization tests.""" - world_size = tp_size * pp_size - expected_output_rank = (pp_size - 1) * tp_size - - assert executor.world_size == world_size - assert len(executor.ray_worker_handles) == world_size - assert len(executor.response_mqs) == world_size - assert executor._get_output_rank() == expected_output_rank - - if pp_size > 1: - assert executor.max_concurrent_batches == pp_size - - executor.check_health() - assert not executor.is_failed - - ranks = sorted(h.rank for h in executor.ray_worker_handles) - assert ranks == list(range(world_size)) - - for handle in executor.ray_worker_handles: - assert handle.node_id is not None - - -def test_ray_v2_multinode_executor_init(): - """Validate RayExecutorV2 initializes correctly across multiple nodes - with TP=4, PP=2 (8 GPUs).""" - vllm_config = create_vllm_config( - tensor_parallel_size=4, - pipeline_parallel_size=2, - ) - executor = RayExecutorV2(vllm_config=vllm_config) - try: - assert_executor(executor, tp_size=4, pp_size=2) - - # Verify workers span multiple nodes - node_ids = {h.node_id for h in executor.ray_worker_handles} - assert len(node_ids) > 1 - - # Verify rank 0 exists and has a valid node_id. - # On clusters where the driver node has GPUs, rank 0 will be there. - # On GPU-less head nodes, rank 0 is on the first GPU node instead. - rank0_handle = next(h for h in executor.ray_worker_handles if h.rank == 0) - assert rank0_handle.node_id is not None - finally: - executor.shutdown() - - -def test_ray_v2_multinode_worker_placement(): - """Verify TP locality: workers in the same TP group share a node.""" - tp_size = 4 - pp_size = 2 - - vllm_config = create_vllm_config( - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - ) - executor = RayExecutorV2(vllm_config=vllm_config) - try: - # Workers are sorted by rank; consecutive tp_size ranks form a TP group - for pp_rank in range(pp_size): - start_rank = pp_rank * tp_size - tp_group_handles = [ - h - for h in executor.ray_worker_handles - if start_rank <= h.rank < start_rank + tp_size - ] - tp_group_nodes = {h.node_id for h in tp_group_handles} - assert len(tp_group_nodes) == 1 - - # Workers should be distributed across > 1 node - all_nodes = {h.node_id for h in executor.ray_worker_handles} - assert len(all_nodes) > 1 - finally: - executor.shutdown() - - -def test_ray_v2_multinode_generation(): - """End-to-end LLM generation with TP=4, PP=2 across multiple nodes.""" - llm = LLM( - model=MODEL, - tensor_parallel_size=4, - pipeline_parallel_size=2, - distributed_executor_backend="ray", - enforce_eager=True, - max_model_len=256, - gpu_memory_utilization=0.3, - ) - try: - prompts = [ - "Hello, my name is", - "The capital of France is", - "The future of AI is", - ] - outputs = llm.generate(prompts) - - assert len(outputs) == len(prompts) - for output in outputs: - assert len(output.outputs) > 0 - assert len(output.outputs[0].text) > 0 - finally: - llm.llm_engine.model_executor.shutdown() - del llm - gc.collect() - - -@pytest.mark.parametrize("tp_size, pp_size", [(4, 2), (2, 4)]) -def test_ray_v2_multinode_generation_with_pg(tp_size, pp_size): - """E2E LLM generation with a user-provided placement group across nodes.""" - if not ray.is_initialized(): - ray.init(ignore_reinit_error=True) - - bundles = [{"GPU": 1, "CPU": 1} for _ in range(tp_size * pp_size)] - pg = ray.util.placement_group(bundles, strategy="PACK") - ray.get(pg.ready()) - - try: - with patch.object(ray.util, "get_current_placement_group", return_value=pg): - llm = LLM( - model=MODEL, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - distributed_executor_backend="ray", - enforce_eager=True, - max_model_len=256, - gpu_memory_utilization=0.3, - ) - prompts = [ - "Hello, my name is", - "The capital of France is", - "The future of AI is", - ] - outputs = llm.generate(prompts) - - assert len(outputs) == len(prompts) - for output in outputs: - assert len(output.outputs) > 0 - assert len(output.outputs[0].text) > 0 - finally: - llm.llm_engine.model_executor.shutdown() - del llm - gc.collect() diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index b5592df3c200..81855c765612 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -182,6 +182,36 @@ class RayExecutorV2(MultiprocExecutor): def __init__(self, vllm_config: VllmConfig): super(MultiprocExecutor, self).__init__(vllm_config) + def _build_runtime_env(self) -> dict: + """Build a runtime_env dict for RayWorkerProc actors. + + Merges parallel_config.ray_runtime_env, driver env vars from + get_env_vars_to_copy, noset-device flags, and optional Nsight + profiling config. + """ + base = self.parallel_config.ray_runtime_env + runtime_env: dict = copy.deepcopy(dict(base)) if base else {} + + env_vars = runtime_env.setdefault("env_vars", {}) + env_vars.update(self._worker_env_vars) + env_vars.update({v: "1" for v in current_platform.ray_noset_device_env_vars}) + if self.parallel_config.ray_workers_use_nsight: + runtime_env["nsight"] = { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + return runtime_env + + @staticmethod + def _get_actor_resource_kwargs() -> dict[str, Any]: + """Return Ray actor resource kwargs for the current platform.""" + num_devices = envs.VLLM_RAY_PER_WORKER_GPUS + device_key = current_platform.ray_device_key + if device_key == "GPU": + return {"num_gpus": num_devices} + return {"num_gpus": 0, "resources": {device_key: num_devices}} + def _init_executor(self) -> None: """Initialize the RayExecutorV2 executor.""" self._finalizer = weakref.finalize(self, self.shutdown) @@ -271,23 +301,8 @@ def _init_executor(self) -> None: placement_group_bundle_index=bundle["bundle_id_idx"], ) - # Build runtime_env for the worker actor: - # 1. Start from parallel_config.ray_runtime_env (pip, working_dir, - # etc.) so that packages installed at the job level are - # available inside RayWorkerProc actors. - # 2. Merge in driver env vars (NCCL, HF, vLLM flags) collected - # by get_env_vars_to_copy. - # 3. Prevent Ray from setting CUDA_VISIBLE_DEVICES; we set it - # ourselves in initialize_worker after discovering GPU IDs. - base_runtime_env = self.parallel_config.ray_runtime_env - runtime_env: dict = ( - copy.deepcopy(dict(base_runtime_env)) if base_runtime_env else {} - ) - env_vars = runtime_env.setdefault("env_vars", {}) - env_vars.update(self._worker_env_vars) - env_vars.update( - {v: "1" for v in current_platform.ray_noset_device_env_vars} - ) + runtime_env = self._build_runtime_env() + resource_kwargs = self._get_actor_resource_kwargs() actor_name = build_actor_name( instance_id, bundle["rank"], tp_size, pp_size, pcp_size @@ -298,7 +313,7 @@ def _init_executor(self) -> None: .options( name=actor_name, num_cpus=0, - num_gpus=envs.VLLM_RAY_PER_WORKER_GPUS, + **resource_kwargs, scheduling_strategy=scheduling_strategy, runtime_env=runtime_env, ) From 25eaf8e9190e8eb1b8900e25b85a63238f24ade7 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 26 Mar 2026 00:27:15 +0000 Subject: [PATCH 21/29] Fix AsyncLLMActor async detection in e2e tests Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor_e2e.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/distributed/test_ray_v2_executor_e2e.py b/tests/distributed/test_ray_v2_executor_e2e.py index 7a857e88cd06..dace67011002 100644 --- a/tests/distributed/test_ray_v2_executor_e2e.py +++ b/tests/distributed/test_ray_v2_executor_e2e.py @@ -25,6 +25,9 @@ def _get_env_var(worker, name): @ray.remote(num_cpus=0) class AsyncLLMActor: + async def __init__(self): + self.engine: AsyncLLM + async def start(self, pg, bundle_indices=None, ray_runtime_env=None): os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" # VLLM_ALLOW_INSECURE_SERIALIZATION is needed so collective_rpc can From 6717ca2bd17a579b7b5e4d41d131245c50d535b0 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 26 Mar 2026 01:23:39 +0000 Subject: [PATCH 22/29] Fix AsyncLLMActor async detection in e2e tests Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor_e2e.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/distributed/test_ray_v2_executor_e2e.py b/tests/distributed/test_ray_v2_executor_e2e.py index dace67011002..d4cd083b37b3 100644 --- a/tests/distributed/test_ray_v2_executor_e2e.py +++ b/tests/distributed/test_ray_v2_executor_e2e.py @@ -23,12 +23,16 @@ def _get_env_var(worker, name): return os.environ.get(name) -@ray.remote(num_cpus=0) +# Follows the same pattern as Ray Serve LLM (LLMServer / VLLMEngine): +# sync __init__ and start() (AsyncLLM constructor is sync), async methods +# for generate/collective_rpc. max_concurrency=1 gives the actor an +# event loop so async methods work. +@ray.remote(num_cpus=0, max_concurrency=1) class AsyncLLMActor: - async def __init__(self): + def __init__(self): self.engine: AsyncLLM - async def start(self, pg, bundle_indices=None, ray_runtime_env=None): + def start(self, pg, bundle_indices=None, ray_runtime_env=None): os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" # VLLM_ALLOW_INSECURE_SERIALIZATION is needed so collective_rpc can # pickle _get_env_var over the AsyncLLM -> EngineCore ZMQ boundary. @@ -77,7 +81,7 @@ async def get_worker_env(self, name): ) return results - async def shutdown(self): + def shutdown(self): if engine := getattr(self, "engine", None): engine.shutdown() del self.engine From cfba15ee68d45a8dcc58dbd3a3b8f98328f9169c Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 26 Mar 2026 05:08:15 +0000 Subject: [PATCH 23/29] Fix test Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor_e2e.py | 103 ++++++++++-------- 1 file changed, 59 insertions(+), 44 deletions(-) diff --git a/tests/distributed/test_ray_v2_executor_e2e.py b/tests/distributed/test_ray_v2_executor_e2e.py index d4cd083b37b3..1d3558ccb5e9 100644 --- a/tests/distributed/test_ray_v2_executor_e2e.py +++ b/tests/distributed/test_ray_v2_executor_e2e.py @@ -10,38 +10,41 @@ import ray from tests.distributed.ray_v2_utils import enable_ray_v2_backend # noqa: F401 -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.sampling_params import SamplingParams -from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.executor.abstract import Executor MODEL = "facebook/opt-125m" def _get_env_var(worker, name): - """Called on RayWorkerProc workers via collective_rpc.""" return os.environ.get(name) -# Follows the same pattern as Ray Serve LLM (LLMServer / VLLMEngine): -# sync __init__ and start() (AsyncLLM constructor is sync), async methods -# for generate/collective_rpc. max_concurrency=1 gives the actor an -# event loop so async methods work. -@ray.remote(num_cpus=0, max_concurrency=1) -class AsyncLLMActor: - def __init__(self): - self.engine: AsyncLLM +def _ray_init(): + """Start Ray with the project root on workers' PYTHONPATH. + Without this, workers cannot unpickle actor classes defined in the + ``tests`` package, causing FunctionActorManager to fall back to + TemporaryActor which drops async method signatures.""" + ray.init( + ignore_reinit_error=True, + runtime_env={"env_vars": {"PYTHONPATH": os.getcwd()}}, + ) + + +class _AsyncLLMActor: def start(self, pg, bundle_indices=None, ray_runtime_env=None): os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" - # VLLM_ALLOW_INSECURE_SERIALIZATION is needed so collective_rpc can - # pickle _get_env_var over the AsyncLLM -> EngineCore ZMQ boundary. + # Needed so collective_rpc can pickle _get_env_var over the + # AsyncLLM -> EngineCore ZMQ boundary. os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" if bundle_indices is not None: os.environ["VLLM_RAY_BUNDLE_INDICES"] = bundle_indices else: os.environ.pop("VLLM_RAY_BUNDLE_INDICES", None) + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.v1.engine.async_llm import AsyncLLM + from vllm.v1.executor.abstract import Executor + engine_args = AsyncEngineArgs( model=MODEL, tensor_parallel_size=2, @@ -64,6 +67,8 @@ def start(self, pg, bundle_indices=None, ray_runtime_env=None): ) async def generate(self, prompt): + from vllm.sampling_params import SamplingParams + params = SamplingParams(max_tokens=16) result = None async for output in self.engine.generate( @@ -73,13 +78,25 @@ async def generate(self, prompt): assert result is not None return result.outputs[0].text - async def get_worker_env(self, name): - results = await self.engine.collective_rpc( - _get_env_var, - timeout=10, - args=(name,), - ) - return results + async def generate_and_get_worker_envs(self, prompt, env_names): + from vllm.sampling_params import SamplingParams + + params = SamplingParams(max_tokens=16) + result = None + async for output in self.engine.generate( + prompt, params, request_id="test_request_id" + ): + result = output + assert result is not None + text = result.outputs[0].text + + env_results = {} + for name in env_names: + vals = await self.engine.collective_rpc( + _get_env_var, timeout=10, args=(name,) + ) + env_results[name] = vals + return text, env_results def shutdown(self): if engine := getattr(self, "engine", None): @@ -88,21 +105,18 @@ def shutdown(self): gc.collect() -def test_multi_replicas(): - """Two actors each run AsyncLLM with TP=2 via RayExecutorV2. +AsyncLLMActor = ray.remote(num_cpus=0, max_concurrency=1)(_AsyncLLMActor) - Actor 1 starts first and claims 80% of GPU memory. Without lazy - RayWorkerProc init, actor 2 lands on the *same* two GPUs and fails - because there is not enough free memory. - """ - ray.init(ignore_reinit_error=True) + +def test_multi_replicas(): + _ray_init() pg1 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") pg2 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") ray.get([pg1.ready(), pg2.ready()]) - actor1 = AsyncLLMActor.remote() # type: ignore[attr-defined] - actor2 = AsyncLLMActor.remote() # type: ignore[attr-defined] + actor1 = AsyncLLMActor.remote() + actor2 = AsyncLLMActor.remote() ray.get(actor1.start.remote(pg1)) ray.get(actor2.start.remote(pg2)) @@ -118,16 +132,13 @@ def test_multi_replicas(): def test_multi_replicas_with_bundle_indices(): - """Two actors share one 4-GPU placement group with out-of-order - bundle indices: actor 1 gets bundles [2,1], actor 2 gets [0,3]. - """ - ray.init(ignore_reinit_error=True) + _ray_init() pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 4, strategy="PACK") ray.get(pg.ready()) - actor1 = AsyncLLMActor.remote() # type: ignore[attr-defined] - actor2 = AsyncLLMActor.remote() # type: ignore[attr-defined] + actor1 = AsyncLLMActor.remote() + actor2 = AsyncLLMActor.remote() ray.get(actor1.start.remote(pg, bundle_indices="2,1")) ray.get(actor2.start.remote(pg, bundle_indices="0,3")) @@ -155,25 +166,29 @@ def test_env_var_and_runtime_env_propagation(): os.environ[k] = v try: - ray.init(ignore_reinit_error=True) + _ray_init() pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") ray.get(pg.ready()) ray_runtime_env = { - "env_vars": {"RAY_RUNTIME_ENV_MARKER": "ray_runtime_env"}, + "env_vars": {"RAY_RUNTIME_ENV_TEST": "ray_runtime_env"}, } - actor = AsyncLLMActor.remote() # type: ignore[attr-defined] + actor = AsyncLLMActor.remote() ray.get(actor.start.remote(pg, ray_runtime_env=ray_runtime_env)) + all_env_names = list(sentinel_vars) + ["RAY_RUNTIME_ENV_TEST"] + text, env_results = ray.get( + actor.generate_and_get_worker_envs.remote("Hello world", all_env_names) + ) + assert len(text) > 0 + for name, expected in sentinel_vars.items(): - results = ray.get(actor.get_worker_env.remote(name)) - for val in results: + for val in env_results[name]: assert val == expected - results = ray.get(actor.get_worker_env.remote("RAY_RUNTIME_ENV_MARKER")) - for val in results: + for val in env_results["RAY_RUNTIME_ENV_TEST"]: assert val == "ray_runtime_env" finally: From 476501bcb41344e0ed26e8d7961c9b21ca33d617 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 26 Mar 2026 07:04:15 +0000 Subject: [PATCH 24/29] Fix test Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor_e2e.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_ray_v2_executor_e2e.py b/tests/distributed/test_ray_v2_executor_e2e.py index 1d3558ccb5e9..a2634ffa5229 100644 --- a/tests/distributed/test_ray_v2_executor_e2e.py +++ b/tests/distributed/test_ray_v2_executor_e2e.py @@ -6,6 +6,7 @@ import gc import os +import pathlib import ray @@ -24,9 +25,10 @@ def _ray_init(): Without this, workers cannot unpickle actor classes defined in the ``tests`` package, causing FunctionActorManager to fall back to TemporaryActor which drops async method signatures.""" + project_root = str(pathlib.Path(__file__).resolve().parents[2]) ray.init( ignore_reinit_error=True, - runtime_env={"env_vars": {"PYTHONPATH": os.getcwd()}}, + runtime_env={"env_vars": {"PYTHONPATH": project_root}}, ) From c7aa6610d283b33cad366b22db0dc020e79bd984 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Thu, 26 Mar 2026 16:11:59 +0000 Subject: [PATCH 25/29] Fix wrong PYTHONPATH in Ray workers Signed-off-by: Jeffrey Wang --- tests/distributed/test_ray_v2_executor_e2e.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_ray_v2_executor_e2e.py b/tests/distributed/test_ray_v2_executor_e2e.py index a2634ffa5229..155a43ad5d6a 100644 --- a/tests/distributed/test_ray_v2_executor_e2e.py +++ b/tests/distributed/test_ray_v2_executor_e2e.py @@ -173,8 +173,14 @@ def test_env_var_and_runtime_env_propagation(): pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") ray.get(pg.ready()) + # Include the project root so that RayWorkerProc actors can + # unpickle _get_env_var. + project_root = str(pathlib.Path(__file__).resolve().parents[2]) ray_runtime_env = { - "env_vars": {"RAY_RUNTIME_ENV_TEST": "ray_runtime_env"}, + "env_vars": { + "RAY_RUNTIME_ENV_TEST": "ray_runtime_env", + "PYTHONPATH": project_root, + }, } actor = AsyncLLMActor.remote() From e0fd3211e250836d84cb5d506e353147bf3852e1 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 30 Mar 2026 00:10:14 +0000 Subject: [PATCH 26/29] CR feedback round 1 Signed-off-by: Jeffrey Wang --- vllm/v1/executor/ray_executor_v2.py | 27 +++++++++++++++++++-------- vllm/v1/executor/ray_utils.py | 26 +++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 81855c765612..f2aeb965d291 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -67,6 +67,10 @@ class RayWorkerHandle: run_ref: ObjectRef = None """run() ObjectRef used as a sentinel for health monitoring""" + def run(self): + """Start the worker's busy loop""" + self.run_ref = self.actor.run.remote() + class RayWorkerProc(WorkerProc): """Worker process that runs inside a Ray actor. @@ -104,8 +108,7 @@ def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: device_key = current_platform.ray_device_key if not device_key: raise RuntimeError( - "current platform %s does not support ray.", - current_platform.device_name, + f"current platform {current_platform.device_name} does not support ray." ) gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] return node_id, [int(x) for x in gpu_ids] @@ -165,6 +168,9 @@ def run(self) -> None: self.worker_response_mq.wait_until_ready() self.worker_busy_loop() + except Exception as e: + logger.exception("RayWorkerProc failed: %s", e) + raise finally: self.shutdown() @@ -180,7 +186,7 @@ class RayExecutorV2(MultiprocExecutor): supports_pp: bool = True def __init__(self, vllm_config: VllmConfig): - super(MultiprocExecutor, self).__init__(vllm_config) + super().__init__(vllm_config) def _build_runtime_env(self) -> dict: """Build a runtime_env dict for RayWorkerProc actors. @@ -221,7 +227,7 @@ def _init_executor(self) -> None: # Step 1: Initialize Ray cluster and retrieve placement group if ray is None: - raise ImportError("Ray is required for RayExecutorV2") + raise ImportError("Using Ray backend requires installation of ray.") initialize_ray_cluster(self.parallel_config, require_gpu_on_driver=False) placement_group = self.parallel_config.placement_group @@ -281,7 +287,7 @@ def _init_executor(self) -> None: instance_id = self.vllm_config.instance_id # Collect env vars to propagate from driver to workers (NCCL, - # HF, vLLM flags, etc.) — same mechanism as RayDistributedExecutor. + # HF, vLLM flags, etc.). env_vars_to_copy = get_env_vars_to_copy( exclude_vars=WORKER_SPECIFIC_ENV_VARS, additional_vars=set(current_platform.additional_env_vars), @@ -369,8 +375,9 @@ def _init_executor(self) -> None: ray.get(init_worker_refs) # Step 8: Collect response MQ handles - init_refs = [h.actor.wait_for_init.remote() for h in self.ray_worker_handles] - init_results = ray.get(init_refs) + init_results = ray.get( + [h.actor.wait_for_init.remote() for h in self.ray_worker_handles] + ) self.response_mqs: list[MessageQueue] = [] for i, result in enumerate(init_results): @@ -383,7 +390,7 @@ def _init_executor(self) -> None: # Step 9: Start run() before wait_until_ready() to avoid # deadlock — workers send subscriptions inside run(). for handle in self.ray_worker_handles: - handle.run_ref = handle.actor.run.remote() + handle.run() # Step 10: wait_until_ready() barrier self.rpc_broadcast_mq.wait_until_ready() @@ -419,6 +426,9 @@ def monitor_workers(): try: done, _ = ray.wait(run_refs, num_returns=1, timeout=5.0) except Exception: + logger.exception( + "RayWorkerMonitor: unexpected error, exiting monitor thread" + ) return if not done or _should_stop(): continue @@ -474,6 +484,7 @@ def shutdown(self) -> None: for handle in getattr(self, "ray_worker_handles", []): try: ray.kill(handle.actor) + logger.debug("Killed actor rank=%d", handle.rank) except Exception: logger.exception("Failed to kill actor rank=%d", handle.rank) diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 6a997c7f50f0..f0d91c1082b3 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -339,6 +339,26 @@ def get_bundles_sorted_by_node( sorted driver-first. This utility has to be invoked from the driver node. + + Example: 3-node cluster, driver on node-A, PG bundles spread + across nodes: + + Input: [ + (0, node-C), + (1, node-A), + (2, node-B), + (3, node-C), + (4, node-A), + (5, node-B), + ] + Output: [ + (1, node-A), + (4, node-A), + (2, node-B), + (5, node-B), + (0, node-C), + (3, node-C), + ] """ pg_data = placement_group_table(placement_group) bundle_to_node = pg_data["bundles_to_node_id"] @@ -356,10 +376,10 @@ def get_bundles_sorted_by_node( bundle_specs = placement_group.bundle_specs assert bundle_specs is not None bundle_to_node_id: list[tuple[int, str, str]] = [] - for i, bundle in enumerate(bundle_specs): + for bundle_idx, bundle in enumerate(bundle_specs): if bundle.get(ray_device_key): - node_id = bundle_to_node.get(i) - bundle_to_node_id.append((i, node_id, node_id_to_ip[node_id])) + node_id = bundle_to_node.get(bundle_idx) + bundle_to_node_id.append((bundle_idx, node_id, node_id_to_ip[node_id])) driver_node = ray.get_runtime_context().get_node_id() From af21cdd649ef83968e61bb181e4de3f49d4d0266 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 30 Mar 2026 21:49:07 +0000 Subject: [PATCH 27/29] CR feedback round 2 Signed-off-by: Jeffrey Wang --- tests/distributed/conftest.py | 29 +++++++++++++++++ tests/distributed/ray_v2_utils.py | 32 ------------------- tests/distributed/test_ray_v2_executor.py | 3 +- tests/distributed/test_ray_v2_executor_e2e.py | 19 +++++++---- vllm/v1/executor/ray_executor_v2.py | 18 +++++++++++ 5 files changed, 61 insertions(+), 40 deletions(-) delete mode 100644 tests/distributed/ray_v2_utils.py diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 9c146a3323d9..da661c5e13ba 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os import random import msgspec @@ -166,3 +167,31 @@ def close(self): self.sub.close() for replay in self.replay_sockets: replay.close() + + +@pytest.fixture +def enable_ray_v2_backend(): + """Set env vars for the Ray V2 executor backend and shut down Ray + between tests.""" + import ray + + saved = { + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( + "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" + ), + "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( + "VLLM_ENABLE_V1_MULTIPROCESSING" + ), + } + os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" + if ray.is_initialized(): + ray.shutdown() + try: + yield + finally: + if ray.is_initialized(): + ray.shutdown() + os.environ.update({k: v for k, v in saved.items() if v is not None}) + for key in (k for k, v in saved.items() if v is None): + os.environ.pop(key, None) diff --git a/tests/distributed/ray_v2_utils.py b/tests/distributed/ray_v2_utils.py deleted file mode 100644 index a490f7675b07..000000000000 --- a/tests/distributed/ray_v2_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Shared test utilities for RayExecutorV2 tests.""" - -import os - -import pytest -import ray - - -@pytest.fixture(autouse=True) -def enable_ray_v2_backend(): - saved = { - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": os.environ.get( - "VLLM_USE_RAY_V2_EXECUTOR_BACKEND" - ), - "VLLM_ENABLE_V1_MULTIPROCESSING": os.environ.get( - "VLLM_ENABLE_V1_MULTIPROCESSING" - ), - } - os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" - if ray.is_initialized(): - ray.shutdown() - try: - yield - finally: - if ray.is_initialized(): - ray.shutdown() - os.environ.update({k: v for k, v in saved.items() if v is not None}) - for key in (k for k, v in saved.items() if v is None): - os.environ.pop(key, None) diff --git a/tests/distributed/test_ray_v2_executor.py b/tests/distributed/test_ray_v2_executor.py index eaf80da05261..5daec22df6fc 100644 --- a/tests/distributed/test_ray_v2_executor.py +++ b/tests/distributed/test_ray_v2_executor.py @@ -14,12 +14,13 @@ import pytest import ray -from tests.distributed.ray_v2_utils import enable_ray_v2_backend # noqa: F401 from vllm import LLM from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.v1.executor.ray_executor_v2 import RayExecutorV2 +pytestmark = pytest.mark.usefixtures("enable_ray_v2_backend") + MODEL = "facebook/opt-125m" diff --git a/tests/distributed/test_ray_v2_executor_e2e.py b/tests/distributed/test_ray_v2_executor_e2e.py index 155a43ad5d6a..fb5830132698 100644 --- a/tests/distributed/test_ray_v2_executor_e2e.py +++ b/tests/distributed/test_ray_v2_executor_e2e.py @@ -8,9 +8,10 @@ import os import pathlib +import pytest import ray -from tests.distributed.ray_v2_utils import enable_ray_v2_backend # noqa: F401 +pytestmark = pytest.mark.usefixtures("enable_ray_v2_backend") MODEL = "facebook/opt-125m" @@ -32,6 +33,11 @@ def _ray_init(): ) +@pytest.fixture +def ray_init(): + _ray_init() + + class _AsyncLLMActor: def start(self, pg, bundle_indices=None, ray_runtime_env=None): os.environ["VLLM_USE_RAY_V2_EXECUTOR_BACKEND"] = "1" @@ -110,9 +116,7 @@ def shutdown(self): AsyncLLMActor = ray.remote(num_cpus=0, max_concurrency=1)(_AsyncLLMActor) -def test_multi_replicas(): - _ray_init() - +def test_multi_replicas(ray_init): pg1 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") pg2 = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") ray.get([pg1.ready(), pg2.ready()]) @@ -133,9 +137,7 @@ def test_multi_replicas(): assert len(out2) > 0 -def test_multi_replicas_with_bundle_indices(): - _ray_init() - +def test_multi_replicas_with_bundle_indices(ray_init): pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 4, strategy="PACK") ray.get(pg.ready()) @@ -168,6 +170,9 @@ def test_env_var_and_runtime_env_propagation(): os.environ[k] = v try: + # Called directly (not via the ray_init fixture) because sentinel + # env vars must be in os.environ before ray.init() so that Ray + # worker processes inherit them. _ray_init() pg = ray.util.placement_group([{"GPU": 1, "CPU": 1}] * 2, strategy="PACK") diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index f2aeb965d291..0944c90bf2a6 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -80,6 +80,24 @@ class RayWorkerProc(WorkerProc): 2. initialize_worker: called after GPU IDs are discovered, completes the full WorkerProc initialization with the correct local_rank and CUDA_VISIBLE_DEVICES. + + CUDA_VISIBLE_DEVICES setup flow: + + 1. RayExecutorV2 enables RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES so Ray does + not set CUDA_VISIBLE_DEVICES on RayWorkerProc actors at creation time. + 2. Each actor is scheduled with a placement group and bundle index; Ray resolves + the physical GPU ID for that bundle at placement time. + 3. After placement, the worker discovers that GPU ID and sets + CUDA_VISIBLE_DEVICES before finishing WorkerProc initialization. + + There is no workaround for this unset-and-reset sequence when the placement group + is externally managed: scheduling must complete before CUDA_VISIBLE_DEVICES can + match the GPU tied to the worker's bundle. + + This sequence allows multiple vLLM instances to coexist on the same node: + each instance is unaware which physical devices others hold, and the + externally managed placement group avoids CUDA_VISIBLE_DEVICES conflicts + by binding workers to specific placement group bundles. """ def __init__( From ad8f6d080e19335299ee5f0a854c6349cd0e7cec Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Mon, 30 Mar 2026 23:13:35 +0000 Subject: [PATCH 28/29] Only apply blacklist & propagate env with setdefault Signed-off-by: Jeffrey Wang --- tests/test_ray_env_utils.py | 51 +++++++++++++++++++++++++++++ vllm/v1/executor/ray_env_utils.py | 18 ++++++++++ vllm/v1/executor/ray_executor_v2.py | 37 +++++++++++---------- 3 files changed, 88 insertions(+), 18 deletions(-) create mode 100644 tests/test_ray_env_utils.py create mode 100644 vllm/v1/executor/ray_env_utils.py diff --git a/tests/test_ray_env_utils.py b/tests/test_ray_env_utils.py new file mode 100644 index 000000000000..d311de41ba96 --- /dev/null +++ b/tests/test_ray_env_utils.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for vllm.v1.executor.ray_env_utils.""" + +import os +from unittest.mock import patch + +from vllm.v1.executor.ray_env_utils import get_driver_env_vars + +WORKER_VARS: set[str] = { + "CUDA_VISIBLE_DEVICES", + "LOCAL_RANK", +} + + +class TestDefaultPropagation: + """All env vars are propagated unless explicitly excluded.""" + + @patch.dict(os.environ, {"NCCL_DEBUG": "INFO"}, clear=False) + def test_nccl_prefix(self): + assert get_driver_env_vars(WORKER_VARS)["NCCL_DEBUG"] == "INFO" + + @patch.dict(os.environ, {"HF_TOKEN": "secret"}, clear=False) + def test_hf_token(self): + assert "HF_TOKEN" in get_driver_env_vars(WORKER_VARS) + + @patch.dict(os.environ, {"LMCACHE_LOCAL_CPU": "True"}, clear=False) + def test_lmcache_prefix(self): + assert "LMCACHE_LOCAL_CPU" in get_driver_env_vars(WORKER_VARS) + + @patch.dict(os.environ, {"PYTHONHASHSEED": "42"}, clear=False) + def test_pythonhashseed(self): + assert get_driver_env_vars(WORKER_VARS)["PYTHONHASHSEED"] == "42" + + @patch.dict(os.environ, {"MYLIB_FOO": "bar"}, clear=False) + def test_arbitrary_var_propagated(self): + assert get_driver_env_vars(WORKER_VARS)["MYLIB_FOO"] == "bar" + + +class TestExclusion: + @patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}, clear=False) + def test_worker_specific_excluded(self): + assert "CUDA_VISIBLE_DEVICES" not in get_driver_env_vars(WORKER_VARS) + + @patch.dict(os.environ, {"LMCACHE_LOCAL_CPU": "True"}, clear=False) + @patch( + "vllm.v1.executor.ray_env_utils.RAY_NON_CARRY_OVER_ENV_VARS", + {"LMCACHE_LOCAL_CPU"}, + ) + def test_non_carry_over_blacklist(self): + assert "LMCACHE_LOCAL_CPU" not in get_driver_env_vars(WORKER_VARS) diff --git a/vllm/v1/executor/ray_env_utils.py b/vllm/v1/executor/ray_env_utils.py new file mode 100644 index 000000000000..6ce12b8ca913 --- /dev/null +++ b/vllm/v1/executor/ray_env_utils.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +from vllm.ray.ray_env import RAY_NON_CARRY_OVER_ENV_VARS + + +def get_driver_env_vars( + worker_specific_vars: set[str], +) -> dict[str, str]: + """Return driver env vars to propagate to Ray workers. + + Returns everything from ``os.environ`` except ``worker_specific_vars`` + and user-configured exclusions (``RAY_NON_CARRY_OVER_ENV_VARS``). + """ + exclude_vars = worker_specific_vars | RAY_NON_CARRY_OVER_ENV_VARS + + return {key: value for key, value in os.environ.items() if key not in exclude_vars} diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index 0944c90bf2a6..ce6085a6de96 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -16,7 +16,6 @@ ) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.ray.ray_env import get_env_vars_to_copy from vllm.utils.network_utils import ( get_distributed_init_method, get_open_port, @@ -26,6 +25,7 @@ MultiprocExecutor, WorkerProc, ) +from vllm.v1.executor.ray_env_utils import get_driver_env_vars from vllm.v1.executor.ray_utils import ( WORKER_SPECIFIC_ENV_VARS, build_actor_name, @@ -131,12 +131,21 @@ def get_node_and_gpu_ids(self) -> tuple[str, list[int]]: gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] return node_id, [int(x) for x in gpu_ids] - def initialize_worker(self, local_rank: int, env_vars: dict[str, str]) -> None: + def initialize_worker( + self, + local_rank: int, + env_vars: dict[str, str], + driver_env_vars: dict[str, str] | None = None, + ) -> None: """Complete initialization after GPU assignment is known. - Sets CUDA_VISIBLE_DEVICES and initializes the underlying WorkerProc - with the correct local_rank. + *driver_env_vars* are applied with ``setdefault`` — they fill + in missing vars but never overwrite node-local values. + *env_vars* (e.g. CUDA_VISIBLE_DEVICES) always overwrite. """ + if driver_env_vars: + for key, value in driver_env_vars.items(): + os.environ.setdefault(key, value) for key, value in env_vars.items(): os.environ[key] = value @@ -209,15 +218,13 @@ def __init__(self, vllm_config: VllmConfig): def _build_runtime_env(self) -> dict: """Build a runtime_env dict for RayWorkerProc actors. - Merges parallel_config.ray_runtime_env, driver env vars from - get_env_vars_to_copy, noset-device flags, and optional Nsight - profiling config. + Driver env vars are applied separately via initialize_worker + with setdefault semantics. """ base = self.parallel_config.ray_runtime_env runtime_env: dict = copy.deepcopy(dict(base)) if base else {} env_vars = runtime_env.setdefault("env_vars", {}) - env_vars.update(self._worker_env_vars) env_vars.update({v: "1" for v in current_platform.ray_noset_device_env_vars}) if self.parallel_config.ray_workers_use_nsight: runtime_env["nsight"] = { @@ -304,16 +311,10 @@ def _init_executor(self) -> None: self.ray_worker_handles: list[RayWorkerHandle] = [] instance_id = self.vllm_config.instance_id - # Collect env vars to propagate from driver to workers (NCCL, - # HF, vLLM flags, etc.). - env_vars_to_copy = get_env_vars_to_copy( - exclude_vars=WORKER_SPECIFIC_ENV_VARS, - additional_vars=set(current_platform.additional_env_vars), - destination="RayWorkerProc actors", + # Collect driver env vars and apply but don't overwrite node-local values. + self.driver_env_vars = get_driver_env_vars( + worker_specific_vars=WORKER_SPECIFIC_ENV_VARS, ) - self._worker_env_vars = { - name: os.environ[name] for name in env_vars_to_copy if name in os.environ - } for bundle_idx in range(self.world_size): bundle = bundle_assignments[bundle_idx] @@ -387,7 +388,7 @@ def _init_executor(self) -> None: self.ray_worker_handles[i].local_rank = local_rank init_worker_refs.append( self.ray_worker_handles[i].actor.initialize_worker.remote( - local_rank, worker_env_vars + local_rank, worker_env_vars, self.driver_env_vars ) ) ray.get(init_worker_refs) From 75862045e6ae449e4e221c5c6fb410c0764764e4 Mon Sep 17 00:00:00 2001 From: Jeffrey Wang Date: Tue, 31 Mar 2026 16:07:04 +0000 Subject: [PATCH 29/29] CR feedback round 3 Signed-off-by: Jeffrey Wang --- vllm/v1/executor/ray_executor_v2.py | 26 +++++++++++++++++--------- vllm/v1/executor/ray_utils.py | 2 ++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/vllm/v1/executor/ray_executor_v2.py b/vllm/v1/executor/ray_executor_v2.py index ce6085a6de96..255bcb5fcd7f 100644 --- a/vllm/v1/executor/ray_executor_v2.py +++ b/vllm/v1/executor/ray_executor_v2.py @@ -64,7 +64,7 @@ class RayWorkerHandle: bundle_id_idx: int = -1 """Placement group bundle index for the worker""" - run_ref: ObjectRef = None + run_ref: ObjectRef | None = None """run() ObjectRef used as a sentinel for health monitoring""" def run(self): @@ -207,6 +207,9 @@ class RayExecutorV2(MultiprocExecutor): Inherits from MultiprocExecutor to reuse the MQ-based control plane and NCCL data plane. Workers are Ray actors. + + Async scheduling is enabled, inherited from MultiprocExecutor. + This is cricitcal for RayExecutorV2 to be performant. """ uses_ray: bool = True @@ -249,6 +252,7 @@ def _init_executor(self) -> None: self.is_failed = False self.failure_callback = None self.shutting_down = False + self.shutdown_lock = threading.Lock() # Step 1: Initialize Ray cluster and retrieve placement group if ray is None: @@ -316,6 +320,9 @@ def _init_executor(self) -> None: worker_specific_vars=WORKER_SPECIFIC_ENV_VARS, ) + runtime_env = self._build_runtime_env() + resource_kwargs = self._get_actor_resource_kwargs() + for bundle_idx in range(self.world_size): bundle = bundle_assignments[bundle_idx] is_driver_worker = self._is_driver_worker(bundle["rank"]) @@ -326,9 +333,6 @@ def _init_executor(self) -> None: placement_group_bundle_index=bundle["bundle_id_idx"], ) - runtime_env = self._build_runtime_env() - resource_kwargs = self._get_actor_resource_kwargs() - actor_name = build_actor_name( instance_id, bundle["rank"], tp_size, pp_size, pcp_size ) @@ -452,7 +456,7 @@ def monitor_workers(): if not done or _should_stop(): continue - dead_ranks = [ref_to_rank[r] for r in done if r in ref_to_rank] + dead_ranks = [ref_to_rank[r] for r in done] executor = self_ref() if not executor: return @@ -492,11 +496,15 @@ def _join_monitor_thread(self) -> None: monitor.join(timeout=10) def shutdown(self) -> None: - """Properly shut down the executor and its workers""" - if getattr(self, "shutting_down", False): - self._join_monitor_thread() + """Properly shut down the executor and its workers.""" + lock = getattr(self, "shutdown_lock", None) + if lock is None: return - self.shutting_down = True + + with lock: + if getattr(self, "shutting_down", False): + return + self.shutting_down = True self._join_monitor_thread() diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index f0d91c1082b3..7c9a0c1976d7 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -596,6 +596,8 @@ def initialize_ray_cluster( current_ip = get_ip() current_node_id = ray.get_runtime_context().get_node_id() current_node_resource = available_resources_per_node()[current_node_id] + # TODO (jeffreywang): require_gpu_on_driver should be always False + # after deprecating RayDistributedExecutor. if require_gpu_on_driver: if current_node_resource.get(device_str, 0) < 1: raise ValueError(