diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 68b5cd5101d5..70c58ad96dd7 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -124,8 +124,6 @@ def test_models( [ ("facebook/opt-125m", "ray", "", "L4", {}), ("facebook/opt-125m", "mp", "", "L4", {}), - ("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), - ("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("facebook/opt-125m", "ray", "", "A100", {}), diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index a7ace62e1b54..7cf3b01e75c7 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import multiprocessing import random +import threading import time +from unittest import mock +import multiprocess as mp import numpy as np +import pytest import torch.distributed as dist from vllm.distributed.device_communicators.shm_broadcast import MessageQueue @@ -22,7 +25,14 @@ def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: return [np.random.randint(1, 100, i) for i in sizes] -def distributed_run(fn, world_size): +def distributed_run(fn, world_size, timeout=60): + """Run a function in multiple processes with proper error handling. + + Args: + fn: Function to run in each process + world_size: Number of processes to spawn + timeout: Maximum time in seconds to wait for processes (default: 60) + """ number_of_processes = world_size processes = [] for i in range(number_of_processes): @@ -33,19 +43,45 @@ def distributed_run(fn, world_size): env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["MASTER_ADDR"] = "localhost" env["MASTER_PORT"] = "12345" - p = multiprocessing.Process(target=fn, args=(env,)) + p = mp.Process(target=fn, args=(env,)) processes.append(p) p.start() - for p in processes: - p.join() + # Monitor processes and fail fast if any process fails + start_time = time.time() + failed_processes = [] + + # Wait for all processes, checking for failures + while time.time() - start_time < timeout: + all_done = True + for i, p in enumerate(processes): + if p.is_alive(): + all_done = False + elif p.exitcode != 0: + # Process failed + failed_processes.append((i, p.exitcode)) + break + + if failed_processes or all_done: + break + time.sleep(0.1) # Check every 100ms - for p in processes: - assert p.exitcode == 0 + # Check for timeout if no failures detected yet + for i, p in enumerate(processes): + if p.is_alive(): + p.kill() + p.join() + + # Report failures + if failed_processes: + error_msg = "Distributed test failed:\n" + for rank, status in failed_processes: + error_msg += f" Rank {rank}: Exit code {status}\n" + raise AssertionError(error_msg) def worker_fn_wrapper(fn): - # `multiprocessing.Process` cannot accept environment variables directly + # `mp.Process` cannot accept environment variables directly # so we need to pass the environment variables as arguments # and update the environment variables in the function def wrapped_fn(env): @@ -115,3 +151,244 @@ def worker_fn(): def test_shm_broadcast(): distributed_run(worker_fn, 4) + + +@worker_fn_wrapper +def worker_fn_test_shutdown_busy(): + rank = dist.get_rank() + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) + + if not message_queue._is_writer: + # Put into busy mode + message_queue._spin_condition.busy_loop_s = 9999 + + shutdown_event = threading.Event() + + def shutdown_thread(mq, shutdown_event): + shutdown_event.wait() + mq.shutdown() + + threading.Thread( + target=shutdown_thread, args=(message_queue, shutdown_event) + ).start() + + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + + shutdown_event.set() + + with pytest.raises(RuntimeError, match="cancelled"): + message_queue.dequeue(timeout=1) + + assert message_queue.shutting_down + + print(f"torch distributed passed the test! Rank {rank}") + dist.barrier() + + +def test_message_queue_shutdown_busy(caplog_vllm): + distributed_run(worker_fn_test_shutdown_busy, 4) + print(caplog_vllm.text) + + +@worker_fn_wrapper +def worker_fn_test_shutdown_idle(): + rank = dist.get_rank() + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) + + if not message_queue._is_writer: + # Put into idle mode + message_queue._spin_condition.last_read = 0 + + shutdown_event = threading.Event() + + def shutdown_thread(mq, shutdown_event): + shutdown_event.wait() + mq.shutdown() + + threading.Thread( + target=shutdown_thread, args=(message_queue, shutdown_event) + ).start() + + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + + shutdown_event.set() + + with pytest.raises(RuntimeError, match="cancelled"): + message_queue.dequeue(timeout=1) + + assert message_queue.shutting_down + + print(f"torch distributed passed the test! Rank {rank}") + dist.barrier() + + +def test_message_queue_shutdown_idle(): + distributed_run(worker_fn_test_shutdown_idle, 4) + + +@worker_fn_wrapper +def worker_fn_test_idle_to_busy(): + rank = dist.get_rank() + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) + + message1 = "hello world" + message2 = np.random.randint(1, 100, 100) + with mock.patch.object( + message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait + ) as wrapped_wait: + if not message_queue._is_writer: + # Put into idle mode + message_queue._spin_condition.last_read = 0 + + # no messages, so expect a TimeoutError + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + # wait should only be called once while idle + assert wrapped_wait.call_count == 1 + + # sync with the writer and wait for message1 + dist.barrier() + recv_message = message_queue.dequeue(timeout=5) + assert recv_message == message1 + # second call to wait, with a message read, this puts in a busy spin + assert wrapped_wait.call_count == 2 + + # sync with the writer and wait for message2 + dist.barrier() + recv_message = message_queue.dequeue(timeout=1) + assert np.array_equal(recv_message, message2) + # in busy mode, we expect wait to have been called multiple times + assert wrapped_wait.call_count > 3 + else: + # writer writes two messages in sync with the reader + dist.barrier() + # sleep delays the send to ensure reader enters the read loop + time.sleep(0.1) + message_queue.enqueue(message1) + + dist.barrier() + time.sleep(0.1) + message_queue.enqueue(message2) + + message_queue.shutdown() + assert message_queue.shutting_down + print(f"torch distributed passed the test! Rank {rank}") + + +def test_message_queue_idle_wake(): + distributed_run(worker_fn_test_idle_to_busy, 4) + + +@worker_fn_wrapper +def worker_fn_test_busy_to_idle(): + rank = dist.get_rank() + writer_rank = 2 + message_queue = MessageQueue.create_from_process_group( + dist.group.WORLD, 40 * 1024, 2, writer_rank + ) + + message1 = 12345 + message2 = list(range(3)) + with mock.patch.object( + message_queue._spin_condition, "wait", wraps=message_queue._spin_condition.wait + ) as wrapped_wait: + if not message_queue._is_writer: + # Put into busy mode + message_queue._spin_condition.busy_loop_s = 9999 + + # sync with the writer and wait for message1 + dist.barrier() + recv_message = message_queue.dequeue(timeout=1) + assert recv_message == message1 + # in busy mode, we expect wait to have been called many times + assert wrapped_wait.call_count > 1 + + # simulate busy loop ending + message_queue._spin_condition.busy_loop_s = 0 + # ensure we enter idle mode, then record call count + with pytest.raises(TimeoutError): + message_queue.dequeue(timeout=0.01) + call_count = wrapped_wait.call_count + + # sync with the writer and wait for message2 + dist.barrier() + recv_message = message_queue.dequeue(timeout=1) + assert recv_message == message2 + + # call to wait after idle should only happen once + assert wrapped_wait.call_count == call_count + 1 + else: + # writer writes two messages in sync with the reader + dist.barrier() + # sleep delays the send to ensure reader enters the read loop + time.sleep(0.1) + message_queue.enqueue(message1) + + dist.barrier() + time.sleep(0.1) + message_queue.enqueue(message2) + + message_queue.shutdown() + assert message_queue.shutting_down + print(f"torch distributed passed the test! Rank {rank}") + + +def test_message_queue_busy_to_idle(): + distributed_run(worker_fn_test_busy_to_idle, 4) + + +def test_warning_logs(caplog_vllm): + """ + Test that warning logs are emitted at VLLM_RINGBUFFER_WARNING_INTERVAL intervals + when indefinite=False, and are not emitted when indefinite=True. + """ + + # Patch the warning log interval to every 1 ms during reads + with mock.patch( + "vllm.distributed.device_communicators.shm_broadcast.VLLM_RINGBUFFER_WARNING_INTERVAL", + new=0.001, # 1 ms + ): + writer = MessageQueue( + n_reader=1, + n_local_reader=1, + max_chunk_bytes=1024 * 1024, # 1MB chunks + max_chunks=10, + ) + reader = MessageQueue.create_from_handle(writer.export_handle(), rank=0) + writer.wait_until_ready() + reader.wait_until_ready() + + # We should have at least one warning log here + # "0 seconds" expected due to rounding of 1ms test interval + with pytest.raises(TimeoutError): + reader.dequeue(timeout=0.01, indefinite=False) + assert any( + "No available shared memory broadcast block found in 0 seconds" + in record.message + for record in caplog_vllm.records + ) + caplog_vllm.clear() + + # We should have no warnings this time + with pytest.raises(TimeoutError): + reader.dequeue(timeout=0.01, indefinite=True) + assert all( + "No available shared memory broadcast block found in 0 seconds" + not in record.message + for record in caplog_vllm.records + ) + + # Clean up when done + writer.shutdown() + reader.shutdown() diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index ac46a5667373..1c5c4e01d8c8 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -2,13 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import pickle +import sys import threading import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory from pickle import PickleBuffer -from threading import Event from typing import TYPE_CHECKING, Any, cast from unittest.mock import patch @@ -18,6 +18,7 @@ from torch.distributed import ProcessGroup from zmq import ( # type: ignore IPV6, # type: ignore + PUB, SUB, SUBSCRIBE, XPUB, @@ -32,6 +33,7 @@ from vllm.utils.network_utils import ( get_ip, get_open_port, + get_open_zmq_inproc_path, get_open_zmq_ipc_path, is_valid_ipv6_address, ) @@ -78,50 +80,125 @@ def to_bytes_big(value: int, size: int) -> bytes: logger = init_logger(__name__) -def long_wait_time_msg(threshold: int) -> str: - return ( - "No available shared memory broadcast block found " - f"in {threshold} seconds. This typically happens " - "when some processes are hanging or doing some " - "time-consuming work (e.g. compilation, " - "weight/kv cache quantization)." - ) - - -class SpinTimer: - def record_activity(self): - pass - - def spin(self): - sched_yield() +LONG_WAIT_TIME_LOG_MSG = ( + "No available shared memory broadcast block found " + "in %d seconds. This typically happens " + "when some processes are hanging or doing some " + "time-consuming work (e.g. compilation, " + "weight/kv cache quantization)." +) -class SpinSleepTimer(SpinTimer): +class SpinCondition: """ - In setups which have long inactivity periods it is desirable to reduce - system power consumption when vllm does nothing. This would lead to more - CPU thermal headroom when a request eventually comes, especially when - multiple GPUs are connected as each GPU would otherwise pin one thread at - 100% CPU usage. - - The simplest solution is to reduce polling frequency when there is no - activity for a certain period of time. + This class implements an interface similar to a threading.Condition. It + allows a writer to notify readers to wake up and read from the shared memory + buffer. This notification is done over a zmq socket. + + For optimal performance under load we don't want the readers to need to poll + the zmq socket for every read. So the `wait` method here will return + immediately when reads are frequent, and will only enter "idle mode" and + await a notification on the zmq socket after a period of inactivity. This + allows the readers to spin quickly, hence "SpinCondition". + + To support clean shutdown, a separate thread in the reader's process must be + able to wake the reader so that it can exit. A separate cancel() method is + implemented with an in-process socket to allow this interruption. """ - def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1): - self.last_activity = time.monotonic() - self.busy_loop_s = busy_loop_s - self.wait_sleep_s = wait_sleep_s - - def record_activity(self): - self.last_activity = time.monotonic() - - def spin(self): - curr_time = time.monotonic() - if curr_time >= self.last_activity + self.busy_loop_s: - time.sleep(self.wait_sleep_s) + def __init__( + self, + is_reader: bool, + context: zmq.Context, + notify_address: str, + busy_loop_s: float = 1, + ): + self.is_reader = is_reader + + if is_reader: + # Time of last shm buffer read + self.last_read = time.monotonic() + + # Time to keep busy-looping on the shm buffer before going idle + self.busy_loop_s = busy_loop_s + + # Readers subscribe to write notifications + self.local_notify_socket: zmq.Socket = context.socket(SUB) + # Set zmq.CONFLATE to only keep the last message that the socket + # receives. This prevents us from piling up notification messages + # under high load when we aren't polling the socket. + self.local_notify_socket.setsockopt(zmq.CONFLATE, 1) + # Subscribe to all messages on the socket + self.local_notify_socket.setsockopt_string(SUBSCRIBE, "") + self.local_notify_socket.connect(notify_address) + + # Readers require a process-local socket to poll for cancellation + cancel_path = get_open_zmq_inproc_path() + self.write_cancel_socket: zmq.Socket = context.socket(zmq.PAIR) + self.write_cancel_socket.bind(cancel_path) + self.read_cancel_socket: zmq.Socket = context.socket(zmq.PAIR) + self.read_cancel_socket.connect(cancel_path) + + # Poller allows waiting on either `.notify()` or `.cancel()` + self.poller = zmq.Poller() + self.poller.register(self.read_cancel_socket, zmq.POLLIN) + self.poller.register(self.local_notify_socket, zmq.POLLIN) else: + # Writer side publishes write notifications + self.local_notify_socket: zmq.Socket = context.socket(PUB) # type: ignore + # Set high water mark to 1 - we don't need to send a massive amount of + # pings during busy operation. PUB sockets will silently drop subsequent + # messages after the high water mark is reached. + self.local_notify_socket.setsockopt(zmq.SNDHWM, 1) + self.local_notify_socket.bind(notify_address) + + self.last_read = 0 + self.busy_loop_s = 0 + self.read_cancel_socket = None + self.write_cancel_socket = None + self.poller = None + + def record_read(self): + self.last_read = time.monotonic() + + def cancel(self): + # Sends cancellation ping that will cause the reader to wake up. + # This is done from a monitor thread in the same process as the reader. + if self.is_reader: + logger.debug("Canceling waiting reads on SHM Buffer") + self.write_cancel_socket.send(b"\x00") + + def wait(self, timeout_ms: int | None = None) -> None: + """Wait for data on the shared memory buffer. + + Yields the scheduler then returns immediately if it has been less than + self.busy_loop_s since the last read. + + Otherwise, enters idle mode and awaits a socket ping for at most + `timeout_ms` milliseconds, or indefinitely if timeout_ms is None. + """ + assert self.is_reader, "Only readers can wait" + + current_time = time.monotonic() + if current_time <= self.last_read + self.busy_loop_s: sched_yield() + else: + events = dict(self.poller.poll(timeout=timeout_ms)) + + if self.read_cancel_socket in events: + logger.debug("Poller received cancel event") + elif self.local_notify_socket in events: + logger.debug("Poller received notify event") + # Since zmq.CONFLATE is set, there will only be one notification + # to read from the socket + self.local_notify_socket.recv(flags=zmq.NOBLOCK, copy=False) + else: + logger.debug("Poller timed out") + + def notify(self): + """Notifies all readers to wake up""" + assert not self.is_reader, "Only writers can notify" + self.local_notify_socket.send(b"\x00") class ShmRingBuffer: @@ -265,6 +342,7 @@ class Handle: buffer_handle: tuple[int, int, int, str] | None = None local_subscribe_addr: str | None = None + local_notify_addr: str | None = None remote_subscribe_addr: str | None = None remote_addr_ipv6: bool = False @@ -288,7 +366,7 @@ def __init__( self.n_local_reader = n_local_reader n_remote_reader = n_reader - n_local_reader self.n_remote_reader = n_remote_reader - + self.shutting_down = False context = Context() if n_local_reader > 0: @@ -310,11 +388,19 @@ def __init__( self.local_socket.bind(local_subscribe_addr) self.current_idx = 0 + + # Create the notification side of the SpinCondition + local_notify_addr = get_open_zmq_ipc_path() + self._spin_condition = SpinCondition( + is_reader=False, context=context, notify_address=local_notify_addr + ) else: self.buffer = None # type: ignore local_subscribe_addr = None self.local_socket = None self.current_idx = -1 + local_notify_addr = None + self._spin_condition = None # type: ignore remote_addr_ipv6 = False if n_remote_reader > 0: @@ -341,12 +427,12 @@ def __init__( self.local_reader_rank = -1 # rank does not matter for remote readers self._is_remote_reader = False - self._read_spin_timer = SpinTimer() self.handle = Handle( local_reader_ranks=local_reader_ranks, buffer_handle=self.buffer.handle() if self.buffer is not None else None, local_subscribe_addr=local_subscribe_addr, + local_notify_addr=local_notify_addr, remote_subscribe_addr=remote_subscribe_addr, remote_addr_ipv6=remote_addr_ipv6, ) @@ -379,9 +465,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket.connect(socket_addr) self.remote_socket = None - - self._read_spin_timer = ( - SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() + assert isinstance(handle.local_notify_addr, str) + self._spin_condition = SpinCondition( + is_reader=True, context=context, notify_address=handle.local_notify_addr ) else: self.buffer = None # type: ignore @@ -399,7 +485,9 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": socket_addr = handle.remote_subscribe_addr logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) + self._spin_condition = None # type: ignore + self.shutting_down = False return self def wait_until_ready(self): @@ -435,6 +523,13 @@ def wait_until_ready(self): recv = self.remote_socket.recv() assert recv == b"READY" + def shutdown(self): + """If this is an idle reader, wakes it up so it can clean up and shut + down""" + self.shutting_down = True + if self._spin_condition is not None: + self._spin_condition.cancel() + @contextmanager def acquire_write(self, timeout: float | None = None): assert self._is_writer, "Only writers can acquire write" @@ -465,7 +560,7 @@ def acquire_write(self, timeout: float | None = None): # if we wait for a long time, log a message if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: logger.info( - long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) + LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL ) n_warning += 1 @@ -503,16 +598,60 @@ def acquire_write(self, timeout: float | None = None): self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks break + class ReadTimeoutWithWarnings: + def __init__(self, timeout: float | None, should_warn: bool) -> None: + self.started = time.monotonic() + self.deadline = sys.maxsize if timeout is None else self.started + timeout + + # if should_warn, we need to wake up periodically to log + self.warning_wait_time_ms: int | None = ( + VLLM_RINGBUFFER_WARNING_INTERVAL * 1000 if should_warn else None + ) + + self._should_warn = should_warn + self.n_warning = 1 + self.timeout = timeout + + def timeout_ms(self) -> int | None: + """Returns a timeout that is: + - min(time to deadline, time to next warning) if we're logging warnings + - time to deadline, if we're not logging warnings + - None if the timeout is None and we're not logging warnings + - raise TimeoutError if we are past the deadline + """ + warning_wait_time = self.warning_wait_time_ms + if self.timeout is None: + return warning_wait_time + + time_left_ms = int((self.deadline - time.monotonic()) * 1000) + if time_left_ms <= 0: + raise TimeoutError + + if warning_wait_time and warning_wait_time < time_left_ms: + return warning_wait_time + + return time_left_ms + + def should_warn(self) -> bool: + """Returns true if it's time to log a warning for a timeout that is not + indefinite""" + if self._should_warn: + elapsed = time.monotonic() - self.started + if elapsed >= VLLM_RINGBUFFER_WARNING_INTERVAL * self.n_warning: + self.n_warning += 1 + return True + return False + @contextmanager def acquire_read( self, timeout: float | None = None, - cancel: Event | None = None, indefinite: bool = False, ): assert self._is_local_reader, "Only readers can acquire read" - start_time = time.monotonic() - n_warning = 1 + read_timeout = self.ReadTimeoutWithWarnings( + timeout=timeout, should_warn=not indefinite + ) with self.buffer.get_metadata(self.current_idx) as metadata_buffer: while True: # Memory fence ensures we see the latest writes from the writer. @@ -529,26 +668,16 @@ def acquire_read( # for readers, `self.current_idx` is the next block to read # if this block is not ready, # we need to wait until it is written + self._spin_condition.wait(timeout_ms=read_timeout.timeout_ms()) - # Release the processor to other threads - self._read_spin_timer.spin() - - if cancel is not None and cancel.is_set(): + if self.shutting_down: raise RuntimeError("cancelled") - # if we time out, raise an exception - elapsed = time.monotonic() - start_time - if timeout is not None and elapsed > timeout: - raise TimeoutError - # if we wait for a long time, log a message - if not indefinite and ( - elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning - ): + if read_timeout.should_warn(): logger.info( - long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL) + LONG_WAIT_TIME_LOG_MSG, VLLM_RINGBUFFER_WARNING_INTERVAL ) - n_warning += 1 continue # found a block that is not read by this reader @@ -565,7 +694,7 @@ def acquire_read( memory_fence() self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks - self._read_spin_timer.record_activity() + self._spin_condition.record_read() break def enqueue(self, obj, timeout: float | None = None): @@ -608,18 +737,19 @@ def oob_callback(buf: PickleBuffer) -> bool: buf[offset:buf_offset] = to_bytes_big(buf_len, 4) buf[buf_offset : (offset := buf_offset + buf_len)] = buffer + self._spin_condition.notify() + if self.n_remote_reader > 0: self.remote_socket.send_multipart(all_buffers, copy=False) def dequeue( self, timeout: float | None = None, - cancel: Event | None = None, indefinite: bool = False, ): """Read from message queue with optional timeout (in seconds)""" if self._is_local_reader: - with self.acquire_read(timeout, cancel, indefinite) as buf: + with self.acquire_read(timeout, indefinite) as buf: overflow = buf[0] == 1 if not overflow: offset = 3 diff --git a/vllm/envs.py b/vllm/envs.py index 02fcd998a031..598545d23cd1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -179,7 +179,6 @@ VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998 VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 - VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None @@ -1338,9 +1337,6 @@ def _get_or_set_default() -> str: "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int( os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") ), - # Reduce CPU usage when vLLM is idle. Enabling this will incur small - # latency penalty when a request eventually comes. - "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. @@ -1751,7 +1747,6 @@ def compile_factors() -> dict[str, object]: "VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS", "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", - "VLLM_SLEEP_WHEN_IDLE", "VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT", diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e3376ba2d1f7..ec215d8e525b 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -104,7 +104,6 @@ def _init_executor(self) -> None: # and ensure workers will be terminated. self._finalizer = weakref.finalize(self, self.shutdown) self.is_failed = False - self.shutdown_event = threading.Event() self.failure_callback: FailureCallback | None = None tp_size, pp_size, pcp_size = self._get_parallel_sizes() @@ -158,20 +157,31 @@ def _init_executor(self) -> None: global_start_rank = ( self.local_world_size * self.parallel_config.node_rank_within_dp ) + # Keep track of socket file descriptors that are inherited by the + # worker when using fork, so that we can close them in subsequent + # workers + inherited_fds: list[int] = [] for local_rank in range(self.local_world_size): global_rank = global_start_rank + local_rank is_driver_worker = self._is_driver_worker(global_rank) - unready_workers.append( - WorkerProc.make_worker_process( - vllm_config=self.vllm_config, - local_rank=local_rank, - rank=global_rank, - distributed_init_method=distributed_init_method, - input_shm_handle=scheduler_output_handle, - shared_worker_lock=shared_worker_lock, - is_driver_worker=is_driver_worker, - ) + unready_worker_handle = WorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=global_rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, + is_driver_worker=is_driver_worker, + inherited_fds=inherited_fds, ) + unready_workers.append(unready_worker_handle) + if context.get_start_method() == "fork": + inherited_fds.extend( + [ + unready_worker_handle.death_writer.fileno(), + unready_worker_handle.ready_pipe.fileno(), + ] + ) # Workers must be created before wait_for_ready to avoid # deadlock, since worker.init_device() does a device sync. @@ -220,6 +230,7 @@ def _init_executor(self) -> None: for uw in unready_workers: if uw.death_writer is not None: uw.death_writer.close() + uw.death_writer = None self._ensure_worker_termination([uw.proc for uw in unready_workers]) self.output_rank = self._get_output_rank() @@ -255,6 +266,7 @@ def monitor_workers(): died = multiprocessing.connection.wait(sentinels) _self = self_ref() if not _self or getattr(_self, "shutting_down", False): + logger.debug("MultiprocWorkerMonitor: shutdown already initiated") return _self.is_failed = True proc_name = next(h.proc.name for h in workers if h.proc.sentinel == died[0]) @@ -354,8 +366,6 @@ def collective_rpc( # type: ignore[override] if output_rank is not None: response_mqs = (response_mqs[output_rank],) - shutdown_event = self.shutdown_event - def get_response(): responses = [] for mq in response_mqs: @@ -363,9 +373,7 @@ def get_response(): None if deadline is None else (deadline - time.monotonic()) ) try: - status, result = mq.dequeue( - timeout=dequeue_timeout, cancel=shutdown_event - ) + status, result = mq.dequeue(timeout=dequeue_timeout) except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e if status != WorkerProc.ResponseStatus.SUCCESS: @@ -408,20 +416,26 @@ def wait_for_termination(procs, timeout): active_procs = lambda: [proc for proc in worker_procs if proc.is_alive()] # Give processes time to clean themselves up properly first + logger.debug("Worker Termination: allow workers to gracefully shutdown") if wait_for_termination(active_procs(), 4): return # Send SIGTERM if still running + logger.debug("Worker Termination: workers still running sending SIGTERM") for p in active_procs(): p.terminate() if not wait_for_termination(active_procs(), 4): # Send SIGKILL if still running + logger.debug( + "Worker Termination: resorting to SIGKILL to take down workers" + ) for p in active_procs(): p.kill() def shutdown(self): """Properly shut down the executor and its workers""" if not getattr(self, "shutting_down", False): + logger.debug("Triggering shutdown of workers") self.shutting_down = True # Make sure all the worker processes are terminated first. @@ -431,12 +445,20 @@ def shutdown(self): if w.death_writer is not None: w.death_writer.close() w.death_writer = None - w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in workers]) - self.shutdown_event.set() - - self.rpc_broadcast_mq = None + for w in workers: + # Shutdown response queues + if w.worker_response_mq is not None: + w.worker_response_mq.shutdown() + w.worker_response_mq = None + + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.shutdown() + self.rpc_broadcast_mq = None + for mq in self.response_mqs: + mq.shutdown() + self.response_mqs = [] def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) @@ -609,24 +631,26 @@ def make_worker_process( input_shm_handle, # Receive SchedulerOutput shared_worker_lock: LockType, is_driver_worker: bool, + inherited_fds: list[int], ) -> UnreadyWorkerProcHandle: context = get_mp_context() - # (reader, writer) - reader, writer = context.Pipe(duplex=False) - - # Create death pipe to detect parent process exit + # Ready pipe to communicate readiness from child to parent + ready_reader, ready_writer = context.Pipe(duplex=False) + # Death pipe to let child detect parent process exit death_reader, death_writer = context.Pipe(duplex=False) - process_kwargs = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, "input_shm_handle": input_shm_handle, - "ready_pipe": (reader, writer), + "ready_pipe": ready_writer, "death_pipe": death_reader, "shared_worker_lock": shared_worker_lock, "is_driver_worker": is_driver_worker, + # Have the worker close parent end of this worker's pipes too + "inherited_fds": inherited_fds + + [ready_reader.fileno(), death_writer.fileno()], } # Run EngineCore busy loop in background process. proc = context.Process( @@ -637,10 +661,12 @@ def make_worker_process( ) proc.start() - writer.close() + # Close child ends of pipes here in the parent + ready_writer.close() + death_reader.close() # Keep death_writer open in parent - when parent exits, # death_reader in child will get EOFError - return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) + return UnreadyWorkerProcHandle(proc, rank, ready_reader, death_writer) @staticmethod def wait_for_response_handle_ready( @@ -703,12 +729,41 @@ def wait_for_ready( return cast(list[WorkerProcHandle], ready_proc_handles) def shutdown(self): + if self.rpc_broadcast_mq is not None: + self.rpc_broadcast_mq.shutdown() + if self.worker_response_mq is not None: + self.worker_response_mq.shutdown() self.worker.shutdown() self.rpc_broadcast_mq = None self.worker_response_mq = None destroy_model_parallel() destroy_distributed_environment() + def monitor_death_pipe(self, death_pipe, shutdown_requested: threading.Event): + if death_pipe is None: + return + + def death_pipe_monitor(queues_to_shutdown: list[MessageQueue]): + try: + # This will block until parent process exits (pipe closes) + death_pipe.recv() + except EOFError: + logger.info_once("Parent process exited, terminating worker queues") + shutdown_requested.set() + for mq in queues_to_shutdown: + if mq is not None: + mq.shutdown() + except Exception as e: + logger.warning("Death monitoring error: %s", e) + + # Pass queue references directly to avoid gc issues if passing self + Thread( + target=death_pipe_monitor, + args=([self.rpc_broadcast_mq, self.worker_response_mq],), + daemon=True, + name="DeathPipeMonitor", + ).start() + @staticmethod def worker_main(*args, **kwargs): """Worker initialization and execution loops. @@ -717,12 +772,12 @@ def worker_main(*args, **kwargs): # Signal handler used for graceful termination. # SystemExit exception is only raised once to allow this and worker # processes to terminate without error - shutdown_requested = False + shutdown_requested = threading.Event() def signal_handler(signum, frame): nonlocal shutdown_requested - if not shutdown_requested: - shutdown_requested = True + if not shutdown_requested.is_set(): + shutdown_requested.set() logger.debug( "WorkerProc handling signal %d, raising SystemExit", signum ) @@ -733,33 +788,20 @@ def signal_handler(signum, frame): signal.signal(signal.SIGINT, signal_handler) worker = None - # tuple[Connection, Connection] - reader, ready_writer = kwargs.pop("ready_pipe") - death_pipe: Connection | None = kwargs.pop("death_pipe", None) - shutdown_event = threading.Event() - # Start death monitoring thread if death_pipe is provided - if death_pipe is not None: - - def monitor_parent_death(): - try: - # This will block until parent process exits (pipe closes) - death_pipe.recv() - except EOFError: - # Parent process has exited, terminate this worker - logger.info_once("Parent process exited, terminating worker") - # Send signal to self to trigger clean shutdown - shutdown_event.set() - except Exception as e: - logger.warning("Death monitoring error: %s", e) - - death_monitor = Thread( - target=monitor_parent_death, daemon=True, name="WorkerDeathMonitor" - ) - death_monitor.start() + ready_writer = kwargs.pop("ready_pipe") + death_pipe = kwargs.pop("death_pipe", None) + + # Close inherited pipes from parent (incl. other worker pipes) + # Explicitly passing in existing pipes and closing them makes the pipe + # behave when using fork. Otherwise, a hidden reference to the pipes + # exist in the child process and prevents EOF closure. + for fd in kwargs.pop("inherited_fds", []): + try: + os.close(fd) + except Exception as e: + logger.warning("Exception closing inherited connection: %s", e) try: - reader.close() - # Initialize tracer rank = kwargs.get("rank", 0) maybe_init_worker_tracer( @@ -771,6 +813,8 @@ def monitor_parent_death(): worker = WorkerProc(*args, **kwargs) assert worker.worker_response_mq is not None + worker.monitor_death_pipe(death_pipe, shutdown_requested) + # Send READY once we know everything is loaded ready_writer.send( { @@ -788,7 +832,7 @@ def monitor_parent_death(): ready_writer.close() ready_writer = None - worker.worker_busy_loop(cancel=shutdown_event) + worker.worker_busy_loop() except Exception: # NOTE: if an Exception arises in busy_loop, we send @@ -798,7 +842,7 @@ def monitor_parent_death(): if ready_writer is not None: logger.exception("WorkerProc failed to start.") - elif shutdown_event.is_set(): + elif shutdown_requested.is_set(): logger.info("WorkerProc shutting down.") else: logger.exception("WorkerProc failed.") @@ -806,7 +850,7 @@ def monitor_parent_death(): # The parent sends a SIGTERM to all worker processes if # any worker dies. Set this value so we don't re-throw # SystemExit() to avoid zmq exceptions in __del__. - shutdown_requested = True + shutdown_requested.set() except SystemExit as e: # SystemExit is raised on SIGTERM or SIGKILL, which usually indicates that @@ -859,12 +903,12 @@ def async_output_busy_loop(self): output = self.async_output_queue.get() self.enqueue_output(output) - def worker_busy_loop(self, cancel: threading.Event | None = None): + def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" assert self.rpc_broadcast_mq is not None while True: method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue( - cancel=cancel, indefinite=True + indefinite=True ) try: if isinstance(method, str):