Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm_omni/entrypoints/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,14 @@ def _run_generation_with_generator(
# Cleanup when generator is exhausted or closed
self.close()

def sleep(self, level: int = 1):
for stage in self.stage_list:
stage.sleep(level=level)

def wake_up(self):
for stage in self.stage_list:
stage.wake_up()

def _run_generation(
self,
prompts: OmniPromptType | Sequence[OmniPromptType],
Expand Down
38 changes: 38 additions & 0 deletions vllm_omni/entrypoints/omni_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,20 @@ def try_collect(self) -> dict[str, Any] | None:
except Exception:
return None

def sleep(self, level: int = 1):
if self.engine is not None:
if hasattr(self.engine, "sleep"):
self.engine.sleep(level=level)
elif self._in_q is not None:
self.submit({"type": OmniStageTaskType.SLEEP, "level": level})

def wake_up(self):
if self.engine is not None:
if hasattr(self.engine, "wake_up"):
self.engine.wake_up()
elif self._in_q is not None:
self.submit({"type": OmniStageTaskType.WAKE_UP})

def process_engine_inputs(
self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None
) -> list[OmniTokensPrompt | TextPrompt]:
Expand Down Expand Up @@ -779,6 +793,19 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict:
logger.info("Received shutdown signal")
break

if task_type == OmniStageTaskType.SLEEP:
level = task.get("level", 1)
if hasattr(stage_engine, "sleep"):
logger.info(f"[Stage-{stage_id}] Executing real sleep(level={level})")
stage_engine.sleep(level=level)
continue

if task_type == OmniStageTaskType.WAKE_UP:
if hasattr(stage_engine, "wake_up"):
logger.info(f"[Stage-{stage_id}] Execuring real wake_up()")
stage_engine.wake_up()
continue

# Handle profiler control commands
if is_profiler_task(task_type):
profiler_data = handle_profiler_task_local(task_type)
Expand Down Expand Up @@ -1415,6 +1442,17 @@ async def generation_single_request(task: dict[str, Any]):
logger.debug("Received shutdown signal")
stage_engine.shutdown()
break
elif task_type == OmniStageTaskType.SLEEP:
level = task.get("level", 1)
if hasattr(stage_engine, "sleep"):
logger.info(f"[Stage-{stage_id}] Async Worker executing sleep(level={level})")
stage_engine.sleep(level=level)
continue
elif task_type == OmniStageTaskType.WAKE_UP:
if hasattr(stage_engine, "wake_up"):
logger.info(f"[Stage-{stage_id}] Async Worker executing wake_up()")
stage_engine.wake_up()
continue
elif task_type == OmniStageTaskType.ABORT:
rid = task["request_id"]
asyncio.create_task(stage_engine.abort(rid))
Expand Down
2 changes: 2 additions & 0 deletions vllm_omni/entrypoints/stage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class OmniStageTaskType(enum.Enum):
SHUTDOWN = "shutdown"
PROFILER_START = "profiler_start"
PROFILER_STOP = "profiler_stop"
SLEEP = "sleep"
WAKE_UP = "wake_up"


SHUTDOWN_TASK = {"type": OmniStageTaskType.SHUTDOWN}
Expand Down