diff --git a/docs/design/module/dit_module.md b/docs/design/module/dit_module.md index e24a75238f2..b0c7e9fc7fb 100644 --- a/docs/design/module/dit_module.md +++ b/docs/design/module/dit_module.md @@ -192,7 +192,7 @@ class _BaseScheduler(SchedulerInterface): self._waiting = deque() self._running = [] self._finished_req_ids = set() - self._max_batch_size = 1 + self.max_num_running_reqs = 1 ``` **Design Features**: @@ -201,7 +201,7 @@ class _BaseScheduler(SchedulerInterface): - **Shared cleanup logic**: Request-id registration, finish handling, and state removal are centralized instead of duplicated in each policy. -- **Current constraint**: `_max_batch_size` remains `1` because the current engine path is still synchronous request-mode execution. +- **Current constraint**: `max_num_running_reqs` remains `1` because the current engine path is still synchronous request-mode execution. #### 2.4 Current `RequestScheduler` Policy diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 7e325c1edc8..f0969b677f8 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -15,6 +15,7 @@ vLLM-Omni supports various advanced features for diffusion models: - Acceleration: **cache methods**, **parallelism methods** - Memory optimization: **cpu offloading**, **quantization** - Extensions: **LoRA inference** +- Execution modes: **step execution** ## Supported Features @@ -64,6 +65,16 @@ Extension methods add specialized capabilities to diffusion models beyond standa | **[LoRA Inference](diffusion/lora.md)** | Enables inference with Low-Rank Adaptation (LoRA) adapters weights | Reinforcement learning extensions | +### Execution Modes + +Execution modes control how the diffusion pipeline processes denoise steps. + +| Method | Description | Best For | +|--------|-------------|----------| +| **[Step Execution](diffusion/step_execution.md)** | Per-step denoise execution with mid-request abort support | Request cancellation between denoise steps, fine-grained execution control | + +**Note:** Step execution is currently supported by QwenImagePipeline only. See [Supported Models](#supported-models) for details. + ### Quantization Methods | Method | Configuration | Description | Best For | @@ -87,28 +98,28 @@ The following tables show which models support each feature: ### ImageGen -| Model | ⚑TeaCache | ⚑Cache-DiT | πŸ”€SP (Ulysses & Ring) | πŸ”€CFG-Parallel | πŸ”€Tensor-Parallel | πŸ”€HSDP | πŸ’ΎCPU Offload (Layerwise) | πŸ’ΎVAE-Patch-Parallel | πŸ’ΎQuantization | -|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:| -| **Bagel** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | -| **FLUX.1-dev** | ❌ | βœ… | ❌ | βœ… | βœ… | βœ… | ❌ | ❌ | βœ… | -| **FLUX.2-klein** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | βœ… | -| **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | -| **FLUX.2-dev** | ❌ | ❌ | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | -| **GLM-Image** | ❌ | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | -| **HunyuanImage3** | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | -| **LongCat-Image** | ❌ | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | -| **LongCat-Image-Edit** | ❌ | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | -| **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | -| **OmniGen2** | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | -| **Ovis-Image** | ❌ | βœ… | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | -| **Qwen-Image** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | βœ… | -| **Qwen-Image-2512** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | βœ… | -| **Qwen-Image-Edit** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | ❌ | -| **Qwen-Image-Edit-2509** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | ❌ | -| **Qwen-Image-Layered** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | ❌ | -| **Stable-Diffusion3.5** | ❌ | βœ… | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | -| **Z-Image** | βœ… | βœ… | βœ… | ❓ | βœ… (TP=2 only) | ❌ | ❌ | βœ… | βœ… | +| Model | ⚑TeaCache | ⚑Cache-DiT | πŸ”€SP (Ulysses & Ring) | πŸ”€CFG-Parallel | πŸ”€Tensor-Parallel | πŸ”€HSDP | πŸ’ΎCPU Offload (Layerwise) | πŸ’ΎVAE-Patch-Parallel | πŸ’ΎQuantization | πŸ”„Step Execution | +|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| +| **Bagel** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| **FLUX.1-dev** | ❌ | βœ… | ❌ | βœ… | βœ… | βœ… | ❌ | ❌ | βœ… | ❌ | +| **FLUX.2-klein** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | βœ… | ❌ | +| **FLUX.1-Kontext-dev** | ❌ | ❌ | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | +| **FLUX.2-dev** | ❌ | ❌ | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | +| **GLM-Image** | ❌ | ❌ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| **HunyuanImage3** | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | βœ… | ❌ | ❌ | +| **LongCat-Image** | ❌ | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| **LongCat-Image-Edit** | ❌ | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| **OmniGen2** | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Ovis-Image** | ❌ | βœ… | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Qwen-Image** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | βœ… | βœ… | +| **Qwen-Image-2512** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | βœ… | βœ… | +| **Qwen-Image-Edit** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | ❌ | ❌ | +| **Qwen-Image-Edit-2509** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | ❌ | ❌ | +| **Qwen-Image-Layered** | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | βœ… | βœ… | ❌ | ❌ | +| **Stable-Diffusion3.5** | ❌ | βœ… | ❌ | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Z-Image** | βœ… | βœ… | βœ… | ❓ | βœ… (TP=2 only) | ❌ | ❌ | βœ… | βœ… | ❌ | > Notes: > 1. Nextstep_1(T2I) does not support cache acceleration methods such as TeaCache or Cache-DiT. @@ -116,19 +127,19 @@ The following tables show which models support each feature: ### VideoGen -| Model | ⚑TeaCache | ⚑Cache-DiT | πŸ”€SP (Ulysses & Ring) | πŸ”€CFG-Parallel | πŸ”€Tensor-Parallel | πŸ”€HSDP | πŸ’ΎCPU Offload (Layerwise) | πŸ’ΎVAE-Patch-Parallel | πŸ’ΎQuantization | -|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:| -| **Wan2.2** | ❌ | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | -| **LTX-2** | ❌ | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | -| **Helios** | ❌ | ❌ | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | -| **HunyuanVideo-1.5 T2V I2V** | ❌ | βœ… | ❌ | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | -| **DreamID-Omni** | ❌ | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| Model | ⚑TeaCache | ⚑Cache-DiT | πŸ”€SP (Ulysses & Ring) | πŸ”€CFG-Parallel | πŸ”€Tensor-Parallel | πŸ”€HSDP | πŸ’ΎCPU Offload (Layerwise) | πŸ’ΎVAE-Patch-Parallel | πŸ’ΎQuantization | πŸ”„Step Execution | +|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| +| **Wan2.2** | ❌ | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | +| **LTX-2** | ❌ | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | +| **Helios** | ❌ | ❌ | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | ❌ | +| **HunyuanVideo-1.5 T2V I2V** | ❌ | βœ… | ❌ | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | +| **DreamID-Omni** | ❌ | ❌ | ❌ | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ### AudioGen -| Model | ⚑TeaCache | ⚑Cache-DiT | πŸ”€SP (Ulysses & Ring) | πŸ”€CFG-Parallel | πŸ”€Tensor-Parallel | πŸ”€HSDP | πŸ’ΎCPU Offload (Layerwise) | πŸ’ΎVAE-Patch-Parallel | πŸ’ΎQuantization | -|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:| -| **Stable-Audio-Open** | ❌ | ❌ | ❓ | ❓ | ❌ | ❌ | ❌ | ❌ | βœ… | +| Model | ⚑TeaCache | ⚑Cache-DiT | πŸ”€SP (Ulysses & Ring) | πŸ”€CFG-Parallel | πŸ”€Tensor-Parallel | πŸ”€HSDP | πŸ’ΎCPU Offload (Layerwise) | πŸ’ΎVAE-Patch-Parallel | πŸ’ΎQuantization | πŸ”„Step Execution | +|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| +| **Stable-Audio-Open** | ❌ | ❌ | ❓ | ❓ | ❌ | ❌ | ❌ | ❌ | βœ… | ❌ | ## Feature Compatibility @@ -139,21 +150,22 @@ The following tables show which models support each feature: - ❌: No support plan - ❓: Not verified yet and Not Recommended -| | ⚑TeaCache | ⚑Cache-DiT | πŸ”€Ulysses-SP | πŸ”€Ring-Attn | πŸ”€CFG-Parallel | πŸ”€Tensor Parallel | πŸ”€HSDP | πŸ”€Expert Parallel | πŸ’ΎCPU Offloading (Layerwise) | πŸ’ΎCPU Offloading (Module-wise) | πŸ’ΎVAE Patch Parallel | πŸ’ΎFP8 Quant | πŸ”§LoRA Inference | -|---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| -| **⚑TeaCache** | | | | | | | | | | | | | | -| **⚑Cache-DiT** | ❌ | | | | | | | | | | | | | -| **πŸ”€Ulysses-SP** | βœ… | βœ… | | | | | | | | | | | | -| **πŸ”€Ring-Attn** | βœ… | βœ… | βœ… | | | | | | | | | | | -| **πŸ”€CFG-Parallel** | βœ… | βœ… | βœ… | βœ… | | | | | | | | | | -| **πŸ”€Tensor Parallel** | βœ… | βœ… | βœ… | βœ… | βœ… | | | | | | | | | -| **πŸ”€HSDP** | ❓ | ❓ | ❓ | ❓ | ❓ | ❌ | | | | | | | | -| **πŸ”€Expert Parallel** | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | | | | | | | -| **πŸ’ΎCPU Offloading (Layerwise)** | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | -| **πŸ’ΎCPU Offloading (Module-wise)** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❓ | ❓ | ❌ | | | | | -| **πŸ’ΎVAE Patch Parallel** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | | | | -| **πŸ’ΎFP8 Quant** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❓ | ❓ | βœ… | βœ… | βœ… | | | -| **πŸ”§LoRA Inference** | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | | +| | ⚑TeaCache | ⚑Cache-DiT | πŸ”€Ulysses-SP | πŸ”€Ring-Attn | πŸ”€CFG-Parallel | πŸ”€Tensor Parallel | πŸ”€HSDP | πŸ”€Expert Parallel | πŸ’ΎCPU Offloading (Layerwise) | πŸ’ΎCPU Offloading (Module-wise) | πŸ’ΎVAE Patch Parallel | πŸ’ΎFP8 Quant | πŸ”§LoRA Inference | πŸ”„Step Execution | +|---|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| **⚑TeaCache** | | | | | | | | | | | | | | | +| **⚑Cache-DiT** | ❌ | | | | | | | | | | | | | | +| **πŸ”€Ulysses-SP** | βœ… | βœ… | | | | | | | | | | | | | +| **πŸ”€Ring-Attn** | βœ… | βœ… | βœ… | | | | | | | | | | | | +| **πŸ”€CFG-Parallel** | βœ… | βœ… | βœ… | βœ… | | | | | | | | | | | +| **πŸ”€Tensor Parallel** | βœ… | βœ… | βœ… | βœ… | βœ… | | | | | | | | | | +| **πŸ”€HSDP** | ❓ | ❓ | ❓ | ❓ | ❓ | ❌ | | | | | | | | | +| **πŸ”€Expert Parallel** | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | | | | | | | | +| **πŸ’ΎCPU Offloading (Layerwise)** | βœ… | βœ… | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | | | | | | +| **πŸ’ΎCPU Offloading (Module-wise)** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❓ | ❓ | ❌ | | | | | | +| **πŸ’ΎVAE Patch Parallel** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❌ | ❌ | | | | | +| **πŸ’ΎFP8 Quant** | βœ… | βœ… | βœ… | βœ… | βœ… | βœ… | ❓ | ❓ | βœ… | βœ… | βœ… | | | | +| **πŸ”§LoRA Inference** | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | | | +| **πŸ”„Step Execution** | ❌ | ❌ | βœ… | βœ… | βœ… | βœ… | ❓ | ❓ | βœ… | ❓ | βœ… | βœ… | ❌ | | !!! info @@ -162,6 +174,7 @@ The following tables show which models support each feature: 3. CPU Offloading (Layerwise) and CPU Offloading (Module-wise) are not compatible. 4. CPU Offloading (Layerwise) supports single-card for now. 5. Using FP8-Quant as an example of qunatization methods. + 6. Step Execution is not compatible with cache backends (TeaCache, Cache-DiT) or LoRA. ## Learn More @@ -185,6 +198,10 @@ The following tables show which models support each feature: - **[LoRA Inference Guide](diffusion/lora.md)** - Low-Rank Adaptation for style customization and fine-tuning +**Execution Modes:** + +- **[Step Execution Guide](diffusion/step_execution.md)** - Per-step denoise execution with mid-request abort support + **Advanced Topics:** - **[Feature Compatibility](feature_compatibility.md)** - How to combine multiple features for maximum performance diff --git a/tests/diffusion/test_diffusion_scheduler.py b/tests/diffusion/test_diffusion_scheduler.py index 171a6278cd9..4324ba1e630 100644 --- a/tests/diffusion/test_diffusion_scheduler.py +++ b/tests/diffusion/test_diffusion_scheduler.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import queue import threading +from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +import torch -from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.data import DiffusionOutput, DiffusionRequestAbortedError from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.sched import ( @@ -14,6 +17,7 @@ RequestScheduler, Scheduler, SchedulerInterface, + StepScheduler, ) from vllm_omni.diffusion.sched.interface import CachedRequestData, NewRequestData from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -29,9 +33,46 @@ def _make_request(req_id: str) -> OmniDiffusionRequest: ) -def _make_request_output(req_id: str, *, error: str | None = None) -> DiffusionOutput: - del req_id - return DiffusionOutput(output=None, error=error) +def _make_request_output(req_id: str, *, error: str | None = None, finished: bool = True): + return SimpleNamespace( + req_id=req_id, + step_index=None, + finished=finished, + result=DiffusionOutput(output=None, error=error), + ) + + +def _make_step_output( + req_id: str, + step_index: int, + *, + finished: bool = False, + error: str | None = None, +): + return SimpleNamespace( + req_id=req_id, + step_index=step_index, + finished=finished, + result=DiffusionOutput(output=None, error=error) if error is not None else None, + ) + + +def _make_step_request( + req_id: str, + *, + num_inference_steps: int = 4, + step_index: int | None = None, + sampling_params: OmniDiffusionSamplingParams | None = None, +) -> OmniDiffusionRequest: + return OmniDiffusionRequest( + prompts=[f"prompt_{req_id}"], + sampling_params=sampling_params + or OmniDiffusionSamplingParams( + num_inference_steps=num_inference_steps, + step_index=step_index, + ), + request_ids=[req_id], + ) def _new_ids(sched_output) -> list[str]: @@ -43,7 +84,7 @@ def _cached_ids(sched_output) -> list[str]: class _StubScheduler(SchedulerInterface): - def __init__(self, request: OmniDiffusionRequest, output: DiffusionOutput) -> None: + def __init__(self, request: OmniDiffusionRequest, output) -> None: self._request = request self._output = output self.initialized_with = None @@ -75,9 +116,10 @@ def schedule(self): is_empty=False, ) - def update_from_output(self, sched_output, output: DiffusionOutput) -> set[str]: + def update_from_output(self, sched_output, output) -> set[str]: del sched_output assert output is self._output + self._state.status = DiffusionRequestStatus.FINISHED_COMPLETED return {self._sched_req_id} def has_requests(self) -> bool: @@ -185,9 +227,14 @@ def test_abort_request_for_waiting_and_running(self) -> None: state_b = self.scheduler.get_request_state(req_id_b) assert state_b.status == DiffusionRequestStatus.FINISHED_ABORTED + first = self.scheduler.schedule() + assert first.finished_req_ids == {req_id_b} # A should still run normally. - output_a = self.scheduler.schedule() - assert _new_ids(output_a) == [req_id_a] + assert _new_ids(first) == [req_id_a] + + # B is already marked finished aborted, scheduling again should not pull it. + second = self.scheduler.schedule() + assert second.finished_req_ids == set() # Abort running request. self.scheduler.finish_requests(req_id_a, DiffusionRequestStatus.FINISHED_ABORTED) @@ -233,33 +280,33 @@ def test_add_req_and_wait_for_response_single_path(self) -> None: engine = DiffusionEngine.__new__(DiffusionEngine) engine.scheduler = RequestScheduler() engine.scheduler.initialize(Mock()) - engine.executor = Mock() - engine._rpc_lock = threading.Lock() + engine._rpc_lock = threading.RLock() + engine.abort_queue = queue.Queue() request = _make_request("engine") - expected = DiffusionOutput(output=None) - engine.executor.add_req.return_value = expected + runner_output = _make_request_output("engine") + engine.execute_fn = Mock(return_value=runner_output) output = engine.add_req_and_wait_for_response(request) - assert output is expected - engine.executor.add_req.assert_called_once_with(request) + assert output is runner_output.result + engine.execute_fn.assert_called_once() def test_supports_scheduler_interface_injection(self) -> None: request = _make_request("engine_iface") - expected = DiffusionOutput(output=None) - scheduler = _StubScheduler(request, expected) + runner_output = _make_request_output("engine_iface") + scheduler = _StubScheduler(request, runner_output) engine = DiffusionEngine.__new__(DiffusionEngine) engine.scheduler = scheduler - engine.executor = Mock() - engine.executor.add_req = Mock(return_value=expected) - engine._rpc_lock = threading.Lock() + engine._rpc_lock = threading.RLock() + engine.abort_queue = queue.Queue() + engine.execute_fn = Mock(return_value=runner_output) output = engine.add_req_and_wait_for_response(request) - assert output is expected - engine.executor.add_req.assert_called_once_with(request) + assert output is runner_output.result + engine.execute_fn.assert_called_once() def test_initializes_injected_scheduler(self) -> None: request = _make_request("init") @@ -289,6 +336,59 @@ def test_scheduler_alias_keeps_default_request_scheduler(self) -> None: assert req_id in finished assert scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + def test_step_raises_aborted_error(self) -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.pre_process_func = None + engine.add_req_and_wait_for_response = Mock( + return_value=DiffusionOutput(aborted=True, abort_message="Request req-abort aborted.") + ) + + with pytest.raises(DiffusionRequestAbortedError, match="Request req-abort aborted"): + engine.step(_make_request("req-abort")) + + def test_abort_queue_marks_request_finished_aborted(self) -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.scheduler = RequestScheduler() + engine.scheduler.initialize(Mock()) + engine.abort_queue = queue.Queue() + + req_id = engine.scheduler.add_request(_make_request("req-abort")) + engine.abort("req-abort") + engine._process_aborts_queue() + + assert engine.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_ABORTED + + def test_finalize_finished_request_returns_aborted_output(self) -> None: + engine = DiffusionEngine.__new__(DiffusionEngine) + engine.scheduler = RequestScheduler() + engine.scheduler.initialize(Mock()) + + req_id = engine.scheduler.add_request(_make_request("req-finalize")) + engine.scheduler.finish_requests(req_id, DiffusionRequestStatus.FINISHED_ABORTED) + + output = engine._finalize_finished_request(req_id) + + assert output.aborted is True + assert output.abort_message == "Request req-finalize aborted." + + def test_initializes_step_scheduler_when_step_execution_enabled(self) -> None: + od_config = Mock(model_class_name="mock_model") + od_config.step_execution = True + fake_executor = Mock() + fake_executor_cls = Mock(return_value=fake_executor) + + with ( + patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_post_process_func", return_value=None), + patch("vllm_omni.diffusion.diffusion_engine.get_diffusion_pre_process_func", return_value=None), + patch("vllm_omni.diffusion.diffusion_engine.DiffusionExecutor.get_class", return_value=fake_executor_cls), + patch.object(DiffusionEngine, "_dummy_run", return_value=None), + ): + engine = DiffusionEngine(od_config) + + assert isinstance(engine.scheduler, StepScheduler) + assert engine.execute_fn is fake_executor.execute_step + fake_executor_cls.assert_called_once_with(od_config) + def test_dummy_run_raises_on_output_error(self) -> None: engine = DiffusionEngine.__new__(DiffusionEngine) engine.od_config = Mock(model_class_name="mock_model") @@ -297,3 +397,240 @@ def test_dummy_run_raises_on_output_error(self) -> None: with pytest.raises(RuntimeError, match="Dummy run failed: boom"): engine._dummy_run() + + +class TestStepScheduler: + def setup_method(self) -> None: + self.scheduler: StepScheduler = StepScheduler() + self.scheduler.initialize(Mock()) + + def test_single_request_step_lifecycle(self) -> None: + request = _make_step_request("step", num_inference_steps=3) + req_id = self.scheduler.add_request(request) + + first = self.scheduler.schedule() + assert _new_ids(first) == [req_id] + assert _cached_ids(first) == [] + assert first.num_running_reqs == 1 + assert first.num_waiting_reqs == 0 + + finished = self.scheduler.update_from_output(first, _make_step_output(req_id, step_index=1)) + assert finished == set() + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.RUNNING + assert request.sampling_params.step_index == 1 + assert self.scheduler.has_requests() is True + + second = self.scheduler.schedule() + assert _new_ids(second) == [] + assert _cached_ids(second) == [req_id] + assert second.num_running_reqs == 1 + assert second.num_waiting_reqs == 0 + + finished = self.scheduler.update_from_output(second, _make_step_output(req_id, step_index=2)) + assert finished == set() + assert request.sampling_params.step_index == 2 + + third = self.scheduler.schedule() + assert _new_ids(third) == [] + assert _cached_ids(third) == [req_id] + + finished = self.scheduler.update_from_output( + third, + _make_step_output(req_id, step_index=3, finished=True), + ) + assert finished == {req_id} + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert request.sampling_params.step_index == 3 + assert self.scheduler.has_requests() is False + + def test_fifo_single_request_scheduling(self) -> None: + req_id_a = self.scheduler.add_request(_make_step_request("a", num_inference_steps=2)) + req_id_b = self.scheduler.add_request(_make_step_request("b", num_inference_steps=2)) + + first = self.scheduler.schedule() + assert _new_ids(first) == [req_id_a] + assert _cached_ids(first) == [] + assert first.num_running_reqs == 1 + assert first.num_waiting_reqs == 1 + + finished = self.scheduler.update_from_output(first, _make_step_output(req_id_a, step_index=1)) + assert finished == set() + + second = self.scheduler.schedule() + assert _new_ids(second) == [] + assert _cached_ids(second) == [req_id_a] + assert second.num_running_reqs == 1 + assert second.num_waiting_reqs == 1 + + finished = self.scheduler.update_from_output( + second, + _make_step_output(req_id_a, step_index=2, finished=True), + ) + assert finished == {req_id_a} + + third = self.scheduler.schedule() + assert _new_ids(third) == [req_id_b] + assert _cached_ids(third) == [] + assert third.num_running_reqs == 1 + assert third.num_waiting_reqs == 0 + + def test_error_output_marks_finished_error(self) -> None: + req_id = self.scheduler.add_request(_make_step_request("err", num_inference_steps=3)) + + sched_output = self.scheduler.schedule() + assert _new_ids(sched_output) == [req_id] + finished = self.scheduler.update_from_output( + sched_output, + _make_step_output(req_id, step_index=1, finished=True, error="worker failed"), + ) + + assert finished == {req_id} + state = self.scheduler.get_request_state(req_id) + assert state.status == DiffusionRequestStatus.FINISHED_ERROR + assert state.error == "worker failed" + assert self.scheduler.has_requests() is False + + def test_missing_step_index_marks_finished_error(self) -> None: + req_id = self.scheduler.add_request(_make_step_request("missing", num_inference_steps=3)) + + sched_output = self.scheduler.schedule() + finished = self.scheduler.update_from_output( + sched_output, + SimpleNamespace( + req_id=req_id, + step_index=None, + finished=True, + result=None, + ), + ) + + assert finished == {req_id} + state = self.scheduler.get_request_state(req_id) + assert state.status == DiffusionRequestStatus.FINISHED_ERROR + assert state.error == "Missing step_index in RunnerOutput" + + def test_abort_request_for_waiting_and_running(self) -> None: + req_id_a = self.scheduler.add_request(_make_step_request("a", num_inference_steps=2)) + req_id_b = self.scheduler.add_request(_make_step_request("b", num_inference_steps=2)) + + self.scheduler.finish_requests(req_id_b, DiffusionRequestStatus.FINISHED_ABORTED) + assert self.scheduler.get_request_state(req_id_b).status == DiffusionRequestStatus.FINISHED_ABORTED + + running = self.scheduler.schedule() + assert _new_ids(running) == [req_id_a] + + self.scheduler.finish_requests(req_id_a, DiffusionRequestStatus.FINISHED_ABORTED) + assert self.scheduler.get_request_state(req_id_a).status == DiffusionRequestStatus.FINISHED_ABORTED + assert self.scheduler.has_requests() is False + + def test_has_requests_state_transition(self) -> None: + assert self.scheduler.has_requests() is False + + req_id = self.scheduler.add_request(_make_step_request("has", num_inference_steps=2)) + assert self.scheduler.has_requests() is True + + sched_output = self.scheduler.schedule() + assert self.scheduler.has_requests() is True + + finished = self.scheduler.update_from_output( + sched_output, + _make_step_output(req_id, step_index=2, finished=True), + ) + assert finished == {req_id} + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + assert self.scheduler.has_requests() is False + + def test_scheduled_request_aborted_before_update_is_returned_finished(self) -> None: + req_id = self.scheduler.add_request(_make_step_request("abort-late", num_inference_steps=2)) + + sched_output = self.scheduler.schedule() + self.scheduler.finish_requests(req_id, DiffusionRequestStatus.FINISHED_ABORTED) + + finished = self.scheduler.update_from_output( + sched_output, + _make_step_output(req_id, step_index=1), + ) + assert finished == {req_id} + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_ABORTED + + def test_preempt_request_preserves_step_index(self) -> None: + request = _make_step_request("preempt", num_inference_steps=3) + req_id = self.scheduler.add_request(request) + + first = self.scheduler.schedule() + assert self.scheduler.update_from_output(first, _make_step_output(req_id, step_index=1)) == set() + assert request.sampling_params.step_index == 1 + + second = self.scheduler.schedule() + assert _cached_ids(second) == [req_id] + assert self.scheduler.preempt_request(req_id) is True + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.PREEMPTED + assert request.sampling_params.step_index == 1 + + third = self.scheduler.schedule() + assert _cached_ids(third) == [req_id] + assert request.sampling_params.step_index == 1 + + @pytest.mark.parametrize( + ("sampling_params", "expected_steps"), + [ + ( + OmniDiffusionSamplingParams( + timesteps=torch.tensor([1.0, 0.5, 0.0]), + sigmas=[1.0, 0.5, 0.25, 0.0], + num_inference_steps=5, + ), + 3, + ), + ( + OmniDiffusionSamplingParams( + sigmas=[1.0, 0.5], + num_inference_steps=5, + ), + 2, + ), + ( + OmniDiffusionSamplingParams( + num_inference_steps=4, + ), + 4, + ), + ], + ) + def test_total_steps_priority(self, sampling_params: OmniDiffusionSamplingParams, expected_steps: int) -> None: + request = _make_step_request("priority", sampling_params=sampling_params) + req_id = self.scheduler.add_request(request) + + for _ in range(expected_steps - 1): + sched_output = self.scheduler.schedule() + assert sched_output.scheduled_req_ids == [req_id] + next_step = request.sampling_params.step_index + 1 + assert ( + self.scheduler.update_from_output( + sched_output, + _make_step_output(req_id, step_index=next_step), + ) + == set() + ) + + final_output = self.scheduler.schedule() + assert final_output.scheduled_req_ids == [req_id] + assert self.scheduler.update_from_output( + final_output, + _make_step_output(req_id, step_index=expected_steps, finished=True), + ) == {req_id} + assert self.scheduler.get_request_state(req_id).status == DiffusionRequestStatus.FINISHED_COMPLETED + + @pytest.mark.parametrize( + "sampling_params", + [ + OmniDiffusionSamplingParams(num_inference_steps=0), + OmniDiffusionSamplingParams(num_inference_steps=3, step_index=3), + OmniDiffusionSamplingParams(num_inference_steps=3, step_index=-1), + ], + ) + def test_rejects_invalid_initial_step_state(self, sampling_params: OmniDiffusionSamplingParams) -> None: + request = _make_step_request("invalid", sampling_params=sampling_params) + + with pytest.raises(ValueError): + self.scheduler.add_request(request) diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py index ad08487fe9c..68aba9ba3bf 100644 --- a/tests/diffusion/test_diffusion_step_pipeline.py +++ b/tests/diffusion/test_diffusion_step_pipeline.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for step-level diffusion runner and worker execution.""" +"""Tests for step-level diffusion execution across runner / worker / executor / engine.""" import os +import queue +import threading from contextlib import contextmanager from types import SimpleNamespace +from unittest.mock import Mock import pytest import torch @@ -12,6 +15,7 @@ import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module from tests.utils import hardware_test from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.comm import RingComm, SeqAllToAll4D from vllm_omni.diffusion.distributed.parallel_state import ( @@ -20,10 +24,13 @@ init_distributed_environment, initialize_model_parallel, ) +from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor from vllm_omni.diffusion.ipc import ( pack_diffusion_output_shm, unpack_diffusion_output_shm, ) +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched import StepScheduler from vllm_omni.diffusion.sched.interface import ( CachedRequestData, DiffusionSchedulerOutput, @@ -32,6 +39,8 @@ from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker from vllm_omni.diffusion.worker.utils import RunnerOutput +from vllm_omni.engine.async_omni_engine import AsyncOmniEngine +from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] @@ -86,6 +95,23 @@ def post_decode(self, state, **kwargs): return DiffusionOutput(output=torch.tensor([state.step_index], dtype=torch.float32)) +class _InterruptingStepPipeline(_StepPipeline): + interrupt = True + + def denoise_step(self, state, **kwargs): + del state, kwargs + self.denoise_calls += 1 + return None + + def step_scheduler(self, state, noise_pred, **kwargs): + del state, noise_pred, kwargs + raise AssertionError("step_scheduler should not run after interrupt") + + def post_decode(self, state, **kwargs): + del state, kwargs + raise AssertionError("post_decode should not run after interrupt") + + class _IdentityNoiseTransformer(torch.nn.Module): def forward(self, x: torch.Tensor, **kwargs): del kwargs @@ -188,6 +214,21 @@ def _make_step_request(num_inference_steps: int = 2): ) +def _assert_aborted_output(output: DiffusionOutput, request_id: str) -> None: + assert output.output is None + assert output.error is None + assert output.aborted is True + assert output.abort_message == f"Request {request_id} aborted." + + +def _make_engine_request(req_id: str = "req-1", num_inference_steps: int = 2) -> OmniDiffusionRequest: + return OmniDiffusionRequest( + prompts=[f"prompt-{req_id}"], + sampling_params=OmniDiffusionSamplingParams(num_inference_steps=num_inference_steps), + request_ids=[req_id], + ) + + def _make_runner(): runner = object.__new__(DiffusionModelRunner) runner.vllm_config = object() @@ -242,6 +283,18 @@ def _make_cached_scheduler_output(sched_req_id="req-1", step_id=1, finished_req_ ) +def _make_engine(scheduler, execute_fn=None) -> DiffusionEngine: + engine = object.__new__(DiffusionEngine) + engine.od_config = SimpleNamespace(model_class_name="QwenImagePipeline") + engine.pre_process_func = None + engine.post_process_func = None + engine.scheduler = scheduler + engine.execute_fn = execute_fn + engine._rpc_lock = threading.RLock() + engine.abort_queue = queue.Queue() + return engine + + def _expected_output_for_mode(mode: str) -> torch.Tensor: if mode == "cfg": return torch.tensor([[3.0]]) @@ -322,6 +375,52 @@ def test_completes_request_and_clears_state(self, monkeypatch): assert runner.pipeline.scheduler_calls == 2 assert runner.pipeline.decode_calls == 1 + def test_rejects_multi_request_step_batch(self): + runner = _make_runner() + req_1 = _make_step_request() + req_2 = _make_step_request() + req_2.request_ids = ["req-2"] + + scheduler_output = DiffusionSchedulerOutput( + step_id=0, + scheduled_new_reqs=[ + NewRequestData(sched_req_id="req-1", req=req_1), + NewRequestData(sched_req_id="req-2", req=req_2), + ], + scheduled_cached_reqs=CachedRequestData.make_empty(), + finished_req_ids=set(), + num_running_reqs=2, + num_waiting_reqs=0, + ) + + with pytest.raises(ValueError, match="batch_size=1"): + DiffusionModelRunner.execute_stepwise(runner, scheduler_output) + + def test_rejects_missing_cached_state(self): + runner = _make_runner() + + with pytest.raises(ValueError, match="Missing cached state"): + DiffusionModelRunner.execute_stepwise(runner, _make_cached_scheduler_output(sched_req_id="req-missing")) + + def test_interrupt_marks_request_finished_and_clears_state(self, monkeypatch): + runner = _make_runner() + runner.pipeline = _InterruptingStepPipeline() + req = _make_step_request() + monkeypatch.setattr(model_runner_module, "set_forward_context", _noop_forward_context) + + output = DiffusionModelRunner.execute_stepwise(runner, _make_scheduler_output(req, step_id=0)) + + assert output.req_id == "req-1" + assert output.step_index == 0 + assert output.finished is True + assert output.result is not None + assert output.result.error == "stepwise denoise interrupted" + assert "req-1" not in runner.state_cache + assert runner.pipeline.prepare_calls == 1 + assert runner.pipeline.denoise_calls == 1 + assert runner.pipeline.scheduler_calls == 0 + assert runner.pipeline.decode_calls == 0 + def test_load_model_rejects_unsupported_step_execution(self, monkeypatch): class _RequestOnlyPipeline: pass @@ -439,6 +538,153 @@ def test_rejects_lora_requests_in_step_mode(self): DiffusionWorker.execute_stepwise(worker, scheduler_output) +@pytest.mark.cpu +class TestExecutor: + """MultiprocDiffusionExecutor.execute_step""" + + def test_execute_step_passes_through_runner_output(self): + executor = object.__new__(MultiprocDiffusionExecutor) + executor._ensure_open = lambda: None + expected = RunnerOutput(req_id="req-step", step_index=1, finished=False, result=None) + executor.collective_rpc = Mock(return_value=expected) + + request = _make_engine_request("req-step", num_inference_steps=2) + scheduler_output = _make_scheduler_output(request, sched_req_id="req-step") + + output = MultiprocDiffusionExecutor.execute_step(executor, scheduler_output) + + assert output is expected + + +@pytest.mark.cpu +class TestEngine: + """Step-execution paths in DiffusionEngine.add_req_and_wait_for_response""" + + @pytest.mark.parametrize( + ("execute_fn", "expected_error"), + [ + ( + lambda _: RunnerOutput( + req_id="req-error", + step_index=1, + finished=True, + result=DiffusionOutput(error="boom"), + ), + "boom", + ), + ( + lambda _: (_ for _ in ()).throw(RuntimeError("gpu on fire")), + "gpu on fire", + ), + ], + ) + def test_step_engine_returns_error(self, execute_fn, expected_error): + scheduler = StepScheduler() + scheduler.initialize(Mock()) + engine = _make_engine(scheduler, execute_fn=execute_fn) + + output = engine.add_req_and_wait_for_response(_make_engine_request("req-error", num_inference_steps=2)) + + assert output.output is None + assert expected_error in output.error + + def test_step_execution_completes(self): + scheduler = StepScheduler() + scheduler.initialize(Mock()) + engine = _make_engine(scheduler) + request = _make_engine_request("req-step", num_inference_steps=2) + + call_count = {"n": 0} + + def execute_fn(_): + call_count["n"] += 1 + finished = call_count["n"] == 2 + return RunnerOutput( + req_id="req-step", + step_index=call_count["n"], + finished=finished, + result=(DiffusionOutput(output=torch.tensor([2.0])) if finished else None), + ) + + engine.execute_fn = execute_fn + + output = engine.add_req_and_wait_for_response(request) + + assert call_count["n"] == 2 + assert output.error is None + assert torch.equal(output.output, torch.tensor([2.0])) + + def test_step_abort_stops_rescheduling_after_first_step(self): + scheduler = StepScheduler() + scheduler.initialize(Mock()) + engine = _make_engine(scheduler) + request = _make_engine_request("req-stop", num_inference_steps=4) + + step = {"n": 0} + + def execute_fn(_): + step["n"] += 1 + engine.abort("req-stop") + return RunnerOutput( + req_id="req-stop", + step_index=1, + finished=False, + result=None, + ) + + engine.execute_fn = execute_fn + + output = engine.add_req_and_wait_for_response(request) + + assert step["n"] == 1 + _assert_aborted_output(output, "req-stop") + + def test_step_abort_after_reschedule_returns_aborted_output(self): + scheduler = StepScheduler() + scheduler.initialize(Mock()) + engine = _make_engine(scheduler) + request = _make_engine_request("req-mid", num_inference_steps=4) + + step = {"n": 0} + + def execute_fn(sched_output): + step["n"] += 1 + if step["n"] == 2: + assert sched_output == _make_cached_scheduler_output("req-mid", step_id=1) + engine.abort("req-mid") + return RunnerOutput( + req_id="req-mid", + step_index=step["n"], + finished=False, + result=None, + ) + + engine.execute_fn = execute_fn + + output = engine.add_req_and_wait_for_response(request) + + assert step["n"] == 2 + _assert_aborted_output(output, "req-mid") + + def test_finished_step_without_result_returns_error(self): + scheduler = StepScheduler() + scheduler.initialize(Mock()) + engine = _make_engine( + scheduler, + execute_fn=lambda _: RunnerOutput( + req_id="req-missing", + step_index=1, + finished=True, + result=None, + ), + ) + + output = engine.add_req_and_wait_for_response(_make_engine_request("req-missing", num_inference_steps=1)) + + assert output.output is None + assert output.error == "Diffusion execution finished without a final output." + + @pytest.mark.cpu class TestIPC: def test_pack_unpack_runner_output_shm(self): @@ -458,6 +704,15 @@ def test_pack_unpack_runner_output_shm(self): class TestSupportedPipelines: """Step-execution protocol checks for supported pipelines.""" + def test_default_stage_config_includes_step_execution(self): + stage_cfg = AsyncOmniEngine._create_default_diffusion_stage_cfg( + { + "step_execution": True, + } + )[0] + + assert stage_cfg["engine_args"]["step_execution"] is True + def test_qwen_image_supports_step_execution(self): from vllm_omni.diffusion.models.interface import SupportsStepExecution, supports_step_execution from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image import QwenImagePipeline diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py index adb8dc338c6..517f98ddaa9 100644 --- a/tests/diffusion/test_multiproc_engine_concurrency.py +++ b/tests/diffusion/test_multiproc_engine_concurrency.py @@ -66,7 +66,9 @@ def _make_engine(num_gpus: int = 1): sched.initialize(Mock()) engine.scheduler = sched engine.executor = executor - engine._rpc_lock = threading.Lock() + engine._rpc_lock = threading.RLock() + engine.abort_queue = queue.Queue() + engine.execute_fn = executor.execute_request return engine, executor, req_q, res_q @@ -80,7 +82,7 @@ def _run(): req = req_q.get(timeout=10) method = req.get("method", "") args = req.get("args", ()) - if method == "generate" and args and hasattr(args[0], "request_ids"): + if method in {"generate", "execute_model"} and args and hasattr(args[0], "request_ids"): tag = f"result_for_{args[0].request_ids[0]}" elif args: tag = f"result_for_{args[0]}" @@ -116,11 +118,11 @@ def _controlled(item): return a_enqueued, b_complete -# ──────────────────── bug-reproduction: concurrent add_req ──────────────── +# ───────────────── concurrent request execution ───────────────── -class TestConcurrentAddReqBug: - """Two concurrent ``add_req_and_wait_for_response()`` calls swap results.""" +class TestConcurrentRequestExecution: + """Concurrent request execution should not swap results.""" def test_results_are_correctly_routed(self): engine, executor, req_q, res_q = _make_engine() @@ -151,11 +153,11 @@ def _b(): assert results["B"].error == "result_for_B" -# ──────────────── bug-reproduction: concurrent collective_rpc ───────────── +# ───────────────── concurrent collective RPC ───────────────── -class TestConcurrentCollectiveRpcBug: - """Two concurrent ``collective_rpc()`` calls swap results.""" +class TestConcurrentCollectiveRpc: + """Concurrent ``collective_rpc()`` calls should not swap results.""" def test_results_are_correctly_routed(self): engine, executor, req_q, res_q = _make_engine() @@ -192,11 +194,11 @@ def _b(): assert results["B"].error == "result_for_call_B" -# ──────── bug-reproduction: add_req vs collective_rpc concurrently ──────── +# ──────────── concurrent request execution and collective RPC ──────────── -class TestConcurrentAddReqVsCollectiveRpcBug: - """``add_req`` and ``collective_rpc`` running concurrently swap results.""" +class TestConcurrentRequestExecutionAndCollectiveRpc: + """Request execution and ``collective_rpc()`` should not swap results.""" def test_results_are_correctly_routed(self): engine, executor, req_q, res_q = _make_engine() @@ -205,7 +207,7 @@ def test_results_are_correctly_routed(self): results: dict[str, object] = {} - def _a(): # add_req path + def _a(): # request execution path results["A"] = engine.add_req_and_wait_for_response(_mock_request("A")) def _b(): # collective_rpc path @@ -230,10 +232,10 @@ def _b(): # collective_rpc path assert results["B"].error == "result_for_call_B" -# ─────────────── backward-compatibility (serial) tests ──────────────────── +# ─────────────────────── serial operation coverage ─────────────────────── -class TestSerialOperations: +class TestSerialEngineOperations: """Verify correct behaviour for single-threaded (serial) usage. These tests must pass both **before** and **after** any concurrency fix @@ -385,18 +387,18 @@ def _hanging_dequeue(timeout=None): executor._result_mq.dequeue = _hanging_dequeue - # Thread running add_req β€” acquires the lock, enqueues, then + # Thread running request execution β€” acquires the lock, enqueues, then # blocks on dequeue forever (worker hang). - def _stalled_add_req(): + def _stalled_request_execution(): try: engine.add_req_and_wait_for_response(_mock_request("stalled")) except Exception: pass - t = threading.Thread(target=_stalled_add_req, daemon=True) + t = threading.Thread(target=_stalled_request_execution, daemon=True) t.start() - # Wait until add_req is truly inside the lock and blocking. + # Wait until request execution is truly inside the lock and blocking. add_req_blocked.wait(5) # collective_rpc should time out at lock acquisition, not hang. diff --git a/tests/e2e/online_serving/test_qwen_image_expansion.py b/tests/e2e/online_serving/test_qwen_image_expansion.py index e5bcde417e3..6d6d236016b 100644 --- a/tests/e2e/online_serving/test_qwen_image_expansion.py +++ b/tests/e2e/online_serving/test_qwen_image_expansion.py @@ -28,6 +28,11 @@ def _get_diffusion_feature_cases(model: str): return [ + pytest.param( + OmniServerParams(model=model, server_args=["--step-execution"]), + id="step_execution", + marks=SINGLE_CARD_FEATURE_MARKS, + ), pytest.param( OmniServerParams(model=model, server_args=["--cache-backend", "tea_cache"]), id="cache_tea_cache", diff --git a/tests/entrypoints/test_async_omni_abort.py b/tests/entrypoints/test_async_omni_abort.py new file mode 100644 index 00000000000..71f3e99feb4 --- /dev/null +++ b/tests/entrypoints/test_async_omni_abort.py @@ -0,0 +1,85 @@ +import asyncio +from types import SimpleNamespace + +import pytest + +from vllm_omni.entrypoints.async_omni import AsyncOmni + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_generate_accepts_request_after_repeated_cancellations(): + async def run_test(): + submitted_request_ids = [] + aborted_request_batches = [] + + async def fake_add_request_async(*, request_id, prompt, sampling_params_list, final_stage_id): + del prompt, sampling_params_list, final_stage_id + submitted_request_ids.append(request_id) + + async def fake_abort_async(request_ids): + aborted_request_batches.append(list(request_ids)) + + async def fake_process_results(request_id, metrics, final_stage_id_for_e2e, req_start_ts, wall_start_ts): + del metrics, final_stage_id_for_e2e, req_start_ts, wall_start_ts + if request_id.startswith("cancel-"): + await asyncio.Future() + return + yield SimpleNamespace( + stage_id=0, + request_output=SimpleNamespace(outputs=[]), + finished=True, + ) + + async def collect_outputs(request_id): + outputs = [] + async for output in AsyncOmni.generate( + omni, + prompt={"prompt": "prompt"}, + request_id=request_id, + sampling_params_list=[SimpleNamespace()], + output_modalities=["image"], + ): + outputs.append(output) + return outputs + + omni = object.__new__(AsyncOmni) + omni._pause_cond = asyncio.Condition() + omni._paused = False + omni.engine = SimpleNamespace( + num_stages=1, + add_request_async=fake_add_request_async, + abort_async=fake_abort_async, + ) + omni.log_stats = False + omni.request_states = {} + omni._final_output_handler = lambda: None + omni.resolve_sampling_params_list = lambda params: params + omni._compute_final_stage_id = lambda output_modalities: 0 + omni._process_orchestrator_results = fake_process_results + omni._log_summary_and_cleanup = lambda request_id: omni.request_states.pop(request_id, None) + + assert len(await collect_outputs("baseline")) == 1 + + for idx in range(3): + task = asyncio.create_task(collect_outputs(f"cancel-{idx}")) + await asyncio.sleep(0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert len(await collect_outputs("after-cancel")) == 1 + assert submitted_request_ids == [ + "baseline", + "cancel-0", + "cancel-1", + "cancel-2", + "after-cancel", + ] + assert aborted_request_batches == [ + ["cancel-0"], + ["cancel-1"], + ["cancel-2"], + ] + + asyncio.run(run_test()) diff --git a/tests/entrypoints/test_async_omni_diffusion.py b/tests/entrypoints/test_async_omni_diffusion.py index c0eae0992fd..c8aaae4f942 100644 --- a/tests/entrypoints/test_async_omni_diffusion.py +++ b/tests/entrypoints/test_async_omni_diffusion.py @@ -1,9 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +from types import SimpleNamespace +from unittest.mock import Mock + import pytest +import vllm_omni.diffusion.stage_diffusion_client as stage_diffusion_client_module +from vllm_omni.diffusion.data import DiffusionRequestAbortedError +from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion +from vllm_omni.inputs.data import OmniDiffusionSamplingParams pytestmark = [pytest.mark.core_model, pytest.mark.cpu] @@ -13,3 +23,91 @@ def test_get_diffusion_od_config_returns_direct_config(): diffusion.od_config = object() assert diffusion.get_diffusion_od_config() is diffusion.od_config + + +def test_async_omni_diffusion_generate_aborts_engine_on_cancel(): + async def run_test(): + started = threading.Event() + release = threading.Event() + abort = Mock() + + def step(request): + del request + started.set() + release.wait(timeout=5) + return [SimpleNamespace(request_id="req-1")] + + diffusion = object.__new__(AsyncOmniDiffusion) + diffusion.engine = SimpleNamespace(step=step, abort=abort) + diffusion._executor = ThreadPoolExecutor(max_workers=1) + + task = asyncio.create_task( + diffusion.generate( + prompt="hello", + sampling_params=OmniDiffusionSamplingParams(), + request_id="req-1", + ) + ) + try: + assert await asyncio.to_thread(started.wait, 1) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + finally: + release.set() + diffusion._executor.shutdown(wait=True) + + abort.assert_called_once_with("req-1") + + asyncio.run(run_test()) + + +def test_stage_diffusion_client_abort_requests_forwards_to_engine(): + async def run_test(): + aborted_request_ids: list[list[str]] = [] + + async def abort(request_ids): + aborted_request_ids.append(request_ids) + + client = object.__new__(StageDiffusionClient) + client._engine = SimpleNamespace(abort=abort) + client._tasks = {} + + task = asyncio.create_task(asyncio.sleep(60)) + client._tasks["req-1"] = task + + await client.abort_requests_async(["req-1", "req-2"]) + + with pytest.raises(asyncio.CancelledError): + await task + assert client._tasks == {} + assert aborted_request_ids == [["req-1", "req-2"]] + + asyncio.run(run_test()) + + +def test_stage_diffusion_client_run_treats_abort_as_normal_path(monkeypatch): + async def run_test(): + async def generate(prompt, sampling_params, request_id): + del prompt, sampling_params + raise DiffusionRequestAbortedError(f"Request {request_id} aborted.") + + info = Mock() + exception = Mock() + monkeypatch.setattr(stage_diffusion_client_module.logger, "info", info) + monkeypatch.setattr(stage_diffusion_client_module.logger, "exception", exception) + + client = object.__new__(StageDiffusionClient) + client.stage_id = 3 + client._engine = SimpleNamespace(generate=generate) + client._output_queue = asyncio.Queue() + client._tasks = {"req-1": object()} + + await client._run("req-1", "prompt", OmniDiffusionSamplingParams()) + + assert client._output_queue.empty() + assert client._tasks == {} + info.assert_called_once() + exception.assert_not_called() + + asyncio.run(run_test()) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 488378b40ff..12eb5ed3da3 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -492,6 +492,9 @@ class OmniDiffusionConfig: # Step mode settings step_execution: bool = False + # Maximum number of sequences to generate in a batch + max_num_seqs: int = 1 + @property def is_moe(self) -> bool: num_experts = self.tf_model_config.get("num_experts", None) @@ -658,6 +661,8 @@ class DiffusionOutput: trajectory_latents: torch.Tensor | None = None trajectory_decoded: list[torch.Tensor] | None = None error: str | None = None + aborted: bool = False + abort_message: str | None = None post_process_func: Callable[..., Any] | None = None @@ -675,6 +680,10 @@ class DiffusionOutput: peak_memory_mb: float = 0.0 +class DiffusionRequestAbortedError(RuntimeError): + """Raised when a diffusion request ends via user-visible abort.""" + + class AttentionBackendEnum(enum.Enum): FA = enum.auto() SLIDING_TILE_ATTN = enum.auto() diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index ff0f753b404..308c8cef80e 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import queue import threading import time from collections.abc import Iterable @@ -11,7 +14,11 @@ import torch from vllm.logger import init_logger -from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.data import ( + DiffusionOutput, + DiffusionRequestAbortedError, + OmniDiffusionConfig, +) from vllm_omni.diffusion.executor.abstract import DiffusionExecutor from vllm_omni.diffusion.registry import ( DiffusionModelRegistry, @@ -19,7 +26,9 @@ get_diffusion_pre_process_func, ) from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface +from vllm_omni.diffusion.sched import RequestScheduler, SchedulerInterface, StepScheduler +from vllm_omni.diffusion.sched.interface import DiffusionRequestStatus +from vllm_omni.diffusion.worker.utils import RunnerOutput from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.outputs import OmniRequestOutput @@ -72,9 +81,14 @@ def __init__( executor_class = DiffusionExecutor.get_class(od_config) self.executor = executor_class(od_config) - self.scheduler: SchedulerInterface = scheduler or RequestScheduler() + self.step_execution = bool(getattr(od_config, "step_execution", False)) + self.scheduler: SchedulerInterface = scheduler or ( + StepScheduler() if self.step_execution else RequestScheduler() + ) self.scheduler.initialize(od_config) - self._rpc_lock = threading.Lock() + self._rpc_lock = threading.RLock() + self.abort_queue: queue.Queue[str] = queue.Queue() + self.execute_fn = self.executor.execute_step if self.step_execution else self.executor.execute_request try: self._dummy_run() @@ -98,6 +112,8 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: output = self.add_req_and_wait_for_response(request) exec_total_time = time.perf_counter() - exec_start_time + if output.aborted: + raise DiffusionRequestAbortedError(output.abort_message or "Diffusion request aborted.") if output.error: raise Exception(f"{output.error}") logger.info("Generation completed successfully.") @@ -264,7 +280,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: def make_engine( config: OmniDiffusionConfig, scheduler: SchedulerInterface | None = None, - ) -> "DiffusionEngine": + ) -> DiffusionEngine: """Factory method to create a DiffusionEngine instance. Args: @@ -281,8 +297,11 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus # keep scheduling and executing until the target request is finished while True: + self._process_aborts_queue() sched_output = self.scheduler.schedule() if sched_output.is_empty: + if target_sched_req_id in sched_output.finished_req_ids: + return self._finalize_finished_request(target_sched_req_id) if not self.scheduler.has_requests(): raise RuntimeError("Diffusion scheduler has no runnable requests.") continue @@ -292,21 +311,26 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus # vllm_omni/diffusion/sched/base_scheduler.py), so we directly # take the single scheduled request here. sched_req_id = sched_output.scheduled_req_ids[0] - req = sched_output.scheduled_new_reqs[0].req try: - output = self.executor.add_req(req) + runner_output = self.execute_fn(sched_output) except Exception as exc: - logger.error( - "Execution failed for diffusion request %s", - sched_req_id, - exc_info=True, + logger.error("Execution failed for diffusion request %s", sched_req_id, exc_info=True) + runner_output = RunnerOutput( + req_id=sched_req_id, + step_index=None, + finished=True, + result=DiffusionOutput(error=str(exc)), ) - output = DiffusionOutput(error=str(exc)) - finished_req_ids = self.scheduler.update_from_output(sched_output, output) + self._process_aborts_queue() + + finished_req_ids = self.scheduler.update_from_output(sched_output, runner_output) if target_sched_req_id in finished_req_ids: - self.scheduler.pop_request_state(target_sched_req_id) - return output + return self._finalize_finished_request( + target_sched_req_id, + runner_output=runner_output, + missing_result_error="Diffusion execution finished without a final output.", + ) def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: """Start or stop torch profiling on all diffusion workers. @@ -437,6 +461,55 @@ def close(self) -> None: self.executor.shutdown() def abort(self, request_id: str | Iterable[str]) -> None: - # TODO implement it - logger.warning("DiffusionEngine abort is not implemented yet") - pass + request_ids = [request_id] if isinstance(request_id, str) else list(request_id) + for req_id in request_ids: + self.abort_queue.put(req_id) + + def _process_aborts_queue(self) -> None: + if self.abort_queue.empty(): + return + + request_ids: list[str] = [] + while not self.abort_queue.empty(): + ids = self.abort_queue.get_nowait() + request_ids.extend((ids,) if isinstance(ids, str) else ids) + + self._abort_requests(request_ids) + + def _abort_requests(self, request_ids: str | Iterable[str]) -> None: + request_ids = [request_ids] if isinstance(request_ids, str) else list(request_ids) + + sched_req_ids: list[str] = [] + for request_id in dict.fromkeys(request_ids): + sched_req_id = self.scheduler.get_sched_req_id(request_id) + if sched_req_id is not None: + sched_req_ids.append(sched_req_id) + + for sched_req_id in dict.fromkeys(sched_req_ids): + if self.scheduler.get_request_state(sched_req_id) is not None: + self.scheduler.finish_requests(sched_req_id, DiffusionRequestStatus.FINISHED_ABORTED) + + def _finalize_finished_request( + self, + sched_req_id: str, + runner_output: RunnerOutput | None = None, + missing_result_error: str = "Diffusion scheduler finished target request without execution output.", + ) -> DiffusionOutput: + state = self.scheduler.get_request_state(sched_req_id) + popped_state = self.scheduler.pop_request_state(sched_req_id) + state = state or popped_state + + if state is None: + raise RuntimeError(f"Diffusion scheduler lost state for request {sched_req_id}.") + + if state.status == DiffusionRequestStatus.FINISHED_ABORTED: + request_id = state.req.request_ids[0] if state.req.request_ids else sched_req_id + return DiffusionOutput( + aborted=True, + abort_message=f"Request {request_id} aborted.", + ) + + if runner_output is not None and runner_output.result is not None: + return runner_output.result + + return DiffusionOutput(error=missing_result_error) diff --git a/vllm_omni/diffusion/executor/abstract.py b/vllm_omni/diffusion/executor/abstract.py index e41f41d119e..564980f6601 100644 --- a/vllm_omni/diffusion/executor/abstract.py +++ b/vllm_omni/diffusion/executor/abstract.py @@ -1,11 +1,17 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Any from vllm.utils.import_utils import resolve_obj_by_qualname from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.request import OmniDiffusionRequest +if TYPE_CHECKING: + from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput + from vllm_omni.diffusion.worker.utils import RunnerOutput + class DiffusionExecutor(ABC): """Abstract base class for Diffusion executors.""" @@ -13,7 +19,7 @@ class DiffusionExecutor(ABC): uses_multiproc: bool = False @staticmethod - def get_class(od_config: OmniDiffusionConfig) -> type["DiffusionExecutor"]: + def get_class(od_config: OmniDiffusionConfig) -> type[DiffusionExecutor]: executor_class: type[DiffusionExecutor] distributed_executor_backend = od_config.distributed_executor_backend @@ -63,6 +69,16 @@ def add_req(self, requests: OmniDiffusionRequest) -> DiffusionOutput: """Add requests to the execution queue.""" pass + @abstractmethod + def execute_request(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute request-mode work from a scheduler output.""" + pass + + @abstractmethod + def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Execute step-mode work from a scheduler output.""" + pass + @abstractmethod def collective_rpc( self, @@ -71,6 +87,7 @@ def collective_rpc( args: tuple = (), kwargs: dict | None = None, unique_reply_rank: int | None = None, + exec_all_ranks: bool = False, ) -> Any: """Execute a method on workers.""" pass diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index 1756633ba67..e55a464fb4a 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import multiprocessing as mp import time import weakref from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any import zmq from vllm.distributed.device_communicators.shm_broadcast import MessageQueue @@ -14,6 +16,10 @@ from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.worker import WorkerProc +if TYPE_CHECKING: + from vllm_omni.diffusion.sched.interface import DiffusionSchedulerOutput + from vllm_omni.diffusion.worker.utils import RunnerOutput + logger = init_logger(__name__) @@ -190,6 +196,61 @@ def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: logger.error(f"Generate call failed: {e}") raise + def execute_request(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Adapt request-mode scheduler output to worker execute_model RPC.""" + from vllm_omni.diffusion.worker.utils import RunnerOutput + + self._ensure_open() + if scheduler_output.num_scheduled_reqs != 1: + raise ValueError( + f"Request mode currently supports batch_size=1, " + f"but got {scheduler_output.num_scheduled_reqs} scheduled requests." + ) + + new_req = scheduler_output.scheduled_new_reqs[0] + result = self.collective_rpc( + "execute_model", + args=(new_req.req, self.od_config), + unique_reply_rank=0, + exec_all_ranks=True, + ) + if not isinstance(result, DiffusionOutput): + raise RuntimeError(f"Unexpected response type for execute_request: {type(result)!r}") + + return RunnerOutput( + req_id=new_req.sched_req_id, + step_index=None, + finished=True, + result=result, + ) + + def execute_step(self, scheduler_output: DiffusionSchedulerOutput) -> RunnerOutput: + """Forward step-mode scheduler output to worker execute_stepwise RPC.""" + from vllm_omni.diffusion.worker.utils import RunnerOutput + + self._ensure_open() + result = self.collective_rpc( + "execute_stepwise", + args=(scheduler_output,), + unique_reply_rank=0, + exec_all_ranks=True, + ) + + if isinstance(result, RunnerOutput): + return result + # TODO: Remove this fallback; DiffusionOutput cannot faithfully represent + # failed multi-request step batches. + if isinstance(result, DiffusionOutput): + req_id = scheduler_output.scheduled_req_ids[0] if scheduler_output.scheduled_req_ids else "" + return RunnerOutput( + req_id=req_id, + step_index=None, + finished=True, + result=result, + ) + else: + raise RuntimeError(f"Unexpected response type for execute_step: {type(result)!r}") + def collective_rpc( self, method: str, @@ -197,6 +258,7 @@ def collective_rpc( args: tuple = (), kwargs: dict | None = None, unique_reply_rank: int | None = None, + exec_all_ranks: bool = False, ) -> Any: self._ensure_open() @@ -212,7 +274,7 @@ def collective_rpc( "args": args, "kwargs": kwargs, "output_rank": unique_reply_rank if unique_reply_rank is not None else 0, - "exec_all_ranks": unique_reply_rank is None, + "exec_all_ranks": unique_reply_rank is None or exec_all_ranks, } try: @@ -228,6 +290,11 @@ def collective_rpc( try: response = self._result_mq.dequeue(timeout=dequeue_timeout) + try: + unpack_diffusion_output_shm(response) + except Exception as e: + logger.warning("SHM unpack failed (data may already be inline): %s", e) + # Check if response indicates an error if isinstance(response, dict) and response.get("status") == "error": raise RuntimeError( diff --git a/vllm_omni/diffusion/lora/manager.py b/vllm_omni/diffusion/lora/manager.py index 1466a335847..5f75e26cb16 100644 --- a/vllm_omni/diffusion/lora/manager.py +++ b/vllm_omni/diffusion/lora/manager.py @@ -218,10 +218,16 @@ def set_active_adapter(self, lora_request: LoRARequest | None, lora_scale: float lora_scale: The external scale for the LoRA adapter. """ if lora_request is None: + if self._active_adapter_id is None: + logger.debug("No lora_request provided and adapters are already inactive") + return logger.debug("No lora_request provided, deactivating all LoRA adapters") self._deactivate_all_adapters() return elif math.isclose(0.0, lora_scale): + if self._active_adapter_id is None: + logger.debug("Received LoRA scale 0 with adapters already inactive") + return logger.warning("Received a request with LoRA scale 0; deactivating all LoRA adapters") self._deactivate_all_adapters() return @@ -605,6 +611,9 @@ def _activate_adapter(self, adapter_id: int, scale: float) -> None: self._update_adapter_scale(adapter_id, scale) def _deactivate_all_adapters(self) -> None: + if self._active_adapter_id is None: + logger.debug("All adapters already inactive") + return logger.info("Deactivating all adapters: %d layers", len(self._lora_modules)) for lora_layer in self._lora_modules.values(): lora_layer.reset_lora(0) diff --git a/vllm_omni/diffusion/sched/__init__.py b/vllm_omni/diffusion/sched/__init__.py index 650a1a1e6fb..e0263733847 100644 --- a/vllm_omni/diffusion/sched/__init__.py +++ b/vllm_omni/diffusion/sched/__init__.py @@ -10,16 +10,18 @@ SchedulerInterface, ) from vllm_omni.diffusion.sched.request_scheduler import RequestScheduler +from vllm_omni.diffusion.sched.step_scheduler import StepScheduler Scheduler = RequestScheduler __all__ = [ + "DiffusionRequestStatus", "CachedRequestData", "DiffusionRequestState", - "DiffusionRequestStatus", "DiffusionSchedulerOutput", "NewRequestData", + "SchedulerInterface", "RequestScheduler", + "StepScheduler", "Scheduler", - "SchedulerInterface", ] diff --git a/vllm_omni/diffusion/sched/base_scheduler.py b/vllm_omni/diffusion/sched/base_scheduler.py index a59fa50d1ee..6a7ee3d3efd 100644 --- a/vllm_omni/diffusion/sched/base_scheduler.py +++ b/vllm_omni/diffusion/sched/base_scheduler.py @@ -5,13 +5,21 @@ from collections import deque +from vllm.logger import init_logger + from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.sched.interface import ( + CachedRequestData, DiffusionRequestState, DiffusionRequestStatus, + DiffusionSchedulerOutput, + NewRequestData, SchedulerInterface, ) +logger = init_logger(__name__) + class _BaseScheduler(SchedulerInterface): """Shared queue/state bookkeeping for diffusion schedulers.""" @@ -24,8 +32,6 @@ def __init__(self) -> None: self._waiting: deque[str] = deque() self._running: list[str] = [] self._finished_req_ids: set[str] = set() - # The current DiffusionEngine execution mode does not support real - # request batching well, so we keep this fixed at 1 for now. self._max_batch_size: int = 1 def initialize(self, od_config: OmniDiffusionConfig) -> None: @@ -36,8 +42,67 @@ def initialize(self, od_config: OmniDiffusionConfig) -> None: self._waiting.clear() self._running.clear() self._finished_req_ids.clear() + # The current DiffusionEngine execution mode does not support real + # request batching well, so we keep this fixed at 1 for now. + # TODO: Add support for multiple concurrent requests + self.max_num_running_reqs = 1 self._reset_scheduler_state() + def add_request(self, request: OmniDiffusionRequest) -> str: + sched_req_id = self._make_sched_req_id(request) + return self._add_request_with_sched_req_id(sched_req_id, request) + + def _add_request_with_sched_req_id(self, sched_req_id: str, request: OmniDiffusionRequest) -> str: + state = DiffusionRequestState(sched_req_id=sched_req_id, req=request) + self._request_states[sched_req_id] = state + self._register_request_ids(request.request_ids, sched_req_id) + self._waiting.append(sched_req_id) + logger.debug("%s add_request: %s (waiting=%d)", self.__class__.__name__, sched_req_id, len(self._waiting)) + return sched_req_id + + def schedule(self) -> DiffusionSchedulerOutput: + scheduled_new_reqs: list[NewRequestData] = [] + scheduled_cached_req_ids: list[str] = [] + + # First, schedule the RUNNING request(s) + for sched_req_id in self._running: + state = self._request_states.get(sched_req_id) + if state is not None: + scheduled_cached_req_ids.append(sched_req_id) + + # Second, schedule WAITING requests while capacity remains. + while self._waiting and len(self._running) < self.max_num_running_reqs: + sched_req_id = self._waiting[0] + state = self._request_states.get(sched_req_id) + if state is None: + self._waiting.popleft() + continue + if not self._can_schedule_waiting(state): + break + + self._waiting.popleft() + was_new_request = state.status == DiffusionRequestStatus.WAITING + state.status = DiffusionRequestStatus.RUNNING + self._running.append(sched_req_id) + if was_new_request: + scheduled_new_reqs.append(NewRequestData.from_state(state)) + else: + scheduled_cached_req_ids.append(sched_req_id) + + scheduler_output = DiffusionSchedulerOutput( + step_id=self._step_id, + scheduled_new_reqs=scheduled_new_reqs, + scheduled_cached_reqs=CachedRequestData(sched_req_ids=scheduled_cached_req_ids), + finished_req_ids=set(self._finished_req_ids), + num_running_reqs=len(self._running), + num_waiting_reqs=len(self._waiting), + ) + + # update after schedule + self._step_id += 1 + self._finished_req_ids.clear() + return scheduler_output + def has_requests(self) -> bool: return bool(self._waiting or self._running) @@ -121,12 +186,32 @@ def _finish_requests( self._finished_req_ids |= finished_req_ids return finished_req_ids + def _finalize_update_from_output( + self, + sched_output: DiffusionSchedulerOutput, + statuses: dict[str, DiffusionRequestStatus], + errors: dict[str, str | None] | None = None, + ) -> set[str]: + # A scheduled request may be aborted after schedule() but before + # update_from_output() processes the runner output. It is already + # marked finished at that point, but we still need to surface its id + # in this update so the engine can observe the terminal state. + finished_req_ids = { + sched_req_id for sched_req_id in sched_output.scheduled_req_ids if sched_req_id in self._finished_req_ids + } + finished_req_ids |= self._finish_requests(statuses, errors) + return finished_req_ids + def _reset_scheduler_state(self) -> None: """Reset subclass-owned state during initialize()/close().""" def _pop_extra_request_state(self, sched_req_id: str) -> None: """Remove subclass-owned per-request state before popping request state.""" + def _can_schedule_waiting(self, state: DiffusionRequestState) -> bool: + del state + return True + def _register_request_ids(self, request_ids: list[str], sched_req_id: str) -> None: for request_id in request_ids: existing = self._request_id_to_sched_req_id.get(request_id) diff --git a/vllm_omni/diffusion/sched/interface.py b/vllm_omni/diffusion/sched/interface.py index 427cad03d0e..4db6f413558 100644 --- a/vllm_omni/diffusion/sched/interface.py +++ b/vllm_omni/diffusion/sched/interface.py @@ -8,12 +8,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import cached_property +from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.request import OmniDiffusionRequest +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.utils import RunnerOutput + logger = init_logger(__name__) @@ -141,7 +145,7 @@ def schedule(self) -> DiffusionSchedulerOutput: """Run one scheduling cycle.""" @abstractmethod - def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: DiffusionOutput) -> set[str]: + def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput) -> set[str]: """Update scheduler state from executor output.""" @abstractmethod diff --git a/vllm_omni/diffusion/sched/request_scheduler.py b/vllm_omni/diffusion/sched/request_scheduler.py index ed8316ee58f..f641648e96a 100644 --- a/vllm_omni/diffusion/sched/request_scheduler.py +++ b/vllm_omni/diffusion/sched/request_scheduler.py @@ -3,103 +3,48 @@ from __future__ import annotations -from vllm.logger import init_logger +from typing import TYPE_CHECKING -from vllm_omni.diffusion.data import DiffusionOutput from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.sched.base_scheduler import _BaseScheduler from vllm_omni.diffusion.sched.interface import ( - CachedRequestData, - DiffusionRequestState, DiffusionRequestStatus, DiffusionSchedulerOutput, - NewRequestData, ) -logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.utils import RunnerOutput class RequestScheduler(_BaseScheduler): """Diffusion scheduler with vLLM-style waiting/running queues.""" def add_request(self, request: OmniDiffusionRequest) -> str: - sched_req_id = self._make_sched_req_id(request) - state = DiffusionRequestState(sched_req_id=sched_req_id, req=request) - self._request_states[sched_req_id] = state - self._register_request_ids(request.request_ids, sched_req_id) - self._waiting.append(sched_req_id) - logger.debug("Scheduler add_request: %s (waiting=%d)", sched_req_id, len(self._waiting)) - return sched_req_id + return super().add_request(request) def schedule(self) -> DiffusionSchedulerOutput: - scheduled_new_reqs: list[NewRequestData] = [] - scheduled_cached_req_ids: list[str] = [] + return super().schedule() - # First, schedule the RUNNING request(s) - for sched_req_id in self._running: - state = self._request_states.get(sched_req_id) - if state is not None: - scheduled_cached_req_ids.append(sched_req_id) - - # Second, schedule WAITING requests while capacity remains. - while self._waiting and len(self._running) < self._max_batch_size: - sched_req_id = self._waiting.popleft() - state = self._request_states.get(sched_req_id) - if state is None: - continue - was_new_request = state.status == DiffusionRequestStatus.WAITING - state.status = DiffusionRequestStatus.RUNNING - self._running.append(sched_req_id) - if was_new_request: - scheduled_new_reqs.append(NewRequestData.from_state(state)) - else: - scheduled_cached_req_ids.append(sched_req_id) - - scheduler_output = DiffusionSchedulerOutput( - step_id=self._step_id, - scheduled_new_reqs=scheduled_new_reqs, - scheduled_cached_reqs=CachedRequestData(sched_req_ids=scheduled_cached_req_ids), - finished_req_ids=set(self._finished_req_ids), - num_running_reqs=len(self._running), - num_waiting_reqs=len(self._waiting), - ) - - self._step_id += 1 - self._finished_req_ids.clear() - return scheduler_output - - def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: DiffusionOutput) -> set[str]: + def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput) -> set[str]: scheduled_req_ids = sched_output.scheduled_req_ids if not scheduled_req_ids: return set() - # A scheduled request may be aborted after schedule() but before - # update_from_output() processes the runner output. It is already - # marked finished at that point, but we still need to surface its id - # in this update so the engine can observe the terminal state. - finished_req_ids = { - sched_req_id for sched_req_id in scheduled_req_ids if sched_req_id in self._finished_req_ids - } terminal_statuses: dict[str, DiffusionRequestStatus] = {} terminal_errors: dict[str, str | None] = {} - # NOTE: request-mode currently assumes one executor call produces one - # DiffusionOutput for the single scheduled request in this cycle. + result = output.result for sched_req_id in scheduled_req_ids: state = self._request_states.get(sched_req_id) if state is None or state.is_finished(): continue - if output.error: + if result is None: terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_ERROR - terminal_errors[sched_req_id] = output.error + terminal_errors[sched_req_id] = "No output result" + elif result.error: + terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[sched_req_id] = result.error else: terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_COMPLETED terminal_errors[sched_req_id] = None - finished_req_ids |= self._finish_requests(terminal_statuses, terminal_errors) - return finished_req_ids - - def abort_request(self, sched_req_id: str) -> bool: - if self.get_request_state(sched_req_id) is None: - return False - self.finish_requests(sched_req_id, DiffusionRequestStatus.FINISHED_ABORTED) - return True + return self._finalize_update_from_output(sched_output, terminal_statuses, terminal_errors) diff --git a/vllm_omni/diffusion/sched/step_scheduler.py b/vllm_omni/diffusion/sched/step_scheduler.py new file mode 100644 index 00000000000..4d995dcf40a --- /dev/null +++ b/vllm_omni/diffusion/sched/step_scheduler.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from vllm.logger import init_logger + +from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.sched.base_scheduler import _BaseScheduler +from vllm_omni.diffusion.sched.interface import ( + DiffusionRequestStatus, + DiffusionSchedulerOutput, +) + +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.utils import RunnerOutput + +logger = init_logger(__name__) + + +@dataclass +class _StepProgress: + current_step: int + total_steps: int + + +class StepScheduler(_BaseScheduler): + """Placeholder scheduler that advances a request one denoise step per update.""" + + def __init__(self) -> None: + super().__init__() + self._request_progress: dict[str, _StepProgress] = {} + + def _reset_scheduler_state(self) -> None: + self._request_progress.clear() + + def add_request(self, request: OmniDiffusionRequest) -> str: + sched_req_id = self._make_sched_req_id(request) + total_steps = self._get_total_steps(request) + if total_steps <= 0: + raise ValueError(f"Diffusion request {sched_req_id} must have positive total_steps, got {total_steps}") + + current_step = request.sampling_params.step_index or 0 + if current_step < 0 or current_step >= total_steps: + raise ValueError( + f"Diffusion request {sched_req_id} has invalid initial step_index {current_step} " + f"for total_steps={total_steps}" + ) + + request.sampling_params.step_index = current_step + sched_req_id = self._add_request_with_sched_req_id(sched_req_id, request) + self._request_progress[sched_req_id] = _StepProgress(current_step=current_step, total_steps=total_steps) + logger.debug( + "StepScheduler add_request: %s (step=%d/%d, waiting=%d)", + sched_req_id, + current_step, + total_steps, + len(self._waiting), + ) + return sched_req_id + + def schedule(self) -> DiffusionSchedulerOutput: + return super().schedule() + + def update_from_output(self, sched_output: DiffusionSchedulerOutput, output: RunnerOutput) -> set[str]: + scheduled_req_ids = sched_output.scheduled_req_ids + if not scheduled_req_ids: + return set() + + terminal_statuses: dict[str, DiffusionRequestStatus] = {} + terminal_errors: dict[str, str | None] = {} + output_error = output.result.error if output.result is not None else None + for sched_req_id in scheduled_req_ids: + state = self._request_states.get(sched_req_id) + progress = self._request_progress.get(sched_req_id) + if state is None or progress is None or state.is_finished(): + continue + + if output_error is not None: + terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[sched_req_id] = output_error + continue + + if output.step_index is None: + logger.warning( + "Received RunnerOutput with no step_index for request %s, treating as error", + sched_req_id, + ) + terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_ERROR + terminal_errors[sched_req_id] = "Missing step_index in RunnerOutput" + continue + + # We assume that the decoding stage is executed immediately after the denoising stage completes. + progress.current_step = output.step_index + state.req.sampling_params.step_index = output.step_index + if output.finished: + terminal_statuses[sched_req_id] = DiffusionRequestStatus.FINISHED_COMPLETED + terminal_errors[sched_req_id] = None + else: + state.error = None + + return self._finalize_update_from_output(sched_output, terminal_statuses, terminal_errors) + + def _pop_extra_request_state(self, sched_req_id: str) -> None: + self._request_progress.pop(sched_req_id, None) + + def _get_total_steps(self, request: OmniDiffusionRequest) -> int: + sampling = request.sampling_params + + if sampling.timesteps is not None: + return self._sequence_length(sampling.timesteps) + if sampling.sigmas is not None: + return len(sampling.sigmas) + return int(sampling.num_inference_steps) + + @staticmethod + def _sequence_length(values: Any) -> int: + ndim = getattr(values, "ndim", None) + if ndim == 0: + return 1 + + shape = getattr(values, "shape", None) + if shape is not None: + return int(shape[0]) + + return len(values) diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py index 5a6fb6371f0..ddad2f9f3f8 100644 --- a/vllm_omni/diffusion/stage_diffusion_client.py +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger +from vllm_omni.diffusion.data import DiffusionRequestAbortedError from vllm_omni.engine.stage_init_utils import StageMetadata from vllm_omni.entrypoints.async_omni_diffusion import AsyncOmniDiffusion from vllm_omni.outputs import OmniRequestOutput @@ -74,6 +75,20 @@ async def _run( try: result = await self._engine.generate(prompt, sampling_params, request_id) await self._output_queue.put(result) + except asyncio.CancelledError: + logger.info( + "[StageDiffusionClient] Stage-%s req=%s cancelled", + self.stage_id, + request_id, + ) + raise + except DiffusionRequestAbortedError as e: + logger.info( + "[StageDiffusionClient] Stage-%s req=%s aborted: %s", + self.stage_id, + request_id, + e, + ) except Exception as e: logger.exception( "[StageDiffusionClient] Stage-%s req=%s failed: %s", @@ -138,6 +153,7 @@ async def abort_requests_async(self, request_ids: list[str]) -> None: task = self._tasks.pop(rid, None) if task: task.cancel() + await self._engine.abort(request_ids) async def collective_rpc_async( self, diff --git a/vllm_omni/diffusion/worker/__init__.py b/vllm_omni/diffusion/worker/__init__.py index 8af0283857f..80a7addf3c3 100644 --- a/vllm_omni/diffusion/worker/__init__.py +++ b/vllm_omni/diffusion/worker/__init__.py @@ -2,14 +2,31 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Worker classes for diffusion models.""" -from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner -from vllm_omni.diffusion.worker.diffusion_worker import ( - DiffusionWorker, - WorkerProc, -) +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner + from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker, WorkerProc __all__ = [ "DiffusionModelRunner", "DiffusionWorker", "WorkerProc", ] + + +def __getattr__(name: str) -> Any: + if name == "DiffusionModelRunner": + from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner + + return DiffusionModelRunner + if name in {"DiffusionWorker", "WorkerProc"}: + from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker, WorkerProc + + return { + "DiffusionWorker": DiffusionWorker, + "WorkerProc": WorkerProc, + }[name] + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index a4d87c96e4a..9de3dc867ff 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -860,6 +860,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "max_num_seqs": 1, "parallel_config": parallel_config, "model_class_name": kwargs.get("model_class_name", None), + "step_execution": kwargs.get("step_execution", False), "vae_use_slicing": kwargs.get("vae_use_slicing", False), "vae_use_tiling": kwargs.get("vae_use_tiling", False), "cache_backend": cache_backend, diff --git a/vllm_omni/entrypoints/async_omni_diffusion.py b/vllm_omni/entrypoints/async_omni_diffusion.py index a7a02eded6b..674c3509d22 100644 --- a/vllm_omni/entrypoints/async_omni_diffusion.py +++ b/vllm_omni/entrypoints/async_omni_diffusion.py @@ -18,7 +18,11 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import get_hf_file_to_dict -from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig +from vllm_omni.diffusion.data import ( + DiffusionRequestAbortedError, + OmniDiffusionConfig, + TransformerConfig, +) from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType @@ -308,6 +312,11 @@ async def generate( request, ) result = result[0] + except asyncio.CancelledError: + self.engine.abort(request_id) + raise + except DiffusionRequestAbortedError: + raise except Exception as e: logger.error("Generation failed for request %s: %s", request_id, e) raise RuntimeError(f"Diffusion generation failed: {e}") from e diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index f924d64c391..4e1c8d3a94c 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -267,6 +267,11 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu action="store_true", help="Enable cache-dit summary logging after diffusion forward passes.", ) + omni_config_group.add_argument( + "--step-execution", + action="store_true", + help="Enable per-step diffusion execution so running requests can be aborted between denoise steps.", + ) # VAE memory optimization parameters omni_config_group.add_argument(