Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/features/sleep_mode.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,26 @@ Example:
```python
omni = Omni(model=...,enable_sleep_mode=True)
```

## API Usage

### Sleep
To sleep current model, it can release GPU memory.

```
POST /sleep?level=1
```

### Wake Up
To wake up current sleep model, it can occupancy GPU memory.

```
POST /wake_up
```

### Get Sleep info
To search current model sleep info, about sleep level and whether sleep.

```
GET /sleep_info
```
95 changes: 95 additions & 0 deletions tests/entrypoints/openai_api/test_sleep_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from argparse import Namespace
from http import HTTPStatus

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from vllm_omni.entrypoints.openai.api_server import sleep_router


class FakeEngine:
"""Full-featured stub that records calls and tracks sleep state."""

def __init__(self):
self._is_sleeping = False
self._sleep_level: int | None = None
self.sleep_calls: list[tuple] = [] # [(level, mode), ...]
self.wake_up_calls: list[tuple] = [] # [(tags,), ...]

async def sleep(self, level: int = 1, mode: str = "abort") -> None:
self._is_sleeping = True
self._sleep_level = level
self.sleep_calls.append((level, mode))

async def wake_up(self, tags: list[str] | None = None) -> None:
self._is_sleeping = False
self._sleep_level = None
self.wake_up_calls.append((tags,))

async def is_sleeping(self) -> bool:
return self._is_sleeping

async def sleep_level(self) -> int:
return self._sleep_level


@pytest.fixture
def client() -> TestClient:
app = FastAPI()
app.include_router(sleep_router)
app.state.engine_client = FakeEngine()
app.state.args = Namespace(enable_sleep_mode=True)
return TestClient(app, raise_server_exceptions=False)


def test_sleep(client: TestClient):
assert client.post("/sleep?level=1").status_code == HTTPStatus.OK

client.post("/sleep?level=1")
engine: FakeEngine = client.app.state.engine_client
level, model = engine.sleep_calls[-1]
assert level == 1
assert model == "abort"

assert client.post("/sleep?level=3").status_code == HTTPStatus.BAD_REQUEST
assert client.post("/sleep?level=abc").status_code == HTTPStatus.BAD_REQUEST

client.post("/sleep?level=2&mode=drain")
engine: FakeEngine = client.app.state.engine_client
level, model = engine.sleep_calls[-1]
assert level == 2
assert model == "abort"


def test_wakeup(client: TestClient):
assert client.post("/wake_up").status_code == HTTPStatus.OK

client.post("/wake_up")
engine: FakeEngine = client.app.state.engine_client
assert not engine._is_sleeping

client.post("/wake_up?tags=")
engine: FakeEngine = client.app.state.engine_client
assert engine.wake_up_calls[-1] == ([""],)

client.post("/wake_up?tags=weights")
engine: FakeEngine = client.app.state.engine_client
assert engine.wake_up_calls[-1] == (["weights"],)


def test_sleep_info(client: TestClient):
client.post("/sleep?level=1")
assert client.get("/sleep_info").status_code == HTTPStatus.OK

client.post("/wake_up")
assert client.get("/sleep_info").json() == {"sleep_level": None, "is_sleeping": False}

client.post("/sleep?level=1")
assert client.get("/sleep_info").json() == {"sleep_level": 1, "is_sleeping": True}

client.post("/sleep?level=2")
assert client.get("/sleep_info").json() == {"sleep_level": 2, "is_sleeping": True}
6 changes: 6 additions & 0 deletions vllm_omni/entrypoints/async_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None:
self._pause_cond: asyncio.Condition = asyncio.Condition()
self._paused: bool = False
self._is_sleeping: bool = False
self._sleep_level: int = None
self.final_output_task: asyncio.Task | None = None

self.config_path = self.engine.config_path
Expand Down Expand Up @@ -648,6 +649,7 @@ async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
Best-effort: unsupported stages will emit a TODO result.
"""
self._is_sleeping = True
self._sleep_level = level
await self.collective_rpc(method="sleep", args=(level,))

async def wake_up(self, tags: list[str] | None = None) -> None:
Expand All @@ -656,6 +658,7 @@ async def wake_up(self, tags: list[str] | None = None) -> None:
Best-effort: unsupported stages will emit a TODO result.
"""
self._is_sleeping = False
self._sleep_level = -1
await self.collective_rpc(method="wake_up", args=(tags,))

async def is_sleeping(self) -> bool:
Expand All @@ -666,6 +669,9 @@ async def is_sleeping(self) -> bool:
"""
return self._is_sleeping

async def sleep_level(self) -> int:
return self._sleep_level

async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into all stages.

Expand Down
83 changes: 83 additions & 0 deletions vllm_omni/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
router = APIRouter()

profiler_router = APIRouter()
sleep_router = APIRouter()


def _should_enable_profiler_endpoints(stage_configs: list | None) -> bool:
Expand Down Expand Up @@ -313,6 +314,11 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None,
logger.warning("Profiler endpoints are enabled. This should ONLY be used for local development!")
app.include_router(profiler_router)

# Conditionally register sleep endpoints
if args.enable_sleep_mode:
logger.info("Sleep endpoints are enabled.")
app.include_router(sleep_router)

vllm_config = await _get_vllm_config(engine_client)

# Check if pure diffusion mode (vllm_config will be None)
Expand Down Expand Up @@ -2405,3 +2411,80 @@ async def stop_profile(raw_request: Request, request: ProfileRequest | None = No
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}"
)


@sleep_router.post("/sleep")
async def sleep(raw_request: Request):
raw_level = raw_request.query_params.get("level", "1")
_VALID_LEVELS = {1, 2}
try:
level = int(raw_level)
except ValueError:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"Invalid 'level' value {raw_level!r}: must be an integer.",
)
if level not in _VALID_LEVELS:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"Invalid 'level' value {level}: must be one of {sorted(_VALID_LEVELS)}.",
)
try:
engine_client = raw_request.app.state.engine_client
await engine_client.sleep(int(level))
return JSONResponse(content="ok")
except Exception as e:
logger.exception("Failed to sleep model: %s", e)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=f"Failed to sleep model: {e}",
)


@sleep_router.post("/wake_up")
async def wake_up(raw_request: Request):
"""
Wake up the worker from sleep mode. See the sleep function
method for more details.

Args:
tags: An optional list of tags to reallocate the worker memory
for specific memory allocations. Values must be in
`("weights")`. If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
worker is used again.
"""
tags = raw_request.query_params.getlist("tags")
if tags == []:
tags = None
try:
engine_client = raw_request.app.state.engine_client
await engine_client.wake_up(tags)
return JSONResponse(content="ok")
except Exception as e:
logger.exception("Failed to wake_up model: %s", e)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=f"Failed to wake_up model: {e}",
)


@sleep_router.get("/sleep_info")
async def sleep_info(raw_request: Request) -> JSONResponse:
"""Return the current sleep level of the engine.

Response body: ``{"sleep_level": int | null}``
- ``null`` engine is awake (not sleeping).
- ``1`` weights are offloaded.
- ``2`` weights offloaded and KV caches saved/reset.
"""
try:
engine_client = raw_request.app.state.engine_client
level = await engine_client.sleep_level()
result = await engine_client.is_sleeping()
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=f"Failed to query sleep level: {e}",
)
return JSONResponse(content={"sleep_level": level, "is_sleeping": result})
Loading