diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 521d6c33dd39..3f4693a90dee 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -124,8 +124,6 @@ def test_models( [ ("facebook/opt-125m", "ray", "", "L4", {}), ("facebook/opt-125m", "mp", "", "L4", {}), - ("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), - ("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}), ("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}), ("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}), ("facebook/opt-125m", "ray", "", "A100", {}), diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 052df19e34d7..dc67d0a41315 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -17,6 +17,7 @@ from torch.distributed import ProcessGroup from zmq import ( # type: ignore IPV6, # type: ignore + PUB, SUB, SUBSCRIBE, XPUB, @@ -59,42 +60,6 @@ def long_wait_time_msg(threshold: int) -> str: ) -class SpinTimer: - def record_activity(self): - pass - - def spin(self): - sched_yield() - - -class SpinSleepTimer(SpinTimer): - """ - In setups which have long inactivity periods it is desirable to reduce - system power consumption when vllm does nothing. This would lead to more - CPU thermal headroom when a request eventually comes, especially when - multiple GPUs are connected as each GPU would otherwise pin one thread at - 100% CPU usage. - - The simplest solution is to reduce polling frequency when there is no - activity for a certain period of time. - """ - - def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1): - self.last_activity = time.monotonic() - self.busy_loop_s = busy_loop_s - self.wait_sleep_s = wait_sleep_s - - def record_activity(self): - self.last_activity = time.monotonic() - - def spin(self): - curr_time = time.monotonic() - if curr_time >= self.last_activity + self.busy_loop_s: - time.sleep(self.wait_sleep_s) - else: - sched_yield() - - class ShmRingBuffer: def __init__( self, @@ -236,6 +201,7 @@ class Handle: buffer_handle: tuple[int, int, int, str] | None = None local_subscribe_addr: str | None = None + local_tick_addr: str | None = None remote_subscribe_addr: str | None = None remote_addr_ipv6: bool = False @@ -280,11 +246,20 @@ def __init__( logger.debug("Binding to %s", local_subscribe_addr) self.local_socket.bind(local_subscribe_addr) + self.local_tick_socket = context.socket(PUB) + local_tick_addr = get_open_zmq_ipc_path() + self.local_tick_socket.bind(local_tick_addr) + logger.debug("Binding to %s", local_tick_addr) + self.last_activity = 0.0 + self.current_idx = 0 else: self.buffer = None # type: ignore local_subscribe_addr = None self.local_socket = None + local_tick_addr = None + self.local_tick_socket = None + self.last_activity = 0.0 self.current_idx = -1 remote_addr_ipv6 = False @@ -312,12 +287,12 @@ def __init__( self.local_reader_rank = -1 # rank does not matter for remote readers self._is_remote_reader = False - self._read_spin_timer = SpinTimer() self.handle = Handle( local_reader_ranks=local_reader_ranks, buffer_handle=self.buffer.handle() if self.buffer is not None else None, local_subscribe_addr=local_subscribe_addr, + local_tick_addr=local_tick_addr, remote_subscribe_addr=remote_subscribe_addr, remote_addr_ipv6=remote_addr_ipv6, ) @@ -349,11 +324,15 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": logger.debug("Connecting to %s", socket_addr) self.local_socket.connect(socket_addr) + self.local_tick_socket = context.socket(SUB) + self.local_tick_socket.setsockopt_string(SUBSCRIBE, "") + socket_addr = handle.local_tick_addr + logger.debug("Connecting to %s", socket_addr) + self.local_tick_socket.connect(socket_addr) + self.last_activity = 0.0 + self.remote_socket = None - self._read_spin_timer = ( - SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer() - ) else: self.buffer = None # type: ignore self.current_idx = -1 @@ -362,6 +341,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self._is_remote_reader = True self.local_socket = None + self.local_tick_socket = None + self.last_activity = 0.0 self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") @@ -483,8 +464,25 @@ def acquire_read( # if this block is not ready, # we need to wait until it is written - # Release the processor to other threads - self._read_spin_timer.spin() + if time.monotonic() - self.last_activity > 0.1: + poll_timeout_ms = 50 + if timeout is not None: + remaining = start_time + timeout - time.monotonic() + poll_timeout_ms = min( + int(remaining * 1000), poll_timeout_ms + ) + if poll_timeout_ms > 0: + ev = self.local_tick_socket.poll(timeout=poll_timeout_ms) + if ev > 0: + while True: + try: + self.local_tick_socket.recv( + flags=zmq.NOBLOCK, copy=False + ) + except zmq.Again: + break + else: + sched_yield() if cancel is not None and cancel.is_set(): raise RuntimeError("cancelled") @@ -514,7 +512,7 @@ def acquire_read( metadata_buffer[self.local_reader_rank + 1] = 1 self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks - self._read_spin_timer.record_activity() + self.last_activity = time.monotonic() break def enqueue(self, obj, timeout: float | None = None): @@ -557,6 +555,7 @@ def oob_callback(buf: PickleBuffer) -> bool: buf[offset:buf_offset] = to_bytes_big(buf_len, 4) buf[buf_offset : (offset := buf_offset + buf_len)] = buffer + self.local_tick_socket.send(b"\x00") if self.n_remote_reader > 0: self.remote_socket.send_multipart(all_buffers, copy=False) diff --git a/vllm/envs.py b/vllm/envs.py index 8b954fa14f28..d99a1bd47c3b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -185,7 +185,6 @@ ] = "allgather_reducescatter" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 - VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None @@ -1325,9 +1324,6 @@ def get_vllm_port() -> int | None: "VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int( os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") ), - # Reduce CPU usage when vLLM is idle. Enabling this will incur small - # latency penalty when a request eventually comes. - "VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq.