diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 97357dc3b33..893428b1b6e 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -618,6 +618,14 @@ def _run_generation_with_generator( # Cleanup when generator is exhausted or closed self.close() + def sleep(self, level: int = 1): + for stage in self.stage_list: + stage.sleep(level=level) + + def wake_up(self): + for stage in self.stage_list: + stage.wake_up() + def _run_generation( self, prompts: OmniPromptType | Sequence[OmniPromptType], diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index fc01c147cd6..c3a583718b1 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -433,6 +433,20 @@ def try_collect(self) -> dict[str, Any] | None: except Exception: return None + def sleep(self, level: int = 1): + if self.engine is not None: + if hasattr(self.engine, "sleep"): + self.engine.sleep(level=level) + elif self._in_q is not None: + self.submit({"type": OmniStageTaskType.SLEEP, "level": level}) + + def wake_up(self): + if self.engine is not None: + if hasattr(self.engine, "wake_up"): + self.engine.wake_up() + elif self._in_q is not None: + self.submit({"type": OmniStageTaskType.WAKE_UP}) + def process_engine_inputs( self, stage_list: list[Any], prompt: OmniTokensPrompt | TextPrompt = None ) -> list[OmniTokensPrompt | TextPrompt]: @@ -779,6 +793,19 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: logger.info("Received shutdown signal") break + if task_type == OmniStageTaskType.SLEEP: + level = task.get("level", 1) + if hasattr(stage_engine, "sleep"): + logger.info(f"[Stage-{stage_id}] Executing real sleep(level={level})") + stage_engine.sleep(level=level) + continue + + if task_type == OmniStageTaskType.WAKE_UP: + if hasattr(stage_engine, "wake_up"): + logger.info(f"[Stage-{stage_id}] Execuring real wake_up()") + stage_engine.wake_up() + continue + # Handle profiler control commands if is_profiler_task(task_type): profiler_data = handle_profiler_task_local(task_type) @@ -1415,6 +1442,17 @@ async def generation_single_request(task: dict[str, Any]): logger.debug("Received shutdown signal") stage_engine.shutdown() break + elif task_type == OmniStageTaskType.SLEEP: + level = task.get("level", 1) + if hasattr(stage_engine, "sleep"): + logger.info(f"[Stage-{stage_id}] Async Worker executing sleep(level={level})") + stage_engine.sleep(level=level) + continue + elif task_type == OmniStageTaskType.WAKE_UP: + if hasattr(stage_engine, "wake_up"): + logger.info(f"[Stage-{stage_id}] Async Worker executing wake_up()") + stage_engine.wake_up() + continue elif task_type == OmniStageTaskType.ABORT: rid = task["request_id"] asyncio.create_task(stage_engine.abort(rid)) diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 74ad42f045e..2269ae60c32 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -18,6 +18,8 @@ class OmniStageTaskType(enum.Enum): SHUTDOWN = "shutdown" PROFILER_START = "profiler_start" PROFILER_STOP = "profiler_stop" + SLEEP = "sleep" + WAKE_UP = "wake_up" SHUTDOWN_TASK = {"type": OmniStageTaskType.SHUTDOWN}