diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md index 41aa48c1735..5016632cbab 100644 --- a/docs/features/sleep_mode.md +++ b/docs/features/sleep_mode.md @@ -37,3 +37,26 @@ Example: ```python omni = Omni(model=...,enable_sleep_mode=True) ``` + +## API Usage + +### Sleep +To sleep current model, it can release GPU memory. + +``` +POST /sleep?level=1 +``` + +### Wake Up +To wake up current sleep model, it can occupancy GPU memory. + +``` +POST /wake_up +``` + +### Get Sleep info +To search current model sleep info, about sleep level and whether sleep. + +``` +GET /sleep_info +``` diff --git a/tests/entrypoints/openai_api/test_sleep_api.py b/tests/entrypoints/openai_api/test_sleep_api.py new file mode 100644 index 00000000000..3b27a4ced4c --- /dev/null +++ b/tests/entrypoints/openai_api/test_sleep_api.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from argparse import Namespace +from http import HTTPStatus + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from vllm_omni.entrypoints.openai.api_server import sleep_router + + +class FakeEngine: + """Full-featured stub that records calls and tracks sleep state.""" + + def __init__(self): + self._is_sleeping = False + self._sleep_level: int | None = None + self.sleep_calls: list[tuple] = [] # [(level, mode), ...] + self.wake_up_calls: list[tuple] = [] # [(tags,), ...] + + async def sleep(self, level: int = 1, mode: str = "abort") -> None: + self._is_sleeping = True + self._sleep_level = level + self.sleep_calls.append((level, mode)) + + async def wake_up(self, tags: list[str] | None = None) -> None: + self._is_sleeping = False + self._sleep_level = None + self.wake_up_calls.append((tags,)) + + async def is_sleeping(self) -> bool: + return self._is_sleeping + + async def sleep_level(self) -> int: + return self._sleep_level + + +@pytest.fixture +def client() -> TestClient: + app = FastAPI() + app.include_router(sleep_router) + app.state.engine_client = FakeEngine() + app.state.args = Namespace(enable_sleep_mode=True) + return TestClient(app, raise_server_exceptions=False) + + +def test_sleep(client: TestClient): + assert client.post("/sleep?level=1").status_code == HTTPStatus.OK + + client.post("/sleep?level=1") + engine: FakeEngine = client.app.state.engine_client + level, model = engine.sleep_calls[-1] + assert level == 1 + assert model == "abort" + + assert client.post("/sleep?level=3").status_code == HTTPStatus.BAD_REQUEST + assert client.post("/sleep?level=abc").status_code == HTTPStatus.BAD_REQUEST + + client.post("/sleep?level=2&mode=drain") + engine: FakeEngine = client.app.state.engine_client + level, model = engine.sleep_calls[-1] + assert level == 2 + assert model == "abort" + + +def test_wakeup(client: TestClient): + assert client.post("/wake_up").status_code == HTTPStatus.OK + + client.post("/wake_up") + engine: FakeEngine = client.app.state.engine_client + assert not engine._is_sleeping + + client.post("/wake_up?tags=") + engine: FakeEngine = client.app.state.engine_client + assert engine.wake_up_calls[-1] == ([""],) + + client.post("/wake_up?tags=weights") + engine: FakeEngine = client.app.state.engine_client + assert engine.wake_up_calls[-1] == (["weights"],) + + +def test_sleep_info(client: TestClient): + client.post("/sleep?level=1") + assert client.get("/sleep_info").status_code == HTTPStatus.OK + + client.post("/wake_up") + assert client.get("/sleep_info").json() == {"sleep_level": None, "is_sleeping": False} + + client.post("/sleep?level=1") + assert client.get("/sleep_info").json() == {"sleep_level": 1, "is_sleeping": True} + + client.post("/sleep?level=2") + assert client.get("/sleep_info").json() == {"sleep_level": 2, "is_sleeping": True} diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 129ef3c99d8..2af89241e59 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -75,6 +75,7 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None: self._pause_cond: asyncio.Condition = asyncio.Condition() self._paused: bool = False self._is_sleeping: bool = False + self._sleep_level: int = None self.final_output_task: asyncio.Task | None = None self.config_path = self.engine.config_path @@ -648,6 +649,7 @@ async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None: Best-effort: unsupported stages will emit a TODO result. """ self._is_sleeping = True + self._sleep_level = level await self.collective_rpc(method="sleep", args=(level,)) async def wake_up(self, tags: list[str] | None = None) -> None: @@ -656,6 +658,7 @@ async def wake_up(self, tags: list[str] | None = None) -> None: Best-effort: unsupported stages will emit a TODO result. """ self._is_sleeping = False + self._sleep_level = -1 await self.collective_rpc(method="wake_up", args=(tags,)) async def is_sleeping(self) -> bool: @@ -666,6 +669,9 @@ async def is_sleeping(self) -> bool: """ return self._is_sleeping + async def sleep_level(self) -> int: + return self._sleep_level + async def add_lora(self, lora_request: LoRARequest) -> bool: """Load a new LoRA adapter into all stages. diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index defaa9822cc..6c794388730 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -121,6 +121,7 @@ router = APIRouter() profiler_router = APIRouter() +sleep_router = APIRouter() def _should_enable_profiler_endpoints(stage_configs: list | None) -> bool: @@ -313,6 +314,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, logger.warning("Profiler endpoints are enabled. This should ONLY be used for local development!") app.include_router(profiler_router) + # Conditionally register sleep endpoints + if args.enable_sleep_mode: + logger.info("Sleep endpoints are enabled.") + app.include_router(sleep_router) + vllm_config = await _get_vllm_config(engine_client) # Check if pure diffusion mode (vllm_config will be None) @@ -2405,3 +2411,80 @@ async def stop_profile(raw_request: Request, request: ProfileRequest | None = No raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}" ) + + +@sleep_router.post("/sleep") +async def sleep(raw_request: Request): + raw_level = raw_request.query_params.get("level", "1") + _VALID_LEVELS = {1, 2} + try: + level = int(raw_level) + except ValueError: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"Invalid 'level' value {raw_level!r}: must be an integer.", + ) + if level not in _VALID_LEVELS: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"Invalid 'level' value {level}: must be one of {sorted(_VALID_LEVELS)}.", + ) + try: + engine_client = raw_request.app.state.engine_client + await engine_client.sleep(int(level)) + return JSONResponse(content="ok") + except Exception as e: + logger.exception("Failed to sleep model: %s", e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=f"Failed to sleep model: {e}", + ) + + +@sleep_router.post("/wake_up") +async def wake_up(raw_request: Request): + """ + Wake up the worker from sleep mode. See the sleep function + method for more details. + + Args: + tags: An optional list of tags to reallocate the worker memory + for specific memory allocations. Values must be in + `("weights")`. If None, all memory is reallocated. + wake_up should be called with all tags (or None) before the + worker is used again. + """ + tags = raw_request.query_params.getlist("tags") + if tags == []: + tags = None + try: + engine_client = raw_request.app.state.engine_client + await engine_client.wake_up(tags) + return JSONResponse(content="ok") + except Exception as e: + logger.exception("Failed to wake_up model: %s", e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=f"Failed to wake_up model: {e}", + ) + + +@sleep_router.get("/sleep_info") +async def sleep_info(raw_request: Request) -> JSONResponse: + """Return the current sleep level of the engine. + + Response body: ``{"sleep_level": int | null}`` + - ``null`` engine is awake (not sleeping). + - ``1`` weights are offloaded. + - ``2`` weights offloaded and KV caches saved/reset. + """ + try: + engine_client = raw_request.app.state.engine_client + level = await engine_client.sleep_level() + result = await engine_client.is_sleeping() + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=f"Failed to query sleep level: {e}", + ) + return JSONResponse(content={"sleep_level": level, "is_sleeping": result})