diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fbb2d32a229d..fb61865063c4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -17,6 +17,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable from contextlib import asynccontextmanager from http import HTTPStatus +from multiprocessing import shared_memory from typing import Annotated, Any, Literal import model_hosting_container_standards.sagemaker as sagemaker_standards @@ -132,6 +133,17 @@ _running_tasks: set[asyncio.Task] = set() +def set_sleep_signal(value: int = 1, shared_memory_name: str = "sleep_signal") -> None: + try: + shm = shared_memory.SharedMemory(name=shared_memory_name, create=False, size=4) + except Exception: + shm = shared_memory.SharedMemory(name=shared_memory_name, create=True, size=4) + + if shm is not None: + shm.buf[0:4] = value.to_bytes(4, "little") + shm.close() + + @asynccontextmanager async def lifespan(app: FastAPI): try: @@ -1082,6 +1094,8 @@ async def reset_mm_cache(raw_request: Request): @router.post("/sleep") async def sleep(raw_request: Request): + set_sleep_signal(1) + # get POST params level = raw_request.query_params.get("level", "1") await engine_client(raw_request).sleep(int(level)) @@ -1091,6 +1105,8 @@ async def sleep(raw_request: Request): @router.post("/wake_up") async def wake_up(raw_request: Request): + set_sleep_signal(0) + tags = raw_request.query_params.getlist("tags") if tags == []: # set to None to wake up all tags if no tags are provided diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 82166dc97839..4274be10e76b 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +from multiprocessing import shared_memory import torch @@ -69,4 +70,18 @@ def check_stop( ): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True + + # Check if the model is sleeping + sleep_signal = 0 + shared_memory_name = "sleep_signal" + try: + shm = shared_memory.SharedMemory(name=shared_memory_name) + sleep_signal = int.from_bytes(shm.buf[0:4], "little") + shm.close() + except Exception: + pass + if sleep_signal == 1: + request.status = RequestStatus.FINISHED_STOPPED + return True + return False