-
-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[DO NOT MERGE] Reapply "[BugFix] Fix engine hanging after KV cache initialization failure #35478" #36650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT MERGE] Reapply "[BugFix] Fix engine hanging after KV cache initialization failure #35478" #36650
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
| import contextlib | ||||||
| import os | ||||||
| import queue | ||||||
| import signal | ||||||
|
|
@@ -119,9 +120,18 @@ def __init__( | |||||
| self._eep_scale_up_before_kv_init() | ||||||
|
|
||||||
| # Setup KV Caches and update CacheConfig after profiling. | ||||||
| num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( | ||||||
| vllm_config | ||||||
| ) | ||||||
| try: | ||||||
| num_gpu_blocks, num_cpu_blocks, kv_cache_config = ( | ||||||
| self._initialize_kv_caches(vllm_config) | ||||||
| ) | ||||||
| except Exception: | ||||||
| logger.exception( | ||||||
| "EngineCore failed during KV cache initialization; " | ||||||
| "shutting down executor." | ||||||
| ) | ||||||
| self.model_executor.shutdown() | ||||||
| raise | ||||||
|
|
||||||
| if kv_cache_config.kv_cache_groups: | ||||||
| vllm_config.cache_config.block_size = min( | ||||||
| g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups | ||||||
|
|
@@ -971,29 +981,49 @@ def _perform_handshake( | |||||
| addresses = self.startup_handshake( | ||||||
| handshake_socket, local_client, headless, parallel_config_to_update | ||||||
| ) | ||||||
| yield addresses | ||||||
|
|
||||||
| # Send ready message. | ||||||
| num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks | ||||||
| # We pass back the coordinator stats update address here for the | ||||||
| # external LB case for our colocated front-end to use (coordinator | ||||||
| # only runs with rank 0). | ||||||
| dp_stats_address = self.frontend_stats_publish_address | ||||||
|
|
||||||
| # Include config hash for DP configuration validation | ||||||
| ready_msg = { | ||||||
| "status": "READY", | ||||||
| "local": local_client, | ||||||
| "headless": headless, | ||||||
| "num_gpu_blocks": num_gpu_blocks, | ||||||
| "dp_stats_address": dp_stats_address, | ||||||
| } | ||||||
| if vllm_config.parallel_config.data_parallel_size > 1: | ||||||
| ready_msg["parallel_config_hash"] = ( | ||||||
| vllm_config.parallel_config.compute_hash() | ||||||
| ) | ||||||
| exc_during_init = False | ||||||
| try: | ||||||
| yield addresses | ||||||
| except Exception: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, using a bare Consider catching more specific exceptions or adding a comment to justify the broad catch if it's a deliberate choice for robust error signaling during critical initialization.
Suggested change
|
||||||
| exc_during_init = True | ||||||
| raise | ||||||
| finally: | ||||||
| if exc_during_init: | ||||||
| # Send FAILED status so the front-end detects init | ||||||
| # failure immediately via ZMQ instead of waiting for | ||||||
| # process sentinel (which may be delayed by cleanup). | ||||||
| with contextlib.suppress(Exception): | ||||||
| handshake_socket.send( | ||||||
| msgspec.msgpack.encode( | ||||||
| { | ||||||
| "status": "FAILED", | ||||||
| "local": local_client, | ||||||
| "headless": headless, | ||||||
| } | ||||||
| ) | ||||||
| ) | ||||||
| else: | ||||||
| # Send ready message. | ||||||
| num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks | ||||||
| # We pass back the coordinator stats update address | ||||||
| # here for the external LB case for our colocated | ||||||
| # front-end to use (coordinator only runs with rank 0). | ||||||
| dp_stats_address = self.frontend_stats_publish_address | ||||||
|
|
||||||
| # Include config hash for DP configuration validation | ||||||
| ready_msg = { | ||||||
| "status": "READY", | ||||||
| "local": local_client, | ||||||
| "headless": headless, | ||||||
| "num_gpu_blocks": num_gpu_blocks, | ||||||
| "dp_stats_address": dp_stats_address, | ||||||
| } | ||||||
| if vllm_config.parallel_config.data_parallel_size > 1: | ||||||
| ready_msg["parallel_config_hash"] = ( | ||||||
| vllm_config.parallel_config.compute_hash() | ||||||
| ) | ||||||
|
|
||||||
| handshake_socket.send(msgspec.msgpack.encode(ready_msg)) | ||||||
| handshake_socket.send(msgspec.msgpack.encode(ready_msg)) | ||||||
|
|
||||||
| @staticmethod | ||||||
| def startup_handshake( | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad
Exceptioncan mask specific underlying issues, making debugging more challenging. While the intent here is likely to catch any failure during KV cache initialization, it's generally recommended to catch more specific exceptions if possible. If the exact exceptions are unknown or if the intent is truly to catch all exceptions for a critical shutdown, consider adding a comment explaining this design choice.For example, if there are known exceptions related to memory allocation or hardware, catching those specifically would provide clearer error messages.