diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c68ac66adea9..60a9bfc9187d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import os import queue import signal @@ -119,9 +120,18 @@ def __init__( self._eep_scale_up_before_kv_init() # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( - vllm_config - ) + try: + num_gpu_blocks, num_cpu_blocks, kv_cache_config = ( + self._initialize_kv_caches(vllm_config) + ) + except Exception: + logger.exception( + "EngineCore failed during KV cache initialization; " + "shutting down executor." + ) + self.model_executor.shutdown() + raise + if kv_cache_config.kv_cache_groups: vllm_config.cache_config.block_size = min( g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups @@ -971,29 +981,49 @@ def _perform_handshake( addresses = self.startup_handshake( handshake_socket, local_client, headless, parallel_config_to_update ) - yield addresses - - # Send ready message. - num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - # We pass back the coordinator stats update address here for the - # external LB case for our colocated front-end to use (coordinator - # only runs with rank 0). - dp_stats_address = self.frontend_stats_publish_address - - # Include config hash for DP configuration validation - ready_msg = { - "status": "READY", - "local": local_client, - "headless": headless, - "num_gpu_blocks": num_gpu_blocks, - "dp_stats_address": dp_stats_address, - } - if vllm_config.parallel_config.data_parallel_size > 1: - ready_msg["parallel_config_hash"] = ( - vllm_config.parallel_config.compute_hash() - ) + exc_during_init = False + try: + yield addresses + except Exception: + exc_during_init = True + raise + finally: + if exc_during_init: + # Send FAILED status so the front-end detects init + # failure immediately via ZMQ instead of waiting for + # process sentinel (which may be delayed by cleanup). + with contextlib.suppress(Exception): + handshake_socket.send( + msgspec.msgpack.encode( + { + "status": "FAILED", + "local": local_client, + "headless": headless, + } + ) + ) + else: + # Send ready message. + num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks + # We pass back the coordinator stats update address + # here for the external LB case for our colocated + # front-end to use (coordinator only runs with rank 0). + dp_stats_address = self.frontend_stats_publish_address + + # Include config hash for DP configuration validation + ready_msg = { + "status": "READY", + "local": local_client, + "headless": headless, + "num_gpu_blocks": num_gpu_blocks, + "dp_stats_address": dp_stats_address, + } + if vllm_config.parallel_config.data_parallel_size > 1: + ready_msg["parallel_config_hash"] = ( + vllm_config.parallel_config.compute_hash() + ) - handshake_socket.send(msgspec.msgpack.encode(ready_msg)) + handshake_socket.send(msgspec.msgpack.encode(ready_msg)) @staticmethod def startup_handshake( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 321f84ea2a54..0a9d9c922502 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -1130,6 +1130,11 @@ def wait_for_engine_startup( start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY + elif status == "FAILED": + raise RuntimeError( + f"Engine core {eng_index} reported initialization failure. " + "See root cause above." + ) else: raise RuntimeError( f"Unexpected {status} message for "