-
-
Notifications
You must be signed in to change notification settings - Fork 17.9k
Fix DP coordinator ZMQ port TOCTOU #37452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import copy | ||
| import multiprocessing | ||
| import multiprocessing.connection | ||
| import time | ||
| import weakref | ||
|
|
||
|
|
@@ -10,7 +11,7 @@ | |
|
|
||
| from vllm.config import ParallelConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.utils.network_utils import make_zmq_socket | ||
| from vllm.utils.network_utils import get_tcp_uri, make_zmq_socket | ||
| from vllm.utils.system_utils import get_mp_context, set_process_title | ||
| from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType | ||
| from vllm.v1.serial_utils import MsgpackDecoder | ||
|
|
@@ -55,6 +56,25 @@ class DPCoordinator: | |
| request wave / running state changes. | ||
| """ | ||
|
|
||
| def _wait_for_zmq_addrs(self, zmq_addr_pipe) -> tuple[str, str, str]: | ||
| try: | ||
| ready = multiprocessing.connection.wait( | ||
| [zmq_addr_pipe, self.proc.sentinel], timeout=30 | ||
| ) | ||
| if not ready: | ||
| raise RuntimeError( | ||
| "DP Coordinator process failed to report ZMQ addresses " | ||
| "during startup." | ||
| ) | ||
| try: | ||
| return zmq_addr_pipe.recv() | ||
| except EOFError: | ||
| raise RuntimeError( | ||
| "DP Coordinator process failed during startup." | ||
| ) from None | ||
|
Comment on lines
+61
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fatal IMO, if the DP Coordinator cannot report ZMQ addresses within 30 seconds it is reasonable to fail There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @itayalroy @tlrmchlsmth can we make the timeout configurable? The current 30s limit can be too short, for example when
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same issue. |
||
| finally: | ||
| zmq_addr_pipe.close() | ||
|
|
||
| def __init__( | ||
| self, parallel_config: ParallelConfig, enable_wave_coordination: bool = True | ||
| ): | ||
|
|
@@ -66,18 +86,24 @@ def __init__( | |
| # Assume coordinator is colocated with front-end procs when not in | ||
| # either external or hybrid DP LB mode. | ||
| local_only = not parallel_config.local_engines_only | ||
| front_publish_address = get_engine_client_zmq_addr( | ||
| local_only=local_only, host=host | ||
| ) | ||
|
|
||
| local_only_eng = dp_size == parallel_config.data_parallel_size_local | ||
| # NOTE(yongji): handling scaling from intra-node to inter-node | ||
| if parallel_config.enable_elastic_ep: | ||
| local_only_eng = False | ||
| back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) | ||
| back_output_address = get_engine_client_zmq_addr(local_only_eng, host) | ||
|
|
||
| def bind_address(local_only: bool) -> str: | ||
| return ( | ||
| get_engine_client_zmq_addr(local_only=True, host=host) | ||
| if local_only | ||
| else get_tcp_uri(host, 0) | ||
| ) | ||
|
Comment on lines
+94
to
+99
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The inconsistency in address types (IPC/TCP) already exists, the only change is that with TCP we now let the OS assign the port on bind time instead of binding to a pre-chosen port that might be already taken |
||
|
|
||
| front_publish_address = bind_address(local_only) | ||
| back_publish_address = bind_address(local_only_eng) | ||
| back_output_address = bind_address(local_only_eng) | ||
|
|
||
| context = get_mp_context() | ||
| parent_zmq_addr_pipe, child_zmq_addr_pipe = context.Pipe(duplex=False) | ||
| self.proc: multiprocessing.Process = context.Process( | ||
| target=DPCoordinatorProc.run_coordinator, | ||
| name="VLLM_DP_Coordinator", | ||
|
|
@@ -86,11 +112,18 @@ def __init__( | |
| "front_publish_address": front_publish_address, | ||
| "back_output_address": back_output_address, | ||
| "back_publish_address": back_publish_address, | ||
| "zmq_addr_pipe": child_zmq_addr_pipe, | ||
| "enable_wave_coordination": enable_wave_coordination, | ||
| }, | ||
| daemon=True, | ||
| ) | ||
| self.proc.start() | ||
| child_zmq_addr_pipe.close() | ||
| ( | ||
| front_publish_address, | ||
| back_output_address, | ||
| back_publish_address, | ||
| ) = self._wait_for_zmq_addrs(parent_zmq_addr_pipe) | ||
|
Comment on lines
+121
to
+126
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After starting the coordinator process, the parent process retrieves the bound ZMQ addresses using
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. If |
||
|
|
||
| self.stats_publish_address = front_publish_address | ||
| self.coord_in_address = back_publish_address | ||
|
|
@@ -136,6 +169,7 @@ def run_coordinator( | |
| front_publish_address: str, | ||
| back_output_address: str, | ||
| back_publish_address: str, | ||
| zmq_addr_pipe=None, | ||
| min_stats_update_interval_ms: int = 100, | ||
| enable_wave_coordination: bool = True, | ||
| ): | ||
|
|
@@ -149,15 +183,20 @@ def run_coordinator( | |
| front_publish_address, | ||
| back_output_address, | ||
| back_publish_address, | ||
| zmq_addr_pipe, | ||
| ) | ||
| except KeyboardInterrupt: | ||
| logger.info("DP Coordinator process exiting") | ||
| finally: | ||
| if zmq_addr_pipe is not None: | ||
| zmq_addr_pipe.close() | ||
|
|
||
| def process_input_socket( | ||
| self, | ||
| front_publish_address: str, | ||
| back_output_address: str, | ||
| back_publish_address: str, | ||
| zmq_addr_pipe=None, | ||
| ): | ||
| decoder = MsgpackDecoder(EngineCoreOutputs) | ||
|
|
||
|
|
@@ -191,6 +230,17 @@ def process_input_socket( | |
| bind=True, | ||
| ) as publish_back, | ||
| ): | ||
| if zmq_addr_pipe is not None: | ||
| try: | ||
| zmq_addr_pipe.send( | ||
| ( | ||
| publish_front.getsockopt(zmq.LAST_ENDPOINT).decode(), | ||
| output_back.getsockopt(zmq.LAST_ENDPOINT).decode(), | ||
| publish_back.getsockopt(zmq.LAST_ENDPOINT).decode(), | ||
| ) | ||
| ) | ||
| finally: | ||
| zmq_addr_pipe.close() | ||
| # Wait until all engines subscribe. | ||
| for _ in self.engines: | ||
| if publish_back.recv() != b"\x01": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.