Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,18 +634,45 @@ def _perform_handshakes(
vllm_config,
vllm_config.parallel_config,
)

def handle_err(zmq_socket: zmq.Socket, err: Exception):
if zmq_socket is not None:
# Send failure message to front-end.
logger.exception("EngineCore failed to start.")
zmq_socket.send(
msgspec.msgpack.encode(
{
"status": "FAILED",
"local": is_local,
"headless": headless,
"error_msg": str(err),
}
)
)

if client_handshake_address is None:
with handshake as addresses:
yield addresses
with handshake as (addresses, zmq_socket):
try:
yield addresses
except Exception as e:
handle_err(zmq_socket, e)
raise
else:
assert local_client
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, True, False, vllm_config
)
with handshake as addresses, local_handshake as client_addresses:
with (
handshake as (addresses, _),
local_handshake as (client_addresses, zmq_socket),
):
addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs
yield addresses
try:
yield addresses
except Exception as e:
handle_err(zmq_socket, e)
raise

# Update config which may have changed from the handshake
vllm_config.__post_init__()
Expand All @@ -660,7 +687,7 @@ def _perform_handshake(
headless: bool,
vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]:
) -> Generator[tuple[EngineZmqAddresses, zmq.Socket], None, None]:
with make_zmq_socket(
ctx,
handshake_address,
Expand All @@ -673,7 +700,7 @@ def _perform_handshake(
addresses = self.startup_handshake(
handshake_socket, local_client, headless, parallel_config_to_update
)
yield addresses
yield addresses, handshake_socket

# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
Expand Down
30 changes: 25 additions & 5 deletions vllm/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,14 +804,16 @@ def wait_for_engine_startup(
and not parallel_config.data_parallel_external_lb
)

proc_manager_poll = zmq.Poller()
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
proc_manager_poll.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
proc_manager_events = proc_manager_poll.poll(STARTUP_POLL_PERIOD_MS)
if not events and not proc_manager_events:
Comment on lines 813 to +816
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The revised loop structure for polling has several critical issues:

  1. IndexError: If proc_manager_events is non-empty but events is empty, the check if len(events) > 1 or events[0][0] != handshake_socket: on line 782 will raise an IndexError when accessing events[0]. You should guard this access, for example with if events:.
  2. Process Hang: In the same scenario (a worker dies silently, so proc_manager_events is non-empty and events is empty), if the IndexError is avoided, the code will proceed to handshake_socket.recv_multipart() on line 791. Since there's no message, this will block indefinitely, causing a hang. Process exit events should be handled before trying to receive messages.
  3. Incomplete Failure Detection: The check if len(proc_manager_events) > 1: on line 868 is incorrect. If a single worker process fails silently, len(proc_manager_events) will be 1, and the failure will go undetected within the loop. This should likely be if proc_manager_events:.

These issues suggest that the loop logic needs to be restructured to correctly handle process exits from both coord_process and worker processes, especially when they fail without sending a ZMQ message.

if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) to connect.",
Expand All @@ -823,9 +825,9 @@ def wait_for_engine_startup(
*start_pending,
)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if len(events) > 1 or (len(events) == 1 and events[0][0] != handshake_socket):
# coord_process processes exited.
finished = {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError(
Expand Down Expand Up @@ -924,13 +926,31 @@ def wait_for_engine_startup(

start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
elif status == "FAILED" and engine.state == CoreEngineState.CONNECTED:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
raise RuntimeError(
"Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}. "
"Recived error message from failed "
f"engine {eng_index}: {msg['error_msg']}"
)
else:
raise RuntimeError(
f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state."
)

if len(proc_manager_events) > 1:
# One or more local core processes exited but we didn't receive any msg.
finished = proc_manager.finished_procs() if proc_manager else {}
raise RuntimeError(
"Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}"
)
logger.debug(
"%s from %s core engine process %s.",
status,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
ExecutorWithExternalLauncher as ExecutorWithExternalLauncherV0,
)
from vllm.executor.uniproc_executor import UniProcExecutor as UniProcExecutorV0 # noqa
from vllm.logger import init_logger
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput

logger = init_logger(__name__)
FailureCallback = Callable[[], None]


Expand Down
23 changes: 18 additions & 5 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""

READY_STR = "READY"
FAILED_INIT_STR = "FAILED_INIT"

def __init__(
self,
Expand Down Expand Up @@ -505,9 +506,9 @@ def make_worker_process(
def wait_for_ready(
unready_proc_handles: list[UnreadyWorkerProcHandle],
) -> list[WorkerProcHandle]:
e = Exception(
err_msg = (
"WorkerProc initialization failed due to "
"an exception in a background process. "
"an exception {}in a background process. "
"See stack trace for root cause."
)

Expand All @@ -523,8 +524,13 @@ def wait_for_ready(
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
raise e
status = response["status"]
if status == WorkerProc.FAILED_INIT_STR:
raise Exception(
err_msg.format(f"(err: {response['error_msg']}) ")
)
elif status != WorkerProc.READY_STR:
raise Exception(err_msg.format(""))

# Extract the message queue handle.
worker_response_mq = MessageQueue.create_from_handle(
Expand All @@ -537,6 +543,7 @@ def wait_for_ready(
)

except EOFError:
e = Exception(err_msg.format(""))
e.__suppress_context__ = True
raise e from None

Expand Down Expand Up @@ -619,13 +626,19 @@ def monitor_parent_death():

worker.worker_busy_loop(cancel=shutdown_event)

except Exception:
except Exception as e:
# NOTE: if an Exception arises in busy_loop, we send
# a FAILURE message over the MQ RPC to notify the Executor,
# which triggers system shutdown.
# TODO(rob): handle case where the MQ itself breaks.

if ready_writer is not None:
ready_writer.send(
{
"status": WorkerProc.FAILED_INIT_STR,
"error_msg": str(e),
}
)
logger.exception("WorkerProc failed to start.")
elif shutdown_event.is_set():
logger.info("WorkerProc shutting down.")
Expand Down