Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ steps:
timeout_in_minutes: 20
depends_on: image-build
commands:
- pytest -s -v tests/diffusion/test_gpu_worker.py
- pytest -s -v tests/e2e/offline_inference/test_rpc_collective.py
- pytest -s -v tests/diffusion/test_gpu_diffusion_worker.py
agents:
queue: "gpu_4_queue" # g6.12xlarge instance on AWS, has 4 L4 GPU
Expand Down
49 changes: 49 additions & 0 deletions tests/e2e/offline_inference/test_rpc_collective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import asyncio
import os
import sys
from pathlib import Path

import pytest

from .utils import create_new_process_for_each_test

# ruff: noqa: E402
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))

from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.omni import Omni

os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"

models = ["/mnt/nvme3n1/n0090/Z-Image-Turbo"]


@create_new_process_for_each_test()
@pytest.mark.parametrize("model_name", models)
def test_rpc_collective_omni(model_name: str):
m = Omni(model=model_name, enable_sleep_mode=True)
sleep_results = m.collective_rpc(
method="sleep",
args=(1,),
)
assert len(sleep_results) == 1
wake_up_results = m.collective_rpc(
method="wake_up",
args=(["weights"],),
)
assert len(wake_up_results) == 1


@create_new_process_for_each_test()
@pytest.mark.parametrize("model_name", models)
def test_rpc_collective_async_omni(model_name: str):
async def _run():
m = AsyncOmni(model=model_name, enable_sleep_mode=True)
sleep_results = await m.collective_rpc(method="sleep", args=(1,))
assert len(sleep_results) == 1
wake_up_results = await m.collective_rpc(method="wake_up", args=(["weights"],))
assert len(wake_up_results) == 1

asyncio.run(_run())
Empty file modified vllm_omni/diffusion/worker/gpu_diffusion_worker.py
100644 → 100755
Empty file.
138 changes: 131 additions & 7 deletions vllm_omni/entrypoints/async_omni.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import asyncio
import time
import weakref
from collections.abc import AsyncGenerator, Iterable, Sequence
from collections.abc import AsyncGenerator, Callable, Iterable, Sequence
from dataclasses import asdict
from pprint import pformat
from typing import Any, cast
from typing import Any, TypeVar, cast

from vllm.config import VllmConfig
from vllm.inputs.preprocess import InputPreprocessor
Expand Down Expand Up @@ -40,6 +40,8 @@

logger = init_logger(__name__)

_R = TypeVar("_R")


def _weak_close_cleanup_async(stage_list, stage_in_queues, ray_pg, output_handler):
"""Weak reference cleanup function for AsyncOmni instances."""
Expand Down Expand Up @@ -163,6 +165,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st
"cache_backend": cache_backend,
"cache_config": cache_config,
"enable_cpu_offload": kwargs.get("enable_cpu_offload", False),
"enable_sleep_mode": kwargs.get("enable_sleep_mode", False),
"enforce_eager": kwargs.get("enforce_eager", False),
},
"final_output": True,
Expand Down Expand Up @@ -692,19 +695,140 @@ async def reset_mm_cache(self) -> None:
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
pass

async def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
"""Execute a method on all stage workers via collective RPC.

Args:
method: Method name (str) or callable to execute on workers
timeout: Optional timeout in seconds
args: Positional arguments for the method
kwargs: Keyword arguments for the method

Returns:
List of results from each stage
"""
results = []
for stage in self.stage_list:
result = stage.collective_rpc(
method=method,
timeout=timeout,
args=args,
kwargs=kwargs,
)
results.append(result)
return results

async def sleep(self, level: int = 1) -> None:
pass
"""Put all stage workers to sleep, offloading model weights.

Args:
level: Sleep level. Level 1 offloads weights, level 2 also saves buffers.
"""
await self.collective_rpc(
method="sleep",
timeout=None,
args=(),
kwargs={"level": level},
)

async def wake_up(self, tags: list[str] | None = None) -> None:
pass
"""Wake up all stage workers from sleep mode.

Args:
tags: Optional list of tags to reallocate worker memory for specific
allocations. Values must be in ("weights",). If None, all memory
is reallocated.
"""
await self.collective_rpc(
method="wake_up",
timeout=None,
args=(),
kwargs={"tags": tags},
)

async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""
return False

async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests."""
return False
async def add_lora(self, lora_request: LoRARequest, lora_scale: float = 1.0) -> bool:
"""Load a new LoRA adapter into the engine for future requests.

Args:
lora_request: LoRA adapter request to load

Returns:
True if successful on all stages
"""
results = await self.collective_rpc(
method="add_lora",
timeout=None,
args=(),
kwargs={"lora_request": lora_request, "lora_scale": lora_scale},
)
return all(results) if isinstance(results, list) else results

async def remove_lora(self, adapter_id: int) -> bool:
"""Remove a LoRA adapter from all stages.

Args:
adapter_id: The adapter ID to remove

Returns:
True if successful on all stages
"""
results = await self.collective_rpc(
method="remove_lora",
timeout=None,
args=(),
kwargs={"adapter_id": adapter_id},
)
return all(results) if isinstance(results, list) else results

async def list_loras(self) -> list[int]:
"""List all registered LoRA adapter IDs across all stages.

Returns:
List of unique adapter IDs
"""
results = await self.collective_rpc(
method="list_loras",
timeout=None,
args=(),
kwargs={},
)
# Flatten and deduplicate adapter IDs from all stages
if not isinstance(results, list):
return results or []
merged: set[int] = set()
for part in results:
if isinstance(part, list):
merged.update(part or [])
elif part is not None:
merged.add(part)
return sorted(merged)

async def pin_lora(self, adapter_id: int) -> bool:
"""Prevent a LoRA adapter from being evicted on all stages.

Args:
adapter_id: The adapter ID to pin

Returns:
True if successful on all stages
"""
results = await self.collective_rpc(
method="pin_lora",
timeout=None,
args=(),
kwargs={"adapter_id": adapter_id},
)
return all(results) if isinstance(results, list) else results

async def encode(
self,
Expand Down
Loading