diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index a2ac49bcb0b2..43f57719a383 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -1,14 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Integration tests for shutdown behavior, timeout, and signal handling.""" +import asyncio import signal import subprocess import sys import time +from dataclasses import dataclass, field +import httpx import openai +import psutil import pytest +from tests.utils import RemoteOpenAIServer from vllm.platforms import current_platform from vllm.utils.network_utils import get_open_port @@ -18,6 +24,101 @@ _IS_ROCM = current_platform.is_rocm() _SERVER_STARTUP_TIMEOUT = 120 _PROCESS_EXIT_TIMEOUT = 15 +_SHUTDOWN_DETECTION_TIMEOUT = 10 +_CHILD_CLEANUP_TIMEOUT = 10 + + +def _get_child_pids(parent_pid: int) -> list[int]: + try: + parent = psutil.Process(parent_pid) + return [c.pid for c in parent.children(recursive=True)] + except psutil.NoSuchProcess: + return [] + + +async def _assert_children_cleaned_up( + child_pids: list[int], + timeout: float = _CHILD_CLEANUP_TIMEOUT, +): + """Wait for child processes to exit and fail if any remain.""" + if not child_pids: + return + + deadline = time.time() + timeout + while time.time() < deadline: + still_alive = [] + for pid in child_pids: + try: + p = psutil.Process(pid) + if p.is_running() and p.status() != psutil.STATUS_ZOMBIE: + still_alive.append(pid) + except psutil.NoSuchProcess: + pass + if not still_alive: + return + await asyncio.sleep(0.5) + + pytest.fail( + f"Child processes {still_alive} still alive after {timeout}s. " + f"Process cleanup may not be working correctly." + ) + + +@dataclass +class ShutdownState: + got_503: bool = False + got_500: bool = False + requests_after_sigterm: int = 0 + aborted_requests: int = 0 + connection_errors: int = 0 + stop_requesting: bool = False + errors: list[str] = field(default_factory=list) + + +async def _concurrent_request_loop( + client: openai.AsyncOpenAI, + state: ShutdownState, + sigterm_sent: asyncio.Event | None = None, + concurrency: int = 10, +): + """Run multiple concurrent requests to keep the server busy.""" + + async def single_request(): + while not state.stop_requesting: + try: + response = await client.completions.create( + model=MODEL_NAME, + prompt="Write a story: ", + max_tokens=200, + ) + if sigterm_sent is not None and sigterm_sent.is_set(): + state.requests_after_sigterm += 1 + # Check if any choice has finish_reason='abort' + if any(choice.finish_reason == "abort" for choice in response.choices): + state.aborted_requests += 1 + except openai.APIStatusError as e: + if e.status_code == 503: + state.got_503 = True + elif e.status_code == 500: + state.got_500 = True + else: + state.errors.append(f"API error: {e}") + except (openai.APIConnectionError, httpx.RemoteProtocolError): + state.connection_errors += 1 + if sigterm_sent is not None and sigterm_sent.is_set(): + break + except Exception as e: + state.errors.append(f"Unexpected error: {e}") + break + await asyncio.sleep(0.01) + + tasks = [asyncio.create_task(single_request()) for _ in range(concurrency)] + try: + await asyncio.gather(*tasks, return_exceptions=True) + finally: + for t in tasks: + if not t.done(): + t.cancel() @pytest.mark.asyncio @@ -103,3 +204,361 @@ async def test_shutdown_on_engine_failure(): return_code = proc.wait(timeout=_PROCESS_EXIT_TIMEOUT) assert return_code is not None + + +@pytest.mark.asyncio +async def test_wait_timeout_completes_requests(): + """Verify wait timeout: new requests rejected, in-flight requests complete.""" + server_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "256", + "--enforce-eager", + "--gpu-memory-utilization", + "0.05", + "--max-num-seqs", + "4", + "--shutdown-timeout", + "30", + ] + + with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: + client = remote_server.get_async_client() + proc = remote_server.proc + child_pids = _get_child_pids(proc.pid) + + state = ShutdownState() + sigterm_sent = asyncio.Event() + + request_task = asyncio.create_task( + _concurrent_request_loop(client, state, sigterm_sent, concurrency=10) + ) + + await asyncio.sleep(0.5) + proc.send_signal(signal.SIGTERM) + sigterm_sent.set() + + try: + await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT) + except asyncio.TimeoutError: + pass + finally: + state.stop_requesting = True + if not request_task.done(): + request_task.cancel() + await asyncio.gather(request_task, return_exceptions=True) + + # wait timeout should complete in-flight requests + assert state.requests_after_sigterm > 0, ( + f"Wait timeout should complete in-flight requests. " + f"503: {state.got_503}, 500: {state.got_500}, " + f"conn_errors: {state.connection_errors}, errors: {state.errors}" + ) + # server must stop accepting new requests (503, 500, or connection close) + assert state.got_503 or state.got_500 or state.connection_errors > 0, ( + f"Server should stop accepting requests. " + f"completed: {state.requests_after_sigterm}, errors: {state.errors}" + ) + + await _assert_children_cleaned_up(child_pids) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("wait_for_engine_idle", [0.0, 2.0]) +async def test_abort_timeout_exits_quickly(wait_for_engine_idle: float): + server_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "256", + "--enforce-eager", + "--gpu-memory-utilization", + "0.05", + "--max-num-seqs", + "4", + "--shutdown-timeout", + "0", + ] + + with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: + proc = remote_server.proc + child_pids = _get_child_pids(proc.pid) + + if wait_for_engine_idle > 0: + client = remote_server.get_async_client() + # Send requests to ensure engine is fully initialized + for _ in range(2): + await client.completions.create( + model=MODEL_NAME, + prompt="Test request: ", + max_tokens=10, + ) + # Wait for engine to become idle + await asyncio.sleep(wait_for_engine_idle) + + start_time = time.time() + proc.send_signal(signal.SIGTERM) + + # abort timeout (0) should exit promptly + for _ in range(20): + if proc.poll() is not None: + break + time.sleep(0.1) + + if proc.poll() is None: + proc.kill() + proc.wait(timeout=5) + pytest.fail("Process did not exit after SIGTERM with abort timeout") + + exit_time = time.time() - start_time + assert exit_time < 2, f"Default shutdown took too long: {exit_time:.1f}s" + assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}" + + await _assert_children_cleaned_up(child_pids) + + +@pytest.mark.asyncio +async def test_wait_timeout_with_short_duration(): + """Verify server exits cleanly with a short wait timeout.""" + wait_timeout = 3 + server_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "256", + "--enforce-eager", + "--gpu-memory-utilization", + "0.05", + "--max-num-seqs", + "4", + "--shutdown-timeout", + str(wait_timeout), + ] + + with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: + client = remote_server.get_async_client() + proc = remote_server.proc + child_pids = _get_child_pids(proc.pid) + + state = ShutdownState() + request_task = asyncio.create_task( + _concurrent_request_loop(client, state, concurrency=3) + ) + + await asyncio.sleep(0.5) + + start_time = time.time() + proc.send_signal(signal.SIGTERM) + + # server should exit within wait_timeout + buffer + max_wait = wait_timeout + 15 + for _ in range(int(max_wait * 10)): + if proc.poll() is not None: + break + time.sleep(0.1) + + exit_time = time.time() - start_time + + state.stop_requesting = True + if not request_task.done(): + request_task.cancel() + await asyncio.gather(request_task, return_exceptions=True) + + if proc.poll() is None: + proc.kill() + proc.wait(timeout=5) + pytest.fail(f"Process did not exit within {max_wait}s after SIGTERM") + + assert exit_time < wait_timeout + 10, ( + f"Took too long to exit ({exit_time:.1f}s), expected <{wait_timeout + 10}s" + ) + assert proc.returncode in (0, -15, None), f"Unexpected: {proc.returncode}" + + await _assert_children_cleaned_up(child_pids) + + +@pytest.mark.asyncio +async def test_abort_timeout_fails_inflight_requests(): + """Verify abort timeout (0) immediately aborts in-flight requests.""" + server_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "256", + "--enforce-eager", + "--gpu-memory-utilization", + "0.05", + "--max-num-seqs", + "4", + "--shutdown-timeout", + "0", + ] + + with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: + client = remote_server.get_async_client() + proc = remote_server.proc + child_pids = _get_child_pids(proc.pid) + + state = ShutdownState() + sigterm_sent = asyncio.Event() + + request_task = asyncio.create_task( + _concurrent_request_loop(client, state, sigterm_sent, concurrency=10) + ) + + await asyncio.sleep(0.5) + + proc.send_signal(signal.SIGTERM) + sigterm_sent.set() + + try: + await asyncio.wait_for(request_task, timeout=5) + except asyncio.TimeoutError: + pass + finally: + state.stop_requesting = True + if not request_task.done(): + request_task.cancel() + await asyncio.gather(request_task, return_exceptions=True) + + # With abort timeout (0), requests should be aborted (finish_reason='abort') + # or rejected (connection errors or API errors) + assert ( + state.aborted_requests > 0 + or state.connection_errors > 0 + or state.got_500 + or state.got_503 + ), ( + f"Abort timeout should cause request aborts or failures. " + f"aborted: {state.aborted_requests}, " + f"503: {state.got_503}, 500: {state.got_500}, " + f"conn_errors: {state.connection_errors}, " + f"completed: {state.requests_after_sigterm}" + ) + + # Verify fast shutdown + start_time = time.time() + for _ in range(100): + if proc.poll() is not None: + break + time.sleep(0.1) + + exit_time = time.time() - start_time + assert exit_time < 10, f"Abort timeout shutdown took too long: {exit_time:.1f}s" + + await _assert_children_cleaned_up(child_pids) + + +@pytest.mark.asyncio +async def test_request_rejection_during_shutdown(): + """Verify new requests are rejected with error during shutdown.""" + server_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "256", + "--enforce-eager", + "--gpu-memory-utilization", + "0.05", + "--max-num-seqs", + "4", + "--shutdown-timeout", + "30", + ] + + with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: + client = remote_server.get_async_client() + proc = remote_server.proc + child_pids = _get_child_pids(proc.pid) + + proc.send_signal(signal.SIGTERM) + + await asyncio.sleep(1.0) + + # Try to send new requests - they should be rejected + rejected_count = 0 + for _ in range(10): + try: + await client.completions.create( + model=MODEL_NAME, prompt="Hello", max_tokens=10 + ) + except ( + openai.APIStatusError, + openai.APIConnectionError, + httpx.RemoteProtocolError, + ): + rejected_count += 1 + await asyncio.sleep(0.1) + + assert rejected_count > 0, ( + f"Expected requests to be rejected during shutdown, " + f"but {rejected_count} were rejected out of 10" + ) + + await _assert_children_cleaned_up(child_pids) + + +@pytest.mark.asyncio +async def test_multi_api_server_shutdown(): + """Verify shutdown works with multiple API servers.""" + server_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "256", + "--enforce-eager", + "--gpu-memory-utilization", + "0.05", + "--max-num-seqs", + "4", + "--shutdown-timeout", + "30", + "--api-server-count", + "2", + ] + + with RemoteOpenAIServer(MODEL_NAME, server_args, auto_port=True) as remote_server: + client = remote_server.get_async_client() + proc = remote_server.proc + child_pids = _get_child_pids(proc.pid) + + assert len(child_pids) >= 2, ( + f"Expected at least 2 child processes, got {len(child_pids)}" + ) + + state = ShutdownState() + sigterm_sent = asyncio.Event() + + # Start concurrent requests across both API servers + request_task = asyncio.create_task( + _concurrent_request_loop(client, state, sigterm_sent, concurrency=8) + ) + + await asyncio.sleep(0.5) + + # Send SIGTERM to parent - should propagate to all children + proc.send_signal(signal.SIGTERM) + sigterm_sent.set() + + try: + await asyncio.wait_for(request_task, timeout=_SHUTDOWN_DETECTION_TIMEOUT) + except asyncio.TimeoutError: + pass + finally: + state.stop_requesting = True + if not request_task.done(): + request_task.cancel() + await asyncio.gather(request_task, return_exceptions=True) + + for _ in range(300): # up to 30 seconds + if proc.poll() is not None: + break + time.sleep(0.1) + + if proc.poll() is None: + proc.kill() + proc.wait(timeout=5) + pytest.fail("Process did not exit after SIGTERM") + + await _assert_children_cleaned_up(child_pids) diff --git a/tests/entrypoints/test_api_server_process_manager.py b/tests/entrypoints/test_api_server_process_manager.py index 3fadbf2ef0dd..3820fdefb194 100644 --- a/tests/entrypoints/test_api_server_process_manager.py +++ b/tests/entrypoints/test_api_server_process_manager.py @@ -79,7 +79,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update): finally: # Always clean up the processes print("Cleaning up processes...") - manager.close() + manager.shutdown() # Give processes time to terminate time.sleep(0.2) @@ -111,6 +111,8 @@ def run_with_exception_capture(): wait_for_completion_or_failure(api_server_manager=manager) except Exception as e: result["exception"] = e + finally: + manager.shutdown() # Start a thread to run wait_for_completion_or_failure wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) @@ -143,7 +145,7 @@ def run_with_exception_capture(): assert not proc.is_alive(), f"Process {i} should not be alive" finally: - manager.close() + manager.shutdown() time.sleep(0.2) @@ -174,11 +176,14 @@ def test_normal_completion(api_server_args): # since all processes have already # terminated, it should return immediately # with no error - wait_for_completion_or_failure(api_server_manager=manager) + try: + wait_for_completion_or_failure(api_server_manager=manager) + finally: + manager.shutdown() finally: # Clean up just in case - manager.close() + manager.shutdown() time.sleep(0.2) @@ -201,7 +206,7 @@ class MockCoordinator: def __init__(self, proc): self.proc = proc - def close(self): + def shutdown(self): if self.proc.is_alive(): self.proc.terminate() self.proc.join(timeout=0.5) @@ -226,6 +231,9 @@ def run_with_exception_capture(): ) except Exception as e: result["exception"] = e + finally: + manager.shutdown() + mock_coordinator.shutdown() # Start a thread to run wait_for_completion_or_failure wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True) @@ -259,6 +267,6 @@ def run_with_exception_capture(): finally: # Clean up - manager.close() - mock_coordinator.close() + manager.shutdown() + mock_coordinator.shutdown() time.sleep(0.2) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 4df1015c0e13..a7c43135388c 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -327,6 +327,12 @@ class VllmConfig: weight_transfer_config: WeightTransferConfig | None = None """The configurations for weight transfer during RL training.""" + shutdown_timeout: int = Field(default=0, ge=0) + """Shutdown grace period for in-flight requests. Shutdown will be delayed for + up to this amount of time to allow already-running requests to complete. Any + remaining requests are aborted once the timeout is reached. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 93384fd78cd7..ebbe39de1d17 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -607,6 +607,8 @@ class EngineArgs: kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend tokens_only: bool = False + shutdown_timeout: int = 0 + weight_transfer_config: WeightTransferConfig | None = get_field( VllmConfig, "weight_transfer_config", @@ -1306,6 +1308,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, action=argparse.BooleanOptionalAction, ) + + parser.add_argument( + "--shutdown-timeout", + type=int, + default=0, + help="Shutdown timeout in seconds. 0 = abort, >0 = wait.", + ) + return parser @classmethod @@ -1914,6 +1924,7 @@ def create_engine_config( optimization_level=self.optimization_level, performance_mode=self.performance_mode, weight_transfer_config=self.weight_transfer_config, + shutdown_timeout=self.shutdown_timeout, ) return config diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index ea2bf5303b5f..0b3b29cd6c1f 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -200,6 +200,11 @@ async def is_paused(self) -> bool: """Return whether the engine is currently paused.""" ... + @abstractmethod + def shutdown(self, timeout: float | None = None) -> None: + """Shutdown the engine with optional timeout.""" + ... + async def scale_elastic_ep( self, new_data_parallel_size: int, drain_timeout: int = 300 ) -> None: diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 944fb88a00cd..04a07ea84428 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -3,6 +3,7 @@ import argparse import signal +import time import uvloop @@ -211,8 +212,12 @@ def signal_handler(signum, frame): try: engine_manager.join_first() finally: + timeout = None + if shutdown_requested: + timeout = vllm_config.shutdown_timeout + logger.info("Waiting up to %d seconds for processes to exit", timeout) + engine_manager.shutdown(timeout=timeout) logger.info("Shutting down.") - engine_manager.close() def run_multi_api_server(args: argparse.Namespace): @@ -229,6 +234,19 @@ def run_multi_api_server(args: argparse.Namespace): if num_api_servers > 1: setup_multiprocess_prometheus() + shutdown_requested = False + + # Catch SIGTERM and SIGINT to allow graceful shutdown. + def signal_handler(signum, frame): + nonlocal shutdown_requested + logger.debug("Received %d signal.", signum) + if not shutdown_requested: + shutdown_requested = True + raise SystemExit + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + listen_address, sock = setup_server(args) engine_args = vllm.AsyncEngineArgs.from_cli_args(args) @@ -290,11 +308,29 @@ def run_multi_api_server(args: argparse.Namespace): api_server_manager = APIServerProcessManager(**api_server_manager_kwargs) # Wait for API servers - wait_for_completion_or_failure( - api_server_manager=api_server_manager, - engine_manager=local_engine_manager, - coordinator=coordinator, - ) + try: + wait_for_completion_or_failure( + api_server_manager=api_server_manager, + engine_manager=local_engine_manager, + coordinator=coordinator, + ) + finally: + timeout = shutdown_by = None + if shutdown_requested: + timeout = vllm_config.shutdown_timeout + shutdown_by = time.monotonic() + timeout + logger.info("Waiting up to %d seconds for processes to exit", timeout) + + def to_timeout(deadline: float | None) -> float | None: + return ( + deadline if deadline is None else max(deadline - time.monotonic(), 0.0) + ) + + api_server_manager.shutdown(timeout=timeout) + if local_engine_manager: + local_engine_manager.shutdown(timeout=to_timeout(shutdown_by)) + if coordinator: + coordinator.shutdown(timeout=to_timeout(shutdown_by)) def run_api_server_worker_proc( diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index b442fc70cdb0..8caeb80836f9 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -4,6 +4,7 @@ import asyncio import signal import socket +from functools import partial from typing import Any import uvicorn @@ -91,12 +92,10 @@ async def serve_http( ) ) + shutdown_event = asyncio.Event() + def signal_handler() -> None: - # prevents the uvicorn signal handler to exit early - server_task.cancel() - watchdog_task.cancel() - if ssl_cert_refresher: - ssl_cert_refresher.stop() + shutdown_event.set() async def dummy_shutdown() -> None: pass @@ -104,6 +103,24 @@ async def dummy_shutdown() -> None: loop.add_signal_handler(signal.SIGINT, signal_handler) loop.add_signal_handler(signal.SIGTERM, signal_handler) + async def handle_shutdown() -> None: + await shutdown_event.wait() + + engine_client = app.state.engine_client + timeout = engine_client.vllm_config.shutdown_timeout + + await loop.run_in_executor( + None, partial(engine_client.shutdown, timeout=timeout) + ) + + server.should_exit = True + server_task.cancel() + watchdog_task.cancel() + if ssl_cert_refresher: + ssl_cert_refresher.stop() + + shutdown_task = loop.create_task(handle_shutdown()) + try: await server_task return dummy_shutdown() @@ -120,6 +137,7 @@ async def dummy_shutdown() -> None: logger.info("Shutting down FastAPI HTTP server.") return server.shutdown() finally: + shutdown_task.cancel() watchdog_task.cancel() diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 07c98513af68..969b441dab78 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -238,6 +238,8 @@ class EngineCoreRequestType(enum.Enum): UTILITY = b"\x03" # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b"\x04" + # Sentinel to wake up input_queue.get() during shutdown. + WAKEUP = b"\x05" class ReconfigureDistributedRequest(msgspec.Struct): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 6be0a07baeb2..a9c42e78e53b 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -264,16 +264,15 @@ def from_engine_args( def __del__(self): self.shutdown() - def shutdown(self): + def shutdown(self, timeout: float | None = None) -> None: """Shutdown, cleaning up the background proc and IPC.""" - shutdown_prometheus() if renderer := getattr(self, "renderer", None): renderer.shutdown() if engine_core := getattr(self, "engine_core", None): - engine_core.shutdown() + engine_core.shutdown(timeout=timeout) handler = getattr(self, "output_handler", None) if handler is not None: diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 44a346350fc8..0d07f29a5cb4 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -104,8 +104,10 @@ def get_engine_socket_addresses(self) -> tuple[str, str]: """Returns tuple of ZMQ input address, output address.""" return self.coord_in_address, self.coord_out_address - def close(self): - self._finalizer() + def shutdown(self, timeout: float | None = None) -> None: + """Shutdown coordinator process with configurable timeout.""" + if self._finalizer.detach() is not None: + shutdown([self.proc], timeout=timeout) class EngineState: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d8e002da5c1d..4bbaafed3ed5 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -9,6 +9,7 @@ from collections.abc import Callable, Generator from concurrent.futures import Future from contextlib import ExitStack, contextmanager +from enum import IntEnum from functools import partial from inspect import isclass, signature from logging import DEBUG @@ -61,6 +62,7 @@ from vllm.v1.engine.utils import ( EngineHandshakeMetadata, EngineZmqAddresses, + SignalCallback, get_device_indices, ) from vllm.v1.executor import Executor @@ -767,6 +769,12 @@ def _eep_send_engine_core_notification( raise NotImplementedError +class EngineShutdownState(IntEnum): + RUNNING = 0 + REQUESTED = 1 + SHUTTING_DOWN = 2 + + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -794,6 +802,7 @@ def __init__( self.engine_index = engine_index identity = self.engine_index.to_bytes(length=2, byteorder="little") self.engines_running = False + self.shutdown_state = EngineShutdownState.RUNNING with self._perform_handshakes( handshake_address, @@ -1024,25 +1033,11 @@ def startup_handshake( def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs): """Launch EngineCore busy loop in background process.""" - # Signal handler used for graceful termination. - # SystemExit exception is only raised once to allow this and worker - # processes to terminate without error - shutdown_requested = False - # Ensure we can serialize transformer config after spawning maybe_register_config_serialize_by_value() - def signal_handler(signum, frame): - nonlocal shutdown_requested - if not shutdown_requested: - shutdown_requested = True - raise SystemExit() - - # Either SIGTERM or SIGINT will terminate the engine_core - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - engine_core: EngineCoreProc | None = None + signal_callback: SignalCallback | None = None try: vllm_config: VllmConfig = kwargs["vllm_config"] parallel_config: ParallelConfig = vllm_config.parallel_config @@ -1090,6 +1085,22 @@ def signal_handler(signum, frame): engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs) assert engine_core is not None + + def wakeup_engine(): + # Wakes up idle engine via input_queue when shutdown is requested + # Not safe in a signal handler - we may interrupt the main thread + # while it is holding the non-reentrant input_queue.mutex + engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None)) + + signal_callback = SignalCallback(wakeup_engine) + + def signal_handler(signum, frame): + engine_core.shutdown_state = EngineShutdownState.REQUESTED + signal_callback.trigger() + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + engine_core.run_busy_loop() except SystemExit: @@ -1103,6 +1114,10 @@ def signal_handler(signum, frame): engine_core._send_engine_dead() raise e finally: + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + if signal_callback is not None: + signal_callback.stop() if engine_core is not None: engine_core.shutdown() @@ -1117,21 +1132,25 @@ def has_work(self) -> bool: or bool(self.batch_queue) ) + def is_running(self) -> bool: + """Returns true if shutdown has not been requested.""" + return self.shutdown_state == EngineShutdownState.RUNNING + def run_busy_loop(self): """Core busy loop of the EngineCore.""" - - # Loop until process is sent a SIGINT or SIGTERM - while True: + while self._handle_shutdown(): # 1) Poll the input queue until there is work to do. self._process_input_queue() # 2) Step the engine core and return the outputs. self._process_engine_step() + raise SystemExit + def _process_input_queue(self): """Exits when an engine step needs to be performed.""" waited = False - while not self.has_work(): + while not self.has_work() and self.is_running(): # Notify callbacks waiting for engine to become idle. self._notify_idle_state_callbacks() if self.input_queue.empty(): @@ -1183,18 +1202,60 @@ def _notify_idle_state_callbacks(self) -> None: callback = self._idle_state_callbacks.pop() callback(self) + def _handle_shutdown(self) -> bool: + # Check if shutdown was requested and handle it + if self.shutdown_state == EngineShutdownState.RUNNING: + return True + + if self.shutdown_state == EngineShutdownState.REQUESTED: + shutdown_timeout = self.vllm_config.shutdown_timeout + + logger.info("Shutdown initiated (timeout=%d)", shutdown_timeout) + + if shutdown_timeout == 0: + num_requests = self.scheduler.get_num_unfinished_requests() + if num_requests > 0: + logger.info("Aborting %d requests", num_requests) + aborted_reqs = self.scheduler.finish_requests( + None, RequestStatus.FINISHED_ABORTED + ) + self._send_abort_outputs(aborted_reqs) + else: + num_requests = self.scheduler.get_num_unfinished_requests() + if num_requests > 0: + logger.info( + "Draining %d in-flight requests (timeout=%ds)", + num_requests, + shutdown_timeout, + ) + + self.shutdown_state = EngineShutdownState.SHUTTING_DOWN + + # Exit when no work remaining + if not self.has_work(): + logger.info("Shutdown complete") + return False + + return True + def _handle_client_request( self, request_type: EngineCoreRequestType, request: Any ) -> None: """Dispatch request from client.""" - if request_type == EngineCoreRequestType.ADD: + if request_type == EngineCoreRequestType.WAKEUP: + return + elif request_type == EngineCoreRequestType.ADD: req, request_wave = request + if self._reject_add_in_shutdown(req): + return self.add_request(req, request_wave) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) elif request_type == EngineCoreRequestType.UTILITY: client_idx, call_id, method_name, args = request + if self._reject_utility_in_shutdown(client_idx, call_id, method_name): + return output = UtilityOutput(call_id) # Lazily look-up utility method so that failure will be handled/returned. get_result = lambda: (method := getattr(self, method_name)) and method( @@ -1211,6 +1272,27 @@ def _handle_client_request( "Unrecognized input request type encountered: %s", request_type ) + def _reject_add_in_shutdown(self, request: Request) -> bool: + if self.shutdown_state == EngineShutdownState.RUNNING: + return False + + logger.info("Rejecting request %s (server shutting down)", request.request_id) + self._send_abort_outputs_to_client([request.request_id], request.client_index) + return True + + def _reject_utility_in_shutdown( + self, client_idx: int, call_id: int, method_name: str + ) -> bool: + if self.shutdown_state == EngineShutdownState.RUNNING: + return False + + logger.warning("Rejecting utility call %s (server shutting down)", method_name) + output = UtilityOutput(call_id, failure_message="Server shutting down") + self.output_queue.put_nowait( + (client_idx, EngineCoreOutputs(utility_output=output)) + ) + return True + @staticmethod def _invoke_utility_method( name: str, get_result: Callable, output: UtilityOutput, enqueue_output: Callable @@ -1424,22 +1506,7 @@ def _handle_request_preproc_error(self, request: EngineCoreRequest) -> None: logger.exception( "Unexpected error pre-processing request %s", request.request_id ) - self.output_queue.put_nowait( - ( - request.client_index, - EngineCoreOutputs( - engine_index=self.engine_index, - finished_requests={request.request_id}, - outputs=[ - EngineCoreOutput( - request_id=request.request_id, - new_token_ids=[], - finish_reason=FinishReason.ERROR, - ) - ], - ), - ) - ) + self._send_error_outputs_to_client([request.request_id], request.client_index) def pause_scheduler( self, mode: PauseMode = "abort", clear_cache: bool = True @@ -1482,6 +1549,26 @@ def engine_idle_callback(engine: "EngineCoreProc", future: Future[Any]) -> None: self._idle_state_callbacks.append(partial(engine_idle_callback, future=future)) return future + def _send_finish_outputs_to_client( + self, req_ids: list[str], client_index: int, finish_reason: FinishReason + ) -> None: + outputs = [ + EngineCoreOutput(req_id, [], finish_reason=finish_reason) + for req_id in req_ids + ] + eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs) + self.output_queue.put_nowait((client_index, eco)) + + def _send_abort_outputs_to_client( + self, req_ids: list[str], client_index: int + ) -> None: + self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ABORT) + + def _send_error_outputs_to_client( + self, req_ids: list[str], client_index: int + ) -> None: + self._send_finish_outputs_to_client(req_ids, client_index, FinishReason.ERROR) + def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None: # TODO(nick) this will be moved inside the scheduler if aborted_reqs: @@ -1490,12 +1577,7 @@ def _send_abort_outputs(self, aborted_reqs: list[tuple[str, int]]) -> None: for req_id, client_index in aborted_reqs: by_client[client_index].add(req_id) for client_index, req_ids in by_client.items(): - outputs = [ - EngineCoreOutput(req_id, [], finish_reason=FinishReason.ABORT) - for req_id in req_ids - ] - eco = EngineCoreOutputs(finished_requests=req_ids, outputs=outputs) - self.output_queue.put_nowait((client_index, eco)) + self._send_abort_outputs_to_client(list(req_ids), client_index) class DPEngineCoreProc(EngineCoreProc): @@ -1613,7 +1695,7 @@ def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" # Loop until process is sent a SIGINT or SIGTERM - while True: + while self._handle_shutdown(): # 1) Poll the input queue until there is work to do. self._process_input_queue() @@ -1661,6 +1743,8 @@ def run_busy_loop(self): self.current_wave += 1 self.step_counter = 0 + raise SystemExit + def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: # Optimization - only perform finish-sync all-reduce every 32 steps. self.step_counter += 1 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 7e1f1cf418bf..4ff51103a681 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -127,7 +127,7 @@ def make_async_mp_client( return AsyncMPClient(*client_args) @abstractmethod - def shutdown(self): ... + def shutdown(self, timeout: float | None = None) -> None: ... def get_output(self) -> EngineCoreOutputs: raise NotImplementedError @@ -297,7 +297,7 @@ def abort_requests(self, request_ids: list[str]) -> None: if len(request_ids) > 0: self.engine_core.abort_requests(request_ids) - def shutdown(self) -> None: + def shutdown(self, timeout: float | None = None) -> None: self.engine_core.shutdown() def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: @@ -389,9 +389,9 @@ def __call__(self): self.engine_dead = True if self.engine_manager is not None: - self.engine_manager.close() + self.engine_manager.shutdown() if self.coordinator is not None: - self.coordinator.close() + self.coordinator.shutdown() if isinstance(self.output_socket, zmq.asyncio.Socket): # Async case. @@ -636,9 +636,12 @@ def __init__( if not success: self._finalizer() - def shutdown(self): - # Terminate background resources. - self._finalizer() + def shutdown(self, timeout: float | None = None) -> None: + """Shutdown engine manager under timeout and clean up resources.""" + self._finalizer.detach() + if self.resources.engine_manager is not None: + self.resources.engine_manager.shutdown(timeout=timeout) + self.resources() def _format_exception(self, e: Exception) -> Exception: """If errored, use EngineDeadError so root cause is clear.""" diff --git a/vllm/v1/engine/launch.py b/vllm/v1/engine/launch.py index c3d9f32f39b5..2d92db4c93d9 100644 --- a/vllm/v1/engine/launch.py +++ b/vllm/v1/engine/launch.py @@ -119,6 +119,9 @@ async def resume_generation(self) -> None: async def is_paused(self) -> bool: return False + def shutdown(self, timeout: float | None = None) -> None: + pass + async def encode( self, prompt: PromptType | ProcessorInputs, diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index a7d3c10b5752..321f84ea2a54 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -3,6 +3,7 @@ import contextlib import os +import threading import weakref from collections.abc import Callable, Iterator from dataclasses import dataclass @@ -151,11 +152,12 @@ def __init__( finally: # Kill other procs if not all are running. if self.finished_procs(): - self.close() + self.shutdown() - def close(self): - """Shutdown all procs.""" - self._finalizer() + def shutdown(self, timeout: float | None = None) -> None: + """Shutdown engine core processes with configurable timeout.""" + if self._finalizer.detach() is not None: + shutdown(self.processes, timeout=timeout) def join_first(self): """Wait for any process to exit.""" @@ -173,6 +175,33 @@ def finished_procs(self) -> dict[str, int]: } +class SignalCallback: + """Safely trigger a callback from signal handler context via a dedicated thread.""" + + def __init__(self, callback: Callable[[], None]): + self._callback = callback + self._event = threading.Event() + self._stopped = False + self._thread = threading.Thread( + target=self._run, + daemon=True, + name="signal-callback", + ) + self._thread.start() + + def _run(self): + self._event.wait() + if not self._stopped: + self._callback() + + def trigger(self): + self._event.set() + + def stop(self): + self._stopped = True + self._event.set() + + @contextlib.contextmanager def set_device_control_env_var( vllm_config: VllmConfig, local_dp_rank: int @@ -768,7 +797,7 @@ def scale_down_elastic_ep( def get_run_refs(self): return self.run_refs - def close(self): + def shutdown(self, timeout: float | None = None) -> None: import ray for actor in self.local_engine_actors + self.remote_engine_actors: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 3d065927ed7e..970465089e10 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -220,8 +220,10 @@ def __init__( # The extra processes are managed by their owners self._finalizer = weakref.finalize(self, shutdown, self.processes) - def close(self) -> None: - self._finalizer() + def shutdown(self, timeout: float | None = None) -> None: + """Shutdown API server processes with configurable timeout""" + if self._finalizer.detach() is not None: + shutdown(self.processes, timeout=timeout) def wait_for_completion_or_failure( @@ -288,25 +290,30 @@ def wait_for_completion_or_failure( except Exception as e: logger.exception("Exception occurred while running API servers: %s", str(e)) raise - finally: - logger.info("Terminating remaining processes ...") - api_server_manager.close() - if coordinator: - coordinator.close() - if engine_manager: - engine_manager.close() # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. -def shutdown(procs: list[BaseProcess]): +def shutdown(procs: list[BaseProcess], timeout: float | None = None) -> None: + """Shutdown processes with timeout. + + Args: + procs: List of processes to shutdown + timeout: Maximum time in seconds to wait for graceful shutdown + """ + if timeout is None: + timeout = 0.0 + + # Allow at least 5 seconds for remaining procs to terminate. + timeout = max(timeout, 5.0) + # Shutdown the process. for proc in procs: if proc.is_alive(): proc.terminate() - # Allow 5 seconds for remaining procs to terminate. - deadline = time.monotonic() + 5 + # Allow time for remaining procs to terminate. + deadline = time.monotonic() + timeout for proc in procs: remaining = deadline - time.monotonic() if remaining <= 0: