diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index d1e71aba2b1..91a990cd25b 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -149,6 +149,8 @@ steps: timeout_in_minutes: 20 depends_on: image-build commands: + - pytest -s -v tests/diffusion/test_gpu_worker.py + - pytest -s -v tests/e2e/offline_inference/test_rpc_collective.py - pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py agents: queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU diff --git a/tests/e2e/offline_inference/test_rpc_collective.py b/tests/e2e/offline_inference/test_rpc_collective.py new file mode 100755 index 00000000000..5828c1b218b --- /dev/null +++ b/tests/e2e/offline_inference/test_rpc_collective.py @@ -0,0 +1,49 @@ +import asyncio +import os +import sys +from pathlib import Path + +import pytest + +from .utils import create_new_process_for_each_test + +# ruff: noqa: E402 +REPO_ROOT = Path(__file__).resolve().parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.entrypoints.omni import Omni + +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1" + +models = ["/mnt/nvme3n1/n0090/Z-Image-Turbo"] + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_name", models) +def test_rpc_collective_omni(model_name: str): + m = Omni(model=model_name, enable_sleep_mode=True) + sleep_results = m.collective_rpc( + method="sleep", + args=(1,), + ) + assert len(sleep_results) == 1 + wake_up_results = m.collective_rpc( + method="wake_up", + args=(["weights"],), + ) + assert len(wake_up_results) == 1 + + +@create_new_process_for_each_test() +@pytest.mark.parametrize("model_name", models) +def test_rpc_collective_async_omni(model_name: str): + async def _run(): + m = AsyncOmni(model=model_name, enable_sleep_mode=True) + sleep_results = await m.collective_rpc(method="sleep", args=(1,)) + assert len(sleep_results) == 1 + wake_up_results = await m.collective_rpc(method="wake_up", args=(["weights"],)) + assert len(wake_up_results) == 1 + + asyncio.run(_run()) diff --git a/vllm_omni/diffusion/worker/gpu_diffusion_worker.py b/vllm_omni/diffusion/worker/gpu_diffusion_worker.py old mode 100644 new mode 100755 diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py old mode 100644 new mode 100755 index a78710fa2fd..f4a665e4b38 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -3,10 +3,10 @@ import asyncio import time import weakref -from collections.abc import AsyncGenerator, Iterable, Sequence +from collections.abc import AsyncGenerator, Callable, Iterable, Sequence from dataclasses import asdict from pprint import pformat -from typing import Any, cast +from typing import Any, TypeVar, cast from vllm.config import VllmConfig from vllm.inputs.preprocess import InputPreprocessor @@ -40,6 +40,8 @@ logger = init_logger(__name__) +_R = TypeVar("_R") + def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handler): """Weak reference cleanup function for AsyncOmni instances.""" @@ -163,6 +165,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "cache_backend": cache_backend, "cache_config": cache_config, "enable_cpu_offload": kwargs.get("enable_cpu_offload", False), + "enable_sleep_mode": kwargs.get("enable_sleep_mode", False), "enforce_eager": kwargs.get("enforce_eager", False), }, "final_output": True, @@ -692,19 +695,140 @@ async def reset_mm_cache(self) -> None: async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: pass + async def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + """Execute a method on all stage workers via collective RPC. + + Args: + method: Method name (str) or callable to execute on workers + timeout: Optional timeout in seconds + args: Positional arguments for the method + kwargs: Keyword arguments for the method + + Returns: + List of results from each stage + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, + ) + results.append(result) + return results + async def sleep(self, level: int = 1) -> None: - pass + """Put all stage workers to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + """ + await self.collective_rpc( + method="sleep", + timeout=None, + args=(), + kwargs={"level": level}, + ) async def wake_up(self, tags: list[str] | None = None) -> None: - pass + """Wake up all stage workers from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + """ + await self.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) async def is_sleeping(self) -> bool: """Check whether the engine is sleeping""" return False - async def add_lora(self, lora_request: LoRARequest) -> bool: - """Load a new LoRA adapter into the engine for future requests.""" - return False + async def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """Load a new LoRA adapter into the engine for future requests. + + Args: + lora_request: LoRA adapter request to load + + Returns: + True if successful on all stages + """ + results = await self.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, + ) + return all(results) if isinstance(results, list) else results + + async def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter from all stages. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful on all stages + """ + results = await self.collective_rpc( + method="remove_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + return all(results) if isinstance(results, list) else results + + async def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs across all stages. + + Returns: + List of unique adapter IDs + """ + results = await self.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, + ) + # Flatten and deduplicate adapter IDs from all stages + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + if isinstance(part, list): + merged.update(part or []) + elif part is not None: + merged.add(part) + return sorted(merged) + + async def pin_lora(self, adapter_id: int) -> bool: + """Prevent a LoRA adapter from being evicted on all stages. + + Args: + adapter_id: The adapter ID to pin + + Returns: + True if successful on all stages + """ + results = await self.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + return all(results) if isinstance(results, list) else results async def encode( self, diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py old mode 100644 new mode 100755 index 535f04f7d2e..883eaea4e7e --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -10,7 +10,7 @@ import asyncio import uuid -from collections.abc import AsyncGenerator, Iterable +from collections.abc import AsyncGenerator, Callable, Iterable from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -233,45 +233,81 @@ def is_stopped(self) -> bool: """Check if the engine is stopped.""" return self._closed - async def remove_lora(self, adapter_id: int) -> bool: - """Remove a LoRA""" + async def collective_rpc( + self, + method: str | Callable, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + ) -> Any: + """Execute a method on diffusion workers via collective RPC. + + Args: + method: Method name (str) or callable to execute on workers + timeout: Optional timeout in seconds + args: Positional arguments for the method + kwargs: Keyword arguments for the method + + Returns: + Results from the workers + """ loop = asyncio.get_event_loop() results = await loop.run_in_executor( self._executor, self.engine.collective_rpc, - "remove_lora", - None, - (adapter_id,), - {}, - None, + method, + timeout, + args, + kwargs, + ) + return results + + async def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="remove_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, ) return all(results) if isinstance(results, list) else results async def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: - """Add a LoRA adapter""" - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "add_lora", - None, - (), - {"lora_request": lora_request, "lora_scale": lora_scale}, - None, + """Add a LoRA adapter. + + Args: + lora_request: LoRA adapter request to load + lora_scale: Scale factor for LoRA weights + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, ) return all(results) if isinstance(results, list) else results async def list_loras(self) -> list[int]: - """List all registered LoRA adapter IDs.""" - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "list_loras", - None, - (), - {}, - None, + """List all registered LoRA adapter IDs. + + Returns: + List of unique adapter IDs + """ + results = await self.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, ) # collective_rpc returns list from workers; flatten unique ids if not isinstance(results, list): @@ -282,15 +318,54 @@ async def list_loras(self) -> list[int]: return sorted(merged) async def pin_lora(self, lora_id: int) -> bool: - """Prevent an adapter from being evicted.""" - loop = asyncio.get_event_loop() - results = await loop.run_in_executor( - self._executor, - self.engine.collective_rpc, - "pin_lora", - None, - (), - {"adapter_id": lora_id}, - None, + """Prevent an adapter from being evicted. + + Args: + lora_id: The adapter ID to pin + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": lora_id}, + ) + return all(results) if isinstance(results, list) else results + + async def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="sleep", + timeout=None, + args=(), + kwargs={"level": level}, + ) + return all(results) if isinstance(results, list) else results + + async def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + results = await self.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, ) return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/async_omni_llm.py b/vllm_omni/entrypoints/async_omni_llm.py old mode 100644 new mode 100755 index 287f12b9ed7..39c06fadc61 --- a/vllm_omni/entrypoints/async_omni_llm.py +++ b/vllm_omni/entrypoints/async_omni_llm.py @@ -185,6 +185,42 @@ def __init__( else: self.profiler = None + async def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + results = await self.engine_core.collective_rpc( + method="sleep", + timeout=None, + args=(), + kwargs={"level": level}, + ) + return all(results) if isinstance(results, list) else results + + async def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + results = await self.engine_core.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) + return all(results) if isinstance(results, list) else results + @classmethod @deprecate_kwargs( "disable_log_requests", diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py old mode 100644 new mode 100755 index 97357dc3b33..9ff0be0707c --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict from pprint import pformat -from typing import Any, Literal, overload +from typing import Any, Literal, TypeVar, overload from omegaconf import OmegaConf from tqdm.auto import tqdm @@ -42,10 +42,13 @@ resolve_model_config_path, ) from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) +_R = TypeVar("_R") + def _weak_close_cleanup(stage_list, stage_in_queues, ray_pg): """Weak reference cleanup function for OmniBase instances.""" @@ -861,6 +864,153 @@ def _run_generation( except Exception as e: logger.exception(f"[{self._name}] Failed to build/log summary: {e}") + def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method=method, + args=args, + timeout=timeout, + kwargs=kwargs, + ) + results.append(result) + return results + @property def _name(self) -> str: return "Orchestrator" + + def sleep(self, level: int = 1) -> list[bool]: + """Put all stage workers to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + List of results from each stage + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="sleep", + timeout=None, + args=(), + kwargs={"level": level}, + ) + results.append(result) + return results + + def wake_up(self, tags: list[str] | None = None) -> list[bool]: + """Wake up all stage workers from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + List of results from each stage + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) + results.append(result) + return results + + def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """Load a new LoRA adapter into the engine for future requests. + + Args: + lora_request: LoRA adapter request to load + lora_scale: Scale factor for LoRA weights + + Returns: + True if successful on all stages + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, + ) + results.append(result) + return all(results) if isinstance(results, list) else results + + def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter from all stages. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful on all stages + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="remove_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + results.append(result) + return all(results) if isinstance(results, list) else results + + def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs across all stages. + + Returns: + List of unique adapter IDs + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, + ) + results.append(result) + # Flatten and deduplicate adapter IDs from all stages + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + if isinstance(part, list): + merged.update(part or []) + elif part is not None: + merged.add(part) + return sorted(merged) + + def pin_lora(self, adapter_id: int) -> bool: + """Prevent a LoRA adapter from being evicted on all stages. + + Args: + adapter_id: The adapter ID to pin + + Returns: + True if successful on all stages + """ + results = [] + for stage in self.stage_list: + result = stage.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + results.append(result) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni_diffusion.py b/vllm_omni/entrypoints/omni_diffusion.py old mode 100644 new mode 100755 index 5ad9a91c80d..45966374115 --- a/vllm_omni/entrypoints/omni_diffusion.py +++ b/vllm_omni/entrypoints/omni_diffusion.py @@ -3,7 +3,8 @@ import logging import uuid -from collections.abc import Sequence +from collections.abc import Callable, Sequence +from typing import Any from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict @@ -12,6 +13,7 @@ from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType +from vllm_omni.lora.request import LoRARequest from vllm_omni.outputs import OmniRequestOutput # TODO configure logging properly @@ -106,6 +108,20 @@ def generate( def _run_engine(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: return self.engine.step(request) + def collective_rpc( + self, + method: str | Callable, + timeout: float | None = None, + args: tuple = (), + kwargs: dict | None = None, + ) -> Any: + return self.engine.collective_rpc( + method, + timeout=timeout, + args=args, + kwargs=kwargs, + ) + def close(self) -> None: self.engine.close() @@ -137,3 +153,111 @@ def stop_profile(self) -> dict: return self.engine.stop_profile() else: raise RuntimeError("Diffusion engine not initialized") + + def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + results = self.collective_rpc( + method="sleep", + timeout=None, + args=(), + kwargs={"level": level}, + ) + return all(results) if isinstance(results, list) else results + + def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + results = self.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) + return all(results) if isinstance(results, list) else results + + def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool: + """Add a LoRA adapter. + + Args: + lora_request: LoRA adapter request to load + lora_scale: Scale factor for LoRA weights + + Returns: + True if successful + """ + results = self.collective_rpc( + method="add_lora", + timeout=None, + args=(), + kwargs={"lora_request": lora_request, "lora_scale": lora_scale}, + ) + return all(results) if isinstance(results, list) else results + + def remove_lora(self, adapter_id: int) -> bool: + """Remove a LoRA adapter. + + Args: + adapter_id: The adapter ID to remove + + Returns: + True if successful + """ + results = self.collective_rpc( + method="remove_lora", + timeout=None, + args=(), + kwargs={"adapter_id": adapter_id}, + ) + return all(results) if isinstance(results, list) else results + + def list_loras(self) -> list[int]: + """List all registered LoRA adapter IDs. + + Returns: + List of unique adapter IDs + """ + results = self.collective_rpc( + method="list_loras", + timeout=None, + args=(), + kwargs={}, + ) + # collective_rpc returns list from workers; flatten unique ids + if not isinstance(results, list): + return results or [] + merged: set[int] = set() + for part in results: + merged.update(part or []) + return sorted(merged) + + def pin_lora(self, lora_id: int) -> bool: + """Prevent an adapter from being evicted. + + Args: + lora_id: The adapter ID to pin + + Returns: + True if successful + """ + results = self.collective_rpc( + method="pin_lora", + timeout=None, + args=(), + kwargs={"adapter_id": lora_id}, + ) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py old mode 100644 new mode 100755 index 6bb8ac3663f..0b5122993b1 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -241,3 +241,39 @@ def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[Re # This is necessary because some requests may be finished earlier than # its previous requests. return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0])) + + def sleep(self, level: int = 1) -> bool: + """Put the worker to sleep, offloading model weights. + + Args: + level: Sleep level. Level 1 offloads weights, level 2 also saves buffers. + + Returns: + True if successful + """ + results = self.llm_engine.collective_rpc( + method="sleep", + timeout=None, + args=(), + kwargs={"level": level}, + ) + return all(results) if isinstance(results, list) else results + + def wake_up(self, tags: list[str] | None = None) -> bool: + """Wake up the worker from sleep mode. + + Args: + tags: Optional list of tags to reallocate worker memory for specific + allocations. Values must be in ("weights",). If None, all memory + is reallocated. + + Returns: + True if successful + """ + results = self.llm_engine.collective_rpc( + method="wake_up", + timeout=None, + args=(), + kwargs={"tags": tags}, + ) + return all(results) if isinstance(results, list) else results diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py old mode 100644 new mode 100755 index 0a5ee55beea..4512c274797 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -15,9 +15,10 @@ import sys import time import traceback -from collections.abc import Sequence +import uuid +from collections.abc import Callable, Sequence from dataclasses import fields -from typing import Any, Literal, cast +from typing import Any, Literal, TypeVar, cast from vllm import PromptType, RequestOutput from vllm.inputs import TextPrompt @@ -55,6 +56,8 @@ logger = init_logger(__name__) +_R = TypeVar("_R") + def _build_od_config(engine_args: dict[str, Any], model: str) -> dict[str, Any]: """Build OmniDiffusionConfig kwargs from engine args.""" @@ -467,6 +470,64 @@ def process_engine_inputs( stage_list, engine_input_source, prompt, self.requires_multimodal_data ) + def collective_rpc( + self, + method: str | Callable[..., _R], + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ) -> list[_R]: + """Execute an RPC call on all workers via the stage engine. + + Args: + method: Name of the worker method to execute, or a callable that + is serialized and sent to all workers to execute. + + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. + timeout: Maximum time in seconds to wait for execution. Raises a + [`TimeoutError`][] on timeout. `None` means wait indefinitely. + args: Positional arguments to pass to the worker method. + kwargs: Keyword arguments to pass to the worker method. + + Returns: + A list containing the results from each worker. + + Note: + It is recommended to use this API to only pass control messages, + and set up data-plane communication to pass data. + """ + assert self._in_q is not None and self._out_q is not None, "Queues must be attached before collective_rpc" + + # Submit collective_rpc task to worker + rpc_id = str(uuid.uuid4()) + self._in_q.put( + { + "type": OmniStageTaskType.COLLECTIVE_RPC, + "rpc_id": rpc_id, + "method": method, + "timeout": timeout, + "args": args, + "kwargs": kwargs, + } + ) + + start_time = time.time() + while True: + if timeout is not None and (time.time() - start_time) > timeout: + raise TimeoutError(f"collective_rpc timed out after {timeout} seconds") + + result = self.try_collect() + if result is not None: + if result.get("type") == "collective_rpc_result": + if result.get("rpc_id") == rpc_id: + if "error" in result: + raise RuntimeError(f"collective_rpc failed: {result['error']}") + return result["result"] + + time.sleep(0.001) # Small sleep to avoid busy waiting + def _stage_worker( model: str, @@ -757,6 +818,32 @@ def handle_profiler_task_local(task_type: OmniStageTaskType) -> dict: logger.info("Received shutdown signal") break + if task_type == OmniStageTaskType.COLLECTIVE_RPC: + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + try: + result = stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + } + ) + continue + except Exception as e: + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + } + ) + continue + # Handle profiler control commands if is_profiler_task(task_type): profiler_data = handle_profiler_task_local(task_type) @@ -1362,6 +1449,30 @@ async def generation_single_request(task: dict[str, Any]): } ) + async def execute_rpc(task: dict[str, Any]): + try: + rpc_id = task.get("rpc_id") + method = task.get("method") + timeout = task.get("timeout") + args = task.get("args") + kwargs = task.get("kwargs") + result = await stage_engine.collective_rpc(method, timeout, args, kwargs) + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "result": result, + } + ) + except Exception as e: + out_q.put( + { + "type": "collective_rpc_result", + "rpc_id": rpc_id, + "error": str(e), + } + ) + _batch_gen_t0 = _time.time() while True: try: @@ -1374,6 +1485,9 @@ async def generation_single_request(task: dict[str, Any]): elif task_type == OmniStageTaskType.ABORT: rid = task["request_id"] asyncio.create_task(stage_engine.abort(rid)) + elif task_type == OmniStageTaskType.COLLECTIVE_RPC: + asyncio.create_task(execute_rpc(task)) + continue elif is_profiler_task(task_type): await handle_profiler_task_async(task_type) else: diff --git a/vllm_omni/entrypoints/stage_utils.py b/vllm_omni/entrypoints/stage_utils.py index da837ce9adc..26187c777af 100644 --- a/vllm_omni/entrypoints/stage_utils.py +++ b/vllm_omni/entrypoints/stage_utils.py @@ -14,6 +14,7 @@ class OmniStageTaskType(enum.Enum): GENERATE = "generate" + COLLECTIVE_RPC = "collective_rpc" ABORT = "abort" SHUTDOWN = "shutdown" PROFILER_START = "profiler_start"