From 67623aa66034475c0a7f00a5f02d63e9e7187e1a Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 31 Oct 2025 14:24:54 -0600 Subject: [PATCH 01/27] wip Signed-off-by: Joe Runde --- .../device_communicators/shm_broadcast.py | 219 +++++++++++++++--- vllm/envs.py | 4 - vllm/v1/executor/multiproc_executor.py | 65 ++++-- 3 files changed, 224 insertions(+), 64 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index f92b3d34af0f..211b302b0870 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools +import math import pickle import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory from pickle import PickleBuffer -from threading import Event from typing import TYPE_CHECKING, Any from unittest.mock import patch @@ -17,6 +17,7 @@ from torch.distributed import ProcessGroup from zmq import ( # type: ignore IPV6, # type: ignore + PUB, SUB, SUBSCRIBE, XPUB, @@ -30,6 +31,7 @@ from vllm.utils.network_utils import ( get_ip, get_open_port, + get_open_zmq_inproc_path, get_open_zmq_ipc_path, is_valid_ipv6_address, ) @@ -49,40 +51,140 @@ def to_bytes_big(value: int, size: int) -> bytes: logger = init_logger(__name__) -class SpinTimer: - def record_activity(self): - pass +class SpinCondition: + """ + This class implements an interface similar to a threading.Condition. It + allows a writer to notify readers to wake up and read from the shared memory + buffer. This notification is done over a zmq socket. + + For optimal performance under load we don't want the readers to need to poll + the zmq socket for every read. So the `wait` method here will return + immediately when reads are frequent, and will only enter "idle mode" and + await a notification on the zmq socket after a period of inactivity. This + allows the readers to spin quickly, hence "SpinCondition". + + To support clean shutdown, a separate thread in the reader's process must be + able to wake the reader so that it can exit. A separate cancel() method is + implemented with an in-process socket to allow this interruption. + """ - def spin(self): - sched_yield() + def __init__( + self, + is_reader: bool, + local_notify_socket: zmq.Socket, + cancel_socket: zmq.Socket | None = None, + busy_loop_s: float = 1, + ): + # Writers should have PUB socket, readers have SUB socket + self.local_notify_socket = local_notify_socket + self.is_reader = is_reader + if is_reader: + # Time of last shm buffer read + self.last_read = time.monotonic() -class SpinSleepTimer(SpinTimer): - """ - In setups which have long inactivity periods it is desirable to reduce - system power consumption when vllm does nothing. This would lead to more - CPU thermal headroom when a request eventually comes, especially when - multiple GPUs are connected as each GPU would otherwise pin one thread at - 100% CPU usage. - - The simplest solution is to reduce polling frequency when there is no - activity for a certain period of time. - """ + # Time to keep busy-looping on the shm buffer before going idle + self.busy_loop_s = busy_loop_s - def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1): - self.last_activity = time.monotonic() - self.busy_loop_s = busy_loop_s - self.wait_sleep_s = wait_sleep_s + assert cancel_socket is not None, "Readers require a cancel socket" + self.cancel_socket = cancel_socket - def record_activity(self): - self.last_activity = time.monotonic() + self.poller = zmq.Poller() + self.poller.register(self.cancel_socket, zmq.POLLIN) + self.poller.register(self.local_notify_socket, zmq.POLLIN) - def spin(self): - curr_time = time.monotonic() - if curr_time >= self.last_activity + self.busy_loop_s: - time.sleep(self.wait_sleep_s) else: + self.last_read = 0 + self.busy_loop_s = 0 + self.cancel_socket = None + self.poller = None + + @classmethod + def create_notifier( + cls, context: zmq.Context, notify_address: str + ) -> "SpinCondition": + """Builds the writer-process side of the SpinCondition which can notify + all readers of a write""" + local_notify_socket: zmq.Socket = context.socket(PUB) + # Set high water mark to 1- we don't need to send a massive amount of + # pings during busy operation. PUB sockets will silently drop subsequent + # messages after the high water mark is reached. + local_notify_socket.setsockopt(zmq.SNDHWM, 1) + local_notify_socket.bind(notify_address) + + return cls(is_reader=False, local_notify_socket=local_notify_socket) + + @classmethod + def create_waiter( + cls, context: zmq.Context, notify_address: str + ) -> "SpinCondition": + """Builds the reader-process side of the SpinCondition which can wait + for notifications from the writer""" + local_notify_socket: zmq.Socket = context.socket(SUB) + # Set high water mark to 1- we don't need to store a massive amount of + # notifications during busy operation. SUB sockets will silently drop + # inbound messages after the high water mark is reached. + # TODO: maybe instead zmq.ZMQ_CONFLATE + # local_notify_socket.setsockopt(zmq.CONFLATE, 1) + local_notify_socket.setsockopt(zmq.RCVHWM, 1) + + local_notify_socket.bind(notify_address) + + cancel_path = get_open_zmq_inproc_path() + cancel_socket: zmq.Socket = context.socket(zmq.PAIR) + cancel_socket.bind(cancel_path) + + print("\n\n\n\n CREATING WAITER WITH NOTIFY ADDR", notify_address) + + return cls( + is_reader=True, + local_notify_socket=local_notify_socket, + cancel_socket=cancel_socket, + ) + + def record_read(self): + self.last_read = time.monotonic() + + def cancel(self): + # Sends cancellation ping that will cause the reader to wake up. + # This is done from a monitor thread in the same process as the reader. + if self.is_reader: + self.cancel_socket.send(b"\x00") + + def wait(self, timeout_ms: float | None) -> None: + """Wait for data on the shared memory buffer. + + Yields the scheduler then returns immediately if it has been less than + self.busy_loop_s since the last read. + + Otherwise, enters idle mode and awaits a socket ping for at most + `timeout_ms` milliseconds, or indefinitely if timeout_s is None. + """ + assert self.is_reader, "Only readers can wait" + + current_time = time.monotonic() + if current_time <= self.last_read + self.busy_loop_s: sched_yield() + else: + events = dict(self.poller.poll(timeout=timeout_ms)) + + if self.cancel_socket in events: + # return immediately on cancel + return + + if self.local_notify_socket in events: + # Read all pings off the socket + while True: + try: + self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False) + except zmq.Again: + # Return when socket has nothing to read + return + + def notify(self): + """Notifies all readers to wake up""" + assert not self.is_reader, "Only writers can notify" + self.local_notify_socket.send(b"\x00") class ShmRingBuffer: @@ -226,6 +328,7 @@ class Handle: buffer_handle: tuple[int, int, int, str] | None = None local_subscribe_addr: str | None = None + local_notify_addr: str | None = None remote_subscribe_addr: str | None = None remote_addr_ipv6: bool = False @@ -249,7 +352,7 @@ def __init__( self.n_local_reader = n_local_reader n_remote_reader = n_reader - n_local_reader self.n_remote_reader = n_remote_reader - + self.shutting_down = False context = Context() if n_local_reader > 0: @@ -271,11 +374,19 @@ def __init__( self.local_socket.bind(local_subscribe_addr) self.current_idx = 0 + + # Create the notification side of the SpinCondition + local_notify_addr = get_open_zmq_ipc_path() + self._spin_condition: SpinCondition = SpinCondition.create_notifier( + context, local_notify_addr + ) else: self.buffer = None # type: ignore local_subscribe_addr = None self.local_socket = None self.current_idx = -1 + local_notify_addr = None + self._spin_condition = None remote_addr_ipv6 = False if n_remote_reader > 0: @@ -302,12 +413,12 @@ def __init__( self.local_reader_rank = -1 # rank does not matter for remote readers self._is_remote_reader = False - self._read_spin_timer = SpinTimer() self.handle = Handle( local_reader_ranks=local_reader_ranks, buffer_handle=self.buffer.handle() if self.buffer is not None else None, local_subscribe_addr=local_subscribe_addr, + local_notify_addr=local_notify_addr, remote_subscribe_addr=remote_subscribe_addr, remote_addr_ipv6=remote_addr_ipv6, ) @@ -341,8 +452,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = None - self._read_spin_timer = ( - SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + self._spin_condition = SpinCondition.create_waiter( + context, handle.local_notify_addr ) else: self.buffer = None # type: ignore @@ -361,6 +472,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) + self.shutting_down = False return self def wait_until_ready(self): @@ -396,6 +508,13 @@ def wait_until_ready(self): recv = self.remote_socket.recv() assert recv == b"READY" + def shutdown(self): + """ "If this is an idle reader, wakes it up so it can clean up and shut + down""" + self.shutting_down = True + if self._spin_condition is not None: + self._spin_condition.cancel() + @contextmanager def acquire_write(self, timeout: float | None = None): assert self._is_writer, "Only writers can acquire write" @@ -458,14 +577,19 @@ def acquire_write(self, timeout: float | None = None): def acquire_read( self, timeout: float | None = None, - cancel: Event | None = None, indefinite: bool = False, ): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() + if not indefinite: + deadline = start_time + timeout + else: + deadline = math.inf n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + # print("buzzy loop read", self.current_idx, indefinite, timeout) + read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: @@ -477,10 +601,20 @@ def acquire_read( # if this block is not ready, # we need to wait until it is written - # Release the processor to other threads - self._read_spin_timer.spin() + if not indefinite: + print("timeout spinnin") + self._spin_condition.wait( + timeout_ms=min( + VLLM_RINGBUFFER_WARNING_INTERVAL, + deadline - time.monotonic(), + ) + * 1000 + ) + else: + print("no-timeout spin") + self._spin_condition.wait() - if cancel is not None and cancel.is_set(): + if self.shutting_down: raise RuntimeError("cancelled") # if we time out, raise an exception @@ -505,6 +639,7 @@ def acquire_read( # found a block that is not read by this reader # let caller read from the buffer with self.buffer.get_data(self.current_idx) as buf: + logger.info("READ!") yield buf # caller has read from the buffer @@ -512,11 +647,14 @@ def acquire_read( metadata_buffer[self.local_reader_rank + 1] = 1 self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks - self._read_spin_timer.record_activity() + self._spin_condition.record_read() break def enqueue(self, obj, timeout: float | None = None): """Write to message queue with optional timeout (in seconds)""" + + logger.info("WRITE!") + assert self._is_writer, "Only writers can enqueue" all_buffers: list[SizedBuffer] = [b""] total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size @@ -536,10 +674,14 @@ def oob_callback(buf: PickleBuffer) -> bool: ) if self.n_local_reader > 0: if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes: + logger.info("Write over ZMQ!") + with self.acquire_write(timeout) as buf: buf[0] = 1 # overflow self.local_socket.send_multipart(all_buffers, copy=False) else: + logger.info("Write over Buf!") + # Byte 0: 0 # Bytes 1-2: Count of buffers # Then each buffer follows, preceded by 4 bytes containing its length: @@ -555,18 +697,19 @@ def oob_callback(buf: PickleBuffer) -> bool: buf[offset:buf_offset] = to_bytes_big(buf_len, 4) buf[buf_offset : (offset := buf_offset + buf_len)] = buffer + self._spin_condition.notify() + if self.n_remote_reader > 0: self.remote_socket.send_multipart(all_buffers, copy=False) def dequeue( self, timeout: float | None = None, - cancel: Event | None = None, indefinite: bool = False, ): """Read from message queue with optional timeout (in seconds)""" if self._is_local_reader: - with self.acquire_read(timeout, cancel, indefinite) as buf: + with self.acquire_read(timeout, indefinite) as buf: overflow = buf[0] == 1 if not overflow: offset = 3 diff --git a/vllm/envs.py b/vllm/envs.py index 73bb2678ea85..2d96a3f37d9a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -170,7 +170,6 @@ ] = "allgather_reducescatter" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 - VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None @@ -1207,9 +1206,6 @@ def get_vllm_port() -> int | None: "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int( os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") ), - # Reduce CPU usage when vLLM is idle. Enabling this will incur small - # latency penalty when a request eventually comes. - "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 4c58d5771c39..e257ac5efde8 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -5,7 +5,6 @@ import pickle import queue import signal -import threading import time import traceback import weakref @@ -62,7 +61,6 @@ def _init_executor(self) -> None: # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False - self.shutdown_event = threading.Event() self.failure_callback: FailureCallback | None = None self.io_thread_pool: ThreadPoolExecutor | None = None @@ -258,11 +256,8 @@ def collective_rpc( def get_response( w: WorkerProcHandle, dequeue_timeout: float | None = None, - cancel_event: threading.Event | None = None, ): - status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout, cancel=cancel_event - ) + status, result = w.worker_response_mq.dequeue(timeout=dequeue_timeout) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( @@ -279,12 +274,12 @@ def get_response( if self.io_thread_pool is not None: # We must consume worker_response_mq from a single thread. result = self.io_thread_pool.submit( # type: ignore - get_response, w, dequeue_timeout, self.shutdown_event + get_response, w, dequeue_timeout ) if not non_block: result = result.result() elif not non_block: - result = get_response(w, dequeue_timeout, self.shutdown_event) + result = get_response(w, dequeue_timeout) else: raise RuntimeError( "non_block can only be used when max_concurrent_batches > 1" @@ -313,14 +308,18 @@ def wait_for_termination(procs, timeout): time.sleep(0.1) return False + active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] + + # Give processes time to clean themselves up properly + if wait_for_termination(active_procs(), 4): + return + # Send SIGTERM if still running - active_procs = [proc for proc in worker_procs if proc.is_alive()] - for p in active_procs: + for p in active_procs(): p.terminate() - if not wait_for_termination(active_procs, 4): + if not wait_for_termination(active_procs(), 4): # Send SIGKILL if still running - active_procs = [p for p in active_procs if p.is_alive()] - for p in active_procs: + for p in active_procs(): p.kill() def shutdown(self): @@ -328,6 +327,8 @@ def shutdown(self): if not getattr(self, "shutting_down", False): self.shutting_down = True + logger.info("INITIATE SHUTDOWN") + # Make sure all the worker processes are terminated first. if workers := getattr(self, "workers", None): for w in workers: @@ -338,7 +339,9 @@ def shutdown(self): w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) - self.shutdown_event.set() + logger.info("SOMETHING SOMETHING SHUTDOWN") + # TODO: no message queues to shut down here, right? (broadcast only) + if self.io_thread_pool is not None: self.io_thread_pool.shutdown(wait=False, cancel_futures=True) del self.io_thread_pool @@ -569,6 +572,7 @@ def signal_handler(signum, frame): nonlocal shutdown_requested if not shutdown_requested: shutdown_requested = True + logger.info("RAISING SYSTEM EXIT") raise SystemExit() # Either SIGTERM or SIGINT will terminate the worker @@ -579,7 +583,7 @@ def signal_handler(signum, frame): # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") death_pipe = kwargs.pop("death_pipe", None) - shutdown_event = threading.Event() + shutdown = False # Start death monitoring thread if death_pipe is provided if death_pipe is not None: @@ -588,10 +592,19 @@ def monitor_parent_death(): # This will block until parent process exits (pipe closes) death_pipe.recv() except EOFError: + # logger.info("sleepin for a bit...") + # time.sleep(1) + # Parent process has exited, terminate this worker logger.info("Parent process exited, terminating worker") - # Send signal to self to trigger clean shutdown - shutdown_event.set() + nonlocal shutdown + shutdown = True + # Shut down message queues + if worker.rpc_broadcast_mq is not None: + worker.rpc_broadcast_mq.shutdown() + if worker.worker_response_mq is not None: + worker.worker_response_mq.shutdown() + except Exception as e: logger.warning("Death monitoring error: %s", e) @@ -619,7 +632,7 @@ def monitor_parent_death(): ready_writer.close() ready_writer = None - worker.worker_busy_loop(cancel=shutdown_event) + worker.worker_busy_loop() except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -629,7 +642,7 @@ def monitor_parent_death(): if ready_writer is not None: logger.exception("WorkerProc failed to start.") - elif shutdown_event.is_set(): + elif shutdown: logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -638,14 +651,22 @@ def monitor_parent_death(): # any worker dies. Set this value so we don't re-throw # SystemExit() to avoid zmq exceptions in __del__. shutdown_requested = True - + except SystemExit as e: + # If proper shutdown does not succeed, the worker processes are sent + # a SIGTERM and finally a SIGKILL, each of which should raise a + # SystemExit() exception + logger.warning("WorkerProc failed to shut down properly and was terminated") + raise e finally: if ready_writer is not None: + logger.info("CLOSING WRITER") ready_writer.close() if death_pipe is not None: + logger.info("CLOSING DEATH PIPE") death_pipe.close() # Clean up once worker exits busy loop if worker is not None: + logger.info("SHUTTING DOWN WORKER") worker.shutdown() class ResponseStatus(Enum): @@ -683,11 +704,11 @@ def async_output_busy_loop(self): output = self.async_output_queue.get() self.enqueue_output(output) - def worker_busy_loop(self, cancel: threading.Event | None = None): + def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( - cancel=cancel, indefinite=True + indefinite=True ) try: if isinstance(method, str): From 3ce3a482ec672cb64ba9318ea2c8539dac35f7f2 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 4 Nov 2025 09:13:30 -0700 Subject: [PATCH 02/27] :zap: functional SpinCondition Signed-off-by: Joe Runde --- .../device_communicators/shm_broadcast.py | 121 +++++++----------- vllm/v1/executor/multiproc_executor.py | 4 +- 2 files changed, 48 insertions(+), 77 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 211b302b0870..13a0a82ffb65 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -71,12 +71,10 @@ class SpinCondition: def __init__( self, is_reader: bool, - local_notify_socket: zmq.Socket, - cancel_socket: zmq.Socket | None = None, + context: zmq.Context, + notify_address: str, busy_loop_s: float = 1, ): - # Writers should have PUB socket, readers have SUB socket - self.local_notify_socket = local_notify_socket self.is_reader = is_reader if is_reader: @@ -86,62 +84,42 @@ def __init__( # Time to keep busy-looping on the shm buffer before going idle self.busy_loop_s = busy_loop_s - assert cancel_socket is not None, "Readers require a cancel socket" - self.cancel_socket = cancel_socket - + # Readers subscribe to write notifications + self.local_notify_socket: zmq.Socket = context.socket(SUB) + # Set zmq.CONFLATE to only keep the last message that the socket + # receives. This prevents us from piling up notification messages + # under high load when we aren't polling the socket. + self.local_notify_socket.setsockopt(zmq.CONFLATE, 1) + # Subscribe to all messages on the socket + self.local_notify_socket.setsockopt_string(SUBSCRIBE, "") + self.local_notify_socket.connect(notify_address) + + # Readers require a process-local socket to poll for cancellation + cancel_path = get_open_zmq_inproc_path() + self.write_cancel_socket: zmq.Socket = context.socket(zmq.PAIR) + self.write_cancel_socket.bind(cancel_path) + self.read_cancel_socket: zmq.Socket = context.socket(zmq.PAIR) + self.read_cancel_socket.connect(cancel_path) + + # Poller allows waiting on either `.notify()` or `.cancel()` self.poller = zmq.Poller() - self.poller.register(self.cancel_socket, zmq.POLLIN) + self.poller.register(self.read_cancel_socket, zmq.POLLIN) self.poller.register(self.local_notify_socket, zmq.POLLIN) - else: + # Writer side publishes write notifications + self.local_notify_socket: zmq.Socket = context.socket(PUB) # type: ignore + # Set high water mark to 1- we don't need to send a massive amount of + # pings during busy operation. PUB sockets will silently drop subsequent + # messages after the high water mark is reached. + self.local_notify_socket.setsockopt(zmq.SNDHWM, 1) + self.local_notify_socket.bind(notify_address) + self.last_read = 0 self.busy_loop_s = 0 - self.cancel_socket = None + self.read_cancel_socket = None + self.write_cancel_socket = None self.poller = None - @classmethod - def create_notifier( - cls, context: zmq.Context, notify_address: str - ) -> "SpinCondition": - """Builds the writer-process side of the SpinCondition which can notify - all readers of a write""" - local_notify_socket: zmq.Socket = context.socket(PUB) - # Set high water mark to 1- we don't need to send a massive amount of - # pings during busy operation. PUB sockets will silently drop subsequent - # messages after the high water mark is reached. - local_notify_socket.setsockopt(zmq.SNDHWM, 1) - local_notify_socket.bind(notify_address) - - return cls(is_reader=False, local_notify_socket=local_notify_socket) - - @classmethod - def create_waiter( - cls, context: zmq.Context, notify_address: str - ) -> "SpinCondition": - """Builds the reader-process side of the SpinCondition which can wait - for notifications from the writer""" - local_notify_socket: zmq.Socket = context.socket(SUB) - # Set high water mark to 1- we don't need to store a massive amount of - # notifications during busy operation. SUB sockets will silently drop - # inbound messages after the high water mark is reached. - # TODO: maybe instead zmq.ZMQ_CONFLATE - # local_notify_socket.setsockopt(zmq.CONFLATE, 1) - local_notify_socket.setsockopt(zmq.RCVHWM, 1) - - local_notify_socket.bind(notify_address) - - cancel_path = get_open_zmq_inproc_path() - cancel_socket: zmq.Socket = context.socket(zmq.PAIR) - cancel_socket.bind(cancel_path) - - print("\n\n\n\n CREATING WAITER WITH NOTIFY ADDR", notify_address) - - return cls( - is_reader=True, - local_notify_socket=local_notify_socket, - cancel_socket=cancel_socket, - ) - def record_read(self): self.last_read = time.monotonic() @@ -149,9 +127,10 @@ def cancel(self): # Sends cancellation ping that will cause the reader to wake up. # This is done from a monitor thread in the same process as the reader. if self.is_reader: - self.cancel_socket.send(b"\x00") + logger.debug("Canceling waiting reads on SHM Buffer") + self.write_cancel_socket.send(b"\x00") - def wait(self, timeout_ms: float | None) -> None: + def wait(self, timeout_ms: float | None = None) -> None: """Wait for data on the shared memory buffer. Yields the scheduler then returns immediately if it has been less than @@ -168,18 +147,14 @@ def wait(self, timeout_ms: float | None) -> None: else: events = dict(self.poller.poll(timeout=timeout_ms)) - if self.cancel_socket in events: + if self.read_cancel_socket in events: # return immediately on cancel return if self.local_notify_socket in events: - # Read all pings off the socket - while True: - try: - self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False) - except zmq.Again: - # Return when socket has nothing to read - return + # Since zmq.CONFLATE is set, there will only be one notification + # to read from the socket + self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False) def notify(self): """Notifies all readers to wake up""" @@ -377,8 +352,8 @@ def __init__( # Create the notification side of the SpinCondition local_notify_addr = get_open_zmq_ipc_path() - self._spin_condition: SpinCondition = SpinCondition.create_notifier( - context, local_notify_addr + self._spin_condition = SpinCondition( + is_reader=False, context=context, notify_address=local_notify_addr ) else: self.buffer = None # type: ignore @@ -386,7 +361,7 @@ def __init__( self.local_socket = None self.current_idx = -1 local_notify_addr = None - self._spin_condition = None + self._spin_condition = None # type: ignore remote_addr_ipv6 = False if n_remote_reader > 0: @@ -451,9 +426,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket.connect(socket_addr) self.remote_socket = None - - self._spin_condition = SpinCondition.create_waiter( - context, handle.local_notify_addr + assert isinstance(handle.local_notify_addr, str) + self._spin_condition = SpinCondition( + is_reader=True, context=context, notify_address=handle.local_notify_addr ) else: self.buffer = None # type: ignore @@ -509,7 +484,7 @@ def wait_until_ready(self): assert recv == b"READY" def shutdown(self): - """ "If this is an idle reader, wakes it up so it can clean up and shut + """If this is an idle reader, wakes it up so it can clean up and shut down""" self.shutting_down = True if self._spin_condition is not None: @@ -581,15 +556,13 @@ def acquire_read( ): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() - if not indefinite: + if not indefinite and timeout is not None: deadline = start_time + timeout else: deadline = math.inf n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: - # print("buzzy loop read", self.current_idx, indefinite, timeout) - read_flag = metadata_buffer[self.local_reader_rank + 1] written_flag = metadata_buffer[0] if not written_flag or read_flag: @@ -602,7 +575,6 @@ def acquire_read( # we need to wait until it is written if not indefinite: - print("timeout spinnin") self._spin_condition.wait( timeout_ms=min( VLLM_RINGBUFFER_WARNING_INTERVAL, @@ -611,7 +583,6 @@ def acquire_read( * 1000 ) else: - print("no-timeout spin") self._spin_condition.wait() if self.shutting_down: diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e257ac5efde8..f6eeeda24496 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -600,9 +600,9 @@ def monitor_parent_death(): nonlocal shutdown shutdown = True # Shut down message queues - if worker.rpc_broadcast_mq is not None: + if worker is not None and worker.rpc_broadcast_mq is not None: worker.rpc_broadcast_mq.shutdown() - if worker.worker_response_mq is not None: + if worker is not None and worker.worker_response_mq is not None: worker.worker_response_mq.shutdown() except Exception as e: From 21c7b042e814810c7e11f932489d6c659d87be93 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 4 Nov 2025 09:24:43 -0700 Subject: [PATCH 03/27] :art: cleanup Signed-off-by: Joe Runde --- vllm/v1/executor/multiproc_executor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index f6eeeda24496..e6c022807448 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -572,7 +572,7 @@ def signal_handler(signum, frame): nonlocal shutdown_requested if not shutdown_requested: shutdown_requested = True - logger.info("RAISING SYSTEM EXIT") + logger.debug("Raising SystemExit() while handling signal %d", signum) raise SystemExit() # Either SIGTERM or SIGINT will terminate the worker @@ -653,20 +653,18 @@ def monitor_parent_death(): shutdown_requested = True except SystemExit as e: # If proper shutdown does not succeed, the worker processes are sent - # a SIGTERM and finally a SIGKILL, each of which should raise a + # a SIGTERM and finally a SIGKILL, which should raise a # SystemExit() exception - logger.warning("WorkerProc failed to shut down properly and was terminated") + logger.warning("WorkerProc was terminated") + # SystemExit must never be ignored raise e finally: if ready_writer is not None: - logger.info("CLOSING WRITER") ready_writer.close() if death_pipe is not None: - logger.info("CLOSING DEATH PIPE") death_pipe.close() # Clean up once worker exits busy loop if worker is not None: - logger.info("SHUTTING DOWN WORKER") worker.shutdown() class ResponseStatus(Enum): From b296ad58f6dab6865c26d6b9f5c2120d819efaa8 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 4 Nov 2025 09:26:19 -0700 Subject: [PATCH 04/27] :art: fmt Signed-off-by: Joe Runde --- vllm/distributed/device_communicators/shm_broadcast.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 13a0a82ffb65..9473db243dbf 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -610,7 +610,6 @@ def acquire_read( # found a block that is not read by this reader # let caller read from the buffer with self.buffer.get_data(self.current_idx) as buf: - logger.info("READ!") yield buf # caller has read from the buffer @@ -623,9 +622,6 @@ def acquire_read( def enqueue(self, obj, timeout: float | None = None): """Write to message queue with optional timeout (in seconds)""" - - logger.info("WRITE!") - assert self._is_writer, "Only writers can enqueue" all_buffers: list[SizedBuffer] = [b""] total_bytes = 6 # 2 bytes for oob buffer count, 4 for main buffer size @@ -645,14 +641,10 @@ def oob_callback(buf: PickleBuffer) -> bool: ) if self.n_local_reader > 0: if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes: - logger.info("Write over ZMQ!") - with self.acquire_write(timeout) as buf: buf[0] = 1 # overflow self.local_socket.send_multipart(all_buffers, copy=False) else: - logger.info("Write over Buf!") - # Byte 0: 0 # Bytes 1-2: Count of buffers # Then each buffer follows, preceded by 4 bytes containing its length: From 0b8408232f4c1aa26d9563cfb832c132c4d3cba8 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 4 Nov 2025 10:00:58 -0700 Subject: [PATCH 05/27] :art: cleanup Signed-off-by: Joe Runde --- vllm/v1/executor/multiproc_executor.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e6c022807448..ed9a96a47ae7 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -327,8 +327,6 @@ def shutdown(self): if not getattr(self, "shutting_down", False): self.shutting_down = True - logger.info("INITIATE SHUTDOWN") - # Make sure all the worker processes are terminated first. if workers := getattr(self, "workers", None): for w in workers: @@ -339,9 +337,6 @@ def shutdown(self): w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) - logger.info("SOMETHING SOMETHING SHUTDOWN") - # TODO: no message queues to shut down here, right? (broadcast only) - if self.io_thread_pool is not None: self.io_thread_pool.shutdown(wait=False, cancel_futures=True) del self.io_thread_pool @@ -592,9 +587,6 @@ def monitor_parent_death(): # This will block until parent process exits (pipe closes) death_pipe.recv() except EOFError: - # logger.info("sleepin for a bit...") - # time.sleep(1) - # Parent process has exited, terminate this worker logger.info("Parent process exited, terminating worker") nonlocal shutdown From 00c4c3f6785b601529004c5aaee29c5949128f2b Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 13 Nov 2025 14:50:24 -0700 Subject: [PATCH 06/27] :poop: WIP unit tests Signed-off-by: Joe Runde --- tests/distributed/test_shm_broadcast.py | 71 +++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index a7ace62e1b54..76599865a7d8 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import multiprocessing import random +import threading import time +import multiprocess as mp import numpy as np +import pytest import torch.distributed as dist from vllm.distributed.device_communicators.shm_broadcast import MessageQueue @@ -33,7 +35,7 @@ def distributed_run(fn, world_size): env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["MASTER_ADDR"] = "localhost" env["MASTER_PORT"] = "12345" - p = multiprocessing.Process(target=fn, args=(env,)) + p = mp.Process(target=fn, args=(env,)) processes.append(p) p.start() @@ -45,7 +47,7 @@ def distributed_run(fn, world_size): def worker_fn_wrapper(fn): - # `multiprocessing.Process` cannot accept environment variables directly + # `mp.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function def wrapped_fn(env): @@ -115,3 +117,66 @@ def worker_fn(): def test_shm_broadcast(): distributed_run(worker_fn, 4) + + +@worker_fn_wrapper +def worker_fn_test_cancel(): + rank = dist.get_rank() + if rank == 0: + port = get_open_port() + ip = "127.0.0.1" + dist.broadcast_object_list([ip, port], src=0) + else: + recv = [None, None] + dist.broadcast_object_list(recv, src=0) + ip, port = recv # type: ignore + + stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) + + for pg in [dist.group.WORLD, stateless_pg]: + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + pg, 40 * 1024, 2, writer_rank + ) + + if pg == dist.group.WORLD: + dist.barrier() + else: + pg.barrier() + + if rank != writer_rank: + # Put into idle mode + message_queue._spin_condition.last_read = 0 + + shutdown_event = threading.Event() + + def shutdown_thread(mq, shutdown_event): + shutdown_event.wait() + mq.shutdown() + + threading.Thread( + target=shutdown_thread, args=(message_queue, shutdown_event) + ).start() + + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.001) + + shutdown_event.set() + + message_queue.dequeue(timeout=1) + + assert message_queue.shutting_down + else: + # Write nothing + message_queue.shutdown() + + if pg == dist.group.WORLD: + print(f"torch distributed passed the test! Rank {rank}") + dist.barrier() + else: + print(f"StatelessProcessGroup passed the test! Rank {rank}") + pg.barrier() + + +def test_message_queue_shutdown(): + distributed_run(worker_fn_test_cancel, 4) From f55c68ecdfb5e07ca1e9a50052a8d1ccabd5c443 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 17 Nov 2025 13:20:07 -0700 Subject: [PATCH 07/27] test: flesh out shm_broadcast tests Signed-off-by: Travis Johnson --- tests/distributed/test_shm_broadcast.py | 174 ++++++++++++++++++------ 1 file changed, 133 insertions(+), 41 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 76599865a7d8..fd0ae0815ba0 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -4,6 +4,7 @@ import random import threading import time +from unittest import mock import multiprocess as mp import numpy as np @@ -120,63 +121,154 @@ def test_shm_broadcast(): @worker_fn_wrapper -def worker_fn_test_cancel(): +def worker_fn_test_shutdown(): rank = dist.get_rank() - if rank == 0: - port = get_open_port() - ip = "127.0.0.1" - dist.broadcast_object_list([ip, port], src=0) - else: - recv = [None, None] - dist.broadcast_object_list(recv, src=0) - ip, port = recv # type: ignore + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) - stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) + if not message_queue._is_writer: + # Put into idle mode + message_queue._spin_condition.last_read = 0 - for pg in [dist.group.WORLD, stateless_pg]: - writer_rank = 2 - message_queue = MessageQueue.create_from_process_group( - pg, 40 * 1024, 2, writer_rank - ) + shutdown_event = threading.Event() + + def shutdown_thread(mq, shutdown_event): + shutdown_event.wait() + mq.shutdown() + + threading.Thread( + target=shutdown_thread, args=(message_queue, shutdown_event) + ).start() + + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + + shutdown_event.set() + + with pytest.raises(RuntimeError, match="cancelled"): + message_queue.dequeue(timeout=1) + + assert message_queue.shutting_down + + print(f"torch distributed passed the test! Rank {rank}") + dist.barrier() + + +def test_message_queue_shutdown(): + distributed_run(worker_fn_test_shutdown, 4) - if pg == dist.group.WORLD: - dist.barrier() - else: - pg.barrier() - if rank != writer_rank: +@worker_fn_wrapper +def worker_fn_test_idle_to_busy(): + rank = dist.get_rank() + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) + + message1 = "hello world" + message2 = np.random.randint(1, 100, 100) + with mock.patch.object( + message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait + ) as wrapped_wait: + if not message_queue._is_writer: # Put into idle mode message_queue._spin_condition.last_read = 0 - shutdown_event = threading.Event() + # no messages, so expect a TimeoutError + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + # wait should only be called once while idle + assert wrapped_wait.call_count == 1 - def shutdown_thread(mq, shutdown_event): - shutdown_event.wait() - mq.shutdown() + # sync with the writer and wait for message1 + dist.barrier() + recv_message = message_queue.dequeue(timeout=5) + assert recv_message == message1 + # second call to wait, with a message read, this puts in a busy spin + assert wrapped_wait.call_count == 2 + + # sync with the writer and wait for message2 + dist.barrier() + recv_message = message_queue.dequeue(timeout=1) + assert np.array_equal(recv_message, message2) + # in busy mode, we expect wait to have been called multiple times + assert wrapped_wait.call_count > 3 + else: + # writer writes two messages in sync with the reader + dist.barrier() + # sleep delays the send to ensure reader enters the read loop + time.sleep(0.1) + message_queue.enqueue(message1) - threading.Thread( - target=shutdown_thread, args=(message_queue, shutdown_event) - ).start() + dist.barrier() + time.sleep(0.1) + message_queue.enqueue(message2) - with pytest.raises(TimeoutError): - message_queue.dequeue(timeout=0.001) + message_queue.shutdown() + assert message_queue.shutting_down + print(f"torch distributed passed the test! Rank {rank}") - shutdown_event.set() - message_queue.dequeue(timeout=1) +def test_message_queue_idle_wake(): + distributed_run(worker_fn_test_idle_to_busy, 4) - assert message_queue.shutting_down - else: - # Write nothing - message_queue.shutdown() - if pg == dist.group.WORLD: - print(f"torch distributed passed the test! Rank {rank}") +@worker_fn_wrapper +def worker_fn_test_busy_to_idle(): + rank = dist.get_rank() + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) + + message1 = 12345 + message2 = list(range(3)) + with mock.patch.object( + message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait + ) as wrapped_wait: + if not message_queue._is_writer: + # Put into busy mode + message_queue._spin_condition.busy_loop_s = 9999 + + # sync with the writer and wait for message1 dist.barrier() + recv_message = message_queue.dequeue(timeout=1) + assert recv_message == message1 + # in busy mode, we expect wait to have been called many times + assert wrapped_wait.call_count > 1 + + # simulate busy loop ending + message_queue._spin_condition.busy_loop_s = 0 + # ensure we enter idle mode, then record call count + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + call_count = wrapped_wait.call_count + + # sync with the writer and wait for message2 + dist.barrier() + recv_message = message_queue.dequeue(timeout=1) + assert recv_message == message2 + + # call to wait after idle should only happen once + assert wrapped_wait.call_count == call_count + 1 else: - print(f"StatelessProcessGroup passed the test! Rank {rank}") - pg.barrier() + # writer writes two messages in sync with the reader + dist.barrier() + # sleep delays the send to ensure reader enters the read loop + time.sleep(0.1) + message_queue.enqueue(message1) + + dist.barrier() + time.sleep(0.1) + message_queue.enqueue(message2) + message_queue.shutdown() + assert message_queue.shutting_down + print(f"torch distributed passed the test! Rank {rank}") -def test_message_queue_shutdown(): - distributed_run(worker_fn_test_cancel, 4) + +def test_message_queue_busy_to_idle(): + distributed_run(worker_fn_test_busy_to_idle, 4) From 0bd12b3573b25393dbb66d74564b2fc8e33ba2d8 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 17 Nov 2025 16:07:52 -0700 Subject: [PATCH 08/27] fix timeout handling and little refactor Signed-off-by: Travis Johnson --- .../device_communicators/shm_broadcast.py | 42 +++++++++++-------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index b0e0bf7ae124..9e592ae6d96f 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -3,6 +3,7 @@ import functools import math import pickle +import sys import time from contextlib import contextmanager from dataclasses import dataclass, field @@ -140,14 +141,14 @@ def cancel(self): logger.debug("Canceling waiting reads on SHM Buffer") self.write_cancel_socket.send(b"\x00") - def wait(self, timeout_ms: float | None = None) -> None: + def wait(self, timeout_ms: int | None = None) -> None: """Wait for data on the shared memory buffer. Yields the scheduler then returns immediately if it has been less than self.busy_loop_s since the last read. Otherwise, enters idle mode and awaits a socket ping for at most - `timeout_ms` milliseconds, or indefinitely if timeout_s is None. + `timeout_ms` milliseconds, or indefinitely if timeout_ms is None. """ assert self.is_reader, "Only readers can wait" @@ -562,10 +563,20 @@ def acquire_read( ): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() - if not indefinite and timeout is not None: + if timeout is not None: deadline = start_time + timeout + wait_timeout_ms = ( + VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 + if not indefinite + else sys.maxsize + ) else: deadline = math.inf + # wait_timeout_ms is a constant if timeout is None + wait_timeout_ms = ( + VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if not indefinite else None + ) + n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: @@ -579,27 +590,22 @@ def acquire_read( # for readers, `self.current_idx` is the next block to read # if this block is not ready, # we need to wait until it is written + if timeout is not None: + time_left_ms = int((deadline - time.monotonic()) * 1000) + # if we time out, raise an exception + if time_left_ms <= 0: + raise TimeoutError - if not indefinite: - self._spin_condition.wait( - timeout_ms=min( - VLLM_RINGBUFFER_WARNING_INTERVAL, - deadline - time.monotonic(), - ) - * 1000 - ) - else: - self._spin_condition.wait() + wait_timeout_ms = min(wait_timeout_ms, time_left_ms) + # else: use constant wait_timeout_ms defined outside of loop + + self._spin_condition.wait(timeout_ms=wait_timeout_ms) if self.shutting_down: raise RuntimeError("cancelled") - # if we time out, raise an exception - elapsed = time.monotonic() - start_time - if timeout is not None and elapsed > timeout: - raise TimeoutError - # if we wait for a long time, log a message + elapsed = time.monotonic() - start_time if not indefinite and ( elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning ): From 73c3398b3d5c360ef3d1ffc95996299330b4086c Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 1 Dec 2025 09:59:13 -0700 Subject: [PATCH 09/27] test: add busy shutdown test Signed-off-by: Travis Johnson --- tests/distributed/test_shm_broadcast.py | 46 +++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index fd0ae0815ba0..b607891e3eec 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -121,7 +121,47 @@ def test_shm_broadcast(): @worker_fn_wrapper -def worker_fn_test_shutdown(): +def worker_fn_test_shutdown_busy(): + rank = dist.get_rank() + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) + + if not message_queue._is_writer: + # Put into busy mode + message_queue._spin_condition.busy_loop_s = 9999 + + shutdown_event = threading.Event() + + def shutdown_thread(mq, shutdown_event): + shutdown_event.wait() + mq.shutdown() + + threading.Thread( + target=shutdown_thread, args=(message_queue, shutdown_event) + ).start() + + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + + shutdown_event.set() + + with pytest.raises(RuntimeError, match="cancelled"): + message_queue.dequeue(timeout=1) + + assert message_queue.shutting_down + + print(f"torch distributed passed the test! Rank {rank}") + dist.barrier() + + +def test_message_queue_shutdown_busy(): + distributed_run(worker_fn_test_shutdown_busy, 4) + + +@worker_fn_wrapper +def worker_fn_test_shutdown_idle(): rank = dist.get_rank() writer_rank = 2 message_queue = MessageQueue.create_from_process_group( @@ -156,8 +196,8 @@ def shutdown_thread(mq, shutdown_event): dist.barrier() -def test_message_queue_shutdown(): - distributed_run(worker_fn_test_shutdown, 4) +def test_message_queue_shutdown_idle(): + distributed_run(worker_fn_test_shutdown_idle, 4) @worker_fn_wrapper From 7749b21ebe7c8dd04a87177d4b1634d1afa30488 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 21 Jan 2026 14:07:38 -0700 Subject: [PATCH 10/27] :bug: fix uninitialized spin condition Signed-off-by: Joe Runde --- vllm/distributed/device_communicators/shm_broadcast.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 39b4ac8fa32f..dd9ea7e3dd1d 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -486,6 +486,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": socket_addr = handle.remote_subscribe_addr logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) + self._spin_condition = None # type: ignore self.shutting_down = False return self From 9ff2361c756068f0877ea9493226db1841e75e3f Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Fri, 23 Jan 2026 16:08:13 -0700 Subject: [PATCH 11/27] :rewind: revert changes from #32965 Signed-off-by: Joe Runde --- vllm/v1/executor/multiproc_executor.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e0df520d08f3..2680c59a64a7 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -385,18 +385,14 @@ def wait_for_termination(procs, timeout): time.sleep(0.1) return False - active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] - - # Give processes time to clean themselves up properly - if wait_for_termination(active_procs(), 4): - return - # Send SIGTERM if still running - for p in active_procs(): + active_procs = [proc for proc in worker_procs if proc.is_alive()] + for p in active_procs: p.terminate() - if not wait_for_termination(active_procs(), 4): + if not wait_for_termination(active_procs, 4): # Send SIGKILL if still running - for p in active_procs(): + active_procs = [p for p in active_procs if p.is_alive()] + for p in active_procs: p.kill() def shutdown(self): @@ -697,7 +693,6 @@ def signal_handler(signum, frame): nonlocal shutdown_requested if not shutdown_requested: shutdown_requested = True - logger.debug("Raising SystemExit() while handling signal %d", signum) raise SystemExit() # Either SIGTERM or SIGINT will terminate the worker @@ -776,13 +771,7 @@ def monitor_parent_death(): # any worker dies. Set this value so we don't re-throw # SystemExit() to avoid zmq exceptions in __del__. shutdown_requested = True - except SystemExit as e: - # If proper shutdown does not succeed, the worker processes are sent - # a SIGTERM and finally a SIGKILL, which should raise a - # SystemExit() exception - logger.warning("WorkerProc was terminated") - # SystemExit must never be ignored - raise e + finally: if ready_writer is not None: ready_writer.close() From 0f0ccf6320329464606a45ac9e9280b3d9633250 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 18 Feb 2026 15:05:08 -0700 Subject: [PATCH 12/27] :recycle: refactor timeout stuff Signed-off-by: Joe Runde --- tests/distributed/test_shm_broadcast.py | 38 ++++++++- .../device_communicators/shm_broadcast.py | 78 ++++++++++++------- 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index b607891e3eec..25b575c22b53 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -156,8 +156,9 @@ def shutdown_thread(mq, shutdown_event): dist.barrier() -def test_message_queue_shutdown_busy(): +def test_message_queue_shutdown_busy(caplog_vllm): distributed_run(worker_fn_test_shutdown_busy, 4) + print(caplog_vllm.text) @worker_fn_wrapper @@ -312,3 +313,38 @@ def worker_fn_test_busy_to_idle(): def test_message_queue_busy_to_idle(): distributed_run(worker_fn_test_busy_to_idle, 4) + + +def test_warning_logs(caplog_vllm): + """ + Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals + when indefinite=False + """ + + with mock.patch( + "vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL", + new=0.001, # 1 ms + ): + writer = MessageQueue( + n_reader=1, + n_local_reader=1, + max_chunk_bytes=1024 * 1024, # 1MB chunks + max_chunks=10, + ) + reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0) + writer.wait_until_ready() + reader.wait_until_ready() + + # Reader times out + with pytest.raises(TimeoutError): + reader.dequeue(timeout=0.01, indefinite=False) + + # Clean up when done + writer.shutdown() + reader.shutdown() + + assert any( + "No available shared memory broadcast block found in 0.001 seconds" + in record.message + for record in caplog_vllm.records + ) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index dd9ea7e3dd1d..c6d2ed34a7b4 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -import math import pickle import sys import threading @@ -593,6 +592,49 @@ def acquire_write(self, timeout: float | None = None): self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break + class ReadTimeout: + def __init__(self, timeout: float | None, should_warn: bool) -> None: + self.started = time.monotonic() + if timeout is not None: + self.deadline = self.started + timeout + else: + self.deadline = sys.maxsize + + if should_warn: + self.warn_timeout_ms = VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 + else: + self.warn_timeout_ms = sys.maxsize + + self._should_warn = should_warn + self.n_warning = 1 + self.timeout = timeout + + def timeout_ms(self) -> int: + """Returns a timeout that is: + - min(time to deadline, time to next warning) if we're logging warnings + - time to deadline, if we're not logging warnings + - sys.maxsize if the timeout is None and we're not logging warnings + """ + if self.timeout is None: + time_left_ms = sys.maxsize + else: + time_left_ms = int((self.deadline - time.monotonic()) * 1000) + return min(self.warn_timeout_ms, time_left_ms) + + def expired(self) -> bool: + """Returns True if the timeout has expired.""" + return time.monotonic() >= self.deadline + + def should_warn(self) -> bool: + """Returns true if it's time to log a warning for a timeout that is not + indefinite""" + if self._should_warn: + elapsed = time.monotonic() - self.started + if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning: + self.n_warning += 1 + return True + return False + @contextmanager def acquire_read( self, @@ -600,22 +642,8 @@ def acquire_read( indefinite: bool = False, ): assert self._is_local_reader, "Only readers can acquire read" - start_time = time.monotonic() - if timeout is not None: - deadline = start_time + timeout - wait_timeout_ms: int | None = ( - VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 - if not indefinite - else sys.maxsize - ) - else: - deadline = math.inf - # wait_timeout_ms is a constant if timeout is None - wait_timeout_ms = ( - VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if not indefinite else None - ) + read_timeout = self.ReadTimeout(timeout=timeout, should_warn=not indefinite) - n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: # Memory fence ensures we see the latest writes from the writer. @@ -632,28 +660,18 @@ def acquire_read( # for readers, `self.current_idx` is the next block to read # if this block is not ready, # we need to wait until it is written - if timeout is not None: - time_left_ms = int((deadline - time.monotonic()) * 1000) - # if we time out, raise an exception - if time_left_ms <= 0: - raise TimeoutError - wait_timeout_ms = min(cast(int, wait_timeout_ms), time_left_ms) - # else: use constant wait_timeout_ms defined outside of loop - - self._spin_condition.wait(timeout_ms=wait_timeout_ms) + if read_timeout.expired(): + raise TimeoutError + self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms()) if self.shutting_down: raise RuntimeError("cancelled") # if we wait for a long time, log a message - elapsed = time.monotonic() - start_time - if not indefinite and ( - elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning - ): + if read_timeout.should_warn(): logger.info( long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) ) - n_warning += 1 continue # found a block that is not read by this reader From dbac53aa998481bc09af2aa358f5cdde3d60ec06 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 18 Feb 2026 15:29:27 -0700 Subject: [PATCH 13/27] refactor: make ReadTimeout class clearer Signed-off-by: Travis Johnson --- .../device_communicators/shm_broadcast.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index c6d2ed34a7b4..d21d40939611 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -600,30 +600,35 @@ def __init__(self, timeout: float | None, should_warn: bool) -> None: else: self.deadline = sys.maxsize - if should_warn: - self.warn_timeout_ms = VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 - else: - self.warn_timeout_ms = sys.maxsize + # if should_warn, we need to wake up periodically to log + self.warning_wait_timeout_ms: int | None = ( + VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None + ) self._should_warn = should_warn self.n_warning = 1 self.timeout = timeout - def timeout_ms(self) -> int: + def timeout_ms(self) -> int | None: """Returns a timeout that is: - min(time to deadline, time to next warning) if we're logging warnings - time to deadline, if we're not logging warnings - - sys.maxsize if the timeout is None and we're not logging warnings + - None if the timeout is None and we're not logging warnings + - raise TimeoutError if we are past the deadline """ if self.timeout is None: - time_left_ms = sys.maxsize + return self.warning_wait_timeout_ms else: - time_left_ms = int((self.deadline - time.monotonic()) * 1000) - return min(self.warn_timeout_ms, time_left_ms) - - def expired(self) -> bool: - """Returns True if the timeout has expired.""" - return time.monotonic() >= self.deadline + time_left = self.deadline - time.monotonic() + if time_left <= 0: + raise TimeoutError + time_left_ms = int(time_left * 1000) + + return ( + time_left_ms + if self.warning_wait_timeout_ms is None + else min(self.warning_wait_timeout_ms, time_left_ms) + ) def should_warn(self) -> bool: """Returns true if it's time to log a warning for a timeout that is not @@ -660,8 +665,6 @@ def acquire_read( # for readers, `self.current_idx` is the next block to read # if this block is not ready, # we need to wait until it is written - if read_timeout.expired(): - raise TimeoutError self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms()) if self.shutting_down: From de3f4a659259c0e8f2e4b0d447a9d7a693f5a16d Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 18 Feb 2026 15:31:20 -0700 Subject: [PATCH 14/27] rename: ReadTimeoutWithWarnings Signed-off-by: Travis Johnson --- vllm/distributed/device_communicators/shm_broadcast.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index d21d40939611..f9ba9db1990d 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -592,7 +592,7 @@ def acquire_write(self, timeout: float | None = None): self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break - class ReadTimeout: + class ReadTimeoutWithWarnings: def __init__(self, timeout: float | None, should_warn: bool) -> None: self.started = time.monotonic() if timeout is not None: @@ -647,7 +647,9 @@ def acquire_read( indefinite: bool = False, ): assert self._is_local_reader, "Only readers can acquire read" - read_timeout = self.ReadTimeout(timeout=timeout, should_warn=not indefinite) + read_timeout = self.ReadTimeoutWithWarnings( + timeout=timeout, should_warn=not indefinite + ) while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: From b8b56f07f3e257afa14979559b294501b351dfdb Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 18 Feb 2026 16:13:19 -0700 Subject: [PATCH 15/27] :test_tube: add negative test for warning logs Signed-off-by: Joe Runde --- tests/distributed/test_shm_broadcast.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 25b575c22b53..d9bbb0bcba2f 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -318,9 +318,10 @@ def test_message_queue_busy_to_idle(): def test_warning_logs(caplog_vllm): """ Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals - when indefinite=False + when indefinite=False, and are not emitted when indefinite=True. """ + # Patch the warning log interval to every 1 ms during reads with mock.patch( "vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL", new=0.001, # 1 ms @@ -335,16 +336,25 @@ def test_warning_logs(caplog_vllm): writer.wait_until_ready() reader.wait_until_ready() - # Reader times out + # We should have at least one warning log here with pytest.raises(TimeoutError): reader.dequeue(timeout=0.01, indefinite=False) - - # Clean up when done - writer.shutdown() - reader.shutdown() - assert any( "No available shared memory broadcast block found in 0.001 seconds" in record.message for record in caplog_vllm.records ) + caplog_vllm.clear() + + # We should have no warnings this time + with pytest.raises(TimeoutError): + reader.dequeue(timeout=0.01, indefinite=True) + assert all( + "No available shared memory broadcast block found in 0.001 seconds" + not in record.message + for record in caplog_vllm.records + ) + + # Clean up when done + writer.shutdown() + reader.shutdown() From 1affbef6f78fbba5cc390fa85741efb96b78dac7 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Thu, 19 Feb 2026 13:59:34 -0700 Subject: [PATCH 16/27] :bug: fix test hangs Signed-off-by: Joe Runde --- tests/distributed/test_shm_broadcast.py | 52 ++++++++++++++++--- .../device_communicators/shm_broadcast.py | 4 +- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index d9bbb0bcba2f..d26574d5faae 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -25,7 +25,18 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: return [np.random.randint(1, 100, i) for i in sizes] -def distributed_run(fn, world_size): +def distributed_run(fn, world_size, timeout=5): + """Run a function in multiple processes with proper error handling. + + Args: + fn: Function to run in each process + world_size: Number of processes to spawn + timeout: Maximum time in seconds to wait for processes (default: 60) + """ + # Use spawn method for better macOS compatibility + # Get the context for spawn method + # ctx = mp.get_context('spawn') + number_of_processes = world_size processes = [] for i in range(number_of_processes): @@ -40,11 +51,40 @@ def distributed_run(fn, world_size): processes.append(p) p.start() - for p in processes: - p.join() - - for p in processes: - assert p.exitcode == 0 + # Join with timeout to detect hangs (parallel timeout for all processes) + start_time = time.time() + failed_processes = [] + + # Wait for all processes with a shared timeout + while time.time() - start_time < timeout: + all_done = True + for p in processes: + if p.is_alive(): + all_done = False + break + if all_done: + break + time.sleep(0.1) # Check every 100ms + + # Check final status of all processes + for i, p in enumerate(processes): + if p.is_alive(): + # Process is still running after timeout - likely hung at barrier + failed_processes.append((i, "timeout")) + p.kill() + p.join() + elif p.exitcode != 0: + failed_processes.append((i, p.exitcode)) + + # Report failures + if failed_processes: + error_msg = "Distributed test failed:\n" + for rank, status in failed_processes: + if status == "timeout": + error_msg += f" Rank {rank}: Timeout (likely hung at barrier)\n" + else: + error_msg += f" Rank {rank}: Exit code {status}\n" + raise AssertionError(error_msg) def worker_fn_wrapper(fn): diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index f9ba9db1990d..7a507cc8454c 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -620,9 +620,9 @@ def timeout_ms(self) -> int | None: return self.warning_wait_timeout_ms else: time_left = self.deadline - time.monotonic() - if time_left <= 0: - raise TimeoutError time_left_ms = int(time_left * 1000) + if time_left_ms <= 0: + raise TimeoutError return ( time_left_ms From 40051c5c91224164d34c7372e82a926c8f66d315 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Thu, 19 Feb 2026 15:15:08 -0700 Subject: [PATCH 17/27] test: distributed_run fail fast Signed-off-by: Travis Johnson --- tests/distributed/test_shm_broadcast.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index d26574d5faae..118cc421b1bc 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -25,7 +25,7 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: return [np.random.randint(1, 100, i) for i in sizes] -def distributed_run(fn, world_size, timeout=5): +def distributed_run(fn, world_size, timeout=60): """Run a function in multiple processes with proper error handling. Args: @@ -51,39 +51,36 @@ def distributed_run(fn, world_size, timeout=5): processes.append(p) p.start() - # Join with timeout to detect hangs (parallel timeout for all processes) + # Monitor processes and fail fast if any process fails start_time = time.time() failed_processes = [] - # Wait for all processes with a shared timeout + # Wait for all processes, checking for failures while time.time() - start_time < timeout: all_done = True - for p in processes: + for i, p in enumerate(processes): if p.is_alive(): all_done = False + elif p.exitcode != 0: + # Process failed + failed_processes.append((i, p.exitcode)) break - if all_done: + + if failed_processes or all_done: break time.sleep(0.1) # Check every 100ms - # Check final status of all processes + # Check for timeout if no failures detected yet for i, p in enumerate(processes): if p.is_alive(): - # Process is still running after timeout - likely hung at barrier - failed_processes.append((i, "timeout")) p.kill() p.join() - elif p.exitcode != 0: - failed_processes.append((i, p.exitcode)) # Report failures if failed_processes: error_msg = "Distributed test failed:\n" for rank, status in failed_processes: - if status == "timeout": - error_msg += f" Rank {rank}: Timeout (likely hung at barrier)\n" - else: - error_msg += f" Rank {rank}: Exit code {status}\n" + error_msg += f" Rank {rank}: Exit code {status}\n" raise AssertionError(error_msg) From 18febb3b80c1191456fea8e75b07bd69a80abc6e Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Thu, 19 Feb 2026 15:25:09 -0700 Subject: [PATCH 18/27] refactor: move monitor_parent_death to be a worker method Signed-off-by: Travis Johnson --- vllm/v1/executor/multiproc_executor.py | 68 ++++++++++++++------------ 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 2680c59a64a7..ffbf88b67c7d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -5,6 +5,7 @@ import pickle import queue import signal +import threading import time import traceback import weakref @@ -673,12 +674,41 @@ def wait_for_ready( return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.shutdown() + if self.worker_response_mq is not None: + self.worker_response_mq.shutdown() self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() destroy_distributed_environment() + def monitor_parent_death(self, death_pipe, shutdown_requested: threading.Event): + if death_pipe is None: + return + + death_monitor = None + # Start death monitoring thread if death_pipe is provided + if death_pipe is not None: + + def monitor_parent_death(): + try: + # This will block until parent process exits (pipe closes) + death_pipe.recv() + except EOFError: + # Parent process has exited, terminate this worker + logger.info_once("Parent process exited, terminating worker") + shutdown_requested.set() + self.shutdown() + except Exception as e: + logger.warning("Death monitoring error: %s", e) + + death_monitor = Thread( + target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" + ) + death_monitor.start() + @staticmethod def worker_main(*args, **kwargs): """Worker initialization and execution loops. @@ -687,12 +717,12 @@ def worker_main(*args, **kwargs): # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker # processes to terminate without error - shutdown_requested = False + shutdown_requested = threading.Event() def signal_handler(signum, frame): nonlocal shutdown_requested - if not shutdown_requested: - shutdown_requested = True + if not shutdown_requested.is_set(): + shutdown_requested.set() raise SystemExit() # Either SIGTERM or SIGINT will terminate the worker @@ -703,38 +733,14 @@ def signal_handler(signum, frame): # tuple[Connection, Connection] reader, ready_writer = kwargs.pop("ready_pipe") death_pipe: Connection | None = kwargs.pop("death_pipe", None) - shutdown = False - # Start death monitoring thread if death_pipe is provided - if death_pipe is not None: - - def monitor_parent_death(): - try: - # This will block until parent process exits (pipe closes) - death_pipe.recv() - except EOFError: - # Parent process has exited, terminate this worker - logger.info_once("Parent process exited, terminating worker") - nonlocal shutdown - shutdown = True - # Shut down message queues - if worker is not None and worker.rpc_broadcast_mq is not None: - worker.rpc_broadcast_mq.shutdown() - if worker is not None and worker.worker_response_mq is not None: - worker.worker_response_mq.shutdown() - - except Exception as e: - logger.warning("Death monitoring error: %s", e) - - death_monitor = Thread( - target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" - ) - death_monitor.start() try: reader.close() worker = WorkerProc(*args, **kwargs) assert worker.worker_response_mq is not None + worker.monitor_parent_death(death_pipe, shutdown_requested) + # Send READY once we know everything is loaded ready_writer.send( { @@ -762,7 +768,7 @@ def monitor_parent_death(): if ready_writer is not None: logger.exception("WorkerProc failed to start.") - elif shutdown: + elif shutdown_requested.is_set(): logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -770,7 +776,7 @@ def monitor_parent_death(): # The parent sends a SIGTERM to all worker processes if # any worker dies. Set this value so we don't re-throw # SystemExit() to avoid zmq exceptions in __del__. - shutdown_requested = True + shutdown_requested.set() finally: if ready_writer is not None: From 6038f8923ad28d3dc49a087ea1ed76894be29672 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Thu, 19 Feb 2026 15:43:33 -0700 Subject: [PATCH 19/27] refactor: cleanup new monitor function Signed-off-by: Travis Johnson --- vllm/v1/executor/multiproc_executor.py | 35 +++++++++++--------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index ffbf88b67c7d..c167be550598 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -684,30 +684,23 @@ def shutdown(self): destroy_model_parallel() destroy_distributed_environment() - def monitor_parent_death(self, death_pipe, shutdown_requested: threading.Event): + def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event): if death_pipe is None: return - death_monitor = None - # Start death monitoring thread if death_pipe is provided - if death_pipe is not None: + def death_pipe_monitor(): + try: + # This will block until parent process exits (pipe closes) + death_pipe.recv() + except EOFError: + # Parent process has exited, terminate this worker + logger.info_once("Parent process exited, terminating worker") + shutdown_requested.set() + self.shutdown() + except Exception as e: + logger.warning("Death monitoring error: %s", e) - def monitor_parent_death(): - try: - # This will block until parent process exits (pipe closes) - death_pipe.recv() - except EOFError: - # Parent process has exited, terminate this worker - logger.info_once("Parent process exited, terminating worker") - shutdown_requested.set() - self.shutdown() - except Exception as e: - logger.warning("Death monitoring error: %s", e) - - death_monitor = Thread( - target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" - ) - death_monitor.start() + Thread(target=death_pipe_monitor, daemon=True, name="DeathPipeMonitor").start() @staticmethod def worker_main(*args, **kwargs): @@ -739,7 +732,7 @@ def signal_handler(signum, frame): worker = WorkerProc(*args, **kwargs) assert worker.worker_response_mq is not None - worker.monitor_parent_death(death_pipe, shutdown_requested) + worker.monitor_death_pipe(death_pipe, shutdown_requested) # Send READY once we know everything is loaded ready_writer.send( From 5f14af269bff11643170408564b279d53631f2e0 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 24 Feb 2026 15:40:03 -0700 Subject: [PATCH 20/27] review: changes from review Signed-off-by: Travis Johnson --- tests/distributed/test_shm_broadcast.py | 4 ---- .../device_communicators/shm_broadcast.py | 22 +++++++++---------- vllm/v1/executor/multiproc_executor.py | 17 +++++++++----- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 118cc421b1bc..dfc612dcaf5c 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -33,10 +33,6 @@ def distributed_run(fn, world_size, timeout=60): world_size: Number of processes to spawn timeout: Maximum time in seconds to wait for processes (default: 60) """ - # Use spawn method for better macOS compatibility - # Get the context for spawn method - # ctx = mp.get_context('spawn') - number_of_processes = world_size processes = [] for i in range(number_of_processes): diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 5333caba1c65..b30a837f3162 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -624,17 +624,17 @@ def timeout_ms(self) -> int | None: """ if self.timeout is None: return self.warning_wait_timeout_ms - else: - time_left = self.deadline - time.monotonic() - time_left_ms = int(time_left * 1000) - if time_left_ms <= 0: - raise TimeoutError - - return ( - time_left_ms - if self.warning_wait_timeout_ms is None - else min(self.warning_wait_timeout_ms, time_left_ms) - ) + + time_left = self.deadline - time.monotonic() + time_left_ms = int(time_left * 1000) + if time_left_ms <= 0: + raise TimeoutError + + return ( + time_left_ms + if self.warning_wait_timeout_ms is None + else min(self.warning_wait_timeout_ms, time_left_ms) + ) def should_warn(self) -> bool: """Returns true if it's time to log a warning for a timeout that is not diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index f9480302b155..6fbb91f4cdde 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -697,19 +697,26 @@ def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event): if death_pipe is None: return - def death_pipe_monitor(): + def death_pipe_monitor(queues_to_shutdown: list[MessageQueue]): try: # This will block until parent process exits (pipe closes) death_pipe.recv() except EOFError: - # Parent process has exited, terminate this worker - logger.info_once("Parent process exited, terminating worker") + logger.info_once("Parent process exited, terminating worker queues") shutdown_requested.set() - self.shutdown() + for mq in queues_to_shutdown: + if mq is not None: + mq.shutdown() except Exception as e: logger.warning("Death monitoring error: %s", e) - Thread(target=death_pipe_monitor, daemon=True, name="DeathPipeMonitor").start() + # Pass queue references directly to avoid gc issues if passing self + Thread( + target=death_pipe_monitor, + args=([self.rpc_broadcast_mq, self.worker_response_mq],), + daemon=True, + name="DeathPipeMonitor", + ).start() @staticmethod def worker_main(*args, **kwargs): From 2f3f98cab470a42a31084b30cfa82536b81c9cff Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 24 Feb 2026 16:12:29 -0700 Subject: [PATCH 21/27] fix: handle inherited socket connections when forking Signed-off-by: Travis Johnson --- vllm/v1/executor/multiproc_executor.py | 68 +++++++++++++++++--------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 6fbb91f4cdde..42731dfae398 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -144,20 +144,31 @@ def _init_executor(self) -> None: global_start_rank = ( self.local_world_size * self.parallel_config.node_rank_within_dp ) + # Keep track of socket file descriptors that are inherited by the + # worker when using fork, so that we can close them in subsequent + # workers + inherited_fds: list[int] = [] for local_rank in range(self.local_world_size): global_rank = global_start_rank + local_rank is_driver_worker = self._is_driver_worker(global_rank) - unready_workers.append( - WorkerProc.make_worker_process( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=global_rank, - distributed_init_method=distributed_init_method, - input_shm_handle=scheduler_output_handle, - shared_worker_lock=shared_worker_lock, - is_driver_worker=is_driver_worker, - ) + unready_worker_handle = WorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=global_rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, + is_driver_worker=is_driver_worker, + inherited_fds=inherited_fds, ) + unready_workers.append(unready_worker_handle) + if context.get_start_method() == "fork": + inherited_fds.extend( + [ + unready_worker_handle.death_writer.fileno(), + unready_worker_handle.ready_pipe.fileno(), + ] + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -589,24 +600,26 @@ def make_worker_process( input_shm_handle, # Receive SchedulerOutput shared_worker_lock: LockType, is_driver_worker: bool, + inherited_fds: list[int], ) -> UnreadyWorkerProcHandle: context = get_mp_context() - # (reader, writer) - reader, writer = context.Pipe(duplex=False) - - # Create death pipe to detect parent process exit + # Ready pipe to communicate readiness from child to parent + ready_reader, ready_writer = context.Pipe(duplex=False) + # Death pipe to let child detect parent process exit death_reader, death_writer = context.Pipe(duplex=False) - process_kwargs = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, - "ready_pipe": (reader, writer), + "ready_pipe": ready_writer, "death_pipe": death_reader, "shared_worker_lock": shared_worker_lock, "is_driver_worker": is_driver_worker, + # Have the worker close parent end of this worker's pipes too + "inherited_fds": inherited_fds + + [ready_reader.fileno(), death_writer.fileno()], } # Run EngineCore busy loop in background process. proc = context.Process( @@ -617,10 +630,12 @@ def make_worker_process( ) proc.start() - writer.close() + # Close child ends of pipes here in the parent + ready_writer.close() + death_reader.close() # Keep death_writer open in parent - when parent exits, # death_reader in child will get EOFError - return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) + return UnreadyWorkerProcHandle(proc, rank, ready_reader, death_writer) @staticmethod def wait_for_response_handle_ready( @@ -742,13 +757,20 @@ def signal_handler(signum, frame): signal.signal(signal.SIGINT, signal_handler) worker = None - # tuple[Connection, Connection] - reader, ready_writer = kwargs.pop("ready_pipe") - death_pipe: Connection | None = kwargs.pop("death_pipe", None) + ready_writer = kwargs.pop("ready_pipe") + death_pipe = kwargs.pop("death_pipe", None) + + # Close inherited pipes from parent (incl. other worker pipes) + # Explicitly passing in existing pipes and closing them makes the pipe + # behave when using fork. Otherwise, a hidden reference to the pipes + # exist in the child process and prevents EOF closure. + for fd in kwargs.pop("inherited_fds", []): + try: + os.close(fd) + except Exception as e: + logger.warning("Exception closing inherited connection: %s", e) try: - reader.close() - # Initialize tracer rank = kwargs.get("rank", 0) maybe_init_worker_tracer( From c45326b2ef5707191a27e301d1073a1f63e8bed6 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 25 Feb 2026 09:51:56 -0700 Subject: [PATCH 22/27] log: add some debug logs to ensure_worker_termination Signed-off-by: Travis Johnson --- vllm/v1/executor/multiproc_executor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 42731dfae398..49b828883907 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -401,14 +401,19 @@ def wait_for_termination(procs, timeout): active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] # Give processes time to clean themselves up properly first + logger.debug("Worker Termination: allow workers to gracefully shutdown") if wait_for_termination(active_procs(), 4): return # Send SIGTERM if still running + logger.debug("Worker Termination: workers still running sending SIGTERM") for p in active_procs(): p.terminate() if not wait_for_termination(active_procs(), 4): # Send SIGKILL if still running + logger.debug( + "Worker Termination: resorting to SIGKILL to take down workers" + ) for p in active_procs(): p.kill() From e1565e0c3a81b895c4ba780ef5fefbc0f93ff1e2 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 25 Feb 2026 12:53:25 -0700 Subject: [PATCH 23/27] fix: shutdown all queues in MultiprocExecutor Signed-off-by: Travis Johnson --- vllm/v1/executor/multiproc_executor.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 49b828883907..b9c5b7c0e0d1 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -217,6 +217,7 @@ def _init_executor(self) -> None: for uw in unready_workers: if uw.death_writer is not None: uw.death_writer.close() + uw.death_writer = None self._ensure_worker_termination([uw.proc for uw in unready_workers]) self.output_rank = self._get_output_rank() @@ -252,6 +253,7 @@ def monitor_workers(): died = multiprocessing.connection.wait(sentinels) _self = self_ref() if not _self or getattr(_self, "shutting_down", False): + logger.debug("MultiprocWorkerMonitor: shutdown already initiated") return _self.is_failed = True proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) @@ -420,6 +422,7 @@ def wait_for_termination(procs, timeout): def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, "shutting_down", False): + logger.debug("Triggering shutdown of workers") self.shutting_down = True # Make sure all the worker processes are terminated first. @@ -429,10 +432,20 @@ def shutdown(self): if w.death_writer is not None: w.death_writer.close() w.death_writer = None - w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) - self.rpc_broadcast_mq = None + for w in workers: + # Shutdown response queues + if w.worker_response_mq is not None: + w.worker_response_mq.shutdown() + w.worker_response_mq = None + + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.shutdown() + self.rpc_broadcast_mq = None + for mq in self.response_mqs: + mq.shutdown() + self.response_mqs = [] def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) From f7e34867121fd49fe39631560a94b903737cf59f Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 2 Mar 2026 10:59:34 -0700 Subject: [PATCH 24/27] log: add logging to SpinCondition wait Signed-off-by: Travis Johnson --- vllm/distributed/device_communicators/shm_broadcast.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index b30a837f3162..8bf9c9cd8057 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -187,13 +187,14 @@ def wait(self, timeout_ms: int | None = None) -> None: events = dict(self.poller.poll(timeout=timeout_ms)) if self.read_cancel_socket in events: - # return immediately on cancel - return - - if self.local_notify_socket in events: + logger.debug("Poller received cancel event") + elif self.local_notify_socket in events: + logger.debug("Poller received notify event") # Since zmq.CONFLATE is set, there will only be one notification # to read from the socket self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False) + else: + logger.debug("Poller timed out") def notify(self): """Notifies all readers to wake up""" From 765660f43393d132e7a17317b4110af5cf076655 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 3 Mar 2026 11:26:32 -0800 Subject: [PATCH 25/27] cleanup removed env var Signed-off-by: Nick Hill --- tests/basic_correctness/test_basic_correctness.py | 2 -- vllm/envs.py | 1 - 2 files changed, 3 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 68b5cd5101d5..70c58ad96dd7 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -124,8 +124,6 @@ def test_models( [ ("facebook/opt-125m", "ray", "", "L4", {}), ("facebook/opt-125m", "mp", "", "L4", {}), - ("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), - ("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("facebook/opt-125m", "ray", "", "A100", {}), diff --git a/vllm/envs.py b/vllm/envs.py index 303f772457a5..9f6222803b3a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1747,7 +1747,6 @@ def compile_factors() -> dict[str, object]: "VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", - "VLLM_SLEEP_WHEN_IDLE", "VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT", From 97eb604bbef2988deca0b5fd98def6172309e5a4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 3 Mar 2026 12:34:07 -0800 Subject: [PATCH 26/27] minor code simplification Signed-off-by: Nick Hill --- .../device_communicators/shm_broadcast.py | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 8bf9c9cd8057..1c5c4e01d8c8 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -80,14 +80,13 @@ def to_bytes_big(value: int, size: int) -> bytes: logger = init_logger(__name__) -def long_wait_time_msg(threshold: int) -> str: - return ( - "No available shared memory broadcast block found " - f"in {threshold} seconds. This typically happens " - "when some processes are hanging or doing some " - "time-consuming work (e.g. compilation, " - "weight/kv cache quantization)." - ) +LONG_WAIT_TIME_LOG_MSG = ( + "No available shared memory broadcast block found " + "in %d seconds. This typically happens " + "when some processes are hanging or doing some " + "time-consuming work (e.g. compilation, " + "weight/kv cache quantization)." +) class SpinCondition: @@ -147,7 +146,7 @@ def __init__( else: # Writer side publishes write notifications self.local_notify_socket: zmq.Socket = context.socket(PUB) # type: ignore - # Set high water mark to 1- we don't need to send a massive amount of + # Set high water mark to 1 - we don't need to send a massive amount of # pings during busy operation. PUB sockets will silently drop subsequent # messages after the high water mark is reached. self.local_notify_socket.setsockopt(zmq.SNDHWM, 1) @@ -561,7 +560,7 @@ def acquire_write(self, timeout: float | None = None): # if we wait for a long time, log a message if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: logger.info( - long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) + LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL ) n_warning += 1 @@ -602,13 +601,10 @@ def acquire_write(self, timeout: float | None = None): class ReadTimeoutWithWarnings: def __init__(self, timeout: float | None, should_warn: bool) -> None: self.started = time.monotonic() - if timeout is not None: - self.deadline = self.started + timeout - else: - self.deadline = sys.maxsize + self.deadline = sys.maxsize if timeout is None else self.started + timeout # if should_warn, we need to wake up periodically to log - self.warning_wait_timeout_ms: int | None = ( + self.warning_wait_time_ms: int | None = ( VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None ) @@ -623,19 +619,18 @@ def timeout_ms(self) -> int | None: - None if the timeout is None and we're not logging warnings - raise TimeoutError if we are past the deadline """ + warning_wait_time = self.warning_wait_time_ms if self.timeout is None: - return self.warning_wait_timeout_ms + return warning_wait_time - time_left = self.deadline - time.monotonic() - time_left_ms = int(time_left * 1000) + time_left_ms = int((self.deadline - time.monotonic()) * 1000) if time_left_ms <= 0: raise TimeoutError - return ( - time_left_ms - if self.warning_wait_timeout_ms is None - else min(self.warning_wait_timeout_ms, time_left_ms) - ) + if warning_wait_time and warning_wait_time < time_left_ms: + return warning_wait_time + + return time_left_ms def should_warn(self) -> bool: """Returns true if it's time to log a warning for a timeout that is not @@ -681,7 +676,7 @@ def acquire_read( # if we wait for a long time, log a message if read_timeout.should_warn(): logger.info( - long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) + LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL ) continue From 8bf1da682338e6985d52e9b9abf93b0ef18c9df1 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 3 Mar 2026 14:54:18 -0700 Subject: [PATCH 27/27] fix: time in log message now rounds to 0 Signed-off-by: Travis Johnson --- tests/distributed/test_shm_broadcast.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index dfc612dcaf5c..7cf3b01e75c7 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -370,10 +370,11 @@ def test_warning_logs(caplog_vllm): reader.wait_until_ready() # We should have at least one warning log here + # "0 seconds" expected due to rounding of 1ms test interval with pytest.raises(TimeoutError): reader.dequeue(timeout=0.01, indefinite=False) assert any( - "No available shared memory broadcast block found in 0.001 seconds" + "No available shared memory broadcast block found in 0 seconds" in record.message for record in caplog_vllm.records ) @@ -383,7 +384,7 @@ def test_warning_logs(caplog_vllm): with pytest.raises(TimeoutError): reader.dequeue(timeout=0.01, indefinite=True) assert all( - "No available shared memory broadcast block found in 0.001 seconds" + "No available shared memory broadcast block found in 0 seconds" not in record.message for record in caplog_vllm.records )