Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def test_models(
[
("facebook/opt-125m", "ray", "", "L4", {}),
("facebook/opt-125m", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("facebook/opt-125m", "mp", "", "L4", {"VLLM_SLEEP_WHEN_IDLE": "1"}),
("meta-llama/Llama-3.2-1B-Instruct", "ray", "", "L4", {}),
("meta-llama/Llama-3.2-1B-Instruct", "mp", "", "L4", {}),
("facebook/opt-125m", "ray", "", "A100", {}),
Expand Down
85 changes: 42 additions & 43 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch.distributed import ProcessGroup
from zmq import ( # type: ignore
IPV6, # type: ignore
PUB,
SUB,
SUBSCRIBE,
XPUB,
Expand Down Expand Up @@ -59,42 +60,6 @@ def long_wait_time_msg(threshold: int) -> str:
)


class SpinTimer:
def record_activity(self):
pass

def spin(self):
sched_yield()


class SpinSleepTimer(SpinTimer):
"""
In setups which have long inactivity periods it is desirable to reduce
system power consumption when vllm does nothing. This would lead to more
CPU thermal headroom when a request eventually comes, especially when
multiple GPUs are connected as each GPU would otherwise pin one thread at
100% CPU usage.

The simplest solution is to reduce polling frequency when there is no
activity for a certain period of time.
"""

def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
self.last_activity = time.monotonic()
self.busy_loop_s = busy_loop_s
self.wait_sleep_s = wait_sleep_s

def record_activity(self):
self.last_activity = time.monotonic()

def spin(self):
curr_time = time.monotonic()
if curr_time >= self.last_activity + self.busy_loop_s:
time.sleep(self.wait_sleep_s)
else:
sched_yield()


class ShmRingBuffer:
def __init__(
self,
Expand Down Expand Up @@ -236,6 +201,7 @@ class Handle:

buffer_handle: tuple[int, int, int, str] | None = None
local_subscribe_addr: str | None = None
local_tick_addr: str | None = None
remote_subscribe_addr: str | None = None
remote_addr_ipv6: bool = False

Expand Down Expand Up @@ -280,11 +246,20 @@ def __init__(
logger.debug("Binding to %s", local_subscribe_addr)
self.local_socket.bind(local_subscribe_addr)

self.local_tick_socket = context.socket(PUB)
local_tick_addr = get_open_zmq_ipc_path()
self.local_tick_socket.bind(local_tick_addr)
logger.debug("Binding to %s", local_tick_addr)
self.last_activity = 0.0

self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_addr = None
self.local_socket = None
local_tick_addr = None
self.local_tick_socket = None
self.last_activity = 0.0
self.current_idx = -1

remote_addr_ipv6 = False
Expand Down Expand Up @@ -312,12 +287,12 @@ def __init__(
self.local_reader_rank = -1
# rank does not matter for remote readers
self._is_remote_reader = False
self._read_spin_timer = SpinTimer()

self.handle = Handle(
local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle() if self.buffer is not None else None,
local_subscribe_addr=local_subscribe_addr,
local_tick_addr=local_tick_addr,
remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6,
)
Expand Down Expand Up @@ -349,11 +324,15 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
logger.debug("Connecting to %s", socket_addr)
self.local_socket.connect(socket_addr)

self.local_tick_socket = context.socket(SUB)
self.local_tick_socket.setsockopt_string(SUBSCRIBE, "")
socket_addr = handle.local_tick_addr
logger.debug("Connecting to %s", socket_addr)
self.local_tick_socket.connect(socket_addr)
self.last_activity = 0.0

self.remote_socket = None

self._read_spin_timer = (
SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
)
else:
self.buffer = None # type: ignore
self.current_idx = -1
Expand All @@ -362,6 +341,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self._is_remote_reader = True

self.local_socket = None
self.local_tick_socket = None
self.last_activity = 0.0

self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
Expand Down Expand Up @@ -483,8 +464,25 @@ def acquire_read(
# if this block is not ready,
# we need to wait until it is written

# Release the processor to other threads
self._read_spin_timer.spin()
if time.monotonic() - self.last_activity > 0.1:
poll_timeout_ms = 50
if timeout is not None:
remaining = start_time + timeout - time.monotonic()
poll_timeout_ms = min(
int(remaining * 1000), poll_timeout_ms
)
if poll_timeout_ms > 0:
ev = self.local_tick_socket.poll(timeout=poll_timeout_ms)
if ev > 0:
while True:
try:
self.local_tick_socket.recv(
flags=zmq.NOBLOCK, copy=False
)
except zmq.Again:
break
else:
sched_yield()

if cancel is not None and cancel.is_set():
raise RuntimeError("cancelled")
Expand Down Expand Up @@ -514,7 +512,7 @@ def acquire_read(
metadata_buffer[self.local_reader_rank + 1] = 1
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks

self._read_spin_timer.record_activity()
self.last_activity = time.monotonic()
break

def enqueue(self, obj, timeout: float | None = None):
Expand Down Expand Up @@ -557,6 +555,7 @@ def oob_callback(buf: PickleBuffer) -> bool:
buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
buf[buf_offset : (offset := buf_offset + buf_len)] = buffer

self.local_tick_socket.send(b"\x00")
if self.n_remote_reader > 0:
self.remote_socket.send_multipart(all_buffers, copy=False)

Expand Down
4 changes: 0 additions & 4 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@
] = "allgather_reducescatter"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
Expand Down Expand Up @@ -1325,9 +1324,6 @@ def get_vllm_port() -> int | None:
"VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS": lambda: int(
os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1")
),
# Reduce CPU usage when vLLM is idle. Enabling this will incur small
# latency penalty when a request eventually comes.
"VLLM_SLEEP_WHEN_IDLE": lambda: bool(int(os.getenv("VLLM_SLEEP_WHEN_IDLE", "0"))),
# Control the max chunk bytes (in MB) for the rpc message queue.
# Object larger than this threshold will be broadcast to worker
# processes via zmq.
Expand Down