diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 43f57719a383..a2ac49bcb0b2 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -1,20 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Integration tests for shutdown behavior, timeout, and signal handling.""" -import asyncio import signal import subprocess import sys import time -from dataclasses import dataclass, field -import httpx import openai -import psutil import pytest -from tests.utils import RemoteOpenAIServer from vllm.platforms import current_platform from vllm.utils.network_utils import get_open_port @@ -24,101 +18,6 @@ _IS_ROCM = current_platform.is_rocm() _SERVER_STARTUP_TIMEOUT = 120 _PROCESS_EXIT_TIMEOUT = 15 -_SHUTDOWN_DETECTION_TIMEOUT = 10 -_CHILD_CLEANUP_TIMEOUT = 10 - - -def _get_child_pids(parent_pid: int) -> list[int]: - try: - parent = psutil.Process(parent_pid) - return [c.pid for c in parent.children(recursive=True)] - except psutil.NoSuchProcess: - return [] - - -async def _assert_children_cleaned_up( - child_pids: list[int], - timeout: float = _CHILD_CLEANUP_TIMEOUT, -): - """Wait for child processes to exit and fail if any remain.""" - if not child_pids: - return - - deadline = time.time() + timeout - while time.time() < deadline: - still_alive = [] - for pid in child_pids: - try: - p = psutil.Process(pid) - if p.is_running() and p.status() != psutil.STATUS_ZOMBIE: - still_alive.append(pid) - except psutil.NoSuchProcess: - pass - if not still_alive: - return - await asyncio.sleep(0.5) - - pytest.fail( - f"Child processes {still_alive} still alive after {timeout}s. " - f"Process cleanup may not be working correctly." - ) - - -@dataclass -class ShutdownState: - got_503: bool = False - got_500: bool = False - requests_after_sigterm: int = 0 - aborted_requests: int = 0 - connection_errors: int = 0 - stop_requesting: bool = False - errors: list[str] = field(default_factory=list) - - -async def _concurrent_request_loop( - client: openai.AsyncOpenAI, - state: ShutdownState, - sigterm_sent: asyncio.Event | None = None, - concurrency: int = 10, -): - """Run multiple concurrent requests to keep the server busy.""" - - async def single_request(): - while not state.stop_requesting: - try: - response = await client.completions.create( - model=MODEL_NAME, - prompt="Write a story: ", - max_tokens=200, - ) - if sigterm_sent is not None and sigterm_sent.is_set(): - state.requests_after_sigterm += 1 - # Check if any choice has finish_reason='abort' - if any(choice.finish_reason == "abort" for choice in response.choices): - state.aborted_requests += 1 - except openai.APIStatusError as e: - if e.status_code == 503: - state.got_503 = True - elif e.status_code == 500: - state.got_500 = True - else: - state.errors.append(f"API error: {e}") - except (openai.APIConnectionError, httpx.RemoteProtocolError): - state.connection_errors += 1 - if sigterm_sent is not None and sigterm_sent.is_set(): - break - except Exception as e: - state.errors.append(f"Unexpected error: {e}") - break - await asyncio.sleep(0.01) - - tasks = [asyncio.create_task(single_request()) for _ in range(concurrency)] - try: - await asyncio.gather(*tasks, return_exceptions=True) - finally: - for t in tasks: - if not t.done(): - t.cancel() @pytest.mark.asyncio @@ -204,361 +103,3 @@ async def test_shutdown_on_engine_failure(): return_code = proc.wait(timeout=_PROCESS_EXIT_TIMEOUT) assert return_code is not None - - -@pytest.mark.asyncio -async def test_wait_timeout_completes_requests(): - """Verify wait timeout: new requests rejected, in-flight requests complete.""" - server_args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "256", - "--enforce-eager", - "--gpu-memory-utilization", - "0.05", - "--max-num-seqs", - "4", - "--shutdown-timeout", - "30", - ] - - with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: - client = remote_server.get_async_client() - proc = remote_server.proc - child_pids = _get_child_pids(proc.pid) - - state = ShutdownState() - sigterm_sent = asyncio.Event() - - request_task = asyncio.create_task( - _concurrent_request_loop(client, state, sigterm_sent, concurrency=10) - ) - - await asyncio.sleep(0.5) - proc.send_signal(signal.SIGTERM) - sigterm_sent.set() - - try: - await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT) - except asyncio.TimeoutError: - pass - finally: - state.stop_requesting = True - if not request_task.done(): - request_task.cancel() - await asyncio.gather(request_task, return_exceptions=True) - - # wait timeout should complete in-flight requests - assert state.requests_after_sigterm > 0, ( - f"Wait timeout should complete in-flight requests. " - f"503: {state.got_503}, 500: {state.got_500}, " - f"conn_errors: {state.connection_errors}, errors: {state.errors}" - ) - # server must stop accepting new requests (503, 500, or connection close) - assert state.got_503 or state.got_500 or state.connection_errors > 0, ( - f"Server should stop accepting requests. " - f"completed: {state.requests_after_sigterm}, errors: {state.errors}" - ) - - await _assert_children_cleaned_up(child_pids) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("wait_for_engine_idle", [0.0, 2.0]) -async def test_abort_timeout_exits_quickly(wait_for_engine_idle: float): - server_args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "256", - "--enforce-eager", - "--gpu-memory-utilization", - "0.05", - "--max-num-seqs", - "4", - "--shutdown-timeout", - "0", - ] - - with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: - proc = remote_server.proc - child_pids = _get_child_pids(proc.pid) - - if wait_for_engine_idle > 0: - client = remote_server.get_async_client() - # Send requests to ensure engine is fully initialized - for _ in range(2): - await client.completions.create( - model=MODEL_NAME, - prompt="Test request: ", - max_tokens=10, - ) - # Wait for engine to become idle - await asyncio.sleep(wait_for_engine_idle) - - start_time = time.time() - proc.send_signal(signal.SIGTERM) - - # abort timeout (0) should exit promptly - for _ in range(20): - if proc.poll() is not None: - break - time.sleep(0.1) - - if proc.poll() is None: - proc.kill() - proc.wait(timeout=5) - pytest.fail("Process did not exit after SIGTERM with abort timeout") - - exit_time = time.time() - start_time - assert exit_time < 2, f"Default shutdown took too long: {exit_time:.1f}s" - assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}" - - await _assert_children_cleaned_up(child_pids) - - -@pytest.mark.asyncio -async def test_wait_timeout_with_short_duration(): - """Verify server exits cleanly with a short wait timeout.""" - wait_timeout = 3 - server_args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "256", - "--enforce-eager", - "--gpu-memory-utilization", - "0.05", - "--max-num-seqs", - "4", - "--shutdown-timeout", - str(wait_timeout), - ] - - with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: - client = remote_server.get_async_client() - proc = remote_server.proc - child_pids = _get_child_pids(proc.pid) - - state = ShutdownState() - request_task = asyncio.create_task( - _concurrent_request_loop(client, state, concurrency=3) - ) - - await asyncio.sleep(0.5) - - start_time = time.time() - proc.send_signal(signal.SIGTERM) - - # server should exit within wait_timeout + buffer - max_wait = wait_timeout + 15 - for _ in range(int(max_wait * 10)): - if proc.poll() is not None: - break - time.sleep(0.1) - - exit_time = time.time() - start_time - - state.stop_requesting = True - if not request_task.done(): - request_task.cancel() - await asyncio.gather(request_task, return_exceptions=True) - - if proc.poll() is None: - proc.kill() - proc.wait(timeout=5) - pytest.fail(f"Process did not exit within {max_wait}s after SIGTERM") - - assert exit_time < wait_timeout + 10, ( - f"Took too long to exit ({exit_time:.1f}s), expected <{wait_timeout + 10}s" - ) - assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}" - - await _assert_children_cleaned_up(child_pids) - - -@pytest.mark.asyncio -async def test_abort_timeout_fails_inflight_requests(): - """Verify abort timeout (0) immediately aborts in-flight requests.""" - server_args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "256", - "--enforce-eager", - "--gpu-memory-utilization", - "0.05", - "--max-num-seqs", - "4", - "--shutdown-timeout", - "0", - ] - - with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: - client = remote_server.get_async_client() - proc = remote_server.proc - child_pids = _get_child_pids(proc.pid) - - state = ShutdownState() - sigterm_sent = asyncio.Event() - - request_task = asyncio.create_task( - _concurrent_request_loop(client, state, sigterm_sent, concurrency=10) - ) - - await asyncio.sleep(0.5) - - proc.send_signal(signal.SIGTERM) - sigterm_sent.set() - - try: - await asyncio.wait_for(request_task, timeout=5) - except asyncio.TimeoutError: - pass - finally: - state.stop_requesting = True - if not request_task.done(): - request_task.cancel() - await asyncio.gather(request_task, return_exceptions=True) - - # With abort timeout (0), requests should be aborted (finish_reason='abort') - # or rejected (connection errors or API errors) - assert ( - state.aborted_requests > 0 - or state.connection_errors > 0 - or state.got_500 - or state.got_503 - ), ( - f"Abort timeout should cause request aborts or failures. " - f"aborted: {state.aborted_requests}, " - f"503: {state.got_503}, 500: {state.got_500}, " - f"conn_errors: {state.connection_errors}, " - f"completed: {state.requests_after_sigterm}" - ) - - # Verify fast shutdown - start_time = time.time() - for _ in range(100): - if proc.poll() is not None: - break - time.sleep(0.1) - - exit_time = time.time() - start_time - assert exit_time < 10, f"Abort timeout shutdown took too long: {exit_time:.1f}s" - - await _assert_children_cleaned_up(child_pids) - - -@pytest.mark.asyncio -async def test_request_rejection_during_shutdown(): - """Verify new requests are rejected with error during shutdown.""" - server_args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "256", - "--enforce-eager", - "--gpu-memory-utilization", - "0.05", - "--max-num-seqs", - "4", - "--shutdown-timeout", - "30", - ] - - with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: - client = remote_server.get_async_client() - proc = remote_server.proc - child_pids = _get_child_pids(proc.pid) - - proc.send_signal(signal.SIGTERM) - - await asyncio.sleep(1.0) - - # Try to send new requests - they should be rejected - rejected_count = 0 - for _ in range(10): - try: - await client.completions.create( - model=MODEL_NAME, prompt="Hello", max_tokens=10 - ) - except ( - openai.APIStatusError, - openai.APIConnectionError, - httpx.RemoteProtocolError, - ): - rejected_count += 1 - await asyncio.sleep(0.1) - - assert rejected_count > 0, ( - f"Expected requests to be rejected during shutdown, " - f"but {rejected_count} were rejected out of 10" - ) - - await _assert_children_cleaned_up(child_pids) - - -@pytest.mark.asyncio -async def test_multi_api_server_shutdown(): - """Verify shutdown works with multiple API servers.""" - server_args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "256", - "--enforce-eager", - "--gpu-memory-utilization", - "0.05", - "--max-num-seqs", - "4", - "--shutdown-timeout", - "30", - "--api-server-count", - "2", - ] - - with RemoteOpenAIServer(MODEL_NAME, server_args, auto_port=True) as remote_server: - client = remote_server.get_async_client() - proc = remote_server.proc - child_pids = _get_child_pids(proc.pid) - - assert len(child_pids) >= 2, ( - f"Expected at least 2 child processes, got {len(child_pids)}" - ) - - state = ShutdownState() - sigterm_sent = asyncio.Event() - - # Start concurrent requests across both API servers - request_task = asyncio.create_task( - _concurrent_request_loop(client, state, sigterm_sent, concurrency=8) - ) - - await asyncio.sleep(0.5) - - # Send SIGTERM to parent - should propagate to all children - proc.send_signal(signal.SIGTERM) - sigterm_sent.set() - - try: - await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT) - except asyncio.TimeoutError: - pass - finally: - state.stop_requesting = True - if not request_task.done(): - request_task.cancel() - await asyncio.gather(request_task, return_exceptions=True) - - for _ in range(300): # up to 30 seconds - if proc.poll() is not None: - break - time.sleep(0.1) - - if proc.poll() is None: - proc.kill() - proc.wait(timeout=5) - pytest.fail("Process did not exit after SIGTERM") - - await _assert_children_cleaned_up(child_pids) diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index 3820fdefb194..3fadbf2ef0dd 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -79,7 +79,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): finally: # Always clean up the processes print("Cleaning up processes...") - manager.shutdown() + manager.close() # Give processes time to terminate time.sleep(0.2) @@ -111,8 +111,6 @@ def run_with_exception_capture(): wait_for_completion_or_failure(api_server_manager=manager) except Exception as e: result["exception"] = e - finally: - manager.shutdown() # Start a thread to run wait_for_completion_or_failure wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) @@ -145,7 +143,7 @@ def run_with_exception_capture(): assert not proc.is_alive(), f"Process {i} should not be alive" finally: - manager.shutdown() + manager.close() time.sleep(0.2) @@ -176,14 +174,11 @@ def test_normal_completion(api_server_args): # since all processes have already # terminated, it should return immediately # with no error - try: - wait_for_completion_or_failure(api_server_manager=manager) - finally: - manager.shutdown() + wait_for_completion_or_failure(api_server_manager=manager) finally: # Clean up just in case - manager.shutdown() + manager.close() time.sleep(0.2) @@ -206,7 +201,7 @@ class MockCoordinator: def __init__(self, proc): self.proc = proc - def shutdown(self): + def close(self): if self.proc.is_alive(): self.proc.terminate() self.proc.join(timeout=0.5) @@ -231,9 +226,6 @@ def run_with_exception_capture(): ) except Exception as e: result["exception"] = e - finally: - manager.shutdown() - mock_coordinator.shutdown() # Start a thread to run wait_for_completion_or_failure wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) @@ -267,6 +259,6 @@ def run_with_exception_capture(): finally: # Clean up - manager.shutdown() - mock_coordinator.shutdown() + manager.close() + mock_coordinator.close() time.sleep(0.2) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index dc776fac1469..f078ae994783 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -327,12 +327,6 @@ class VllmConfig: weight_transfer_config: WeightTransferConfig | None = None """The configurations for weight transfer during RL training.""" - shutdown_timeout: int = Field(default=0, ge=0) - """Shutdown grace period for in-flight requests. Shutdown will be delayed for - up to this amount of time to allow already-running requests to complete. Any - remaining requests are aborted once the timeout is reached. - """ - def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 700713e32dd1..56bbb7bf54e3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -606,8 +606,6 @@ class EngineArgs: kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend tokens_only: bool = False - shutdown_timeout: int = 0 - weight_transfer_config: WeightTransferConfig | None = get_field( VllmConfig, "weight_transfer_config", @@ -1310,14 +1308,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, action=argparse.BooleanOptionalAction, ) - - parser.add_argument( - "--shutdown-timeout", - type=int, - default=0, - help="Shutdown timeout in seconds. 0 = abort, >0 = wait.", - ) - return parser @classmethod @@ -1926,7 +1916,6 @@ def create_engine_config( optimization_level=self.optimization_level, performance_mode=self.performance_mode, weight_transfer_config=self.weight_transfer_config, - shutdown_timeout=self.shutdown_timeout, ) return config diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 0b3b29cd6c1f..ea2bf5303b5f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -200,11 +200,6 @@ async def is_paused(self) -> bool: """Return whether the engine is currently paused.""" ... - @abstractmethod - def shutdown(self, timeout: float | None = None) -> None: - """Shutdown the engine with optional timeout.""" - ... - async def scale_elastic_ep( self, new_data_parallel_size: int, drain_timeout: int = 300 ) -> None: diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 677c6ea0f333..9e3988b15bee 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -3,7 +3,6 @@ import argparse import signal -import time import uvloop @@ -212,12 +211,8 @@ def signal_handler(signum, frame): try: engine_manager.join_first() finally: - timeout = None - if shutdown_requested: - timeout = vllm_config.shutdown_timeout - logger.info("Waiting up to %d seconds for processes to exit", timeout) - engine_manager.shutdown(timeout=timeout) logger.info("Shutting down.") + engine_manager.close() def run_multi_api_server(args: argparse.Namespace): @@ -228,19 +223,6 @@ def run_multi_api_server(args: argparse.Namespace): if num_api_servers > 1: setup_multiprocess_prometheus() - shutdown_requested = False - - # Catch SIGTERM and SIGINT to allow graceful shutdown. - def signal_handler(signum, frame): - nonlocal shutdown_requested - logger.debug("Received %d signal.", signum) - if not shutdown_requested: - shutdown_requested = True - raise SystemExit - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - listen_address, sock = setup_server(args) engine_args = vllm.AsyncEngineArgs.from_cli_args(args) @@ -302,29 +284,11 @@ def signal_handler(signum, frame): api_server_manager = APIServerProcessManager(**api_server_manager_kwargs) # Wait for API servers - try: - wait_for_completion_or_failure( - api_server_manager=api_server_manager, - engine_manager=local_engine_manager, - coordinator=coordinator, - ) - finally: - timeout = shutdown_by = None - if shutdown_requested: - timeout = vllm_config.shutdown_timeout - shutdown_by = time.monotonic() + timeout - logger.info("Waiting up to %d seconds for processes to exit", timeout) - - def to_timeout(deadline: float | None) -> float | None: - return ( - deadline if deadline is None else max(deadline - time.monotonic(), 0.0) - ) - - api_server_manager.shutdown(timeout=timeout) - if local_engine_manager: - local_engine_manager.shutdown(timeout=to_timeout(shutdown_by)) - if coordinator: - coordinator.shutdown(timeout=to_timeout(shutdown_by)) + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + engine_manager=local_engine_manager, + coordinator=coordinator, + ) def run_api_server_worker_proc( diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 8caeb80836f9..b442fc70cdb0 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -4,7 +4,6 @@ import asyncio import signal import socket -from functools import partial from typing import Any import uvicorn @@ -92,10 +91,12 @@ async def serve_http( ) ) - shutdown_event = asyncio.Event() - def signal_handler() -> None: - shutdown_event.set() + # prevents the uvicorn signal handler to exit early + server_task.cancel() + watchdog_task.cancel() + if ssl_cert_refresher: + ssl_cert_refresher.stop() async def dummy_shutdown() -> None: pass @@ -103,24 +104,6 @@ async def dummy_shutdown() -> None: loop.add_signal_handler(signal.SIGINT, signal_handler) loop.add_signal_handler(signal.SIGTERM, signal_handler) - async def handle_shutdown() -> None: - await shutdown_event.wait() - - engine_client = app.state.engine_client - timeout = engine_client.vllm_config.shutdown_timeout - - await loop.run_in_executor( - None, partial(engine_client.shutdown, timeout=timeout) - ) - - server.should_exit = True - server_task.cancel() - watchdog_task.cancel() - if ssl_cert_refresher: - ssl_cert_refresher.stop() - - shutdown_task = loop.create_task(handle_shutdown()) - try: await server_task return dummy_shutdown() @@ -137,7 +120,6 @@ async def handle_shutdown() -> None: logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() finally: - shutdown_task.cancel() watchdog_task.cancel() diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d76948bc277d..33e39a3590ce 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -226,8 +226,6 @@ class EngineCoreRequestType(enum.Enum): UTILITY = b"\x03" # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b"\x04" - # Sentinel to wake up input_queue.get() during shutdown. - WAKEUP = b"\x05" class ReconfigureDistributedRequest(msgspec.Struct): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a9c42e78e53b..6be0a07baeb2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -264,15 +264,16 @@ def from_engine_args( def __del__(self): self.shutdown() - def shutdown(self, timeout: float | None = None) -> None: + def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" + shutdown_prometheus() if renderer := getattr(self, "renderer", None): renderer.shutdown() if engine_core := getattr(self, "engine_core", None): - engine_core.shutdown(timeout=timeout) + engine_core.shutdown() handler = getattr(self, "output_handler", None) if handler is not None: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 0d07f29a5cb4..44a346350fc8 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -104,10 +104,8 @@ def get_engine_socket_addresses(self) -> tuple[str, str]: """Returns tuple of ZMQ input address, output address.""" return self.coord_in_address, self.coord_out_address - def shutdown(self, timeout: float | None = None) -> None: - """Shutdown coordinator process with configurable timeout.""" - if self._finalizer.detach() is not None: - shutdown([self.proc], timeout=timeout) + def close(self): + self._finalizer() class EngineState: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c68ac66adea9..6d57fce0229d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -9,7 +9,6 @@ from collections.abc import Callable, Generator from concurrent.futures import Future from contextlib import ExitStack, contextmanager -from enum import IntEnum from functools import partial from inspect import isclass, signature from logging import DEBUG @@ -62,7 +61,6 @@ from vllm.v1.engine.utils import ( EngineHandshakeMetadata, EngineZmqAddresses, - SignalCallback, get_device_indices, ) from vllm.v1.executor import Executor @@ -773,12 +771,6 @@ def _eep_send_engine_core_notification( raise NotImplementedError -class EngineShutdownState(IntEnum): - RUNNING = 0 - REQUESTED = 1 - SHUTTING_DOWN = 2 - - class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -806,7 +798,6 @@ def __init__( self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False - self.shutdown_state = EngineShutdownState.RUNNING with self._perform_handshakes( handshake_address, @@ -1037,11 +1028,25 @@ def startup_handshake( def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): """Launch EngineCore busy loop in background process.""" + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + # Ensure we can serialize transformer config after spawning maybe_register_config_serialize_by_value() + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the engine_core + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + engine_core: EngineCoreProc | None = None - signal_callback: SignalCallback | None = None try: vllm_config: VllmConfig = kwargs["vllm_config"] parallel_config: ParallelConfig = vllm_config.parallel_config @@ -1089,22 +1094,6 @@ def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) assert engine_core is not None - - def wakeup_engine(): - # Wakes up idle engine via input_queue when shutdown is requested - # Not safe in a signal handler - we may interrupt the main thread - # while it is holding the non-reentrant input_queue.mutex - engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None)) - - signal_callback = SignalCallback(wakeup_engine) - - def signal_handler(signum, frame): - engine_core.shutdown_state = EngineShutdownState.REQUESTED - signal_callback.trigger() - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - engine_core.run_busy_loop() except SystemExit: @@ -1118,10 +1107,6 @@ def signal_handler(signum, frame): engine_core._send_engine_dead() raise e finally: - signal.signal(signal.SIGTERM, signal.SIG_DFL) - signal.signal(signal.SIGINT, signal.SIG_DFL) - if signal_callback is not None: - signal_callback.stop() if engine_core is not None: engine_core.shutdown() @@ -1136,25 +1121,21 @@ def has_work(self) -> bool: or bool(self.batch_queue) ) - def is_running(self) -> bool: - """Returns true if shutdown has not been requested.""" - return self.shutdown_state == EngineShutdownState.RUNNING - def run_busy_loop(self): """Core busy loop of the EngineCore.""" - while self._handle_shutdown(): + + # Loop until process is sent a SIGINT or SIGTERM + while True: # 1) Poll the input queue until there is work to do. self._process_input_queue() # 2) Step the engine core and return the outputs. self._process_engine_step() - raise SystemExit - def _process_input_queue(self): """Exits when an engine step needs to be performed.""" waited = False - while not self.has_work() and self.is_running(): + while not self.has_work(): # Notify callbacks waiting for engine to become idle. self._notify_idle_state_callbacks() if self.input_queue.empty(): @@ -1206,60 +1187,18 @@ def _notify_idle_state_callbacks(self) -> None: callback = self._idle_state_callbacks.pop() callback(self) - def _handle_shutdown(self) -> bool: - # Check if shutdown was requested and handle it - if self.shutdown_state == EngineShutdownState.RUNNING: - return True - - if self.shutdown_state == EngineShutdownState.REQUESTED: - shutdown_timeout = self.vllm_config.shutdown_timeout - - logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout) - - if shutdown_timeout == 0: - num_requests = self.scheduler.get_num_unfinished_requests() - if num_requests > 0: - logger.info("Aborting %d requests", num_requests) - aborted_reqs = self.scheduler.finish_requests( - None, RequestStatus.FINISHED_ABORTED - ) - self._send_abort_outputs(aborted_reqs) - else: - num_requests = self.scheduler.get_num_unfinished_requests() - if num_requests > 0: - logger.info( - "Draining %d in-flight requests (timeout=%ds)", - num_requests, - shutdown_timeout, - ) - - self.shutdown_state = EngineShutdownState.SHUTTING_DOWN - - # Exit when no work remaining - if not self.has_work(): - logger.info("Shutdown complete") - return False - - return True - def _handle_client_request( self, request_type: EngineCoreRequestType, request: Any ) -> None: """Dispatch request from client.""" - if request_type == EngineCoreRequestType.WAKEUP: - return - elif request_type == EngineCoreRequestType.ADD: + if request_type == EngineCoreRequestType.ADD: req, request_wave = request - if self._reject_add_in_shutdown(req): - return self.add_request(req, request_wave) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: client_idx, call_id, method_name, args = request - if self._reject_utility_in_shutdown(client_idx, call_id, method_name): - return output = UtilityOutput(call_id) # Lazily look-up utility method so that failure will be handled/returned. get_result = lambda: (method := getattr(self, method_name)) and method( @@ -1276,27 +1215,6 @@ def _handle_client_request( "Unrecognized input request type encountered: %s", request_type ) - def _reject_add_in_shutdown(self, request: Request) -> bool: - if self.shutdown_state == EngineShutdownState.RUNNING: - return False - - logger.info("Rejecting request %s (server shutting down)", request.request_id) - self._send_abort_outputs_to_client([request.request_id], request.client_index) - return True - - def _reject_utility_in_shutdown( - self, client_idx: int, call_id: int, method_name: str - ) -> bool: - if self.shutdown_state == EngineShutdownState.RUNNING: - return False - - logger.warning("Rejecting utility call %s (server shutting down)", method_name) - output = UtilityOutput(call_id, failure_message="Server shutting down") - self.output_queue.put_nowait( - (client_idx, EngineCoreOutputs(utility_output=output)) - ) - return True - @staticmethod def _invoke_utility_method( name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable @@ -1510,7 +1428,22 @@ def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None: logger.exception( "Unexpected error pre-processing request %s", request.request_id ) - self._send_error_outputs_to_client([request.request_id], request.client_index) + self.output_queue.put_nowait( + ( + request.client_index, + EngineCoreOutputs( + engine_index=self.engine_index, + finished_requests={request.request_id}, + outputs=[ + EngineCoreOutput( + request_id=request.request_id, + new_token_ids=[], + finish_reason=FinishReason.ERROR, + ) + ], + ), + ) + ) def pause_scheduler( self, mode: PauseMode = "abort", clear_cache: bool = True @@ -1553,26 +1486,6 @@ def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None: self._idle_state_callbacks.append(partial(engine_idle_callback, future=future)) return future - def _send_finish_outputs_to_client( - self, req_ids: list[str], client_index: int, finish_reason: FinishReason - ) -> None: - outputs = [ - EngineCoreOutput(req_id, [], finish_reason=finish_reason) - for req_id in req_ids - ] - eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs) - self.output_queue.put_nowait((client_index, eco)) - - def _send_abort_outputs_to_client( - self, req_ids: list[str], client_index: int - ) -> None: - self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ABORT) - - def _send_error_outputs_to_client( - self, req_ids: list[str], client_index: int - ) -> None: - self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ERROR) - def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None: # TODO(nick) this will be moved inside the scheduler if aborted_reqs: @@ -1581,7 +1494,12 @@ def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None: for req_id, client_index in aborted_reqs: by_client[client_index].add(req_id) for client_index, req_ids in by_client.items(): - self._send_abort_outputs_to_client(list(req_ids), client_index) + outputs = [ + EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT) + for req_id in req_ids + ] + eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs) + self.output_queue.put_nowait((client_index, eco)) class DPEngineCoreProc(EngineCoreProc): @@ -1699,7 +1617,7 @@ def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" # Loop until process is sent a SIGINT or SIGTERM - while self._handle_shutdown(): + while True: # 1) Poll the input queue until there is work to do. self._process_input_queue() @@ -1747,8 +1665,6 @@ def run_busy_loop(self): self.current_wave += 1 self.step_counter = 0 - raise SystemExit - def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: # Optimization - only perform finish-sync all-reduce every 32 steps. self.step_counter += 1 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c1b9b8ac42b1..f199e3b8d733 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -128,7 +128,7 @@ def make_async_mp_client( return AsyncMPClient(*client_args) @abstractmethod - def shutdown(self, timeout: float | None = None) -> None: ... + def shutdown(self): ... def get_output(self) -> EngineCoreOutputs: raise NotImplementedError @@ -298,7 +298,7 @@ def abort_requests(self, request_ids: list[str]) -> None: if len(request_ids) > 0: self.engine_core.abort_requests(request_ids) - def shutdown(self, timeout: float | None = None) -> None: + def shutdown(self) -> None: self.engine_core.shutdown() def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: @@ -390,9 +390,9 @@ def __call__(self): self.engine_dead = True if self.engine_manager is not None: - self.engine_manager.shutdown() + self.engine_manager.close() if self.coordinator is not None: - self.coordinator.shutdown() + self.coordinator.close() if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. @@ -568,7 +568,10 @@ def __init__( ) with launch_core_engines( - vllm_config, executor_class, log_stats, addresses + vllm_config, + executor_class, + log_stats, + addresses, ) as (engine_manager, coordinator, addresses): self.resources.coordinator = coordinator self.resources.engine_manager = engine_manager @@ -634,12 +637,9 @@ def __init__( if not success: self._finalizer() - def shutdown(self, timeout: float | None = None) -> None: - """Shutdown engine manager under timeout and clean up resources.""" - if self._finalizer.detach() is not None: - if self.resources.engine_manager is not None: - self.resources.engine_manager.shutdown(timeout=timeout) - self.resources() + def shutdown(self): + # Terminate background resources. + self._finalizer() def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" @@ -683,7 +683,7 @@ def monitor_engine_cores(): sentinels = [proc.sentinel for proc in engine_processes] died = multiprocessing.connection.wait(sentinels) _self = self_ref() - if not _self or not _self._finalizer.alive or _self.resources.engine_dead: + if not _self or _self.resources.engine_dead: return _self.resources.engine_dead = True proc_name = next( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 321f84ea2a54..a7d3c10b5752 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -3,7 +3,6 @@ import contextlib import os -import threading import weakref from collections.abc import Callable, Iterator from dataclasses import dataclass @@ -152,12 +151,11 @@ def __init__( finally: # Kill other procs if not all are running. if self.finished_procs(): - self.shutdown() + self.close() - def shutdown(self, timeout: float | None = None) -> None: - """Shutdown engine core processes with configurable timeout.""" - if self._finalizer.detach() is not None: - shutdown(self.processes, timeout=timeout) + def close(self): + """Shutdown all procs.""" + self._finalizer() def join_first(self): """Wait for any process to exit.""" @@ -175,33 +173,6 @@ def finished_procs(self) -> dict[str, int]: } -class SignalCallback: - """Safely trigger a callback from signal handler context via a dedicated thread.""" - - def __init__(self, callback: Callable[[], None]): - self._callback = callback - self._event = threading.Event() - self._stopped = False - self._thread = threading.Thread( - target=self._run, - daemon=True, - name="signal-callback", - ) - self._thread.start() - - def _run(self): - self._event.wait() - if not self._stopped: - self._callback() - - def trigger(self): - self._event.set() - - def stop(self): - self._stopped = True - self._event.set() - - @contextlib.contextmanager def set_device_control_env_var( vllm_config: VllmConfig, local_dp_rank: int @@ -797,7 +768,7 @@ def scale_down_elastic_ep( def get_run_refs(self): return self.run_refs - def shutdown(self, timeout: float | None = None) -> None: + def close(self): import ray for actor in self.local_engine_actors + self.remote_engine_actors: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 970465089e10..3d065927ed7e 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -220,10 +220,8 @@ def __init__( # The extra processes are managed by their owners self._finalizer = weakref.finalize(self, shutdown, self.processes) - def shutdown(self, timeout: float | None = None) -> None: - """Shutdown API server processes with configurable timeout""" - if self._finalizer.detach() is not None: - shutdown(self.processes, timeout=timeout) + def close(self) -> None: + self._finalizer() def wait_for_completion_or_failure( @@ -290,30 +288,25 @@ def wait_for_completion_or_failure( except Exception as e: logger.exception("Exception occurred while running API servers: %s", str(e)) raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if engine_manager: + engine_manager.close() # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. -def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None: - """Shutdown processes with timeout. - - Args: - procs: List of processes to shutdown - timeout: Maximum time in seconds to wait for graceful shutdown - """ - if timeout is None: - timeout = 0.0 - - # Allow at least 5 seconds for remaining procs to terminate. - timeout = max(timeout, 5.0) - +def shutdown(procs: list[BaseProcess]): # Shutdown the process. for proc in procs: if proc.is_alive(): proc.terminate() - # Allow time for remaining procs to terminate. - deadline = time.monotonic() + timeout + # Allow 5 seconds for remaining procs to terminate. + deadline = time.monotonic() + 5 for proc in procs: remaining = deadline - time.monotonic() if remaining <= 0: