Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
459 changes: 459 additions & 0 deletions tests/entrypoints/openai/test_shutdown.py

Large diffs are not rendered by default.

22 changes: 15 additions & 7 deletions tests/entrypoints/test_api_server_process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
finally:
# Always clean up the processes
print("Cleaning up processes...")
manager.close()
manager.shutdown()

# Give processes time to terminate
time.sleep(0.2)
Expand Down Expand Up @@ -111,6 +111,8 @@ def run_with_exception_capture():
wait_for_completion_or_failure(api_server_manager=manager)
except Exception as e:
result["exception"] = e
finally:
manager.shutdown()

# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
Expand Down Expand Up @@ -143,7 +145,7 @@ def run_with_exception_capture():
assert not proc.is_alive(), f"Process {i} should not be alive"

finally:
manager.close()
manager.shutdown()
time.sleep(0.2)


Expand Down Expand Up @@ -174,11 +176,14 @@ def test_normal_completion(api_server_args):
# since all processes have already
# terminated, it should return immediately
# with no error
wait_for_completion_or_failure(api_server_manager=manager)
try:
wait_for_completion_or_failure(api_server_manager=manager)
finally:
manager.shutdown()

finally:
# Clean up just in case
manager.close()
manager.shutdown()
time.sleep(0.2)


Expand All @@ -201,7 +206,7 @@ class MockCoordinator:
def __init__(self, proc):
self.proc = proc

def close(self):
def shutdown(self):
if self.proc.is_alive():
self.proc.terminate()
self.proc.join(timeout=0.5)
Expand All @@ -226,6 +231,9 @@ def run_with_exception_capture():
)
except Exception as e:
result["exception"] = e
finally:
manager.shutdown()
mock_coordinator.shutdown()

# Start a thread to run wait_for_completion_or_failure
wait_thread = threading.Thread(target=run_with_exception_capture, daemon=True)
Expand Down Expand Up @@ -259,6 +267,6 @@ def run_with_exception_capture():

finally:
# Clean up
manager.close()
mock_coordinator.close()
manager.shutdown()
mock_coordinator.shutdown()
time.sleep(0.2)
6 changes: 6 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ class VllmConfig:
weight_transfer_config: WeightTransferConfig | None = None
"""The configurations for weight transfer during RL training."""

shutdown_timeout: int = Field(default=0, ge=0)
"""Shutdown grace period for in-flight requests. Shutdown will be delayed for
up to this amount of time to allow already-running requests to complete. Any
remaining requests are aborted once the timeout is reached.
"""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down
11 changes: 11 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False

shutdown_timeout: int = 0

weight_transfer_config: WeightTransferConfig | None = get_field(
VllmConfig,
"weight_transfer_config",
Expand Down Expand Up @@ -1306,6 +1308,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
action=argparse.BooleanOptionalAction,
)

parser.add_argument(
"--shutdown-timeout",
type=int,
default=0,
help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
)

return parser

@classmethod
Expand Down Expand Up @@ -1914,6 +1924,7 @@ def create_engine_config(
optimization_level=self.optimization_level,
performance_mode=self.performance_mode,
weight_transfer_config=self.weight_transfer_config,
shutdown_timeout=self.shutdown_timeout,
)

return config
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ async def is_paused(self) -> bool:
"""Return whether the engine is currently paused."""
...

@abstractmethod
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown the engine with optional timeout."""
...

async def scale_elastic_ep(
self, new_data_parallel_size: int, drain_timeout: int = 300
) -> None:
Expand Down
48 changes: 42 additions & 6 deletions vllm/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
import signal
import time

import uvloop

Expand Down Expand Up @@ -211,8 +212,12 @@ def signal_handler(signum, frame):
try:
engine_manager.join_first()
finally:
timeout = None
if shutdown_requested:
timeout = vllm_config.shutdown_timeout
logger.info("Waiting up to %d seconds for processes to exit", timeout)
engine_manager.shutdown(timeout=timeout)
logger.info("Shutting down.")
engine_manager.close()


def run_multi_api_server(args: argparse.Namespace):
Expand All @@ -229,6 +234,19 @@ def run_multi_api_server(args: argparse.Namespace):
if num_api_servers > 1:
setup_multiprocess_prometheus()

shutdown_requested = False

# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
nonlocal shutdown_requested
logger.debug("Received %d signal.", signum)
if not shutdown_requested:
shutdown_requested = True
raise SystemExit

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

listen_address, sock = setup_server(args)

engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
Expand Down Expand Up @@ -290,11 +308,29 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager = APIServerProcessManager(**api_server_manager_kwargs)

# Wait for API servers
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator,
)
try:
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator,
)
finally:
timeout = shutdown_by = None
if shutdown_requested:
timeout = vllm_config.shutdown_timeout
shutdown_by = time.monotonic() + timeout
logger.info("Waiting up to %d seconds for processes to exit", timeout)

def to_timeout(deadline: float | None) -> float | None:
return (
deadline if deadline is None else max(deadline - time.monotonic(), 0.0)
)

api_server_manager.shutdown(timeout=timeout)
if local_engine_manager:
local_engine_manager.shutdown(timeout=to_timeout(shutdown_by))
if coordinator:
coordinator.shutdown(timeout=to_timeout(shutdown_by))


def run_api_server_worker_proc(
Expand Down
28 changes: 23 additions & 5 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import signal
import socket
from functools import partial
from typing import Any

import uvicorn
Expand Down Expand Up @@ -91,19 +92,35 @@ async def serve_http(
)
)

shutdown_event = asyncio.Event()

def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
shutdown_event.set()

async def dummy_shutdown() -> None:
pass

loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)

async def handle_shutdown() -> None:
await shutdown_event.wait()

engine_client = app.state.engine_client
timeout = engine_client.vllm_config.shutdown_timeout

await loop.run_in_executor(
None, partial(engine_client.shutdown, timeout=timeout)
)

server.should_exit = True
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()

shutdown_task = loop.create_task(handle_shutdown())

try:
await server_task
return dummy_shutdown()
Expand All @@ -120,6 +137,7 @@ async def dummy_shutdown() -> None:
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
shutdown_task.cancel()
watchdog_task.cancel()


Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ class EngineCoreRequestType(enum.Enum):
UTILITY = b"\x03"
# Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b"\x04"
# Sentinel to wake up input_queue.get() during shutdown.
WAKEUP = b"\x05"


class ReconfigureDistributedRequest(msgspec.Struct):
Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,16 +264,15 @@ def from_engine_args(
def __del__(self):
self.shutdown()

def shutdown(self):
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown, cleaning up the background proc and IPC."""

shutdown_prometheus()

if renderer := getattr(self, "renderer", None):
renderer.shutdown()

if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
engine_core.shutdown(timeout=timeout)

handler = getattr(self, "output_handler", None)
if handler is not None:
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/engine/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def get_engine_socket_addresses(self) -> tuple[str, str]:
"""Returns tuple of ZMQ input address, output address."""
return self.coord_in_address, self.coord_out_address

def close(self):
self._finalizer()
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown coordinator process with configurable timeout."""
if self._finalizer.detach() is not None:
shutdown([self.proc], timeout=timeout)


class EngineState:
Expand Down
Loading