-
-
Notifications
You must be signed in to change notification settings - Fork 18.9k
[BugFix] Use late binding to avoid zmq port conflict race conditions #30520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7c6cf81
78ab95c
eceb313
3e98439
19b52ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| Iterator, | ||
| Sequence, | ||
| ) | ||
| from typing import Any | ||
| from typing import Any, Literal, overload | ||
| from urllib.parse import urlparse | ||
| from uuid import uuid4 | ||
|
|
||
|
|
@@ -170,7 +170,7 @@ def get_open_ports_list(count: int = 5) -> list[int]: | |
| """Get a list of open ports.""" | ||
| ports = set[int]() | ||
| while len(ports) < count: | ||
| ports.add(get_open_port()) | ||
| ports.add(_get_open_port()) | ||
| return list(ports) | ||
|
|
||
|
|
||
|
|
@@ -254,17 +254,130 @@ def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: | |
| return f"{scheme}://{host}:{port}" | ||
|
|
||
|
|
||
| def is_wildcard_addr(addr: str) -> bool: | ||
| """Check if an address is a TCP address with wildcard port requiring late binding. | ||
|
|
||
| A wildcard port address has port 0, which tells the OS to assign an available | ||
| port. The host can be specific (e.g., "tcp://192.168.1.5:0") or wildcard | ||
| (e.g., "tcp://*:0"). | ||
|
|
||
| Args: | ||
| addr: Address string to check | ||
|
|
||
| Returns: | ||
| True if the address is a TCP address with wildcard port (:0) | ||
|
|
||
| Examples: | ||
| >>> is_wildcard_addr("tcp://*:0") | ||
| True | ||
| >>> is_wildcard_addr("tcp://192.168.1.5:0") | ||
| True | ||
| >>> is_wildcard_addr("tcp://127.0.0.1:8080") | ||
| False | ||
| >>> is_wildcard_addr("ipc:///tmp/socket") | ||
| False | ||
| """ | ||
| return addr.startswith("tcp://") and ":0" in addr | ||
|
|
||
|
|
||
| def bind_zmq_socket_and_get_address( | ||
| ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined] | ||
| wildcard_addr: str, | ||
| socket_type: Any, | ||
| **socket_opts: Any, | ||
| ) -> tuple[zmq.Socket | zmq.asyncio.Socket, str]: # type: ignore[name-defined] | ||
| """ | ||
| Bind a ZMQ socket to an address and return the actual bound address. | ||
|
|
||
| For TCP wildcard addresses like "tcp://*:0", binds to let the OS assign | ||
| a port, then discovers the actual port via socket.last_endpoint. | ||
|
|
||
| For IPC addresses, binds directly without modification. | ||
|
|
||
| This eliminates port race conditions by binding immediately and discovering | ||
| the actual assigned port, rather than pre-allocating a port that could be | ||
| stolen before binding. | ||
|
|
||
| Note: This is a convenience wrapper around make_zmq_socket() with | ||
| return_address=True. Prefer using make_zmq_socket() directly for new code. | ||
|
|
||
| Args: | ||
| ctx: ZMQ context (async or sync) | ||
| wildcard_addr: Address to bind (e.g., "tcp://*:0" or "ipc:///tmp/path") | ||
| socket_type: ZMQ socket type constant (zmq.ROUTER, zmq.PULL, etc.) | ||
| **socket_opts: Additional options passed to make_zmq_socket | ||
| (identity, linger, etc.) | ||
|
|
||
| Returns: | ||
| (socket, actual_address) tuple where: | ||
| - socket: The bound ZMQ socket (caller must keep alive) | ||
| - actual_address: Real address with OS-assigned port | ||
|
|
||
| Example: | ||
| >>> ctx = zmq.Context() | ||
| >>> sock, addr = bind_zmq_socket_and_get_address(ctx, "tcp://*:0", zmq.ROUTER) | ||
| >>> print(addr) # "tcp://127.0.0.1:54321" | ||
| """ | ||
| # Use make_zmq_socket with return_address=True to handle both wildcard | ||
| # and non-wildcard addresses uniformly | ||
| return make_zmq_socket( | ||
| ctx, wildcard_addr, socket_type, bind=True, return_address=True, **socket_opts | ||
| ) | ||
|
|
||
|
|
||
| # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 | ||
| @overload | ||
| def make_zmq_socket( | ||
| ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined] | ||
| path: str, | ||
| socket_type: Any, | ||
| bind: bool | None = ..., | ||
| identity: bytes | None = ..., | ||
| linger: int | None = ..., | ||
| *, | ||
| return_address: Literal[True], | ||
| ) -> tuple[zmq.Socket | zmq.asyncio.Socket, str]: ... # type: ignore[name-defined] | ||
|
|
||
|
|
||
| @overload | ||
| def make_zmq_socket( | ||
| ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined] | ||
| path: str, | ||
| socket_type: Any, | ||
| bind: bool | None = ..., | ||
| identity: bytes | None = ..., | ||
| linger: int | None = ..., | ||
| *, | ||
| return_address: Literal[False] = ..., | ||
| ) -> zmq.Socket | zmq.asyncio.Socket: ... # type: ignore[name-defined] | ||
|
|
||
|
|
||
| def make_zmq_socket( | ||
| ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined] | ||
| path: str, | ||
| socket_type: Any, | ||
| bind: bool | None = None, | ||
| identity: bytes | None = None, | ||
| linger: int | None = None, | ||
| ) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined] | ||
| """Make a ZMQ socket with the proper bind/connect semantics.""" | ||
| return_address: bool = False, | ||
| ) -> zmq.Socket | zmq.asyncio.Socket | tuple[zmq.Socket | zmq.asyncio.Socket, str]: # type: ignore[name-defined] | ||
| """Make a ZMQ socket with the proper bind/connect semantics. | ||
|
|
||
| Args: | ||
| ctx: ZMQ context | ||
| path: Socket path/address to bind or connect to | ||
| socket_type: ZMQ socket type | ||
| bind: Whether to bind (True) or connect (False). If None, auto-determined. | ||
| identity: Optional socket identity | ||
| linger: Optional linger value | ||
| return_address: If True, return (socket, actual_address) tuple. | ||
| For wildcard addresses, returns the discovered address. | ||
| For non-wildcard addresses, returns the input path. | ||
|
|
||
| Returns: | ||
| Socket if return_address=False (default for backward compatibility) | ||
| (socket, actual_address) tuple if return_address=True | ||
| """ | ||
| mem = psutil.virtual_memory() | ||
| socket = ctx.socket(socket_type) | ||
|
|
||
|
|
@@ -305,10 +418,27 @@ def make_zmq_socket( | |
|
|
||
| if bind: | ||
| socket.bind(path) | ||
| # For wildcard port addresses, discover the actual bound address. | ||
| if return_address and is_wildcard_addr(path): | ||
| # last_endpoint is bytes like b"tcp://192.168.1.5:54321" or b"tcp://[::1]:54321" | ||
| actual_endpoint = socket.last_endpoint.decode("utf-8") | ||
|
|
||
| # Parse the endpoint to extract host and port | ||
| # Handle both IPv4 and IPv6 formats | ||
| scheme, host, port_str = split_zmq_path(actual_endpoint) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. split_zmq_path 当port是0的时候,解析出来的port是“”,导致报错。port = str(parsed.port or "")改为port = str(parsed.port) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. like this : |
||
| if scheme != "tcp": | ||
| # Shouldn't happen for wildcard TCP addresses, but fallback safely | ||
| actual_address = actual_endpoint | ||
| else: | ||
| # Preserve the host from the bound endpoint | ||
| actual_address = make_zmq_path(scheme, host, int(port_str)) | ||
| else: | ||
| actual_address = path | ||
| else: | ||
| socket.connect(path) | ||
| actual_address = path | ||
|
|
||
| return socket | ||
| return (socket, actual_address) if return_address else socket | ||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,13 +4,15 @@ | |
| import multiprocessing | ||
| import time | ||
| import weakref | ||
| from contextlib import ExitStack | ||
| from multiprocessing.connection import Connection | ||
|
|
||
| import msgspec.msgpack | ||
| import zmq | ||
|
|
||
| from vllm.config import ParallelConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.utils.network_utils import make_zmq_socket | ||
| from vllm.utils.network_utils import is_wildcard_addr, make_zmq_socket | ||
| from vllm.utils.system_utils import get_mp_context, set_process_title | ||
| from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType | ||
| from vllm.v1.serial_utils import MsgpackDecoder | ||
|
|
@@ -76,6 +78,9 @@ def __init__( | |
| back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) | ||
| back_output_address = get_engine_client_zmq_addr(local_only_eng, host) | ||
|
|
||
| # Create pipe for late binding address reporting | ||
| parent_conn, child_conn = multiprocessing.Pipe() | ||
|
|
||
| context = get_mp_context() | ||
| self.proc: multiprocessing.Process = context.Process( | ||
| target=DPCoordinatorProc.run_coordinator, | ||
|
|
@@ -85,12 +90,47 @@ def __init__( | |
| "front_publish_address": front_publish_address, | ||
| "back_output_address": back_output_address, | ||
| "back_publish_address": back_publish_address, | ||
| "address_report_pipe": child_conn, # For late binding | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. File descriptor leak from unclosed child pipe connectionsMedium Severity When creating pipes for late binding address reporting, 🔬 Verification TestTest code: import multiprocessing
import os
def count_fds():
"""Count open file descriptors for current process"""
return len(os.listdir(f'/proc/{os.getpid()}/fd'))
def child_func(pipe):
pipe.send("hello")
pipe.close()
# Simulate the leak pattern from the code
initial_fds = count_fds()
print(f"Initial FDs: {initial_fds}")
for i in range(10):
parent_conn, child_conn = multiprocessing.Pipe()
proc = multiprocessing.Process(target=child_func, args=(child_conn,))
proc.start()
# Parent receives and closes parent_conn (like the code does)
parent_conn.recv()
parent_conn.close()
proc.join()
# BUG: child_conn is never closed in parent!
# child_conn.close() # This line is missing
leaked_fds = count_fds()
print(f"FDs after 10 iterations: {leaked_fds}")
print(f"Leaked FDs: {leaked_fds - initial_fds}")Command run: Output: Why this proves the bug: Each iteration leaks one file descriptor because Additional Locations (1) |
||
| "enable_wave_coordination": enable_wave_coordination, | ||
| }, | ||
| daemon=True, | ||
| ) | ||
| self.proc.start() | ||
|
|
||
| # Wait for coordinator to report actual addresses (for late binding) | ||
| needs_late_binding = any( | ||
| is_wildcard_addr(x) | ||
| for x in (front_publish_address, back_publish_address, back_output_address) | ||
| ) | ||
|
|
||
| if needs_late_binding: | ||
| try: | ||
| if not parent_conn.poll(timeout=30.0): # 30 second timeout | ||
| raise TimeoutError( | ||
| "DP Coordinator proc did not report addresses within 30 seconds" | ||
| ) | ||
| addr_report = parent_conn.recv() | ||
| front_publish_address = addr_report.get( | ||
| "front_publish", front_publish_address | ||
| ) | ||
| back_publish_address = addr_report.get( | ||
| "back_publish", back_publish_address | ||
| ) | ||
| back_output_address = addr_report.get( | ||
| "back_output", back_output_address | ||
| ) | ||
| logger.debug( | ||
| "DP Coordinator reported addresses:" | ||
| " front=%s, back_pub=%s, back_out=%s", | ||
| front_publish_address, | ||
| back_publish_address, | ||
| back_output_address, | ||
| ) | ||
| except Exception as e: | ||
| logger.error("Failed to get addresses from DP Coordinator: %s", e) | ||
| raise | ||
| parent_conn.close() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Child process crashes when late binding not neededHigh Severity The parent process always creates a pipe and passes it to the child, but only receives from it when 🔬 Verification TestWhy verification test was not possible: This bug involves multiprocessing race conditions between parent and child processes in Python's multiprocessing framework. Testing would require spawning actual processes and coordinating their timing, which cannot be done in a simple unit test without the full vLLM environment. The bug is confirmed through code analysis: the parent unconditionally passes the pipe to the child, but only calls Additional Locations (1) |
||
|
|
||
| self.stats_publish_address = front_publish_address | ||
| self.coord_in_address = back_publish_address | ||
| self.coord_out_address = back_output_address | ||
|
|
@@ -133,6 +173,7 @@ def run_coordinator( | |
| front_publish_address: str, | ||
| back_output_address: str, | ||
| back_publish_address: str, | ||
| address_report_pipe: Connection | None = None, | ||
| min_stats_update_interval_ms: int = 100, | ||
| enable_wave_coordination: bool = True, | ||
| ): | ||
|
|
@@ -146,6 +187,7 @@ def run_coordinator( | |
| front_publish_address, | ||
| back_output_address, | ||
| back_publish_address, | ||
| address_report_pipe, | ||
| ) | ||
| except KeyboardInterrupt: | ||
| logger.info("DP Coordinator process exiting") | ||
|
|
@@ -155,6 +197,7 @@ def process_input_socket( | |
| front_publish_address: str, | ||
| back_output_address: str, | ||
| back_publish_address: str, | ||
| address_report_pipe: Connection | None = None, | ||
| ): | ||
| decoder = MsgpackDecoder(EngineCoreOutputs) | ||
|
|
||
|
|
@@ -168,26 +211,46 @@ def process_input_socket( | |
| last_stats_wave = -1 | ||
| last_step_counts: list[list[int]] | None = None | ||
|
|
||
| with ( | ||
| make_zmq_socket( | ||
| path=front_publish_address, # IPC | ||
| with ExitStack() as sockets_to_close: | ||
| # Bind sockets with late binding support (auto-discovers wildcard ports). | ||
| publish_front, actual_front = make_zmq_socket( | ||
| ctx=self.ctx, | ||
| path=front_publish_address, | ||
| socket_type=zmq.XPUB, | ||
| bind=True, | ||
| ) as publish_front, | ||
| make_zmq_socket( | ||
| path=back_output_address, # IPC or TCP | ||
| return_address=True, | ||
| ) | ||
| sockets_to_close.enter_context(publish_front) | ||
|
|
||
| output_back, actual_back_out = make_zmq_socket( | ||
| ctx=self.ctx, | ||
| path=back_output_address, | ||
| socket_type=zmq.PULL, | ||
| bind=True, | ||
| ) as output_back, | ||
| make_zmq_socket( | ||
| path=back_publish_address, # IPC or TCP | ||
| return_address=True, | ||
| ) | ||
| sockets_to_close.enter_context(output_back) | ||
|
|
||
| publish_back, actual_back_pub = make_zmq_socket( | ||
| ctx=self.ctx, | ||
| path=back_publish_address, | ||
| socket_type=zmq.XPUB, | ||
| bind=True, | ||
| ) as publish_back, | ||
| ): | ||
| return_address=True, | ||
| ) | ||
| sockets_to_close.enter_context(publish_back) | ||
|
|
||
| # Report actual addresses to parent | ||
| if address_report_pipe is not None: | ||
| address_report_pipe.send( | ||
| { | ||
| "front_publish": actual_front, | ||
| "back_output": actual_back_out, | ||
| "back_publish": actual_back_pub, | ||
| } | ||
| ) | ||
| address_report_pipe.close() | ||
|
|
||
| # Wait until all engines subscribe. | ||
| for _ in self.engines: | ||
| if publish_back.recv() != b"\x01": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wildcard address check matches IPv6 addresses containing
:0Low Severity
The
is_wildcard_addrfunction uses":0" in addrto check for wildcard port addresses, but this can incorrectly match IPv6 addresses that contain:0in the address portion. For example,tcp://[2001:db8::1:0]:8080would be incorrectly identified as a wildcard because the IPv6 address ends with hex group0. The check should useaddr.endswith(":0")instead, since the port is always at the end of the address string. This causes false positives leading to unnecessary late-binding overhead, though functionality remains correct.🔬 Verification Test
Test code:
Command run:
Output:
Why this proves the bug: The output shows that an IPv6 address with port 8080 (not a wildcard) is incorrectly identified as a wildcard address because
:0appears in the IPv6 portion of the address.