Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/utils/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def split_zmq_path(path: str) -> tuple[str, str, str]:

scheme = parsed.scheme
host = parsed.hostname or ""
port = str(parsed.port or "")
port = "" if parsed.port is None else str(parsed.port)
if host.startswith("[") and host.endswith("]"):
host = host[1:-1] # Remove brackets for IPv6 address

Expand Down
64 changes: 57 additions & 7 deletions vllm/v1/engine/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import multiprocessing
import multiprocessing.connection
import time
import weakref

Expand All @@ -10,7 +11,7 @@

from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.network_utils import get_tcp_uri, make_zmq_socket
from vllm.utils.system_utils import get_mp_context, set_process_title
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
Expand Down Expand Up @@ -55,6 +56,25 @@ class DPCoordinator:
request wave / running state changes.
"""

def _wait_for_zmq_addrs(self, zmq_addr_pipe) -> tuple[str, str, str]:
try:
ready = multiprocessing.connection.wait(
Comment thread
itayalroy marked this conversation as resolved.
[zmq_addr_pipe, self.proc.sentinel], timeout=30
)
if not ready:
raise RuntimeError(
"DP Coordinator process failed to report ZMQ addresses "
"during startup."
)
try:
return zmq_addr_pipe.recv()
except EOFError:
raise RuntimeError(
"DP Coordinator process failed during startup."
) from None
Comment on lines +61 to +74

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _wait_for_zmq_addrs method includes a timeout of 30 seconds. If the DP Coordinator process fails to report ZMQ addresses within this time, a RuntimeError is raised. However, there's no mechanism to handle or retry this failure. Consider adding a retry mechanism or a more robust error handling strategy to improve the resilience of the system. This is a critical issue because a failure here will prevent the engine from starting up.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fatal IMO, if the DP Coordinator cannot report ZMQ addresses within 30 seconds it is reasonable to fail

@zch42 zch42 May 6, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itayalroy @tlrmchlsmth can we make the timeout configurable? The current 30s limit can be too short, for example when spawn is forced, the child process will re-import many modules

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue. from vllm.v1.engine import coordinator takes 70+ seconds to import.

finally:
zmq_addr_pipe.close()

def __init__(
self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True
):
Expand All @@ -66,18 +86,24 @@ def __init__(
# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
local_only = not parallel_config.local_engines_only
front_publish_address = get_engine_client_zmq_addr(
local_only=local_only, host=host
)

local_only_eng = dp_size == parallel_config.data_parallel_size_local
# NOTE(yongji): handling scaling from intra-node to inter-node
if parallel_config.enable_elastic_ep:
local_only_eng = False
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)

def bind_address(local_only: bool) -> str:
return (
get_engine_client_zmq_addr(local_only=True, host=host)
if local_only
else get_tcp_uri(host, 0)
)
Comment on lines +94 to +99

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The bind_address function uses get_engine_client_zmq_addr when local_only is true, which returns an IPC path. However, when local_only is false, it uses get_tcp_uri with port 0, which requests the OS to assign a port. This inconsistency in address types (IPC vs. TCP) could lead to unexpected behavior or configuration issues. Ensure that the address type is consistent based on the deployment environment or configuration. This is a high severity issue because it can lead to connectivity problems.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inconsistency in address types (IPC/TCP) already exists, the only change is that with TCP we now let the OS assign the port on bind time instead of binding to a pre-chosen port that might be already taken


front_publish_address = bind_address(local_only)
back_publish_address = bind_address(local_only_eng)
back_output_address = bind_address(local_only_eng)

context = get_mp_context()
parent_zmq_addr_pipe, child_zmq_addr_pipe = context.Pipe(duplex=False)
self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator,
name="VLLM_DP_Coordinator",
Expand All @@ -86,11 +112,18 @@ def __init__(
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"zmq_addr_pipe": child_zmq_addr_pipe,
"enable_wave_coordination": enable_wave_coordination,
},
daemon=True,
)
self.proc.start()
child_zmq_addr_pipe.close()
(
front_publish_address,
back_output_address,
back_publish_address,
) = self._wait_for_zmq_addrs(parent_zmq_addr_pipe)
Comment on lines +121 to +126

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

After starting the coordinator process, the parent process retrieves the bound ZMQ addresses using self._wait_for_zmq_addrs. However, if self._wait_for_zmq_addrs fails, the addresses used to initialize self.stats_publish_address and self.coord_in_address will be the original, unbound addresses. This could lead to the parent process attempting to communicate with the coordinator on the wrong ports. This is a critical issue because it can lead to communication failures between the parent and coordinator processes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. If _wait_for_zmq_addrs() fails, we raise an exception, so we never proceed using wrong ports.


self.stats_publish_address = front_publish_address
self.coord_in_address = back_publish_address
Expand Down Expand Up @@ -136,6 +169,7 @@ def run_coordinator(
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
zmq_addr_pipe=None,
min_stats_update_interval_ms: int = 100,
enable_wave_coordination: bool = True,
):
Expand All @@ -149,15 +183,20 @@ def run_coordinator(
front_publish_address,
back_output_address,
back_publish_address,
zmq_addr_pipe,
)
except KeyboardInterrupt:
logger.info("DP Coordinator process exiting")
finally:
if zmq_addr_pipe is not None:
zmq_addr_pipe.close()

def process_input_socket(
self,
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
zmq_addr_pipe=None,
):
decoder = MsgpackDecoder(EngineCoreOutputs)

Expand Down Expand Up @@ -191,6 +230,17 @@ def process_input_socket(
bind=True,
) as publish_back,
):
if zmq_addr_pipe is not None:
try:
zmq_addr_pipe.send(
(
publish_front.getsockopt(zmq.LAST_ENDPOINT).decode(),
output_back.getsockopt(zmq.LAST_ENDPOINT).decode(),
publish_back.getsockopt(zmq.LAST_ENDPOINT).decode(),
)
)
finally:
zmq_addr_pipe.close()
# Wait until all engines subscribe.
for _ in self.engines:
if publish_back.recv() != b"\x01":
Expand Down
Loading