From dbcc96c469e15d618a2ee696e5f1dc317fed0ce8 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Fri, 2 Jan 2026 15:43:04 +0800 Subject: [PATCH 01/17] Support collective_rpc for OmniStage Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/omni_stage.py | 82 +++++++++++++++++++++++++++- vllm_omni/entrypoints/stage_utils.py | 1 + 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 8f3fac1d9b1..b114e9c0f78 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,7 +15,8 @@ import sys import traceback from dataclasses import fields -from typing import Any +from typing import Any, Callable, TypeVar +import uuid from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor @@ -49,6 +50,7 @@ logger = init_logger(__name__) +_R = TypeVar("_R") def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: """Build OmniDiffusionConfig kwargs from engine args.""" @@ -420,6 +422,63 @@ def process_engine_inputs( stage_list, engine_input_source, prompt, self.requires_multimodal_data ) + def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + """Execute an RPC call on all workers via the stage engine. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" + + # Submit collective_rpc task to worker + rpc_id = str(uuid.uuid4()) + self._in_q.put({ + "type": "collective_rpc", + "rpc_id": rpc_id, + "method": method, + "timeout": timeout, + "args": args, + "kwargs": kwargs, + }) + + # Wait for result from worker + import time + start_time = time.time() + while True: + if timeout is not None and (time.time() - start_time) > timeout: + raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") + + result = self.try_collect() + if result is not None: + if result.get("type") == "collective_rpc_result": + if result.get("rpc_id") == rpc_id: + if "error" in result: + raise RuntimeError(f"collective_rpc failed: {result['error']}") + return result["result"] + + time.sleep(0.001) # Small sleep to avoid busy waiting def _stage_worker( model: str, @@ -636,6 +695,27 @@ def _stage_worker( logger.error("Received shutdown signal") break + if task_type == OmniStageTaskType.COLLECTIVE_RPC: + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + try: + result = stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + }) + except Exception as e: + out_q.put({ + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + }) + continue + batch_tasks: list[dict[str, Any]] = [task] start_time = _time.time() if max_batch_size > 1: diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index 986cdc0b1f4..6f3ace9500e 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -14,6 +14,7 @@ class OmniStageTaskType(enum.Enum): GENERATE = "generate" + COLLECTIVE_RPC = "collective_rpc" ABORT = "abort" SHUTDOWN = "shutdown" From 188d681935cb95a6baca2fb68e96cb77078198b2 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Fri, 2 Jan 2026 16:39:30 +0800 Subject: [PATCH 02/17] Support rpc for omni stage and fix compatible api between OmniDiffusion and OmniLLM Signed-off-by: knlnguyen1802 --- vllm_omni/diffusion/diffusion_engine.py | 5 ++--- vllm_omni/entrypoints/omni_diffusion.py | 15 +++++++++++++++ vllm_omni/entrypoints/omni_stage.py | 2 +- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 9a5e395aebc..12d4db74f81 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -278,7 +278,6 @@ def collective_rpc( timeout: float | None = None, args: tuple = (), kwargs: dict | None = None, - unique_reply_rank: int | None = None, ) -> Any: """Call a method on worker processes and get results immediately. @@ -287,7 +286,6 @@ def collective_rpc( timeout: Optional timeout in seconds args: Positional arguments for the method kwargs: Keyword arguments for the method - unique_reply_rank: If set, only get reply from this rank Returns: Single result if unique_reply_rank is provided, otherwise list of results @@ -297,7 +295,8 @@ def collective_rpc( deadline = None if timeout is None else time.monotonic() + timeout kwargs = kwargs or {} - + unique_reply_rank = kwargs.pop("unique_reply_rank", None) + assert isinstance(method, str) send_method = method diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 1d71e23cb42..9a06935e520 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -3,6 +3,7 @@ import logging from dataclasses import fields +from typing import Any, Callable from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict @@ -114,6 +115,20 @@ def generate( def _run_engine(self, requests: list[OmniDiffusionRequest]): return self.engine.step(requests) + + def collective_rpc( + self, + method: str | Callable, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + ) -> Any: + return self.engine.collective_rpc( + method, + timeout=timeout, + args=args, + kwargs=kwargs, + ) def close(self) -> None: self.engine.close() diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index b114e9c0f78..574d7343f94 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -455,7 +455,7 @@ def collective_rpc( # Submit collective_rpc task to worker rpc_id = str(uuid.uuid4()) self._in_q.put({ - "type": "collective_rpc", + "type": OmniStageTaskType.COLLECTIVE_RPC, "rpc_id": rpc_id, "method": method, "timeout": timeout, From aec5487297c5585f741722646ecf4b178ee35510 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Fri, 2 Jan 2026 16:43:43 +0800 Subject: [PATCH 03/17] Fix pre-commit Signed-off-by: knlnguyen1802 --- vllm_omni/diffusion/diffusion_engine.py | 2 +- vllm_omni/entrypoints/omni_diffusion.py | 5 ++- vllm_omni/entrypoints/omni_stage.py | 56 +++++++++++++++---------- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 12d4db74f81..3ff66cc377c 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -296,7 +296,7 @@ def collective_rpc( deadline = None if timeout is None else time.monotonic() + timeout kwargs = kwargs or {} unique_reply_rank = kwargs.pop("unique_reply_rank", None) - + assert isinstance(method, str) send_method = method diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 9a06935e520..95a1f10bb93 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging +from collections.abc import Callable from dataclasses import fields -from typing import Any, Callable +from typing import Any from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict @@ -115,7 +116,7 @@ def generate( def _run_engine(self, requests: list[OmniDiffusionRequest]): return self.engine.step(requests) - + def collective_rpc( self, method: str | Callable, diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 574d7343f94..0fac9ba361b 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -14,9 +14,10 @@ import queue import sys import traceback -from dataclasses import fields -from typing import Any, Callable, TypeVar import uuid +from collections.abc import Callable +from dataclasses import fields +from typing import Any, TypeVar from vllm.inputs import TextPrompt from vllm.inputs.preprocess import InputPreprocessor @@ -52,6 +53,7 @@ _R = TypeVar("_R") + def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: """Build OmniDiffusionConfig kwargs from engine args.""" od_config = engine_args.get("od_config", {}) @@ -451,25 +453,28 @@ def collective_rpc( and set up data-plane communication to pass data. """ assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" - + # Submit collective_rpc task to worker rpc_id = str(uuid.uuid4()) - self._in_q.put({ - "type": OmniStageTaskType.COLLECTIVE_RPC, - "rpc_id": rpc_id, - "method": method, - "timeout": timeout, - "args": args, - "kwargs": kwargs, - }) - + self._in_q.put( + { + "type": OmniStageTaskType.COLLECTIVE_RPC, + "rpc_id": rpc_id, + "method": method, + "timeout": timeout, + "args": args, + "kwargs": kwargs, + } + ) + # Wait for result from worker import time + start_time = time.time() while True: if timeout is not None and (time.time() - start_time) > timeout: raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") - + result = self.try_collect() if result is not None: if result.get("type") == "collective_rpc_result": @@ -480,6 +485,7 @@ def collective_rpc( time.sleep(0.001) # Small sleep to avoid busy waiting + def _stage_worker( model: str, stage_payload: dict[str, Any], @@ -703,17 +709,21 @@ def _stage_worker( kwargs = task.get("kwargs") try: result = stage_engine.collective_rpc(method, timeout, args, kwargs) - out_q.put({ - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "result": result, - }) + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + } + ) except Exception as e: - out_q.put({ - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "error": str(e), - }) + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + } + ) continue batch_tasks: list[dict[str, Any]] = [task] From e9be588dac68e4462c10bb15f7587a9b3a31f61f Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Mon, 5 Jan 2026 11:15:28 +0800 Subject: [PATCH 04/17] Add test Signed-off-by: knlnguyen1802 --- tests/e2e/test_rpc_collective.py | 37 ++++++++++++++++++++++++++++++++ vllm_omni/entrypoints/omni.py | 21 +++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 tests/e2e/test_rpc_collective.py diff --git a/tests/e2e/test_rpc_collective.py b/tests/e2e/test_rpc_collective.py new file mode 100644 index 00000000000..2c1c83f3a62 --- /dev/null +++ b/tests/e2e/test_rpc_collective.py @@ -0,0 +1,37 @@ +import os +import sys +from pathlib import Path + +import pytest + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni import Omni + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +diffusion_models = ["Tongyi-MAI/Z-Image-Turbo"] + +omni_models = ["Qwen/Qwen2.5-Omni-3B"] + + +@pytest.mark.parametrize("model_name", omni_models) +def test_omni_model(model_name: str): + m = Omni(model=model_name,init_timeout=3600) + results = m.collective_rpc( + method="sleep", + args=(1,), + ) + assert len(results)==3 + +@pytest.mark.parametrize("model_name", diffusion_models) +def test_diffusion_model(model_name: str): + m = Omni(model=model_name) + results = m.collective_rpc( + method="sleep", + args=(1,), + ) + assert len(results)==1 \ No newline at end of file diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 73365b1dc87..0d5d2ebdae8 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict from pprint import pformat -from typing import Any +from typing import Any, TypeVar from omegaconf import OmegaConf from tqdm.auto import tqdm @@ -42,6 +42,7 @@ logger = init_logger(__name__) +_R = TypeVar("_R") def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): """Weak reference cleanup function for OmniBase instances.""" @@ -688,6 +689,24 @@ def _run_generation( except Exception as e: logger.exception(f"[{self._name}] Failed to build/log summary: {e}") + def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method=method, + args=args, + timeout=timeout, + kwargs=kwargs, + ) + results.append(result) + return results + @property def _name(self) -> str: return "Orchestrator" From 2d17bffd87d2361a9b7700d45e9cbdbe3f0c8253 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Mon, 5 Jan 2026 14:14:52 +0800 Subject: [PATCH 05/17] Fix pre-commit Signed-off-by: knlnguyen1802 --- tests/e2e/test_rpc_collective.py | 7 ++++--- vllm_omni/entrypoints/omni.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/e2e/test_rpc_collective.py b/tests/e2e/test_rpc_collective.py index 2c1c83f3a62..dc66a795f40 100644 --- a/tests/e2e/test_rpc_collective.py +++ b/tests/e2e/test_rpc_collective.py @@ -20,12 +20,13 @@ @pytest.mark.parametrize("model_name", omni_models) def test_omni_model(model_name: str): - m = Omni(model=model_name,init_timeout=3600) + m = Omni(model=model_name, init_timeout=3600) results = m.collective_rpc( method="sleep", args=(1,), ) - assert len(results)==3 + assert len(results) == 3 + @pytest.mark.parametrize("model_name", diffusion_models) def test_diffusion_model(model_name: str): @@ -34,4 +35,4 @@ def test_diffusion_model(model_name: str): method="sleep", args=(1,), ) - assert len(results)==1 \ No newline at end of file + assert len(results) == 1 diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 0d5d2ebdae8..0daf62021e2 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -44,6 +44,7 @@ _R = TypeVar("_R") + def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): """Weak reference cleanup function for OmniBase instances.""" if stage_list: From f961a44864f406928e843a6c7b24965b50005166 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 6 Jan 2026 10:40:21 +0800 Subject: [PATCH 06/17] Resolve comment and add test Signed-off-by: knlnguyen1802 --- .buildkite/pipeline.yml | 1 + tests/e2e/test_rpc_collective.py | 6 +++--- vllm_omni/entrypoints/omni_stage.py | 31 ++++++++++++++++++++++++++--- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 4bb9d4cd69a..5723802c2c9 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -82,6 +82,7 @@ steps: depends_on: image-build commands: - pytest -s -v tests/diffusion/test_gpu_worker.py + - pytest -s -v tests/e2e/test_rpc_collective.py agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU plugins: diff --git a/tests/e2e/test_rpc_collective.py b/tests/e2e/test_rpc_collective.py index dc66a795f40..05757ae6a7f 100644 --- a/tests/e2e/test_rpc_collective.py +++ b/tests/e2e/test_rpc_collective.py @@ -1,7 +1,7 @@ import os import sys from pathlib import Path - +from .offline_inference.utils import create_new_process_for_each_test import pytest # ruff: noqa: E402 @@ -17,7 +17,7 @@ omni_models = ["Qwen/Qwen2.5-Omni-3B"] - +@create_new_process_for_each_test() @pytest.mark.parametrize("model_name", omni_models) def test_omni_model(model_name: str): m = Omni(model=model_name, init_timeout=3600) @@ -27,7 +27,7 @@ def test_omni_model(model_name: str): ) assert len(results) == 3 - +@create_new_process_for_each_test() @pytest.mark.parametrize("model_name", diffusion_models) def test_diffusion_model(model_name: str): m = Omni(model=model_name) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 736de48abfb..8ffa4739af5 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,6 +15,7 @@ import sys import traceback import uuid +import time from collections.abc import Callable from dataclasses import fields from typing import Any, TypeVar @@ -468,9 +469,6 @@ def collective_rpc( } ) - # Wait for result from worker - import time - start_time = time.time() while True: if timeout is not None and (time.time() - start_time) > timeout: @@ -743,6 +741,7 @@ def handle_profiler_task(task_type: OmniStageTaskType) -> None: "result": result, } ) + continue except Exception as e: out_q.put( { @@ -751,6 +750,7 @@ def handle_profiler_task(task_type: OmniStageTaskType) -> None: "error": str(e), } ) + continue # Handle profiler control commands if is_profiler_task(task_type): @@ -1336,6 +1336,31 @@ async def generation_single_request(task: dict[str, Any]): elif task_type == OmniStageTaskType.ABORT: rid = task["request_id"] asyncio.create_task(stage_engine.abort(rid)) + elif task_type == OmniStageTaskType.COLLECTIVE_RPC: + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + try: + result = stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + } + ) + continue + except Exception as e: + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + } + ) + continue elif is_profiler_task(task_type): await handle_profiler_task_async(task_type) else: From 111ca2325feb4235eb14a91b14f608c65d6c6687 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Tue, 6 Jan 2026 11:45:27 +0800 Subject: [PATCH 07/17] Enable sleep mode in test and fix pre-commit Signed-off-by: knlnguyen1802 --- tests/e2e/test_rpc_collective.py | 10 +++++++--- vllm_omni/entrypoints/omni_stage.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/e2e/test_rpc_collective.py b/tests/e2e/test_rpc_collective.py index 05757ae6a7f..d87ae2b0a51 100644 --- a/tests/e2e/test_rpc_collective.py +++ b/tests/e2e/test_rpc_collective.py @@ -1,9 +1,11 @@ import os import sys from pathlib import Path -from .offline_inference.utils import create_new_process_for_each_test + import pytest +from .offline_inference.utils import create_new_process_for_each_test + # ruff: noqa: E402 REPO_ROOT = Path(__file__).resolve().parents[2] if str(REPO_ROOT) not in sys.path: @@ -17,20 +19,22 @@ omni_models = ["Qwen/Qwen2.5-Omni-3B"] + @create_new_process_for_each_test() @pytest.mark.parametrize("model_name", omni_models) def test_omni_model(model_name: str): - m = Omni(model=model_name, init_timeout=3600) + m = Omni(model=model_name, enable_sleep_mode=True, init_timeout=3600) results = m.collective_rpc( method="sleep", args=(1,), ) assert len(results) == 3 + @create_new_process_for_each_test() @pytest.mark.parametrize("model_name", diffusion_models) def test_diffusion_model(model_name: str): - m = Omni(model=model_name) + m = Omni(model=model_name, enable_sleep_mode=True) results = m.collective_rpc( method="sleep", args=(1,), diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 8ffa4739af5..d15cf41074c 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -13,9 +13,9 @@ import os import queue import sys +import time import traceback import uuid -import time from collections.abc import Callable from dataclasses import fields from typing import Any, TypeVar @@ -751,7 +751,7 @@ def handle_profiler_task(task_type: OmniStageTaskType) -> None: } ) continue - + # Handle profiler control commands if is_profiler_task(task_type): handle_profiler_task(task_type) From 31010c80de6b44e476cab5526c3cbc9a7cc50705 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 10:50:05 +0800 Subject: [PATCH 08/17] Update function in entrypoints Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/async_omni.py | 32 +++++++++++++- vllm_omni/entrypoints/async_omni_diffusion.py | 42 +++++++++++++++++++ vllm_omni/entrypoints/async_omni_llm.py | 36 ++++++++++++++++ vllm_omni/entrypoints/omni.py | 42 +++++++++++++++++++ vllm_omni/entrypoints/omni_diffusion.py | 36 ++++++++++++++++ vllm_omni/entrypoints/omni_llm.py | 36 ++++++++++++++++ 6 files changed, 222 insertions(+), 2 deletions(-) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 3dd687a8b49..c2de265e8f9 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -719,10 +719,38 @@ async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool pass async def sleep(self, level: int = 1) -> None: - pass + """Put all stage workers to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + """ + for stage in self.stage_list: + await asyncio.get_event_loop().run_in_executor( + None, + stage.collective_rpc, + "sleep", + None, + (level,), + {}, + ) async def wake_up(self, tags: list[str] | None = None) -> None: - pass + """Wake up all stage workers from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + """ + for stage in self.stage_list: + await asyncio.get_event_loop().run_in_executor( + None, + stage.collective_rpc, + "wake_up", + None, + (), + {"tags": tags}, + ) async def is_sleeping(self) -> bool: """Check whether the engine is sleeping""" diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 6d34e3445f4..875726402fd 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -370,3 +370,45 @@ async def pin_lora(self, lora_id: int) -> bool: None, ) return all(results) if isinstance(results, list) else results + + async def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "sleep", + None, + (level,), + {}, + ) + return all(results) if isinstance(results, list) else results + + async def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + loop = asyncio.get_event_loop() + results = await loop.run_in_executor( + self._executor, + self.engine.collective_rpc, + "wake_up", + None, + (tags,), + {}, + ) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py index 287f12b9ed7..c4c530eb49a 100644 --- a/vllm_omni/entrypoints/async_omni_llm.py +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -185,6 +185,42 @@ def __init__( else: self.profiler = None + async def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + results = await self.engine_core.collective_rpc( + method="sleep", + timeout=None, + args=(level,), + kwargs={}, + ) + return all(results) if isinstance(results, list) else results + + async def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + results = await self.engine_core.collective_rpc( + method="wake_up", + timeout=None, + args=(tags,), + kwargs={}, + ) + return all(results) if isinstance(results, list) else results + @classmethod @deprecate_kwargs( "disable_log_requests", diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 23610e69f68..6107e283187 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -884,3 +884,45 @@ def collective_rpc( @property def _name(self) -> str: return "Orchestrator" + + def sleep(self, level: int = 1) -> list[bool]: + """Put all stage workers to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + List of results from each stage + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="sleep", + timeout=None, + args=(level,), + kwargs={}, + ) + results.append(result) + return results + + def wake_up(self, tags: list[str] | None = None) -> list[bool]: + """Wake up all stage workers from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + List of results from each stage + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) + results.append(result) + return results diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 18b26544f36..980dfb20072 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -184,3 +184,39 @@ def stop_profile(self) -> dict: return self.engine.stop_profile() else: raise RuntimeError("Diffusion engine not initialized") + + def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + results = self.collective_rpc( + method="sleep", + timeout=None, + args=(level,), + kwargs={}, + ) + return all(results) if isinstance(results, list) else results + + def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + results = self.collective_rpc( + method="wake_up", + timeout=None, + args=(tags,), + kwargs={}, + ) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 6bb8ac3663f..7503617ae19 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -241,3 +241,39 @@ def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[Re # This is necessary because some requests may be finished earlier than # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) + + def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + results = self.llm_engine.collective_rpc( + method="sleep", + timeout=None, + args=(level,), + kwargs={}, + ) + return all(results) if isinstance(results, list) else results + + def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + results = self.llm_engine.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) + return all(results) if isinstance(results, list) else results From 786cdc9ed0764f6574c35b06b14cc847b4b9c98d Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 11:00:05 +0800 Subject: [PATCH 09/17] Fix rpc_collective for AsyncOmni Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/async_omni.py | 139 +++++++++++++++++++++++----- 1 file changed, 118 insertions(+), 21 deletions(-) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index c2de265e8f9..74e18d14901 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -3,10 +3,10 @@ import asyncio import time import weakref -from collections.abc import AsyncGenerator, Iterable +from collections.abc import AsyncGenerator, Callable, Iterable from dataclasses import asdict from pprint import pformat -from typing import Any +from typing import Any, TypeVar from vllm.config import VllmConfig from vllm.inputs.preprocess import InputPreprocessor @@ -39,6 +39,8 @@ logger = init_logger(__name__) +_R = TypeVar("_R") + def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handler): """Weak reference cleanup function for AsyncOmni instances.""" @@ -718,21 +720,49 @@ async def reset_mm_cache(self) -> None: async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: pass - async def sleep(self, level: int = 1) -> None: - """Put all stage workers to sleep, offloading model weights. + async def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + """Execute a method on all stage workers via collective RPC. Args: - level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + method: Method name (str) or callable to execute on workers + timeout: Optional timeout in seconds + args: Positional arguments for the method + kwargs: Keyword arguments for the method + + Returns: + List of results from each stage """ + results = [] for stage in self.stage_list: - await asyncio.get_event_loop().run_in_executor( + result = await asyncio.get_event_loop().run_in_executor( None, stage.collective_rpc, - "sleep", - None, - (level,), - {}, + method, + timeout, + args, + kwargs, ) + results.append(result) + return results + + async def sleep(self, level: int = 1) -> None: + """Put all stage workers to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + """ + await self.collective_rpc( + method="sleep", + timeout=None, + args=(level,), + kwargs={}, + ) async def wake_up(self, tags: list[str] | None = None) -> None: """Wake up all stage workers from sleep mode. @@ -742,23 +772,90 @@ async def wake_up(self, tags: list[str] | None = None) -> None: allocations. Values must be in ("weights",). If None, all memory is reallocated. """ - for stage in self.stage_list: - await asyncio.get_event_loop().run_in_executor( - None, - stage.collective_rpc, - "wake_up", - None, - (), - {"tags": tags}, - ) + await self.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) async def is_sleeping(self) -> bool: """Check whether the engine is sleeping""" return False async def add_lora(self, lora_request: LoRARequest) -> bool: - """Load a new LoRA adapter into the engine for future requests.""" - return False + """Load a new LoRA adapter into the engine for future requests. + + Args: + lora_request: LoRA adapter request to load + + Returns: + True if successful on all stages + """ + results = await self.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request}, + ) + return all(results) if isinstance(results, list) else results + + async def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter from all stages. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful on all stages + """ + results = await self.collective_rpc( + method="remove_lora", + timeout=None, + args=(adapter_id,), + kwargs={}, + ) + return all(results) if isinstance(results, list) else results + + async def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs across all stages. + + Returns: + List of unique adapter IDs + """ + results = await self.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, + ) + # Flatten and deduplicate adapter IDs from all stages + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + if isinstance(part, list): + merged.update(part or []) + elif part is not None: + merged.add(part) + return sorted(merged) + + async def pin_lora(self, adapter_id: int) -> bool: + """Prevent a LoRA adapter from being evicted on all stages. + + Args: + adapter_id: The adapter ID to pin + + Returns: + True if successful on all stages + """ + results = await self.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + return all(results) if isinstance(results, list) else results async def encode( self, From 2aa65e05e93643dc51b27c7800e101fc74329227 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 11:35:36 +0800 Subject: [PATCH 10/17] Fix rpc_collective for AsyncOmniDiffusion Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/async_omni_diffusion.py | 141 +++++++++++------- 1 file changed, 87 insertions(+), 54 deletions(-) diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 875726402fd..46396fae8ab 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -10,7 +10,7 @@ import asyncio import uuid -from collections.abc import AsyncGenerator, Iterable +from collections.abc import AsyncGenerator, Callable, Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import fields from typing import Any @@ -309,45 +309,81 @@ def is_stopped(self) -> bool: """Check if the engine is stopped.""" return self._closed - async def remove_lora(self, adapter_id: int) -> bool: - """Remove a LoRA""" + async def collective_rpc( + self, + method: str | Callable, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + ) -> Any: + """Execute a method on diffusion workers via collective RPC. + + Args: + method: Method name (str) or callable to execute on workers + timeout: Optional timeout in seconds + args: Positional arguments for the method + kwargs: Keyword arguments for the method + + Returns: + Results from the workers + """ loop = asyncio.get_event_loop() results = await loop.run_in_executor( self._executor, self.engine.collective_rpc, - "remove_lora", - None, - (adapter_id,), - {}, - None, + method, + timeout, + args, + kwargs, + ) + return results + + async def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="remove_lora", + timeout=None, + args=(adapter_id,), + kwargs={}, ) return all(results) if isinstance(results, list) else results async def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: - """Add a LoRA adapter""" - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "add_lora", - None, - (), - {"lora_request": lora_request, "lora_scale": lora_scale}, - None, + """Add a LoRA adapter. + + Args: + lora_request: LoRA adapter request to load + lora_scale: Scale factor for LoRA weights + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, ) return all(results) if isinstance(results, list) else results async def list_loras(self) -> list[int]: - """List all registered LoRA adapter IDs.""" - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "list_loras", - None, - (), - {}, - None, + """List all registered LoRA adapter IDs. + + Returns: + List of unique adapter IDs + """ + results = await self.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, ) # collective_rpc returns list from workers; flatten unique ids if not isinstance(results, list): @@ -358,16 +394,19 @@ async def list_loras(self) -> list[int]: return sorted(merged) async def pin_lora(self, lora_id: int) -> bool: - """Prevent an adapter from being evicted.""" - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "pin_lora", - None, - (), - {"adapter_id": lora_id}, - None, + """Prevent an adapter from being evicted. + + Args: + lora_id: The adapter ID to pin + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": lora_id}, ) return all(results) if isinstance(results, list) else results @@ -380,14 +419,11 @@ async def sleep(self, level: int = 1) -> bool: Returns: True if successful """ - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "sleep", - None, - (level,), - {}, + results = await self.collective_rpc( + method="sleep", + timeout=None, + args=(level,), + kwargs={}, ) return all(results) if isinstance(results, list) else results @@ -402,13 +438,10 @@ async def wake_up(self, tags: list[str] | None = None) -> bool: Returns: True if successful """ - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "wake_up", - None, - (tags,), - {}, + results = await self.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, ) return all(results) if isinstance(results, list) else results From 11058c4c8b344ae5ac2cd2b739dd5b41de263113 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 06:12:06 +0000 Subject: [PATCH 11/17] Fix bug Signed-off-by: knlnguyen1802 --- .../offline_inference/test_rpc_collective.py | 49 ++++++++++++++++++ vllm_omni/entrypoints/async_omni.py | 13 +++-- vllm_omni/entrypoints/omni_stage.py | 50 ++++++++++--------- 3 files changed, 81 insertions(+), 31 deletions(-) create mode 100755 tests/e2e/offline_inference/test_rpc_collective.py mode change 100644 => 100755 vllm_omni/entrypoints/async_omni.py mode change 100644 => 100755 vllm_omni/entrypoints/omni_stage.py diff --git a/tests/e2e/offline_inference/test_rpc_collective.py b/tests/e2e/offline_inference/test_rpc_collective.py new file mode 100755 index 00000000000..aa2d04ff201 --- /dev/null +++ b/tests/e2e/offline_inference/test_rpc_collective.py @@ -0,0 +1,49 @@ +import os +import sys +from pathlib import Path + +import pytest +import asyncio + +from .utils import create_new_process_for_each_test + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.entrypoints.async_omni import AsyncOmni + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +models = ["/mnt/nvme3n1/n0090/Z-Image-Turbo"] + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_name", models) +def test_rpc_collective_omni(model_name: str): + m = Omni(model=model_name, enable_sleep_mode=True) + sleep_results = m.collective_rpc( + method="sleep", + args=(1,), + ) + assert len(sleep_results) == 1 + wake_up_results = m.collective_rpc( + method="wake_up", + args=(["weights"],), + ) + assert len(wake_up_results) == 1 + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_name", models) +def test_rpc_collective_async_omni(model_name: str): + async def _run(): + m = AsyncOmni(model=model_name, enable_sleep_mode=True) + sleep_results = await m.collective_rpc(method="sleep", args=(1,)) + assert len(sleep_results) == 1 + wake_up_results = await m.collective_rpc(method="wake_up", args=(["weights"],)) + assert len(wake_up_results) == 1 + + asyncio.run(_run()) + diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py old mode 100644 new mode 100755 index 74e18d14901..9309221ef1f --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -166,6 +166,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "cache_backend": cache_backend, "cache_config": cache_config, "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), + "enable_sleep_mode": kwargs.get("enable_sleep_mode", False), "enforce_eager": kwargs.get("enforce_eager", False), }, "final_output": True, @@ -740,13 +741,11 @@ async def collective_rpc( """ results = [] for stage in self.stage_list: - result = await asyncio.get_event_loop().run_in_executor( - None, - stage.collective_rpc, - method, - timeout, - args, - kwargs, + result = stage.collective_rpc( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, ) results.append(result) return results diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py old mode 100644 new mode 100755 index f8487ce8f5b..7e4c3034e3c --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -1495,6 +1495,30 @@ async def generation_single_request(task: dict[str, Any]): } ) + async def execute_rpc(task: dict[str, Any]): + try: + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + result = await stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + } + ) + except Exception as e: + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + } + ) + _batch_gen_t0 = _time.time() while True: try: @@ -1508,30 +1532,8 @@ async def generation_single_request(task: dict[str, Any]): rid = task["request_id"] asyncio.create_task(stage_engine.abort(rid)) elif task_type == OmniStageTaskType.COLLECTIVE_RPC: - rpc_id = task.get("rpc_id") - method = task.get("method") - timeout = task.get("timeout") - args = task.get("args") - kwargs = task.get("kwargs") - try: - result = stage_engine.collective_rpc(method, timeout, args, kwargs) - out_q.put( - { - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "result": result, - } - ) - continue - except Exception as e: - out_q.put( - { - "type": "collective_rpc_result", - "rpc_id": rpc_id, - "error": str(e), - } - ) - continue + asyncio.create_task(execute_rpc(task)) + continue elif is_profiler_task(task_type): await handle_profiler_task_async(task_type) else: From 475ed7f4d9e85150294c44ba6d31a3fad25c51c7 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 06:35:06 +0000 Subject: [PATCH 12/17] Update collective rpc for lora Signed-off-by: knlnguyen1802 --- vllm_omni/diffusion/worker/gpu_diffusion_worker.py | 0 vllm_omni/entrypoints/async_omni.py | 4 ++-- vllm_omni/entrypoints/async_omni_diffusion.py | 4 ++-- vllm_omni/entrypoints/async_omni_llm.py | 8 ++++---- vllm_omni/entrypoints/omni.py | 4 ++-- vllm_omni/entrypoints/omni_diffusion.py | 8 ++++---- vllm_omni/entrypoints/omni_llm.py | 4 ++-- 7 files changed, 16 insertions(+), 16 deletions(-) mode change 100644 => 100755 vllm_omni/diffusion/worker/gpu_diffusion_worker.py mode change 100644 => 100755 vllm_omni/entrypoints/async_omni_diffusion.py mode change 100644 => 100755 vllm_omni/entrypoints/async_omni_llm.py mode change 100644 => 100755 vllm_omni/entrypoints/omni.py mode change 100644 => 100755 vllm_omni/entrypoints/omni_diffusion.py mode change 100644 => 100755 vllm_omni/entrypoints/omni_llm.py diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_worker.py b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py old mode 100644 new mode 100755 diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 9309221ef1f..a28f33dedc7 100755 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -759,8 +759,8 @@ async def sleep(self, level: int = 1) -> None: await self.collective_rpc( method="sleep", timeout=None, - args=(level,), - kwargs={}, + args=(), + kwargs={"level":level}, ) async def wake_up(self, tags: list[str] | None = None) -> None: diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py old mode 100644 new mode 100755 index 46396fae8ab..8638f7fdc94 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -422,8 +422,8 @@ async def sleep(self, level: int = 1) -> bool: results = await self.collective_rpc( method="sleep", timeout=None, - args=(level,), - kwargs={}, + args=(), + kwargs={"level": level}, ) return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py old mode 100644 new mode 100755 index c4c530eb49a..39c06fadc61 --- a/vllm_omni/entrypoints/async_omni_llm.py +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -197,8 +197,8 @@ async def sleep(self, level: int = 1) -> bool: results = await self.engine_core.collective_rpc( method="sleep", timeout=None, - args=(level,), - kwargs={}, + args=(), + kwargs={"level": level}, ) return all(results) if isinstance(results, list) else results @@ -216,8 +216,8 @@ async def wake_up(self, tags: list[str] | None = None) -> bool: results = await self.engine_core.collective_rpc( method="wake_up", timeout=None, - args=(tags,), - kwargs={}, + args=(), + kwargs={"tags": tags}, ) return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py old mode 100644 new mode 100755 index 6107e283187..6a969d14e08 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -899,8 +899,8 @@ def sleep(self, level: int = 1) -> list[bool]: result = stage.collective_rpc( method="sleep", timeout=None, - args=(level,), - kwargs={}, + args=(), + kwargs={"level": level}, ) results.append(result) return results diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py old mode 100644 new mode 100755 index 980dfb20072..5b6df5b3c6e --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -197,8 +197,8 @@ def sleep(self, level: int = 1) -> bool: results = self.collective_rpc( method="sleep", timeout=None, - args=(level,), - kwargs={}, + args=(), + kwargs={"level": level}, ) return all(results) if isinstance(results, list) else results @@ -216,7 +216,7 @@ def wake_up(self, tags: list[str] | None = None) -> bool: results = self.collective_rpc( method="wake_up", timeout=None, - args=(tags,), - kwargs={}, + args=(), + kwargs={"tags": tags}, ) return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py old mode 100644 new mode 100755 index 7503617ae19..0b5122993b1 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -254,8 +254,8 @@ def sleep(self, level: int = 1) -> bool: results = self.llm_engine.collective_rpc( method="sleep", timeout=None, - args=(level,), - kwargs={}, + args=(), + kwargs={"level": level}, ) return all(results) if isinstance(results, list) else results From a6d1e9ebbddfcf2d5e79ca516c6fe8c4a2322d8d Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 14:36:43 +0800 Subject: [PATCH 13/17] Update lora function Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/async_omni.py | 8 +- vllm_omni/entrypoints/async_omni_diffusion.py | 4 +- vllm_omni/entrypoints/omni.py | 88 +++++++++++++++++++ vllm_omni/entrypoints/omni_diffusion.py | 73 +++++++++++++++ 4 files changed, 167 insertions(+), 6 deletions(-) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 9309221ef1f..b0eded10e2e 100755 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -782,7 +782,7 @@ async def is_sleeping(self) -> bool: """Check whether the engine is sleeping""" return False - async def add_lora(self, lora_request: LoRARequest) -> bool: + async def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: """Load a new LoRA adapter into the engine for future requests. Args: @@ -795,7 +795,7 @@ async def add_lora(self, lora_request: LoRARequest) -> bool: method="add_lora", timeout=None, args=(), - kwargs={"lora_request": lora_request}, + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, ) return all(results) if isinstance(results, list) else results @@ -811,8 +811,8 @@ async def remove_lora(self, adapter_id: int) -> bool: results = await self.collective_rpc( method="remove_lora", timeout=None, - args=(adapter_id,), - kwargs={}, + args=(), + kwargs={"adapter_id": adapter_id}, ) return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index 46396fae8ab..518854c2f6e 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -350,8 +350,8 @@ async def remove_lora(self, adapter_id: int) -> bool: results = await self.collective_rpc( method="remove_lora", timeout=None, - args=(adapter_id,), - kwargs={}, + args=(), + kwargs={"adapter_id": adapter_id}, ) return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 6107e283187..40544d3911a 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -42,6 +42,7 @@ load_stage_configs_from_yaml, resolve_model_config_path, ) +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -926,3 +927,90 @@ def wake_up(self, tags: list[str] | None = None) -> list[bool]: ) results.append(result) return results + + def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """Load a new LoRA adapter into the engine for future requests. + + Args: + lora_request: LoRA adapter request to load + lora_scale: Scale factor for LoRA weights + + Returns: + True if successful on all stages + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, + ) + results.append(result) + return all(results) if isinstance(results, list) else results + + def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter from all stages. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful on all stages + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="remove_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + results.append(result) + return all(results) if isinstance(results, list) else results + + def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs across all stages. + + Returns: + List of unique adapter IDs + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, + ) + results.append(result) + # Flatten and deduplicate adapter IDs from all stages + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + if isinstance(part, list): + merged.update(part or []) + elif part is not None: + merged.add(part) + return sorted(merged) + + def pin_lora(self, adapter_id: int) -> bool: + """Prevent a LoRA adapter from being evicted on all stages. + + Args: + adapter_id: The adapter ID to pin + + Returns: + True if successful on all stages + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + results.append(result) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 980dfb20072..cf5360086ad 100644 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -13,6 +13,7 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.lora.request import LoRARequest # TODO configure logging properly logging.basicConfig(level=logging.INFO) @@ -220,3 +221,75 @@ def wake_up(self, tags: list[str] | None = None) -> bool: kwargs={}, ) return all(results) if isinstance(results, list) else results + + def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """Add a LoRA adapter. + + Args: + lora_request: LoRA adapter request to load + lora_scale: Scale factor for LoRA weights + + Returns: + True if successful + """ + results = self.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, + ) + return all(results) if isinstance(results, list) else results + + def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful + """ + results = self.collective_rpc( + method="remove_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + return all(results) if isinstance(results, list) else results + + def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs. + + Returns: + List of unique adapter IDs + """ + results = self.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, + ) + # collective_rpc returns list from workers; flatten unique ids + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + merged.update(part or []) + return sorted(merged) + + def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted. + + Args: + lora_id: The adapter ID to pin + + Returns: + True if successful + """ + results = self.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": lora_id}, + ) + return all(results) if isinstance(results, list) else results From 15efb565a41a4bd891a378ad131c4eea5de02ba9 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 14:52:33 +0800 Subject: [PATCH 14/17] Linting Signed-off-by: knlnguyen1802 --- tests/e2e/offline_inference/test_rpc_collective.py | 6 +++--- vllm_omni/diffusion/diffusion_engine.py | 2 +- vllm_omni/entrypoints/async_omni.py | 2 +- vllm_omni/entrypoints/omni_diffusion.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/e2e/offline_inference/test_rpc_collective.py b/tests/e2e/offline_inference/test_rpc_collective.py index aa2d04ff201..5828c1b218b 100755 --- a/tests/e2e/offline_inference/test_rpc_collective.py +++ b/tests/e2e/offline_inference/test_rpc_collective.py @@ -1,9 +1,9 @@ +import asyncio import os import sys from pathlib import Path import pytest -import asyncio from .utils import create_new_process_for_each_test @@ -12,13 +12,14 @@ if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) -from vllm_omni.entrypoints.omni import Omni from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.omni import Omni os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" models = ["/mnt/nvme3n1/n0090/Z-Image-Turbo"] + @create_new_process_for_each_test() @pytest.mark.parametrize("model_name", models) def test_rpc_collective_omni(model_name: str): @@ -46,4 +47,3 @@ async def _run(): assert len(wake_up_results) == 1 asyncio.run(_run()) - diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index c709a7584f3..db1c329af97 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -355,7 +355,7 @@ def collective_rpc( args: Positional arguments for the method kwargs: Keyword arguments for the method unique_reply_rank: If set, only get the result from this rank - + Returns: Single result if unique_reply_rank is provided, otherwise list of results """ diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 0571f58d393..55b2dd3ae54 100755 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -760,7 +760,7 @@ async def sleep(self, level: int = 1) -> None: method="sleep", timeout=None, args=(), - kwargs={"level":level}, + kwargs={"level": level}, ) async def wake_up(self, tags: list[str] | None = None) -> None: diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index 8fb113e9c8f..8cd379e556f 100755 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -from collections.abc import Callable import uuid +from collections.abc import Callable from dataclasses import fields from typing import Any From dc49a9fbd9273fc141e4239e4c24364c8d71cf46 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 16:51:11 +0800 Subject: [PATCH 15/17] Fix test Signed-off-by: knlnguyen1802 --- .buildkite/pipeline.yml | 2 +- tests/e2e/test_rpc_collective.py | 42 -------------------------------- 2 files changed, 1 insertion(+), 43 deletions(-) delete mode 100644 tests/e2e/test_rpc_collective.py diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 19614be3396..af6607dae27 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -150,7 +150,7 @@ steps: depends_on: image-build commands: - pytest -s -v tests/diffusion/test_gpu_worker.py - - pytest -s -v tests/e2e/test_rpc_collective.py + - pytest -s -v tests/e2e/offline_inference/test_rpc_collective.py - pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU diff --git a/tests/e2e/test_rpc_collective.py b/tests/e2e/test_rpc_collective.py deleted file mode 100644 index d87ae2b0a51..00000000000 --- a/tests/e2e/test_rpc_collective.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import sys -from pathlib import Path - -import pytest - -from .offline_inference.utils import create_new_process_for_each_test - -# ruff: noqa: E402 -REPO_ROOT = Path(__file__).resolve().parents[2] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - -from vllm_omni import Omni - -os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" - -diffusion_models = ["Tongyi-MAI/Z-Image-Turbo"] - -omni_models = ["Qwen/Qwen2.5-Omni-3B"] - - -@create_new_process_for_each_test() -@pytest.mark.parametrize("model_name", omni_models) -def test_omni_model(model_name: str): - m = Omni(model=model_name, enable_sleep_mode=True, init_timeout=3600) - results = m.collective_rpc( - method="sleep", - args=(1,), - ) - assert len(results) == 3 - - -@create_new_process_for_each_test() -@pytest.mark.parametrize("model_name", diffusion_models) -def test_diffusion_model(model_name: str): - m = Omni(model=model_name, enable_sleep_mode=True) - results = m.collective_rpc( - method="sleep", - args=(1,), - ) - assert len(results) == 1 From 4d796931f84d80c3ceae6b96c05e3d8830993e47 Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 16:54:59 +0800 Subject: [PATCH 16/17] Linting Signed-off-by: knlnguyen1802 --- vllm_omni/diffusion/diffusion_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index db1c329af97..2e00f41b97f 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -354,7 +354,7 @@ def collective_rpc( timeout: Optional timeout in seconds args: Positional arguments for the method kwargs: Keyword arguments for the method - unique_reply_rank: If set, only get the result from this rank + unique_reply_rank: If set, only get reply from this rank Returns: Single result if unique_reply_rank is provided, otherwise list of results From d1d64909d3a67a19ed6156d644cbeb442264199c Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Wed, 28 Jan 2026 17:03:40 +0800 Subject: [PATCH 17/17] Rebase Signed-off-by: knlnguyen1802 --- vllm_omni/entrypoints/omni.py | 4 ++-- vllm_omni/entrypoints/omni_diffusion.py | 2 +- vllm_omni/entrypoints/omni_stage.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 284558fc5d1..9ff0be0707c 100755 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict from pprint import pformat -from typing import Any, TypeVar, Literal, overload +from typing import Any, Literal, TypeVar, overload from omegaconf import OmegaConf from tqdm.auto import tqdm @@ -41,8 +41,8 @@ load_stage_configs_from_yaml, resolve_model_config_path, ) -from vllm_omni.lora.request import LoRARequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py index cba773b37ba..45966374115 100755 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -12,8 +12,8 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.lora.request import LoRARequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput # TODO configure logging properly diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index 6efbc3125a2..4512c274797 100755 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -18,7 +18,7 @@ import uuid from collections.abc import Callable, Sequence from dataclasses import fields -from typing import Any, Literal, cast +from typing import Any, Literal, TypeVar, cast from vllm import PromptType, RequestOutput from vllm.inputs import TextPrompt