From 7375b5d390ca553b422e6a759129c137c74f7d8d Mon Sep 17 00:00:00 2001 From: jader Date: Tue, 3 Mar 2026 13:45:44 +0800 Subject: [PATCH 01/14] [Refactor] Refactor Diffusion Scheduler/Executor Boundaries and Request State Flow Refactor diffusion runtime boundaries to separate scheduler state management from multiprocess IPC execution. Core goals: - Make Scheduler a pure request-state scheduler (waiting/running/finished) without owning IPC queues. - Make MultiprocDiffusionExecutor a pure IPC runtime (broadcast/result queues + worker lifecycle). - Let DiffusionEngine explicitly drive add_request -> schedule -> execute -> update_from_output. - Consolidate cross-API concurrency control into DiffusionEngine._rpc_lock. Main code changes: - scheduler.py: introduce request status/state output types and pure scheduling APIs; remove scheduler-side IPC ownership. - diffusion_engine.py: engine owns scheduler and _rpc_lock; refactor add_req_and_wait_for_response to scheduler-driven flow. - multiproc_executor.py: executor directly manages IPC queues and worker lifecycle; decouple from scheduler internals. - tests: add diffusion scheduler tests; rename/refactor multiproc concurrency test to engine-focused variant. Test plan: - pytest -m diffusion tests/diffusion/test_diffusion_scheduler.py - pytest -m diffusion tests/diffusion/test_multiproc_engine_concurrency.py Signed-off-by: jader --- tests/diffusion/test_diffusion_scheduler.py | 149 ++++++++++ ...y => test_multiproc_engine_concurrency.py} | 177 ++++++------ vllm_omni/diffusion/diffusion_engine.py | 67 ++++- .../diffusion/executor/multiproc_executor.py | 171 ++++++++---- vllm_omni/diffusion/scheduler.py | 259 +++++++++++++----- 5 files changed, 598 insertions(+), 225 deletions(-) create mode 100644 tests/diffusion/test_diffusion_scheduler.py rename tests/diffusion/{test_multiproc_executor_concurrency.py => test_multiproc_engine_concurrency.py} (72%) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py new file mode 100644 index 00000000000..09d8f6c6de9 --- /dev/null +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import threading +from unittest.mock import Mock + +import pytest + +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.diffusion_engine import DiffusionEngine +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.scheduler import DiffusionRequestStatus, Scheduler +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.diffusion] + + +def _make_scheduler() -> Scheduler: + scheduler = Scheduler() + scheduler.initialize(Mock()) + return scheduler + + +def _make_request(req_id: str) -> OmniDiffusionRequest: + return OmniDiffusionRequest( + prompts=[f"prompt_{req_id}"], + sampling_params=OmniDiffusionSamplingParams(num_inference_steps=1), + request_ids=[req_id], + ) + + +def test_single_request_success_lifecycle() -> None: + scheduler = _make_scheduler() + + req_id = scheduler.add_request(_make_request("a")) + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.WAITING + + sched_output = scheduler.schedule() + assert len(sched_output.req_states) == 1 + assert sched_output.req_states[0].req_id == req_id + assert sched_output.num_running_reqs == 1 + assert sched_output.num_waiting_reqs == 0 + + finished = scheduler.update_from_output(sched_output, DiffusionOutput(output=None)) + assert finished == {req_id} + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert scheduler.has_requests() is False + + +def test_error_output_marks_finished_error() -> None: + scheduler = _make_scheduler() + req_id = scheduler.add_request(_make_request("err")) + + sched_output = scheduler.schedule() + finished = scheduler.update_from_output(sched_output, DiffusionOutput(error="worker failed")) + + assert finished == {req_id} + state = scheduler.get_request_state(req_id) + assert state.status == DiffusionRequestStatus.FINISHED_ERROR + assert state.error == "worker failed" + + +def test_empty_output_without_error_marks_completed() -> None: + scheduler = _make_scheduler() + req_id = scheduler.add_request(_make_request("empty")) + + sched_output = scheduler.schedule() + finished = scheduler.update_from_output(sched_output, DiffusionOutput(output=None, error=None)) + + assert finished == {req_id} + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + + +def test_fifo_single_request_scheduling() -> None: + scheduler = _make_scheduler() + req_id_a = scheduler.add_request(_make_request("a")) + req_id_b = scheduler.add_request(_make_request("b")) + + first = scheduler.schedule() + assert [s.req_id for s in first.req_states] == [req_id_a] + assert first.num_running_reqs == 1 + assert first.num_waiting_reqs == 1 + + # Request A is still running; scheduling again should not pull B. + second = scheduler.schedule() + assert [s.req_id for s in second.req_states] == [req_id_a] + assert second.num_running_reqs == 1 + assert second.num_waiting_reqs == 1 + + scheduler.update_from_output(first, DiffusionOutput(output=None)) + + third = scheduler.schedule() + assert [s.req_id for s in third.req_states] == [req_id_b] + assert third.num_running_reqs == 1 + assert third.num_waiting_reqs == 0 + + +def test_abort_request_for_waiting_and_running() -> None: + scheduler = _make_scheduler() + req_id_a = scheduler.add_request(_make_request("a")) + req_id_b = scheduler.add_request(_make_request("b")) + + # Abort waiting request. + assert scheduler.abort_request(req_id_b) is True + state_b = scheduler.get_request_state(req_id_b) + assert state_b.status == DiffusionRequestStatus.FINISHED_ABORTED + + # A should still run normally. + output_a = scheduler.schedule() + assert [s.req_id for s in output_a.req_states] == [req_id_a] + + # Abort running request. + assert scheduler.abort_request(req_id_a) is True + state_a = scheduler.get_request_state(req_id_a) + assert state_a.status == DiffusionRequestStatus.FINISHED_ABORTED + + assert scheduler.has_requests() is False + assert scheduler.schedule().req_states == [] + + +def test_has_requests_state_transition() -> None: + scheduler = _make_scheduler() + assert scheduler.has_requests() is False + + req_id = scheduler.add_request(_make_request("has")) + assert scheduler.has_requests() is True + + sched_output = scheduler.schedule() + assert scheduler.has_requests() is True + + scheduler.update_from_output(sched_output, DiffusionOutput(output=None)) + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert scheduler.has_requests() is False + + +def test_engine_add_req_and_wait_for_response_single_path() -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.scheduler = _make_scheduler() + engine.executor = Mock() + engine._rpc_lock = threading.Lock() + + request = _make_request("engine") + expected = DiffusionOutput(output=None) + engine.executor.add_req.return_value = expected + + output = engine.add_req_and_wait_for_response(request) + + assert output is expected + engine.executor.add_req.assert_called_once_with(request) diff --git a/tests/diffusion/test_multiproc_executor_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py similarity index 72% rename from tests/diffusion/test_multiproc_executor_concurrency.py rename to tests/diffusion/test_multiproc_engine_concurrency.py index 76caa984465..90289727fcf 100644 --- a/tests/diffusion/test_multiproc_executor_concurrency.py +++ b/tests/diffusion/test_multiproc_engine_concurrency.py @@ -9,6 +9,7 @@ import torch from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor from vllm_omni.diffusion.scheduler import Scheduler @@ -30,42 +31,43 @@ def _mock_request(tag: str) -> Mock: return req -def _make_scheduler(): - """Create a ``Scheduler`` whose *mq* / *result_mq* are backed by - plain ``queue.Queue`` objects (thread-safe, no real IPC). +def _make_executor(num_gpus: int = 1): + """Create a ``MultiprocDiffusionExecutor`` without launching workers. - Returns ``(scheduler, request_queue, result_queue)``. + Returns ``(executor, request_queue, result_queue)``. """ - sched = Scheduler() - sched.num_workers = 1 - sched._lock = threading.Lock() + od_cfg = Mock() + od_cfg.num_gpus = num_gpus + + with patch.object(MultiprocDiffusionExecutor, "_init_executor"): + executor = MultiprocDiffusionExecutor(od_cfg) req_q: queue.Queue = queue.Queue() res_q: queue.Queue = queue.Queue() - mock_mq = Mock() - mock_mq.enqueue = req_q.put + mock_broadcast_mq = Mock() + mock_broadcast_mq.enqueue = req_q.put mock_rmq = Mock() - mock_rmq.dequeue = lambda timeout=None: res_q.get(timeout=timeout if timeout else 10) - - sched.mq = mock_mq - sched.result_mq = mock_rmq - return sched, req_q, res_q + mock_rmq.dequeue = lambda timeout=None: res_q.get(timeout=timeout if timeout is not None else 10) - -def _make_executor(scheduler): - """Create a ``MultiprocDiffusionExecutor`` wired to *scheduler* - without launching real worker processes. - """ - od_cfg = Mock() - od_cfg.num_gpus = 1 - with patch.object(MultiprocDiffusionExecutor, "_init_executor"): - executor = MultiprocDiffusionExecutor(od_cfg) - executor.scheduler = scheduler + executor._broadcast_mq = mock_broadcast_mq + executor._result_mq = mock_rmq executor._closed = False executor._processes = [] - return executor + return executor, req_q, res_q + + +def _make_engine(num_gpus: int = 1): + """Create a lightweight ``DiffusionEngine`` wired to mocked executor.""" + executor, req_q, res_q = _make_executor(num_gpus) + engine = DiffusionEngine.__new__(DiffusionEngine) + sched = Scheduler() + sched.initialize(Mock()) + engine.scheduler = sched + engine.executor = executor + engine._rpc_lock = threading.Lock() + return engine, executor, req_q, res_q def _start_worker(req_q, res_q, count=2): @@ -91,8 +93,8 @@ def _run(): return t -def _inject_interleave(scheduler): - """Monkey-patch ``scheduler.mq.enqueue`` so that: +def _inject_interleave(executor): + """Monkey-patch ``executor._broadcast_mq.enqueue`` so that: * The thread named **thread_a** *blocks* after its enqueue until the thread named **thread_b** has finished entirely. @@ -102,7 +104,7 @@ def _inject_interleave(scheduler): """ a_enqueued = threading.Event() b_complete = threading.Event() - orig_enqueue = scheduler.mq.enqueue # points to req_q.put + orig_enqueue = executor._broadcast_mq.enqueue # points to req_q.put def _controlled(item): orig_enqueue(item) @@ -110,7 +112,7 @@ def _controlled(item): a_enqueued.set() # tell B: "A has enqueued" b_complete.wait(5) # block A until B finishes - scheduler.mq.enqueue = _controlled + executor._broadcast_mq.enqueue = _controlled return a_enqueued, b_complete @@ -118,21 +120,21 @@ def _controlled(item): class TestConcurrentAddReqBug: - """Two concurrent ``Scheduler.add_req()`` calls swap results.""" + """Two concurrent ``add_req_and_wait_for_response()`` calls swap results.""" def test_results_are_correctly_routed(self): - sched, req_q, res_q = _make_scheduler() - a_enqueued, b_complete = _inject_interleave(sched) + engine, executor, req_q, res_q = _make_engine() + a_enqueued, b_complete = _inject_interleave(executor) wt = _start_worker(req_q, res_q, count=2) results: dict[str, DiffusionOutput] = {} def _a(): - results["A"] = sched.add_req(_mock_request("A")) + results["A"] = engine.add_req_and_wait_for_response(_mock_request("A")) def _b(): a_enqueued.wait(5) # wait for A to enqueue - results["B"] = sched.add_req(_mock_request("B")) + results["B"] = engine.add_req_and_wait_for_response(_mock_request("B")) b_complete.set() # release A ta = threading.Thread(target=_a, name="thread_a") @@ -156,15 +158,14 @@ class TestConcurrentCollectiveRpcBug: """Two concurrent ``collective_rpc()`` calls swap results.""" def test_results_are_correctly_routed(self): - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) - a_enqueued, b_complete = _inject_interleave(sched) + engine, executor, req_q, res_q = _make_engine() + a_enqueued, b_complete = _inject_interleave(executor) wt = _start_worker(req_q, res_q, count=2) results: dict[str, object] = {} def _a(): - results["A"] = executor.collective_rpc( + results["A"] = engine.collective_rpc( "ping", args=("call_A",), unique_reply_rank=0, @@ -172,7 +173,7 @@ def _a(): def _b(): a_enqueued.wait(5) - results["B"] = executor.collective_rpc( + results["B"] = engine.collective_rpc( "ping", args=("call_B",), unique_reply_rank=0, @@ -198,19 +199,18 @@ class TestConcurrentAddReqVsCollectiveRpcBug: """``add_req`` and ``collective_rpc`` running concurrently swap results.""" def test_results_are_correctly_routed(self): - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) - a_enqueued, b_complete = _inject_interleave(sched) + engine, executor, req_q, res_q = _make_engine() + a_enqueued, b_complete = _inject_interleave(executor) wt = _start_worker(req_q, res_q, count=2) results: dict[str, object] = {} def _a(): # add_req path - results["A"] = sched.add_req(_mock_request("A")) + results["A"] = engine.add_req_and_wait_for_response(_mock_request("A")) def _b(): # collective_rpc path a_enqueued.wait(5) - results["B"] = executor.collective_rpc( + results["B"] = engine.collective_rpc( "ping", args=("call_B",), unique_reply_rank=0, @@ -241,31 +241,30 @@ class TestSerialOperations: """ def test_serial_add_req_returns_correct_result(self): - sched, req_q, res_q = _make_scheduler() + engine, _, req_q, res_q = _make_engine() wt = _start_worker(req_q, res_q, count=1) - result = sched.add_req(_mock_request("X")) + result = engine.add_req_and_wait_for_response(_mock_request("X")) wt.join(5) assert isinstance(result, DiffusionOutput) assert result.error == "result_for_X" def test_serial_add_req_multiple_sequential(self): - sched, req_q, res_q = _make_scheduler() + engine, _, req_q, res_q = _make_engine() wt = _start_worker(req_q, res_q, count=3) for tag in ("one", "two", "three"): - out = sched.add_req(_mock_request(tag)) + out = engine.add_req_and_wait_for_response(_mock_request(tag)) assert out.error == f"result_for_{tag}" wt.join(5) def test_serial_collective_rpc_single_rank(self): - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) + engine, _, req_q, res_q = _make_engine() wt = _start_worker(req_q, res_q, count=1) - result = executor.collective_rpc( + result = engine.collective_rpc( "ping", args=("Y",), unique_reply_rank=0, @@ -278,27 +277,24 @@ def test_serial_collective_rpc_all_ranks(self): """``collective_rpc`` without *unique_reply_rank* collects ``num_gpus`` responses. """ - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) - executor.od_config.num_gpus = 2 + engine, _, _, res_q = _make_engine(num_gpus=2) # Pre-populate two results (simulating two workers replying) res_q.put(_tagged_output("rank0")) res_q.put(_tagged_output("rank1")) - results = executor.collective_rpc("ping", args=("multi",)) + results = engine.collective_rpc("ping", args=("multi",)) assert len(results) == 2 assert results[0].error == "rank0" assert results[1].error == "rank1" def test_serial_add_req_then_collective_rpc(self): - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) + engine, _, req_q, res_q = _make_engine() wt = _start_worker(req_q, res_q, count=2) - gen_out = sched.add_req(_mock_request("gen")) - rpc_out = executor.collective_rpc( + gen_out = engine.add_req_and_wait_for_response(_mock_request("gen")) + rpc_out = engine.collective_rpc( "ping", args=("rpc",), unique_reply_rank=0, @@ -310,29 +306,30 @@ def test_serial_add_req_then_collective_rpc(self): def test_serial_add_req_error_propagation(self): """``add_req`` should raise when the worker reports an error.""" - sched, _, res_q = _make_scheduler() + engine, _, _, res_q = _make_engine() # Put an error response directly res_q.put({"status": "error", "error": "boom"}) - with pytest.raises(RuntimeError, match="worker error"): - sched.add_req(_mock_request("fail")) + out = engine.add_req_and_wait_for_response(_mock_request("fail")) + + assert isinstance(out, DiffusionOutput) + assert out.error is not None + assert "boom" in out.error def test_serial_collective_rpc_error_propagation(self): """``collective_rpc`` should raise when the worker reports an error.""" - sched, _, res_q = _make_scheduler() - executor = _make_executor(sched) + engine, _, _, res_q = _make_engine() res_q.put({"status": "error", "error": "kaboom"}) with pytest.raises(RuntimeError, match="kaboom"): - executor.collective_rpc("bad", unique_reply_rank=0) + engine.collective_rpc("bad", unique_reply_rank=0) def test_collective_rpc_closed_executor_raises(self): - sched, _, _ = _make_scheduler() - executor = _make_executor(sched) + engine, executor, _, _ = _make_engine() executor._closed = True with pytest.raises(RuntimeError, match="closed"): - executor.collective_rpc("anything") + engine.collective_rpc("anything") # ─────────── timeout regression: RPC must not block on a stalled lock ───── @@ -340,23 +337,22 @@ def test_collective_rpc_closed_executor_raises(self): class TestCollectiveRpcTimeoutWhileLockHeld: """``collective_rpc(timeout=...)`` must honour its timeout even when - another thread holds ``scheduler._lock`` indefinitely (e.g. a stalled + another thread holds ``engine._rpc_lock`` indefinitely (e.g. a stalled ``add_req`` waiting on an unresponsive worker). """ def test_rpc_times_out_when_lock_held_directly(self): """Simplest case: lock is manually held by another thread.""" - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) + engine, _, _, _ = _make_engine() stall_started = threading.Event() def _hold_lock(): - sched._lock.acquire() + engine._rpc_lock.acquire() stall_started.set() # Hold the lock far longer than the RPC timeout. threading.Event().wait(30) - sched._lock.release() + engine._rpc_lock.release() stall_thread = threading.Thread(target=_hold_lock, daemon=True) stall_thread.start() @@ -364,24 +360,23 @@ def _hold_lock(): # collective_rpc should raise TimeoutError, NOT block forever. with pytest.raises(TimeoutError): - executor.collective_rpc("health", timeout=0.5) + engine.collective_rpc("health", timeout=0.5) def test_rpc_times_out_when_add_req_stalled_on_worker(self): """Real-world scenario the bot flagged: - ``add_req`` holds ``_lock`` while blocked on ``result_mq.dequeue()`` - because the worker never replies. A concurrent - ``collective_rpc(timeout=...)`` must still time out instead of - hanging forever waiting for the lock. + ``add_req`` holds ``_rpc_lock`` while blocked on + ``executor._result_mq.dequeue()`` because the worker never replies. + A concurrent ``collective_rpc(timeout=...)`` must still time out + instead of hanging forever waiting for the lock. """ - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) + engine, executor, _, _ = _make_engine() add_req_blocked = threading.Event() # Patch dequeue: signal once entered, then block indefinitely # (simulates a worker that never sends a result). - orig_dequeue = sched.result_mq.dequeue + orig_dequeue = executor._result_mq.dequeue def _hanging_dequeue(timeout=None): add_req_blocked.set() @@ -389,13 +384,13 @@ def _hanging_dequeue(timeout=None): threading.Event().wait(30) return orig_dequeue(timeout=timeout) - sched.result_mq.dequeue = _hanging_dequeue + executor._result_mq.dequeue = _hanging_dequeue # Thread running add_req — acquires the lock, enqueues, then # blocks on dequeue forever (worker hang). def _stalled_add_req(): try: - sched.add_req(_mock_request("stalled")) + engine.add_req_and_wait_for_response(_mock_request("stalled")) except Exception: pass @@ -407,23 +402,19 @@ def _stalled_add_req(): # collective_rpc should time out at lock acquisition, not hang. with pytest.raises(TimeoutError): - executor.collective_rpc("health_check", timeout=0.5) + engine.collective_rpc("health_check", timeout=0.5) def test_rpc_without_timeout_still_waits_for_lock(self): """When no timeout is given, ``collective_rpc`` should still wait for the lock (blocking) — existing behaviour preserved. """ - sched, req_q, res_q = _make_scheduler() - executor = _make_executor(sched) - - lock_released = threading.Event() + engine, _, _, res_q = _make_engine() def _hold_and_release(): - sched._lock.acquire() + engine._rpc_lock.acquire() # Hold for a short time then release. threading.Event().wait(0.3) - sched._lock.release() - lock_released.set() + engine._rpc_lock.release() # Pre-populate a result so collective_rpc succeeds after lock. res_q.put(_tagged_output("ok")) @@ -431,8 +422,8 @@ def _hold_and_release(): t = threading.Thread(target=_hold_and_release, daemon=True) t.start() - # No timeout → should block until lock is released, then succeed. - result = executor.collective_rpc( + # No timeout -> should block until lock is released, then succeed. + result = engine.collective_rpc( "ping", args=("wait",), unique_reply_rank=0, diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index b80d96bfe3e..5f53c3d7dec 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +import threading import time from collections.abc import Iterable from typing import Any @@ -11,7 +12,7 @@ import torch from vllm.logger import init_logger -from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.executor.abstract import DiffusionExecutor from vllm_omni.diffusion.registry import ( DiffusionModelRegistry, @@ -19,6 +20,7 @@ get_diffusion_pre_process_func, ) from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.scheduler import Scheduler from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput @@ -67,6 +69,9 @@ def __init__(self, od_config: OmniDiffusionConfig): executor_class = DiffusionExecutor.get_class(od_config) self.executor = executor_class(od_config) + self.scheduler = Scheduler() + self.scheduler.initialize(od_config) + self._rpc_lock = threading.Lock() try: self._dummy_run() @@ -260,8 +265,28 @@ def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": """ return DiffusionEngine(config) - def add_req_and_wait_for_response(self, request: OmniDiffusionRequest): - return self.executor.add_req(request) + def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> DiffusionOutput: + with self._rpc_lock: + target_req_id = self.scheduler.add_request(request) + + while True: + sched_output = self.scheduler.schedule() + if not sched_output.req_states: + if not self.scheduler.has_requests(): + raise RuntimeError("Diffusion scheduler has no runnable requests.") + continue + + req_state = sched_output.req_states[0] + try: + output = self.executor.add_req(req_state.req) + except Exception as exc: + logger.error("Execution failed for diffusion request %s", req_state.req_id, exc_info=True) + output = DiffusionOutput(error=str(exc)) + + finished_req_ids = self.scheduler.update_from_output(sched_output, output) + if target_req_id in finished_req_ids: + self.scheduler.pop_request_state(target_req_id) + return output def start_profile(self, trace_filename: str | None = None) -> None: """ @@ -447,15 +472,37 @@ def collective_rpc( Single result if unique_reply_rank is provided, otherwise list of results """ assert isinstance(method, str), "Only string method names are supported for now" - return self.executor.collective_rpc( - method=method, - timeout=timeout, - args=args, - kwargs=kwargs, - unique_reply_rank=unique_reply_rank, - ) + + deadline = None if timeout is None else time.monotonic() + timeout + acquired = False + try: + if deadline is None: + self._rpc_lock.acquire() + acquired = True + else: + lock_timeout = max(0, deadline - time.monotonic()) + acquired = self._rpc_lock.acquire(timeout=lock_timeout) + if not acquired: + raise TimeoutError(f"RPC call to {method} timed out waiting for engine lock.") + + rpc_timeout = None if deadline is None else max(0, deadline - time.monotonic()) + if deadline is not None and rpc_timeout <= 0: + raise TimeoutError(f"RPC call to {method} timed out.") + + return self.executor.collective_rpc( + method=method, + timeout=rpc_timeout, + args=args, + kwargs=kwargs, + unique_reply_rank=unique_reply_rank, + ) + finally: + if acquired: + self._rpc_lock.release() def close(self) -> None: + if hasattr(self, "scheduler"): + self.scheduler.close() if hasattr(self, "executor"): self.executor.shutdown() diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index 24d0741a504..bb9e8d1c4d8 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -4,12 +4,14 @@ from dataclasses import dataclass from typing import Any +import zmq +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, DiffusionOutput from vllm_omni.diffusion.executor.abstract import DiffusionExecutor +from vllm_omni.diffusion.ipc import unpack_diffusion_output_shm from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.scheduler import Scheduler from vllm_omni.diffusion.worker import WorkerProc logger = init_logger(__name__) @@ -21,18 +23,30 @@ class BackgroundResources: Used as a finalizer for clean shutdown. """ - scheduler: Scheduler | None = None + broadcast_mq: MessageQueue | None = None + result_mq: MessageQueue | None = None + num_workers: int = 0 processes: list[mp.Process] | None = None def __call__(self): """Clean up background resources.""" - if self.scheduler is not None: + if self.broadcast_mq is not None: try: - for _ in range(self.scheduler.num_workers): - self.scheduler.mq.enqueue(SHUTDOWN_MESSAGE) - self.scheduler.close() + for _ in range(self.num_workers): + self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE) except Exception as exc: logger.warning("Failed to send shutdown signal: %s", exc) + + for queue, label in ((self.broadcast_mq, "broadcast"), (self.result_mq, "result")): + if queue is None: + continue + try: + close_fn = getattr(queue, "close", None) + if callable(close_fn): + close_fn() + except Exception as exc: + logger.warning("Failed to close %s queue: %s", label, exc) + if self.processes: for proc in self.processes: if not proc.is_alive(): @@ -51,24 +65,42 @@ def _init_executor(self) -> None: self._processes: list[mp.Process] = [] self._closed = False - # Initialize scheduler - self.scheduler = Scheduler() - self.scheduler.initialize(self.od_config) - broadcast_handle = self.scheduler.get_broadcast_handle() + num_workers = self.od_config.num_gpus + self._broadcast_mq = self._init_broadcast_queue(num_workers) + broadcast_handle = self._broadcast_mq.export_handle() # Launch workers processes, result_handle = self._launch_workers(broadcast_handle) - - if result_handle is not None: - self.scheduler.initialize_result_queue(result_handle) - else: - logger.error("Failed to get result queue handle from workers") - + self._result_mq = self._init_result_queue(result_handle) self._processes = processes - self.resources = BackgroundResources(scheduler=self.scheduler, processes=self._processes) + self.resources = BackgroundResources( + broadcast_mq=self._broadcast_mq, + result_mq=self._result_mq, + num_workers=num_workers, + processes=self._processes, + ) self._finalizer = weakref.finalize(self, self.resources) + def _init_broadcast_queue(self, num_workers: int) -> MessageQueue: + return MessageQueue( + n_reader=num_workers, + n_local_reader=num_workers, + local_reader_ranks=list(range(num_workers)), + ) + + def _init_result_queue(self, result_handle) -> MessageQueue | None: + if result_handle is None: + logger.error("Failed to get result queue handle from workers") + return None + return MessageQueue.create_from_handle(result_handle, 0) + + def _ensure_open(self) -> None: + if self._closed: + raise RuntimeError("DiffusionExecutor is closed.") + if self._result_mq is None: + raise RuntimeError("Result queue not initialized") + def _launch_workers(self, broadcast_handle): od_config = self.od_config logger.info("Starting server...") @@ -134,7 +166,43 @@ def _launch_workers(self, broadcast_handle): return processes, result_handle def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: - return self.scheduler.add_req(request) + self._ensure_open() + deadline = None + rpc_request = { + "type": "rpc", + "method": "generate", + "args": (request,), + "kwargs": {}, + "output_rank": 0, + "exec_all_ranks": True, + } + + try: + self._broadcast_mq.enqueue(rpc_request) + + try: + response = self._result_mq.dequeue(timeout=deadline) + except zmq.error.Again as exc: + raise TimeoutError("Generate call timed out.") from exc + except TimeoutError as exc: + raise TimeoutError("Generate call timed out.") from exc + + try: + unpack_diffusion_output_shm(response) + except Exception as e: + logger.warning("SHM unpack failed (data may already be inline): %s", e) + + if isinstance(response, dict) and response.get("status") == "error": + raise RuntimeError( + f"Worker failed with error '{response.get('error')}', " + "please check the stack trace above for the root cause" + ) + if not isinstance(response, DiffusionOutput): + raise RuntimeError(f"Unexpected response type for generate: {type(response)!r}") + return response + except Exception as e: + logger.error(f"Generate call failed: {e}") + raise def collective_rpc( self, @@ -144,8 +212,7 @@ def collective_rpc( kwargs: dict | None = None, unique_reply_rank: int | None = None, ) -> Any: - if self._closed: - raise RuntimeError("DiffusionExecutor is closed.") + self._ensure_open() deadline = None if timeout is None else time.monotonic() + timeout kwargs = kwargs or {} @@ -160,44 +227,32 @@ def collective_rpc( } try: - # Acquire lock with timeout awareness so that a stalled add_req - # (holding the lock while blocked on dequeue) does not prevent - # this RPC from honouring its own timeout. - lock_timeout = None if deadline is None else max(0, deadline - time.monotonic()) - acquired = self.scheduler._lock.acquire(timeout=lock_timeout if lock_timeout is not None else -1) - if not acquired: - raise TimeoutError(f"RPC call to {method} timed out waiting for scheduler lock.") - try: - # Broadcast RPC request to all workers via unified message queue - self.scheduler.mq.enqueue(rpc_request) - - # Determine which workers we expect responses from - num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus - - responses = [] - for _ in range(num_responses): - dequeue_timeout = None if deadline is None else max(0, deadline - time.monotonic()) - try: - if self.scheduler.result_mq is None: - raise RuntimeError("Result queue not initialized") - - response = self.scheduler.result_mq.dequeue(timeout=dequeue_timeout) - - # Check if response indicates an error - if isinstance(response, dict) and response.get("status") == "error": - raise RuntimeError( - f"Worker failed with error '{response.get('error')}', " - "please check the stack trace above for the root cause" - ) - - responses.append(response) - except TimeoutError as e: - raise TimeoutError(f"RPC call to {method} timed out.") from e - - return responses[0] if unique_reply_rank is not None else responses - finally: - self.scheduler._lock.release() - + # Broadcast RPC request to all workers via unified message queue + self._broadcast_mq.enqueue(rpc_request) + + # Determine which workers we expect responses from + num_responses = 1 if unique_reply_rank is not None else self.od_config.num_gpus + + responses = [] + for _ in range(num_responses): + dequeue_timeout = None if deadline is None else max(0, deadline - time.monotonic()) + try: + response = self._result_mq.dequeue(timeout=dequeue_timeout) + + # Check if response indicates an error + if isinstance(response, dict) and response.get("status") == "error": + raise RuntimeError( + f"Worker failed with error '{response.get('error')}', " + "please check the stack trace above for the root cause" + ) + + responses.append(response) + except zmq.error.Again as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e + except TimeoutError as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e + + return responses[0] if unique_reply_rank is not None else responses except Exception as e: logger.error(f"RPC call failed: {e}") raise diff --git a/vllm_omni/diffusion/scheduler.py b/vllm_omni/diffusion/scheduler.py index 5f1d01c1282..6e3b21ccbe5 100644 --- a/vllm_omni/diffusion/scheduler.py +++ b/vllm_omni/diffusion/scheduler.py @@ -1,86 +1,217 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import threading +from __future__ import annotations + +import enum +import uuid +from collections import deque +from dataclasses import dataclass -import zmq -from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.ipc import unpack_diffusion_output_shm from vllm_omni.diffusion.request import OmniDiffusionRequest logger = init_logger(__name__) -class Scheduler: - def initialize(self, od_config: OmniDiffusionConfig): - existing_mq = getattr(self, "mq", None) - if existing_mq is not None and not existing_mq.closed: - logger.warning("SyncSchedulerClient is already initialized. Re-initializing.") - self.close() +class DiffusionRequestStatus(enum.IntEnum): + """Request status tracked by diffusion scheduler.""" + + WAITING = enum.auto() + RUNNING = enum.auto() + PREEMPTED = enum.auto() + + # if any status is after FINISHED_COMPLETED, it is considered finished + FINISHED_COMPLETED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_ERROR = enum.auto() + + @staticmethod + def is_finished(status: DiffusionRequestStatus) -> bool: + return status >= DiffusionRequestStatus.FINISHED_COMPLETED - self.num_workers = od_config.num_gpus - self.od_config = od_config - self._lock = threading.Lock() - - # Initialize single MessageQueue for all message types (generation & RPC) - # Assuming all readers are local for now as per current launch_engine implementation - self.mq = MessageQueue( - n_reader=self.num_workers, - n_local_reader=self.num_workers, - local_reader_ranks=list(range(self.num_workers)), - ) - self.result_mq = None +@dataclass +class DiffusionRequestState: + """Scheduler-owned state for one queued OmniDiffusionRequest.""" - def initialize_result_queue(self, handle): - # Initialize MessageQueue for receiving results - # We act as rank 0 reader for this queue - self.result_mq = MessageQueue.create_from_handle(handle, rank=0) - logger.info("SyncScheduler initialized result MessageQueue") + req_id: str # unique request ID generated by scheduler + req: OmniDiffusionRequest + status: DiffusionRequestStatus = DiffusionRequestStatus.WAITING + error: str | None = None - def get_broadcast_handle(self): - return self.mq.export_handle() - def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: - """Sends a request to the scheduler and waits for the response.""" - with self._lock: - try: - # Prepare RPC request for generation - rpc_request = { - "type": "rpc", - "method": "generate", - "args": (request,), - "kwargs": {}, - "output_rank": 0, - "exec_all_ranks": True, - } +@dataclass +class DiffusionSchedulerOutput: + """Output of a single scheduling cycle.""" - # Broadcast RPC request to all workers - self.mq.enqueue(rpc_request) + step_id: int + req_states: list[DiffusionRequestState] + finished_req_ids: set[str] + num_running_reqs: int + num_waiting_reqs: int - # Wait for result from Rank 0 (or whoever sends it) - if self.result_mq is None: - raise RuntimeError("Result queue not initialized") - output = self.result_mq.dequeue() +class Scheduler: + """Diffusion scheduler with vLLM-style waiting/running queues. + + NOTE: Currently, each OmniDiffusionRequest is already pre-batched upstream. + Scheduler only handles request-level state transitions and never merges + multiple requests into a new batch. + """ + + def __init__(self) -> None: + self.od_config: OmniDiffusionConfig | None = None + self._request_states: dict[str, DiffusionRequestState] = {} + self._step_id: int = 0 + + self._waiting: deque[str] = deque() + self._running: list[str] = [] + self._finished_req_ids: set[str] = set() + + def initialize(self, od_config: OmniDiffusionConfig) -> None: + self.od_config = od_config + self._request_states.clear() + self._step_id = 0 + + self._waiting.clear() + self._running.clear() + self._finished_req_ids.clear() + + def add_request(self, request: OmniDiffusionRequest) -> str: + req_id = self._make_req_id(request) + state = DiffusionRequestState(req_id=req_id, req=request) + self._request_states[req_id] = state + self._waiting.append(req_id) + logger.debug("Scheduler add_request: %s (waiting=%d)", req_id, len(self._waiting)) + return req_id + + def schedule(self) -> DiffusionSchedulerOutput: + # Single-request scheduling: do not build multiple request batches. + if not self._running and self._waiting: + req_id = self._waiting.popleft() + state = self._request_states.get(req_id) + if state is not None: + state.status = DiffusionRequestStatus.RUNNING + self._running.append(req_id) + + running_states: list[DiffusionRequestState] = [] + for req_id in self._running: + state = self._request_states.get(req_id) + if state is not None: + running_states.append(state) + + scheduler_output = DiffusionSchedulerOutput( + step_id=self._step_id, + req_states=running_states, + finished_req_ids=set(self._finished_req_ids), + num_running_reqs=len(self._running), + num_waiting_reqs=len(self._waiting), + ) + # update after schedule + self._step_id += 1 + self._finished_req_ids.clear() + return scheduler_output + + def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: DiffusionOutput) -> set[str]: + scheduled_req_ids = {state.req_id for state in sched_output.req_states} + if not scheduled_req_ids: + return set() + + completed_req_ids: set[str] = set() + if output.error: + for req_id in scheduled_req_ids: + state = self._request_states.get(req_id) + if state is None: + continue + state.status = DiffusionRequestStatus.FINISHED_ERROR + state.error = output.error + completed_req_ids.add(req_id) + else: + for req_id in scheduled_req_ids: + state = self._request_states.get(req_id) + if state is None: + continue + state.status = DiffusionRequestStatus.FINISHED_COMPLETED + state.error = None + completed_req_ids.add(req_id) + + if completed_req_ids: + self._running = [req_id for req_id in self._running if req_id not in completed_req_ids] + for req_id in completed_req_ids: try: - unpack_diffusion_output_shm(output) - except Exception as e: - logger.warning("SHM unpack failed (data may already be inline): %s", e) - - # {"status": "error", "error": str(e)} - if isinstance(output, dict) and output.get("status") == "error": - raise RuntimeError("worker error") - return output - except zmq.error.Again: - logger.error("Timeout waiting for response from scheduler.") - raise TimeoutError("Scheduler did not respond in time.") - - def close(self): - """Closes the socket and terminates the context.""" - self.mq = None - self.result_mq = None + self._waiting.remove(req_id) + except ValueError: + pass + self._finished_req_ids |= completed_req_ids + + return completed_req_ids + + def abort_request(self, req_id: str) -> bool: + if req_id not in self._request_states: + return False + self.finish_request(req_id, DiffusionRequestStatus.FINISHED_ABORTED) + self._finished_req_ids.add(req_id) + return True + + def has_requests(self) -> bool: + return bool(self._waiting or self._running) + + def get_request_state(self, req_id: str) -> DiffusionRequestState | None: + return self._request_states.get(req_id) + + def pop_request_state(self, req_id: str) -> DiffusionRequestState | None: + return self._request_states.pop(req_id, None) + + def preempt_request(self, req_id: str) -> bool: + if req_id not in self._request_states: + return False + if req_id in self._running: + self._running.remove(req_id) + self._waiting.appendleft(req_id) + self._request_states[req_id].status = DiffusionRequestStatus.PREEMPTED + return True + return False + + def finish_request(self, req_id: str, status: DiffusionRequestStatus) -> None: + assert DiffusionRequestStatus.is_finished(status) + state = self._request_states.get(req_id) + if state is None: + return + + state.status = status + if req_id in self._running: + self._running.remove(req_id) + try: + self._waiting.remove(req_id) + except ValueError: + pass + + def close(self) -> None: + self._request_states.clear() + self._waiting.clear() + self._running.clear() + self._finished_req_ids.clear() + + def _make_req_id(self, request: OmniDiffusionRequest) -> str: + """ + Generate a unique request ID for the given request. + If the request already has request IDs, use the first one as the base. + + NOTE: OmniDiffusionRequest already contain multiple prompts/outputs + as a pre-batched request object. + """ + if request.request_ids: + base = request.request_ids[0] + else: + base = f"req_{uuid.uuid4().hex[:8]}" + + req_id = base + suffix = 1 + while req_id in self._request_states: + req_id = f"{base}#{suffix}" + suffix += 1 + return req_id From 3c6509984c390d1627ca2bd96c39fb177356c0f0 Mon Sep 17 00:00:00 2001 From: jader Date: Tue, 3 Mar 2026 20:22:12 +0800 Subject: [PATCH 02/14] [Bugfix] Handle output error in DiffusionEngine's dummy run Signed-off-by: jader --- tests/diffusion/test_diffusion_scheduler.py | 10 ++++++++++ vllm_omni/diffusion/diffusion_engine.py | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index 09d8f6c6de9..ef88f99a587 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -147,3 +147,13 @@ def test_engine_add_req_and_wait_for_response_single_path() -> None: assert output is expected engine.executor.add_req.assert_called_once_with(request) + + +def test_engine_dummy_run_raises_on_output_error() -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.od_config = Mock(model_class_name="mock_model") + engine.pre_process_func = None + engine.add_req_and_wait_for_response = Mock(return_value=DiffusionOutput(error="boom")) + + with pytest.raises(RuntimeError, match="Dummy run failed: boom"): + engine._dummy_run() diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 5f53c3d7dec..fafcd088a44 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -449,7 +449,9 @@ def _dummy_run(self): ) logger.info("dummy run to warm up the model") request = self.pre_process_func(req) if self.pre_process_func is not None else req - self.add_req_and_wait_for_response(request) + output = self.add_req_and_wait_for_response(request) + if output.error: + raise RuntimeError(f"Dummy run failed: {output.error}") def collective_rpc( self, From 0d5c838eb08bff629e9d0d3cd433198ceac8f43e Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 9 Mar 2026 06:28:53 +0000 Subject: [PATCH 03/14] refactor: remove dead timeout handling from add_req Signed-off-by: jader --- vllm_omni/diffusion/executor/multiproc_executor.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index bb9e8d1c4d8..b2e2c2ef9d9 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -167,7 +167,6 @@ def _launch_workers(self, broadcast_handle): def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: self._ensure_open() - deadline = None rpc_request = { "type": "rpc", "method": "generate", @@ -179,13 +178,7 @@ def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: try: self._broadcast_mq.enqueue(rpc_request) - - try: - response = self._result_mq.dequeue(timeout=deadline) - except zmq.error.Again as exc: - raise TimeoutError("Generate call timed out.") from exc - except TimeoutError as exc: - raise TimeoutError("Generate call timed out.") from exc + response = self._result_mq.dequeue() try: unpack_diffusion_output_shm(response) From 7b3fb6b2a3d3b32c5911141687a2e3aa6873a960 Mon Sep 17 00:00:00 2001 From: jader Date: Tue, 10 Mar 2026 06:46:24 +0000 Subject: [PATCH 04/14] test: update pytestmark to include core_model and cpu for diffusion tests Signed-off-by: jader --- tests/diffusion/test_diffusion_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index ef88f99a587..2565ab9e360 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -12,7 +12,7 @@ from vllm_omni.diffusion.scheduler import DiffusionRequestStatus, Scheduler from vllm_omni.inputs.data import OmniDiffusionSamplingParams -pytestmark = [pytest.mark.diffusion] +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] def _make_scheduler() -> Scheduler: From eb76c3d02b9a14760d3a7c4d7180fcc11fff2ef3 Mon Sep 17 00:00:00 2001 From: jader Date: Tue, 10 Mar 2026 14:45:49 +0000 Subject: [PATCH 05/14] refactor: rename DiffusionRequestState.req_id to sched_req_id Signed-off-by: jader --- vllm_omni/diffusion/diffusion_engine.py | 12 ++- vllm_omni/diffusion/scheduler.py | 124 ++++++++++++++---------- 2 files changed, 80 insertions(+), 56 deletions(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index fafcd088a44..1189de43eae 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -267,7 +267,7 @@ def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> DiffusionOutput: with self._rpc_lock: - target_req_id = self.scheduler.add_request(request) + target_sched_req_id = self.scheduler.add_request(request) while True: sched_output = self.scheduler.schedule() @@ -280,12 +280,16 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus try: output = self.executor.add_req(req_state.req) except Exception as exc: - logger.error("Execution failed for diffusion request %s", req_state.req_id, exc_info=True) + logger.error( + "Execution failed for diffusion request %s", + req_state.sched_req_id, + exc_info=True, + ) output = DiffusionOutput(error=str(exc)) finished_req_ids = self.scheduler.update_from_output(sched_output, output) - if target_req_id in finished_req_ids: - self.scheduler.pop_request_state(target_req_id) + if target_sched_req_id in finished_req_ids: + self.scheduler.pop_request_state(target_sched_req_id) return output def start_profile(self, trace_filename: str | None = None) -> None: diff --git a/vllm_omni/diffusion/scheduler.py b/vllm_omni/diffusion/scheduler.py index 6e3b21ccbe5..d5b1de152ac 100644 --- a/vllm_omni/diffusion/scheduler.py +++ b/vllm_omni/diffusion/scheduler.py @@ -37,7 +37,7 @@ def is_finished(status: DiffusionRequestStatus) -> bool: class DiffusionRequestState: """Scheduler-owned state for one queued OmniDiffusionRequest.""" - req_id: str # unique request ID generated by scheduler + sched_req_id: str # unique scheduler request ID generated by scheduler req: OmniDiffusionRequest status: DiffusionRequestStatus = DiffusionRequestStatus.WAITING error: str | None = None @@ -81,25 +81,29 @@ def initialize(self, od_config: OmniDiffusionConfig) -> None: self._finished_req_ids.clear() def add_request(self, request: OmniDiffusionRequest) -> str: - req_id = self._make_req_id(request) - state = DiffusionRequestState(req_id=req_id, req=request) - self._request_states[req_id] = state - self._waiting.append(req_id) - logger.debug("Scheduler add_request: %s (waiting=%d)", req_id, len(self._waiting)) - return req_id + sched_req_id = self._make_sched_req_id(request) + state = DiffusionRequestState(sched_req_id=sched_req_id, req=request) + self._request_states[sched_req_id] = state + self._waiting.append(sched_req_id) + logger.debug( + "Scheduler add_request: %s (waiting=%d)", + sched_req_id, + len(self._waiting), + ) + return sched_req_id def schedule(self) -> DiffusionSchedulerOutput: # Single-request scheduling: do not build multiple request batches. if not self._running and self._waiting: - req_id = self._waiting.popleft() - state = self._request_states.get(req_id) + sched_req_id = self._waiting.popleft() + state = self._request_states.get(sched_req_id) if state is not None: state.status = DiffusionRequestStatus.RUNNING - self._running.append(req_id) + self._running.append(sched_req_id) running_states: list[DiffusionRequestState] = [] - for req_id in self._running: - state = self._request_states.get(req_id) + for sched_req_id in self._running: + state = self._request_states.get(sched_req_id) if state is not None: running_states.append(state) @@ -116,77 +120,93 @@ def schedule(self) -> DiffusionSchedulerOutput: self._finished_req_ids.clear() return scheduler_output - def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: DiffusionOutput) -> set[str]: - scheduled_req_ids = {state.req_id for state in sched_output.req_states} - if not scheduled_req_ids: + def update_from_output( + self, + sched_output: DiffusionSchedulerOutput, + output: DiffusionOutput, + ) -> set[str]: + scheduled_sched_req_ids = {state.sched_req_id for state in sched_output.req_states} + if not scheduled_sched_req_ids: return set() - completed_req_ids: set[str] = set() + completed_sched_req_ids: set[str] = set() if output.error: - for req_id in scheduled_req_ids: - state = self._request_states.get(req_id) + for sched_req_id in scheduled_sched_req_ids: + state = self._request_states.get(sched_req_id) if state is None: continue state.status = DiffusionRequestStatus.FINISHED_ERROR state.error = output.error - completed_req_ids.add(req_id) + completed_sched_req_ids.add(sched_req_id) else: - for req_id in scheduled_req_ids: - state = self._request_states.get(req_id) + for sched_req_id in scheduled_sched_req_ids: + state = self._request_states.get(sched_req_id) if state is None: continue state.status = DiffusionRequestStatus.FINISHED_COMPLETED state.error = None - completed_req_ids.add(req_id) + completed_sched_req_ids.add(sched_req_id) - if completed_req_ids: - self._running = [req_id for req_id in self._running if req_id not in completed_req_ids] - for req_id in completed_req_ids: + if completed_sched_req_ids: + self._running = [ + sched_req_id for sched_req_id in self._running if sched_req_id not in completed_sched_req_ids + ] + for sched_req_id in completed_sched_req_ids: try: - self._waiting.remove(req_id) + self._waiting.remove(sched_req_id) except ValueError: pass - self._finished_req_ids |= completed_req_ids + self._finished_req_ids |= completed_sched_req_ids - return completed_req_ids + return completed_sched_req_ids - def abort_request(self, req_id: str) -> bool: - if req_id not in self._request_states: + def abort_request(self, sched_req_id: str) -> bool: + if sched_req_id not in self._request_states: return False - self.finish_request(req_id, DiffusionRequestStatus.FINISHED_ABORTED) - self._finished_req_ids.add(req_id) + self.finish_request(sched_req_id, DiffusionRequestStatus.FINISHED_ABORTED) + self._finished_req_ids.add(sched_req_id) return True def has_requests(self) -> bool: return bool(self._waiting or self._running) - def get_request_state(self, req_id: str) -> DiffusionRequestState | None: - return self._request_states.get(req_id) + def get_request_state( + self, + sched_req_id: str, + ) -> DiffusionRequestState | None: + return self._request_states.get(sched_req_id) - def pop_request_state(self, req_id: str) -> DiffusionRequestState | None: - return self._request_states.pop(req_id, None) + def pop_request_state( + self, + sched_req_id: str, + ) -> DiffusionRequestState | None: + return self._request_states.pop(sched_req_id, None) - def preempt_request(self, req_id: str) -> bool: - if req_id not in self._request_states: + def preempt_request(self, sched_req_id: str) -> bool: + if sched_req_id not in self._request_states: return False - if req_id in self._running: - self._running.remove(req_id) - self._waiting.appendleft(req_id) - self._request_states[req_id].status = DiffusionRequestStatus.PREEMPTED + if sched_req_id in self._running: + self._running.remove(sched_req_id) + self._waiting.appendleft(sched_req_id) + self._request_states[sched_req_id].status = DiffusionRequestStatus.PREEMPTED return True return False - def finish_request(self, req_id: str, status: DiffusionRequestStatus) -> None: + def finish_request( + self, + sched_req_id: str, + status: DiffusionRequestStatus, + ) -> None: assert DiffusionRequestStatus.is_finished(status) - state = self._request_states.get(req_id) + state = self._request_states.get(sched_req_id) if state is None: return state.status = status - if req_id in self._running: - self._running.remove(req_id) + if sched_req_id in self._running: + self._running.remove(sched_req_id) try: - self._waiting.remove(req_id) + self._waiting.remove(sched_req_id) except ValueError: pass @@ -196,7 +216,7 @@ def close(self) -> None: self._running.clear() self._finished_req_ids.clear() - def _make_req_id(self, request: OmniDiffusionRequest) -> str: + def _make_sched_req_id(self, request: OmniDiffusionRequest) -> str: """ Generate a unique request ID for the given request. If the request already has request IDs, use the first one as the base. @@ -209,9 +229,9 @@ def _make_req_id(self, request: OmniDiffusionRequest) -> str: else: base = f"req_{uuid.uuid4().hex[:8]}" - req_id = base + sched_req_id = base suffix = 1 - while req_id in self._request_states: - req_id = f"{base}#{suffix}" + while sched_req_id in self._request_states: + sched_req_id = f"{base}#{suffix}" suffix += 1 - return req_id + return sched_req_id From b29219a6ca546c98422b3873b906229d24708b80 Mon Sep 17 00:00:00 2001 From: jader Date: Wed, 18 Mar 2026 11:39:32 +0000 Subject: [PATCH 06/14] refactor(diffusion): update design for future feature Signed-off-by: jader --- tests/diffusion/test_diffusion_scheduler.py | 362 ++++++++++++------ .../test_multiproc_engine_concurrency.py | 4 +- vllm_omni/diffusion/diffusion_engine.py | 48 +-- vllm_omni/diffusion/sched/__init__.py | 25 ++ vllm_omni/diffusion/sched/base_scheduler.py | 141 +++++++ vllm_omni/diffusion/sched/interface.py | 130 ++++++- .../diffusion/sched/request_scheduler.py | 104 +++++ vllm_omni/diffusion/scheduler.py | 237 ------------ 8 files changed, 666 insertions(+), 385 deletions(-) create mode 100644 vllm_omni/diffusion/sched/__init__.py create mode 100644 vllm_omni/diffusion/sched/base_scheduler.py create mode 100644 vllm_omni/diffusion/sched/request_scheduler.py delete mode 100644 vllm_omni/diffusion/scheduler.py diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index 2565ab9e360..db61c8f65d6 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -2,23 +2,23 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from vllm_omni.diffusion.data import DiffusionOutput from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.scheduler import DiffusionRequestStatus, Scheduler +from vllm_omni.diffusion.sched import ( + DiffusionRequestStatus, + RequestScheduler, + Scheduler, + SchedulerInterface, +) +from vllm_omni.diffusion.sched.interface import CachedRequestData, NewRequestData from vllm_omni.inputs.data import OmniDiffusionSamplingParams -pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] - - -def _make_scheduler() -> Scheduler: - scheduler = Scheduler() - scheduler.initialize(Mock()) - return scheduler +pytestmark = [pytest.mark.diffusion] def _make_request(req_id: str) -> OmniDiffusionRequest: @@ -29,131 +29,271 @@ def _make_request(req_id: str) -> OmniDiffusionRequest: ) -def test_single_request_success_lifecycle() -> None: - scheduler = _make_scheduler() - - req_id = scheduler.add_request(_make_request("a")) - assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.WAITING - - sched_output = scheduler.schedule() - assert len(sched_output.req_states) == 1 - assert sched_output.req_states[0].req_id == req_id - assert sched_output.num_running_reqs == 1 - assert sched_output.num_waiting_reqs == 0 - - finished = scheduler.update_from_output(sched_output, DiffusionOutput(output=None)) - assert finished == {req_id} - assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED - assert scheduler.has_requests() is False - - -def test_error_output_marks_finished_error() -> None: - scheduler = _make_scheduler() - req_id = scheduler.add_request(_make_request("err")) - - sched_output = scheduler.schedule() - finished = scheduler.update_from_output(sched_output, DiffusionOutput(error="worker failed")) - - assert finished == {req_id} - state = scheduler.get_request_state(req_id) - assert state.status == DiffusionRequestStatus.FINISHED_ERROR - assert state.error == "worker failed" - - -def test_empty_output_without_error_marks_completed() -> None: - scheduler = _make_scheduler() - req_id = scheduler.add_request(_make_request("empty")) - - sched_output = scheduler.schedule() - finished = scheduler.update_from_output(sched_output, DiffusionOutput(output=None, error=None)) +def _make_request_output(req_id: str, *, error: str | None = None) -> DiffusionOutput: + del req_id + return DiffusionOutput(output=None, error=error) - assert finished == {req_id} - assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED +def _new_ids(sched_output) -> list[str]: + return [req.sched_req_id for req in sched_output.scheduled_new_reqs] -def test_fifo_single_request_scheduling() -> None: - scheduler = _make_scheduler() - req_id_a = scheduler.add_request(_make_request("a")) - req_id_b = scheduler.add_request(_make_request("b")) - first = scheduler.schedule() - assert [s.req_id for s in first.req_states] == [req_id_a] - assert first.num_running_reqs == 1 - assert first.num_waiting_reqs == 1 +def _cached_ids(sched_output) -> list[str]: + return list(sched_output.scheduled_cached_reqs.sched_req_ids) + - # Request A is still running; scheduling again should not pull B. - second = scheduler.schedule() - assert [s.req_id for s in second.req_states] == [req_id_a] - assert second.num_running_reqs == 1 - assert second.num_waiting_reqs == 1 +class _StubScheduler(SchedulerInterface): + def __init__(self, request: OmniDiffusionRequest, output: DiffusionOutput) -> None: + self._request = request + self._output = output + self.initialized_with = None + self._sched_req_id = request.request_ids[0] + self._state = None + self._scheduled = False - scheduler.update_from_output(first, DiffusionOutput(output=None)) + def initialize(self, od_config) -> None: + self.initialized_with = od_config - third = scheduler.schedule() - assert [s.req_id for s in third.req_states] == [req_id_b] - assert third.num_running_reqs == 1 - assert third.num_waiting_reqs == 0 + def add_request(self, request: OmniDiffusionRequest) -> str: + assert request is self._request + self._state = Mock(sched_req_id=self._sched_req_id, req=request) + return self._sched_req_id + + def schedule(self): + if self._scheduled or self._state is None: + return Mock( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + scheduled_req_ids=[], + is_empty=True, + ) + self._scheduled = True + return Mock( + scheduled_new_reqs=[NewRequestData.from_state(self._state)], + scheduled_cached_reqs=CachedRequestData.make_empty(), + scheduled_req_ids=[self._state.sched_req_id], + is_empty=False, + ) + + def update_from_output(self, sched_output, output: DiffusionOutput) -> set[str]: + del sched_output + assert output is self._output + return {self._sched_req_id} + + def has_requests(self) -> bool: + return not self._scheduled + + def get_request_state(self, sched_req_id: str): + del sched_req_id + return self._state + + def get_sched_req_id(self, request_id: str) -> str | None: + if request_id in self._request.request_ids: + return self._sched_req_id + return None + + def pop_request_state(self, sched_req_id: str): + del sched_req_id + return self._state + + def preempt_request(self, sched_req_id: str) -> bool: + del sched_req_id + return False + + def finish_requests(self, sched_req_ids, status) -> None: + del sched_req_ids, status + return None + + def close(self) -> None: + return None + + +class TestRequestScheduler: + def setup_method(self) -> None: + self.scheduler: RequestScheduler = RequestScheduler() + self.scheduler.initialize(Mock()) + + def test_single_request_success_lifecycle(self) -> None: + req_id = self.scheduler.add_request(_make_request("a")) + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.WAITING + + sched_output = self.scheduler.schedule() + assert _new_ids(sched_output) == [req_id] + assert _cached_ids(sched_output) == [] + assert sched_output.num_running_reqs == 1 + assert sched_output.num_waiting_reqs == 0 + + finished = self.scheduler.update_from_output(sched_output, _make_request_output(req_id)) + assert finished == {req_id} + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert self.scheduler.has_requests() is False + def test_error_output_marks_finished_error(self) -> None: + req_id = self.scheduler.add_request(_make_request("err")) + + sched_output = self.scheduler.schedule() + finished = self.scheduler.update_from_output( + sched_output, + _make_request_output(req_id, error="worker failed"), + ) + + assert finished == {req_id} + state = self.scheduler.get_request_state(req_id) + assert state.status == DiffusionRequestStatus.FINISHED_ERROR + assert state.error == "worker failed" + + def test_empty_output_without_error_marks_completed(self) -> None: + req_id = self.scheduler.add_request(_make_request("empty")) + + sched_output = self.scheduler.schedule() + finished = self.scheduler.update_from_output(sched_output, _make_request_output(req_id)) + + assert finished == {req_id} + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + + def test_fifo_single_request_scheduling(self) -> None: + req_id_a = self.scheduler.add_request(_make_request("a")) + req_id_b = self.scheduler.add_request(_make_request("b")) + + first = self.scheduler.schedule() + assert _new_ids(first) == [req_id_a] + assert _cached_ids(first) == [] + assert first.num_running_reqs == 1 + assert first.num_waiting_reqs == 1 + + # Request A is still running; scheduling again should not pull B. + second = self.scheduler.schedule() + assert _new_ids(second) == [] + assert _cached_ids(second) == [req_id_a] + assert second.num_running_reqs == 1 + assert second.num_waiting_reqs == 1 + + self.scheduler.update_from_output(first, _make_request_output(req_id_a)) + + third = self.scheduler.schedule() + assert _new_ids(third) == [req_id_b] + assert _cached_ids(third) == [] + assert third.num_running_reqs == 1 + assert third.num_waiting_reqs == 0 + + def test_abort_request_for_waiting_and_running(self) -> None: + req_id_a = self.scheduler.add_request(_make_request("a")) + req_id_b = self.scheduler.add_request(_make_request("b")) -def test_abort_request_for_waiting_and_running() -> None: - scheduler = _make_scheduler() - req_id_a = scheduler.add_request(_make_request("a")) - req_id_b = scheduler.add_request(_make_request("b")) + # Abort waiting request. + self.scheduler.finish_requests(req_id_b, DiffusionRequestStatus.FINISHED_ABORTED) + state_b = self.scheduler.get_request_state(req_id_b) + assert state_b.status == DiffusionRequestStatus.FINISHED_ABORTED + + # A should still run normally. + output_a = self.scheduler.schedule() + assert _new_ids(output_a) == [req_id_a] - # Abort waiting request. - assert scheduler.abort_request(req_id_b) is True - state_b = scheduler.get_request_state(req_id_b) - assert state_b.status == DiffusionRequestStatus.FINISHED_ABORTED + # Abort running request. + self.scheduler.finish_requests(req_id_a, DiffusionRequestStatus.FINISHED_ABORTED) + state_a = self.scheduler.get_request_state(req_id_a) + assert state_a.status == DiffusionRequestStatus.FINISHED_ABORTED + + assert self.scheduler.has_requests() is False + assert self.scheduler.schedule().scheduled_req_ids == [] + + def test_has_requests_state_transition(self) -> None: + assert self.scheduler.has_requests() is False - # A should still run normally. - output_a = scheduler.schedule() - assert [s.req_id for s in output_a.req_states] == [req_id_a] + req_id = self.scheduler.add_request(_make_request("has")) + assert self.scheduler.has_requests() is True - # Abort running request. - assert scheduler.abort_request(req_id_a) is True - state_a = scheduler.get_request_state(req_id_a) - assert state_a.status == DiffusionRequestStatus.FINISHED_ABORTED + sched_output = self.scheduler.schedule() + assert self.scheduler.has_requests() is True + + self.scheduler.update_from_output(sched_output, _make_request_output(req_id)) + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert self.scheduler.has_requests() is False + + def test_request_id_mapping_lifecycle(self) -> None: + request = OmniDiffusionRequest( + prompts=["prompt_map_a", "prompt_map_b"], + sampling_params=OmniDiffusionSamplingParams(num_inference_steps=1), + request_ids=["map-a", "map-b"], + ) + + sched_req_id = self.scheduler.add_request(request) + + assert self.scheduler.get_sched_req_id("map-a") == sched_req_id + assert self.scheduler.get_sched_req_id("map-b") == sched_req_id + + self.scheduler.pop_request_state(sched_req_id) + + assert self.scheduler.get_sched_req_id("map-a") is None + assert self.scheduler.get_sched_req_id("map-b") is None - assert scheduler.has_requests() is False - assert scheduler.schedule().req_states == [] +class TestDiffusionEngine: + def test_add_req_and_wait_for_response_single_path(self) -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.scheduler = RequestScheduler() + engine.scheduler.initialize(Mock()) + engine.executor = Mock() + engine._rpc_lock = threading.Lock() -def test_has_requests_state_transition() -> None: - scheduler = _make_scheduler() - assert scheduler.has_requests() is False + request = _make_request("engine") + expected = DiffusionOutput(output=None) + engine.executor.add_req.return_value = expected - req_id = scheduler.add_request(_make_request("has")) - assert scheduler.has_requests() is True + output = engine.add_req_and_wait_for_response(request) - sched_output = scheduler.schedule() - assert scheduler.has_requests() is True + assert output is expected + engine.executor.add_req.assert_called_once_with(request) - scheduler.update_from_output(sched_output, DiffusionOutput(output=None)) - assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED - assert scheduler.has_requests() is False + def test_supports_scheduler_interface_injection(self) -> None: + request = _make_request("engine_iface") + expected = DiffusionOutput(output=None) + scheduler = _StubScheduler(request, expected) + + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.scheduler = scheduler + engine.executor = Mock() + engine.executor.add_req = Mock(return_value=expected) + engine._rpc_lock = threading.Lock() + + output = engine.add_req_and_wait_for_response(request) + + assert output is expected + engine.executor.add_req.assert_called_once_with(request) + def test_initializes_injected_scheduler(self) -> None: + request = _make_request("init") + scheduler = _StubScheduler(request, DiffusionOutput(output=None)) + od_config = Mock(model_class_name="mock_model") + fake_executor_cls = Mock(return_value=Mock()) -def test_engine_add_req_and_wait_for_response_single_path() -> None: - engine = DiffusionEngine.__new__(DiffusionEngine) - engine.scheduler = _make_scheduler() - engine.executor = Mock() - engine._rpc_lock = threading.Lock() + with ( + patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None), + patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None), + patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls), + patch.object(DiffusionEngine, "_dummy_run", return_value=None), + ): + DiffusionEngine(od_config, scheduler=scheduler) - request = _make_request("engine") - expected = DiffusionOutput(output=None) - engine.executor.add_req.return_value = expected + assert scheduler.initialized_with is od_config + fake_executor_cls.assert_called_once_with(od_config) - output = engine.add_req_and_wait_for_response(request) + def test_scheduler_alias_keeps_default_request_scheduler(self) -> None: + scheduler = Scheduler() + scheduler.initialize(Mock()) - assert output is expected - engine.executor.add_req.assert_called_once_with(request) + req_id = scheduler.add_request(_make_request("alias")) + sched_output = scheduler.schedule() + finished = scheduler.update_from_output(sched_output, _make_request_output(req_id)) + assert req_id in finished + assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED -def test_engine_dummy_run_raises_on_output_error() -> None: - engine = DiffusionEngine.__new__(DiffusionEngine) - engine.od_config = Mock(model_class_name="mock_model") - engine.pre_process_func = None - engine.add_req_and_wait_for_response = Mock(return_value=DiffusionOutput(error="boom")) + def test_dummy_run_raises_on_output_error(self) -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.od_config = Mock(model_class_name="mock_model") + engine.pre_process_func = None + engine.add_req_and_wait_for_response = Mock(return_value=DiffusionOutput(error="boom")) - with pytest.raises(RuntimeError, match="Dummy run failed: boom"): - engine._dummy_run() + with pytest.raises(RuntimeError, match="Dummy run failed: boom"): + engine._dummy_run() diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py index 90289727fcf..0465a1932b2 100644 --- a/tests/diffusion/test_multiproc_engine_concurrency.py +++ b/tests/diffusion/test_multiproc_engine_concurrency.py @@ -11,7 +11,7 @@ from vllm_omni.diffusion.data import DiffusionOutput from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor -from vllm_omni.diffusion.scheduler import Scheduler +from vllm_omni.diffusion.sched import RequestScheduler pytestmark = [pytest.mark.diffusion, pytest.mark.core_model, pytest.mark.cpu] @@ -62,7 +62,7 @@ def _make_engine(num_gpus: int = 1): """Create a lightweight ``DiffusionEngine`` wired to mocked executor.""" executor, req_q, res_q = _make_executor(num_gpus) engine = DiffusionEngine.__new__(DiffusionEngine) - sched = Scheduler() + sched = RequestScheduler() sched.initialize(Mock()) engine.scheduler = sched engine.executor = executor diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 1189de43eae..53a567f6180 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -7,7 +7,6 @@ from collections.abc import Iterable from typing import Any -import numpy as np import PIL.Image import torch from vllm.logger import init_logger @@ -20,7 +19,7 @@ get_diffusion_pre_process_func, ) from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.scheduler import Scheduler +from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput @@ -34,13 +33,6 @@ def supports_image_input(model_class_name: str) -> bool: return bool(getattr(model_cls, "support_image_input", False)) -def supports_audio_input(model_class_name: str) -> bool: - model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) - if model_cls is None: - return False - return bool(getattr(model_cls, "support_audio_input", False)) - - def image_color_format(model_class_name: str) -> str: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) return getattr(model_cls, "color_format", "RGB") @@ -56,7 +48,11 @@ def supports_audio_output(model_class_name: str) -> bool: class DiffusionEngine: """The diffusion engine for vLLM-Omni diffusion models.""" - def __init__(self, od_config: OmniDiffusionConfig): + def __init__( + self, + od_config: OmniDiffusionConfig, + scheduler: SchedulerInterface | None = None, + ): """Initialize the diffusion engine. Args: @@ -69,7 +65,7 @@ def __init__(self, od_config: OmniDiffusionConfig): executor_class = DiffusionExecutor.get_class(od_config) self.executor = executor_class(od_config) - self.scheduler = Scheduler() + self.scheduler: SchedulerInterface = scheduler or RequestScheduler() self.scheduler.initialize(od_config) self._rpc_lock = threading.Lock() @@ -254,7 +250,10 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: return results @staticmethod - def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": + def make_engine( + config: OmniDiffusionConfig, + scheduler: SchedulerInterface | None = None, + ) -> "DiffusionEngine": """Factory method to create a DiffusionEngine instance. Args: @@ -263,26 +262,28 @@ def make_engine(config: OmniDiffusionConfig) -> "DiffusionEngine": Returns: An instance of DiffusionEngine. """ - return DiffusionEngine(config) + return DiffusionEngine(config, scheduler=scheduler) def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> DiffusionOutput: with self._rpc_lock: target_sched_req_id = self.scheduler.add_request(request) + # keep scheduling and executing until the target request is finished while True: sched_output = self.scheduler.schedule() - if not sched_output.req_states: + if sched_output.is_empty: if not self.scheduler.has_requests(): raise RuntimeError("Diffusion scheduler has no runnable requests.") continue - req_state = sched_output.req_states[0] + sched_req_id = sched_output.scheduled_req_ids[0] + req = sched_output.scheduled_new_reqs[0].req try: - output = self.executor.add_req(req_state.req) + output = self.executor.add_req(req) except Exception as exc: logger.error( "Execution failed for diffusion request %s", - req_state.sched_req_id, + sched_req_id, exc_info=True, ) output = DiffusionOutput(error=str(exc)) @@ -423,18 +424,7 @@ def _dummy_run(self): else: dummy_image = None - if supports_audio_input(self.od_config.model_class_name): - audio_sr = 16000 - audio_duration_sec = 4 - audio_array = np.random.randn(audio_sr * audio_duration_sec).astype(np.float32) - dummy_audio = audio_array[audio_sr * 1 : audio_sr * 3] - else: - dummy_audio = None - - prompt: OmniTextPrompt = { - "prompt": "dummy run", - "multi_modal_data": {"image": dummy_image, "audio": dummy_audio}, - } + prompt: OmniTextPrompt = {"prompt": "dummy run", "multi_modal_data": {"image": dummy_image}} req = OmniDiffusionRequest( prompts=[prompt], sampling_params=OmniDiffusionSamplingParams( diff --git a/vllm_omni/diffusion/sched/__init__.py b/vllm_omni/diffusion/sched/__init__.py new file mode 100644 index 00000000000..650a1a1e6fb --- /dev/null +++ b/vllm_omni/diffusion/sched/__init__.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.sched.interface import ( + CachedRequestData, + DiffusionRequestState, + DiffusionRequestStatus, + DiffusionSchedulerOutput, + NewRequestData, + SchedulerInterface, +) +from vllm_omni.diffusion.sched.request_scheduler import RequestScheduler + +Scheduler = RequestScheduler + +__all__ = [ + "CachedRequestData", + "DiffusionRequestState", + "DiffusionRequestStatus", + "DiffusionSchedulerOutput", + "NewRequestData", + "RequestScheduler", + "Scheduler", + "SchedulerInterface", +] diff --git a/vllm_omni/diffusion/sched/base_scheduler.py b/vllm_omni/diffusion/sched/base_scheduler.py new file mode 100644 index 00000000000..ef45d868755 --- /dev/null +++ b/vllm_omni/diffusion/sched/base_scheduler.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from collections import deque + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.sched.interface import ( + DiffusionRequestState, + DiffusionRequestStatus, + SchedulerInterface, +) + + +class _BaseScheduler(SchedulerInterface): + """Shared queue/state bookkeeping for diffusion schedulers.""" + + def __init__(self) -> None: + self.od_config: OmniDiffusionConfig | None = None + self._request_states: dict[str, DiffusionRequestState] = {} + self._request_id_to_sched_req_id: dict[str, str] = {} + self._step_id: int = 0 + self._waiting: deque[str] = deque() + self._running: list[str] = [] + self._finished_req_ids: set[str] = set() + # currently used by vllm_omni/entrypoints/omni_stage.py, + # can't be used for real multi-step scheduling without proper architectural changes, + # so we keep it fixed at 1 for now. + self._max_batch_size: int = 1 + + def initialize(self, od_config: OmniDiffusionConfig) -> None: + self.od_config = od_config + self._request_states.clear() + self._request_id_to_sched_req_id.clear() + self._step_id = 0 + self._waiting.clear() + self._running.clear() + self._finished_req_ids.clear() + self._reset_scheduler_state() + + def has_requests(self) -> bool: + return bool(self._waiting or self._running) + + def get_request_state(self, sched_req_id: str) -> DiffusionRequestState | None: + return self._request_states.get(sched_req_id) + + def get_sched_req_id(self, request_id: str) -> str | None: + return self._request_id_to_sched_req_id.get(request_id) + + def pop_request_state(self, sched_req_id: str) -> DiffusionRequestState | None: + self._pop_extra_request_state(sched_req_id) + state = self._request_states.pop(sched_req_id, None) + if state is not None: + self._unregister_request_ids(state.req.request_ids, sched_req_id) + return state + + def preempt_request(self, sched_req_id: str) -> bool: + if sched_req_id not in self._request_states: + return False + if sched_req_id in self._running: + self._running.remove(sched_req_id) + self._waiting.appendleft(sched_req_id) + self._request_states[sched_req_id].status = DiffusionRequestStatus.PREEMPTED + return True + return False + + def finish_requests(self, sched_req_ids: str | list[str], status: DiffusionRequestStatus) -> None: + assert DiffusionRequestStatus.is_finished(status) + if isinstance(sched_req_ids, str): + sched_req_ids = [sched_req_ids] + self._finish_requests({sched_req_id: status for sched_req_id in sched_req_ids}) + + def close(self) -> None: + self._request_states.clear() + self._request_id_to_sched_req_id.clear() + self._waiting.clear() + self._running.clear() + self._finished_req_ids.clear() + self._reset_scheduler_state() + + def _finish_requests( + self, + statuses: dict[str, DiffusionRequestStatus], + errors: dict[str, str | None] | None = None, + ) -> set[str]: + if not statuses: + return set() + + finished_req_ids: set[str] = set() + running_to_remove: set[str] = set() + waiting_to_remove: set[str] = set() + + for sched_req_id, status in statuses.items(): + assert DiffusionRequestStatus.is_finished(status) + state = self._request_states.get(sched_req_id) + if state is None or state.is_finished(): + continue + + finished_req_ids.add(sched_req_id) + if sched_req_id in self._running: + running_to_remove.add(sched_req_id) + if sched_req_id in self._waiting: + waiting_to_remove.add(sched_req_id) + + if running_to_remove: + self._running = [sched_req_id for sched_req_id in self._running if sched_req_id not in running_to_remove] + if waiting_to_remove: + self._waiting = deque( + sched_req_id for sched_req_id in self._waiting if sched_req_id not in waiting_to_remove + ) + + for sched_req_id in finished_req_ids: + state = self._request_states[sched_req_id] + status = statuses[sched_req_id] + state.status = status + if status == DiffusionRequestStatus.FINISHED_ERROR: + state.error = None if errors is None else errors.get(sched_req_id) + else: + state.error = None + + self._finished_req_ids |= finished_req_ids + return finished_req_ids + + def _reset_scheduler_state(self) -> None: + """Reset subclass-owned state during initialize()/close().""" + + def _pop_extra_request_state(self, sched_req_id: str) -> None: + """Remove subclass-owned per-request state before popping request state.""" + + def _register_request_ids(self, request_ids: list[str], sched_req_id: str) -> None: + for request_id in request_ids: + existing = self._request_id_to_sched_req_id.get(request_id) + if existing is not None and existing != sched_req_id: + raise ValueError(f"request_id {request_id!r} is already mapped to active sched_req_id {existing!r}.") + self._request_id_to_sched_req_id[request_id] = sched_req_id + + def _unregister_request_ids(self, request_ids: list[str], sched_req_id: str) -> None: + for request_id in request_ids: + if self._request_id_to_sched_req_id.get(request_id) == sched_req_id: + self._request_id_to_sched_req_id.pop(request_id, None) diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index d8b3a39c54b..fd566d3b228 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -4,10 +4,18 @@ from __future__ import annotations import enum +import uuid +from abc import ABC, abstractmethod from dataclasses import dataclass +from functools import cached_property +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.request import OmniDiffusionRequest +logger = init_logger(__name__) + class DiffusionRequestStatus(enum.IntEnum): """Request status tracked by diffusion scheduler.""" @@ -30,6 +38,9 @@ def is_finished(status: DiffusionRequestStatus) -> bool: class DiffusionRequestState: """Scheduler-owned state for one queued OmniDiffusionRequest.""" + # Unique scheduler-owned request ID. + # NOTE: This identifies one OmniDiffusionRequest, which may contain multiple request_ids. + # TODO: Align this with OmniDiffusionRequest.request_ids once scheduler batching is supported. sched_req_id: str req: OmniDiffusionRequest status: DiffusionRequestStatus = DiffusionRequestStatus.WAITING @@ -40,15 +51,122 @@ def is_finished(self) -> bool: @dataclass -class DiffusionSchedulerOutput: - """Output of a single scheduling cycle. +class NewRequestData: + """Full request payload for a newly scheduled diffusion request.""" + + sched_req_id: str + req: OmniDiffusionRequest + + @classmethod + def from_state(cls, state: DiffusionRequestState) -> NewRequestData: + return cls(sched_req_id=state.sched_req_id, req=state.req) + + +@dataclass +class CachedRequestData: + """Cached diffusion requests that only need their scheduler ids resent.""" + + sched_req_ids: list[str] + + @classmethod + def make_empty(cls) -> CachedRequestData: + return cls(sched_req_ids=[]) - Kept intentionally small so step-execution components can share a stable - transport shape while scheduler policy continues to evolve. - """ + +@dataclass +class DiffusionSchedulerOutput: + """Output of a single scheduling cycle.""" step_id: int - req_states: list[DiffusionRequestState] + scheduled_new_reqs: list[NewRequestData] + scheduled_cached_reqs: CachedRequestData finished_req_ids: set[str] num_running_reqs: int num_waiting_reqs: int + + @cached_property + def scheduled_req_ids(self) -> list[str]: + """ + All scheduled request ids in this cycle, including both new and cached ones. + NOTE: + This id is generated and owned by the scheduler, + and may be different from the OmniDiffusionRequest.request_ids. + """ + return [ + *(req.sched_req_id for req in self.scheduled_new_reqs), + *self.scheduled_cached_reqs.sched_req_ids, + ] + + @property + def num_scheduled_reqs(self) -> int: + return len(self.scheduled_req_ids) + + @property + def is_empty(self) -> bool: + return self.num_scheduled_reqs == 0 + + +class SchedulerInterface(ABC): + """Abstract lifecycle contract for diffusion schedulers.""" + + def _make_sched_req_id(self, request: OmniDiffusionRequest) -> str: + """ + Generate a unique scheduler request ID for the given request. + The default implementation uses the first request_id from the request if available, + otherwise generates a random one. + """ + if request.request_ids: + base = request.request_ids[0] + else: + logger.warning("Request has no request_ids, generating a random one. Request: %s", request) + base = f"req_{uuid.uuid4().hex[:8]}" + + sched_req_id = base + suffix = 1 + while self.get_request_state(sched_req_id) is not None: + sched_req_id = f"{base}#{suffix}" + suffix += 1 + return sched_req_id + + @abstractmethod + def initialize(self, od_config: OmniDiffusionConfig) -> None: + """Initialize or reset scheduler state.""" + + @abstractmethod + def add_request(self, request: OmniDiffusionRequest) -> str: + """Add a request and return the scheduler-owned request id.""" + + @abstractmethod + def schedule(self) -> DiffusionSchedulerOutput: + """Run one scheduling cycle.""" + + @abstractmethod + def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: DiffusionOutput) -> set[str]: + """Update scheduler state from executor output.""" + + def has_requests(self) -> bool: + """Return whether the scheduler still owns runnable requests.""" + + @abstractmethod + def get_request_state(self, sched_req_id: str) -> DiffusionRequestState | None: + """Return request state if present.""" + + @abstractmethod + def get_sched_req_id(self, request_id: str) -> str | None: + """Resolve a public request_id to the active scheduler request id.""" + + @abstractmethod + def pop_request_state(self, sched_req_id: str) -> DiffusionRequestState | None: + """Remove and return request state if present.""" + + @abstractmethod + def preempt_request(self, sched_req_id: str) -> bool: + """Preempt a running request back to waiting.""" + + @abstractmethod + def finish_requests(self, sched_req_ids: str | list[str], status: DiffusionRequestStatus) -> None: + """Mark one or more requests finished.""" + + @abstractmethod + def close(self) -> None: + """Release scheduler-owned state.""" diff --git a/vllm_omni/diffusion/sched/request_scheduler.py b/vllm_omni/diffusion/sched/request_scheduler.py new file mode 100644 index 00000000000..fe48bd62453 --- /dev/null +++ b/vllm_omni/diffusion/sched/request_scheduler.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.base_scheduler import _BaseScheduler +from vllm_omni.diffusion.sched.interface import ( + CachedRequestData, + DiffusionRequestState, + DiffusionRequestStatus, + DiffusionSchedulerOutput, + NewRequestData, +) + +logger = init_logger(__name__) + + +class RequestScheduler(_BaseScheduler): + """Diffusion scheduler with vLLM-style waiting/running queues.""" + + def add_request(self, request: OmniDiffusionRequest) -> str: + sched_req_id = self._make_sched_req_id(request) + state = DiffusionRequestState(sched_req_id=sched_req_id, req=request) + self._request_states[sched_req_id] = state + self._register_request_ids(request.request_ids, sched_req_id) + self._waiting.append(sched_req_id) + logger.debug("Scheduler add_request: %s (waiting=%d)", sched_req_id, len(self._waiting)) + return sched_req_id + + def schedule(self) -> DiffusionSchedulerOutput: + scheduled_new_reqs: list[NewRequestData] = [] + scheduled_cached_req_ids: list[str] = [] + + # First, schedule the RUNNING request(s) + for sched_req_id in self._running: + state = self._request_states.get(sched_req_id) + if state is not None: + scheduled_cached_req_ids.append(sched_req_id) + + # Second, schedule WAITING requests while capacity remains. + # RequestScheduler only allows one active request at a time. + while self._waiting and not self._running: + sched_req_id = self._waiting.popleft() + state = self._request_states.get(sched_req_id) + if state is None: + continue + was_new_request = state.status == DiffusionRequestStatus.WAITING + state.status = DiffusionRequestStatus.RUNNING + self._running.append(sched_req_id) + if was_new_request: + scheduled_new_reqs.append(NewRequestData.from_state(state)) + else: + scheduled_cached_req_ids.append(sched_req_id) + + scheduler_output = DiffusionSchedulerOutput( + step_id=self._step_id, + scheduled_new_reqs=scheduled_new_reqs, + scheduled_cached_reqs=CachedRequestData(sched_req_ids=scheduled_cached_req_ids), + finished_req_ids=set(self._finished_req_ids), + num_running_reqs=len(self._running), + num_waiting_reqs=len(self._waiting), + ) + + self._step_id += 1 + self._finished_req_ids.clear() + return scheduler_output + + def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: DiffusionOutput) -> set[str]: + scheduled_req_ids = sched_output.scheduled_req_ids + if not scheduled_req_ids: + return set() + + # A scheduled request may be aborted after schedule() but before + # update_from_output() processes the runner output. It is already + # marked finished at that point, but we still need to surface its id + # in this update so the engine can observe the terminal state. + finished_req_ids = { + sched_req_id for sched_req_id in scheduled_req_ids if sched_req_id in self._finished_req_ids + } + terminal_statuses: dict[str, DiffusionRequestStatus] = {} + terminal_errors: dict[str, str | None] = {} + for sched_req_id in scheduled_req_ids: + state = self._request_states.get(sched_req_id) + if state is None or state.is_finished(): + continue + if output.error: + terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[sched_req_id] = output.error + else: + terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_COMPLETED + terminal_errors[sched_req_id] = None + + finished_req_ids |= self._finish_requests(terminal_statuses, terminal_errors) + return finished_req_ids + + def abort_request(self, sched_req_id: str) -> bool: + if self.get_request_state(sched_req_id) is None: + return False + self.finish_requests(sched_req_id, DiffusionRequestStatus.FINISHED_ABORTED) + return True diff --git a/vllm_omni/diffusion/scheduler.py b/vllm_omni/diffusion/scheduler.py deleted file mode 100644 index d5b1de152ac..00000000000 --- a/vllm_omni/diffusion/scheduler.py +++ /dev/null @@ -1,237 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from __future__ import annotations - -import enum -import uuid -from collections import deque -from dataclasses import dataclass - -from vllm.logger import init_logger - -from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig -from vllm_omni.diffusion.request import OmniDiffusionRequest - -logger = init_logger(__name__) - - -class DiffusionRequestStatus(enum.IntEnum): - """Request status tracked by diffusion scheduler.""" - - WAITING = enum.auto() - RUNNING = enum.auto() - PREEMPTED = enum.auto() - - # if any status is after FINISHED_COMPLETED, it is considered finished - FINISHED_COMPLETED = enum.auto() - FINISHED_ABORTED = enum.auto() - FINISHED_ERROR = enum.auto() - - @staticmethod - def is_finished(status: DiffusionRequestStatus) -> bool: - return status >= DiffusionRequestStatus.FINISHED_COMPLETED - - -@dataclass -class DiffusionRequestState: - """Scheduler-owned state for one queued OmniDiffusionRequest.""" - - sched_req_id: str # unique scheduler request ID generated by scheduler - req: OmniDiffusionRequest - status: DiffusionRequestStatus = DiffusionRequestStatus.WAITING - error: str | None = None - - -@dataclass -class DiffusionSchedulerOutput: - """Output of a single scheduling cycle.""" - - step_id: int - req_states: list[DiffusionRequestState] - finished_req_ids: set[str] - num_running_reqs: int - num_waiting_reqs: int - - -class Scheduler: - """Diffusion scheduler with vLLM-style waiting/running queues. - - NOTE: Currently, each OmniDiffusionRequest is already pre-batched upstream. - Scheduler only handles request-level state transitions and never merges - multiple requests into a new batch. - """ - - def __init__(self) -> None: - self.od_config: OmniDiffusionConfig | None = None - self._request_states: dict[str, DiffusionRequestState] = {} - self._step_id: int = 0 - - self._waiting: deque[str] = deque() - self._running: list[str] = [] - self._finished_req_ids: set[str] = set() - - def initialize(self, od_config: OmniDiffusionConfig) -> None: - self.od_config = od_config - self._request_states.clear() - self._step_id = 0 - - self._waiting.clear() - self._running.clear() - self._finished_req_ids.clear() - - def add_request(self, request: OmniDiffusionRequest) -> str: - sched_req_id = self._make_sched_req_id(request) - state = DiffusionRequestState(sched_req_id=sched_req_id, req=request) - self._request_states[sched_req_id] = state - self._waiting.append(sched_req_id) - logger.debug( - "Scheduler add_request: %s (waiting=%d)", - sched_req_id, - len(self._waiting), - ) - return sched_req_id - - def schedule(self) -> DiffusionSchedulerOutput: - # Single-request scheduling: do not build multiple request batches. - if not self._running and self._waiting: - sched_req_id = self._waiting.popleft() - state = self._request_states.get(sched_req_id) - if state is not None: - state.status = DiffusionRequestStatus.RUNNING - self._running.append(sched_req_id) - - running_states: list[DiffusionRequestState] = [] - for sched_req_id in self._running: - state = self._request_states.get(sched_req_id) - if state is not None: - running_states.append(state) - - scheduler_output = DiffusionSchedulerOutput( - step_id=self._step_id, - req_states=running_states, - finished_req_ids=set(self._finished_req_ids), - num_running_reqs=len(self._running), - num_waiting_reqs=len(self._waiting), - ) - - # update after schedule - self._step_id += 1 - self._finished_req_ids.clear() - return scheduler_output - - def update_from_output( - self, - sched_output: DiffusionSchedulerOutput, - output: DiffusionOutput, - ) -> set[str]: - scheduled_sched_req_ids = {state.sched_req_id for state in sched_output.req_states} - if not scheduled_sched_req_ids: - return set() - - completed_sched_req_ids: set[str] = set() - if output.error: - for sched_req_id in scheduled_sched_req_ids: - state = self._request_states.get(sched_req_id) - if state is None: - continue - state.status = DiffusionRequestStatus.FINISHED_ERROR - state.error = output.error - completed_sched_req_ids.add(sched_req_id) - else: - for sched_req_id in scheduled_sched_req_ids: - state = self._request_states.get(sched_req_id) - if state is None: - continue - state.status = DiffusionRequestStatus.FINISHED_COMPLETED - state.error = None - completed_sched_req_ids.add(sched_req_id) - - if completed_sched_req_ids: - self._running = [ - sched_req_id for sched_req_id in self._running if sched_req_id not in completed_sched_req_ids - ] - for sched_req_id in completed_sched_req_ids: - try: - self._waiting.remove(sched_req_id) - except ValueError: - pass - self._finished_req_ids |= completed_sched_req_ids - - return completed_sched_req_ids - - def abort_request(self, sched_req_id: str) -> bool: - if sched_req_id not in self._request_states: - return False - self.finish_request(sched_req_id, DiffusionRequestStatus.FINISHED_ABORTED) - self._finished_req_ids.add(sched_req_id) - return True - - def has_requests(self) -> bool: - return bool(self._waiting or self._running) - - def get_request_state( - self, - sched_req_id: str, - ) -> DiffusionRequestState | None: - return self._request_states.get(sched_req_id) - - def pop_request_state( - self, - sched_req_id: str, - ) -> DiffusionRequestState | None: - return self._request_states.pop(sched_req_id, None) - - def preempt_request(self, sched_req_id: str) -> bool: - if sched_req_id not in self._request_states: - return False - if sched_req_id in self._running: - self._running.remove(sched_req_id) - self._waiting.appendleft(sched_req_id) - self._request_states[sched_req_id].status = DiffusionRequestStatus.PREEMPTED - return True - return False - - def finish_request( - self, - sched_req_id: str, - status: DiffusionRequestStatus, - ) -> None: - assert DiffusionRequestStatus.is_finished(status) - state = self._request_states.get(sched_req_id) - if state is None: - return - - state.status = status - if sched_req_id in self._running: - self._running.remove(sched_req_id) - try: - self._waiting.remove(sched_req_id) - except ValueError: - pass - - def close(self) -> None: - self._request_states.clear() - self._waiting.clear() - self._running.clear() - self._finished_req_ids.clear() - - def _make_sched_req_id(self, request: OmniDiffusionRequest) -> str: - """ - Generate a unique request ID for the given request. - If the request already has request IDs, use the first one as the base. - - NOTE: OmniDiffusionRequest already contain multiple prompts/outputs - as a pre-batched request object. - """ - if request.request_ids: - base = request.request_ids[0] - else: - base = f"req_{uuid.uuid4().hex[:8]}" - - sched_req_id = base - suffix = 1 - while sched_req_id in self._request_states: - sched_req_id = f"{base}#{suffix}" - suffix += 1 - return sched_req_id From d8285ea7560e52f43588906cad07e53e325b31dd Mon Sep 17 00:00:00 2001 From: JiangJie Zhang <76905040+yJader@users.noreply.github.com> Date: Wed, 18 Mar 2026 23:38:23 +0800 Subject: [PATCH 07/14] fix: Add missing @abstractmethod Co-authored-by: SYLAR <125541396+lishunyang12@users.noreply.github.com> Signed-off-by: JiangJie Zhang <76905040+yJader@users.noreply.github.com> --- vllm_omni/diffusion/sched/interface.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index fd566d3b228..427cad03d0e 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -144,13 +144,14 @@ def schedule(self) -> DiffusionSchedulerOutput: def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: DiffusionOutput) -> set[str]: """Update scheduler state from executor output.""" - def has_requests(self) -> bool: - """Return whether the scheduler still owns runnable requests.""" - @abstractmethod def get_request_state(self, sched_req_id: str) -> DiffusionRequestState | None: """Return request state if present.""" + @abstractmethod + def has_requests(self) -> bool: + """Return whether the scheduler still owns runnable requests.""" + @abstractmethod def get_sched_req_id(self, request_id: str) -> str | None: """Resolve a public request_id to the active scheduler request id.""" From f53ab95409eab80dea1aefe1c5c0023d3e53e9e8 Mon Sep 17 00:00:00 2001 From: jader Date: Wed, 18 Mar 2026 16:58:41 +0000 Subject: [PATCH 08/14] docs(diffusion/scheduler): update comments for clarity on request handling Signed-off-by: jader --- vllm_omni/diffusion/diffusion_engine.py | 4 ++++ vllm_omni/diffusion/sched/base_scheduler.py | 5 ++--- vllm_omni/diffusion/sched/request_scheduler.py | 5 +++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 53a567f6180..63867325a99 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -276,6 +276,10 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus raise RuntimeError("Diffusion scheduler has no runnable requests.") continue + # NOTE: add_req_and_wait_for_response() is synchronous, and + # the scheduler currently enforces _max_batch_size = 1 (see + # vllm_omni/diffusion/sched/base_scheduler.py), so we directly + # take the single scheduled request here. sched_req_id = sched_output.scheduled_req_ids[0] req = sched_output.scheduled_new_reqs[0].req try: diff --git a/vllm_omni/diffusion/sched/base_scheduler.py b/vllm_omni/diffusion/sched/base_scheduler.py index ef45d868755..a59fa50d1ee 100644 --- a/vllm_omni/diffusion/sched/base_scheduler.py +++ b/vllm_omni/diffusion/sched/base_scheduler.py @@ -24,9 +24,8 @@ def __init__(self) -> None: self._waiting: deque[str] = deque() self._running: list[str] = [] self._finished_req_ids: set[str] = set() - # currently used by vllm_omni/entrypoints/omni_stage.py, - # can't be used for real multi-step scheduling without proper architectural changes, - # so we keep it fixed at 1 for now. + # The current DiffusionEngine execution mode does not support real + # request batching well, so we keep this fixed at 1 for now. self._max_batch_size: int = 1 def initialize(self, od_config: OmniDiffusionConfig) -> None: diff --git a/vllm_omni/diffusion/sched/request_scheduler.py b/vllm_omni/diffusion/sched/request_scheduler.py index fe48bd62453..ed8316ee58f 100644 --- a/vllm_omni/diffusion/sched/request_scheduler.py +++ b/vllm_omni/diffusion/sched/request_scheduler.py @@ -42,8 +42,7 @@ def schedule(self) -> DiffusionSchedulerOutput: scheduled_cached_req_ids.append(sched_req_id) # Second, schedule WAITING requests while capacity remains. - # RequestScheduler only allows one active request at a time. - while self._waiting and not self._running: + while self._waiting and len(self._running) < self._max_batch_size: sched_req_id = self._waiting.popleft() state = self._request_states.get(sched_req_id) if state is None: @@ -83,6 +82,8 @@ def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: Dif } terminal_statuses: dict[str, DiffusionRequestStatus] = {} terminal_errors: dict[str, str | None] = {} + # NOTE: request-mode currently assumes one executor call produces one + # DiffusionOutput for the single scheduled request in this cycle. for sched_req_id in scheduled_req_ids: state = self._request_states.get(sched_req_id) if state is None or state.is_finished(): From 2f25a9ed6512eae9118b383647da66c9d19988ec Mon Sep 17 00:00:00 2001 From: jader Date: Fri, 20 Mar 2026 04:28:23 +0000 Subject: [PATCH 09/14] feat: add request_ids to OmniDiffusionRequest in dummy run Signed-off-by: jader --- vllm_omni/diffusion/diffusion_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 63867325a99..37d9894f20d 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -431,6 +431,7 @@ def _dummy_run(self): prompt: OmniTextPrompt = {"prompt": "dummy run", "multi_modal_data": {"image": dummy_image}} req = OmniDiffusionRequest( prompts=[prompt], + request_ids=["dummy_req_id"], sampling_params=OmniDiffusionSamplingParams( height=height, width=width, From d0c016364af22674bf34f6cffdf91a4099204973 Mon Sep 17 00:00:00 2001 From: jader Date: Fri, 20 Mar 2026 06:13:42 +0000 Subject: [PATCH 10/14] fix(diffusion): restore audio warmup support Signed-off-by: jader --- vllm_omni/diffusion/diffusion_engine.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index 37d9894f20d..d51a6cec3ad 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -7,6 +7,7 @@ from collections.abc import Iterable from typing import Any +import numpy as np import PIL.Image import torch from vllm.logger import init_logger @@ -33,6 +34,13 @@ def supports_image_input(model_class_name: str) -> bool: return bool(getattr(model_cls, "support_image_input", False)) +def supports_audio_input(model_class_name: str) -> bool: + model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) + if model_cls is None: + return False + return bool(getattr(model_cls, "support_audio_input", False)) + + def image_color_format(model_class_name: str) -> str: model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name) return getattr(model_cls, "color_format", "RGB") @@ -428,7 +436,18 @@ def _dummy_run(self): else: dummy_image = None - prompt: OmniTextPrompt = {"prompt": "dummy run", "multi_modal_data": {"image": dummy_image}} + if supports_audio_input(self.od_config.model_class_name): + audio_sr = 16000 + audio_duration_sec = 4 + audio_array = np.random.randn(audio_sr * audio_duration_sec).astype(np.float32) + dummy_audio = audio_array[audio_sr * 1 : audio_sr * 3] + else: + dummy_audio = None + + prompt: OmniTextPrompt = { + "prompt": "dummy run", + "multi_modal_data": {"image": dummy_image, "audio": dummy_audio}, + } req = OmniDiffusionRequest( prompts=[prompt], request_ids=["dummy_req_id"], From 8f42b638a0d0283464c251afe99f3a8b1493e0d3 Mon Sep 17 00:00:00 2001 From: JiangJie Zhang <76905040+yJader@users.noreply.github.com> Date: Fri, 20 Mar 2026 14:14:28 +0800 Subject: [PATCH 11/14] Update tests/diffusion/test_diffusion_scheduler.py Co-authored-by: Didan Deng <33117903+wtomin@users.noreply.github.com> Signed-off-by: JiangJie Zhang <76905040+yJader@users.noreply.github.com> --- tests/diffusion/test_diffusion_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index db61c8f65d6..171a6278cd9 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -18,7 +18,7 @@ from vllm_omni.diffusion.sched.interface import CachedRequestData, NewRequestData from vllm_omni.inputs.data import OmniDiffusionSamplingParams -pytestmark = [pytest.mark.diffusion] +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] def _make_request(req_id: str) -> OmniDiffusionRequest: From 324476b566bfffa992f4c981039f6552b7d06776 Mon Sep 17 00:00:00 2001 From: jader Date: Sun, 22 Mar 2026 04:42:41 +0000 Subject: [PATCH 12/14] docs(diffusion/scheduler): update diffusion scheduler design document Signed-off-by: jader --- docs/design/module/dit_module.md | 129 ++++++++++++++++++++----------- 1 file changed, 84 insertions(+), 45 deletions(-) diff --git a/docs/design/module/dit_module.md b/docs/design/module/dit_module.md index 4bc3f2b6cee..e24a75238f2 100644 --- a/docs/design/module/dit_module.md +++ b/docs/design/module/dit_module.md @@ -127,80 +127,117 @@ def step(self, requests: list[OmniDiffusionRequest]): ## 2. Scheduler -**Location**: `vllm_omni/diffusion/scheduler.py` +**Location**: `vllm_omni/diffusion/sched/` ### Architecture -The `Scheduler` is implemented as a **Singleton** pattern to ensure a single coordination point across the system, i.e., only one scheduler instance exists for coordination. +The scheduler is a **request-state scheduler**. It owns request lifecycle management and scheduling decisions, while execution stays in `DiffusionEngine` and the executor. ### Key Components -#### 2.1 Message Queue System +#### 2.1 Scheduler Interface ```python -class Scheduler: - def initialize(self, od_config: OmniDiffusionConfig): - # Broadcast queue: scheduler -> all workers - self.mq = MessageQueue( - n_reader=od_config.num_gpus, - n_local_reader=od_config.num_gpus, - local_reader_ranks=list(range(od_config.num_gpus)), - ) - - # Result queue: rank 0 worker -> scheduler - self.result_mq = None # Initialized later +class SchedulerInterface(ABC): + def add_request(self, request: OmniDiffusionRequest) -> str: ... + def schedule(self) -> DiffusionSchedulerOutput: ... + def update_from_output( + self, + sched_output: DiffusionSchedulerOutput, + output: DiffusionOutput, + ) -> set[str]: ... ``` -**Communication Pattern**: +**Responsibilities**: -- **Broadcast Queue**: One-to-many communication (scheduler → all workers) +- **Lifecycle contract**: Defines how the engine adds requests, triggers one scheduling cycle, and feeds executor results back. -- **Result Queue**: One-to-one communication (rank 0 → scheduler) +- **Stable boundary**: `DiffusionSchedulerOutput` is the only scheduling result consumed by `DiffusionEngine`. -- **Shared Memory**: Uses `MessageQueue` (ZMQ-based) for efficient IPC +- **Pluggability**: Different scheduler policies can reuse the same engine integration path. -#### 2.2 Request Distribution +#### 2.2 Request State Model ```python -def add_req(self, requests: list[OmniDiffusionRequest]) -> DiffusionOutput: - # Broadcast request to all workers - self.mq.enqueue(requests) +class DiffusionRequestStatus(enum.IntEnum): + WAITING = ... + RUNNING = ... + PREEMPTED = ... + FINISHED_COMPLETED = ... + FINISHED_ABORTED = ... + FINISHED_ERROR = ... + +@dataclass +class DiffusionRequestState: + sched_req_id: str + req: OmniDiffusionRequest + status: DiffusionRequestStatus = DiffusionRequestStatus.WAITING +``` - # Wait for result from Rank 0 - output = self.result_mq.dequeue() - return output +**Design Features**: + +- **Scheduler-owned ID**: Each `OmniDiffusionRequest` is tracked by an internal `sched_req_id`, separated from public `request_id` values. + +- **Explicit lifecycle**: Requests move through waiting, running, optional preemption, and terminal states. + +- **Centralized error handling**: Completion, abort, and error states are all normalized in the scheduler layer. + +#### 2.3 Shared Bookkeeping in `_BaseScheduler` + +```python +class _BaseScheduler(SchedulerInterface): + def __init__(self) -> None: + self._request_states = {} + self._request_id_to_sched_req_id = {} + self._waiting = deque() + self._running = [] + self._finished_req_ids = set() + self._max_batch_size = 1 ``` **Design Features**: -- **Broadcast Model**: All workers receive the same request (for tensor parallelism) +- **Common state storage**: Shared request maps and waiting/running sets live in the base class. -- **Single Response**: Only rank 0 sends results back (avoids duplicate outputs) +- **Shared cleanup logic**: Request-id registration, finish handling, and state removal are centralized instead of duplicated in each policy. -- **Synchronous**: Blocks until result is received (can be made async) +- **Current constraint**: `_max_batch_size` remains `1` because the current engine path is still synchronous request-mode execution. -#### 2.3 Singleton Pattern +#### 2.4 Current `RequestScheduler` Policy ```python -class Scheduler: - _instance = None +class RequestScheduler(_BaseScheduler): + def schedule(self) -> DiffusionSchedulerOutput: + # 1. keep existing RUNNING requests in the scheduling result + # 2. pull WAITING requests while capacity remains + # 3. move newly admitted requests into RUNNING +``` + +**Behavior**: + +- **FIFO request scheduling**: Waiting requests are promoted in queue order. - def __new__(cls, *args, **kwargs): - if not cls._instance: - cls._instance = super().__new__(cls) - return cls._instance +- **Single-request admission**: The current policy only admits one active request at a time. -# Global singleton instance -scheduler = Scheduler() +- **Executor result feedback**: `update_from_output()` converts executor output into `FINISHED_COMPLETED` or `FINISHED_ERROR` and returns finished scheduler ids. + +#### 2.5 Engine-Driven Execution Loop + +```python +sched_req_id = scheduler.add_request(request) +while True: + sched_output = scheduler.schedule() + output = executor.add_req(req) + finished_req_ids = scheduler.update_from_output(sched_output, output) ``` -**Benefits**: +**Design Decisions**: -- **Single Point of Control**: Ensures consistent state +- **Separation of concerns**: Scheduler manages state and policy; executor handles runtime execution. -- **Easy Access**: Global `scheduler` instance accessible everywhere +- **No scheduler-owned IPC**: Scheduler no longer talks to workers directly. -- **Resource Management**: Centralized queue management +- **Conservative concurrency**: The current request-mode implementation still allows only one active request at a time. --- @@ -880,8 +917,9 @@ def initialize_model_parallel( └─> Model-specific transformations 3. Scheduling - └─> scheduler.add_req(requests) - └─> Broadcast via MessageQueue to all workers + └─> scheduler.add_request(request) + └─> scheduler.schedule() + └─> DiffusionEngine submits scheduled request to executor.add_req(req) 4. Worker Execution └─> WorkerProc.worker_busy_loop() @@ -895,8 +933,9 @@ def initialize_model_parallel( └─> vae.decode() 5. Result Collection - └─> Rank 0 sends DiffusionOutput via result queue - └─> Scheduler receives and returns + └─> Executor returns DiffusionOutput + └─> scheduler.update_from_output(...) + └─> DiffusionEngine pops finished request state 6. Post-processing └─> post_process_func(output) From df3ac75405c35049e38710e5c5571b8537e7f444 Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 23 Mar 2026 04:31:10 +0000 Subject: [PATCH 13/14] fix(diffusion): update request handling to support new request structure Co-authored-by: asukaqaq-s <1311722138@qq.com> Signed-off-by: jader --- .../diffusion/test_diffusion_step_pipeline.py | 27 +++++++++---- .../worker/diffusion_model_runner.py | 38 +++++++++++-------- .../diffusion/worker/diffusion_worker.py | 2 +- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py index 121fbc0d5fa..ad08487fe9c 100644 --- a/tests/diffusion/test_diffusion_step_pipeline.py +++ b/tests/diffusion/test_diffusion_step_pipeline.py @@ -25,10 +25,9 @@ unpack_diffusion_output_shm, ) from vllm_omni.diffusion.sched.interface import ( - DiffusionRequestState as SchedulerRequestState, -) -from vllm_omni.diffusion.sched.interface import ( + CachedRequestData, DiffusionSchedulerOutput, + NewRequestData, ) from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker @@ -224,7 +223,19 @@ def _make_distributed_runner(mode: str, device: torch.device): def _make_scheduler_output(req, sched_req_id="req-1", step_id=0, finished_req_ids=None): return DiffusionSchedulerOutput( step_id=step_id, - req_states=[SchedulerRequestState(sched_req_id=sched_req_id, req=req)], + scheduled_new_reqs=[NewRequestData(sched_req_id=sched_req_id, req=req)], + scheduled_cached_reqs=CachedRequestData.make_empty(), + finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), + num_running_reqs=1, + num_waiting_reqs=0, + ) + + +def _make_cached_scheduler_output(sched_req_id="req-1", step_id=1, finished_req_ids=None): + return DiffusionSchedulerOutput( + step_id=step_id, + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData(sched_req_ids=[sched_req_id]), finished_req_ids=set() if finished_req_ids is None else set(finished_req_ids), num_running_reqs=1, num_waiting_reqs=0, @@ -297,7 +308,7 @@ def test_completes_request_and_clears_state(self, monkeypatch): assert first.result is None assert "req-1" in runner.state_cache - second = DiffusionModelRunner.execute_stepwise(runner, _make_scheduler_output(req, step_id=1)) + second = DiffusionModelRunner.execute_stepwise(runner, _make_cached_scheduler_output(step_id=1)) assert second.req_id == "req-1" assert second.step_index == 2 assert second.finished is True @@ -369,7 +380,7 @@ def test_delegates_to_model_runner(self): worker = object.__new__(DiffusionWorker) expected = RunnerOutput(req_id="req-1", step_index=1, finished=False, result=None) scheduler_output = SimpleNamespace( - req_states=[ + scheduled_new_reqs=[ SimpleNamespace( req=SimpleNamespace( sampling_params=SimpleNamespace(lora_request=None), @@ -389,7 +400,7 @@ def test_delegates_to_model_runner(self): def test_clears_active_lora_before_stepwise_execution(self): worker = object.__new__(DiffusionWorker) scheduler_output = SimpleNamespace( - req_states=[ + scheduled_new_reqs=[ SimpleNamespace( req=SimpleNamespace( sampling_params=SimpleNamespace(lora_request=None), @@ -413,7 +424,7 @@ def set_active_adapter(self, adapter): def test_rejects_lora_requests_in_step_mode(self): worker = object.__new__(DiffusionWorker) scheduler_output = SimpleNamespace( - req_states=[ + scheduled_new_reqs=[ SimpleNamespace( req=SimpleNamespace( sampling_params=SimpleNamespace(lora_request=object()), diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index e3c27f94545..853ee937d48 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -262,24 +262,30 @@ def supports_step_mode(self) -> bool: """Return whether current pipeline supports step execution.""" return self.pipeline is not None and supports_step_execution(self.pipeline) - def _update_states(self, scheduler_output: DiffusionSchedulerOutput) -> DiffusionRequestState: + def _update_states(self, scheduler_output: DiffusionSchedulerOutput) -> tuple[DiffusionRequestState, bool]: """Step-before update: cleanup finished requests and get/create one running state.""" for req_id in scheduler_output.finished_req_ids: self.state_cache.pop(req_id, None) - req_states = scheduler_output.req_states - if len(req_states) != 1: - raise ValueError(f"Step mode currently supports batch_size=1, but got {len(req_states)} req_states.") + if scheduler_output.num_scheduled_reqs != 1: + raise ValueError( + "Step mode currently supports batch_size=1, " + f"but got {scheduler_output.num_scheduled_reqs} scheduled requests." + ) - # TODO: remove req state from SchedulerOutput - # Stepwise mode currently trusts runner-owned cached state more than - # re-validating scheduler-provided request content on every step. - sched_req_state = req_states[0] - req_id = sched_req_state.sched_req_id - if req_id in self.state_cache: - return self.state_cache[req_id] + if scheduler_output.scheduled_new_reqs: + new_req_data = scheduler_output.scheduled_new_reqs[0] + req_id = new_req_data.sched_req_id + req = new_req_data.req + if req_id in self.state_cache: + raise ValueError(f"Received duplicate new-request payload for cached request {req_id}.") + else: + req_id = scheduler_output.scheduled_cached_reqs.sched_req_ids[0] + state = self.state_cache.get(req_id) + if state is None: + raise ValueError(f"Missing cached state for request {req_id}.") + return state, False - req = sched_req_state.req request_ids = req.request_ids or [req_id] if len(request_ids) != len(req.prompts): raise ValueError( @@ -292,7 +298,7 @@ def _update_states(self, scheduler_output: DiffusionSchedulerOutput) -> Diffusio prompts=req.prompts, ) self.state_cache[req_id] = state - return state + return state, True def _update_states_after(self, state: DiffusionRequestState, finished: bool) -> None: """Step-after update: clear cached state for completed request.""" @@ -313,9 +319,9 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner use_hsdp = self.od_config.parallel_config.use_hsdp grad_context = torch.no_grad() if use_hsdp else torch.inference_mode() with grad_context: - state = self._update_states(scheduler_output) + state, is_new_request = self._update_states(scheduler_output) - if state.new_request: + if is_new_request: # TODO: support kv manager recv # TODO: support cache backend if state.sampling.generator is None and state.sampling.seed is not None: @@ -329,7 +335,7 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner with set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config): # step0/new request: encode - if state.new_request: + if is_new_request: self.pipeline.prepare_encode(state) noise_pred = self.pipeline.denoise_step(state) diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 3ca12d8e6a2..1b9b9f19542 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -204,7 +204,7 @@ def execute_stepwise(self, scheduler_output: DiffusionSchedulerOutput) -> Runner # adapter first so worker-local LoRA state cannot leak in. self.lora_manager.set_active_adapter(None) - if any(req_state.req.sampling_params.lora_request is not None for req_state in scheduler_output.req_states): + if any(new_req.req.sampling_params.lora_request is not None for new_req in scheduler_output.scheduled_new_reqs): raise ValueError("Step mode does not support LoRA yet.") return self.model_runner.execute_stepwise(scheduler_output) From b9be68a057d5300ac7eaeb88eefac893c55e327b Mon Sep 17 00:00:00 2001 From: jader Date: Mon, 23 Mar 2026 06:44:24 +0000 Subject: [PATCH 14/14] fix(diffusion/executor): improve shutdown signal handling and resource cleanup Signed-off-by: jader --- .../diffusion/executor/multiproc_executor.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index b2e2c2ef9d9..eacdb2b628c 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -34,18 +34,11 @@ def __call__(self): try: for _ in range(self.num_workers): self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE) - except Exception as exc: - logger.warning("Failed to send shutdown signal: %s", exc) - for queue, label in ((self.broadcast_mq, "broadcast"), (self.result_mq, "result")): - if queue is None: - continue - try: - close_fn = getattr(queue, "close", None) - if callable(close_fn): - close_fn() + self.broadcast_mq = None + self.result_mq = None except Exception as exc: - logger.warning("Failed to close %s queue: %s", label, exc) + logger.warning("Failed to send shutdown signal: %s", exc) if self.processes: for proc in self.processes: @@ -258,4 +251,10 @@ def check_health(self) -> None: def shutdown(self) -> None: self._closed = True - self._finalizer() + try: + self._finalizer() + finally: + self._broadcast_mq = None + self._result_mq = None + self.resources = None + self._processes = []