From 38a281c960840c738135484e0f152967f3771764 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 10 Mar 2026 04:59:27 -0400 Subject: [PATCH 1/2] Revert "[Bugfix] Quickfix followups to busy loop removal in #28053 (#36068)" This reverts commit 6b625a8807f4c82137c46d58dfb38f8eeef4865c. --- vllm/v1/executor/multiproc_executor.py | 45 +++++++++++++------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 95336034caf7..e455270cbef3 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -158,13 +158,10 @@ def _init_executor(self) -> None: global_start_rank = ( self.local_world_size * self.parallel_config.node_rank_within_dp ) - # When using fork, keep track of socket file descriptors that are - # inherited by the worker, so that we can close them in subsequent + # Keep track of socket file descriptors that are inherited by the + # worker when using fork, so that we can close them in subsequent # workers - inherited_fds: list[int] | None = ( - [] if context.get_start_method() == "fork" else None - ) - + inherited_fds: list[int] = [] for local_rank in range(self.local_world_size): global_rank = global_start_rank + local_rank is_driver_worker = self._is_driver_worker(global_rank) @@ -179,9 +176,13 @@ def _init_executor(self) -> None: inherited_fds=inherited_fds, ) unready_workers.append(unready_worker_handle) - if inherited_fds is not None: - inherited_fds.append(unready_worker_handle.death_writer.fileno()) - inherited_fds.append(unready_worker_handle.ready_pipe.fileno()) + if context.get_start_method() == "fork": + inherited_fds.extend( + [ + unready_worker_handle.death_writer.fileno(), + unready_worker_handle.ready_pipe.fileno(), + ] + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -453,13 +454,12 @@ def shutdown(self): w.worker_response_mq.shutdown() w.worker_response_mq = None - if rpc_broadcast_mq := getattr(self, "rpc_broadcast_mq", None): - rpc_broadcast_mq.shutdown() + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.shutdown() self.rpc_broadcast_mq = None - if response_mqs := getattr(self, "response_mqs", None): - for mq in response_mqs: - mq.shutdown() - self.response_mqs = [] + for mq in self.response_mqs: + mq.shutdown() + self.response_mqs = [] def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) @@ -638,16 +638,13 @@ def make_worker_process( input_shm_handle, # Receive SchedulerOutput shared_worker_lock: LockType, is_driver_worker: bool, - inherited_fds: list[int] | None = None, + inherited_fds: list[int], ) -> UnreadyWorkerProcHandle: context = get_mp_context() # Ready pipe to communicate readiness from child to parent ready_reader, ready_writer = context.Pipe(duplex=False) # Death pipe to let child detect parent process exit death_reader, death_writer = context.Pipe(duplex=False) - if inherited_fds is not None: - inherited_fds = inherited_fds.copy() - inherited_fds.extend((ready_reader.fileno(), death_writer.fileno())) process_kwargs = { "vllm_config": vllm_config, "local_rank": local_rank, @@ -659,7 +656,8 @@ def make_worker_process( "shared_worker_lock": shared_worker_lock, "is_driver_worker": is_driver_worker, # Have the worker close parent end of this worker's pipes too - "inherited_fds": inherited_fds if inherited_fds is not None else [], + "inherited_fds": inherited_fds + + [ready_reader.fileno(), death_writer.fileno()], } # Run EngineCore busy loop in background process. proc = context.Process( @@ -703,8 +701,9 @@ def wait_for_ready( unready_proc_handles: list[UnreadyWorkerProcHandle], ) -> list[WorkerProcHandle]: e = Exception( - "WorkerProc initialization failed due to an exception in a " - "background process. See stack trace for root cause." + "WorkerProc initialization failed due to " + "an exception in a background process. " + "See stack trace for root cause." ) pipes = {handle.ready_pipe: handle for handle in unready_proc_handles} @@ -807,7 +806,7 @@ def signal_handler(signum, frame): try: os.close(fd) except Exception as e: - logger.warning("Error closing inherited connection: %s: %s", type(e), e) + logger.warning("Exception closing inherited connection: %s", e) try: # Initialize tracer From 0a802b16c69e6173d0b380d4788f7d798535622a Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 10 Mar 2026 04:59:36 -0400 Subject: [PATCH 2/2] Revert "[Core] Remove busy loop from idle buffer readers (#28053)" This reverts commit 6f0dd93801163a6418695c6dc0b43c516261f55a. --- .../test_basic_correctness.py | 2 + tests/distributed/test_shm_broadcast.py | 293 +----------------- .../device_communicators/shm_broadcast.py | 258 ++++----------- vllm/envs.py | 5 + vllm/v1/executor/multiproc_executor.py | 166 ++++------ 5 files changed, 140 insertions(+), 584 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 70c58ad96dd7..68b5cd5101d5 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -124,6 +124,8 @@ def test_models( [ ("facebook/opt-125m", "ray", "", "L4", {}), ("facebook/opt-125m", "mp", "", "L4", {}), + ("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), + ("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("facebook/opt-125m", "ray", "", "A100", {}), diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 7cf3b01e75c7..a7ace62e1b54 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing import random -import threading import time -from unittest import mock -import multiprocess as mp import numpy as np -import pytest import torch.distributed as dist from vllm.distributed.device_communicators.shm_broadcast import MessageQueue @@ -25,14 +22,7 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: return [np.random.randint(1, 100, i) for i in sizes] -def distributed_run(fn, world_size, timeout=60): - """Run a function in multiple processes with proper error handling. - - Args: - fn: Function to run in each process - world_size: Number of processes to spawn - timeout: Maximum time in seconds to wait for processes (default: 60) - """ +def distributed_run(fn, world_size): number_of_processes = world_size processes = [] for i in range(number_of_processes): @@ -43,45 +33,19 @@ def distributed_run(fn, world_size, timeout=60): env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["MASTER_ADDR"] = "localhost" env["MASTER_PORT"] = "12345" - p = mp.Process(target=fn, args=(env,)) + p = multiprocessing.Process(target=fn, args=(env,)) processes.append(p) p.start() - # Monitor processes and fail fast if any process fails - start_time = time.time() - failed_processes = [] - - # Wait for all processes, checking for failures - while time.time() - start_time < timeout: - all_done = True - for i, p in enumerate(processes): - if p.is_alive(): - all_done = False - elif p.exitcode != 0: - # Process failed - failed_processes.append((i, p.exitcode)) - break - - if failed_processes or all_done: - break - time.sleep(0.1) # Check every 100ms + for p in processes: + p.join() - # Check for timeout if no failures detected yet - for i, p in enumerate(processes): - if p.is_alive(): - p.kill() - p.join() - - # Report failures - if failed_processes: - error_msg = "Distributed test failed:\n" - for rank, status in failed_processes: - error_msg += f" Rank {rank}: Exit code {status}\n" - raise AssertionError(error_msg) + for p in processes: + assert p.exitcode == 0 def worker_fn_wrapper(fn): - # `mp.Process` cannot accept environment variables directly + # `multiprocessing.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function def wrapped_fn(env): @@ -151,244 +115,3 @@ def worker_fn(): def test_shm_broadcast(): distributed_run(worker_fn, 4) - - -@worker_fn_wrapper -def worker_fn_test_shutdown_busy(): - rank = dist.get_rank() - writer_rank = 2 - message_queue = MessageQueue.create_from_process_group( - dist.group.WORLD, 40 * 1024, 2, writer_rank - ) - - if not message_queue._is_writer: - # Put into busy mode - message_queue._spin_condition.busy_loop_s = 9999 - - shutdown_event = threading.Event() - - def shutdown_thread(mq, shutdown_event): - shutdown_event.wait() - mq.shutdown() - - threading.Thread( - target=shutdown_thread, args=(message_queue, shutdown_event) - ).start() - - with pytest.raises(TimeoutError): - message_queue.dequeue(timeout=0.01) - - shutdown_event.set() - - with pytest.raises(RuntimeError, match="cancelled"): - message_queue.dequeue(timeout=1) - - assert message_queue.shutting_down - - print(f"torch distributed passed the test! Rank {rank}") - dist.barrier() - - -def test_message_queue_shutdown_busy(caplog_vllm): - distributed_run(worker_fn_test_shutdown_busy, 4) - print(caplog_vllm.text) - - -@worker_fn_wrapper -def worker_fn_test_shutdown_idle(): - rank = dist.get_rank() - writer_rank = 2 - message_queue = MessageQueue.create_from_process_group( - dist.group.WORLD, 40 * 1024, 2, writer_rank - ) - - if not message_queue._is_writer: - # Put into idle mode - message_queue._spin_condition.last_read = 0 - - shutdown_event = threading.Event() - - def shutdown_thread(mq, shutdown_event): - shutdown_event.wait() - mq.shutdown() - - threading.Thread( - target=shutdown_thread, args=(message_queue, shutdown_event) - ).start() - - with pytest.raises(TimeoutError): - message_queue.dequeue(timeout=0.01) - - shutdown_event.set() - - with pytest.raises(RuntimeError, match="cancelled"): - message_queue.dequeue(timeout=1) - - assert message_queue.shutting_down - - print(f"torch distributed passed the test! Rank {rank}") - dist.barrier() - - -def test_message_queue_shutdown_idle(): - distributed_run(worker_fn_test_shutdown_idle, 4) - - -@worker_fn_wrapper -def worker_fn_test_idle_to_busy(): - rank = dist.get_rank() - writer_rank = 2 - message_queue = MessageQueue.create_from_process_group( - dist.group.WORLD, 40 * 1024, 2, writer_rank - ) - - message1 = "hello world" - message2 = np.random.randint(1, 100, 100) - with mock.patch.object( - message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait - ) as wrapped_wait: - if not message_queue._is_writer: - # Put into idle mode - message_queue._spin_condition.last_read = 0 - - # no messages, so expect a TimeoutError - with pytest.raises(TimeoutError): - message_queue.dequeue(timeout=0.01) - # wait should only be called once while idle - assert wrapped_wait.call_count == 1 - - # sync with the writer and wait for message1 - dist.barrier() - recv_message = message_queue.dequeue(timeout=5) - assert recv_message == message1 - # second call to wait, with a message read, this puts in a busy spin - assert wrapped_wait.call_count == 2 - - # sync with the writer and wait for message2 - dist.barrier() - recv_message = message_queue.dequeue(timeout=1) - assert np.array_equal(recv_message, message2) - # in busy mode, we expect wait to have been called multiple times - assert wrapped_wait.call_count > 3 - else: - # writer writes two messages in sync with the reader - dist.barrier() - # sleep delays the send to ensure reader enters the read loop - time.sleep(0.1) - message_queue.enqueue(message1) - - dist.barrier() - time.sleep(0.1) - message_queue.enqueue(message2) - - message_queue.shutdown() - assert message_queue.shutting_down - print(f"torch distributed passed the test! Rank {rank}") - - -def test_message_queue_idle_wake(): - distributed_run(worker_fn_test_idle_to_busy, 4) - - -@worker_fn_wrapper -def worker_fn_test_busy_to_idle(): - rank = dist.get_rank() - writer_rank = 2 - message_queue = MessageQueue.create_from_process_group( - dist.group.WORLD, 40 * 1024, 2, writer_rank - ) - - message1 = 12345 - message2 = list(range(3)) - with mock.patch.object( - message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait - ) as wrapped_wait: - if not message_queue._is_writer: - # Put into busy mode - message_queue._spin_condition.busy_loop_s = 9999 - - # sync with the writer and wait for message1 - dist.barrier() - recv_message = message_queue.dequeue(timeout=1) - assert recv_message == message1 - # in busy mode, we expect wait to have been called many times - assert wrapped_wait.call_count > 1 - - # simulate busy loop ending - message_queue._spin_condition.busy_loop_s = 0 - # ensure we enter idle mode, then record call count - with pytest.raises(TimeoutError): - message_queue.dequeue(timeout=0.01) - call_count = wrapped_wait.call_count - - # sync with the writer and wait for message2 - dist.barrier() - recv_message = message_queue.dequeue(timeout=1) - assert recv_message == message2 - - # call to wait after idle should only happen once - assert wrapped_wait.call_count == call_count + 1 - else: - # writer writes two messages in sync with the reader - dist.barrier() - # sleep delays the send to ensure reader enters the read loop - time.sleep(0.1) - message_queue.enqueue(message1) - - dist.barrier() - time.sleep(0.1) - message_queue.enqueue(message2) - - message_queue.shutdown() - assert message_queue.shutting_down - print(f"torch distributed passed the test! Rank {rank}") - - -def test_message_queue_busy_to_idle(): - distributed_run(worker_fn_test_busy_to_idle, 4) - - -def test_warning_logs(caplog_vllm): - """ - Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals - when indefinite=False, and are not emitted when indefinite=True. - """ - - # Patch the warning log interval to every 1 ms during reads - with mock.patch( - "vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL", - new=0.001, # 1 ms - ): - writer = MessageQueue( - n_reader=1, - n_local_reader=1, - max_chunk_bytes=1024 * 1024, # 1MB chunks - max_chunks=10, - ) - reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0) - writer.wait_until_ready() - reader.wait_until_ready() - - # We should have at least one warning log here - # "0 seconds" expected due to rounding of 1ms test interval - with pytest.raises(TimeoutError): - reader.dequeue(timeout=0.01, indefinite=False) - assert any( - "No available shared memory broadcast block found in 0 seconds" - in record.message - for record in caplog_vllm.records - ) - caplog_vllm.clear() - - # We should have no warnings this time - with pytest.raises(TimeoutError): - reader.dequeue(timeout=0.01, indefinite=True) - assert all( - "No available shared memory broadcast block found in 0 seconds" - not in record.message - for record in caplog_vllm.records - ) - - # Clean up when done - writer.shutdown() - reader.shutdown() diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 1c5c4e01d8c8..ac46a5667373 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import pickle -import sys import threading import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory from pickle import PickleBuffer +from threading import Event from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch @@ -18,7 +18,6 @@ from torch.distributed import ProcessGroup from zmq import ( # type: ignore IPV6, # type: ignore - PUB, SUB, SUBSCRIBE, XPUB, @@ -33,7 +32,6 @@ from vllm.utils.network_utils import ( get_ip, get_open_port, - get_open_zmq_inproc_path, get_open_zmq_ipc_path, is_valid_ipv6_address, ) @@ -80,125 +78,50 @@ def to_bytes_big(value: int, size: int) -> bytes: logger = init_logger(__name__) -LONG_WAIT_TIME_LOG_MSG = ( - "No available shared memory broadcast block found " - "in %d seconds. This typically happens " - "when some processes are hanging or doing some " - "time-consuming work (e.g. compilation, " - "weight/kv cache quantization)." -) +def long_wait_time_msg(threshold: int) -> str: + return ( + "No available shared memory broadcast block found " + f"in {threshold} seconds. This typically happens " + "when some processes are hanging or doing some " + "time-consuming work (e.g. compilation, " + "weight/kv cache quantization)." + ) + + +class SpinTimer: + def record_activity(self): + pass + + def spin(self): + sched_yield() -class SpinCondition: +class SpinSleepTimer(SpinTimer): """ - This class implements an interface similar to a threading.Condition. It - allows a writer to notify readers to wake up and read from the shared memory - buffer. This notification is done over a zmq socket. - - For optimal performance under load we don't want the readers to need to poll - the zmq socket for every read. So the `wait` method here will return - immediately when reads are frequent, and will only enter "idle mode" and - await a notification on the zmq socket after a period of inactivity. This - allows the readers to spin quickly, hence "SpinCondition". - - To support clean shutdown, a separate thread in the reader's process must be - able to wake the reader so that it can exit. A separate cancel() method is - implemented with an in-process socket to allow this interruption. + In setups which have long inactivity periods it is desirable to reduce + system power consumption when vllm does nothing. This would lead to more + CPU thermal headroom when a request eventually comes, especially when + multiple GPUs are connected as each GPU would otherwise pin one thread at + 100% CPU usage. + + The simplest solution is to reduce polling frequency when there is no + activity for a certain period of time. """ - def __init__( - self, - is_reader: bool, - context: zmq.Context, - notify_address: str, - busy_loop_s: float = 1, - ): - self.is_reader = is_reader - - if is_reader: - # Time of last shm buffer read - self.last_read = time.monotonic() - - # Time to keep busy-looping on the shm buffer before going idle - self.busy_loop_s = busy_loop_s - - # Readers subscribe to write notifications - self.local_notify_socket: zmq.Socket = context.socket(SUB) - # Set zmq.CONFLATE to only keep the last message that the socket - # receives. This prevents us from piling up notification messages - # under high load when we aren't polling the socket. - self.local_notify_socket.setsockopt(zmq.CONFLATE, 1) - # Subscribe to all messages on the socket - self.local_notify_socket.setsockopt_string(SUBSCRIBE, "") - self.local_notify_socket.connect(notify_address) - - # Readers require a process-local socket to poll for cancellation - cancel_path = get_open_zmq_inproc_path() - self.write_cancel_socket: zmq.Socket = context.socket(zmq.PAIR) - self.write_cancel_socket.bind(cancel_path) - self.read_cancel_socket: zmq.Socket = context.socket(zmq.PAIR) - self.read_cancel_socket.connect(cancel_path) - - # Poller allows waiting on either `.notify()` or `.cancel()` - self.poller = zmq.Poller() - self.poller.register(self.read_cancel_socket, zmq.POLLIN) - self.poller.register(self.local_notify_socket, zmq.POLLIN) - else: - # Writer side publishes write notifications - self.local_notify_socket: zmq.Socket = context.socket(PUB) # type: ignore - # Set high water mark to 1 - we don't need to send a massive amount of - # pings during busy operation. PUB sockets will silently drop subsequent - # messages after the high water mark is reached. - self.local_notify_socket.setsockopt(zmq.SNDHWM, 1) - self.local_notify_socket.bind(notify_address) - - self.last_read = 0 - self.busy_loop_s = 0 - self.read_cancel_socket = None - self.write_cancel_socket = None - self.poller = None - - def record_read(self): - self.last_read = time.monotonic() - - def cancel(self): - # Sends cancellation ping that will cause the reader to wake up. - # This is done from a monitor thread in the same process as the reader. - if self.is_reader: - logger.debug("Canceling waiting reads on SHM Buffer") - self.write_cancel_socket.send(b"\x00") - - def wait(self, timeout_ms: int | None = None) -> None: - """Wait for data on the shared memory buffer. - - Yields the scheduler then returns immediately if it has been less than - self.busy_loop_s since the last read. - - Otherwise, enters idle mode and awaits a socket ping for at most - `timeout_ms` milliseconds, or indefinitely if timeout_ms is None. - """ - assert self.is_reader, "Only readers can wait" + def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1): + self.last_activity = time.monotonic() + self.busy_loop_s = busy_loop_s + self.wait_sleep_s = wait_sleep_s - current_time = time.monotonic() - if current_time <= self.last_read + self.busy_loop_s: - sched_yield() - else: - events = dict(self.poller.poll(timeout=timeout_ms)) - - if self.read_cancel_socket in events: - logger.debug("Poller received cancel event") - elif self.local_notify_socket in events: - logger.debug("Poller received notify event") - # Since zmq.CONFLATE is set, there will only be one notification - # to read from the socket - self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False) - else: - logger.debug("Poller timed out") + def record_activity(self): + self.last_activity = time.monotonic() - def notify(self): - """Notifies all readers to wake up""" - assert not self.is_reader, "Only writers can notify" - self.local_notify_socket.send(b"\x00") + def spin(self): + curr_time = time.monotonic() + if curr_time >= self.last_activity + self.busy_loop_s: + time.sleep(self.wait_sleep_s) + else: + sched_yield() class ShmRingBuffer: @@ -342,7 +265,6 @@ class Handle: buffer_handle: tuple[int, int, int, str] | None = None local_subscribe_addr: str | None = None - local_notify_addr: str | None = None remote_subscribe_addr: str | None = None remote_addr_ipv6: bool = False @@ -366,7 +288,7 @@ def __init__( self.n_local_reader = n_local_reader n_remote_reader = n_reader - n_local_reader self.n_remote_reader = n_remote_reader - self.shutting_down = False + context = Context() if n_local_reader > 0: @@ -388,19 +310,11 @@ def __init__( self.local_socket.bind(local_subscribe_addr) self.current_idx = 0 - - # Create the notification side of the SpinCondition - local_notify_addr = get_open_zmq_ipc_path() - self._spin_condition = SpinCondition( - is_reader=False, context=context, notify_address=local_notify_addr - ) else: self.buffer = None # type: ignore local_subscribe_addr = None self.local_socket = None self.current_idx = -1 - local_notify_addr = None - self._spin_condition = None # type: ignore remote_addr_ipv6 = False if n_remote_reader > 0: @@ -427,12 +341,12 @@ def __init__( self.local_reader_rank = -1 # rank does not matter for remote readers self._is_remote_reader = False + self._read_spin_timer = SpinTimer() self.handle = Handle( local_reader_ranks=local_reader_ranks, buffer_handle=self.buffer.handle() if self.buffer is not None else None, local_subscribe_addr=local_subscribe_addr, - local_notify_addr=local_notify_addr, remote_subscribe_addr=remote_subscribe_addr, remote_addr_ipv6=remote_addr_ipv6, ) @@ -465,9 +379,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket.connect(socket_addr) self.remote_socket = None - assert isinstance(handle.local_notify_addr, str) - self._spin_condition = SpinCondition( - is_reader=True, context=context, notify_address=handle.local_notify_addr + + self._read_spin_timer = ( + SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() ) else: self.buffer = None # type: ignore @@ -485,9 +399,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": socket_addr = handle.remote_subscribe_addr logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) - self._spin_condition = None # type: ignore - self.shutting_down = False return self def wait_until_ready(self): @@ -523,13 +435,6 @@ def wait_until_ready(self): recv = self.remote_socket.recv() assert recv == b"READY" - def shutdown(self): - """If this is an idle reader, wakes it up so it can clean up and shut - down""" - self.shutting_down = True - if self._spin_condition is not None: - self._spin_condition.cancel() - @contextmanager def acquire_write(self, timeout: float | None = None): assert self._is_writer, "Only writers can acquire write" @@ -560,7 +465,7 @@ def acquire_write(self, timeout: float | None = None): # if we wait for a long time, log a message if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: logger.info( - LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL + long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) ) n_warning += 1 @@ -598,60 +503,16 @@ def acquire_write(self, timeout: float | None = None): self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break - class ReadTimeoutWithWarnings: - def __init__(self, timeout: float | None, should_warn: bool) -> None: - self.started = time.monotonic() - self.deadline = sys.maxsize if timeout is None else self.started + timeout - - # if should_warn, we need to wake up periodically to log - self.warning_wait_time_ms: int | None = ( - VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None - ) - - self._should_warn = should_warn - self.n_warning = 1 - self.timeout = timeout - - def timeout_ms(self) -> int | None: - """Returns a timeout that is: - - min(time to deadline, time to next warning) if we're logging warnings - - time to deadline, if we're not logging warnings - - None if the timeout is None and we're not logging warnings - - raise TimeoutError if we are past the deadline - """ - warning_wait_time = self.warning_wait_time_ms - if self.timeout is None: - return warning_wait_time - - time_left_ms = int((self.deadline - time.monotonic()) * 1000) - if time_left_ms <= 0: - raise TimeoutError - - if warning_wait_time and warning_wait_time < time_left_ms: - return warning_wait_time - - return time_left_ms - - def should_warn(self) -> bool: - """Returns true if it's time to log a warning for a timeout that is not - indefinite""" - if self._should_warn: - elapsed = time.monotonic() - self.started - if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning: - self.n_warning += 1 - return True - return False - @contextmanager def acquire_read( self, timeout: float | None = None, + cancel: Event | None = None, indefinite: bool = False, ): assert self._is_local_reader, "Only readers can acquire read" - read_timeout = self.ReadTimeoutWithWarnings( - timeout=timeout, should_warn=not indefinite - ) + start_time = time.monotonic() + n_warning = 1 with self.buffer.get_metadata(self.current_idx) as metadata_buffer: while True: # Memory fence ensures we see the latest writes from the writer. @@ -668,16 +529,26 @@ def acquire_read( # for readers, `self.current_idx` is the next block to read # if this block is not ready, # we need to wait until it is written - self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms()) - if self.shutting_down: + # Release the processor to other threads + self._read_spin_timer.spin() + + if cancel is not None and cancel.is_set(): raise RuntimeError("cancelled") + # if we time out, raise an exception + elapsed = time.monotonic() - start_time + if timeout is not None and elapsed > timeout: + raise TimeoutError + # if we wait for a long time, log a message - if read_timeout.should_warn(): + if not indefinite and ( + elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning + ): logger.info( - LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL + long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) ) + n_warning += 1 continue # found a block that is not read by this reader @@ -694,7 +565,7 @@ def acquire_read( memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks - self._spin_condition.record_read() + self._read_spin_timer.record_activity() break def enqueue(self, obj, timeout: float | None = None): @@ -737,19 +608,18 @@ def oob_callback(buf: PickleBuffer) -> bool: buf[offset:buf_offset] = to_bytes_big(buf_len, 4) buf[buf_offset : (offset := buf_offset + buf_len)] = buffer - self._spin_condition.notify() - if self.n_remote_reader > 0: self.remote_socket.send_multipart(all_buffers, copy=False) def dequeue( self, timeout: float | None = None, + cancel: Event | None = None, indefinite: bool = False, ): """Read from message queue with optional timeout (in seconds)""" if self._is_local_reader: - with self.acquire_read(timeout, indefinite) as buf: + with self.acquire_read(timeout, cancel, indefinite) as buf: overflow = buf[0] == 1 if not overflow: offset = 3 diff --git a/vllm/envs.py b/vllm/envs.py index 716810da1c27..615ea2416b02 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -179,6 +179,7 @@ VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998 VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 + VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None @@ -1338,6 +1339,9 @@ def _get_or_set_default() -> str: "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int( os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") ), + # Reduce CPU usage when vLLM is idle. Enabling this will incur small + # latency penalty when a request eventually comes. + "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. @@ -1754,6 +1758,7 @@ def compile_factors() -> dict[str, object]: "VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", + "VLLM_SLEEP_WHEN_IDLE", "VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT", diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e455270cbef3..2cb9507b8a23 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -105,6 +105,7 @@ def _init_executor(self) -> None: # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False + self.shutdown_event = threading.Event() self.failure_callback: FailureCallback | None = None tp_size, pp_size, pcp_size = self._get_parallel_sizes() @@ -158,31 +159,20 @@ def _init_executor(self) -> None: global_start_rank = ( self.local_world_size * self.parallel_config.node_rank_within_dp ) - # Keep track of socket file descriptors that are inherited by the - # worker when using fork, so that we can close them in subsequent - # workers - inherited_fds: list[int] = [] for local_rank in range(self.local_world_size): global_rank = global_start_rank + local_rank is_driver_worker = self._is_driver_worker(global_rank) - unready_worker_handle = WorkerProc.make_worker_process( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=global_rank, - distributed_init_method=distributed_init_method, - input_shm_handle=scheduler_output_handle, - shared_worker_lock=shared_worker_lock, - is_driver_worker=is_driver_worker, - inherited_fds=inherited_fds, - ) - unready_workers.append(unready_worker_handle) - if context.get_start_method() == "fork": - inherited_fds.extend( - [ - unready_worker_handle.death_writer.fileno(), - unready_worker_handle.ready_pipe.fileno(), - ] + unready_workers.append( + WorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=global_rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, + is_driver_worker=is_driver_worker, ) + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -231,7 +221,6 @@ def _init_executor(self) -> None: for uw in unready_workers: if uw.death_writer is not None: uw.death_writer.close() - uw.death_writer = None self._ensure_worker_termination([uw.proc for uw in unready_workers]) self.output_rank = self._get_output_rank() @@ -267,7 +256,6 @@ def monitor_workers(): died = multiprocessing.connection.wait(sentinels) _self = self_ref() if not _self or getattr(_self, "shutting_down", False): - logger.debug("MultiprocWorkerMonitor: shutdown already initiated") return _self.is_failed = True proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) @@ -367,6 +355,8 @@ def collective_rpc( # type: ignore[override] if output_rank is not None: response_mqs = (response_mqs[output_rank],) + shutdown_event = self.shutdown_event + def get_response(): responses = [] for mq in response_mqs: @@ -374,7 +364,9 @@ def get_response(): None if deadline is None else (deadline - time.monotonic()) ) try: - status, result = mq.dequeue(timeout=dequeue_timeout) + status, result = mq.dequeue( + timeout=dequeue_timeout, cancel=shutdown_event + ) except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e if status != WorkerProc.ResponseStatus.SUCCESS: @@ -417,26 +409,20 @@ def wait_for_termination(procs, timeout): active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] # Give processes time to clean themselves up properly first - logger.debug("Worker Termination: allow workers to gracefully shutdown") if wait_for_termination(active_procs(), 4): return # Send SIGTERM if still running - logger.debug("Worker Termination: workers still running sending SIGTERM") for p in active_procs(): p.terminate() if not wait_for_termination(active_procs(), 4): # Send SIGKILL if still running - logger.debug( - "Worker Termination: resorting to SIGKILL to take down workers" - ) for p in active_procs(): p.kill() def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, "shutting_down", False): - logger.debug("Triggering shutdown of workers") self.shutting_down = True # Make sure all the worker processes are terminated first. @@ -446,20 +432,12 @@ def shutdown(self): if w.death_writer is not None: w.death_writer.close() w.death_writer = None + w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) - for w in workers: - # Shutdown response queues - if w.worker_response_mq is not None: - w.worker_response_mq.shutdown() - w.worker_response_mq = None - - if self.rpc_broadcast_mq is not None: - self.rpc_broadcast_mq.shutdown() - self.rpc_broadcast_mq = None - for mq in self.response_mqs: - mq.shutdown() - self.response_mqs = [] + self.shutdown_event.set() + + self.rpc_broadcast_mq = None def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) @@ -638,26 +616,24 @@ def make_worker_process( input_shm_handle, # Receive SchedulerOutput shared_worker_lock: LockType, is_driver_worker: bool, - inherited_fds: list[int], ) -> UnreadyWorkerProcHandle: context = get_mp_context() - # Ready pipe to communicate readiness from child to parent - ready_reader, ready_writer = context.Pipe(duplex=False) - # Death pipe to let child detect parent process exit + # (reader, writer) + reader, writer = context.Pipe(duplex=False) + + # Create death pipe to detect parent process exit death_reader, death_writer = context.Pipe(duplex=False) + process_kwargs = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, - "ready_pipe": ready_writer, + "ready_pipe": (reader, writer), "death_pipe": death_reader, "shared_worker_lock": shared_worker_lock, "is_driver_worker": is_driver_worker, - # Have the worker close parent end of this worker's pipes too - "inherited_fds": inherited_fds - + [ready_reader.fileno(), death_writer.fileno()], } # Run EngineCore busy loop in background process. proc = context.Process( @@ -668,12 +644,10 @@ def make_worker_process( ) proc.start() - # Close child ends of pipes here in the parent - ready_writer.close() - death_reader.close() + writer.close() # Keep death_writer open in parent - when parent exits, # death_reader in child will get EOFError - return UnreadyWorkerProcHandle(proc, rank, ready_reader, death_writer) + return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) @staticmethod def wait_for_response_handle_ready( @@ -736,41 +710,12 @@ def wait_for_ready( return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): - if self.rpc_broadcast_mq is not None: - self.rpc_broadcast_mq.shutdown() - if self.worker_response_mq is not None: - self.worker_response_mq.shutdown() self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() destroy_distributed_environment() - def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event): - if death_pipe is None: - return - - def death_pipe_monitor(queues_to_shutdown: list[MessageQueue]): - try: - # This will block until parent process exits (pipe closes) - death_pipe.recv() - except EOFError: - logger.info_once("Parent process exited, terminating worker queues") - shutdown_requested.set() - for mq in queues_to_shutdown: - if mq is not None: - mq.shutdown() - except Exception as e: - logger.warning("Death monitoring error: %s", e) - - # Pass queue references directly to avoid gc issues if passing self - Thread( - target=death_pipe_monitor, - args=([self.rpc_broadcast_mq, self.worker_response_mq],), - daemon=True, - name="DeathPipeMonitor", - ).start() - @staticmethod def worker_main(*args, **kwargs): """Worker initialization and execution loops. @@ -779,12 +724,12 @@ def worker_main(*args, **kwargs): # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker # processes to terminate without error - shutdown_requested = threading.Event() + shutdown_requested = False def signal_handler(signum, frame): nonlocal shutdown_requested - if not shutdown_requested.is_set(): - shutdown_requested.set() + if not shutdown_requested: + shutdown_requested = True logger.debug( "WorkerProc handling signal %d, raising SystemExit", signum ) @@ -795,20 +740,33 @@ def signal_handler(signum, frame): signal.signal(signal.SIGINT, signal_handler) worker = None - ready_writer = kwargs.pop("ready_pipe") - death_pipe = kwargs.pop("death_pipe", None) - - # Close inherited pipes from parent (incl. other worker pipes) - # Explicitly passing in existing pipes and closing them makes the pipe - # behave when using fork. Otherwise, a hidden reference to the pipes - # exist in the child process and prevents EOF closure. - for fd in kwargs.pop("inherited_fds", []): - try: - os.close(fd) - except Exception as e: - logger.warning("Exception closing inherited connection: %s", e) + # tuple[Connection, Connection] + reader, ready_writer = kwargs.pop("ready_pipe") + death_pipe: Connection | None = kwargs.pop("death_pipe", None) + shutdown_event = threading.Event() + # Start death monitoring thread if death_pipe is provided + if death_pipe is not None: + + def monitor_parent_death(): + try: + # This will block until parent process exits (pipe closes) + death_pipe.recv() + except EOFError: + # Parent process has exited, terminate this worker + logger.info_once("Parent process exited, terminating worker") + # Send signal to self to trigger clean shutdown + shutdown_event.set() + except Exception as e: + logger.warning("Death monitoring error: %s", e) + + death_monitor = Thread( + target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" + ) + death_monitor.start() try: + reader.close() + # Initialize tracer rank = kwargs.get("rank", 0) maybe_init_worker_tracer( @@ -820,8 +778,6 @@ def signal_handler(signum, frame): worker = WorkerProc(*args, **kwargs) assert worker.worker_response_mq is not None - worker.monitor_death_pipe(death_pipe, shutdown_requested) - # Send READY once we know everything is loaded ready_writer.send( { @@ -839,7 +795,7 @@ def signal_handler(signum, frame): ready_writer.close() ready_writer = None - worker.worker_busy_loop() + worker.worker_busy_loop(cancel=shutdown_event) except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -849,7 +805,7 @@ def signal_handler(signum, frame): if ready_writer is not None: logger.exception("WorkerProc failed to start.") - elif shutdown_requested.is_set(): + elif shutdown_event.is_set(): logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -857,7 +813,7 @@ def signal_handler(signum, frame): # The parent sends a SIGTERM to all worker processes if # any worker dies. Set this value so we don't re-throw # SystemExit() to avoid zmq exceptions in __del__. - shutdown_requested.set() + shutdown_requested = True except SystemExit as e: # SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that @@ -910,12 +866,12 @@ def async_output_busy_loop(self): output = self.async_output_queue.get() self.enqueue_output(output) - def worker_busy_loop(self): + def worker_busy_loop(self, cancel: threading.Event | None = None): """Main busy loop for Multiprocessing Workers""" assert self.rpc_broadcast_mq is not None while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( - indefinite=True + cancel=cancel, indefinite=True ) try: if isinstance(method, str):