diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh
index f56f23b5deb..2707dfefe83 100755
--- a/.buildkite/scripts/hardware_ci/run-amd-test.sh
+++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh
@@ -64,12 +64,13 @@ while true; do
done
echo "--- Pulling container"
-## Temporary change to use AMD Docker Hub to store the vllm-ci image
+## Temporary change to use AMD Docker Hub to store the vllm-omni image
# to bypass the rate limit issue with ECR Public Gallery.
+# Images are now stored in a separate repository for vllm-omni, instead of vllm-ci.
# TODO: @tjtanaa point back to ECR Public Gallery
# once the amd agents are configured to use ECR Public Gallery.
# image_name="public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${BUILDKITE_COMMIT}-rocm-omni"
-image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}-rocm-omni"
+image_name="rocm/vllm-omni:${BUILDKITE_COMMIT}"
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
# TODO: @tjtanaa uncomment this once the amd agents are configured to use ECR Public Gallery.
diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml
index e175385ff0d..0b9a3f47aba 100644
--- a/.buildkite/test-amd.yaml
+++ b/.buildkite/test-amd.yaml
@@ -117,3 +117,17 @@ steps:
- export VLLM_LOGGING_LEVEL=DEBUG
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py
+
+
+- label: "Omni Sleep Mode Test"
+ timeout_in_minutes: 40
+ agent_pool: mi325_2
+ depends_on: amd-build
+ mirror_hardwares: [amdproduction]
+ grade: Blocking
+ commands:
+ - export GPU_ARCHS=gfx942
+ - export VLLM_LOGGING_LEVEL=DEBUG
+ - export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
+ - pytest -s -v tests/e2e/offline_inference/test_omni_sleep_mode.py -m "advanced_model and omni and MI325" --run-level "advanced_model"
diff --git a/.buildkite/test-merge.yml b/.buildkite/test-merge.yml
index 2a6cb6488a0..785cc58fab9 100644
--- a/.buildkite/test-merge.yml
+++ b/.buildkite/test-merge.yml
@@ -360,6 +360,46 @@ steps:
path: /mnt/hf-cache
type: DirectoryOrCreate
+ - label: "Omni Sleep Mode Test with H100"
+ timeout_in_minutes: 30
+ depends_on: upload-merge-pipeline
+ commands:
+ - export VLLM_TEST_CLEAN_GPU_MEMORY="1"
+ - pytest -s -v tests/e2e/offline_inference/test_omni_sleep_mode.py -m "advanced_model and H100 and omni" --run-level "advanced_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 2
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
- label: "Voxtral-TTS E2E Test"
timeout_in_minutes: 20
depends_on: upload-merge-pipeline
diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml
index c33d7b4d10d..fb21a2acaac 100644
--- a/.buildkite/test-nightly.yml
+++ b/.buildkite/test-nightly.yml
@@ -151,6 +151,44 @@ steps:
type: DirectoryOrCreate
+ - label: ":full_moon: Omni Multi-Replica Startup Test with 4x H100"
+ timeout_in_minutes: 45
+ commands:
+ - pytest -s -v tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py -m "core_model" --run-level "core_model"
+ agents:
+ queue: "mithril-h100-pool"
+ plugins:
+ - kubernetes:
+ podSpec:
+ containers:
+ - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ resources:
+ limits:
+ nvidia.com/gpu: 4
+ volumeMounts:
+ - name: devshm
+ mountPath: /dev/shm
+ - name: hf-cache
+ mountPath: /root/.cache/huggingface
+ env:
+ - name: HF_HOME
+ value: /root/.cache/huggingface
+ - name: HF_TOKEN
+ valueFrom:
+ secretKeyRef:
+ name: hf-token-secret
+ key: token
+ nodeSelector:
+ node.kubernetes.io/instance-type: gpu-h100-sxm
+ volumes:
+ - name: devshm
+ emptyDir:
+ medium: Memory
+ - name: hf-cache
+ hostPath:
+ path: /mnt/hf-cache
+ type: DirectoryOrCreate
+
- group: ":card_index_dividers: TTS Model Test"
key: nightly-tts-test-group
depends_on: upload-nightly-pipeline
diff --git a/.buildkite/test-template-amd-omni.j2 b/.buildkite/test-template-amd-omni.j2
index f4c386a5fe3..78f47d1aec0 100644
--- a/.buildkite/test-template-amd-omni.j2
+++ b/.buildkite/test-template-amd-omni.j2
@@ -3,7 +3,7 @@
Last synced: 2025-12-15
Modifications: Removed unused CUDA/NVIDIA logic, keeping only AMD tests
#}
-{% set docker_image_amd = "rocm/vllm-ci:$BUILDKITE_COMMIT-rocm-omni" %}
+{% set docker_image_amd = "rocm/vllm-omni:$BUILDKITE_COMMIT" %}
{% set default_working_dir = "/app/vllm-omni" %}
- group: "AMD Tests"
diff --git a/.gitignore b/.gitignore
index 35dc7571ee2..a631aa05f74 100644
--- a/.gitignore
+++ b/.gitignore
@@ -174,6 +174,7 @@ CLAUDE.md
# Codex
AGENTS.md
+.codex
.codex/
# cursor
diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md
index 6c209e5659a..a7b508302d2 100644
--- a/docs/contributing/profiling.md
+++ b/docs/contributing/profiling.md
@@ -138,9 +138,27 @@ Single-stage diffusion serving with torch profiler:
vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \
--omni \
--port 8091 \
- --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}'
+ --profiler-config '{
+ "profiler": "torch",
+ "torch_profiler_dir": "/tmp/vllm_profile_wan22_i2v",
+ "torch_profiler_with_stack": true,
+ "torch_profiler_with_flops": false,
+ "torch_profiler_use_gzip": true,
+ "torch_profiler_dump_cuda_time_total": true,
+ "torch_profiler_record_shapes": true,
+ "torch_profiler_with_memory": true,
+ "delay_iterations": 0,
+ "max_iterations": 0
+ }'
```
+Useful optional fields:
+
+- `torch_profiler_with_stack`: export `by_stack` operator views and stack text files.
+- `torch_profiler_record_shapes`: export `by_shape` operator views.
+- `torch_profiler_with_memory`: dump `memory_snapshot_rank*.pickle` when the backend supports memory history.
+- `torch_profiler_use_gzip`: write the trace as `trace_rank*.json.gz`.
+
Single-stage diffusion serving with Nsight Systems:
```bash
@@ -177,8 +195,11 @@ For mixed-stage pipelines, use explicit `stages` and pass the same stage list to
Torch profiler output:
-- Chrome/Perfetto traces under `torch_profiler_dir`
-- Optional aggregated CUDA-time tables under the same directory
+- Chrome/Perfetto trace: `trace_rank*.json` or `trace_rank*.json.gz`
+- Excel workbook: `ops_rank*.xlsx` with `summary`, and optional `by_shape` / `by_stack` sheets
+- Stack exports: `stacks_cpu_rank*.txt` and `stacks_cuda_rank*.txt` when stack capture is enabled
+- Memory snapshot: `memory_snapshot_rank*.pickle` when memory capture is enabled and supported by the backend
+- Optional aggregated CUDA-time tables under the same session directory
CUDA profiler / Nsight Systems output:
diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md
index 41aa48c1735..b4eec162d31 100644
--- a/docs/features/sleep_mode.md
+++ b/docs/features/sleep_mode.md
@@ -1,39 +1,243 @@
-# Sleep Mode
+# Sleep Mode & ACK Protocol
-vLLM-Omni’s **Sleep Mode** allows you to temporarily release most GPU memory used by a model—such as model weights and key-value (KV) caches (for autoregressive models)—**without stopping the server or unloading the Docker container**.
+vLLM-Omni’s **Sleep Mode** allows you to temporarily release most GPU memory used by a model—such as model weights and key-value (KV) caches—**without stopping the server or unloading the Docker container**.
-This feature is inherited from [vLLM’s Sleep Mode](https://blog.vllm.ai/2025/10/26/sleep-mode.html), which provides zero-reload model switching for multi-model serving.
-
-It is especially useful in **RLHF**, **training**, or **cost-saving scenarios**, where GPU resources must be freed between inference workloads.
+This feature is inherited from [vLLM’s Sleep Mode](https://blog.vllm.ai/2025/10/26/sleep-mode.html) and extended with the **Omni ACK Protocol** to support multi-stage pipelines and heterogeneous hardware backends (NVIDIA, AMD, Intel, Huawei). It is especially useful in **RLHF**, **dynamic model switching**, or **cost-saving scenarios**.
---
-## Omni Model
+## 1. Feature Documentation
-Omni model inherit the feature from vLLM' Sleep Mode
+### Overview
+Omni Sleep Mode provides a mechanism to "sleep" specific model stages. When a stage enters sleep, its physical VRAM is reclaimed by the system, while the process state is preserved for rapid "wake-up" without full re-initialization.
-This means:
+### Sleep Levels
+We support two levels of hibernation to balance recovery speed and memory efficiency:
-- Support both Level 1 and Level 2 sleep, allow to release and reset both model weights and KV Cache
+| Level | Name | Mechanism | Recovery Speed | Memory Freed |
+| :--- | :--- | :--- | :--- | :--- |
+| **Level 1** | **Weight Offloading** | Offloads weights to Host CPU RAM. | **Fast** (DMA) | Substantial |
+| **Level 2** | **Full De-mapping** | Physically releases memory pages via VRAM scavenging. | **Moderate** | **Maximum** (up to 95%+) |
-## Diffusion Model Extension
+### Supported Platforms
-We added Sleep Mode support for **diffusion models**, which previously lacked this functionality.
-In diffusion pipelines, this currently only offloads **model weight memory**, as these models typically do not use KV caches.
+Omni Sleep Mode is optimized for high-performance computing backends:
-This means:
+* **NVIDIA**: Supported via Virtual Memory Management (VMM).
+* **AMD (ROCm)**: Fully supported with physical page de-mapping.
+* **Intel XPU**: Supported via Level Zero memory management.
+* **Huawei NPU**: Supported via Ascend memory scavenging.
-- Diffusion models can now enter Level 1 sleep.
-- Pipeline states (e.g., noise schedulers, buffers) remain intact after waking.
-- Useful for releasing VRAM between image generation or training cycles.
+### Hardware Requirements
+* **Memory Considerations**: System RAM must be sufficient to hold offloaded weights during sleep.
+* **TP Support**: Tensor Parallel groups synchronize sleep/wake transitions across all workers.
---
-## Enable sleep mode
-To enable sleep mode, set the `enable_sleep_mode` in `engine_args` to `True`
+## 2. Usage Examples
+
+### Python API Example
+You can programmatically control the lifecycle of stages using the `AsyncOmni` engine.
-Example:
```python
-omni = Omni(model=...,enable_sleep_mode=True)
+
+import asyncio
+from vllm_omni.entrypoints.async_omni import AsyncOmni
+
+async def run_sleep_demo():
+ # 1. initialization
+ engine = AsyncOmni(
+ model="ByteDance-Seed/BAGEL-7B-MoT",
+ enable_sleep_mode=True
+ )
+
+ # 2. sleep mode level2
+ acks = await engine.sleep(stage_ids=[0], level=2)
+ print(f"Freed {acks[0].freed_bytes / 1024**3:.2f} GiB on Stage 0")
+
+ # 3. wake up
+ await engine.wake_up(stage_ids=[0])
+
+if __name__ == "__main__":
+ asyncio.run(run_sleep_demo())
+
+```
+
+### server command Example
+Start the server with sleep mode enabled:
+
+The first method
+
+```
+
+vllm serve ByteDance-Seed/BAGEL-7B-MoT \
+--omni \
+--enable-sleep-mode \
+--trust-remote-code \
+--gpu-memory-utilization 0.7
+
+```
+
+The second method
+
+```
+
+python3 -m vllm_omni.entrypoints.openai.api_server \
+ --model ByteDance-Seed/BAGEL-7B-MoT \
+ --omni \
+ --enable-sleep-mode \
+ --trust-remote-code \
+--gpu-memory-utilization 0.7
+
+```
+
+
+
+
+### Test Scenarios & Commands
+
+#### Scenario 1: LLM Engine Sleep
+
+Objective: Verify VRAM reclamation for Stage 0 (Thinker).
+
+Trigger sleep (Level 1 or Level 2) via client:
+
```
+
+curl -X POST http://localhost:8000/v1/omni/sleep \
+ -H "Content-Type: application/json" \
+ -d '{"stage_ids": [0], "level": 2}'
+
+```
+
+Tip: Open a new terminal and run rocm-smi or nvidia-smi or to observe the immediate drop in VRAM usage.
+
+
+
+#### Scenario 2: Diffusion Sleep
+Objective: Verify VRAM reclamation for Stage 1 (Diffusion).
+
+Trigger sleep (Level 1 or Level 2) via client:
+
+```
+
+curl -X POST http://localhost:8000/v1/omni/sleep \
+ -H "Content-Type: application/json" \
+ -d '{"stage_ids": [1], "level": 2}'
+
+```
+
+
+
+#### Scenario 3: Multi-Stage Coordinated Stress Test
+Objective: Test concurrent sleep and rapid wake-up across multiple stages.
+
+Concurrent Sleep (Stage 0 & 1):
+
+```
+
+curl -X POST http://localhost:8000/v1/omni/sleep \
+ -H "Content-Type: application/json" \
+ -d '{"stage_ids": [0, 1], "level": 2}'
+
+```
+
+
+Rapid Wake-up:
+
+```
+
+curl -X POST http://localhost:8000/v1/omni/wakeup \
+ -H "Content-Type: application/json" \
+ -d '{"stage_ids": [0, 1]}'
+
+```
+
+
+#### Scenario 4: Full Lifecycle Memory Audit & Functional Integrity
+Objective: Audit the complete flow from Sleep to Wake-up followed by an Inference validation.
+
+Check Initial State: Observe baseline VRAM usage.
+
+Trigger Deep Sleep (Level 2):
+
+```
+
+curl -X POST http://localhost:8000/v1/omni/sleep \
+ -H "Content-Type: application/json" \
+ -d '{"stage_ids": [0], "level": 2}'
+
+```
+
+Wake-up Model:
+
+```
+
+curl -X POST http://localhost:8000/v1/omni/wakeup \
+ -H "Content-Type: application/json" \
+ -d '{"stage_ids": [0]}'
+
+```
+
+Verify Functional Integrity (Inference):
+Ensure the model still generates valid output after reloading weights.
+
+```
+
+curl -X POST http://localhost:8000/v1/images/generations \
+ -H "Content-Type: application/json" \
+ -d '{
+ "prompt": "A huge swimming pool, with many people swimming.",
+ "model": "ByteDance-Seed/BAGEL-7B-MoT",
+ "response_format": "b64_json",
+ "extra_body": {"sampling_params": {"num_inference_steps": 4, "seed": 42}}
+ }' > post.json
+
+```
+
+
+
+
+## 3. API Reference
+
+
+### Methods
+
+| Method | Arguments | Return Type | Description |
+| :--- | :--- | :--- | :--- |
+| **sleep** | `stage_ids: List[int], level: int` | `List[OmniACK]` | Triggers hibernation for specified stages. |
+| **wake_up** | `stage_ids: List[int]` | `List[OmniACK]` | Reloads weights and re-maps memory. |
+
+
+
+### OmniACK Dataclass Fields
+
+| Field | Type | Description |
+| :--- | :--- | :--- |
+| **task_id** | `str` | Unique identifier for the operation. |
+| **status** | `str` | `SUCCESS` or `ERROR`. |
+| **stage_id** | `int` | The ID of the stage that responded. |
+| **rank** | `int` | The rank ID within the Tensor Parallel group. |
+| **freed_bytes** | `int` | Actual amount of physical VRAM reclaimed. |
+| **metadata** | `dict` | Additional platform-specific metrics. |
+
+Metadata Field Analysis
+The metadata field is a dynamic dictionary containing hardware-specific telemetry and audit data, primarily used for verifying memory reclamation on various backends (e.g., AMD ROCm, NVIDIA CUDA).
+
+```
+"metadata": {
+ "source": "Platform_AMD_Instinct_MI300X",
+ "total_freed_gib": "78.57",
+ "rank_residual_gib": "2.07"
+}
+```
+
+#### Core Utility:
+**VRAM Reclamation Audit (total_freed_gib)**: Converts raw freed_bytes into human-readable GiB. It serves as the primary metric to verify that Level 2 sleep has successfully purged model weights from VRAM.
+
+**Residual & Fragmentation Monitoring (rank_residual_gib)**: Reports the remaining VRAM footprint after memory de-mapping. A low residual value (e.g., 2.07 GiB) confirms a successful "clean" state, ensuring the device is ready for high-memory co-located tasks like training or diffusion pipelines.
+
+**Backend Traceability (source)**: Identifies the underlying hardware driver or audit source. This is critical for debugging synchronization issues in multi-stage, distributed environments.
+
+**Performance Analytics (Roadmap)**: Future updates will include latency_ms (context-switch overhead) and cuda_graph_recalled (graph engine status) to optimize performance in high-frequency sleep/wake scenarios.
diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md
index 7bdeede446a..be5b2d10ca3 100644
--- a/docs/user_guide/diffusion_features.md
+++ b/docs/user_guide/diffusion_features.md
@@ -140,7 +140,7 @@ The following tables show which models support each feature:
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ |
| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
-| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
+| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
diff --git a/docs/user_guide/examples/online_serving/text_to_video.md b/docs/user_guide/examples/online_serving/text_to_video.md
index 00a9c167239..b918aac19d0 100644
--- a/docs/user_guide/examples/online_serving/text_to_video.md
+++ b/docs/user_guide/examples/online_serving/text_to_video.md
@@ -288,6 +288,14 @@ vllm serve Lightricks/LTX-2 --omni --port 8098 \
--enforce-eager --flow-shift 1.0 --boundary-ratio 1.0
```
+For multi-GPU memory reduction, you can enable HSDP:
+
+```bash
+vllm serve Lightricks/LTX-2 --omni --port 8098 \
+ --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0 \
+ --use-hsdp --hsdp-shard-size 2
+```
+
#### Start with Optimization Presets
Use the LTX-2 startup script with built-in optimization presets:
diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py
index 13bcd23f55a..7ae9478ba75 100644
--- a/examples/offline_inference/glm_image/end2end.py
+++ b/examples/offline_inference/glm_image/end2end.py
@@ -57,22 +57,22 @@
GLM_IMAGE_VISION_VOCAB_SIZE = 16512 # top_k should be vision_vocab_size
-def compute_max_tokens(height: int, width: int, factor: int = 32) -> int:
+def compute_max_tokens(height: int, width: int, factor: int = 32, is_i2i: bool = False) -> int:
"""
Compute max_new_tokens for GLM-Image AR generation.
- GLM-Image generates tokens in this order for text-to-image:
- 1. Small preview image (half resolution in each dimension)
- 2. Large target image (full resolution)
- 3. EOS token
+ GLM-Image generation differs by mode:
+ - text-to-image (t2i): small preview + large target + EOS
+ - image-to-image (i2i): large target + EOS
Args:
height: Target image height in pixels
width: Target image width in pixels
factor: Downsampling factor (32 for GLM-Image AR output)
+ is_i2i: Whether the request is image-to-image mode
Returns:
- Total number of tokens to generate (small + large + EOS)
+ Total number of tokens to generate for the specified mode
"""
# Large image tokens (target resolution)
token_h = height // factor
@@ -80,11 +80,15 @@ def compute_max_tokens(height: int, width: int, factor: int = 32) -> int:
large_tokens = token_h * token_w
# Small preview tokens (half resolution in each dimension)
- small_h = token_h // 2
- small_w = token_w // 2
- small_tokens = small_h * small_w
+ import math
- # Total: small + large + EOS
+ ratio = token_h / token_w if token_w > 0 else 1.0
+ small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
+ small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
+ small_tokens = small_token_h * small_token_w
+
+ if is_i2i:
+ return large_tokens + 1
return small_tokens + large_tokens + 1
@@ -282,14 +286,18 @@ def main(args: argparse.Namespace) -> None:
# Compute max_tokens dynamically based on target image size
target_height = prompt_dict.get("height", 1024)
target_width = prompt_dict.get("width", 1024)
- calculated_max_tokens = compute_max_tokens(target_height, target_width)
+ is_i2i = source_image is not None
+ calculated_max_tokens = compute_max_tokens(target_height, target_width, is_i2i=is_i2i)
# Use calculated value unless user explicitly specified a different value
# Default args.max_tokens is 16384 (very large), so prefer calculated value
effective_max_tokens = calculated_max_tokens if args.max_tokens == 16384 else args.max_tokens
if args.verbose:
- print(f"AR max_tokens: {effective_max_tokens} (calculated: {calculated_max_tokens}, arg: {args.max_tokens})")
+ print(
+ f"AR max_tokens: {effective_max_tokens} "
+ f"(calculated: {calculated_max_tokens}, arg: {args.max_tokens}, mode: {'i2i' if is_i2i else 't2i'})"
+ )
# IMPORTANT: GLM-Image AR model requires these exact sampling parameters
# from generation_config.json for proper image token generation.
@@ -303,6 +311,12 @@ def main(args: argparse.Namespace) -> None:
stop_token_ids=[GLM_IMAGE_EOS_TOKEN_ID], # 16385, CRITICAL for stopping
seed=args.seed,
detokenize=False,
+ # Keep target size available in runner/model for deterministic M-RoPE
+ # decode grids in t2i (no mm_features available in this path).
+ extra_args={
+ "target_h": int(target_height),
+ "target_w": int(target_width),
+ },
)
# For diffusion stage, sampling_params contains diffusion-specific parameters
diff --git a/examples/offline_inference/hunyuan_image3/README.md b/examples/offline_inference/hunyuan_image3/README.md
index da28a44d9e6..3cd8fa01b2e 100644
--- a/examples/offline_inference/hunyuan_image3/README.md
+++ b/examples/offline_inference/hunyuan_image3/README.md
@@ -1,25 +1,161 @@
-# HunyuanImage-3.0 Image-to-Text Inference
+# HunyuanImage-3.0-Instruct
-This example demonstrates how to run HunyuanImage-3.0 Image-to-Text with the vLLM-Omni.
+## Set up
-## Local CLI Usage
+Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
-Download the example image:
+## Run examples
+
+**Note**: These examples work with the default configuration on **8x NVIDIA L40S (48GB)**. For different GPU setups, modify the stage configuration to adjust device allocation and memory utilization.
+
+Get into the hunyuan_image3 folder:
+
+```bash
+cd examples/offline_inference/hunyuan_image3
+```
+
+### Modality Control
+
+HunyuanImage-3.0-Instruct supports multiple modality modes. You can control the mode using the `--modality` argument:
+
+#### Text to Image (text2img)
+
+- **Pipeline**: Text → AR (CoT + latent tokens) → DiT (denoise) → VAE Decode → Image
+- **Stages Used**: Stage 0 (AR) + Stage 1 (DiT)
+- **KV Transfer**: AR sends KV cache to DiT for conditioned generation
+- **Default Config**: `hunyuan_image3_t2i.yaml`
+
+```bash
+python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
+ --modality text2img \
+ --prompts "A cute cat sitting on a windowsill watching the sunset"
+```
+
+#### Image to Image (img2img)
+
+- **Pipeline**: Image + Text → AR (CoT + recaption + latent) → DiT → Edited Image
+- **Stages Used**: Stage 0 (AR) + Stage 1 (DiT)
+- **KV Transfer**: AR sends KV cache to DiT
+- **Default Config**: `hunyuan_image3_it2i.yaml`
+
+```bash
+python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
+ --modality img2img \
+ --image-path /path/to/image.png \
+ --prompts "Make the petals neon pink"
+```
+
+#### Image to Text (img2text)
+
+- **Pipeline**: Image + Question → AR → Text description
+- **Stages Used**: Stage 0 (AR) only
+- **Default Config**: `hunyuan_image3_i2t.yaml`
```bash
-wget https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg
+python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
+ --modality img2text \
+ --image-path /path/to/image.jpg \
+ --prompts "Describe the content of the picture."
```
-Run example:
+#### Text to Text (text2text)
+
+- **Pipeline**: Text → AR → Text
+- **Stages Used**: Stage 0 (AR) only
+- **Default Config**: `hunyuan_image3_t2t.yaml`
```bash
-python image_to_text.py \
- --image cherry_blossom.jpg \
- --prompt "<|startoftext|>You are an assistant that understands images and outputs text.
Describe the content of the picture."
+python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
+ --modality text2text \
+ --prompts "What is the capital of France?"
```
-Key arguments:
+### Inference Steps & Guidance
+
+Control generation quality for image modalities:
+
+```bash
+python end2end.py --modality text2img \
+ --steps 50 \
+ --guidance-scale 5.0 \
+ --height 1024 --width 1024 \
+ --prompts "A photo-realistic sunset over the ocean"
+```
+
+### Key Arguments
+
+#### 📌 Command Line Arguments (end2end.py)
+
+| Argument | Type | Default | Description |
+| :--------------------- | :----- | :----------------------------------- | :----------------------------------------------------------- |
+| `--model` | string | `tencent/HunyuanImage-3.0-Instruct` | Model path or name |
+| `--modality` | choice | `text2img` | Modality: `text2img`, `img2img`, `img2text`, `text2text` |
+| `--prompts` | list | `None` | Input text prompts |
+| `--image-path` | string | `None` | Input image path (for `img2img`/`img2text`) |
+| `--output` | string | `.` | Output directory for saved images |
+| `--steps` | int | `50` | Number of inference steps |
+| `--guidance-scale` | float | `5.0` | Classifier-free guidance scale |
+| `--seed` | int | `42` | Random seed |
+| `--height` | int | `1024` | Output image height |
+| `--width` | int | `1024` | Output image width |
+| `--bot-task` | string | auto | Override prompt task (e.g. `it2i_think`, `t2i_recaption`) |
+| `--sys-type` | string | auto | Override system prompt type (e.g. `en_unified`, `en_vanilla`) |
+| `--stage-configs-path` | string | auto | Custom stage config YAML path |
+| `--enforce-eager` | flag | `False` | Disable torch.compile |
+| `--init-timeout` | int | `300` | Initialization timeout (seconds) |
+
+------
+
+#### ⚙️ Stage Configurations
+
+| Config YAML | Modality | Stages | GPUs | Description |
+| :---------------------------------- | :-------- | :----- | :----- | :------------------------------------ |
+| `hunyuan_image3_t2i.yaml` | text2img | 2 | 8 | T2I with AR→DiT, 4 GPU each |
+| `hunyuan_image3_it2i.yaml` | img2img | 2 | 8 | IT2I with AR→DiT, 4 GPU each |
+| `hunyuan_image3_i2t.yaml` | img2text | 1 | 4 | I2T (AR only) |
+| `hunyuan_image3_t2t.yaml` | text2text | 1 | 4 | T2T (AR only) |
+| `hunyuan_image3_t2i_2gpu.yaml` | text2img | 2 | 2 | T2I for 2-GPU setups |
+| `hunyuan_image3_moe.yaml` | text2img | 2 | 8 | T2I with MoE AR→DiT KV reuse |
+| `hunyuan_image3_moe_dit_2gpu_fp8.yaml` | text2img | 2 | 2 | T2I with FP8 quantization |
+
+------
+
+## Using MoE Config
+
+The `hunyuan_image3_moe.yaml` config enables AR→DiT KV cache reuse with 8 GPUs (4 for AR + 4 for DiT).
+
+```bash
+python end2end.py --model tencent/HunyuanImage-3.0-Instruct \
+ --modality text2img \
+ --stage-configs-path hunyuan_image3_moe.yaml \
+ --prompts "A cute cat"
+```
+
+------
+
+## Prompt Format
+
+HunyuanImage-3.0 uses a pretrain template format:
+
+```
+<|startoftext|>{system_prompt}{
}{trigger_tag}{user_prompt}
+```
+
+- `
`: Placeholder for each input image (auto-inserted by `prompt_utils.py`)
+- Trigger tags: `` (CoT), `` (recaptioning)
+- System prompt: Auto-selected based on task
+
+The `prompt_utils.build_prompt()` handles this formatting automatically.
+
+------
+
+## FAQ
+
+- **OOM errors**: Decrease `gpu_memory_utilization` in the YAML stage config, or use a smaller `max_num_batched_tokens`.
+- **Custom image sizes**: Use `--height` and `--width` flags (multiples of 16 recommended).
-- `--model`: Model used. Default is: tencent/HunyuanImage-3.0-Instruct (Optional).
-- `--image`: Path to input image (required).
-- `--prompt`: Text description used to guide image understanding (required).
+| Stage | VRAM (approx) |
+| :---------------- | :------------------- |
+| Stage 0 (AR) | ~15 GiB + KV Cache |
+| Stage 1 (DiT) | ~30 GiB |
+| Total (8-GPU) | ~45 GiB + KV Cache |
diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py
new file mode 100644
index 00000000000..3c1ae386678
--- /dev/null
+++ b/examples/offline_inference/hunyuan_image3/end2end.py
@@ -0,0 +1,262 @@
+"""
+HunyuanImage-3.0-Instruct unified end-to-end inference script.
+
+Supports all modalities through a single entry point:
+ - text2img: Text → AR → DiT → Image
+ - img2img: Text+Image → AR → DiT → Edited Image (IT2I)
+ - img2text: Image+Text → AR → Text description (I2T)
+ - text2text: Text → AR → Text (comprehension, no image)
+
+Usage:
+ python end2end.py --modality text2img --prompts "A cute cat"
+ python end2end.py --modality img2img --image-path input.png --prompts "Make it snowy"
+ python end2end.py --modality img2text --image-path input.png --prompts "Describe this image"
+"""
+
+import argparse
+import os
+
+from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import (
+ get_system_prompt,
+)
+from vllm_omni.entrypoints.omni import Omni
+from vllm_omni.inputs.data import OmniPromptType
+
+# task → (sys_type, bot_task, trigger_tag)
+_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = {
+ "t2t": ("en_unified", None, None),
+ "i2t": ("en_unified", None, None),
+ "it2i_think": ("en_unified", "think", ""),
+ "it2i_recaption": ("en_unified", "recaption", ""),
+ "t2i_think": ("en_unified", "think", ""),
+ "t2i_recaption": ("en_unified", "recaption", ""),
+ "t2i_vanilla": ("en_vanilla", "image", None),
+}
+
+# Modality → prompt_utils task mapping
+_MODALITY_TASK_MAP = {
+ "text2img": "t2i_think",
+ "img2img": "it2i_think",
+ "img2text": "i2t",
+ "text2text": "t2t",
+}
+
+
+def build_prompt(
+ user_prompt: str,
+ task: str = "it2i_think",
+ sys_type: str | None = None,
+ custom_system_prompt: str | None = None,
+) -> str:
+ """Build a HunyuanImage-3.0 prompt using pretrain template format."""
+ if task not in _TASK_PRESETS:
+ raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}")
+
+ preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
+ effective_sys_type = sys_type or preset_sys_type
+
+ system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
+ sys_text = system_prompt.strip() if system_prompt else ""
+
+ has_image_input = task.startswith("i2t") or task.startswith("it2i")
+
+ parts = ["<|startoftext|>"]
+ if sys_text:
+ parts.append(sys_text)
+ if has_image_input:
+ parts.append("
")
+ if trigger_tag:
+ parts.append(trigger_tag)
+ parts.append(user_prompt)
+
+ return "".join(parts)
+
+
+# Modality → default stage config
+_MODALITY_DEFAULT_CONFIG = {
+ "text2img": "hunyuan_image3_t2i.yaml",
+ "img2img": "hunyuan_image3_it2i.yaml",
+ "img2text": "hunyuan_image3_i2t.yaml",
+ "text2text": "hunyuan_image3_t2t.yaml",
+}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="HunyuanImage-3.0-Instruct end-to-end inference.")
+ parser.add_argument(
+ "--model",
+ default="tencent/HunyuanImage-3.0-Instruct",
+ help="Model name or local path.",
+ )
+ parser.add_argument(
+ "--modality",
+ default="text2img",
+ choices=["text2img", "img2img", "img2text", "text2text"],
+ help="Modality mode to control stage execution.",
+ )
+ parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.")
+ parser.add_argument(
+ "--image-path",
+ type=str,
+ default=None,
+ help="Path to input image (for img2img/img2text).",
+ )
+ parser.add_argument(
+ "--output",
+ type=str,
+ default=".",
+ help="Output directory to save results.",
+ )
+
+ # Generation parameters
+ parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.")
+ parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale.")
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
+ parser.add_argument("--height", type=int, default=1024, help="Output image height.")
+ parser.add_argument("--width", type=int, default=1024, help="Output image width.")
+
+ # Prompt configuration
+ parser.add_argument(
+ "--bot-task",
+ type=str,
+ default=None,
+ help="Override prompt task (e.g. it2i_think, t2i_recaption). Default: auto from modality.",
+ )
+ parser.add_argument(
+ "--sys-type",
+ type=str,
+ default=None,
+ help="Override system prompt type (e.g. en_unified, en_vanilla).",
+ )
+
+ # Omni init args
+ parser.add_argument("--stage-configs-path", type=str, default=None, help="Custom stage config YAML path.")
+ parser.add_argument("--log-stats", action="store_true", default=False)
+ parser.add_argument("--init-timeout", type=int, default=300, help="Initialization timeout in seconds.")
+ parser.add_argument("--enforce-eager", action="store_true", help="Disable torch.compile.")
+
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+ os.makedirs(args.output, exist_ok=True)
+
+ # Determine task for prompt formatting
+ task = args.bot_task or _MODALITY_TASK_MAP[args.modality]
+
+ # Determine stage config
+ stage_configs_path = args.stage_configs_path or _MODALITY_DEFAULT_CONFIG[args.modality]
+
+ # Build Omni
+ omni_kwargs = {
+ "model": args.model,
+ "stage_configs_path": stage_configs_path,
+ "log_stats": args.log_stats,
+ "init_timeout": args.init_timeout,
+ "enforce_eager": args.enforce_eager,
+ }
+ if args.modality in ("text2img", "img2img"):
+ omni_kwargs["mode"] = "text-to-image"
+
+ omni = Omni(**omni_kwargs)
+
+ # Prepare prompts
+ prompts = args.prompts or ["A cute cat"]
+ if not prompts:
+ print("[Info] No prompts provided, using default.")
+ prompts = ["A cute cat"]
+
+ # Load image if needed
+ input_image = None
+ if args.modality in ("img2img", "img2text"):
+ if not args.image_path or not os.path.exists(args.image_path):
+ raise ValueError(f"--image-path required for {args.modality}, got: {args.image_path}")
+ from PIL import Image
+
+ input_image = Image.open(args.image_path).convert("RGB")
+
+ # Format prompts
+ formatted_prompts: list[OmniPromptType] = []
+ for p in prompts:
+ formatted_text = build_prompt(p, task=task, sys_type=args.sys_type)
+
+ prompt_dict: dict = {"prompt": formatted_text}
+
+ if args.modality == "text2img":
+ prompt_dict["modalities"] = ["image"]
+ elif args.modality == "img2img":
+ prompt_dict["modalities"] = ["image"]
+ prompt_dict["multi_modal_data"] = {"image": input_image}
+ prompt_dict["height"] = input_image.height
+ prompt_dict["width"] = input_image.width
+ elif args.modality == "img2text":
+ prompt_dict["modalities"] = ["text"]
+ prompt_dict["multi_modal_data"] = {"image": input_image}
+ elif args.modality == "text2text":
+ prompt_dict["modalities"] = ["text"]
+
+ formatted_prompts.append(prompt_dict)
+
+ # Build sampling params from defaults
+ params_list = list(omni.default_sampling_params_list)
+
+ # Override diffusion params if applicable
+ from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+ for i, sp in enumerate(params_list):
+ if isinstance(sp, OmniDiffusionSamplingParams):
+ sp.num_inference_steps = args.steps
+ sp.guidance_scale = args.guidance_scale
+ if args.seed is not None:
+ sp.seed = args.seed
+ if args.modality in ("text2img",):
+ sp.height = args.height
+ sp.width = args.width
+
+ # Print configuration
+ print(f"\n{'=' * 60}")
+ print("HunyuanImage-3.0 Generation Configuration:")
+ print(f" Model: {args.model}")
+ print(f" Modality: {args.modality}")
+ print(f" Stage config: {stage_configs_path}")
+ print(f" Num stages: {omni.num_stages}")
+ if args.modality in ("text2img", "img2img"):
+ print(f" Inference steps: {args.steps}")
+ print(f" Guidance scale: {args.guidance_scale}")
+ print(f" Seed: {args.seed}")
+ if args.modality == "text2img":
+ print(f" Output size: {args.width}x{args.height}")
+ if args.image_path:
+ print(f" Input image: {args.image_path}")
+ print(f" Prompts: {prompts}")
+ print(f"{'=' * 60}\n")
+
+ # Generate
+ omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
+
+ # Process outputs
+ img_idx = 0
+ for req_output in omni_outputs:
+ # Text output (AR stage or text-only)
+ ro = getattr(req_output, "request_output", None)
+ if ro and getattr(ro, "outputs", None):
+ txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs)
+ if txt:
+ print(f"[Output] Text:\n{txt}")
+
+ # Image output (DiT stage)
+ images = getattr(req_output, "images", None)
+ if not images and ro and hasattr(ro, "images"):
+ images = ro.images
+
+ if images:
+ for j, img in enumerate(images):
+ save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png")
+ img.save(save_path)
+ print(f"[Output] Saved image to {save_path}")
+ img_idx += 1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/offline_inference/hunyuan_image3/image_to_text.py b/examples/offline_inference/hunyuan_image3/image_to_text.py
deleted file mode 100644
index d40134ac0a0..00000000000
--- a/examples/offline_inference/hunyuan_image3/image_to_text.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import argparse
-import os
-
-from PIL import Image
-
-from vllm_omni.entrypoints.omni import Omni
-
-"""
-The tencent/HunyuanImage-3.0-Instruct base model uses the tencent/Hunyuan-A13B-Instruct backbone. It utilizes two tokenizer delimiter templates:
-
-1) Pretrained template (default for gen_text mode), which concatenates system, image
- tokens, and user question WITHOUT role delimiters:
-"<|startoftext|>{system_prompt}{image_tokens}{user_question}"
-
- Example (before image token expansion):
-"<|startoftext|>You are an assistant that understands images and outputs text.
Describe the content of the picture."
-
-2) Instruct template, which uses explicit role prefixes and separators.
-"""
-
-
-def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="Generate text from image using HunyuanImage-3.0-Instruct.")
- parser.add_argument(
- "--model",
- default="tencent/HunyuanImage-3.0-Instruct",
- help="Model name or local path.",
- )
- parser.add_argument(
- "--image",
- type=str,
- required=True,
- help="Path to input image file (PNG, JPG, etc.).",
- )
- parser.add_argument(
- "--prompt",
- type=str,
- required=True,
- help="Pretrain template prompt: <|startoftext|>{system}
{question}",
- )
- parser.add_argument(
- "--enable-diffusion-pipeline-profiler",
- action="store_true",
- help="Enable diffusion pipeline profiler to display stage durations.",
- )
- return parser.parse_args()
-
-
-def load_image(image_path: str) -> Image.Image:
- """Load an image from file path."""
- if not os.path.exists(image_path):
- raise FileNotFoundError(f"Image file not found: {image_path}")
- return Image.open(image_path).convert("RGB")
-
-
-def main(args: argparse.Namespace) -> None:
- omni = Omni(
- model=args.model,
- enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
- mode="image-to-text",
- )
-
- prompt = "<|startoftext|>You are an assistant that understands images and outputs text.
" + args.prompt
-
- prompt_dict = {
- "prompt": prompt,
- "modalities": ["text"],
- }
-
- # Add image input if provided
- if args.image:
- if not os.path.exists(args.image):
- raise FileNotFoundError(f"Input image not found: {args.image}")
-
- input_image = load_image(args.image)
- prompt_dict["multi_modal_data"] = {"image": input_image}
-
- prompts = [prompt_dict]
- omni_outputs = omni.generate(prompts=prompts)
-
- prompt_text = omni_outputs[0].request_output.prompt
- generated_text = omni_outputs[0].request_output.outputs[0].text
- print(f"Prompt: {prompt_text}")
- print(f"Text: {generated_text}")
-
-
-if __name__ == "__main__":
- args = parse_args()
- main(args)
diff --git a/examples/offline_inference/hunyuan_image3/prompt_utils.py b/examples/offline_inference/hunyuan_image3/prompt_utils.py
deleted file mode 100644
index a5ef8e15369..00000000000
--- a/examples/offline_inference/hunyuan_image3/prompt_utils.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Prompt construction utilities for HunyuanImage-3.0-Instruct examples.
-
-Wraps system_prompt.get_system_prompt() with task-aware presets so that
-examples and tests don't need to manually concatenate system prompts,
-
, , and tags.
-
-Usage:
- from prompt_utils import build_prompt
-
- # IT2I (image editing, think+recaption mode)
- prompt = build_prompt("Make the petals neon pink", task="it2i_think")
-
- # I2T (image understanding)
- prompt = build_prompt("Describe the content of the picture.", task="i2t")
-"""
-
-from __future__ import annotations
-
-from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import (
- get_system_prompt,
-)
-
-# task → (sys_type, bot_task, trigger_tag)
-# trigger_tag: "", "", or None
-_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = {
- # Pure text generation (text → text, no image)
- "t2t": ("en_unified", None, None),
- # Image understanding (image → text)
- "i2t": ("en_unified", None, None),
- # Image editing (image+text → image), think+recaption mode
- "it2i_think": ("en_unified", "think", ""),
- # Image editing, recaption-only mode
- "it2i_recaption": ("en_unified", "recaption", ""),
- # Text-to-image, think mode
- "t2i_think": ("en_unified", "think", ""),
- # Text-to-image, recaption mode
- "t2i_recaption": ("en_unified", "recaption", ""),
- # Text-to-image, vanilla (no CoT)
- "t2i_vanilla": ("en_vanilla", "image", None),
-}
-
-
-def build_prompt(
- user_prompt: str,
- task: str = "it2i_think",
- sys_type: str | None = None,
- custom_system_prompt: str | None = None,
-) -> str:
- """Build a complete HunyuanImage-3.0 prompt with auto-selected system
- prompt and mode trigger tags.
-
- Args:
- user_prompt: The user's raw instruction or question.
- task: One of the preset task keys (see _TASK_PRESETS).
- sys_type: Override the preset's sys_type for get_system_prompt().
- custom_system_prompt: Custom system prompt text (used when
- sys_type="custom").
-
- Returns:
- Fully formatted prompt string ready for Omni.generate().
- """
- if task not in _TASK_PRESETS:
- raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}")
-
- preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task]
- effective_sys_type = sys_type or preset_sys_type
-
- system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt)
- sys_text = system_prompt.strip() if system_prompt else ""
-
- has_image_input = task.startswith("i2t") or task.startswith("it2i")
-
- parts = ["<|startoftext|>"]
- if sys_text:
- parts.append(sys_text)
- # Instruct conversation template: \n\nUser: ... \n\nAssistant:
- parts.append("\n\nUser: ")
- if has_image_input:
- parts.append("
")
- parts.append(user_prompt)
- parts.append("\n\nAssistant: ")
- if trigger_tag:
- parts.append(trigger_tag)
-
- return "".join(parts)
diff --git a/pyproject.toml b/pyproject.toml
index 9b034a7c8e9..012bcd47c41 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -237,3 +237,4 @@ ue = "ue"
semantics = "semantics"
fullset = "fullset"
Vai = "Vai"
+CANN = "CANN"
diff --git a/recipes/README.md b/recipes/README.md
index 5b3dfb5430b..01ecc41f185 100644
--- a/recipes/README.md
+++ b/recipes/README.md
@@ -27,6 +27,8 @@ recipes/
- [`Qwen/Qwen3-Omni.md`](./Qwen/Qwen3-Omni.md): online serving recipe for
multimodal chat on `1x A100 80GB`
+- [`Wan-AI/Wan2.2-I2V.md`](./Wan-AI/Wan2.2-I2V.md): image-to-video serving
+ recipe for Wan2.2 14B on `8x Ascend NPU (A2/A3)`
Within a single recipe file, include different hardware support sections such
as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like
diff --git a/recipes/Wan-AI/Wan2.2-I2V.md b/recipes/Wan-AI/Wan2.2-I2V.md
new file mode 100644
index 00000000000..99ceac3cebe
--- /dev/null
+++ b/recipes/Wan-AI/Wan2.2-I2V.md
@@ -0,0 +1,136 @@
+# Wan2.2 Image To Video
+
+## Summary
+
+- Vendor: Wan-AI
+- Model: `Wan-AI/Wan2.2-I2V-A14B-Diffusers`
+- Task: Image-to-video generation
+- Mode: Online serving with the OpenAI-compatible API
+- Maintainer: Community
+
+## When to use this recipe
+
+Use this recipe when you want to deploy the Wan2.2 14B image-to-video model
+with vLLM-Omni using multi-card parallelism. Two configurations are provided:
+
+1. **Distilled model (no negative-prompt / CFG computation)** — higher
+ throughput, recommended when using a distilled checkpoint that does not
+ require classifier-free guidance.
+2. **Official open-source model (with CFG)** — uses `--cfg 2` to run negative
+ and positive samples in parallel for the original released weights.
+
+## References
+
+- Upstream model card:
+
+## Hardware Support
+
+## NPU
+
+### 8x Ascend A2 / A3
+
+#### Environment
+
+- OS: Linux
+- Python: 3.10+
+- Driver / runtime: Ascend NPU driver with CANN toolkit
+- Recommended operator library: **mindie-sd** (Ascend high-performance fused
+ operators — enables `adalayernorm` and other fused kernels automatically upon
+ installation)
+- vLLM version: Match the repository requirements for your checkout
+- vLLM-Omni version or commit: Use the commit you are deploying from
+
+#### Prerequisites
+
+Install the **mindie-sd** operator library to enable Ascend-optimized fused
+operators (`adalayernorm`, etc.):
+
+```bash
+git clone https://gitcode.com/Ascend/MindIE-SD.git && cd MindIE-SD
+
+# Comment out the tik_ops build step (not needed for this use case)
+sed -i 's|^\(\s*\)source ${current_script_dir}/build_tik_ops.sh|\1# source ${current_script_dir}/build_tik_ops.sh|' build/build_ops.sh
+
+python setup.py bdist_wheel
+cd dist
+pip install mindiesd-*.whl
+```
+
+After installation, enable the Laser Attention kernel for significant
+long-sequence speedups (up to ~40% at 720p in tested workloads):
+
+```bash
+export MINDIE_SD_FA_TYPE=ascend_laser_attention
+```
+
+When using HSDP with FSDP2, set the following environment variable to work
+around a PyTorch NPU multi-stream memory reuse issue
+([pytorch/pytorch#147168](https://github.com/pytorch/pytorch/issues/147168)).
+This issue has been fixed on CUDA but still applies to NPU:
+
+```bash
+export MULTI_STREAM_MEMORY_REUSE=2
+```
+
+#### Command
+
+**Distilled model (no CFG, recommended for distilled checkpoints):**
+
+```bash
+export MINDIE_SD_FA_TYPE=ascend_laser_attention
+export MULTI_STREAM_MEMORY_REUSE=2
+
+vllm serve \
+ --omni Wan-AI/Wan2.2-I2V-A14B-Diffusers \
+ --use-hsdp \
+ --usp 8 \
+ --vae-patch-parallel-size 8 \
+ --vae-use-tiling
+```
+
+**Official open-source model (with CFG):**
+
+```bash
+export MINDIE_SD_FA_TYPE=ascend_laser_attention
+export MULTI_STREAM_MEMORY_REUSE=2
+
+vllm serve \
+ --omni Wan-AI/Wan2.2-I2V-A14B-Diffusers \
+ --use-hsdp \
+ --usp 4 \
+ --cfg 2 \
+ --vae-patch-parallel-size 8 \
+ --vae-use-tiling
+```
+
+> **Why the difference?** With `--cfg 2`, two copies of the input (positive and
+> negative prompts) are processed in parallel, effectively doubling the compute
+> for the DiT backbone. USP is therefore halved from 8 to 4 so that the total
+> parallelism across the 8 cards remains balanced (`usp * cfg = 8`).
+
+#### Verification
+
+After the server is ready, see
+[`examples/online_serving/image_to_video/README.md`](../../examples/online_serving/image_to_video/README.md)
+for complete client examples and request formats.
+
+#### Notes
+
+- **Key flags:**
+ - `--omni` — enables vLLM-Omni diffusion serving.
+ - `--use-hsdp` — enables Hybrid Sharded Data Parallelism for the DiT model
+ weights.
+ - `--usp ` — Unified Sequence Parallelism degree.
+ - `--cfg ` — Classifier-Free Guidance parallelism; set to 2 for models
+ that require negative-prompt computation, omit for distilled models.
+ - `--vae-patch-parallel-size 8` — parallelizes VAE decoding across all 8
+ cards.
+ - `--vae-use-tiling` — enables tiled VAE decoding to reduce peak memory.
+- **Performance tips:**
+ - Installing mindie-sd and enabling Laser Attention
+ (`MINDIE_SD_FA_TYPE=ascend_laser_attention`) provides up to ~40%
+ performance improvement at 720p resolution due to long-sequence attention
+ optimization.
+- **Known limitations:**
+ - `MULTI_STREAM_MEMORY_REUSE=2` is required on NPU when using HSDP/FSDP2
+ due to a multi-stream memory reuse bug. This is not needed on CUDA.
diff --git a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
index 5ec7f1cc2b6..cdd0cac2c03 100644
--- a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
+++ b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json
@@ -105,22 +105,6 @@
}
},
"benchmark_params": [
- {
- "name": "512x512_steps20",
- "dataset": "random",
- "task": "t2i",
- "width": 512,
- "height": 512,
- "num-inference-steps": 20,
- "num-prompts": 10,
- "max-concurrency": 1,
- "enable-negative-prompt": true,
- "baseline": {
- "throughput_qps": 0.1,
- "latency_mean": 2.7,
- "peak_memory_mb_mean": 61000
- }
- },
{
"name": "1536x1536_steps35",
"dataset": "random",
diff --git a/tests/diffusion/models/dmd2/__init__.py b/tests/diffusion/models/dmd2/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py
new file mode 100644
index 00000000000..e270390bd99
--- /dev/null
+++ b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py
@@ -0,0 +1,180 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
+from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
+from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests).
+_DMD2_BASE = {
+ WanT2VDMD2Pipeline: Wan22Pipeline,
+ WanI2VDMD2Pipeline: Wan22I2VPipeline,
+ LTX2T2VDMD2Pipeline: LTX2Pipeline,
+ LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
+}
+
+
+def _make_pipeline(cls):
+ """Run the DMD2 __init__ with the base pipeline mocked out (no model weights loaded)."""
+
+ base = _DMD2_BASE[cls]
+ od_config = MagicMock()
+ od_config.model = "/nonexistent"
+
+ def _mock_base_init(self, *a, **kw):
+ self.od_config = od_config
+
+ with patch.object(base, "__init__", _mock_base_init):
+ pipeline = object.__new__(cls)
+ torch.nn.Module.__init__(pipeline)
+ cls.__init__(pipeline, od_config=od_config)
+ return pipeline
+
+
+def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest:
+ sp = OmniDiffusionSamplingParams(**sp_kwargs)
+ return OmniDiffusionRequest(
+ prompts=prompts or [{"prompt": "a cat dancing"}],
+ sampling_params=sp,
+ )
+
+
+@pytest.fixture(
+ params=list(_DMD2_BASE.keys()),
+ ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
+)
+def pipeline(request):
+ return _make_pipeline(request.param)
+
+
+# ---------------------------------------------------------------------------
+# num_inference_steps
+# ---------------------------------------------------------------------------
+
+
+def test_num_inference_steps_forced_to_dmd2_value(pipeline):
+ req = _make_request(num_inference_steps=40)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps
+
+
+def test_num_inference_steps_already_correct(pipeline):
+ req = _make_request(num_inference_steps=pipeline.num_inference_steps)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps
+
+
+# ---------------------------------------------------------------------------
+# guidance_scale
+# ---------------------------------------------------------------------------
+
+
+def test_guidance_scale_forced_to_one(pipeline):
+ req = _make_request(guidance_scale=5.0, guidance_scale_provided=True)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
+ assert req.sampling_params.guidance_scale_provided is False
+
+
+def test_guidance_scale_already_correct(pipeline):
+ req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
+
+
+def test_guidance_scale_provided_flag_cleared(pipeline):
+ """guidance_scale_provided=True must be cleared even if scale is already dmd2_guidance_scale."""
+ req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=True)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.guidance_scale_provided is False
+
+
+def test_guidance_scale_2_cleared(pipeline):
+ req = _make_request(guidance_scale_2=3.0)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.guidance_scale_2 is None
+
+
+def test_guidance_scale_2_unset_unchanged(pipeline):
+ req = _make_request()
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.guidance_scale_2 is None
+
+
+def test_true_cfg_scale_cleared(pipeline):
+ req = _make_request(true_cfg_scale=2.0)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.true_cfg_scale is None
+
+
+def test_do_classifier_free_guidance_forced_false(pipeline):
+ req = _make_request(do_classifier_free_guidance=True)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.do_classifier_free_guidance is False
+
+
+def test_is_cfg_negative_forced_false(pipeline):
+ req = _make_request(is_cfg_negative=True)
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.is_cfg_negative is False
+
+
+def test_negative_prompt_stripped_from_prompt_dict(pipeline):
+ req = _make_request(prompts=[{"prompt": "a cat", "negative_prompt": "blurry"}])
+ pipeline._sanitize_dmd2_request(req)
+ assert "negative_prompt" not in req.prompts[0]
+ assert req.prompts[0]["prompt"] == "a cat"
+
+
+def test_no_negative_prompt_unchanged(pipeline):
+ req = _make_request(prompts=[{"prompt": "a cat"}])
+ pipeline._sanitize_dmd2_request(req)
+ assert req.prompts[0] == {"prompt": "a cat"}
+
+
+def test_string_prompt_not_mutated(pipeline):
+ """String prompts (not dicts) must pass through unchanged."""
+ req = _make_request(prompts=["a cat dancing"])
+ pipeline._sanitize_dmd2_request(req)
+ assert req.prompts == ["a cat dancing"]
+
+
+def test_multiple_prompts_all_sanitized(pipeline):
+ req = _make_request(
+ prompts=[
+ {"prompt": "a cat", "negative_prompt": "blurry"},
+ {"prompt": "a dog", "negative_prompt": "ugly"},
+ ]
+ )
+ pipeline._sanitize_dmd2_request(req)
+ for p in req.prompts:
+ assert "negative_prompt" not in p
+
+
+# ---------------------------------------------------------------------------
+# Clean request — nothing changes
+# ---------------------------------------------------------------------------
+
+
+def test_clean_request_no_changes(pipeline):
+ req = _make_request(
+ guidance_scale=pipeline.dmd2_guidance_scale,
+ guidance_scale_provided=False,
+ do_classifier_free_guidance=False,
+ is_cfg_negative=False,
+ )
+ pipeline._sanitize_dmd2_request(req)
+ assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale
+ assert req.sampling_params.guidance_scale_provided is False
+ assert req.sampling_params.guidance_scale_2 is None
+ assert req.sampling_params.true_cfg_scale is None
+ assert req.sampling_params.do_classifier_free_guidance is False
+ assert req.sampling_params.is_cfg_negative is False
diff --git a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py
new file mode 100644
index 00000000000..32d00dbf18e
--- /dev/null
+++ b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py
@@ -0,0 +1,93 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline
+from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline
+from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+_DMD2_TIMESTEPS = [999, 937, 833, 624]
+
+# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests).
+_DMD2_BASE = {
+ WanT2VDMD2Pipeline: Wan22Pipeline,
+ WanI2VDMD2Pipeline: Wan22I2VPipeline,
+ LTX2T2VDMD2Pipeline: LTX2Pipeline,
+ LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline,
+}
+
+
+def _make_pipeline(cls):
+ """Run the DMD2 __init__ (including __init_dmd2__) with the base pipeline mocked."""
+
+ base = _DMD2_BASE[cls]
+ od_config = MagicMock()
+ od_config.model = "/nonexistent"
+
+ def _mock_base_init(self, *a, **kw):
+ self.od_config = od_config # __init_dmd2__ needs this
+
+ with patch.object(base, "__init__", _mock_base_init):
+ pipeline = object.__new__(cls)
+ torch.nn.Module.__init__(pipeline)
+ cls.__init__(pipeline, od_config=od_config)
+ return pipeline
+
+
+def _make_request(**sp_kwargs) -> OmniDiffusionRequest:
+ sp = OmniDiffusionSamplingParams(**sp_kwargs)
+ return OmniDiffusionRequest(prompts=[{"prompt": "a cat"}], sampling_params=sp)
+
+
+@pytest.fixture(
+ params=list(_DMD2_BASE.keys()),
+ ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"],
+)
+def pipeline(request):
+ return _make_pipeline(request.param)
+
+
+# ---------------------------------------------------------------------------
+# forward() timestep injection
+# ---------------------------------------------------------------------------
+
+
+def _fake_parent_forward(self, req, *args, num_inference_steps=40, **kwargs):
+ """Stub that calls set_timesteps as the real parent does."""
+ self.scheduler.set_timesteps(num_inference_steps, device="cpu")
+ return MagicMock()
+
+
+def test_forward_timesteps_match_dmd2_schedule(pipeline):
+ """After forward() runs, scheduler.timesteps must equal the DMD2 training schedule."""
+ parent = _DMD2_BASE[type(pipeline)]
+
+ # Baseline: calling set_timesteps(40) without the DMD2 override gives a different schedule
+ pipeline.scheduler.set_timesteps(40, device="cpu")
+ default_timesteps = pipeline.scheduler.timesteps.long().tolist()
+ assert default_timesteps == _DMD2_TIMESTEPS, (
+ "DMD2EulerScheduler should always return DMD2 timesteps regardless of num_steps"
+ )
+
+ with patch.object(parent, "forward", _fake_parent_forward):
+ pipeline.forward(_make_request())
+
+ assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS
+
+
+def test_forward_timesteps_idempotent_across_calls(pipeline):
+ """Successive forward() calls must not cause scheduler state to drift."""
+ parent = _DMD2_BASE[type(pipeline)]
+
+ with patch.object(parent, "forward", _fake_parent_forward):
+ pipeline.forward(_make_request())
+ pipeline.forward(_make_request())
+
+ assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS
diff --git a/tests/diffusion/models/ltx2/test_ltx2_hsdp.py b/tests/diffusion/models/ltx2/test_ltx2_hsdp.py
new file mode 100644
index 00000000000..4dd07e1bf82
--- /dev/null
+++ b/tests/diffusion/models/ltx2/test_ltx2_hsdp.py
@@ -0,0 +1,25 @@
+import pytest
+import torch.nn as nn
+
+from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def test_ltx2_exposes_hsdp_shard_conditions_for_transformer_blocks():
+ model = object.__new__(LTX2VideoTransformer3DModel)
+ nn.Module.__init__(model)
+ model.transformer_blocks = nn.ModuleList([nn.Linear(4, 4) for _ in range(2)])
+ model.norm_out = nn.LayerNorm(4)
+
+ conditions = getattr(model, "_hsdp_shard_conditions", None)
+
+ assert conditions is not None
+ assert len(conditions) == 1
+
+ matched = []
+ for name, module in model.named_modules():
+ if any(cond(name, module) for cond in conditions):
+ matched.append(name)
+
+ assert matched == ["transformer_blocks.0", "transformer_blocks.1"]
diff --git a/tests/diffusion/models/wan2_2/__init__.py b/tests/diffusion/models/wan2_2/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py
deleted file mode 100644
index 64c2b271c9c..00000000000
--- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py
+++ /dev/null
@@ -1,150 +0,0 @@
-from types import SimpleNamespace
-
-import PIL.Image
-import pytest
-import torch
-from torch import nn
-
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
- WAN22_MAX_SEQUENCE_LENGTH,
- Wan22Pipeline,
-)
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import (
- Wan22I2VPipeline,
-)
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import (
- Wan22TI2VPipeline,
-)
-from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import (
- Wan22VACEPipeline,
-)
-
-pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-
-
-class _RejectingTextEncoder:
- dtype = torch.float32
-
- def __call__(self, *args, **kwargs):
- raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length")
-
-
-class _FakeTokenBatch:
- def __init__(self, total_sequence_length: int):
- attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long)
- self.input_ids = attention_mask.clone()
- self.attention_mask = attention_mask
-
-
-class _FakeTokenizer:
- def __init__(self, total_sequence_length: int):
- self.total_sequence_length = total_sequence_length
-
- def __call__(self, *args, **kwargs):
- return _FakeTokenBatch(self.total_sequence_length)
-
-
-PIPELINE_CASES = [
- pytest.param(Wan22Pipeline, id="wan22-t2v"),
- pytest.param(Wan22I2VPipeline, id="wan22-i2v"),
- pytest.param(Wan22TI2VPipeline, id="wan22-ti2v"),
- pytest.param(Wan22VACEPipeline, id="wan22-vace"),
-]
-
-
-def _make_pipeline(pipeline_class: type, *, total_sequence_length: int):
- pipeline = object.__new__(pipeline_class)
- nn.Module.__init__(pipeline)
- pipeline.device = torch.device("cpu")
- pipeline.text_encoder = _RejectingTextEncoder()
- pipeline.tokenizer = _FakeTokenizer(total_sequence_length)
- pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
- return pipeline
-
-
-@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES)
-def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type):
- pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1)
-
- with pytest.raises(ValueError, match=r"got 513 tokens, but `max_sequence_length` is 512"):
- pipeline.encode_prompt(prompt="prompt")
-
-
-@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES)
-def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(pipeline_class: type):
- pipeline = _make_pipeline(pipeline_class, total_sequence_length=17)
-
- with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"):
- pipeline.encode_prompt(prompt="prompt", max_sequence_length=16)
-
-
-def _sampling_params(**overrides):
- defaults = dict(
- height=None,
- width=None,
- num_frames=None,
- num_inference_steps=None,
- generator=None,
- guidance_scale_provided=False,
- guidance_scale_2=None,
- boundary_ratio=None,
- num_outputs_per_prompt=0,
- max_sequence_length=None,
- seed=None,
- extra_args={},
- prompt_embeds=None,
- negative_prompt_embeds=None,
- )
- defaults.update(overrides)
- return SimpleNamespace(**defaults)
-
-
-@pytest.mark.parametrize(
- ("pipeline_class", "prompt_value", "forward_kwargs"),
- [
- pytest.param(Wan22Pipeline, "prompt", {}, id="wan22-t2v"),
- pytest.param(
- Wan22I2VPipeline,
- {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}},
- {"image": PIL.Image.new("RGB", (64, 64))},
- id="wan22-i2v",
- ),
- pytest.param(
- Wan22TI2VPipeline,
- {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}},
- {"image": PIL.Image.new("RGB", (64, 64))},
- id="wan22-ti2v",
- ),
- pytest.param(Wan22VACEPipeline, "prompt", {}, id="wan22-vace"),
- ],
-)
-def test_forward_defaults_to_wan22_tokenizer_max_length(
- pipeline_class: type,
- prompt_value,
- forward_kwargs,
-):
- pipeline = object.__new__(pipeline_class)
- nn.Module.__init__(pipeline)
- pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
- pipeline.boundary_ratio = None
- pipeline.vae_scale_factor_temporal = 4
- pipeline.vae_scale_factor_spatial = 8
- pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2))
-
- captured = {}
-
- def _fake_check_inputs(*args, **kwargs):
- captured["max_sequence_length"] = kwargs["max_sequence_length"]
- raise RuntimeError("stop after capture")
-
- pipeline.check_inputs = _fake_check_inputs
-
- req = SimpleNamespace(
- prompts=[prompt_value],
- sampling_params=_sampling_params(),
- )
-
- with pytest.raises(RuntimeError, match="stop after capture"):
- pipeline.forward(req, **forward_kwargs)
-
- assert captured["max_sequence_length"] == WAN22_MAX_SEQUENCE_LENGTH
diff --git a/tests/diffusion/test_diffusion_worker.py b/tests/diffusion/test_diffusion_worker.py
index e2bd7ef8a32..fc08c5f7f03 100644
--- a/tests/diffusion/test_diffusion_worker.py
+++ b/tests/diffusion/test_diffusion_worker.py
@@ -16,7 +16,7 @@
from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker
-pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu]
@pytest.fixture
@@ -81,17 +81,31 @@ def test_load_weights_empty_iterable(self, mocker: MockerFixture, mock_gpu_worke
class TestDiffusionWorkerSleep:
"""Test DiffusionWorker.sleep method."""
+ @pytest.fixture(autouse=True)
+ def setup_allocator(self, mocker: MockerFixture):
+ """
+ Unified interception of Allocators, and provision of default security values.
+ """
+ self.mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
+ self.mock_allocator = mocker.Mock()
+ self.mock_allocator_class.get_instance.return_value = self.mock_allocator
+ self.mock_allocator.get_current_usage.return_value = 4 * 1024**3
+ self.mock_allocator.sleep = mocker.Mock()
+
def test_sleep_level_1(self, mocker: MockerFixture, mock_gpu_worker):
"""Test sleep mode level 1 (offload weights only)."""
mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
- mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
+ mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
+ mock_platform.get_free_memory.side_effect = [10 * 1024**3, 12 * 1024**3]
+ mock_platform.get_device_total_memory.return_value = 80 * 1024**3
mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory")
# Setup process-scoped memory mocks
# Before sleep: 3GB used
# After sleep: 1GB used (freed 2GB)
+ initial_usage = 3 * 1024**3
mock_get_process_memory.side_effect = [
- 3 * 1024**3,
+ initial_usage,
1 * 1024**3,
]
@@ -99,25 +113,29 @@ def test_sleep_level_1(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
+ mock_allocator.get_current_usage.return_value = initial_usage
# Call sleep with level 1
result = mock_gpu_worker.sleep(level=1)
# Verify sleep was called with correct tags
mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",))
- assert result is True
+ assert bool(result) is True
# Verify buffers were NOT saved (level 1 doesn't save buffers)
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
"""Test sleep mode level 2 (offload all, save buffers)."""
mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
- mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
+ mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
+ mock_platform.get_free_memory.side_effect = [5 * 1024**3, 10 * 1024**3]
+ mock_platform.get_device_total_memory.return_value = 80 * 1024**3
mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory")
# Setup process-scoped memory mocks
+ initial_usage = 5 * 1024**3
mock_get_process_memory.side_effect = [
- 5 * 1024**3, # Before sleep
+ initial_usage, # Before sleep
1 * 1024**3, # After sleep (freed 4GB)
]
@@ -125,6 +143,7 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
+ mock_allocator.get_current_usage.return_value = initial_usage
# Mock pipeline buffers
mock_buffer1 = torch.randn(10, 10)
@@ -141,7 +160,7 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
# Verify sleep was called with empty tags (offload all)
mock_allocator.sleep.assert_called_once_with(offload_tags=tuple())
- assert result is True
+ assert bool(result) is True
# Verify buffers were saved
assert len(mock_gpu_worker._sleep_saved_buffers) == 2
@@ -151,22 +170,26 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker):
def test_sleep_memory_freed_validation(self, mocker: MockerFixture, mock_gpu_worker):
"""Test that sleep validates memory was actually freed."""
mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator")
- mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
+ mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform")
+ mock_platform.get_free_memory.return_value = 10 * 1024**3
+ mock_platform.get_device_total_memory.return_value = 80 * 1024**3
mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory")
# Simulate process memory increase (should trigger assertion error)
+ initial_usage = 1 * 1024**3
mock_get_process_memory.side_effect = [
- 1 * 1024**3, # Before sleep: 1GB used
+ initial_usage, # Before sleep: 1GB used
3 * 1024**3, # After sleep: 3GB used (negative freed)
]
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
+ mock_allocator.get_current_usage.return_value = initial_usage
# This should raise an assertion error
- with pytest.raises(AssertionError, match="Memory usage increased after sleeping"):
- mock_gpu_worker.sleep(level=1)
+ result = mock_gpu_worker.sleep(level=1)
+ assert result == initial_usage
def test_sleep_falls_back_to_device_memory_when_nvml_unavailable(self, mocker: MockerFixture, mock_gpu_worker):
"""Test sleep uses device-scoped fallback when NVML is unavailable."""
@@ -184,11 +207,12 @@ def test_sleep_falls_back_to_device_memory_when_nvml_unavailable(self, mocker: M
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.sleep = mocker.Mock()
+ mock_allocator.get_current_usage.return_value = 2 * 1024**3
result = mock_gpu_worker.sleep(level=1)
mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",))
- assert result is True
+ assert bool(result) is True
class TestDiffusionWorkerWakeUp:
@@ -202,6 +226,7 @@ def test_wake_up_without_buffers(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.wake_up = mocker.Mock()
+ mock_allocator.get_current_usage.return_value = 10 * 1024**3
# Ensure no saved buffers
mock_gpu_worker._sleep_saved_buffers = {}
@@ -211,7 +236,7 @@ def test_wake_up_without_buffers(self, mocker: MockerFixture, mock_gpu_worker):
# Verify allocator.wake_up was called
mock_allocator.wake_up.assert_called_once_with(["weights"])
- assert result is True
+ assert bool(result) is True
def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker):
"""Test wake_up with saved buffers (level 2 sleep)."""
@@ -221,6 +246,7 @@ def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker):
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.wake_up = mocker.Mock()
+ mock_allocator.get_current_usage.return_value = 10 * 1024**3
# Create saved buffers
saved_buffer1 = torch.randn(10, 10)
@@ -255,7 +281,7 @@ def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker):
# Verify saved buffers were cleared
assert len(mock_gpu_worker._sleep_saved_buffers) == 0
- assert result is True
+ assert bool(result) is True
def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_worker):
"""Test wake_up only restores buffers that were saved."""
@@ -265,6 +291,7 @@ def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_wo
mock_allocator = mocker.Mock()
mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator)
mock_allocator.wake_up = mocker.Mock()
+ mock_allocator.get_current_usage.return_value = 10 * 1024**3
# Only save buffer1, not buffer2
saved_buffer1 = torch.randn(10, 10)
@@ -293,4 +320,4 @@ def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_wo
# buffer2 should NOT be restored since it wasn't saved
mock_buffer2.data.copy_.assert_not_called()
- assert result is True
+ assert bool(result) is True
diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py
index 4bc3e05fe91..9ec06e8107d 100644
--- a/tests/diffusion/test_multiproc_engine_concurrency.py
+++ b/tests/diffusion/test_multiproc_engine_concurrency.py
@@ -1,17 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import asyncio
+import multiprocessing as mp
import queue
import threading
+import time
from types import SimpleNamespace
+from unittest.mock import MagicMock, Mock
import pytest
import torch
+import zmq
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.data import DiffusionOutput
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor
from vllm_omni.diffusion.sched import RequestScheduler
+from vllm_omni.diffusion.stage_diffusion_proc import StageDiffusionProc
+from vllm_omni.outputs import OmniRequestOutput
pytestmark = [pytest.mark.diffusion, pytest.mark.core_model, pytest.mark.cpu]
@@ -51,6 +59,8 @@ def _make_executor(num_gpus: int = 1):
executor._result_mq = mock_rmq
executor._closed = False
executor._processes = []
+ executor.is_failed = False
+ executor._failure_callbacks = []
return executor, req_q, res_q
@@ -334,8 +344,9 @@ def test_collective_rpc_closed_executor_raises(self):
class TestCollectiveRpcTimeoutWhileLockHeld:
"""``collective_rpc(timeout=...)`` must honour its timeout even when
- another thread holds ``engine._rpc_lock`` indefinitely (e.g. a stalled
- ``add_req`` waiting on an unresponsive worker).
+ another thread holds ``engine._rpc_lock`` indefinitely (e.g. request
+ execution stalled on ``add_req_and_wait_for_response`` → ``execute_fn``
+ → ``collective_rpc`` while blocked on an unresponsive worker).
"""
def test_rpc_times_out_when_lock_held_directly(self):
@@ -359,10 +370,10 @@ def _hold_lock():
with pytest.raises(TimeoutError):
engine.collective_rpc("health", timeout=0.5)
- def test_rpc_times_out_when_add_req_stalled_on_worker(self):
+ def test_rpc_times_out_when_request_execution_stalled_on_worker(self):
"""Real-world scenario the bot flagged:
- ``add_req`` holds ``_rpc_lock`` while blocked on
+ The scheduler/execute path holds ``_rpc_lock`` while blocked on
``executor._result_mq.dequeue()`` because the worker never replies.
A concurrent ``collective_rpc(timeout=...)`` must still time out
instead of hanging forever waiting for the lock.
@@ -428,3 +439,353 @@ def _hold_and_release():
t.join(5)
assert result.error == "ok"
+
+
+# ───────── error handling: EngineDeadError propagation through layers ─────
+
+
+class TestMultiprocExecutorRaisesEngineDeadError:
+ """``collective_rpc`` raises ``EngineDeadError`` when the engine is failed."""
+
+ def test_collective_rpc_raises_when_is_failed(self):
+ executor = object.__new__(MultiprocDiffusionExecutor)
+ executor._closed = False
+ executor._broadcast_mq = MagicMock()
+ executor._result_mq = MagicMock()
+ executor._result_mq.dequeue = MagicMock(side_effect=TimeoutError)
+ executor.is_failed = True
+
+ with pytest.raises(EngineDeadError):
+ executor.collective_rpc(
+ "generate",
+ args=(MagicMock(),),
+ unique_reply_rank=0,
+ exec_all_ranks=True,
+ )
+
+ def test_collective_rpc_raises_mid_dequeue_when_is_failed(self):
+ """Worker dies while we are polling the dequeue loop."""
+ executor, _, res_q = _make_executor()
+
+ call_count = 0
+ orig_dequeue = executor._result_mq.dequeue
+
+ def _dying_dequeue(timeout=None):
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ executor.is_failed = True
+ raise TimeoutError
+ return orig_dequeue(timeout=timeout)
+
+ executor._result_mq.dequeue = _dying_dequeue
+
+ with pytest.raises(EngineDeadError):
+ executor.collective_rpc(
+ "generate",
+ args=(MagicMock(),),
+ unique_reply_rank=0,
+ exec_all_ranks=True,
+ )
+
+
+class TestDiffusionEngineDeadErrorPassthrough:
+ """``DiffusionEngine.add_req_and_wait_for_response`` re-raises
+ ``EngineDeadError`` from executor and wraps other errors."""
+
+ def test_engine_dead_error_propagates(self):
+ engine, executor, _, _ = _make_engine()
+ engine.execute_fn = Mock(side_effect=EngineDeadError())
+
+ with pytest.raises(EngineDeadError):
+ engine.add_req_and_wait_for_response(_mock_request("dead"))
+
+ def test_runtime_error_wrapped_in_output(self):
+ engine, executor, _, _ = _make_engine()
+ engine.execute_fn = Mock(side_effect=RuntimeError("gpu fault"))
+
+ out = engine.add_req_and_wait_for_response(_mock_request("fault"))
+ assert isinstance(out, DiffusionOutput)
+ assert "gpu fault" in out.error
+
+
+class TestStageDiffusionClientErrorPropagation:
+ """Error surface behaviour of ``StageDiffusionClient``.
+
+ Uses ``object.__new__`` to construct a client without spawning a real
+ subprocess, then manually sets the fields needed for each test.
+ """
+
+ def _make_client(self, *, engine_dead=False, proc_alive=True):
+ from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
+
+ client = object.__new__(StageDiffusionClient)
+ client.stage_id = 0
+ client.final_output = True
+ client.final_output_type = "image"
+ client.default_sampling_params = None
+ client.custom_process_input_func = None
+ client.engine_input_source = None
+
+ client._output_queue = asyncio.Queue()
+ client._rpc_results = {}
+ client._pending_rpcs = set()
+ client._tasks = {}
+ client._shutting_down = False
+ client._engine_dead = engine_dead
+ client._owns_process = True
+ client._proc = MagicMock(
+ is_alive=MagicMock(return_value=proc_alive),
+ exitcode=1,
+ )
+ client._request_socket = MagicMock()
+ client._response_socket = MagicMock()
+ client._encoder = MagicMock()
+ client._decoder = MagicMock()
+
+ return client
+
+ @pytest.mark.asyncio
+ async def test_add_request_raises_when_dead(self):
+ client = self._make_client(engine_dead=True)
+
+ with pytest.raises(EngineDeadError):
+ await client.add_request_async("req-3", "test prompt", None)
+
+ def test_check_health_raises_when_dead(self):
+ client = self._make_client(engine_dead=True)
+
+ with pytest.raises(EngineDeadError):
+ client.check_health()
+
+ def test_check_health_ok_when_alive(self):
+ client = self._make_client()
+ client.check_health()
+
+ def test_get_output_raises_engine_dead_when_dead(self):
+ """When ``_engine_dead`` is True and the output queue is empty,
+ ``get_diffusion_output_nowait`` must raise ``EngineDeadError``."""
+ client = self._make_client(engine_dead=True)
+ # Simulate _drain_responses as a no-op (no ZMQ socket)
+ client._response_socket.recv.side_effect = zmq.Again
+
+ with pytest.raises(EngineDeadError):
+ client.get_diffusion_output_nowait()
+
+ def test_get_output_returns_none_when_alive_and_empty(self):
+ """When the engine is alive and the queue is empty, return None."""
+ client = self._make_client()
+ client._response_socket.recv.side_effect = zmq.Again
+
+ assert client.get_diffusion_output_nowait() is None
+
+ def test_check_health_raises_when_proc_dead(self):
+ """``check_health`` detects a dead subprocess via ``_proc.is_alive()``
+ and raises ``EngineDeadError``, setting ``_engine_dead`` as a
+ side effect."""
+ client = self._make_client(proc_alive=False)
+
+ with pytest.raises(EngineDeadError, match="not alive"):
+ client.check_health()
+
+ assert client._engine_dead is True
+
+ def test_get_output_raises_when_proc_dead(self):
+ """When the subprocess has died (non-signal exit) and the output
+ queue is empty, ``get_diffusion_output_nowait`` must raise
+ ``EngineDeadError`` with the exit code."""
+ client = self._make_client(proc_alive=False)
+ client._response_socket.recv.side_effect = zmq.Again
+
+ with pytest.raises(EngineDeadError, match="exit code"):
+ client.get_diffusion_output_nowait()
+
+ assert client._engine_dead is True
+
+ def test_get_output_returns_none_on_signal_death(self):
+ """When the subprocess was killed by a signal (exit code > 128),
+ ``get_diffusion_output_nowait`` returns ``None`` and sets
+ ``_shutting_down`` instead of raising."""
+ client = self._make_client(proc_alive=False)
+ client._proc.exitcode = 137 # SIGKILL (128 + 9)
+ client._response_socket.recv.side_effect = zmq.Again
+
+ result = client.get_diffusion_output_nowait()
+
+ assert result is None
+ assert client._shutting_down is True
+ assert client._engine_dead is True
+
+
+# ───────── monitor thread & death sentinel integration tests ─────────
+
+
+def _poll_flag(get_flag, *, timeout=5.0, interval=0.05) -> bool:
+ """Poll until ``get_flag()`` returns True or *timeout* elapses."""
+ deadline = time.monotonic() + timeout
+ while time.monotonic() < deadline:
+ if get_flag():
+ return True
+ time.sleep(interval)
+ return False
+
+
+def _make_short_lived_process() -> mp.Process:
+ """Spawn a real subprocess that exits immediately.
+
+ The process must be started with ``"fork"`` (or the platform default)
+ so that it can use a plain ``lambda`` as its target — ``"spawn"`` would
+ fail to pickle it.
+ """
+ ctx = mp.get_context("fork")
+ p = ctx.Process(target=lambda: None, name="ShortLivedWorker-0")
+ p.start()
+ return p
+
+
+class TestMultiprocExecutorWorkerMonitor:
+ """Integration tests for ``start_worker_monitor``.
+
+ Uses real short-lived subprocesses so that OS-level sentinel fd
+ readiness is exercised end-to-end.
+ """
+
+ def test_worker_monitor_sets_is_failed_and_calls_callbacks_on_death(self):
+ """When a worker process dies, the monitor thread must:
+ 1. Set ``is_failed = True``
+ 2. Call ``shutdown()`` (which sets ``_closed = True``)
+ 3. Invoke all registered failure callbacks
+ """
+ executor = object.__new__(MultiprocDiffusionExecutor)
+ executor._closed = False
+ executor.is_failed = False
+ executor._failure_callbacks = []
+ executor._broadcast_mq = None
+ executor._result_mq = None
+ executor.resources = None
+ # Use a no-op so shutdown() doesn't crash on None resources.
+ executor._finalizer = lambda: None
+
+ proc = _make_short_lived_process()
+ executor._processes = [proc]
+
+ callback_called = threading.Event()
+ executor.register_failure_callback(callback_called.set)
+
+ executor.start_worker_monitor()
+
+ # Wait for the process to exit and the monitor to react.
+ proc.join(5)
+ assert _poll_flag(lambda: executor.is_failed), "is_failed was not set"
+ assert executor._closed, "shutdown() was not called"
+ assert callback_called.wait(timeout=2), "failure callback was not invoked"
+
+ def test_worker_monitor_noop_when_already_closed(self):
+ """If ``_closed`` is already True when the process dies (orderly
+ shutdown), the monitor must *not* set ``is_failed``."""
+ executor = object.__new__(MultiprocDiffusionExecutor)
+ executor._closed = True # already shut down
+ executor.is_failed = False
+ executor._failure_callbacks = []
+ executor._broadcast_mq = None
+ executor._result_mq = None
+ executor.resources = None
+ executor._finalizer = lambda: None
+
+ proc = _make_short_lived_process()
+ executor._processes = [proc]
+
+ executor.start_worker_monitor()
+ proc.join(5)
+
+ # Give the monitor thread a chance to run (it should early-return).
+ time.sleep(0.3)
+ assert not executor.is_failed, "is_failed should remain False on orderly shutdown"
+
+
+class TestStageDiffusionClientProcMonitor:
+ """Integration test for ``StageDiffusionClient._start_proc_monitor``.
+
+ Uses a real short-lived subprocess to verify the sentinel-based
+ detection pipeline.
+ """
+
+ def test_proc_monitor_sets_engine_dead_on_process_death(self):
+ """When the subprocess dies, the monitor thread must set
+ ``_engine_dead = True``."""
+ from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
+
+ client = object.__new__(StageDiffusionClient)
+ client.stage_id = 0
+ client._shutting_down = False
+ client._engine_dead = False
+
+ proc = _make_short_lived_process()
+ client._proc = proc
+
+ client._start_proc_monitor()
+ proc.join(5)
+
+ assert _poll_flag(lambda: client._engine_dead), "_engine_dead was not set"
+
+
+class TestDrainResponsesDeathSentinel:
+ """Tests for death sentinel and error routing in
+ ``StageDiffusionClient._drain_responses()``.
+ """
+
+ def _make_client(self):
+ from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
+
+ client = object.__new__(StageDiffusionClient)
+ client.stage_id = 0
+ client._engine_dead = False
+ client._shutting_down = False
+ client._output_queue = asyncio.Queue()
+ client._rpc_results = {}
+ client._pending_rpcs = set()
+ client._response_socket = MagicMock()
+ client._decoder = MagicMock()
+ return client
+
+ def test_drain_responses_sets_engine_dead_on_death_sentinel(self):
+ """When ``_drain_responses`` receives the ``DIFFUSION_PROC_DEAD``
+ sentinel, it must set ``_engine_dead = True`` and stop draining
+ (decoder is never called)."""
+ client = self._make_client()
+
+ # First recv returns the death sentinel, second would be a normal
+ # message but should never be reached.
+ client._response_socket.recv.side_effect = [
+ StageDiffusionProc.DIFFUSION_PROC_DEAD,
+ b"should-not-be-reached",
+ ]
+
+ client._drain_responses()
+
+ assert client._engine_dead is True
+ client._decoder.decode.assert_not_called()
+
+ def test_drain_responses_routes_error_as_omni_request_output(self):
+ """When ``_drain_responses`` receives a ``{"type": "error"}`` message
+ with a ``request_id``, it must place an ``OmniRequestOutput`` with
+ the error on ``_output_queue``."""
+ client = self._make_client()
+
+ error_msg = {
+ "type": "error",
+ "request_id": "req-fail",
+ "error": "gpu fault",
+ }
+ # First recv returns the encoded error, second raises zmq.Again.
+ client._response_socket.recv.side_effect = [b"encoded-error", zmq.Again]
+ client._decoder.decode.return_value = error_msg
+
+ client._drain_responses()
+
+ assert not client._output_queue.empty()
+ output = client._output_queue.get_nowait()
+ assert isinstance(output, OmniRequestOutput)
+ assert output.request_id == "req-fail"
+ assert output.error == "gpu fault"
+ assert output.finished is True
diff --git a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
index 24daa8ccf54..3aa5da85c24 100644
--- a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
+++ b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
@@ -567,7 +567,6 @@ def test_wan22_i2v_diffusers_offline_generates_video(
@pytest.mark.benchmark
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=2)
-@pytest.mark.skip(reason="issue: #2874")
@pytest.mark.parametrize("omni_server", SERVER_CASES, indirect=True)
def test_wan22_i2v_online_serving_generates_video(
omni_server,
diff --git a/tests/e2e/offline_inference/test_omni_sleep_mode.py b/tests/e2e/offline_inference/test_omni_sleep_mode.py
new file mode 100644
index 00000000000..2ad0b53b010
--- /dev/null
+++ b/tests/e2e/offline_inference/test_omni_sleep_mode.py
@@ -0,0 +1,159 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""
+End-to-end tests for Omni Sleep Mode across various model architectures.
+"""
+
+import pytest
+import torch
+from vllm import SamplingParams
+
+from tests.helpers.mark import hardware_test
+from vllm_omni.entrypoints.async_omni import AsyncOmni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+MODEL = "ByteDance-Seed/BAGEL-7B-MoT"
+MODEL_DIFF = "riverclouds/qwen_image_random"
+
+
+def get_ack_info(ack, key, default=None):
+ if hasattr(ack, key):
+ return getattr(ack, key)
+ if isinstance(ack, dict):
+ return ack.get(key, default)
+ return default
+
+
+def get_dynamic_devices(stage_idx, num_stages, tp_size):
+ total_gpus = torch.cuda.device_count()
+ gpus_per_stage = tp_size
+ start_idx = stage_idx * gpus_per_stage
+ if start_idx + gpus_per_stage > total_gpus:
+ start_idx = start_idx % total_gpus
+ device_ids = [str(start_idx + i) for i in range(gpus_per_stage)]
+ return ",".join(device_ids)
+
+
+# Test 1: Diffusion Model (2-Stage BAGEL)
+@pytest.mark.advanced_model
+@pytest.mark.omni
+@pytest.mark.parametrize("tp_size", [1, 2])
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.asyncio
+async def test_diffusion_model_sleep_tp(tp_size):
+ num_gpus = torch.cuda.device_count()
+ if num_gpus < tp_size:
+ pytest.skip(f"Skipping TP={tp_size}")
+
+ engine_args = {
+ "model": MODEL,
+ "enable_sleep_mode": True,
+ "tensor_parallel_size": tp_size,
+ "enforce_eager": True,
+ "trust_remote_code": True,
+ "dtype": "bfloat16",
+ "gpu_memory_utilization": 0.5,
+ }
+
+ engine = AsyncOmni(**engine_args, stage_init_timeout=1200)
+ try:
+ # BAGEL requires 2 params
+ diff_sp = OmniDiffusionSamplingParams(num_inference_steps=2, height=256, width=256)
+ llm_sp = SamplingParams()
+
+ # Warmup
+ async for _ in engine.generate("test", sampling_params=[llm_sp, diff_sp]):
+ pass
+
+ # Sleep all
+ acks = await engine.sleep(level=2)
+ statuses = [get_ack_info(ack, "status") for ack in acks]
+ assert all(s == "SUCCESS" for s in statuses), f"Sleep failed. Statuses: {statuses}"
+
+ # Wakeup & Verify
+ await engine.wake_up()
+ async for _ in engine.generate("verify", sampling_params=[llm_sp, diff_sp]):
+ pass
+
+ print(f"Diffusion TP={tp_size} Lifecycle OK")
+ finally:
+ engine.shutdown()
+
+
+# Test 2: Multi-stage Manual Config
+@pytest.mark.advanced_model
+@pytest.mark.omni
+@pytest.mark.parametrize("tp_size", [1, 2])
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.asyncio
+async def test_multistage_sleep_h100(tp_size):
+ num_gpus = torch.cuda.device_count()
+ if num_gpus < tp_size * 2:
+ pytest.skip("Not enough GPUs")
+
+ stages = []
+ for i in range(2):
+ devs = get_dynamic_devices(i, 2, tp_size)
+ stages.append(
+ {
+ "stage_id": i,
+ "stage_type": "llm" if i == 0 else "diffusion",
+ "runtime": {"process": True, "devices": devs},
+ "engine_args": {
+ "model": MODEL,
+ "model_stage": "thinker" if i == 0 else "base",
+ "tensor_parallel_size": tp_size,
+ "gpu_memory_utilization": 0.4,
+ "dtype": "bfloat16",
+ "enable_sleep_mode": True,
+ "trust_remote_code": True,
+ },
+ }
+ )
+
+ connectors = [{"src_stage_id": 0, "dst_stage_id": 1, "connector_type": "queue"}]
+
+ engine = AsyncOmni(
+ model=MODEL, stages=stages, connectors=connectors, enable_sleep_mode=True, stage_init_timeout=1200
+ )
+ try:
+ sp = OmniDiffusionSamplingParams(num_inference_steps=2)
+ async for _ in engine.generate("warmup", sampling_params=[SamplingParams(), sp]):
+ pass
+
+ acks = await engine.sleep(stage_ids=[0, 1], level=2)
+ assert len(acks) == 2 * tp_size
+
+ await engine.wake_up(stage_ids=[0, 1])
+ async for _ in engine.generate("verify", sampling_params=[SamplingParams(), sp]):
+ pass
+ finally:
+ engine.shutdown()
+
+
+# Test 3: Pure Diffusion Single-Stage
+@pytest.mark.advanced_model
+@pytest.mark.omni
+@pytest.mark.parametrize("tp_size", [1, 2])
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.asyncio
+async def test_pure_diffusion_scenario(tp_size):
+ engine_args = {
+ "model": MODEL_DIFF,
+ "enable_sleep_mode": True,
+ "tensor_parallel_size": tp_size,
+ "enforce_eager": True,
+ "dtype": "bfloat16",
+ "gpu_memory_utilization": 0.5,
+ }
+
+ engine = AsyncOmni(**engine_args, stage_init_timeout=1200)
+ try:
+ await engine.sleep(level=1)
+ await engine.wake_up()
+ async for _ in engine.generate("test", sampling_params=[SamplingParams()]):
+ pass
+ print("Pure Diffusion OK")
+ finally:
+ engine.shutdown()
diff --git a/tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py b/tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py
new file mode 100644
index 00000000000..5b51615feaf
--- /dev/null
+++ b/tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py
@@ -0,0 +1,137 @@
+"""
+Core-model CI guard for Qwen3-Omni multi-replica stage-pool routing on 4 GPUs.
+"""
+
+import os
+
+import pytest
+
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video
+from tests.helpers.runtime import OmniResponse, OmniServerParams, dummy_messages_from_mix_data
+from tests.helpers.stage_config import get_deploy_config_path
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+
+MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+MULTI_REPLICA_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe_multi_replicas_4gpu.yaml")
+ROUTE_STRESS_REQUESTS = 6
+MIXED_MODAL_REQUESTS = 4
+
+test_params = [
+ OmniServerParams(
+ model=MODEL,
+ stage_config_path=MULTI_REPLICA_DEPLOY,
+ server_args=["--disable-log-stats"],
+ )
+]
+
+
+def _system_prompt() -> dict[str, object]:
+ return {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": (
+ "You are Qwen, a virtual human developed by the Qwen Team, "
+ "Alibaba Group, capable of perceiving auditory and visual inputs, "
+ "as well as generating text and speech."
+ ),
+ }
+ ],
+ }
+
+
+def _text_messages() -> list[dict[str, object]]:
+ return dummy_messages_from_mix_data(
+ system_prompt=_system_prompt(),
+ content_text="What is the capital of China? Answer in one short sentence.",
+ )
+
+
+def _mixed_messages() -> list[dict[str, object]]:
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}"
+ return dummy_messages_from_mix_data(
+ system_prompt=_system_prompt(),
+ video_data_url=video_data_url,
+ image_data_url=image_data_url,
+ audio_data_url=audio_data_url,
+ content_text="What is recited in the audio? What is in this image? Describe the video briefly.",
+ )
+
+
+def _assert_batch_size(responses: list[OmniResponse], expected: int) -> None:
+ assert len(responses) == expected, f"Expected {expected} responses, got {len(responses)}"
+ assert all(resp.success for resp in responses), "At least one request failed"
+ assert all(resp.e2e_latency is not None and resp.e2e_latency > 0 for resp in responses), "Missing request latency"
+
+
+def _assert_text_outputs(responses: list[OmniResponse]) -> None:
+ assert all(resp.text_content is not None for resp in responses), "Missing text output"
+ assert all(resp.audio_bytes is None for resp in responses), "Text-only request unexpectedly produced audio"
+
+
+def _assert_audio_outputs(responses: list[OmniResponse], *, expect_text: bool) -> None:
+ assert all(resp.audio_bytes is not None and len(resp.audio_bytes) > 128 for resp in responses), (
+ "Missing audio output"
+ )
+ if expect_text:
+ assert all(resp.text_content is not None for resp in responses), "Missing text output"
+ else:
+ assert all(not (resp.text_content or "").strip() for resp in responses), (
+ "Audio-only request unexpectedly produced text"
+ )
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_text_only_batch_uses_multi_replica_talker(omni_server, openai_client) -> None:
+ request_config = {
+ "model": omni_server.model,
+ "messages": _text_messages(),
+ "stream": False,
+ "modalities": ["text"],
+ }
+
+ responses = openai_client.send_omni_request(request_config, request_num=ROUTE_STRESS_REQUESTS)
+ _assert_batch_size(responses, ROUTE_STRESS_REQUESTS)
+ _assert_text_outputs(responses)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_text_to_audio_stream_batch_uses_multi_replica_vocoder(omni_server, openai_client) -> None:
+ request_config = {
+ "model": omni_server.model,
+ "messages": _text_messages(),
+ "stream": True,
+ "modalities": ["audio"],
+ }
+
+ responses = openai_client.send_omni_request(request_config, request_num=ROUTE_STRESS_REQUESTS)
+ _assert_batch_size(responses, ROUTE_STRESS_REQUESTS)
+ _assert_audio_outputs(responses, expect_text=False)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_mixed_modal_stream_batch_generates_text_and_audio(omni_server, openai_client) -> None:
+ request_config = {
+ "model": omni_server.model,
+ "messages": _mixed_messages(),
+ "stream": True,
+ }
+
+ responses = openai_client.send_omni_request(request_config, request_num=MIXED_MODAL_REQUESTS)
+ _assert_batch_size(responses, MIXED_MODAL_REQUESTS)
+ _assert_audio_outputs(responses, expect_text=True)
diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py
index 3700e426d42..d86483f1561 100644
--- a/tests/engine/test_async_omni_engine_input.py
+++ b/tests/engine/test_async_omni_engine_input.py
@@ -59,7 +59,7 @@ def test_build_add_request_message_preserves_additional_information(mocker: Mock
assert request.additional_information is not None
assert request.additional_information.entries["text"].list_data == ["hello world"]
assert request.additional_information.entries["speaker"].list_data == ["vivian"]
- output_processor.add_request.assert_called_once()
+ output_processor.add_request.assert_not_called()
def test_build_add_request_message_with_resumable_streaming(mocker: MockerFixture):
diff --git a/tests/engine/test_async_omni_engine_outputs.py b/tests/engine/test_async_omni_engine_outputs.py
index ef3cfab3bf8..47b3d5e9f14 100644
--- a/tests/engine/test_async_omni_engine_outputs.py
+++ b/tests/engine/test_async_omni_engine_outputs.py
@@ -63,3 +63,39 @@ async def test_try_get_output_async_raises_after_orchestrator_dies(mocker: Mocke
with pytest.raises(RuntimeError, match="Orchestrator died unexpectedly"):
await engine.try_get_output_async()
+
+
+def test_fatal_error_message_surfaces_through_try_get_output(mocker: MockerFixture):
+ """When the orchestrator thread crashes, it enqueues a fatal error message.
+
+ ``try_get_output`` must return this message so the caller
+ (``OmniBase._handle_output_message``) can detect the fatal flag.
+ """
+ fatal_msg = {"type": "error", "error": "Orchestrator thread crashed", "fatal": True}
+
+ mock_queue = mocker.MagicMock()
+ mock_queue.sync_q.get.return_value = fatal_msg
+
+ engine = _make_engine(mock_queue, mocker, thread_alive=False)
+
+ msg = engine.try_get_output()
+ assert msg is not None
+ assert msg["type"] == "error"
+ assert msg["fatal"] is True
+ assert "crashed" in msg["error"]
+
+
+@pytest.mark.asyncio
+async def test_fatal_error_message_surfaces_through_try_get_output_async(mocker: MockerFixture):
+ """Async variant of the fatal error message test."""
+ fatal_msg = {"type": "error", "error": "Orchestrator thread crashed", "fatal": True}
+
+ mock_queue = mocker.MagicMock()
+ mock_queue.sync_q.get_nowait.return_value = fatal_msg
+
+ engine = _make_engine(mock_queue, mocker, thread_alive=False)
+
+ msg = await engine.try_get_output_async()
+ assert msg is not None
+ assert msg["type"] == "error"
+ assert msg["fatal"] is True
diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py
index e0999ca8796..ed1d15b3d1e 100644
--- a/tests/engine/test_async_omni_engine_stage_init.py
+++ b/tests/engine/test_async_omni_engine_stage_init.py
@@ -1,15 +1,138 @@
import importlib
import os
import threading
+import time
import types
import pytest
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
+from vllm_omni.engine.stage_init_utils import (
+ LogicalStageInitPlan,
+ ReplicaInitPlan,
+ build_stage0_input_processor,
+ compute_replica_layout,
+)
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+def _make_llm_metadata(
+ stage_id: int,
+ *,
+ replica_id: int = 0,
+ final_output: bool = False,
+ final_output_type: str | None = None,
+ is_comprehension: bool = False,
+):
+ return types.SimpleNamespace(
+ stage_id=stage_id,
+ stage_type="llm",
+ runtime_cfg={},
+ prompt_expand_func=None,
+ final_output=final_output,
+ final_output_type=final_output_type,
+ default_sampling_params=types.SimpleNamespace(name=f"sp-{stage_id}-{replica_id}"),
+ custom_process_input_func=None,
+ engine_input_source=[] if stage_id == 0 else [stage_id - 1],
+ engine_output_type="token_ids",
+ replica_id=replica_id,
+ is_comprehension=is_comprehension,
+ )
+
+
+def _make_diffusion_metadata(stage_id: int, *, replica_id: int = 0, final_output_type: str = "image"):
+ return types.SimpleNamespace(
+ stage_id=stage_id,
+ stage_type="diffusion",
+ runtime_cfg={"devices": str(replica_id)},
+ prompt_expand_func=None,
+ final_output=True,
+ final_output_type=final_output_type,
+ default_sampling_params=types.SimpleNamespace(name=f"dsp-{stage_id}-{replica_id}"),
+ custom_process_input_func=None,
+ engine_input_source=[],
+ cfg_kv_collect_func=None,
+ replica_id=replica_id,
+ )
+
+
+def _make_llm_plan(
+ stage_idx: int,
+ *,
+ configured_stage_id: int,
+ vllm_config: object,
+ num_replicas: int = 1,
+ final_output: bool = False,
+ final_output_type: str | None = None,
+ is_comprehension: bool = False,
+):
+ replicas: list[ReplicaInitPlan] = []
+ for replica_id in range(num_replicas):
+ stage_cfg = types.SimpleNamespace(
+ stage_id=configured_stage_id,
+ stage_type="llm",
+ runtime=types.SimpleNamespace(devices=str(replica_id)),
+ engine_args={},
+ )
+ replicas.append(
+ ReplicaInitPlan(
+ replica_id=replica_id,
+ num_replicas=num_replicas,
+ launch_mode="local",
+ stage_cfg=stage_cfg,
+ metadata=_make_llm_metadata(
+ configured_stage_id,
+ replica_id=replica_id,
+ final_output=final_output,
+ final_output_type=final_output_type,
+ is_comprehension=is_comprehension and replica_id == 0,
+ ),
+ stage_connector_spec={},
+ omni_kv_connector=(None, None, None),
+ stage_vllm_config=vllm_config,
+ executor_class=object,
+ )
+ )
+ return LogicalStageInitPlan(
+ stage_idx=stage_idx,
+ configured_stage_id=configured_stage_id,
+ replicas=replicas,
+ )
+
+
+def _make_diffusion_plan(
+ stage_idx: int,
+ *,
+ configured_stage_id: int,
+ num_replicas: int = 1,
+):
+ replicas: list[ReplicaInitPlan] = []
+ for replica_id in range(num_replicas):
+ stage_cfg = types.SimpleNamespace(
+ stage_id=configured_stage_id,
+ stage_type="diffusion",
+ runtime=types.SimpleNamespace(devices=str(replica_id)),
+ engine_args={},
+ )
+ replicas.append(
+ ReplicaInitPlan(
+ replica_id=replica_id,
+ num_replicas=num_replicas,
+ launch_mode="local",
+ stage_cfg=stage_cfg,
+ metadata=_make_diffusion_metadata(configured_stage_id, replica_id=replica_id),
+ stage_connector_spec={},
+ omni_kv_connector=(None, None, None),
+ )
+ )
+ return LogicalStageInitPlan(
+ stage_idx=stage_idx,
+ configured_stage_id=configured_stage_id,
+ replicas=replicas,
+ )
+
+
def test_stage_engine_core_client_module_reload_keeps_forward_refs_deferred():
"""Regression test for forward references in make_async_mp_client."""
import vllm_omni.engine.stage_engine_core_client as client_mod
@@ -21,64 +144,68 @@ def test_stage_engine_core_client_module_reload_keeps_forward_refs_deferred():
)
-def test_initialize_stages_restores_device_visibility_after_diffusion_init(monkeypatch):
- """Regression test for stage device env leakage across stage init.
+def test_compute_replica_layout_splits_diffusion_devices_by_world_size():
+ stage_cfg = types.SimpleNamespace(
+ stage_id=0,
+ stage_type="diffusion",
+ engine_args={"parallel_config": {"tensor_parallel_size": 2}},
+ runtime={"devices": "0,1,2,3", "num_replicas": 2},
+ )
+
+ replicas_per_stage, replica_devices_map = compute_replica_layout([stage_cfg])
+
+ assert replicas_per_stage == [2]
+ assert replica_devices_map == {0: ["0,1", "2,3"]}
+
+
+def test_collect_initialized_clients_for_cleanup_deduplicates_clients():
+ shared = types.SimpleNamespace(name="shared")
+ extra = types.SimpleNamespace(name="extra")
+
+ cleanup_clients = AsyncOmniEngine._collect_initialized_clients_for_cleanup(
+ stage_pools=[types.SimpleNamespace(clients=[shared, None])],
+ initialized_clients_by_stage={0: [shared], 1: [extra]},
+ )
+
+ assert cleanup_clients == [shared, extra]
+
+
+def test_initialize_stages_rejects_replicas_in_single_stage_mode():
+ engine = object.__new__(AsyncOmniEngine)
+ engine.single_stage_mode = True
+ engine.stage_configs = [types.SimpleNamespace(stage_id=0, runtime={"num_replicas": 2})]
+
+ with pytest.raises(ValueError, match="single_stage_mode does not support num_replicas > 1 yet"):
+ engine._validate_single_stage_mode_replica_constraints()
+
- Diffusion init mutates process-level CUDA visibility. Ensure AsyncOmniEngine
- restores the previous value after diffusion stage setup.
- """
+def test_initialize_diffusion_replica_restores_device_visibility_after_local_init(monkeypatch):
import vllm_omni.engine.async_omni_engine as engine_mod
from vllm_omni.platforms import current_omni_platform
engine = object.__new__(AsyncOmniEngine)
engine.model = "dummy-model"
- engine.config_path = "dummy-config"
engine.num_stages = 1
- engine.async_chunk = False
engine.diffusion_batch_size = 1
engine.single_stage_mode = False
- engine._single_stage_id_filter = None
engine._omni_master_server = None
- engine.stage_configs = [types.SimpleNamespace(stage_id=0, stage_type="diffusion")]
+ engine.stage_configs = []
+
+ plan = _make_diffusion_plan(0, configured_stage_id=0).replicas[0]
env_var = current_omni_platform.device_control_env_var
old_env = os.environ.get(env_var)
os.environ[env_var] = "0,1"
- diffusion_client = types.SimpleNamespace(is_comprehension=False)
-
- metadata = types.SimpleNamespace(
- stage_id=0,
- stage_type="diffusion",
- runtime_cfg={"devices": "1"},
- prompt_expand_func=None,
- )
-
- monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None)
- monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None)
- monkeypatch.setattr(engine_mod, "extract_stage_metadata", lambda _cfg: metadata)
- monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {})
- monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None))
-
def _fake_setup_stage_devices(_stage_id, _runtime_cfg):
- # Simulate diffusion setup mutating process-global visibility.
current_omni_platform.set_device_control_env_var("1")
monkeypatch.setattr(engine_mod, "setup_stage_devices", _fake_setup_stage_devices)
monkeypatch.setattr(engine_mod, "inject_kv_stage_info", lambda *_: None)
- monkeypatch.setattr(engine_mod, "initialize_diffusion_stage", lambda *_, **__: diffusion_client)
- monkeypatch.setattr(
- engine_mod,
- "finalize_initialized_stages",
- lambda stage_clients, _input_processor: (
- stage_clients,
- [types.SimpleNamespace()],
- [{"final_output_type": "image"}],
- ),
- )
+ monkeypatch.setattr(engine_mod, "initialize_diffusion_stage", lambda *_, **__: types.SimpleNamespace())
try:
- engine._initialize_stages(stage_init_timeout=1)
+ engine._initialize_diffusion_replica(plan, stage_init_timeout=1, stage_launch_lock=threading.Lock())
assert os.environ.get(env_var) == "0,1"
finally:
if old_env is None:
@@ -87,290 +214,273 @@ def _fake_setup_stage_devices(_stage_id, _runtime_cfg):
os.environ[env_var] = old_env
-def test_initialize_stages_passes_stage_init_timeout_to_diffusion_handshake(monkeypatch):
- """Regression test for stage_init_timeout passing to complete_diffusion_handshake
- in the diffusion stage path.
- """
- import vllm_omni.diffusion.data as diffusion_data_mod
- import vllm_omni.diffusion.stage_diffusion_client as client_mod
+def test_initialize_diffusion_replica_passes_stage_init_timeout_and_inline_flag(monkeypatch):
+ import vllm_omni.engine.async_omni_engine as engine_mod
+
+ engine = object.__new__(AsyncOmniEngine)
+ engine.model = "dummy-model"
+ engine.num_stages = 1
+ engine.diffusion_batch_size = 4
+ engine.single_stage_mode = False
+ engine._omni_master_server = None
+ engine.stage_configs = []
+
+ plan = _make_diffusion_plan(0, configured_stage_id=0).replicas[0]
+
+ captured: dict[str, object] = {}
+
+ monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "inject_kv_stage_info", lambda *_: None)
+
+ def _capture_initialize_diffusion_stage(
+ stage_id, _model, _stage_cfg, _metadata, *, stage_init_timeout, batch_size, use_inline
+ ):
+ captured["stage_id"] = stage_id
+ captured["stage_init_timeout"] = stage_init_timeout
+ captured["batch_size"] = batch_size
+ captured["use_inline"] = use_inline
+ return types.SimpleNamespace()
+
+ monkeypatch.setattr(engine_mod, "initialize_diffusion_stage", _capture_initialize_diffusion_stage)
+
+ engine._initialize_diffusion_replica(plan, stage_init_timeout=302, stage_launch_lock=threading.Lock())
+
+ assert captured == {
+ "stage_id": 0,
+ "stage_init_timeout": 302,
+ "batch_size": 4,
+ "use_inline": True,
+ }
+
+
+def test_initialize_stages_exposes_logical_stage_views_and_builds_top_level_input_processor(monkeypatch):
import vllm_omni.engine.async_omni_engine as engine_mod
- from vllm_omni.platforms import current_omni_platform
engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
engine.model = "dummy-model"
engine.config_path = "dummy-config"
engine.num_stages = 2
engine.async_chunk = False
engine.diffusion_batch_size = 1
engine.single_stage_mode = False
+ engine._single_stage_id_filter = None
engine._omni_master_server = None
- engine.stage_configs = [types.SimpleNamespace(stage_id=0, stage_type="diffusion", engine_args={})]
+ engine.stage_configs = [types.SimpleNamespace(), types.SimpleNamespace()]
- metadata = types.SimpleNamespace(
- stage_id=0,
- stage_type="diffusion",
- runtime_cfg={"devices": "0"},
- prompt_expand_func=None,
+ cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
+ cfg1 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
+ stage_plans = [
+ _make_llm_plan(0, configured_stage_id=0, vllm_config=cfg0, num_replicas=2, is_comprehension=True),
+ _make_llm_plan(1, configured_stage_id=1, vllm_config=cfg1, final_output=True),
+ ]
+
+ stage0_client_r0 = types.SimpleNamespace(
+ stage_type="llm",
+ is_comprehension=True,
+ final_output=False,
+ final_output_type=None,
+ default_sampling_params=types.SimpleNamespace(name="sp0"),
+ )
+ stage0_client_r1 = types.SimpleNamespace(
+ stage_type="llm",
+ is_comprehension=False,
+ final_output=False,
+ final_output_type=None,
+ default_sampling_params=types.SimpleNamespace(name="sp0r1"),
+ )
+ stage1_client_r0 = types.SimpleNamespace(
+ stage_type="llm",
+ is_comprehension=False,
final_output=True,
- final_output_type="image",
- default_sampling_params=None,
- custom_process_input_func=None,
- engine_input_source=None,
- cfg_kv_collect_func=None,
+ final_output_type=None,
+ default_sampling_params=types.SimpleNamespace(name="sp1"),
)
+ initialized_clients = {
+ 0: [stage0_client_r0, stage0_client_r1],
+ 1: [stage1_client_r0],
+ }
- captured_timeout = None
- device_env_var = current_omni_platform.device_control_env_var
- prev_device_env = os.environ.get(device_env_var)
- os.environ[device_env_var] = "0"
+ stage0_output_processor = object()
+ stage1_output_processor = object()
+ top_level_input_processor = object()
monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None)
monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None)
- monkeypatch.setattr(engine_mod, "extract_stage_metadata", lambda _cfg: metadata)
- monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "compute_replica_layout", lambda _cfgs: ([2, 1], {}))
+ monkeypatch.setattr(engine, "_build_logical_stage_init_plans", lambda *_: (stage_plans, None))
+ monkeypatch.setattr(engine, "_initialize_stage_replicas", lambda *_: initialized_clients)
monkeypatch.setattr(
engine_mod,
- "finalize_initialized_stages",
- lambda stage_clients, _input_processor: (
- stage_clients,
- [types.SimpleNamespace()],
- [{"final_output_type": "image"}],
- ),
+ "build_llm_stage_output_processor",
+ lambda plan, _cfg: stage0_output_processor if plan.stage_idx == 0 else stage1_output_processor,
)
- monkeypatch.setattr(
- diffusion_data_mod.OmniDiffusionConfig,
- "from_kwargs",
- classmethod(lambda cls, **kwargs: types.SimpleNamespace(parallel_config=types.SimpleNamespace(world_size=1))),
- )
- monkeypatch.setattr(
- client_mod,
- "spawn_diffusion_proc",
- lambda model, od_cfg: (object(), "ipc://handshake", "ipc://request", "ipc://response"),
- )
-
- def _capture_handshake_timeout(proc, handshake_address, handshake_timeout):
- nonlocal captured_timeout
- captured_timeout = handshake_timeout
+ monkeypatch.setattr(engine_mod, "build_stage0_input_processor", lambda _cfg: top_level_input_processor)
- monkeypatch.setattr(client_mod, "complete_diffusion_handshake", _capture_handshake_timeout)
- monkeypatch.setattr(
- client_mod.zmq,
- "Context",
- lambda: types.SimpleNamespace(socket=lambda _: types.SimpleNamespace(connect=lambda _: None)),
- )
-
- try:
- engine._initialize_stages(stage_init_timeout=302)
- finally:
- if prev_device_env is None:
- os.environ.pop(device_env_var, None)
- else:
- os.environ[device_env_var] = prev_device_env
+ engine._initialize_stages(stage_init_timeout=1)
- assert captured_timeout == 302
+ assert len(engine.stage_pools) == 2
+ assert engine.input_processor is top_level_input_processor
+ assert engine.stage_clients == [stage0_client_r0, stage1_client_r0]
+ assert engine.stage_vllm_configs == [cfg0, cfg1]
+ assert engine.output_processors == [stage0_output_processor, stage1_output_processor]
+ assert engine.default_sampling_params_list == [
+ stage0_client_r0.default_sampling_params,
+ stage1_client_r0.default_sampling_params,
+ ]
+ assert engine.stage_metadata == [
+ {"final_output": False, "final_output_type": None, "stage_type": "llm"},
+ {"final_output": True, "final_output_type": None, "stage_type": "llm"},
+ ]
-def test_initialize_stages_exposes_logical_stage_views_and_shares_stage_output_processor(monkeypatch):
+def test_build_logical_stage_init_plans_applies_replica_device_splits(monkeypatch):
import vllm_omni.engine.async_omni_engine as engine_mod
engine = object.__new__(AsyncOmniEngine)
engine.model = "dummy-model"
- engine.config_path = "dummy-config"
- engine.num_stages = 2
engine.async_chunk = False
- engine.diffusion_batch_size = 1
engine.single_stage_mode = False
engine._single_stage_id_filter = None
- engine._omni_master_server = None
engine.stage_configs = [
- types.SimpleNamespace(stage_id=0, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace()),
- types.SimpleNamespace(stage_id=1, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace()),
+ types.SimpleNamespace(stage_id=0, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace(devices="0")),
+ types.SimpleNamespace(
+ stage_id=1, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace(devices="1,2,3")
+ ),
]
- stage0_client_r0 = types.SimpleNamespace(is_comprehension=False, stage_type="llm")
- stage0_client_r1 = types.SimpleNamespace(is_comprehension=False, stage_type="llm")
- stage1_client_r0 = types.SimpleNamespace(is_comprehension=False, stage_type="llm")
-
- stage0_proc_r0 = object()
- stage1_proc_r0 = object()
-
- cfg0_r0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
- cfg0_r1 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
- cfg1_r0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
-
- metadata_map = {
- 0: types.SimpleNamespace(
- stage_id=0,
- stage_type="llm",
- runtime_cfg={},
- prompt_expand_func=None,
- final_output=False,
- final_output_type=None,
- default_sampling_params=types.SimpleNamespace(),
- custom_process_input_func=None,
- engine_input_source=[],
- engine_output_type="token_ids",
- ),
- 1: types.SimpleNamespace(
- stage_id=1,
- stage_type="llm",
- runtime_cfg={},
- prompt_expand_func=None,
- final_output=True,
- final_output_type=None,
- default_sampling_params=types.SimpleNamespace(),
- custom_process_input_func=None,
- engine_input_source=[0],
- engine_output_type="token_ids",
- ),
+ metadata_by_stage = {
+ 0: _make_llm_metadata(0),
+ 1: _make_llm_metadata(1),
}
- monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None)
- monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None)
- monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {})
- monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None))
- monkeypatch.setattr(engine_mod, "compute_replica_layout", lambda _cfgs: ([2, 1], {}, 3))
monkeypatch.setattr(
engine_mod,
"extract_stage_metadata",
- lambda cfg: types.SimpleNamespace(**metadata_map[cfg.stage_id].__dict__),
+ lambda cfg: types.SimpleNamespace(**metadata_by_stage[cfg.stage_id].__dict__),
)
+ monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {})
+ monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None))
+ monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
monkeypatch.setattr(
engine_mod,
- "finalize_initialized_stages",
- lambda stage_clients, _input_processor: (
- stage_clients,
- [types.SimpleNamespace(), types.SimpleNamespace()],
- [{"final_output_type": None}, {"final_output_type": None}],
- ),
+ "build_vllm_config",
+ lambda stage_cfg, *_args, **_kwargs: (types.SimpleNamespace(tag=f"cfg-{stage_cfg.stage_id}"), object),
)
- started_by_stage = {
- 0: [
- types.SimpleNamespace(stage_id=0, metadata=types.SimpleNamespace(replica_id=0), vllm_config=cfg0_r0),
- types.SimpleNamespace(stage_id=0, metadata=types.SimpleNamespace(replica_id=1), vllm_config=cfg0_r1),
- ],
- 1: [
- types.SimpleNamespace(stage_id=1, metadata=types.SimpleNamespace(replica_id=0), vllm_config=cfg1_r0),
- ],
- }
+ stage_plans, prompt_expand_func = engine._build_logical_stage_init_plans(
+ omni_transfer_config=None,
+ replicas_per_stage=[1, 3],
+ replica_devices_map={1: ["1", "2", "3"]},
+ )
- attach_outputs = {
- (0, 0): (stage0_client_r0, stage0_proc_r0, cfg0_r0, object()),
- (0, 1): (stage0_client_r1, None, cfg0_r1, None),
- (1, 0): (stage1_client_r0, stage1_proc_r0, cfg1_r0, None),
- }
+ assert prompt_expand_func is None
+ assert [plan.configured_stage_id for plan in stage_plans] == [0, 1]
+ assert [replica.stage_cfg.runtime.devices for replica in stage_plans[1].replicas] == ["1", "2", "3"]
+ assert [replica.replica_id for replica in stage_plans[1].replicas] == [0, 1, 2]
+ assert all(replica.num_replicas == 3 for replica in stage_plans[1].replicas)
- launch_counters = {0: 0, 1: 0}
- def _fake_launch_llm_stage(stage_cfg, metadata, *_args, **_kwargs):
- idx = metadata.stage_id
- launch_idx = launch_counters[idx]
- launch_counters[idx] += 1
- return started_by_stage[idx][launch_idx]
+def test_initialize_stage_replicas_collects_results_by_stage_and_replica_id(monkeypatch):
+ engine = object.__new__(AsyncOmniEngine)
- def _fake_attach_llm_stage(started):
- return attach_outputs[(started.stage_id, started.metadata.replica_id)]
+ cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
+ cfg1 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
+ stage_plans = [
+ _make_llm_plan(0, configured_stage_id=0, vllm_config=cfg0, num_replicas=2),
+ _make_llm_plan(1, configured_stage_id=1, vllm_config=cfg1, num_replicas=2),
+ ]
- monkeypatch.setattr(engine, "_launch_llm_stage", _fake_launch_llm_stage)
- monkeypatch.setattr(engine, "_attach_llm_stage", _fake_attach_llm_stage)
+ clients = {
+ (0, 0): types.SimpleNamespace(name="stage0-replica0"),
+ (0, 1): types.SimpleNamespace(name="stage0-replica1"),
+ (1, 0): types.SimpleNamespace(name="stage1-replica0"),
+ (1, 1): types.SimpleNamespace(name="stage1-replica1"),
+ }
- engine._initialize_stages(stage_init_timeout=1)
+ def _initialize_replica(plan, _stage_init_timeout, _stage_launch_lock):
+ time.sleep(0.02 * (3 - plan.metadata.stage_id - plan.replica_id))
+ return clients[(plan.metadata.stage_id, plan.replica_id)]
- assert len(engine.stage_pools) == 2
- assert len(engine.stage_clients) == 2
- assert len(engine.stage_vllm_configs) == 2
- assert len(engine.output_processors) == 2
+ monkeypatch.setattr(engine, "_initialize_replica", _initialize_replica)
- assert engine.stage_clients[0] is stage0_client_r0
- assert engine.stage_clients[1] is stage1_client_r0
- assert engine.stage_vllm_configs[0] is cfg0_r0
- assert engine.stage_vllm_configs[1] is cfg1_r0
+ initialized_clients = engine._initialize_stage_replicas(stage_plans, stage_init_timeout=123)
- stage0_pool = engine.stage_pools[0]
- assert engine.output_processors[0] is stage0_pool.output_processor
+ assert initialized_clients == {
+ 0: [clients[(0, 0)], clients[(0, 1)]],
+ 1: [clients[(1, 0)], clients[(1, 1)]],
+ }
-def test_launch_llm_stage_passes_stage_init_timeout_to_complete_stage_handshake(monkeypatch):
- """Regression test for stage_init_timeout reaching complete_stage_handshake
- in the LLM stage path.
- """
+def test_initialize_stages_cleans_up_successful_replicas_after_partial_multi_replica_failure(monkeypatch):
import vllm_omni.engine.async_omni_engine as engine_mod
- from vllm_omni.platforms import current_omni_platform
engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
engine.model = "dummy-model"
+ engine.config_path = "dummy-config"
+ engine.num_stages = 1
+ engine.async_chunk = False
+ engine.diffusion_batch_size = 1
engine.single_stage_mode = False
+ engine._single_stage_id_filter = None
engine._omni_master_server = None
- engine.stage_configs = []
+ engine.stage_configs = [types.SimpleNamespace()]
- metadata = types.SimpleNamespace(stage_id=0, runtime_cfg={"devices": "0"})
- fake_vllm_config = types.SimpleNamespace()
- fake_addresses = types.SimpleNamespace()
- fake_proc = types.SimpleNamespace()
+ cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64))
+ stage_plans = [_make_llm_plan(0, configured_stage_id=0, vllm_config=cfg0, num_replicas=2)]
+ initialized_client = types.SimpleNamespace(shutdown=lambda: None)
- captured_timeout = None
+ monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None)
+ monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None)
+ monkeypatch.setattr(engine_mod, "compute_replica_layout", lambda _cfgs: ([2], {}))
+ monkeypatch.setattr(engine, "_build_logical_stage_init_plans", lambda *_: (stage_plans, None))
- device_env_var = current_omni_platform.device_control_env_var
- prev_device_env = os.environ.get(device_env_var)
- os.environ[device_env_var] = "0"
+ def _initialize_replica(plan, _stage_init_timeout, _stage_launch_lock):
+ if plan.replica_id == 0:
+ return initialized_client
+ time.sleep(0.05)
+ raise RuntimeError("replica launch failed")
- monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
- monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
- monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object))
- monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: [])
- monkeypatch.setattr(
- engine_mod,
- "spawn_stage_core",
- lambda **_: (fake_addresses, fake_proc, "ipc://handshake"),
- )
+ monkeypatch.setattr(engine, "_initialize_replica", _initialize_replica)
- def _capture_stage_timeout(_proc, _handshake_addr, _addresses, _vllm_cfg, handshake_timeout):
- nonlocal captured_timeout
- captured_timeout = handshake_timeout
+ captured_cleanup: list[list[object]] = []
- monkeypatch.setattr(engine_mod, "complete_stage_handshake", _capture_stage_timeout)
+ def _capture_shutdown(clients):
+ captured_cleanup.append(list(clients))
- try:
- engine._launch_llm_stage(
- stage_cfg=types.SimpleNamespace(engine_args={}),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=302,
- llm_stage_launch_lock=threading.Lock(),
- )
- finally:
- if prev_device_env is None:
- os.environ.pop(device_env_var, None)
- else:
- os.environ[device_env_var] = prev_device_env
+ monkeypatch.setattr(engine, "_shutdown_initialized_clients", _capture_shutdown)
- assert captured_timeout == 302
+ with pytest.raises(RuntimeError, match="replica launch failed"):
+ engine._initialize_stages(stage_init_timeout=1)
+
+ assert captured_cleanup == [[initialized_client]]
-def test_launch_llm_stage_releases_launch_lock_before_complete_stage_handshake(monkeypatch):
- """Regression test for parallel LLM stage startup during handshake wait."""
+def test_initialize_llm_replica_passes_stage_init_timeout_to_complete_stage_handshake(monkeypatch):
import vllm_omni.engine.async_omni_engine as engine_mod
from vllm_omni.platforms import current_omni_platform
engine = object.__new__(AsyncOmniEngine)
- engine.log_stats = False
engine.model = "dummy-model"
engine.single_stage_mode = False
engine._omni_master_server = None
engine.stage_configs = []
fake_vllm_config = types.SimpleNamespace()
- fake_addresses = types.SimpleNamespace()
- shared_launch_lock = threading.Lock()
- counter_lock = threading.Lock()
- first_handshake_started = threading.Event()
- second_stage_spawned = threading.Event()
- allow_first_handshake_to_finish = threading.Event()
- launch_errors: list[BaseException] = []
- spawn_count = 0
+ fake_addresses = types.SimpleNamespace(inputs=["in"], outputs=["out"], frontend_stats_publish_address=None)
+ fake_proc = types.SimpleNamespace()
+ captured_timeout: int | None = None
+
+ plan = ReplicaInitPlan(
+ replica_id=0,
+ num_replicas=1,
+ launch_mode="local",
+ stage_cfg=types.SimpleNamespace(engine_args={}, runtime=types.SimpleNamespace(devices="0")),
+ metadata=types.SimpleNamespace(stage_id=0, runtime_cfg={"devices": "0"}),
+ stage_connector_spec={},
+ omni_kv_connector=(None, None, None),
+ stage_vllm_config=fake_vllm_config,
+ executor_class=object,
+ )
device_env_var = current_omni_platform.device_control_env_var
prev_device_env = os.environ.get(device_env_var)
@@ -378,127 +488,52 @@ def test_launch_llm_stage_releases_launch_lock_before_complete_stage_handshake(m
monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None)
monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
- monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object))
monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: [])
+ monkeypatch.setattr(engine_mod, "spawn_stage_core", lambda **_: (fake_addresses, fake_proc, "ipc://handshake"))
- def _spawn_stage_core(**_):
- nonlocal spawn_count
- with counter_lock:
- spawn_count += 1
- call_idx = spawn_count
- if call_idx == 2:
- second_stage_spawned.set()
- return fake_addresses, types.SimpleNamespace(), f"ipc://handshake-{call_idx}"
-
- def _complete_stage_handshake(_proc, handshake_address, _addresses, _vllm_cfg, _timeout):
- if handshake_address == "ipc://handshake-1":
- first_handshake_started.set()
- assert second_stage_spawned.wait(timeout=1), (
- "second stage did not reach spawn_stage_core while first stage waited in handshake"
- )
- assert allow_first_handshake_to_finish.wait(timeout=1), (
- "second stage did not enter handshake while first stage was still waiting"
- )
- else:
- allow_first_handshake_to_finish.set()
-
- monkeypatch.setattr(engine_mod, "spawn_stage_core", _spawn_stage_core)
- monkeypatch.setattr(engine_mod, "complete_stage_handshake", _complete_stage_handshake)
+ def _capture_stage_timeout(_proc, _handshake_addr, _addresses, _vllm_cfg, handshake_timeout):
+ nonlocal captured_timeout
+ captured_timeout = handshake_timeout
- def _launch_stage(stage_id: int) -> None:
- metadata = types.SimpleNamespace(stage_id=stage_id, runtime_cfg={"devices": str(stage_id)})
- try:
- engine._launch_llm_stage(
- stage_cfg=types.SimpleNamespace(engine_args={}),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=302,
- llm_stage_launch_lock=shared_launch_lock,
- )
- except BaseException as exc: # pragma: no cover - surfaced through assertion below
- launch_errors.append(exc)
+ monkeypatch.setattr(engine_mod, "complete_stage_handshake", _capture_stage_timeout)
+ monkeypatch.setattr(
+ engine_mod.StageEngineCoreClientBase,
+ "make_async_mp_client",
+ staticmethod(lambda **_: types.SimpleNamespace(shutdown=lambda: None)),
+ )
try:
- first_thread = threading.Thread(target=_launch_stage, args=(0,))
- first_thread.start()
- assert first_handshake_started.wait(timeout=1), "first stage never entered handshake"
-
- second_thread = threading.Thread(target=_launch_stage, args=(1,))
- second_thread.start()
-
- first_thread.join(timeout=3)
- second_thread.join(timeout=3)
+ engine._initialize_llm_replica(plan, 302, threading.Lock())
finally:
if prev_device_env is None:
os.environ.pop(device_env_var, None)
else:
os.environ[device_env_var] = prev_device_env
- assert not first_thread.is_alive()
- assert not second_thread.is_alive()
- assert second_stage_spawned.is_set()
- assert not launch_errors
-
-
-def test_attach_llm_stage_uses_omni_input_preprocessor(monkeypatch):
- """Regression test for GLM-Image t2i preprocessing path.
-
- Stage-0 InputProcessor must use OmniInputPreprocessor so text prompts with
- mm_processor_kwargs go through multimodal preprocessing.
- """
- import vllm_omni.engine.async_omni_engine as engine_mod
-
- class DummyStageEngineCoreClient:
- def __init__(self, **kwargs):
- self.kwargs = kwargs
+ assert captured_timeout == 302
- def shutdown(self):
- return None
- class DummyOutputProcessor:
- def __init__(self, **kwargs):
- self.kwargs = kwargs
+def test_build_stage0_input_processor_uses_omni_input_preprocessor(monkeypatch):
+ import vllm_omni.engine.stage_init_utils as init_mod
class DummyInputProcessor:
def __init__(self, vllm_config):
self.vllm_config = vllm_config
self.renderer = object()
- self.input_preprocessor = object()
+ self.input_preprocessor = None
class DummyOmniInputPreprocessor:
def __init__(self, vllm_config, renderer=None):
self.vllm_config = vllm_config
self.renderer = renderer
- monkeypatch.setattr(
- engine_mod.StageEngineCoreClientBase,
- "make_async_mp_client",
- staticmethod(lambda **kwargs: DummyStageEngineCoreClient(**kwargs)),
- )
- monkeypatch.setattr(engine_mod, "MultimodalOutputProcessor", DummyOutputProcessor)
- monkeypatch.setattr(engine_mod, "InputProcessor", DummyInputProcessor)
- monkeypatch.setattr(engine_mod, "OmniInputPreprocessor", DummyOmniInputPreprocessor)
+ monkeypatch.setattr(init_mod, "InputProcessor", DummyInputProcessor)
+ monkeypatch.setattr(init_mod, "OmniInputPreprocessor", DummyOmniInputPreprocessor)
- started = types.SimpleNamespace(
- stage_id=0,
- metadata=types.SimpleNamespace(stage_id=0, engine_output_type="token_ids"),
- vllm_config=types.SimpleNamespace(model_config=types.SimpleNamespace(skip_tokenizer_init=True)),
- executor_class=object,
- engine_manager=object(),
- coordinator=object(),
- proc=None,
- addresses=types.SimpleNamespace(
- inputs=["inproc://input"],
- outputs=["inproc://output"],
- frontend_stats_publish_address=None,
- ),
+ input_processor = build_stage0_input_processor(
+ types.SimpleNamespace(model_config=types.SimpleNamespace(try_get_generation_config=lambda: {}))
)
- engine = object.__new__(AsyncOmniEngine)
-
- _stage_client, _out_proc, _vllm_cfg, input_processor = engine._attach_llm_stage(started)
-
- assert input_processor is not None
assert isinstance(input_processor.input_preprocessor, DummyOmniInputPreprocessor)
assert input_processor.input_preprocessor.renderer is input_processor.renderer
diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py
index b012728cce5..88251ff7410 100644
--- a/tests/engine/test_orchestrator.py
+++ b/tests/engine/test_orchestrator.py
@@ -2,6 +2,7 @@
import asyncio
import concurrent.futures
+import logging
import queue
import threading
import time
@@ -14,7 +15,7 @@
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
-from vllm_omni.engine.orchestrator import Orchestrator
+from vllm_omni.engine.orchestrator import Orchestrator, OrchestratorRequestState
from vllm_omni.engine.stage_pool import StagePool
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -88,6 +89,25 @@ def push_diffusion_output(self, output) -> None:
self._diffusion_outputs.put_nowait(output)
+class FakeCollectiveRpcStageClient(FakeStageClient):
+ def __init__(self, *args, rpc_result: Any = None, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.rpc_result = rpc_result
+ self.collective_rpc_calls: list[tuple[str, float | None, tuple[Any, ...], dict[str, Any]]] = []
+
+ async def collective_rpc_async(
+ self,
+ *,
+ method: str,
+ timeout: float | None = None,
+ args: tuple[Any, ...] = (),
+ kwargs: dict[str, Any] | None = None,
+ ) -> Any:
+ normalized_kwargs = dict(kwargs or {})
+ self.collective_rpc_calls.append((method, timeout, args, normalized_kwargs))
+ return self.rpc_result
+
+
class FakeOutputProcessor:
def __init__(self, *, request_outputs: list[object] | None = None) -> None:
self.request_outputs = list(request_outputs or [])
@@ -281,6 +301,18 @@ async def _get_output_message(orchestrator_fixture: OrchestratorFixture, *, time
return msg
+async def _get_rpc_message(orchestrator_fixture: OrchestratorFixture, *, timeout: float = 2.0) -> dict:
+ deadline = time.monotonic() + timeout
+ rpc_sync_q = orchestrator_fixture.queues[2].sync_q
+ while True:
+ if time.monotonic() >= deadline:
+ raise AssertionError("Timed out waiting for orchestrator rpc output")
+ try:
+ return rpc_sync_q.get_nowait()
+ except queue.Empty:
+ await asyncio.sleep(0.01)
+
+
async def _enqueue_add_request(
orchestrator_fixture: OrchestratorFixture,
*,
@@ -691,3 +723,207 @@ async def test_multi_replica_shutdown_all_replicas(orchestrator_factory) -> None
assert not orchestrator_fixture.thread.is_alive()
for client in [stage0_r0, stage0_r1, stage1]:
assert client.shutdown_calls == 1
+
+
+@pytest.mark.asyncio
+async def test_stage_pool_submit_update_reuses_existing_binding() -> None:
+ """A request admitted to one replica must keep using that replica on updates."""
+ stage0_r0 = FakeStageClient(stage_type="llm", final_output=False)
+ stage0_r1 = FakeStageClient(stage_type="llm", final_output=False)
+ pool = StagePool(
+ 0,
+ [stage0_r0, stage0_r1],
+ output_processor=FakeOutputProcessor(),
+ stage_vllm_config=SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)),
+ )
+
+ req0_state = OrchestratorRequestState(
+ request_id="req-0",
+ sampling_params_list=[_sampling_params()],
+ final_stage_id=0,
+ )
+ req1_state = OrchestratorRequestState(
+ request_id="req-1",
+ sampling_params_list=[_sampling_params()],
+ final_stage_id=0,
+ )
+
+ await pool.submit_initial("req-0", req0_state, SimpleNamespace(request_id="req-0", prompt_token_ids=[1, 2]))
+ await pool.submit_update("req-0", req0_state, SimpleNamespace(request_id="req-0", prompt_token_ids=[3]))
+ await pool.submit_initial("req-1", req1_state, SimpleNamespace(request_id="req-1", prompt_token_ids=[4, 5]))
+ await pool.submit_update("req-1", req1_state, SimpleNamespace(request_id="req-1", prompt_token_ids=[6]))
+
+ assert pool.get_bound_replica_id("req-0") == 0
+ assert pool.get_bound_replica_id("req-1") == 1
+ assert len(stage0_r0.add_request_calls) == 2
+ assert len(stage0_r1.add_request_calls) == 2
+ assert stage0_r0.add_request_calls[0][0].request_id == "req-0"
+ assert stage0_r0.add_request_calls[1][0].request_id == "req-0"
+ assert stage0_r1.add_request_calls[0][0].request_id == "req-1"
+ assert stage0_r1.add_request_calls[1][0].request_id == "req-1"
+
+
+@pytest.mark.asyncio
+async def test_stage_pool_submit_initial_rolls_back_output_processor_when_client_submit_fails() -> None:
+ class FailingStageClient(FakeStageClient):
+ async def add_request_async(self, *args, **_kwargs) -> None:
+ raise RuntimeError("submit failed")
+
+ class TrackingOutputProcessor(FakeOutputProcessor):
+ def __init__(self) -> None:
+ super().__init__()
+ self.added_request_ids: list[str] = []
+ self.removed_request_ids: list[str] = []
+
+ def add_request(self, request, *_args, **_kwargs) -> None:
+ self.added_request_ids.append(request.request_id)
+
+ def remove_request(self, request_id: str) -> None:
+ self.removed_request_ids.append(request_id)
+
+ client = FailingStageClient(stage_type="llm", final_output=False)
+ output_processor = TrackingOutputProcessor()
+ pool = StagePool(
+ 0,
+ [client],
+ output_processor=output_processor,
+ stage_vllm_config=SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)),
+ )
+ req_state = OrchestratorRequestState(
+ request_id="req-0",
+ sampling_params_list=[_sampling_params()],
+ final_stage_id=0,
+ )
+
+ with pytest.raises(RuntimeError, match="submit failed"):
+ await pool.submit_initial("req-0", req_state, SimpleNamespace(request_id="req-0", prompt_token_ids=[1, 2]))
+
+ assert output_processor.added_request_ids == ["req-0"]
+ assert output_processor.removed_request_ids == ["req-0"]
+ assert pool.get_bound_replica_id("req-0") is None
+
+
+@pytest.mark.asyncio
+async def test_stage_pool_abort_requests_logs_when_binding_is_missing(caplog) -> None:
+ stage0 = FakeStageClient(stage_type="llm", final_output=False)
+ pool = StagePool(
+ 0,
+ [stage0],
+ output_processor=FakeOutputProcessor(),
+ stage_vllm_config=SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)),
+ )
+
+ target_logger = logging.getLogger("vllm_omni.engine.stage_pool")
+ target_logger.addHandler(caplog.handler)
+ prev_level = target_logger.level
+ target_logger.setLevel(logging.DEBUG)
+ try:
+ await pool.abort_requests(["missing-req"])
+ finally:
+ target_logger.removeHandler(caplog.handler)
+ target_logger.setLevel(prev_level)
+
+ assert not stage0.abort_calls
+ assert "abort: no binding for req=missing-req in stage-0" in caplog.text
+
+
+@pytest.mark.asyncio
+async def test_collective_rpc_ignores_invalid_stage_ids(orchestrator_factory, caplog) -> None:
+ stage0 = FakeCollectiveRpcStageClient(stage_type="llm", final_output=True, rpc_result={"stage": 0})
+ stage1 = FakeCollectiveRpcStageClient(stage_type="llm", final_output=True, rpc_result={"stage": 1})
+ stage_pools = _build_stage_pools(
+ [[stage0], [stage1]],
+ output_processors=[FakeOutputProcessor(), FakeOutputProcessor()],
+ stage_vllm_configs=[
+ SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)),
+ SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)),
+ ],
+ )
+ orchestrator_fixture = orchestrator_factory([], stage_pools=stage_pools)
+
+ try:
+ target_logger = logging.getLogger("vllm_omni.engine.orchestrator")
+ target_logger.addHandler(caplog.handler)
+ prev_level = target_logger.level
+ target_logger.setLevel(logging.WARNING)
+ try:
+ orchestrator_fixture.request_sync_q.put_nowait(
+ {
+ "type": "collective_rpc",
+ "rpc_id": "rpc-1",
+ "method": "list_loras",
+ "stage_ids": [99, 1],
+ }
+ )
+
+ msg = await _get_rpc_message(orchestrator_fixture)
+ finally:
+ target_logger.removeHandler(caplog.handler)
+ target_logger.setLevel(prev_level)
+
+ assert msg["type"] == "collective_rpc_result"
+ assert msg["rpc_id"] == "rpc-1"
+ assert msg["stage_ids"] == [1]
+ assert msg["results"] == [{"stage": 1}]
+ assert not stage0.collective_rpc_calls
+ assert len(stage1.collective_rpc_calls) == 1
+ assert "collective_rpc: ignoring invalid stage_id 99" in caplog.text
+ finally:
+ await _shutdown_orchestrator(orchestrator_fixture)
+
+
+@pytest.mark.asyncio
+async def test_multi_replica_cfg_companion_inherits_parent_affinity(orchestrator_factory) -> None:
+ """CFG companions should be routed to the same stage-0 replica as their parent."""
+ stage0_r0 = FakeStageClient(stage_type="llm", final_output=False)
+ stage0_r1 = FakeStageClient(stage_type="llm", final_output=False)
+ default_vllm_cfg = SimpleNamespace(model_config=SimpleNamespace(max_model_len=64))
+ stage_pools = _build_stage_pools(
+ [[stage0_r0, stage0_r1]],
+ output_processors=[FakeOutputProcessor()],
+ stage_vllm_configs=[default_vllm_cfg],
+ )
+ orchestrator_fixture = orchestrator_factory([], stage_pools=stage_pools)
+
+ try:
+ # Consume replica-0 first so the parent request binds to replica-1.
+ await _enqueue_add_request(
+ orchestrator_fixture,
+ request_id="warmup",
+ prompt=SimpleNamespace(request_id="warmup", prompt_token_ids=[0]),
+ original_prompt={"prompt": "warmup"},
+ sampling_params_list=[_sampling_params()],
+ final_stage_id=0,
+ )
+ await _wait_for(lambda: len(stage0_r0.add_request_calls) == 1)
+
+ await _enqueue_add_request(
+ orchestrator_fixture,
+ request_id="parent",
+ prompt=SimpleNamespace(request_id="parent", prompt_token_ids=[1, 2]),
+ original_prompt={"prompt": "parent"},
+ sampling_params_list=[_sampling_params()],
+ final_stage_id=0,
+ )
+ await _wait_for(lambda: len(stage0_r1.add_request_calls) == 1)
+
+ orchestrator_fixture.request_sync_q.put_nowait(
+ {
+ "type": "add_companion_request",
+ "companion_id": "parent-neg",
+ "parent_id": "parent",
+ "role": "negative",
+ "prompt": SimpleNamespace(request_id="parent-neg", prompt_token_ids=[9]),
+ "companion_prompt_text": {"prompt": "negative"},
+ "sampling_params_list": [_sampling_params()],
+ }
+ )
+ await _wait_for(lambda: len(stage0_r1.add_request_calls) == 2)
+
+ assert stage_pools[0].get_bound_replica_id("parent") == 1
+ assert stage_pools[0].get_bound_replica_id("parent-neg") == 1
+ assert len(stage0_r0.add_request_calls) == 1
+ assert stage0_r1.add_request_calls[0][0].request_id == "parent"
+ assert stage0_r1.add_request_calls[1][0].request_id == "parent-neg"
+ finally:
+ await _shutdown_orchestrator(orchestrator_fixture)
diff --git a/tests/engine/test_orchestrator_error_handling.py b/tests/engine/test_orchestrator_error_handling.py
new file mode 100644
index 00000000000..18099c01640
--- /dev/null
+++ b/tests/engine/test_orchestrator_error_handling.py
@@ -0,0 +1,160 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for error propagation paths within the Orchestrator.
+
+Covers:
+- EngineDeadError from an LLM stage poll → fatal error broadcast + shutdown
+- Diffusion stage error output (OmniRequestOutput.from_error) → routed correctly
+"""
+
+from __future__ import annotations
+
+import asyncio
+import queue
+import time
+from types import SimpleNamespace
+
+import pytest
+from vllm.v1.engine.exceptions import EngineDeadError
+
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.outputs import OmniRequestOutput
+
+from .test_orchestrator import (
+ FakeStageClient,
+ OrchestratorFixture,
+ _build_harness,
+ _enqueue_add_request,
+ _wait_for,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _sampling_params(max_tokens: int = 4):
+ from vllm.sampling_params import SamplingParams
+
+ return SamplingParams(max_tokens=max_tokens)
+
+
+async def _get_any_output_message(fixture: OrchestratorFixture, *, timeout: float = 2.0) -> dict:
+ """Like _get_output_message but returns any message type (including errors)."""
+ deadline = time.monotonic() + timeout
+ while True:
+ if time.monotonic() >= deadline:
+ raise AssertionError("Timed out waiting for orchestrator output")
+ try:
+ return fixture.output_sync_q.get_nowait()
+ except queue.Empty:
+ await asyncio.sleep(0.01)
+
+
+@pytest.fixture
+def orchestrator_factory():
+ fixtures: list[OrchestratorFixture] = []
+
+ def _factory(*args, **kwargs) -> OrchestratorFixture:
+ fixture = _build_harness(*args, **kwargs)
+ fixtures.append(fixture)
+ return fixture
+
+ yield _factory
+
+ for fixture in fixtures:
+ if fixture.thread.is_alive():
+ fixture.request_sync_q.put_nowait({"type": "shutdown"})
+ fixture.thread.join(timeout=5)
+ for q in fixture.queues:
+ q.close()
+
+
+# ───────── EngineDeadError from LLM stage poll ─────────
+
+
+class FakeDeadLLMStageClient(FakeStageClient):
+ """LLM stage client that raises EngineDeadError on get_output_async."""
+
+ async def get_output_async(self):
+ raise EngineDeadError("Stage-0 engine core is dead")
+
+
+@pytest.mark.asyncio
+async def test_engine_dead_error_broadcasts_fatal_and_shuts_down(orchestrator_factory) -> None:
+ """When a stage raises EngineDeadError during poll, the orchestrator must:
+ 1. Enqueue a fatal error message for each affected request
+ 2. Shut itself down (thread exits)
+ """
+ stage0 = FakeDeadLLMStageClient(stage_type="llm", final_output=True)
+ orchestrator_fixture = orchestrator_factory([stage0])
+ request = SimpleNamespace(request_id="req-dead", prompt_token_ids=[1, 2])
+
+ try:
+ await _enqueue_add_request(
+ orchestrator_fixture,
+ request_id="req-dead",
+ prompt=request,
+ original_prompt={"prompt": "hello"},
+ sampling_params_list=[_sampling_params()],
+ final_stage_id=0,
+ )
+
+ # Collect the fatal error message.
+ msg = await _get_any_output_message(orchestrator_fixture)
+
+ assert msg["type"] == "error"
+ assert msg["fatal"] is True
+ assert msg["request_id"] == "req-dead"
+ assert "Stage-0 engine core is dead" in msg["error"]
+
+ # The orchestrator thread should exit after the fatal error.
+ orchestrator_fixture.thread.join(timeout=5)
+ assert not orchestrator_fixture.thread.is_alive()
+
+ # Request state should be cleaned up.
+ assert "req-dead" not in orchestrator_fixture.orchestrator.request_states
+ finally:
+ if orchestrator_fixture.thread.is_alive():
+ orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"})
+ orchestrator_fixture.thread.join(timeout=5)
+
+
+# ───────── Diffusion stage error output routing ─────────
+
+
+@pytest.mark.asyncio
+async def test_diffusion_error_output_routed_as_finished(orchestrator_factory) -> None:
+ """When a diffusion stage returns an OmniRequestOutput with a non-None
+ error, the orchestrator must route it as an error message and clean up
+ the request state.
+ """
+ stage0 = FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image")
+ orchestrator_fixture = orchestrator_factory([stage0])
+ params = OmniDiffusionSamplingParams()
+
+ try:
+ await _enqueue_add_request(
+ orchestrator_fixture,
+ request_id="req-err",
+ prompt={"prompt": "draw a cat"},
+ original_prompt={"prompt": "draw a cat"},
+ sampling_params_list=[params],
+ final_stage_id=0,
+ )
+
+ await _wait_for(lambda: len(stage0.add_request_calls) == 1)
+
+ # Push an error output from the diffusion stage.
+ stage0.push_diffusion_output(OmniRequestOutput.from_error("req-err", "gpu fault"))
+
+ msg = await _get_any_output_message(orchestrator_fixture)
+
+ assert msg["type"] == "error"
+ assert msg["request_id"] == "req-err"
+ assert msg["stage_id"] == 0
+ assert msg["error"] == "gpu fault"
+
+ # Request state should be cleaned up.
+ await _wait_for(lambda: "req-err" not in orchestrator_fixture.orchestrator.request_states)
+ finally:
+ orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"})
+ orchestrator_fixture.thread.join(timeout=5)
diff --git a/tests/engine/test_orchestrator_kv_sender_info.py b/tests/engine/test_orchestrator_kv_sender_info.py
index a030ab3478b..2fa9878fb84 100644
--- a/tests/engine/test_orchestrator_kv_sender_info.py
+++ b/tests/engine/test_orchestrator_kv_sender_info.py
@@ -153,7 +153,7 @@ def test_forward_to_diffusion_attaches_kv_sender_info():
)
output = SimpleNamespace(request_id="req-1", finished=True)
- asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-1", sender_pool, output, req_state))
+ asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-1", sender_pool.stage_id, output, req_state))
assert diffusion_stage.calls[0]["request_id"] == "req-1"
assert diffusion_stage.calls[0]["kv_sender_info"] == {
@@ -182,7 +182,7 @@ def test_forward_to_diffusion_uses_engine_input_source_for_kv_sender_info():
)
output = SimpleNamespace(request_id="req-3", finished=True)
- asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-3", previous_pool, output, req_state))
+ asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-3", previous_pool.stage_id, output, req_state))
assert diffusion_stage.calls[0]["kv_sender_info"] == {
0: {"host": "10.0.0.2", "zmq_port": 50151},
diff --git a/tests/engine/test_single_stage_mode.py b/tests/engine/test_single_stage_mode.py
index 28ccccaa2b5..f183b769b43 100644
--- a/tests/engine/test_single_stage_mode.py
+++ b/tests/engine/test_single_stage_mode.py
@@ -1,20 +1,8 @@
-"""Unit tests for AsyncOmniEngine single-stage mode and OmniMasterServer.
-
-These tests cover:
-- OmniMasterServer address pre-allocation & ZMQ registration handshake
-- AsyncOmniEngine single_stage_mode detection / _single_stage_id_filter setup
-- _initialize_stages stage routing (local launch vs. remote-wait) in
- single_stage_mode
-- _create_remote_llm_stage delegation to connect_remote_engine_cores
-- _launch_llm_stage delegation to launch_omni_core_engines in
- single_stage_mode
-
-All tests run without real hardware by mocking ZMQ, vllm_config, and the
-heavy initialization helpers.
-"""
+"""Unit tests for AsyncOmniEngine single-stage mode and OmniMasterServer."""
from __future__ import annotations
+import os
import threading
from contextlib import contextmanager
from types import SimpleNamespace
@@ -25,6 +13,7 @@
from vllm.v1.engine.utils import EngineZmqAddresses
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
+from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClientBase
from vllm_omni.engine.stage_engine_startup import (
OmniMasterServer,
StageAllocation,
@@ -32,21 +21,17 @@
connect_remote_engine_cores,
launch_omni_core_engines,
)
-from vllm_omni.engine.stage_init_utils import StartedLlmStage
+from vllm_omni.engine.stage_init_utils import LogicalStageInitPlan, ReplicaInitPlan
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
-# ---------------------------------------------------------------------------
-# Helpers
-# ---------------------------------------------------------------------------
-
-
def _make_stage_cfg(stage_id: int, stage_type: str = "llm"):
"""Return a lightweight stage config mock."""
return SimpleNamespace(
stage_id=stage_id,
stage_type=stage_type,
+ runtime=SimpleNamespace(devices="0"),
engine_args=SimpleNamespace(
async_chunk=False,
model_stage=None,
@@ -55,67 +40,108 @@ def _make_stage_cfg(stage_id: int, stage_type: str = "llm"):
)
-def _make_started_llm_stage(stage_id: int) -> StartedLlmStage:
- """Return a minimal StartedLlmStage for mocking."""
- addresses = SimpleNamespace(
- inputs=["tcp://127.0.0.1:5000"],
- outputs=["tcp://127.0.0.1:5001"],
- frontend_stats_publish_address=None,
+def _make_llm_plan(
+ stage_idx: int,
+ *,
+ configured_stage_id: int,
+ launch_mode: str,
+ vllm_config: Any | None = None,
+) -> LogicalStageInitPlan:
+ stage_cfg = _make_stage_cfg(configured_stage_id)
+ metadata = SimpleNamespace(
+ stage_id=configured_stage_id,
+ stage_type="llm",
+ runtime_cfg={"devices": "0"},
+ prompt_expand_func=None,
+ final_output=False,
+ final_output_type=None,
+ default_sampling_params=SimpleNamespace(),
+ custom_process_input_func=None,
+ engine_input_source=[],
+ engine_output_type="token_ids",
+ replica_id=0,
)
- return StartedLlmStage(
- stage_id=stage_id,
- metadata=SimpleNamespace(stage_id=stage_id),
- vllm_config=SimpleNamespace(),
- executor_class=SimpleNamespace(),
- engine_manager=SimpleNamespace(),
- coordinator=SimpleNamespace(),
- addresses=addresses,
+ return LogicalStageInitPlan(
+ stage_idx=stage_idx,
+ configured_stage_id=configured_stage_id,
+ replicas=[
+ ReplicaInitPlan(
+ replica_id=0,
+ num_replicas=1,
+ launch_mode=launch_mode,
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ omni_kv_connector=(None, None, None),
+ stage_vllm_config=vllm_config
+ or SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1)),
+ executor_class=object,
+ )
+ ],
+ )
+
+
+def _make_diffusion_plan(
+ stage_idx: int,
+ *,
+ configured_stage_id: int,
+ launch_mode: str,
+) -> LogicalStageInitPlan:
+ stage_cfg = _make_stage_cfg(configured_stage_id, stage_type="diffusion")
+ metadata = SimpleNamespace(
+ stage_id=configured_stage_id,
+ stage_type="diffusion",
+ runtime_cfg={"devices": "0"},
+ prompt_expand_func=None,
+ final_output=True,
+ final_output_type="image",
+ default_sampling_params=SimpleNamespace(),
+ custom_process_input_func=None,
+ engine_input_source=[],
+ cfg_kv_collect_func=None,
+ replica_id=0,
+ )
+ return LogicalStageInitPlan(
+ stage_idx=stage_idx,
+ configured_stage_id=configured_stage_id,
+ replicas=[
+ ReplicaInitPlan(
+ replica_id=0,
+ num_replicas=1,
+ launch_mode=launch_mode,
+ stage_cfg=stage_cfg,
+ metadata=metadata,
+ stage_connector_spec={},
+ omni_kv_connector=(None, None, None),
+ )
+ ],
)
# ---------------------------------------------------------------------------
-# OmniMasterServer – address pre-allocation
+# OmniMasterServer address pre-allocation
# ---------------------------------------------------------------------------
class TestOmniMasterServerAllocation:
- """Test address pre-allocation in OmniMasterServer.__init__."""
-
def test_public_address_and_port_properties_expose_registration_endpoint(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15000,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=15000, stage_ids=[0])
assert server.address == "127.0.0.1"
assert server.port == 15000
def test_allocations_created_for_each_stage_id(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15000,
- stage_ids=[0, 1, 2],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=15000, stage_ids=[0, 1, 2])
assert set(server._allocations.keys()) == {0, 1, 2}
def test_each_allocation_is_stage_allocation(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15000,
- stage_ids=[0, 1],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=15000, stage_ids=[0, 1])
for sid in (0, 1):
- alloc = server._allocations[sid]
- assert isinstance(alloc, StageAllocation)
+ assert isinstance(server._allocations[sid], StageAllocation)
def test_allocation_addresses_reference_master_address(self):
- server = OmniMasterServer(
- master_address="192.168.1.10",
- master_port=20000,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="192.168.1.10", master_port=20000, stage_ids=[0])
alloc = server._allocations[0]
- for addr in (
+ for address in (
alloc.handshake_bind_address,
alloc.handshake_connect_address,
alloc.input_bind_address,
@@ -123,73 +149,48 @@ def test_allocation_addresses_reference_master_address(self):
alloc.output_bind_address,
alloc.output_connect_address,
):
- assert "192.168.1.10" in addr, f"Expected master address in {addr}"
+ assert "192.168.1.10" in address
def test_port_uniqueness_within_single_allocation(self):
- """Each allocation uses three distinct ports."""
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15001,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=15001, stage_ids=[0])
alloc = server._allocations[0]
- hs_port = int(alloc.handshake_bind_address.split(":")[-1])
- inp_port = int(alloc.input_bind_address.split(":")[-1])
- out_port = int(alloc.output_bind_address.split(":")[-1])
- assert len({hs_port, inp_port, out_port}) == 3, "Expected three distinct ports per stage allocation"
+ handshake_port = int(alloc.handshake_bind_address.split(":")[-1])
+ input_port = int(alloc.input_bind_address.split(":")[-1])
+ output_port = int(alloc.output_bind_address.split(":")[-1])
+ assert len({handshake_port, input_port, output_port}) == 3
def test_get_zmq_addresses_returns_bind_addresses(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15002,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=15002, stage_ids=[0])
alloc = server._allocations[0]
zmq_addrs = server.get_zmq_addresses(0)
assert zmq_addrs.inputs == [alloc.input_bind_address]
assert zmq_addrs.outputs == [alloc.output_bind_address]
def test_get_engine_zmq_addresses_returns_connect_addresses(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15003,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=15003, stage_ids=[0])
alloc = server._allocations[0]
- engine_addrs = server.get_engine_zmq_addresses(0)
- assert engine_addrs.inputs == [alloc.input_connect_address]
- assert engine_addrs.outputs == [alloc.output_connect_address]
+ zmq_addrs = server.get_engine_zmq_addresses(0)
+ assert zmq_addrs.inputs == [alloc.input_connect_address]
+ assert zmq_addrs.outputs == [alloc.output_connect_address]
def test_get_allocation_returns_correct_object(self):
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=15004,
- stage_ids=[3],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=15004, stage_ids=[3])
assert server.get_allocation(3) is server._allocations[3]
# ---------------------------------------------------------------------------
-# OmniMasterServer – ZMQ registration flow
+# OmniMasterServer registration flow
# ---------------------------------------------------------------------------
class TestOmniMasterServerRegistration:
- """Test that the server correctly handles a stage registration."""
-
def test_registration_reply_contains_handshake_address(self):
- """A DEALER client that sends a registration msg gets the handshake
- address back from the ROUTER registration socket."""
import msgspec
import zmq
from vllm.utils.network_utils import get_open_port
master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0])
server.start()
expected_hs = server._allocations[0].handshake_connect_address
@@ -198,8 +199,7 @@ def test_registration_reply_contains_handshake_address(self):
sock = ctx.socket(zmq.DEALER)
sock.connect(f"tcp://127.0.0.1:{master_port}")
sock.send(msgspec.msgpack.encode({"stage_id": 0}))
- if not sock.poll(timeout=5_000):
- pytest.fail("No reply received from OmniMasterServer within 5 s")
+ assert sock.poll(timeout=5_000)
reply = msgspec.msgpack.decode(sock.recv())
assert reply["handshake_address"] == expected_hs
finally:
@@ -208,92 +208,65 @@ def test_registration_reply_contains_handshake_address(self):
server.stop()
def test_server_handles_unknown_stage_id_gracefully(self):
- """A registration for an unrecognised stage_id must not crash the server."""
import msgspec
import zmq
from vllm.utils.network_utils import get_open_port
master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0])
server.start()
ctx = zmq.Context()
+ bad_sock = None
+ good_sock = None
try:
bad_sock = ctx.socket(zmq.DEALER)
bad_sock.connect(f"tcp://127.0.0.1:{master_port}")
- # Send unknown stage_id=99
bad_sock.send(msgspec.msgpack.encode({"stage_id": 99}))
- # Server should NOT reply for an unknown id; wait briefly
- has_reply = bad_sock.poll(timeout=500)
- assert not has_reply, "Server should not reply to unknown stage_id"
- # Then register the valid stage so the server thread can exit
+ assert not bad_sock.poll(timeout=500)
+
good_sock = ctx.socket(zmq.DEALER)
good_sock.connect(f"tcp://127.0.0.1:{master_port}")
good_sock.send(msgspec.msgpack.encode({"stage_id": 0}))
- good_sock.poll(timeout=2_000)
+ assert good_sock.poll(timeout=2_000)
+ good_sock.recv()
finally:
- for s in (bad_sock, good_sock):
- try:
- s.close(linger=0)
- except Exception:
- pass
+ for sock in (bad_sock, good_sock):
+ if sock is not None:
+ sock.close(linger=0)
ctx.term()
server.stop()
def test_registration_stores_stage_config(self):
- """Stage registration should persist the sender's stage config."""
import msgspec
import zmq
from vllm.utils.network_utils import get_open_port
master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0])
server.start()
- payload = {
- "stage_id": 0,
- "stage_config": {
- "stage_id": 0,
- "stage_type": "llm",
- "engine_args": {"model": "fake-model"},
- },
- }
-
+ stage_config = {"stage_id": 0, "stage_type": "llm"}
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.DEALER)
sock.connect(f"tcp://127.0.0.1:{master_port}")
- sock.send(msgspec.msgpack.encode(payload))
+ sock.send(msgspec.msgpack.encode({"stage_id": 0, "stage_config": stage_config}))
assert sock.poll(timeout=5_000)
sock.recv()
-
- stored = server.get_stage_config(0, timeout_s=0.1)
- assert stored == payload["stage_config"]
+ assert server.get_stage_config(0) == stage_config
finally:
sock.close(linger=0)
ctx.term()
server.stop()
def test_registration_stores_coordinator_addresses(self):
- """Stage registration should persist optional coordinator addresses."""
import msgspec
import zmq
from vllm.utils.network_utils import get_open_port
master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[0],
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0])
server.start()
payload = {
@@ -303,7 +276,6 @@ def test_registration_stores_coordinator_addresses(self):
"coordinator_output": "tcp://127.0.0.1:31002",
"frontend_stats_publish_address": "tcp://127.0.0.1:31003",
}
-
ctx = zmq.Context()
try:
sock = ctx.socket(zmq.DEALER)
@@ -311,9 +283,7 @@ def test_registration_stores_coordinator_addresses(self):
sock.send(msgspec.msgpack.encode(payload))
assert sock.poll(timeout=5_000)
sock.recv()
-
- stored = server.get_stage_coordinator_addresses(0, timeout_s=0.1)
- assert stored == StageCoordinatorAddresses(
+ assert server.get_stage_coordinator_addresses(0) == StageCoordinatorAddresses(
coordinator_input=payload["coordinator_input"],
coordinator_output=payload["coordinator_output"],
frontend_stats_publish_address=payload["frontend_stats_publish_address"],
@@ -327,57 +297,47 @@ def test_stop_joins_server_thread(self):
from vllm.utils.network_utils import get_open_port
master_port = get_open_port()
- server = OmniMasterServer(
- master_address="127.0.0.1",
- master_port=master_port,
- stage_ids=[], # no stages → thread exits immediately
- )
+ server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[])
server.start()
+
assert server._thread is not None
server.stop()
- # Thread should have exited (joined with timeout=10 inside stop())
assert not server._thread.is_alive()
# ---------------------------------------------------------------------------
-# AsyncOmniEngine – single_stage_mode detection in __init__
+# AsyncOmniEngine single_stage_mode detection in __init__
# ---------------------------------------------------------------------------
class TestSingleStageModeDetection:
- """Test __init__ single_stage_mode / _single_stage_id_filter setup.
-
- We bypass the real __init__ by patching _resolve_stage_configs and
- the orchestrator thread, so no actual engines are started.
- """
-
- def _make_engine_no_thread(self, mocker: MockerFixture, **kwargs: Any) -> AsyncOmniEngine:
- """Create an AsyncOmniEngine without starting the orchestrator thread."""
- stage_cfg = _make_stage_cfg(0)
- mock_stage_configs = [stage_cfg]
+ def _make_engine_no_thread(
+ self,
+ mocker: MockerFixture,
+ *,
+ stage_cfgs: list[Any] | None = None,
+ **kwargs: Any,
+ ) -> AsyncOmniEngine:
+ mock_stage_configs = stage_cfgs or [_make_stage_cfg(0)]
mocker.patch.object(
AsyncOmniEngine,
"_resolve_stage_configs",
return_value=("/fake/path", mock_stage_configs),
)
- mocker.patch.object(
- AsyncOmniEngine,
- "_bootstrap_orchestrator",
- )
+ mocker.patch.object(AsyncOmniEngine, "_bootstrap_orchestrator")
mock_thread_cls = mocker.patch("threading.Thread")
mock_future_cls = mocker.patch("concurrent.futures.Future")
mock_future = mocker.Mock()
- mock_future.result.return_value = mocker.Mock() # simulates a loop
+ mock_future.result.return_value = mocker.Mock()
mock_future_cls.return_value = mock_future
mock_thread = mocker.Mock()
mock_thread.is_alive.return_value = False
mock_thread_cls.return_value = mock_thread
- engine = AsyncOmniEngine(model="fake-model", **kwargs)
- return engine
+ return AsyncOmniEngine(model="fake-model", **kwargs)
def test_explicit_single_stage_mode_true(self, mocker: MockerFixture):
engine = self._make_engine_no_thread(
@@ -407,9 +367,7 @@ def test_stage_id_kwarg_sets_filter(self, mocker: MockerFixture):
assert engine._single_stage_id_filter == 1
def test_no_stage_id_no_single_stage_mode(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- )
+ engine = self._make_engine_no_thread(mocker)
assert engine.single_stage_mode is False
assert engine._single_stage_id_filter is None
@@ -420,6 +378,7 @@ def test_single_stage_mode_without_stage_id_has_no_filter(self, mocker: MockerFi
omni_master_address="127.0.0.1",
omni_master_port=20003,
)
+ assert engine.single_stage_mode is True
assert engine._single_stage_id_filter is None
def test_master_address_and_port_stored(self, mocker: MockerFixture):
@@ -433,899 +392,489 @@ def test_master_address_and_port_stored(self, mocker: MockerFixture):
assert engine._omni_master_port == 12345
def test_omni_master_server_starts_as_none(self, mocker: MockerFixture):
- engine = self._make_engine_no_thread(
- mocker,
- )
+ engine = self._make_engine_no_thread(mocker)
assert engine._omni_master_server is None
# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _initialize_stages stage routing
+# AsyncOmniEngine single-stage initialization paths
# ---------------------------------------------------------------------------
-class TestInitializeStagesRouting:
- """Verify that _initialize_stages routes each stage to the correct launch
- function depending on single_stage_mode and _single_stage_id_filter."""
-
- _COMMON_PATCHES = [
- "vllm_omni.engine.async_omni_engine.prepare_engine_environment",
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- ]
-
- def _build_engine_skeleton(
- self,
- stage_cfgs: list[Any],
- single_stage_mode: bool,
- stage_id_filter: int | None,
- omni_master_address: str = "127.0.0.1",
- omni_master_port: int = 25000,
+class TestSingleStageInitialization:
+ def _build_engine(
+ self, stage_cfgs: list[Any], *, single_stage_mode: bool, stage_id_filter: int | None
) -> AsyncOmniEngine:
- """Build a bare AsyncOmniEngine without launching any threads."""
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
- engine.config_path = "/fake"
- engine.stage_configs = stage_cfgs
+ engine.config_path = "/fake/stages.yaml"
engine.num_stages = len(stage_cfgs)
- engine.async_chunk = False
+ engine.stage_configs = stage_cfgs
engine.single_stage_mode = single_stage_mode
engine._single_stage_id_filter = stage_id_filter
- engine._omni_master_address = omni_master_address
- engine._omni_master_port = omni_master_port
+ engine._omni_master_address = "127.0.0.1"
+ engine._omni_master_port = 26000
engine._omni_master_server = None
- engine._llm_stage_launch_lock = __import__("threading").Lock()
- engine.diffusion_batch_size = 1
- engine.stage_clients = []
- engine.stage_vllm_configs = []
- engine.output_processors = []
- engine.input_processor = None
- engine.supported_tasks = ("generate",)
- engine.default_sampling_params_list = []
- engine.stage_metadata = []
- engine.prompt_expand_func = None
+ engine.async_chunk = False
+ engine.diffusion_batch_size = 2
return engine
- def _fake_metadata(self, mocker: MockerFixture, stage_id: int, stage_type: str = "llm") -> Any:
- meta = mocker.Mock()
- meta.stage_id = stage_id
- meta.stage_type = stage_type
- meta.runtime_cfg = {}
- meta.prompt_expand_func = None
- meta.engine_output_type = None
- meta.is_comprehension = False
- meta.final_output = True if stage_id == 0 else False
- meta.final_output_type = None
- return meta
-
- def _run_initialize_stages_mocked(
- self,
- mocker: MockerFixture,
- engine: AsyncOmniEngine,
- stage_cfgs: list[Any],
- *,
- launch_side_effect: Any = None,
- remote_side_effect: Any = None,
- attach_result: Any = None,
- ) -> tuple[Any, Any]:
- """Execute _initialize_stages with all heavy helpers mocked.
-
- Returns (mock_launch_llm_stage, mock_create_remote_llm_stage).
- """
- started_by_stage: dict[int, StartedLlmStage] = {
- cfg.stage_id: _make_started_llm_stage(cfg.stage_id)
- for cfg in stage_cfgs
- if getattr(cfg, "stage_type", "llm") != "diffusion"
- }
-
- default_attach = (mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock())
+ def test_build_logical_stage_init_plans_marks_non_matching_stage_remote(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
- mock_launch = mocker.Mock(
- side_effect=launch_side_effect
- or (lambda cfg, meta, spec, timeout, llm_stage_launch_lock, kv: started_by_stage[meta.stage_id])
- )
- mock_remote = mocker.Mock(
- side_effect=remote_side_effect or (lambda cfg, meta, spec, timeout, srv: started_by_stage[meta.stage_id])
+ stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
+ engine = self._build_engine(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
+
+ monkeypatch = pytest.MonkeyPatch()
+ monkeypatch.setattr(
+ engine_mod,
+ "extract_stage_metadata",
+ lambda cfg: SimpleNamespace(
+ stage_id=cfg.stage_id,
+ stage_type=getattr(cfg, "stage_type", "llm"),
+ prompt_expand_func=None,
+ runtime_cfg={},
+ ),
)
- mock_attach = mocker.Mock(return_value=attach_result or default_attach)
+ monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {})
+ monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None))
+ monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
+ monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (SimpleNamespace(), object))
+ try:
+ stage_plans, _ = engine._build_logical_stage_init_plans(None, [1, 1], {})
+ finally:
+ monkeypatch.undo()
+
+ assert [plan.replicas[0].launch_mode for plan in stage_plans] == ["local", "remote"]
+
+ def test_start_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ engine = self._build_engine([], single_stage_mode=True, stage_id_filter=7)
mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.side_effect = lambda sid: mocker.Mock()
+ mocker.patch.object(engine_mod, "OmniMasterServer", return_value=mock_oms)
- finalized = (
- [mocker.Mock() for _ in stage_cfgs],
- [mocker.Mock() for _ in stage_cfgs],
- [{"final_output": True, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
- )
+ stage_plans = [
+ _make_llm_plan(0, configured_stage_id=7, launch_mode="local"),
+ _make_diffusion_plan(1, configured_stage_id=11, launch_mode="remote"),
+ ]
- mocker.patch.object(engine, "_launch_llm_stage", mock_launch)
- mocker.patch.object(engine, "_create_remote_llm_stage", mock_remote)
- mocker.patch.object(engine, "_attach_llm_stage", mock_attach)
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.prepare_engine_environment",
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ engine._start_omni_master_server(stage_plans)
+
+ engine_mod.OmniMasterServer.assert_called_once_with(
+ master_address="127.0.0.1",
+ master_port=26000,
+ stage_ids=[7, 11],
)
+ mock_oms.start.assert_called_once()
- engine._initialize_stages(stage_init_timeout=60)
+ def test_start_omni_master_server_duplicate_stage_ids_raise(self):
+ engine = self._build_engine([], single_stage_mode=True, stage_id_filter=7)
+ stage_plans = [
+ _make_llm_plan(0, configured_stage_id=7, launch_mode="local"),
+ _make_llm_plan(1, configured_stage_id=7, launch_mode="remote"),
+ ]
- return mock_launch, mock_remote
+ with pytest.raises(ValueError, match="Duplicate stage_id"):
+ engine._start_omni_master_server(stage_plans)
- # -- single-stage mode: stage matches filter → local launch ---------------
+ def test_start_omni_master_server_missing_address_raises(self):
+ engine = self._build_engine([], single_stage_mode=True, stage_id_filter=7)
+ engine._omni_master_address = None
- def test_matching_stage_uses_launch_llm_stage(self, mocker: MockerFixture):
- """stage_id == _single_stage_id_filter → _launch_llm_stage is called."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
+ with pytest.raises(ValueError, match="requires both"):
+ engine._start_omni_master_server([_make_llm_plan(0, configured_stage_id=7, launch_mode="local")])
- launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
- assert 0 in launched_ids, "_launch_llm_stage should be called for stage 0"
+ def test_build_logical_stage_init_plans_clears_runtime_cfg_in_single_stage_mode(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
- def test_non_matching_stage_uses_create_remote_llm_stage(self, mocker: MockerFixture):
- """stage_id != _single_stage_id_filter → _create_remote_llm_stage is called."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
+ engine = self._build_engine([_make_stage_cfg(7)], single_stage_mode=True, stage_id_filter=7)
- remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
- assert 1 in remote_ids, "_create_remote_llm_stage should be called for stage 1"
+ monkeypatch = pytest.MonkeyPatch()
+ monkeypatch.setattr(
+ engine_mod,
+ "extract_stage_metadata",
+ lambda cfg: SimpleNamespace(
+ stage_id=cfg.stage_id,
+ stage_type="llm",
+ prompt_expand_func=None,
+ runtime_cfg={"devices": "0"},
+ ),
+ )
+ monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {})
+ monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None))
+ monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {})
+ monkeypatch.setattr(
+ engine_mod,
+ "build_vllm_config",
+ lambda *_, **__: (SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1)), object),
+ )
+ try:
+ stage_plans, _ = engine._build_logical_stage_init_plans(None, [1], {})
+ finally:
+ monkeypatch.undo()
- def test_filter_1_routes_correctly(self, mocker: MockerFixture):
- """With filter=1, stage 0 is remote and stage 1 is local."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=1)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
+ assert stage_plans[0].replicas[0].metadata.runtime_cfg is None
- launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
- remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list]
- assert 1 in launched_ids, "stage 1 should be launched locally with filter=1"
- assert 0 in remote_ids, "stage 0 should use remote path with filter=1"
+ def test_initialize_stages_calls_master_server_only_in_single_stage_mode(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
- def test_no_filter_all_stages_use_launch_path(self, mocker: MockerFixture):
- """single_stage_mode=True but no filter → all stages use _launch_llm_stage."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
+ stage_cfgs = [_make_stage_cfg(0)]
+ engine = self._build_engine(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
+ stage_plan = _make_llm_plan(0, configured_stage_id=0, launch_mode="local")
+ client = SimpleNamespace(
+ stage_type="llm",
+ is_comprehension=False,
+ final_output=True,
+ final_output_type=None,
+ default_sampling_params=SimpleNamespace(),
+ )
+
+ mocker.patch.object(engine_mod, "prepare_engine_environment")
+ mocker.patch.object(engine_mod, "load_omni_transfer_config_for_model", return_value=None)
+ mocker.patch.object(engine_mod, "compute_replica_layout", return_value=([1], {}))
+ mocker.patch.object(engine, "_build_logical_stage_init_plans", return_value=([stage_plan], None))
+ mock_start = mocker.patch.object(engine, "_start_omni_master_server")
+ mocker.patch.object(engine, "_initialize_stage_replicas", return_value={0: [client]})
+ mocker.patch.object(engine_mod, "build_stage0_input_processor", return_value=object())
+ mocker.patch.object(engine_mod, "build_llm_stage_output_processor", return_value=object())
- assert mock_remote.call_count == 0, "No remote launches without a filter"
- launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list]
- assert set(launched_ids) == {0, 1}
+ engine._initialize_stages(stage_init_timeout=60)
+ mock_start.assert_called_once()
+
+ engine = self._build_engine(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
+ mocker.patch.object(engine_mod, "prepare_engine_environment")
+ mocker.patch.object(engine_mod, "load_omni_transfer_config_for_model", return_value=None)
+ mocker.patch.object(engine_mod, "compute_replica_layout", return_value=([1], {}))
+ mocker.patch.object(engine, "_build_logical_stage_init_plans", return_value=([stage_plan], None))
+ mock_start = mocker.patch.object(engine, "_start_omni_master_server")
+ mocker.patch.object(engine, "_initialize_stage_replicas", return_value={0: [client]})
+ mocker.patch.object(engine_mod, "build_stage0_input_processor", return_value=object())
+ mocker.patch.object(engine_mod, "build_llm_stage_output_processor", return_value=object())
- def test_non_single_stage_mode_never_calls_create_remote(self, mocker: MockerFixture):
- """Outside single_stage_mode, _create_remote_llm_stage must not be called."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs)
+ engine._initialize_stages(stage_init_timeout=60)
+ mock_start.assert_not_called()
- assert mock_remote.call_count == 0
+ def test_initialize_stages_stops_master_server_and_shuts_down_initialized_clients_on_failure(
+ self,
+ mocker: MockerFixture,
+ ):
+ import vllm_omni.engine.async_omni_engine as engine_mod
- def test_omni_master_server_started_in_single_stage_mode(self, mocker: MockerFixture):
- """OmniMasterServer.start() must be called when single_stage_mode=True."""
stage_cfgs = [_make_stage_cfg(0)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = mocker.Mock()
- finalized = (
- [mocker.Mock()],
- [mocker.Mock()],
- [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
- )
+ engine = self._build_engine(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
+ stage_plan = _make_llm_plan(0, configured_stage_id=0, launch_mode="local")
+ initialized_client = mocker.Mock()
+ mock_master = mocker.Mock(spec=OmniMasterServer)
+
+ mocker.patch.object(engine_mod, "prepare_engine_environment")
+ mocker.patch.object(engine_mod, "load_omni_transfer_config_for_model", return_value=None)
+ mocker.patch.object(engine_mod, "compute_replica_layout", return_value=([1], {}))
+ mocker.patch.object(engine, "_build_logical_stage_init_plans", return_value=([stage_plan], None))
+
+ def _start_master(_plans):
+ engine._omni_master_server = mock_master
+
+ mocker.patch.object(engine, "_start_omni_master_server", side_effect=_start_master)
+ mocker.patch.object(engine, "_initialize_stage_replicas", return_value={0: [initialized_client]})
+ mocker.patch.object(engine_mod, "build_stage0_input_processor", return_value=object())
+ mocker.patch.object(engine, "_assemble_stage_pools", side_effect=RuntimeError("assemble failed"))
+ mock_shutdown = mocker.patch.object(engine, "_shutdown_initialized_clients")
+
+ with pytest.raises(RuntimeError, match="assemble failed"):
+ engine._initialize_stages(stage_init_timeout=60)
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
+ mock_shutdown.assert_called_once_with([initialized_client])
+ mock_master.stop.assert_called_once()
- engine._initialize_stages(stage_init_timeout=60)
- mock_oms.start.assert_called_once()
+class TestSingleStageReplicaInitialization:
+ def test_initialize_llm_replica_remote_uses_connect_remote_engine_cores(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
- def test_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture):
- """Configured stage IDs, not list indexes, should drive pre-allocation."""
- stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.get_zmq_addresses.return_value = mocker.Mock()
- finalized = (
- [mocker.Mock(), mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
- )
+ engine = object.__new__(AsyncOmniEngine)
+ engine.model = "fake-model"
+ engine.single_stage_mode = True
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine._omni_master_server.get_stage_config.return_value = {"stage_id": 7, "stage_type": "llm"}
- mocker.patch.object(
- engine,
- "_launch_llm_stage",
- side_effect=[_make_started_llm_stage(7), _make_started_llm_stage(11)],
- )
- mocker.patch.object(
- engine,
- "_create_remote_llm_stage",
- return_value=_make_started_llm_stage(11),
- )
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mock_oms_cls = mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ fake_vllm_config = SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1))
+ fake_addresses = SimpleNamespace(
+ inputs=["tcp://in"], outputs=["tcp://out"], frontend_stats_publish_address=None
)
+ fake_manager = mocker.Mock()
+ fake_coordinator = mocker.Mock()
- engine._initialize_stages(stage_init_timeout=60)
+ @contextmanager
+ def _fake_connect(**kwargs):
+ yield fake_manager, fake_coordinator, fake_addresses
- mock_oms_cls.assert_called_once_with(
- master_address=engine._omni_master_address,
- master_port=engine._omni_master_port,
- stage_ids=[7, 11],
- )
+ plan = _make_llm_plan(0, configured_stage_id=7, launch_mode="remote", vllm_config=fake_vllm_config).replicas[0]
+ sentinel_client = SimpleNamespace()
- def test_single_stage_filter_uses_configured_stage_ids(self, mocker: MockerFixture):
- """Local/remote dispatch should compare against configured stage IDs."""
- stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- finalized = (
- [mocker.Mock(), mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs],
- )
-
- mock_launch = mocker.patch.object(
- engine,
- "_launch_llm_stage",
- side_effect=[_make_started_llm_stage(7)],
- )
- mock_remote = mocker.patch.object(
- engine,
- "_create_remote_llm_stage",
- return_value=_make_started_llm_stage(11),
- )
+ mock_connect = mocker.patch.object(engine_mod, "connect_remote_engine_cores", side_effect=_fake_connect)
mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ StageEngineCoreClientBase,
+ "make_async_mp_client",
+ side_effect=lambda **_: sentinel_client,
)
- engine._initialize_stages(stage_init_timeout=60)
+ result = engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock())
- assert [call.args[1].stage_id for call in mock_launch.call_args_list] == [7]
- assert [call.args[1].stage_id for call in mock_remote.call_args_list] == [11]
+ assert result is sentinel_client
+ engine._omni_master_server.get_stage_config.assert_called_once_with(7, timeout_s=60)
+ assert fake_vllm_config.parallel_config.data_parallel_size_local == 0
+ assert mock_connect.call_args.kwargs["stage_id"] == 7
- def test_omni_master_server_preallocates_diffusion_stage_ids(self, mocker: MockerFixture):
- """Diffusion stages should also receive OmniMasterServer allocations."""
- stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11, stage_type="diffusion")]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- finalized = (
- [mocker.Mock(), mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [
- {"final_output": False, "final_output_type": None, "stage_type": "llm"},
- {"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
- ],
- )
+ def test_initialize_llm_replica_remote_missing_registered_stage_config_raises(self, mocker: MockerFixture):
+ engine = object.__new__(AsyncOmniEngine)
+ engine.single_stage_mode = True
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine._omni_master_server.get_stage_config.return_value = None
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(7))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(7))
- mocker.patch.object(engine, "_launch_diffusion_stage", return_value=mocker.Mock())
- mocker.patch.object(
- engine,
- "_create_remote_diffusion_stage",
- return_value=mocker.Mock(),
- )
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mock_oms_cls = mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
- )
+ plan = _make_llm_plan(0, configured_stage_id=7, launch_mode="remote").replicas[0]
- engine._initialize_stages(stage_init_timeout=60)
+ with pytest.raises(ValueError, match="registered without stage config"):
+ engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock())
- mock_oms_cls.assert_called_once_with(
- master_address=engine._omni_master_address,
- master_port=engine._omni_master_port,
- stage_ids=[7, 11],
- )
+ def test_initialize_llm_replica_remote_attach_failure_cleans_up_started_resources(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
- def test_duplicate_llm_stage_ids_raise(self, mocker: MockerFixture):
- """Duplicate configured LLM stage IDs should fail fast."""
- stage_cfgs = [_make_stage_cfg(3), _make_stage_cfg(3)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=3)
+ engine = object.__new__(AsyncOmniEngine)
+ engine.single_stage_mode = True
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine._omni_master_server.get_stage_config.return_value = {"stage_id": 7, "stage_type": "llm"}
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
+ fake_vllm_config = SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1))
+ fake_addresses = SimpleNamespace(
+ inputs=["tcp://in"], outputs=["tcp://out"], frontend_stats_publish_address=None
)
- with pytest.raises(ValueError, match="Duplicate stage_id"):
- engine._initialize_stages(stage_init_timeout=60)
+ fake_manager = mocker.Mock()
+ fake_coordinator = mocker.Mock()
- def test_omni_master_server_not_started_in_normal_mode(self, mocker: MockerFixture):
- """OmniMasterServer must NOT be instantiated outside single_stage_mode."""
- stage_cfgs = [_make_stage_cfg(0)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None)
- finalized = (
- [mocker.Mock()],
- [mocker.Mock()],
- [{"final_output": True, "final_output_type": None, "stage_type": "llm"}],
- )
+ @contextmanager
+ def _fake_connect(**kwargs):
+ yield fake_manager, fake_coordinator, fake_addresses
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
+ plan = _make_llm_plan(0, configured_stage_id=7, launch_mode="remote", vllm_config=fake_vllm_config).replicas[0]
+ mocker.patch.object(engine_mod, "connect_remote_engine_cores", side_effect=_fake_connect)
mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mock_oms_cls = mocker.patch("vllm_omni.engine.async_omni_engine.OmniMasterServer")
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ StageEngineCoreClientBase,
+ "make_async_mp_client",
+ side_effect=RuntimeError("attach failed"),
)
- engine._initialize_stages(stage_init_timeout=60)
-
- mock_oms_cls.assert_not_called()
+ with pytest.raises(RuntimeError, match="attach failed"):
+ engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock())
- def test_single_stage_mode_missing_master_address_raises(self, mocker: MockerFixture):
- """single_stage_mode without master address/port raises ValueError."""
- stage_cfgs = [_make_stage_cfg(0)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- engine._omni_master_address = None # missing
- engine._omni_master_port = None
+ fake_manager.shutdown.assert_called_once()
+ fake_coordinator.shutdown.assert_called_once()
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- with pytest.raises(ValueError, match="omni_master_address"):
- engine._initialize_stages(stage_init_timeout=60)
+ def test_initialize_llm_replica_single_stage_local_uses_launch_omni_core_engines(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ from vllm_omni.platforms import current_omni_platform
- def test_matching_diffusion_stage_uses_local_registered_launch(self, mocker: MockerFixture):
- """A local diffusion stage should use the registered single-stage launch path."""
- stage_cfgs = [_make_stage_cfg(0, stage_type="diffusion"), _make_stage_cfg(1)]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- diffusion_client = mocker.Mock(stage_type="diffusion")
- finalized = (
- [diffusion_client, mocker.Mock()],
- [mocker.Mock(), mocker.Mock()],
- [
- {"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
- {"final_output": False, "final_output_type": None, "stage_type": "llm"},
- ],
- )
+ engine = object.__new__(AsyncOmniEngine)
+ engine.model = "fake-model"
+ engine.single_stage_mode = True
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine.stage_configs = []
- mock_local_diff = mocker.patch.object(
- engine,
- "_launch_diffusion_stage",
- return_value=diffusion_client,
- )
- mock_remote_diff = mocker.patch.object(engine, "_create_remote_diffusion_stage")
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(1))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(1))
- mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ fake_vllm_config = SimpleNamespace(parallel_config=SimpleNamespace())
+ fake_addresses = SimpleNamespace(
+ inputs=["tcp://in"], outputs=["tcp://out"], frontend_stats_publish_address=None
)
- engine._initialize_stages(stage_init_timeout=60)
+ @contextmanager
+ def _fake_launch(**kwargs):
+ yield mocker.Mock(), None, fake_addresses
- assert mock_local_diff.call_count == 1
- assert mock_local_diff.call_args.args[1].stage_id == 0
- mock_remote_diff.assert_not_called()
+ plan = _make_llm_plan(0, configured_stage_id=3, launch_mode="local", vllm_config=fake_vllm_config).replicas[0]
+ sentinel_client = SimpleNamespace()
- def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self, mocker: MockerFixture):
- """A non-local diffusion stage should attach via the remote diffusion path."""
- stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1, stage_type="diffusion")]
- engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0)
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- remote_diffusion_client = mocker.Mock(stage_type="diffusion")
- finalized = (
- [mocker.Mock(), remote_diffusion_client],
- [mocker.Mock(), mocker.Mock()],
- [
- {"final_output": False, "final_output_type": None, "stage_type": "llm"},
- {"final_output": False, "final_output_type": None, "stage_type": "diffusion"},
- ],
- )
+ device_env_var = current_omni_platform.device_control_env_var
+ prev_device_env = os.environ.get(device_env_var)
+ os.environ[device_env_var] = "0"
- mock_local_diff = mocker.patch.object(engine, "_launch_diffusion_stage")
- mock_remote_diff = mocker.patch.object(
- engine,
- "_create_remote_diffusion_stage",
- return_value=remote_diffusion_client,
- )
- mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0))
- mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0))
+ mocker.patch.object(engine_mod, "setup_stage_devices")
+ mocker.patch.object(engine_mod, "build_engine_args_dict", return_value={})
+ mocker.patch.object(engine_mod, "acquire_device_locks", return_value=[])
+ mocker.patch.object(engine_mod, "release_device_locks")
+ mock_launch = mocker.patch.object(engine_mod, "launch_omni_core_engines", side_effect=_fake_launch)
mocker.patch.object(
- engine,
- "_attach_llm_stage",
- return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.OmniMasterServer",
- return_value=mock_oms,
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model",
- return_value=None,
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.get_stage_connector_spec",
- return_value={},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage",
- return_value=(None, None, None),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.extract_stage_metadata",
- side_effect=lambda cfg: self._fake_metadata(
- mocker,
- cfg.stage_id,
- getattr(cfg, "stage_type", "llm"),
- ),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.finalize_initialized_stages",
- return_value=finalized,
+ StageEngineCoreClientBase,
+ "make_async_mp_client",
+ side_effect=lambda **_: sentinel_client,
)
+ try:
+ result = engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock())
+ finally:
+ if prev_device_env is None:
+ os.environ.pop(device_env_var, None)
+ else:
+ os.environ[device_env_var] = prev_device_env
- engine._initialize_stages(stage_init_timeout=60)
+ assert result is sentinel_client
+ assert mock_launch.call_args.kwargs["stage_id"] == 3
+ assert mock_launch.call_args.kwargs["stage_config"] is plan.stage_cfg
+
+ def test_initialize_diffusion_replica_remote_uses_from_addresses(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
+
+ engine = object.__new__(AsyncOmniEngine)
+ engine.single_stage_mode = True
+ engine.diffusion_batch_size = 4
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine._omni_master_server.get_stage_config.return_value = {"stage_id": 11, "stage_type": "diffusion"}
+ engine._omni_master_server.get_zmq_addresses.return_value = SimpleNamespace(
+ inputs=["tcp://in"],
+ outputs=["tcp://out"],
+ )
- mock_local_diff.assert_not_called()
- assert mock_remote_diff.call_count == 1
- assert mock_remote_diff.call_args.args[0].stage_id == 1
+ remote_metadata = _make_diffusion_plan(1, configured_stage_id=11, launch_mode="remote").replicas[0].metadata
+ plan = _make_diffusion_plan(1, configured_stage_id=11, launch_mode="remote").replicas[0]
+ sentinel_client = SimpleNamespace()
+ mocker.patch.object(engine_mod, "extract_stage_metadata", return_value=remote_metadata)
+ mock_from_addresses = mocker.patch.object(
+ engine_mod.StageDiffusionClient, "from_addresses", return_value=sentinel_client
+ )
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _launch_diffusion_stage
-# ---------------------------------------------------------------------------
+ result = engine._initialize_diffusion_replica(plan, stage_init_timeout=60, stage_launch_lock=threading.Lock())
+ assert result is sentinel_client
+ engine._omni_master_server.get_stage_config.assert_called_once_with(11, timeout_s=60)
+ mock_from_addresses.assert_called_once()
-class TestLaunchDiffusionStage:
- """Test local diffusion stage launch wiring."""
+ def test_initialize_diffusion_replica_single_stage_local_registers_with_master(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ from vllm_omni.platforms import current_omni_platform
- def test_registers_stage_with_public_master_properties(self, mocker: MockerFixture):
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
+ engine.single_stage_mode = True
engine.diffusion_batch_size = 4
+ engine.stage_configs = []
+ engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
+ engine._omni_master_server.address = "127.0.0.1"
+ engine._omni_master_server.port = 25000
- stage_cfg = _make_stage_cfg(5, stage_type="diffusion")
- metadata = mocker.Mock(stage_id=5)
- omni_master_server = mocker.Mock(spec=OmniMasterServer)
- omni_master_server.address = "127.0.0.1"
- omni_master_server.port = 25000
-
+ plan = _make_diffusion_plan(0, configured_stage_id=5, launch_mode="local").replicas[0]
+ sentinel_client = SimpleNamespace()
proc = mocker.Mock()
- diffusion_client = mocker.Mock()
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_diffusion_config",
- return_value="diffusion-config",
- )
- mock_register = mocker.patch(
- "vllm_omni.engine.async_omni_engine.register_stage_with_omni_master",
- return_value=(
- "tcp://127.0.0.1:25001",
- "tcp://127.0.0.1:25002",
- "tcp://127.0.0.1:25003",
- ),
- )
- mock_spawn = mocker.patch(
- "vllm_omni.engine.async_omni_engine.spawn_diffusion_proc",
+ device_env_var = current_omni_platform.device_control_env_var
+ prev_device_env = os.environ.get(device_env_var)
+ os.environ[device_env_var] = "0"
+
+ mocker.patch.object(engine_mod, "setup_stage_devices")
+ mocker.patch.object(engine_mod, "inject_kv_stage_info")
+ mocker.patch.object(engine_mod, "build_diffusion_config", return_value="diffusion-config")
+ mock_register = mocker.patch.object(
+ engine_mod,
+ "register_stage_with_omni_master",
+ return_value=("tcp://hs", "tcp://req", "tcp://resp"),
+ )
+ mock_spawn = mocker.patch.object(
+ engine_mod,
+ "spawn_diffusion_proc",
return_value=(proc, None, None, None),
)
- mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_diffusion_handshake")
- mock_from_addresses = mocker.patch(
- "vllm_omni.engine.async_omni_engine.StageDiffusionClient.from_addresses",
- return_value=diffusion_client,
+ mock_handshake = mocker.patch.object(engine_mod, "complete_diffusion_handshake")
+ mock_from_addresses = mocker.patch.object(
+ engine_mod.StageDiffusionClient,
+ "from_addresses",
+ return_value=sentinel_client,
)
- result = engine._launch_diffusion_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- omni_master_server=omni_master_server,
- )
+ try:
+ result = engine._initialize_diffusion_replica(
+ plan, stage_init_timeout=60, stage_launch_lock=threading.Lock()
+ )
+ finally:
+ if prev_device_env is None:
+ os.environ.pop(device_env_var, None)
+ else:
+ os.environ[device_env_var] = prev_device_env
+ assert result is sentinel_client
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
omni_master_port=25000,
omni_stage_id=5,
- omni_stage_config=stage_cfg,
+ omni_stage_config=plan.stage_cfg,
return_addresses=True,
)
mock_spawn.assert_called_once_with(
"fake-model",
"diffusion-config",
- handshake_address="tcp://127.0.0.1:25001",
- request_address="tcp://127.0.0.1:25002",
- response_address="tcp://127.0.0.1:25003",
+ handshake_address="tcp://hs",
+ request_address="tcp://req",
+ response_address="tcp://resp",
)
- mock_handshake.assert_called_once_with(proc, "tcp://127.0.0.1:25001")
+ mock_handshake.assert_called_once_with(proc, "tcp://hs")
mock_from_addresses.assert_called_once_with(
- metadata,
- request_address="tcp://127.0.0.1:25002",
- response_address="tcp://127.0.0.1:25003",
+ plan.metadata,
+ request_address="tcp://req",
+ response_address="tcp://resp",
proc=proc,
batch_size=4,
)
- assert result is diffusion_client
-
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _create_remote_llm_stage
-# ---------------------------------------------------------------------------
-
-
-class TestCreateRemoteLlmStage:
- """Test _create_remote_llm_stage delegates correctly."""
+ def test_initialize_diffusion_replica_local_failure_terminates_proc(self, mocker: MockerFixture):
+ import vllm_omni.engine.async_omni_engine as engine_mod
+ from vllm_omni.platforms import current_omni_platform
- def _engine(self, mocker: MockerFixture) -> AsyncOmniEngine:
engine = object.__new__(AsyncOmniEngine)
engine.model = "fake-model"
engine.single_stage_mode = True
- engine._single_stage_id_filter = 0
+ engine.diffusion_batch_size = 4
+ engine.stage_configs = []
engine._omni_master_server = mocker.Mock(spec=OmniMasterServer)
- engine._omni_master_server.get_zmq_addresses.return_value = mocker.Mock()
- engine._omni_master_server.get_allocation.return_value = mocker.Mock()
- engine._omni_master_server.get_stage_config.return_value = {
- "stage_id": 0,
- "stage_type": "llm",
- "engine_args": {},
- }
- return engine
-
- def _mock_build_and_connect(self, mocker: MockerFixture, stage_id: int):
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- eng_mgr = mocker.Mock()
- coordinator = mocker.Mock()
+ engine._omni_master_server.address = "127.0.0.1"
+ engine._omni_master_server.port = 25000
- @contextmanager
- def fake_connect_cm(*args, **kwargs):
- yield eng_mgr, coordinator, fake_addresses
-
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mock_connect = mocker.patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=fake_connect_cm(),
- )
-
- return mock_connect, fake_vllm_config, fake_executor_cls, fake_addresses
-
- def test_returns_started_llm_stage_with_correct_stage_id(self, mocker: MockerFixture):
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(1)
- metadata = mocker.Mock(stage_id=1)
- omni_ms = engine._omni_master_server
- omni_ms.get_stage_config.return_value = {
- "stage_id": 1,
- "stage_type": "llm",
- "engine_args": {},
- }
-
- self._mock_build_and_connect(mocker, 1)
- result = engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
- assert isinstance(result, StartedLlmStage)
- assert result.stage_id == 1
-
- def test_connect_remote_engine_cores_called_with_stage_id(self, mocker: MockerFixture):
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(2)
- metadata = mocker.Mock(stage_id=2)
- omni_ms = engine._omni_master_server
- omni_ms.get_zmq_addresses.return_value = mocker.Mock(inputs=["x"], outputs=["y"])
- omni_ms.get_stage_config.return_value = {
- "stage_id": 2,
- "stage_type": "llm",
- "engine_args": {},
- }
-
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
+ plan = _make_diffusion_plan(0, configured_stage_id=5, launch_mode="local").replicas[0]
+ proc = mocker.Mock()
- @contextmanager
- def fake_connect_cm(*args, **kwargs):
- yield mocker.Mock(), mocker.Mock(), fake_addresses
+ device_env_var = current_omni_platform.device_control_env_var
+ prev_device_env = os.environ.get(device_env_var)
+ os.environ[device_env_var] = "0"
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 2},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mock_connect = mocker.patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=fake_connect_cm(),
+ mocker.patch.object(engine_mod, "setup_stage_devices")
+ mocker.patch.object(engine_mod, "inject_kv_stage_info")
+ mocker.patch.object(engine_mod, "build_diffusion_config", return_value="diffusion-config")
+ mocker.patch.object(
+ engine_mod,
+ "register_stage_with_omni_master",
+ return_value=("tcp://hs", "tcp://req", "tcp://resp"),
)
-
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
+ mocker.patch.object(
+ engine_mod,
+ "spawn_diffusion_proc",
+ return_value=(proc, None, None, None),
)
+ mocker.patch.object(engine_mod, "complete_diffusion_handshake", side_effect=RuntimeError("handshake failed"))
+ mock_terminate = mocker.patch.object(engine_mod, "terminate_alive_proc")
- mock_connect.assert_called_once()
- _, kwargs = mock_connect.call_args
- assert kwargs.get("stage_id") == 2 or mock_connect.call_args.args[-1] == 2
- omni_ms.get_stage_config.assert_called_once_with(2, timeout_s=60)
-
- def test_missing_registered_stage_config_raises_value_error(self, mocker: MockerFixture):
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(3)
- metadata = mocker.Mock(stage_id=3)
- omni_ms = engine._omni_master_server
- omni_ms.get_stage_config.return_value = None
-
- mock_build_args = mocker.patch("vllm_omni.engine.async_omni_engine.build_engine_args_dict")
- with pytest.raises(
- ValueError,
- match="Remote stage 3 registered without stage config",
- ):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
+ try:
+ with pytest.raises(RuntimeError, match="handshake failed"):
+ engine._initialize_diffusion_replica(plan, stage_init_timeout=60, stage_launch_lock=threading.Lock())
+ finally:
+ if prev_device_env is None:
+ os.environ.pop(device_env_var, None)
+ else:
+ os.environ[device_env_var] = prev_device_env
- mock_build_args.assert_not_called()
-
- def test_exception_during_connect_closes_started_stage(self, mocker: MockerFixture):
- """If an error occurs after StartedLlmStage creation, close_started_llm_stage is called."""
- engine = self._engine(mocker)
- stage_cfg = _make_stage_cfg(1)
- metadata = mocker.Mock(stage_id=1)
- omni_ms = engine._omni_master_server
- omni_ms.get_stage_config.return_value = {
- "stage_id": 1,
- "stage_type": "llm",
- "engine_args": {},
- }
+ mock_terminate.assert_called_once_with(proc)
- @contextmanager
- def boom(*args, **kwargs):
- yield mocker.Mock(), mocker.Mock(), mocker.Mock()
- raise RuntimeError("handshake failed")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 1},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores",
- return_value=boom(),
- )
- mock_close = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
- with pytest.raises(RuntimeError, match="handshake failed"):
- engine._create_remote_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- omni_master_server=omni_ms,
- )
- mock_close.assert_called_once()
+# ---------------------------------------------------------------------------
+# Stage engine startup helpers
+# ---------------------------------------------------------------------------
class TestConnectRemoteEngineCoresCoordinator:
- """Test coordinator launch parity with launch_core_engines."""
-
@staticmethod
def _build_vllm_config(
mocker: MockerFixture, *, dp_rank: int = 0, offline_mode: bool = False, needs_dp_coordinator: bool = True
@@ -1347,7 +896,8 @@ def test_uses_registered_coordinator_addresses(self, mocker: MockerFixture):
omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
- inputs=["tcp://client-in"], outputs=["tcp://client-out"]
+ inputs=["tcp://client-in"],
+ outputs=["tcp://client-out"],
)
omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses(
@@ -1360,10 +910,7 @@ def test_uses_registered_coordinator_addresses(self, mocker: MockerFixture):
def fake_socket_ctx(*args, **kwargs):
yield mocker.Mock()
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
+ mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx())
mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
with connect_remote_engine_cores(
vllm_config=vllm_config,
@@ -1379,16 +926,12 @@ def fake_socket_ctx(*args, **kwargs):
mock_wait.assert_called_once()
def test_defaults_to_no_coordinator_addresses_when_none_registered(self, mocker: MockerFixture):
- vllm_config = self._build_vllm_config(
- mocker,
- dp_rank=0,
- offline_mode=False,
- needs_dp_coordinator=True,
- )
+ vllm_config = self._build_vllm_config(mocker, dp_rank=0, offline_mode=False, needs_dp_coordinator=True)
omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
- inputs=["tcp://client-in"], outputs=["tcp://client-out"]
+ inputs=["tcp://client-in"],
+ outputs=["tcp://client-out"],
)
omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses()
@@ -1397,10 +940,7 @@ def test_defaults_to_no_coordinator_addresses_when_none_registered(self, mocker:
def fake_socket_ctx(*args, **kwargs):
yield mocker.Mock()
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
+ mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx())
mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup")
with connect_remote_engine_cores(
vllm_config=vllm_config,
@@ -1414,8 +954,6 @@ def fake_socket_ctx(*args, **kwargs):
class TestLaunchOmniCoreEngines:
- """Tests for local omni engine launch wiring."""
-
def test_registers_stage_once_and_reuses_handshake_for_all_local_engines(self, mocker: MockerFixture):
parallel_config = mocker.Mock(
data_parallel_size_local=2,
@@ -1440,10 +978,7 @@ def fake_socket_ctx(*args, **kwargs):
"vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
return_value="tcp://127.0.0.1:26001",
)
- mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
+ mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx())
mock_manager_cls = mocker.patch(
"vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
return_value=local_engine_manager,
@@ -1459,6 +994,7 @@ def fake_socket_ctx(*args, **kwargs):
) as (yielded_manager, yielded_coordinator, yielded_addresses):
assert yielded_manager is local_engine_manager
assert yielded_coordinator is None
+ assert yielded_addresses is not None
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -1467,15 +1003,11 @@ def fake_socket_ctx(*args, **kwargs):
omni_stage_config=stage_config,
coordinator=None,
)
- mock_manager_cls.assert_called_once()
manager_kwargs = mock_manager_cls.call_args.kwargs
assert manager_kwargs["local_engine_count"] == 2
assert manager_kwargs["start_index"] == 3
assert manager_kwargs["local_start_index"] == 0
- assert manager_kwargs["vllm_config"] is vllm_config
- assert manager_kwargs["local_client"] is True
assert manager_kwargs["handshake_address"] == "tcp://127.0.0.1:26001"
- assert manager_kwargs["executor_class"] is not None
def test_registers_stage_with_coordinator_when_started(self, mocker: MockerFixture):
parallel_config = mocker.Mock(
@@ -1483,15 +1015,19 @@ def test_registers_stage_with_coordinator_when_started(self, mocker: MockerFixtu
data_parallel_size=2,
data_parallel_rank=0,
)
- vllm_config = mocker.Mock(parallel_config=parallel_config)
- vllm_config.needs_dp_coordinator = True
- vllm_config.model_config = mocker.Mock(is_moe=False)
+ vllm_config = mocker.Mock(
+ parallel_config=parallel_config,
+ needs_dp_coordinator=True,
+ model_config=mocker.Mock(is_moe=False),
+ cache_config=mocker.Mock(),
+ )
omni_master_server = mocker.Mock(spec=OmniMasterServer)
omni_master_server.address = "127.0.0.1"
omni_master_server.port = 26000
omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses(
- inputs=["tcp://client-in"], outputs=["tcp://client-out"]
+ inputs=["tcp://client-in"],
+ outputs=["tcp://client-out"],
)
omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001")
@@ -1509,11 +1045,8 @@ def fake_socket_ctx(*args, **kwargs):
"vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master",
return_value="tcp://127.0.0.1:26001",
)
+ mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx())
mocker.patch(
- "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx",
- return_value=fake_socket_ctx(),
- )
- mock_manager_cls = mocker.patch(
"vllm_omni.engine.stage_engine_startup.CoreEngineProcManager",
return_value=mocker.Mock(),
)
@@ -1525,8 +1058,11 @@ def fake_socket_ctx(*args, **kwargs):
omni_master_server=omni_master_server,
stage_id=7,
stage_config={"stage_id": 7},
- ):
- pass
+ ) as (_, yielded_coordinator, yielded_addresses):
+ assert yielded_coordinator is coordinator
+ assert yielded_addresses.coordinator_input == "tcp://coord-in"
+ assert yielded_addresses.coordinator_output == "tcp://coord-out"
+ assert yielded_addresses.frontend_stats_publish_address == "tcp://stats"
mock_register.assert_called_once_with(
omni_master_address="127.0.0.1",
@@ -1535,312 +1071,4 @@ def fake_socket_ctx(*args, **kwargs):
omni_stage_config={"stage_id": 7},
coordinator=coordinator,
)
- manager_kwargs = mock_manager_cls.call_args.kwargs
- assert manager_kwargs["log_stats"] is False
mock_wait.assert_called_once()
-
-
-# ---------------------------------------------------------------------------
-# AsyncOmniEngine – _launch_llm_stage single_stage_mode codepath
-# ---------------------------------------------------------------------------
-
-
-class TestLaunchLlmStageSingleStageMode:
- """Test that _launch_llm_stage selects launch_omni_core_engines when
- single_stage_mode=True and _omni_master_server is set."""
-
- def _build_engine_with_oms(self, mocker: MockerFixture) -> AsyncOmniEngine:
- engine = object.__new__(AsyncOmniEngine)
- engine.model = "fake-model"
- engine.single_stage_mode = True
- engine._single_stage_id_filter = 0
- engine._llm_stage_launch_lock = threading.Lock()
- engine.stage_configs = []
- mock_oms = mocker.Mock(spec=OmniMasterServer)
- mock_oms.address = "127.0.0.1"
- mock_oms.port = 25000
- alloc = mocker.Mock()
- alloc.handshake_bind_address = "tcp://127.0.0.1:25001"
- mock_oms.get_allocation.return_value = alloc
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
- mock_oms.get_zmq_addresses.return_value = fake_addresses
- engine._omni_master_server = mock_oms
- return engine
-
- def _mock_launch_omni(self, mocker: MockerFixture, stage_id: int):
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- eng_mgr = mocker.Mock()
-
- @contextmanager
- def fake_launch_omni(*args, **kwargs):
- yield eng_mgr, None, fake_addresses
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": stage_id},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- return mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- )
-
- def test_launch_omni_core_engines_used_in_single_stage_mode(self, mocker: MockerFixture):
- """single_stage_mode + _omni_master_server → launch_omni_core_engines."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- stage_cfg = _make_stage_cfg(0)
-
- mock_launch_omni = self._mock_launch_omni(mocker, 0)
- result = engine._launch_llm_stage(
- stage_cfg=stage_cfg,
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_launch_omni.assert_called_once()
- assert mock_launch_omni.call_args.kwargs["stage_config"] is stage_cfg
- assert isinstance(result, StartedLlmStage)
- assert result.stage_id == 0
-
- def test_spawn_stage_core_used_in_normal_mode(self, mocker: MockerFixture):
- """~single_stage_mode → spawn_stage_core + complete_stage_handshake."""
- engine = object.__new__(AsyncOmniEngine)
- engine.model = "fake-model"
- engine.single_stage_mode = False
- engine._omni_master_server = None
- engine._llm_stage_launch_lock = threading.Lock()
- engine.stage_configs = []
-
- fake_vllm_config = mocker.Mock()
- fake_executor_cls = mocker.Mock()
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- fake_proc = mocker.Mock()
- fake_handshake_address = "ipc:///tmp/fake-handshake"
- stage_init_timeout = 60
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(fake_vllm_config, fake_executor_cls),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mock_spawn = mocker.patch(
- "vllm_omni.engine.async_omni_engine.spawn_stage_core",
- return_value=(fake_addresses, fake_proc, fake_handshake_address),
- )
- mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_stage_handshake")
- mock_omni = mocker.patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines")
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
- result = engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=stage_init_timeout,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_spawn.assert_called_once_with(
- vllm_config=fake_vllm_config,
- executor_class=fake_executor_cls,
- log_stats=False,
- )
- mock_handshake.assert_called_once_with(
- fake_proc,
- fake_handshake_address,
- fake_addresses,
- fake_vllm_config,
- stage_init_timeout,
- )
- mock_omni.assert_not_called()
- assert isinstance(result, StartedLlmStage)
- assert result.proc is fake_proc
-
- def test_launch_omni_passes_stage_id_and_master_server(self, mocker: MockerFixture):
- """launch_omni_core_engines receives the correct stage_id and omni_master_server."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
-
- captured_kwargs: dict[str, Any] = {}
-
- @contextmanager
- def capturing_launch(*args, **kwargs):
- captured_kwargs.update(kwargs)
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
- yield mocker.Mock(), None, fake_addresses
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- side_effect=capturing_launch,
- )
-
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- assert captured_kwargs.get("stage_id") == 0
- assert captured_kwargs.get("omni_master_server") is engine._omni_master_server
-
- def test_launch_omni_context_exits_before_stage_cleanup_on_error(self, mocker: MockerFixture):
- """Errors after entering the omni launch context still unwind it first."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
-
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- events: list[str] = []
-
- @contextmanager
- def fake_launch_omni(*args, **kwargs):
- try:
- yield mocker.Mock(), None, fake_addresses
- finally:
- events.append("launch_exit")
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.logger.info", side_effect=RuntimeError("boom"))
- mock_close_stage = mocker.patch(
- "vllm_omni.engine.async_omni_engine.close_started_llm_stage",
- side_effect=lambda _started: events.append("stage_close"),
- )
- with pytest.raises(RuntimeError, match="boom"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_close_stage.assert_called_once()
- assert events == ["launch_exit", "stage_close"]
-
- def test_base_exception_propagates_without_started_stage_cleanup(self, mocker: MockerFixture):
- """BaseException subclasses should bypass the Exception cleanup path."""
- engine = self._build_engine_with_oms(mocker)
- metadata = mocker.Mock(stage_id=0, runtime_cfg={})
-
- fake_addresses = mocker.Mock()
- fake_addresses.inputs = ["tcp://127.0.0.1:5000"]
- fake_addresses.outputs = ["tcp://127.0.0.1:5001"]
- fake_addresses.frontend_stats_publish_address = None
-
- events: list[str] = []
-
- class FatalLaunchInterrupt(BaseException):
- pass
-
- @contextmanager
- def fake_launch_omni(*args, **kwargs):
- try:
- yield mocker.Mock(), None, fake_addresses
- finally:
- events.append("launch_exit")
-
- mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_engine_args_dict",
- return_value={"model": "fake", "stage_id": 0},
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.build_vllm_config",
- return_value=(mocker.Mock(), mocker.Mock()),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.acquire_device_locks",
- return_value=[],
- )
- mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks")
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.launch_omni_core_engines",
- return_value=fake_launch_omni(),
- )
- mocker.patch(
- "vllm_omni.engine.async_omni_engine.logger.info",
- side_effect=FatalLaunchInterrupt("stop"),
- )
- mock_close_stage = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage")
- with pytest.raises(FatalLaunchInterrupt, match="stop"):
- engine._launch_llm_stage(
- stage_cfg=_make_stage_cfg(0),
- metadata=metadata,
- stage_connector_spec={},
- stage_init_timeout=60,
- llm_stage_launch_lock=threading.Lock(),
- )
-
- mock_close_stage.assert_not_called()
- assert events == ["launch_exit"]
diff --git a/tests/engine/test_stage_engine_core_client.py b/tests/engine/test_stage_engine_core_client.py
new file mode 100644
index 00000000000..dde0927af2d
--- /dev/null
+++ b/tests/engine/test_stage_engine_core_client.py
@@ -0,0 +1,46 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for StageEngineCoreClient.check_health().
+
+Uses object.__new__ to construct a minimal client — check_health only
+touches self.resources, self.stage_id, and self._proc.
+"""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from unittest.mock import MagicMock
+
+import pytest
+from vllm.v1.engine.exceptions import EngineDeadError
+
+from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _make_client(*, engine_dead=False, proc_alive=True):
+ client = object.__new__(StageEngineCoreClient)
+ client.stage_id = 0
+ client.resources = SimpleNamespace(engine_dead=engine_dead)
+ client._proc = MagicMock(is_alive=MagicMock(return_value=proc_alive), exitcode=1)
+ return client
+
+
+def test_check_health_passes_when_alive():
+ client = _make_client(engine_dead=False, proc_alive=True)
+ client.check_health() # no exception
+
+
+def test_check_health_raises_when_resources_engine_dead():
+ client = _make_client(engine_dead=True, proc_alive=True)
+ with pytest.raises(EngineDeadError, match="engine core is dead"):
+ client.check_health()
+
+
+def test_check_health_raises_when_proc_not_alive():
+ client = _make_client(engine_dead=False, proc_alive=False)
+ with pytest.raises(EngineDeadError, match="not alive"):
+ client.check_health()
+ # Verify it set resources.engine_dead as a side effect
+ assert client.resources.engine_dead is True
diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py
index 607b3eaa813..8a373f74d27 100644
--- a/tests/entrypoints/openai_api/test_image_server.py
+++ b/tests/entrypoints/openai_api/test_image_server.py
@@ -407,6 +407,29 @@ def test_health_endpoint_no_engine():
assert data["status"] == "unhealthy"
+def test_health_endpoint_dead_engine():
+ """Health returns 503 when the engine raises EngineDeadError."""
+ from unittest.mock import AsyncMock
+
+ from fastapi import FastAPI
+ from vllm.v1.engine.exceptions import EngineDeadError
+
+ from vllm_omni.entrypoints.openai.api_server import router
+
+ app = FastAPI()
+ app.include_router(router)
+
+ dead_engine = AsyncMock()
+ dead_engine.check_health = AsyncMock(side_effect=EngineDeadError())
+ app.state.engine_client = dead_engine
+
+ client = TestClient(app)
+ response = client.get("/health")
+ assert response.status_code == 503
+ data = response.json()
+ assert data["status"] == "unhealthy"
+
+
def test_models_endpoint(test_client):
"""Test /v1/models endpoint for diffusion mode"""
response = test_client.get("/v1/models")
diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
index 4190b1fbb13..0a9d45a8ccb 100644
--- a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
+++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py
@@ -100,6 +100,10 @@ def mock_request(mocker: MockerFixture):
request.stop_token_ids = None
request.frequency_penalty = None
request.presence_penalty = None
+ # Must be real Python objects (not MagicMock) so the code's explicit-field
+ # and extra_body checks work correctly.
+ request.model_fields_set = set()
+ request.extra_body = {}
return request
@@ -150,6 +154,7 @@ def test_preserves_yaml_defaults_when_no_request_params(serving_chat, mock_reque
def test_request_temperature_overrides_yaml_default(serving_chat, mock_request):
"""Test that request temperature overrides YAML default."""
mock_request.temperature = 0.8
+ mock_request.model_fields_set = {"temperature"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -162,6 +167,7 @@ def test_request_temperature_overrides_yaml_default(serving_chat, mock_request):
def test_request_top_p_overrides_yaml_default(serving_chat, mock_request):
"""Test that request top_p overrides YAML default."""
mock_request.top_p = 0.95
+ mock_request.model_fields_set = {"top_p"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -173,6 +179,7 @@ def test_request_top_p_overrides_yaml_default(serving_chat, mock_request):
def test_request_max_tokens_overrides_yaml_default(serving_chat, mock_request):
"""Test that request max_tokens overrides YAML default."""
mock_request.max_tokens = 100
+ mock_request.model_fields_set = {"max_tokens"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -189,6 +196,7 @@ def test_max_tokens_uses_yaml_default_when_not_specified(serving_chat, mock_requ
def test_request_seed_overrides_yaml_default(serving_chat, mock_request):
"""Test that request seed overrides YAML default."""
mock_request.seed = 123
+ mock_request.model_fields_set = {"seed"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -200,6 +208,7 @@ def test_request_seed_overrides_yaml_default(serving_chat, mock_request):
def test_request_frequency_penalty_overrides(serving_chat, mock_request):
"""Test that request frequency_penalty is applied."""
mock_request.frequency_penalty = 0.5
+ mock_request.model_fields_set = {"frequency_penalty"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -209,6 +218,7 @@ def test_request_frequency_penalty_overrides(serving_chat, mock_request):
def test_request_presence_penalty_overrides(serving_chat, mock_request):
"""Test that request presence_penalty is applied."""
mock_request.presence_penalty = 0.3
+ mock_request.model_fields_set = {"presence_penalty"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -235,6 +245,7 @@ def test_multiple_params_override_together(serving_chat, mock_request):
mock_request.temperature = 0.7
mock_request.top_p = 0.85
mock_request.seed = 999
+ mock_request.model_fields_set = {"max_tokens", "temperature", "top_p", "seed"}
result = serving_chat._build_sampling_params_list_from_request(mock_request)
@@ -275,6 +286,7 @@ def test_apply_request_overrides_applies_values(serving_chat, mock_request, defa
"""Test that _apply_request_overrides applies non-None request values."""
mock_request.temperature = 0.8
mock_request.seed = 123
+ mock_request.model_fields_set = {"temperature", "seed"}
result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request)
@@ -304,6 +316,8 @@ def test_apply_overrides_empty_stop_list_preserves_default(serving_chat, mocker)
request.stop_token_ids = None
request.frequency_penalty = None
request.presence_penalty = None
+ request.model_fields_set = {"stop"}
+ request.extra_body = {}
result = serving_chat._apply_request_overrides(default_params, request)
@@ -325,6 +339,8 @@ def test_apply_overrides_nonempty_stop_list_overrides_default(serving_chat, mock
request.stop_token_ids = None
request.frequency_penalty = None
request.presence_penalty = None
+ request.model_fields_set = {"stop"}
+ request.extra_body = {}
result = serving_chat._apply_request_overrides(default_params, request)
@@ -367,6 +383,8 @@ def test_apply_overrides_nonempty_stop_token_ids_overrides_default(serving_chat,
request.stop_token_ids = [100] # non-empty list — should override
request.frequency_penalty = None
request.presence_penalty = None
+ request.model_fields_set = {"stop_token_ids"}
+ request.extra_body = {}
result = serving_chat._apply_request_overrides(default_params, request)
@@ -392,6 +410,8 @@ def test_apply_overrides_mixed_empty_and_nonempty_lists(serving_chat, mocker):
request.stop_token_ids = [100, 200] # non-empty — SHOULD override
request.frequency_penalty = None
request.presence_penalty = None
+ request.model_fields_set = {"temperature", "stop", "stop_token_ids"}
+ request.extra_body = {}
result = serving_chat._apply_request_overrides(default_params, request)
@@ -415,6 +435,8 @@ def test_apply_overrides_none_scalar_still_preserves_default(serving_chat, mocke
request.stop_token_ids = None
request.frequency_penalty = None
request.presence_penalty = None
+ request.model_fields_set = set()
+ request.extra_body = {}
result = serving_chat._apply_request_overrides(default_params, request)
@@ -442,6 +464,8 @@ def test_apply_overrides_both_lists_empty_preserves_defaults(serving_chat, mocke
request.stop_token_ids = []
request.frequency_penalty = None
request.presence_penalty = None
+ request.model_fields_set = {"stop", "stop_token_ids"}
+ request.extra_body = {}
result = serving_chat._apply_request_overrides(default_params, request)
@@ -511,3 +535,165 @@ def test_get_comprehension_stage_index_raises_when_not_found(mocker: MockerFixtu
with pytest.raises(ValueError, match="No comprehension stage"):
instance._get_comprehension_stage_index()
+
+
+# =============================================================================
+# Tests for _resolve_height_width_from_extra_body
+# =============================================================================
+
+
+class TestResolveHeightWidth:
+ def test_explicit_height_width(self):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"height": 512, "width": 768})
+ assert h == 512
+ assert w == 768
+
+ def test_size_string(self):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768x512"})
+ assert w == 768
+ assert h == 512
+
+ def test_size_string_uppercase(self):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768X512"})
+ assert w == 768
+ assert h == 512
+
+ def test_size_fallback_when_height_missing(self):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "512x512", "width": 1024})
+ # height is None -> size fallback fires and sets BOTH width and height
+ assert h == 512
+ assert w == 512
+
+ def test_empty_extra_body(self):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({})
+ assert h is None
+ assert w is None
+
+ def test_invalid_size_format_ignored(self):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "invalid"})
+ assert h is None
+ assert w is None
+
+
+# =============================================================================
+# Tests for _apply_request_overrides with GLM-Image (max_tokens computation)
+# =============================================================================
+
+
+class TestApplyRequestOverridesGLMImage:
+ """Test dynamic max_tokens computation for GLM-Image AR stage."""
+
+ @pytest.fixture
+ def glm_serving_chat(self, mock_engine_client, mocker: MockerFixture):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ instance = object.__new__(OmniOpenAIServingChat)
+ instance.engine_client = mock_engine_client
+ # Mock the image extraction to return no reference images (t2i by default)
+ instance._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock(return_value=("a cat", []))
+ return instance
+
+ @pytest.fixture
+ def glm_request(self, mocker: MockerFixture):
+ req = mocker.MagicMock()
+ req.temperature = None
+ req.top_p = None
+ req.top_k = None
+ req.max_tokens = None
+ req.min_tokens = None
+ req.seed = None
+ req.ignore_eos = None
+ req.stop = None
+ req.stop_token_ids = None
+ req.frequency_penalty = None
+ req.presence_penalty = None
+ req.extra_body = {"height": 1024, "width": 1024}
+ req.model_fields_set = set()
+ return req
+
+ def test_t2i_computes_max_tokens(self, glm_serving_chat, glm_request, default_comprehension_params):
+ """t2i mode: max_tokens computed from height/width, no reference images."""
+ result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request)
+ # t2i 1024x1024 = 256 + 1024 + 1 = 1281
+ assert result.max_tokens == 1281
+ assert result.extra_args["target_h"] == 1024
+ assert result.extra_args["target_w"] == 1024
+
+ def test_i2i_computes_fewer_tokens(
+ self, glm_serving_chat, glm_request, default_comprehension_params, mocker: MockerFixture
+ ):
+ """i2i mode: max_tokens should be smaller than t2i for same dimensions."""
+ # Make it detect reference images
+ glm_serving_chat._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock(
+ return_value=("edit this", ["fake_image"])
+ )
+
+ result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request)
+ # i2i 1024x1024 = 1024 + 1 = 1025
+ assert result.max_tokens == 1025
+
+ def test_dynamic_max_tokens_overrides_user_value(self, glm_serving_chat, glm_request, default_comprehension_params):
+ """When height/width are provided, dynamic computation overrides user max_tokens."""
+ glm_request.max_tokens = 500
+ glm_request.model_fields_set = {"max_tokens"}
+
+ result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request)
+ # Dynamic computation from height/width always wins when present
+ assert result.max_tokens == 1281
+
+ def test_no_height_width_preserves_default(
+ self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params
+ ):
+ """When no height/width in extra_body, keep YAML default max_tokens."""
+ req = mocker.MagicMock()
+ req.temperature = None
+ req.top_p = None
+ req.top_k = None
+ req.max_tokens = None
+ req.min_tokens = None
+ req.seed = None
+ req.ignore_eos = None
+ req.stop = None
+ req.stop_token_ids = None
+ req.frequency_penalty = None
+ req.presence_penalty = None
+ req.extra_body = {}
+ req.model_fields_set = set()
+
+ result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req)
+ assert result.max_tokens == 2048 # YAML default
+
+ def test_size_string_parsed_for_glm_image(
+ self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params
+ ):
+ """'size' in extra_body is parsed as fallback for height/width."""
+ req = mocker.MagicMock()
+ req.temperature = None
+ req.top_p = None
+ req.top_k = None
+ req.max_tokens = None
+ req.min_tokens = None
+ req.seed = None
+ req.ignore_eos = None
+ req.stop = None
+ req.stop_token_ids = None
+ req.frequency_penalty = None
+ req.presence_penalty = None
+ req.extra_body = {"size": "512x512"}
+ req.model_fields_set = set()
+
+ result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req)
+ # 512x512 t2i = 256 + 256 + 1 = 513
+ assert result.max_tokens == 513
diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py
index 6157d82313e..1cc5616657d 100644
--- a/tests/entrypoints/openai_api/test_video_server.py
+++ b/tests/entrypoints/openai_api/test_video_server.py
@@ -133,7 +133,6 @@ def _wait_for_status(client: TestClient, video_id: str, status: str, timeout_s:
last_payload = None
while time.time() < deadline:
response = client.get(f"/v1/videos/{video_id}")
- assert response.status_code == 200
last_payload = response.json()
if last_payload["status"] == status:
return last_payload
@@ -644,7 +643,7 @@ def test_invalid_lora_returns_400(test_client):
assert response.status_code == 200
video_id = response.json()["id"]
failed = _wait_for_status(test_client, video_id, VideoGenerationStatus.FAILED.value)
- assert failed["error"]["code"] == "HTTPException"
+ assert failed["error"]["code"] == 400
assert "lora object" in failed["error"]["message"].lower()
diff --git a/tests/entrypoints/test_omni_entrypoints.py b/tests/entrypoints/test_omni_entrypoints.py
index 3cffcd37df4..289cf2673fb 100644
--- a/tests/entrypoints/test_omni_entrypoints.py
+++ b/tests/entrypoints/test_omni_entrypoints.py
@@ -4,11 +4,13 @@
from collections.abc import Callable
from types import SimpleNamespace
from typing import Any
+from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.sampling_params import RequestOutputKind, SamplingParams
+from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.omni import Omni
@@ -163,6 +165,15 @@ def _patch_engine(monkeypatch: pytest.MonkeyPatch, engine: FakeAsyncOmniEngine)
monkeypatch.setattr("vllm_omni.entrypoints.omni_base.omni_snapshot_download", lambda model: model)
+def _make_base():
+ from vllm_omni.entrypoints.omni_base import OmniBase
+
+ obj = object.__new__(OmniBase)
+ obj.engine = MagicMock()
+ obj.request_states = {}
+ return obj
+
+
def _stage_spec(
stage_id: int,
*,
@@ -687,3 +698,255 @@ def test_omni_forces_final_only_on_llm_stages(monkeypatch: pytest.MonkeyPatch):
assert submitted_params[1].output_kind == RequestOutputKind.FINAL_ONLY
assert submitted_params[2].output_kind == original_diffusion_output_kind
assert len(outputs) == 2
+
+
+def test_fatal_error_raises_engine_dead():
+ base = _make_base()
+ msg = {"type": "error", "error": "orchestrator crashed", "fatal": True}
+
+ with pytest.raises(EngineDeadError, match="orchestrator crashed"):
+ base._handle_output_message(msg)
+
+
+def test_non_fatal_error_raises_runtime():
+ base = _make_base()
+ msg = {"type": "error", "error": "something wrong"}
+
+ with pytest.raises(RuntimeError, match="something wrong"):
+ base._handle_output_message(msg)
+
+
+def test_async_omni_errored_property_alive():
+ omni = object.__new__(AsyncOmni)
+ omni.engine = SimpleNamespace(
+ is_alive=lambda: True,
+ stage_clients=[SimpleNamespace(is_comprehension=False)],
+ )
+
+ assert omni.errored is False
+
+
+def test_async_omni_errored_property_dead_engine():
+ omni = object.__new__(AsyncOmni)
+ omni.engine = SimpleNamespace(
+ is_alive=lambda: False,
+ stage_clients=[SimpleNamespace(is_comprehension=False)],
+ )
+
+ assert omni.errored is True
+
+
+def test_async_omni_errored_property_dead_stage():
+ omni = object.__new__(AsyncOmni)
+ dead_stage = SimpleNamespace(is_comprehension=False, _engine_dead=True)
+ omni.engine = SimpleNamespace(
+ is_alive=lambda: True,
+ stage_clients=[dead_stage],
+ )
+
+ assert omni.errored is True
+
+
+def _enqueue_stage_error(
+ engine: FakeAsyncOmniEngine,
+ msg,
+ *,
+ error_text: str,
+ kill_engine: bool = False,
+):
+ """Enqueue a stage error output, optionally killing the engine."""
+ if kill_engine:
+ engine._alive = False
+ engine.output_q.put_nowait(
+ {
+ "type": "output",
+ "request_id": msg["request_id"],
+ "stage_id": 0,
+ "engine_outputs": SimpleNamespace(
+ payload="",
+ finished=True,
+ images=[],
+ stage_durations={},
+ error=error_text,
+ ),
+ "finished": False,
+ }
+ )
+
+
+@pytest.mark.asyncio
+async def test_async_omni_propagates_engine_dead_error(monkeypatch: pytest.MonkeyPatch):
+ """When the engine is dead and an error output arrives, ``generate()``
+ must raise ``EngineDeadError`` (not plain ``RuntimeError``)."""
+
+ engine = FakeAsyncOmniEngine(
+ stage_metadata=THREE_STAGE_META,
+ on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="worker OOM", kill_engine=True),
+ )
+ _patch_engine(monkeypatch, engine)
+
+ app = AsyncOmni("dummy-model")
+ try:
+ with pytest.raises(EngineDeadError, match="worker OOM"):
+ async for _ in app.generate(prompt="hello", request_id="req-dead"):
+ pass
+ finally:
+ app.shutdown()
+
+
+@pytest.mark.asyncio
+async def test_async_omni_propagates_engine_generate_error(monkeypatch: pytest.MonkeyPatch):
+ """When the engine is alive but a stage error occurs, ``generate()``
+ must raise ``EngineGenerateError`` (recoverable, not ``EngineDeadError``)."""
+
+ engine = FakeAsyncOmniEngine(
+ stage_metadata=THREE_STAGE_META,
+ on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="diffusion step failed"),
+ )
+ _patch_engine(monkeypatch, engine)
+
+ app = AsyncOmni("dummy-model")
+ try:
+ with pytest.raises(EngineGenerateError):
+ async for _ in app.generate(prompt="hello", request_id="req-recover"):
+ pass
+ finally:
+ app.shutdown()
+
+
+# ───────── OmniBase.check_health() aggregation ─────────
+
+
+def test_check_health_passes_when_all_healthy():
+ base = _make_base()
+ healthy_stage = MagicMock()
+ healthy_stage.check_health = MagicMock()
+ base.engine.is_alive.return_value = True
+ base.engine.stage_clients = [healthy_stage]
+ base.check_health() # should not raise
+
+
+def test_check_health_raises_when_stage_dead():
+ base = _make_base()
+ dead_stage = MagicMock()
+ dead_stage.check_health = MagicMock(side_effect=EngineDeadError("Stage-1 dead"))
+ base.engine.is_alive.return_value = True
+ base.engine.stage_clients = [dead_stage]
+ with pytest.raises(EngineDeadError, match="Stage-1 dead"):
+ base.check_health()
+
+
+def test_check_health_raises_when_orchestrator_dead():
+ base = _make_base()
+ base.engine.is_alive.return_value = False
+ base.engine.stage_clients = []
+ with pytest.raises(EngineDeadError, match="not alive"):
+ base.check_health()
+
+
+# ───────── OmniBase.errored property ─────────
+
+
+def test_omni_base_errored_false_when_alive():
+ base = _make_base()
+ base.engine.is_alive.return_value = True
+ base.engine.stage_clients = [SimpleNamespace()]
+ assert base.errored is False
+
+
+def test_omni_base_errored_true_when_orchestrator_dead():
+ base = _make_base()
+ base.engine.is_alive.return_value = False
+ base.engine.stage_clients = []
+ assert base.errored is True
+
+
+def test_omni_base_errored_true_when_stage_engine_dead():
+ base = _make_base()
+ base.engine.is_alive.return_value = True
+ dead_stage = SimpleNamespace(_engine_dead=True)
+ base.engine.stage_clients = [dead_stage]
+ assert base.errored is True
+
+
+def test_omni_base_errored_true_when_stage_resources_engine_dead():
+ base = _make_base()
+ base.engine.is_alive.return_value = True
+ dead_stage = SimpleNamespace(resources=SimpleNamespace(engine_dead=True))
+ base.engine.stage_clients = [dead_stage]
+ assert base.errored is True
+
+
+# ───────── Omni (sync) EngineDeadError / EngineGenerateError ─────────
+
+
+def test_omni_propagates_engine_dead_error(monkeypatch: pytest.MonkeyPatch):
+ """When the engine is dead and a stage error output arrives,
+ ``Omni.generate()`` must raise ``EngineDeadError``."""
+ engine = FakeAsyncOmniEngine(
+ stage_metadata=THREE_STAGE_META,
+ on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="worker OOM", kill_engine=True),
+ )
+ _patch_engine(monkeypatch, engine)
+
+ app = Omni("dummy-model")
+ try:
+ with pytest.raises(EngineDeadError, match="worker OOM"):
+ list(app.generate(["hello"], py_generator=False, use_tqdm=False))
+ finally:
+ app.shutdown()
+
+
+def test_omni_propagates_engine_generate_error(monkeypatch: pytest.MonkeyPatch):
+ """When the engine is alive but a stage error occurs,
+ ``Omni.generate()`` must raise ``EngineGenerateError`` (recoverable)."""
+ engine = FakeAsyncOmniEngine(
+ stage_metadata=THREE_STAGE_META,
+ on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="diffusion step failed"),
+ )
+ _patch_engine(monkeypatch, engine)
+
+ app = Omni("dummy-model")
+ try:
+ with pytest.raises(EngineGenerateError):
+ list(app.generate(["hello"], py_generator=False, use_tqdm=False))
+ finally:
+ app.shutdown()
+
+
+def test_omni_errored_property_alive(monkeypatch: pytest.MonkeyPatch):
+ """Omni.errored (inherited from OmniBase) returns False when healthy."""
+ engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META)
+ _patch_engine(monkeypatch, engine)
+
+ app = Omni("dummy-model")
+ try:
+ assert app.errored is False
+ finally:
+ app.shutdown()
+
+
+def test_omni_errored_property_dead_engine(monkeypatch: pytest.MonkeyPatch):
+ """Omni.errored returns True when the orchestrator is dead."""
+ engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META)
+ _patch_engine(monkeypatch, engine)
+
+ app = Omni("dummy-model")
+ try:
+ engine._alive = False
+ assert app.errored is True
+ finally:
+ app.shutdown()
+
+
+def test_omni_errored_property_dead_stage(monkeypatch: pytest.MonkeyPatch):
+ """Omni.errored returns True when a stage client is marked dead."""
+ engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META)
+ _patch_engine(monkeypatch, engine)
+
+ app = Omni("dummy-model")
+ try:
+ engine.stage_clients[0]._engine_dead = True
+ assert app.errored is True
+ finally:
+ app.shutdown()
diff --git a/tests/entrypoints/test_omni_sleep_mode.py b/tests/entrypoints/test_omni_sleep_mode.py
new file mode 100644
index 00000000000..aa7be1ba0f7
--- /dev/null
+++ b/tests/entrypoints/test_omni_sleep_mode.py
@@ -0,0 +1,336 @@
+import asyncio
+import logging
+import os
+
+import pytest
+import torch
+from vllm import SamplingParams
+
+from tests.helpers.mark import hardware_test
+from vllm_omni.entrypoints.async_omni import AsyncOmni
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+from vllm_omni.platforms import current_omni_platform
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger("OmniTest")
+pytestmark = [pytest.mark.advanced_model]
+
+
+def clean_gpu_envs():
+ """clean up GPU environment variables to ensure tests run on all available devices."""
+ device_visibility_vars = [
+ "CUDA_VISIBLE_DEVICES", # NVIDIA
+ "HIP_VISIBLE_DEVICES", # AMD ROCm
+ "ZE_AFFINITY_MASK", # Intel XPU
+ "ONEAPI_DEVICE_SELECTOR", # Intel OneAPI
+ "ASCEND_RT_VISIBLE_DEVICES", # Huawei NPU (CAN)
+ ]
+ for key in device_visibility_vars:
+ os.environ.pop(key, None)
+
+
+def get_vram_info(device_id: int) -> dict:
+ """Obtain a snapshot of the specified GPU's memory (GiB)."""
+ try:
+ if current_omni_platform.is_rocm():
+ num_gpus = torch.cuda.device_count()
+ safe_id = device_id if device_id < num_gpus else 0
+ torch.cuda.synchronize(safe_id)
+ return {
+ "reserved": torch.cuda.memory_reserved(safe_id) / 1024**3,
+ "allocated": torch.cuda.memory_allocated(safe_id) / 1024**3,
+ }
+ else:
+ with torch.cuda.device(device_id):
+ torch.cuda.synchronize()
+ return {
+ "reserved": torch.cuda.memory_reserved() / 1024**3,
+ "allocated": torch.cuda.memory_allocated() / 1024**3,
+ }
+ except Exception as e:
+ logger.warning(f"memory skip ({device_id}): {e}")
+ return {"reserved": 0.0, "allocated": 0.0}
+
+
+def get_ack_info(ack, key, default=None):
+ """
+ Since ACKs in a distributed environment can be either objects or dictionaries,
+ this tool ensures compatibility.
+ """
+ if hasattr(ack, key):
+ return getattr(ack, key)
+ if isinstance(ack, dict):
+ return ack.get(key, default)
+ return default
+
+
+@pytest.fixture(scope="function")
+async def llm_engine():
+ if current_omni_platform.is_rocm():
+ clean_gpu_envs()
+ model_name = "ByteDance-Seed/BAGEL-7B-MoT"
+ common_args = {
+ "worker_type": "ar",
+ "enable_sleep_mode": True,
+ "dtype": "bfloat16",
+ "trust_remote_code": True,
+ "max_model_len": 2048,
+ "max_num_batched_tokens": 8192,
+ "enforce_eager": True,
+ }
+ stages = [
+ {
+ "stage_id": 0,
+ "stage_type": "llm",
+ "runtime": {"process": True, "devices": "0", "max_batch_size": 1},
+ "engine_args": {**common_args, "model_stage": "thinker", "gpu_memory_utilization": 0.1},
+ },
+ {
+ "stage_id": 1,
+ "stage_type": "llm",
+ "engine_input_source": [0],
+ "runtime": {"process": True, "devices": "1", "max_batch_size": 1, "connector_type": "queue"},
+ "engine_args": {**common_args, "model_stage": "talker", "gpu_memory_utilization": 0.1},
+ },
+ ]
+ connectors = [{"src_stage_id": 0, "dst_stage_id": 1, "connector_type": "queue"}]
+ engine = AsyncOmni(model=model_name, stages=stages, connectors=connectors, init_timeout=600, enable_sleep_mode=True)
+ yield engine
+ engine.shutdown()
+
+
+@pytest.fixture(scope="function")
+async def diffusion_engine():
+ if current_omni_platform.is_rocm():
+ clean_gpu_envs()
+ model_name = "ByteDance-Seed/BAGEL-7B-MoT"
+ stages = [
+ {
+ "stage_id": 0,
+ "stage_type": "diffusion",
+ "runtime": {"process": True, "devices": "0,1", "max_batch_size": 1},
+ "engine_args": {
+ "model_stage": "base",
+ "gpu_memory_utilization": 0.1,
+ "model_class_name": "BagelPipeline",
+ "enable_sleep_mode": True,
+ "enforce_eager": True,
+ "max_num_batched_tokens": 8192,
+ "parallel_config": {
+ "tensor_parallel_size": 2,
+ },
+ },
+ "final_output": True,
+ "final_output_type": "image",
+ }
+ ]
+ engine = AsyncOmni(model=model_name, stages=stages, init_timeout=600, enable_sleep_mode=True)
+ yield engine
+ engine.shutdown()
+
+
+class TestOmniSleepMode:
+ @pytest.mark.asyncio
+ @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=1)
+ async def test_llm_sleep_ack(self, llm_engine: AsyncOmni):
+ """LLM Thinker (GPU0) Signal and Physical Recycling Audit"""
+ try:
+ acks = await llm_engine.sleep(stage_ids=[0], level=2)
+ # Verification signal successful
+ assert all(get_ack_info(ack, "status") == "SUCCESS" for ack in acks)
+ # Verify physical recycling volume
+ total_freed_bytes = sum(get_ack_info(ack, "freed_bytes", 0) for ack in acks)
+ freed_gib = total_freed_bytes / 1024**3
+ logger.info(f"Thinker VRAM physically reclaimed: {freed_gib:.2f} GiB")
+ assert freed_gib > 5.0
+ finally:
+ llm_engine.shutdown()
+
+ @pytest.mark.asyncio
+ @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
+ async def test_diffusion_sleep_handshake(self, diffusion_engine: AsyncOmni):
+ """Diffusion Worker stage signal loop"""
+ try:
+ logger.info("Starting Diffusion Worker Handshake Test")
+ acks = await diffusion_engine.sleep(stage_ids=[0], level=2)
+
+ def _get_status(ack):
+ return ack.status if hasattr(ack, "status") else ack.get("status")
+
+ assert len(acks) >= 1, "Expected at least 1 ACK from Diffusion Workers"
+ assert all(_get_status(ack) == "SUCCESS" for ack in acks)
+ logger.info(f"Success: Received {len(acks)} Diffusion Worker ACKs")
+ logger.info("Testing auto-wakeup before test end...")
+ await diffusion_engine.wake_up(stage_ids=[0])
+ logger.info("Test logic finished, triggering manual shutdown...")
+ finally:
+ diffusion_engine.shutdown()
+ logger.info("Manual shutdown executed. Test should exit now.")
+
+ @pytest.mark.asyncio
+ @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
+ async def test_cross_device_cleanup(self, diffusion_engine: AsyncOmni):
+ """Physical recycling audit: leveraging deterministic data returned by Workers"""
+ try:
+ acks = await diffusion_engine.sleep(stage_ids=[0], level=1)
+ # Sum up the release amounts reported by all Workers.
+ total_freed_bytes = sum(get_ack_info(ack, "freed_bytes", 0) for ack in acks)
+ freed_gb = total_freed_bytes / 1024**3
+ logger.info("Physical reclamation summary from workers:")
+ logger.info(f"- Total Workers: {len(acks)}")
+ logger.info(f"- Total Freed: {freed_gb:.2f} GiB")
+ assert freed_gb > 14.0
+ logger.info("SUCCESS: 100% weights offloaded.")
+ finally:
+ diffusion_engine.shutdown()
+
+ @pytest.mark.asyncio
+ @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
+ async def test_diffusion_integrity_bit_level(self, diffusion_engine: AsyncOmni):
+ """Bit-level consistency after Diffusion wake-up (prevent image corruption)"""
+ try:
+ prompt = "A huge swimming pool, with many people swimming."
+ sp = OmniDiffusionSamplingParams(num_inference_steps=4, height=512, width=512, seed=42)
+ llm_sp = SamplingParams()
+
+ # Baseline Generation
+ logger.info("Running Baseline Generation...")
+ base_output = None
+ async for output in diffusion_engine.generate(prompt, request_id="base", sampling_params_list=[llm_sp, sp]):
+ base_output = output
+ assert base_output is not None and len(base_output.images) > 0
+ logger.info("Baseline Generation successful.")
+ # Sleep Level 2
+ logger.info("Entering Deep Sleep (VRAM Scavenging)...")
+ await diffusion_engine.sleep(stage_ids=[0], level=2)
+ # Wake-up
+ logger.info("Waking up (Reloading Weights)...")
+ await diffusion_engine.wake_up(stage_ids=[0])
+
+ await asyncio.sleep(2.0)
+ import gc
+
+ gc.collect()
+
+ logger.info("Running Post-Wakeup Generation...")
+ post_output = None
+ async for output in diffusion_engine.generate(prompt, request_id="post", sampling_params_list=[llm_sp, sp]):
+ post_output = output
+ # Assert result consistency
+ assert post_output is not None
+ assert len(base_output.images) == len(post_output.images)
+ assert post_output.images[0] is not None
+ logger.info("SUCCESS: Diffusion integrity verified after Sleep/Wake cycle.")
+ except Exception as e:
+ logger.error(f"Integrity test failed: {e}")
+ raise e
+ finally:
+ logger.info("Triggering mandatory cleanup...")
+ diffusion_engine.shutdown()
+ logger.info("Cleanup complete, test exiting.")
+
+ @pytest.mark.asyncio
+ @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
+ async def test_coordinated_cross_device(self, llm_engine: AsyncOmni, diffusion_engine: AsyncOmni):
+ """Heterogeneous Coordinated Cleanup Test (Talker and Diffusion on GPU 1)"""
+ device_id = 1
+ try:
+ logger.info(f"Waking up both engines on GPU {device_id}...")
+ await llm_engine.wake_up(stage_ids=[1])
+ await diffusion_engine.wake_up(stage_ids=[0])
+
+ get_vram_info(device_id)
+ torch.cuda.empty_cache()
+ await asyncio.sleep(2)
+
+ initial_vram = get_vram_info(device_id)["reserved"]
+ logger.info(f"GPU {device_id} Peak Pressure: {initial_vram:.2f} GiB")
+
+ # coordinated sleep
+ logger.info("Issuing concurrent SLEEP commands...")
+ await llm_engine.sleep(stage_ids=[1], level=2)
+ await asyncio.sleep(1.0)
+ await diffusion_engine.sleep(stage_ids=[0], level=2)
+
+ await asyncio.sleep(3.0)
+ torch.cuda.empty_cache()
+
+ final_vram = get_vram_info(device_id)["reserved"]
+ logger.info(f"GPU {device_id} Final VRAM after coordinated sleep: {final_vram:.2f} GiB")
+
+ assert initial_vram - final_vram > 15.0 or final_vram < 8.0
+ logger.info(f"SUCCESS: Heterogeneous VRAM drop verified on GPU {device_id}.")
+ except Exception as e:
+ logger.error(f"Coordinated test failed: {e}")
+ raise e
+ finally:
+ logger.info("Triggering mandatory cleanup for both engines...")
+ llm_engine.shutdown()
+ diffusion_engine.shutdown()
+ logger.info("All engines scavenged. Ready for next test.")
+
+ @pytest.mark.asyncio
+ @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2)
+ async def test_diffusion_vram_lifecycle_audit(self, diffusion_engine: AsyncOmni):
+ """Diffusion memory loop: Active -> Deep Sleep -> Active -> inference sanity check"""
+ device_id = 1
+ try:
+ get_vram_info(device_id)
+ torch.cuda.empty_cache()
+ vram_initial = get_vram_info(device_id)["reserved"]
+ logger.info(f"Diffusion Initial VRAM: {vram_initial:.2f} GiB")
+
+ # Sleep
+ logger.info("Triggering Level 2 Deep Sleep (Full Weight Offloading)...")
+ acks = await diffusion_engine.sleep(stage_ids=[0], level=2)
+
+ reported_freed_bytes = sum(getattr(ack, "freed_bytes", 0) for ack in acks)
+ reported_freed_gib = reported_freed_bytes / 1024**3
+ logger.info(f"Worker internally reported freed: {reported_freed_gib:.2f} GiB")
+
+ await asyncio.sleep(2)
+ get_vram_info(device_id)
+ torch.cuda.empty_cache()
+
+ vram_sleeping = get_vram_info(device_id)["reserved"]
+ logger.info(f"External VRAM measurement during Sleep: {vram_sleeping:.2f} GiB")
+
+ assert reported_freed_gib > 14.0 or vram_sleeping < 5.0, (
+ f"Reclamation failed. Reported: {reported_freed_gib:.2f}G, Measured: {vram_sleeping:.2f}G"
+ )
+
+ # wake-up
+ logger.info("Triggering Wake-up (Reloading weights to GPU)...")
+ await diffusion_engine.wake_up(stage_ids=[0])
+
+ await asyncio.sleep(2)
+ get_vram_info(device_id)
+ torch.cuda.empty_cache()
+ vram_restored = get_vram_info(device_id)["reserved"]
+ logger.info(f"VRAM after Wake-up: {vram_restored:.2f} GiB")
+
+ assert abs(vram_restored - vram_initial) < 3.0, "VRAM failed to restore to initial levels"
+
+ # inference sanity check
+ logger.info("Running post-lifecycle inference smoke test...")
+ prompt = "A futuristic lab with glowing lights, high quality."
+ sp = OmniDiffusionSamplingParams(num_inference_steps=2, height=512, width=512, seed=42)
+ llm_sp = SamplingParams()
+
+ base_img_found = False
+ async for output in diffusion_engine.generate(
+ prompt, request_id="lifecycle-check", sampling_params_list=[llm_sp, sp]
+ ):
+ if output.images and output.images[0] is not None:
+ base_img_found = True
+
+ assert base_img_found, "Inference failed after Wake-up cycle!"
+ logger.info("SUCCESS: Full Diffusion Lifecycle (Active -> Sleep -> Active -> Generate) audited.")
+
+ except Exception as e:
+ logger.error(f"Lifecycle audit failed: {e}")
+ raise e
+ finally:
+ logger.info("Cleaning up engine and scavenging processes...")
+ diffusion_engine.shutdown()
+ await asyncio.sleep(1)
diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py
index 248629d51df..4252d5e837e 100644
--- a/tests/entrypoints/test_utils.py
+++ b/tests/entrypoints/test_utils.py
@@ -358,6 +358,48 @@ def test_load_and_resolve_with_kwargs(self):
assert len(stage_configs) == 1
assert "dtype" in stage_configs[0]["engine_args"]
+ def test_stage_configs_path_promotes_new_deploy_yaml_without_expanding_replicas(
+ self, tmp_path, mocker: MockerFixture
+ ):
+ deploy_path = tmp_path / "qwen3_multi.yaml"
+ deploy_path.write_text(
+ 'stages:\n - stage_id: 0\n devices: "0"\n - stage_id: 1\n devices: "1,2,3"\n num_replicas: 3\n',
+ encoding="utf-8",
+ )
+
+ returned_stage_configs = [
+ create_config({"stage_id": 0, "runtime": {"devices": "0"}, "engine_args": {"model": "dummy"}}),
+ create_config(
+ {
+ "stage_id": 1,
+ "runtime": {"devices": "1,2,3", "num_replicas": 3},
+ "engine_args": {"model": "dummy"},
+ }
+ ),
+ ]
+ load_stage_configs = mocker.patch(
+ "vllm_omni.entrypoints.utils.load_stage_configs_from_model",
+ return_value=returned_stage_configs,
+ )
+
+ config_path, stage_configs = load_and_resolve_stage_configs(
+ model="dummy-model",
+ stage_configs_path=str(deploy_path),
+ kwargs={},
+ )
+
+ load_stage_configs.assert_called_once_with(
+ "dummy-model",
+ base_engine_args={},
+ deploy_config_path=str(deploy_path),
+ stage_overrides=None,
+ cli_explicit_keys=None,
+ )
+ assert config_path == str(deploy_path)
+ assert len(stage_configs) == 2
+ assert stage_configs[1].runtime.num_replicas == 3
+ assert stage_configs[1].runtime.devices == "1,2,3"
+
class TestLoadStageConfigsFromYaml:
"""Regression tests for stage-config loading and merging."""
diff --git a/tests/helpers/media.py b/tests/helpers/media.py
index 3c45c2a9d95..4463acbbbb6 100644
--- a/tests/helpers/media.py
+++ b/tests/helpers/media.py
@@ -544,7 +544,12 @@ def cosine_similarity_text(text1, text2, n: int = 3):
norm2 = sum(b * b for b in vec2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
- return dot_product / (norm1 * norm2)
+ cosine = dot_product / (norm1 * norm2)
+ # Down-weight when lengths differ: repeated/hallucinated transcripts stay
+ # high in bag-of-ngrams cosine (e.g. ABCABCABC vs ABC) but should score low.
+ len1, len2 = len(text1), len(text2)
+ length_harmony = (2.0 * min(len1, len2)) / (len1 + len2)
+ return cosine * length_harmony
def _merge_base64_audio_to_segment(base64_list: list[str]):
diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py
index 29a80372ecf..8310cedceff 100644
--- a/tests/helpers/stage_config.py
+++ b/tests/helpers/stage_config.py
@@ -411,6 +411,42 @@ def delete_by_path(config_dict: dict, path: str) -> None:
},
},
},
+ "qwen3_omni_moe_multi_replicas_4gpu": {
+ "base_config": "qwen3_omni_moe.yaml",
+ "async_chunk": True,
+ "stages": [
+ {
+ "stage_id": 0,
+ "devices": "0",
+ "gpu_memory_utilization": 0.85,
+ "max_num_seqs": 6,
+ "max_model_len": 32768,
+ "mm_processor_cache_gb": 0,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 150, "ignore_eos": False},
+ },
+ {
+ "stage_id": 1,
+ "devices": "1,2,3",
+ "num_replicas": 3,
+ "gpu_memory_utilization": 0.6,
+ "max_num_seqs": 2,
+ "max_model_len": 32768,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 1000},
+ },
+ {
+ "stage_id": 2,
+ "devices": "1,2,3",
+ "num_replicas": 3,
+ "gpu_memory_utilization": 0.1,
+ "max_num_seqs": 2,
+ "max_num_batched_tokens": 65536,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 2000},
+ },
+ ],
+ },
# Single-stage thinker-only topology for the abort test.
"qwen2_5_omni_thinker_only": {
"async_chunk": False,
diff --git a/tests/model_executor/models/glm_image/test_glm_image_ar.py b/tests/model_executor/models/glm_image/test_glm_image_ar.py
new file mode 100644
index 00000000000..32a016b2a67
--- /dev/null
+++ b/tests/model_executor/models/glm_image/test_glm_image_ar.py
@@ -0,0 +1,352 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for GLM-Image AR model: DataParser, processor, and M-RoPE."""
+
+import importlib.util
+import os
+import sys
+import types
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+# ---------------------------------------------------------------------------
+# Load target classes via importlib to avoid requiring transformers.models.glm_image
+# (which may not exist in CI). This follows the same pattern as
+# tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py.
+# ---------------------------------------------------------------------------
+
+_BASE = os.path.join(
+ os.path.dirname(__file__),
+ os.pardir,
+ os.pardir,
+ os.pardir,
+ os.pardir,
+ "vllm_omni",
+ "model_executor",
+ "models",
+ "glm_image",
+)
+
+
+def _load_module(name: str, filename: str):
+ path = os.path.abspath(os.path.join(_BASE, filename))
+ spec = importlib.util.spec_from_file_location(name, path)
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ return mod
+
+
+def _build_mock_modules() -> dict[str, object]:
+ """Build the dict of modules to inject into sys.modules."""
+ # Stub transformers.models.glm_image submodules
+ glm_image_mod = types.ModuleType("transformers.models.glm_image")
+ glm_config_mod = types.ModuleType("transformers.models.glm_image.configuration_glm_image")
+ glm_config_mod.GlmImageConfig = type("GlmImageConfig", (), {})
+ glm_config_mod.GlmImageTextConfig = type("GlmImageTextConfig", (), {})
+ glm_config_mod.GlmImageVisionConfig = type("GlmImageVisionConfig", (), {})
+ glm_config_mod.GlmImageVQVAEConfig = type("GlmImageVQVAEConfig", (), {})
+ glm_proc_mod = types.ModuleType("transformers.models.glm_image.processing_glm_image")
+ glm_proc_mod.GlmImageProcessor = type("GlmImageProcessor", (), {})
+
+ # vllm_omni submodules needed by the import chain
+ vllm_omni_mod = MagicMock()
+ vllm_omni_models = types.ModuleType("vllm_omni.model_executor.models")
+ vllm_omni_glm_image_pkg = types.ModuleType("vllm_omni.model_executor.models.glm_image")
+ vllm_omni_glm_image_pkg.__path__ = [os.path.abspath(_BASE)]
+ vllm_omni_output = MagicMock()
+
+ return {
+ "transformers.models.glm_image": glm_image_mod,
+ "transformers.models.glm_image.configuration_glm_image": glm_config_mod,
+ "transformers.models.glm_image.processing_glm_image": glm_proc_mod,
+ "vllm_omni": vllm_omni_mod,
+ "vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"),
+ "vllm_omni.model_executor.models": vllm_omni_models,
+ "vllm_omni.model_executor.models.glm_image": vllm_omni_glm_image_pkg,
+ "vllm_omni.model_executor.models.output_templates": vllm_omni_output,
+ }
+
+
+def _load_target_classes():
+ """Load the glm_image_ar module with mocked dependencies."""
+ mocks = _build_mock_modules()
+ with patch.dict(sys.modules, mocks):
+ mod = _load_module(
+ "vllm_omni.model_executor.models.glm_image.glm_image_ar",
+ "glm_image_ar.py",
+ )
+ sys.modules["vllm_omni.model_executor.models.glm_image.glm_image_ar"] = mod
+ return mod
+
+
+_ar_mod = _load_target_classes()
+
+GlmImageDataParser = _ar_mod.GlmImageDataParser
+GlmImageMultiModalProcessor = _ar_mod.GlmImageMultiModalProcessor
+GlmImageForConditionalGeneration = _ar_mod.GlmImageForConditionalGeneration
+GlmImageRotaryEmbedding = _ar_mod.GlmImageRotaryEmbedding
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+# =============================================================================
+# Helper: Minimal config for testing
+# =============================================================================
+
+
+def _make_hf_config(**overrides):
+ """Create a minimal GlmImageConfig-like object for testing."""
+ defaults = {
+ "image_token_id": 167855,
+ "image_start_token_id": 16384,
+ "image_end_token_id": 16385,
+ "grid_bos_token_id": None,
+ "grid_eos_token_id": None,
+ }
+ defaults.update(overrides)
+ from types import SimpleNamespace
+
+ return SimpleNamespace(**defaults)
+
+
+# =============================================================================
+# Tests for GlmImageDataParser
+# =============================================================================
+
+
+class TestGlmImageDataParser:
+ """Test that img2img key is normalized to image in the data parser."""
+
+ def test_img2img_normalized_to_image(self):
+ parser = GlmImageDataParser.__new__(GlmImageDataParser)
+ parser._expected_hidden_size = 4096
+ # The _get_subparsers should include img2img
+ subparsers = parser._get_subparsers()
+ assert "img2img" in subparsers
+ assert subparsers["img2img"] == parser._parse_image_data
+
+ def test_parse_mm_data_normalizes_img2img(self):
+ parser = GlmImageDataParser.__new__(GlmImageDataParser)
+ parser._expected_hidden_size = 4096
+ # Create a mock for the parent parse_mm_data
+ original_parse = type(parser).parse_mm_data
+
+ calls = []
+
+ def mock_parse(mm_data, **kwargs):
+ calls.append(mm_data)
+ return MagicMock()
+
+ # Monkey-patch temporarily
+ type(parser).parse_mm_data = mock_parse
+ try:
+ parser.parse_mm_data({"img2img": "fake_image"})
+ except Exception:
+ pass # parse might fail on mock, we just check the normalization
+ finally:
+ type(parser).parse_mm_data = original_parse
+
+ # Verify that "img2img" was normalized to "image"
+ if calls:
+ assert "image" in calls[0]
+ assert "img2img" not in calls[0]
+
+
+# =============================================================================
+# Tests for _build_generation_grids
+# =============================================================================
+
+
+class TestBuildGenerationGrids:
+ """Test M-RoPE grid construction for t2i mode."""
+
+ @pytest.fixture
+ def processor(self):
+ """Create a minimal processor instance with mocked info."""
+ proc = object.__new__(GlmImageMultiModalProcessor)
+ proc.info = MagicMock()
+ return proc
+
+ def test_1024x1024(self, processor):
+ kwargs = {"target_h": 1024, "target_w": 1024}
+ grids = processor._build_generation_grids(kwargs)
+ # token_h = 32, token_w = 32
+ # ratio = 1.0, small_h = 16, small_w = 16
+ assert grids.shape == (2, 3)
+ assert grids[0].tolist() == [1, 32, 32] # large
+ assert grids[1].tolist() == [1, 16, 16] # small
+
+ def test_512x512(self, processor):
+ kwargs = {"target_h": 512, "target_w": 512}
+ grids = processor._build_generation_grids(kwargs)
+ assert grids.shape == (2, 3)
+ assert grids[0].tolist() == [1, 16, 16]
+ # small: ratio=1.0, small_h=int(sqrt(1)*16)=16, small_w=16
+ assert grids[1].tolist() == [1, 16, 16]
+
+ def test_non_square(self, processor):
+ kwargs = {"target_h": 1024, "target_w": 512}
+ grids = processor._build_generation_grids(kwargs)
+ # token_h = 32, token_w = 16, ratio = 2.0
+ # small_h = int(sqrt(2)*16) = 22, small_w = int(sqrt(0.5)*16) = 11
+ assert grids[0].tolist() == [1, 32, 16]
+ assert grids[1].tolist() == [1, 22, 11]
+
+ def test_defaults_to_1024_when_no_target(self, processor):
+ kwargs = {}
+ grids = processor._build_generation_grids(kwargs)
+ assert grids[0].tolist() == [1, 32, 32]
+
+ def test_height_width_fallback(self, processor):
+ kwargs = {"height": 512, "width": 512}
+ grids = processor._build_generation_grids(kwargs)
+ assert grids[0].tolist() == [1, 16, 16]
+
+ def test_aligned_to_factor(self, processor):
+ # 1000 not aligned to 32, should be rounded down to 992
+ kwargs = {"target_h": 1000, "target_w": 1000}
+ grids = processor._build_generation_grids(kwargs)
+ # 1000 // 32 = 31
+ assert grids[0].tolist() == [1, 31, 31]
+
+
+# =============================================================================
+# Tests for get_mrope_input_positions
+# =============================================================================
+
+
+class TestGetMropeInputPositions:
+ """Test M-RoPE position ID computation."""
+
+ @pytest.fixture
+ def model(self):
+ """Create a minimal model instance for M-RoPE testing."""
+ model = object.__new__(GlmImageForConditionalGeneration)
+ model.config = _make_hf_config()
+ return model
+
+ def test_pure_text(self, model):
+ """Pure text tokens: all 3 dimensions get same sequential positions."""
+ input_tokens = [100, 101, 102, 103]
+ positions, delta = model.get_mrope_input_positions(input_tokens)
+ assert positions.shape == (3, 4)
+ # All three dims should be [0, 1, 2, 3]
+ for dim in range(3):
+ assert positions[dim].tolist() == [0, 1, 2, 3]
+ assert delta == 0 # max(3) + 1 - seq_len(4) = 0
+
+ def test_t2i_with_target_size(self, model):
+ """t2i with explicit target_h/target_w: grids built from them."""
+ input_tokens = [100, 101, 102, 16384] # text +
+ kwargs = {"target_h": 256, "target_w": 256}
+
+ positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs)
+ # 256/32=8 -> grids = [[1,8,8], [1,16,16]] (small uses factor//2=16 base)
+ # Decode order (reversed): grid[-1]=[1,16,16]=256, grid[-2]=[1,8,8]=64, EOS=1
+ total_decode = 256 + 64 + 1 # 321
+ assert positions.shape == (3, 4 + total_decode)
+ # delta = max_position + 1 - seq_len
+ # Positions advance by max(h,w) per grid: max(16,16)=16, max(8,8)=8
+ # max_pos = seq_len(4) + 16 + 8 = 28, then EOS at 28
+ # delta = 28 + 1 - 4 = 25
+ assert delta == 25
+
+ def test_t2i_1024_default_grids(self, model):
+ """t2i with default 1024x1024 grids when no explicit target size."""
+ # Prompt ending with image_start_token_id but no image_end_token_id
+ input_tokens = [100, 101, 16384]
+ # No target_h/target_w, no mrope_image_grid_thw
+ # Falls back to token parsing then to default [[1,32,32], [1,16,16]]
+ positions, delta = model.get_mrope_input_positions(input_tokens)
+ assert positions.shape[0] == 3
+
+ def test_i2i_with_mrope_grid(self, model):
+ """i2i: mrope_image_grid_thw contains source + target grids."""
+ # Source image tokens: [16384, 167855*4, 16385] + text + 16384(bos)
+ source_grid = [1, 2, 2] # 2x2 = 4 image tokens
+ target_grid = [1, 32, 32] # 32x32 = 1024 tokens
+ mrope_grid = torch.tensor([source_grid, target_grid], dtype=torch.long)
+
+ # input_tokens: text + + 4*image_token + +
+ input_tokens = [100, 101, 16384] + [167855] * 4 + [16385, 16384]
+
+ positions, delta = model.get_mrope_input_positions(input_tokens, mrope_image_grid_thw=mrope_grid)
+
+ # 1 source image (num_complete_images=1), 1 target grid (num_decode_grids=1)
+ # Prefill covers all input tokens
+ # Decode covers: 32*32 + 1(EOS) = 1025 tokens
+ assert positions.shape[0] == 3
+
+ def test_position_delta_non_negative(self, model):
+ """mrope_position_delta should be non-negative for valid inputs."""
+ input_tokens = [100, 16384]
+ kwargs = {"target_h": 64, "target_w": 64}
+ positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs)
+ assert delta >= 0
+
+
+# =============================================================================
+# Tests for GlmImageRotaryEmbedding._apply_mrope
+# =============================================================================
+
+
+class TestGlmImageRotaryEmbedding:
+ """Test M-RoPE section interleaving in the rotary embedding."""
+
+ @pytest.fixture
+ def rotary_emb(self):
+ # mrope_section=[8,12,12] sums to 32, so rotary_dim//2 must be >= 32
+ # -> head_dim=64 gives rotary_dim=64, rotary_dim//2=32
+ return GlmImageRotaryEmbedding(head_dim=64, mrope_section=[8, 12, 12])
+
+ def test_apply_mrope_shape(self, rotary_emb):
+ """Output shape matches [num_tokens, rotary_dim // 2]."""
+ freqs = torch.randn(3, 5, 32) # 3 dims, 5 tokens, rotary_dim//2=32
+ result = rotary_emb._apply_mrope(freqs)
+ assert result.shape == (5, 32)
+
+ def test_apply_mrope_interleaving(self, rotary_emb):
+ """Verify that M-RoPE correctly interleaves T/H/W sections."""
+ # mrope_section = [8, 12, 12] splits dim 32 into 3 chunks: [8, 12, 12]
+ # chunk 0 (size 8): dim 0 % 3 = 0 (temporal)
+ # chunk 1 (size 12): dim 1 % 3 = 1 (height)
+ # chunk 2 (size 12): dim 2 % 3 = 2 (width)
+ freqs = torch.ones(3, 1, 32)
+ freqs[0, :, :] = 1.0 # temporal
+ freqs[1, :, :] = 2.0 # height
+ freqs[2, :, :] = 3.0 # width
+
+ result = rotary_emb._apply_mrope(freqs)
+ assert result.shape == (1, 32)
+ assert (result[0, :8] == 1.0).all() # chunk 0: temporal
+ assert (result[0, 8:20] == 2.0).all() # chunk 1: height
+ assert (result[0, 20:32] == 3.0).all() # chunk 2: width
+
+ def test_forward_1d_positions(self, rotary_emb):
+ """Forward with 1D positions (text-only) produces correct shapes."""
+ positions = torch.arange(10) # [10]
+ q = torch.randn(10, 64)
+ k = torch.randn(10, 64)
+ q_out, k_out = rotary_emb(positions, q, k)
+ assert q_out.shape == (10, 64)
+ assert k_out.shape == (10, 64)
+
+ def test_forward_3d_positions(self, rotary_emb):
+ """Forward with 3D M-RoPE positions produces correct shapes."""
+ positions = torch.arange(30).reshape(3, 10) # [3, 10]
+ q = torch.randn(10, 64)
+ k = torch.randn(10, 64)
+ q_out, k_out = rotary_emb(positions, q, k)
+ assert q_out.shape == (10, 64)
+ assert k_out.shape == (10, 64)
+
+ def test_forward_preserves_dtype(self, rotary_emb):
+ """Output dtype matches input dtype."""
+ positions = torch.arange(5)
+ q = torch.randn(5, 64, dtype=torch.float32)
+ k = torch.randn(5, 64, dtype=torch.float32)
+ q_out, k_out = rotary_emb(positions, q, k)
+ assert q_out.dtype == torch.float32
+ assert k_out.dtype == torch.float32
diff --git a/tests/model_executor/stage_input_processors/test_glm_image.py b/tests/model_executor/stage_input_processors/test_glm_image.py
new file mode 100644
index 00000000000..88352cac248
--- /dev/null
+++ b/tests/model_executor/stage_input_processors/test_glm_image.py
@@ -0,0 +1,389 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for GLM-Image stage input processor."""
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+
+from vllm_omni.model_executor.stage_input_processors.glm_image import (
+ _first_source_image,
+ _has_source_image,
+ _parse_generated_tokens,
+ _upsample_token_ids,
+ ar2diffusion,
+ compute_max_tokens,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+# =============================================================================
+# Helpers
+# =============================================================================
+
+
+def _source_output(token_ids: list[int], mm_output: dict | None = None):
+ """Create a minimal AR output mock."""
+ return SimpleNamespace(
+ outputs=[SimpleNamespace(token_ids=token_ids)],
+ multimodal_output=mm_output,
+ )
+
+
+# =============================================================================
+# Tests for _has_source_image
+# =============================================================================
+
+
+class TestHasSourceImage:
+ def test_none_input(self):
+ assert _has_source_image(None) is False
+
+ def test_non_dict_input(self):
+ assert _has_source_image("not_a_dict") is False
+
+ def test_empty_dict(self):
+ assert _has_source_image({}) is False
+
+ def test_image_key_present(self):
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ assert _has_source_image({"image": img}) is True
+
+ def test_image_key_none(self):
+ assert _has_source_image({"image": None}) is False
+
+ def test_img2img_key_present(self):
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ assert _has_source_image({"img2img": img}) is True
+
+ def test_images_key_list(self):
+ from PIL import Image
+
+ imgs = [Image.new("RGB", (64, 64))]
+ assert _has_source_image({"images": imgs}) is True
+
+ def test_images_key_empty_list(self):
+ assert _has_source_image({"images": []}) is False
+
+ def test_images_key_single(self):
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ assert _has_source_image({"images": img}) is True
+
+
+# =============================================================================
+# Tests for _first_source_image
+# =============================================================================
+
+
+class TestFirstSourceImage:
+ def test_none_input(self):
+ assert _first_source_image(None) is None
+
+ def test_non_dict_input(self):
+ assert _first_source_image("not_a_dict") is None
+
+ def test_image_key_single(self):
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ assert _first_source_image({"image": img}) is img
+
+ def test_image_key_list(self):
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ assert _first_source_image({"image": [img]}) is img
+
+ def test_image_key_empty_list(self):
+ assert _first_source_image({"image": []}) is None
+
+ def test_img2img_key_single(self):
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ assert _first_source_image({"img2img": img}) is img
+
+ def test_images_key_list(self):
+ from PIL import Image
+
+ imgs = [Image.new("RGB", (64, 64))]
+ assert _first_source_image({"images": imgs}) is imgs[0]
+
+ def test_images_key_empty_list(self):
+ assert _first_source_image({"images": []}) is None
+
+ def test_images_key_single_not_list(self):
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ assert _first_source_image({"images": img}) is img
+
+
+# =============================================================================
+# Tests for compute_max_tokens
+# =============================================================================
+
+
+class TestComputeMaxTokens:
+ def test_t2i_1024x1024(self):
+ # t2i: small_tokens + large_tokens + 1 (EOS)
+ # token_h = 1024/32 = 32, token_w = 1024/32 = 32
+ # large = 32*32 = 1024
+ # ratio = 1.0, small_h = sqrt(1)*16 = 16, small_w = sqrt(1)*16 = 16, small = 256
+ # total = 256 + 1024 + 1 = 1281
+ result = compute_max_tokens(1024, 1024, is_i2i=False)
+ assert result == 1281
+
+ def test_i2i_1024x1024(self):
+ # i2i: large_tokens + 1 (EOS)
+ # large = 32*32 = 1024, total = 1025
+ result = compute_max_tokens(1024, 1024, is_i2i=True)
+ assert result == 1025
+
+ def test_t2i_512x512(self):
+ # token_h = 16, token_w = 16, large = 256
+ # ratio = 1.0, small_h = 16, small_w = 16, small = 256
+ # total = 256 + 256 + 1 = 513
+ result = compute_max_tokens(512, 512, is_i2i=False)
+ assert result == 513
+
+ def test_i2i_512x512(self):
+ # large = 256, total = 257
+ result = compute_max_tokens(512, 512, is_i2i=True)
+ assert result == 257
+
+ def test_non_square_t2i(self):
+ # 1024x512: token_h=32, token_w=16, large=512
+ # ratio = 32/16 = 2.0
+ # small_h = max(1, int(sqrt(2)*16)) = 22, small_w = max(1, int(sqrt(0.5)*16)) = 11
+ # small = 22*11 = 242
+ # total = 242 + 512 + 1 = 755
+ result = compute_max_tokens(1024, 512, is_i2i=False)
+ assert result == 242 + 512 + 1
+
+ def test_custom_factor(self):
+ # factor=16, 512x512: token_h=32, token_w=32, large=1024
+ # ratio=1.0, small_h=8, small_w=8, small=64
+ # total = 64 + 1024 + 1 = 1089
+ result = compute_max_tokens(512, 512, factor=16, is_i2i=False)
+ assert result == 1089
+
+ def test_i2i_smaller_than_t2i(self):
+ t2i = compute_max_tokens(1024, 1024, is_i2i=False)
+ i2i = compute_max_tokens(1024, 1024, is_i2i=True)
+ assert i2i < t2i
+
+
+# =============================================================================
+# Tests for _upsample_token_ids
+# =============================================================================
+
+
+class TestUpsampleTokenIds:
+ def test_2x2_to_4x4(self):
+ tokens = torch.tensor([1, 2, 3, 4])
+ result = _upsample_token_ids(tokens, 2, 2)
+ assert result.shape == (16,) # 4 * 4 = 16 (2x each dim)
+
+ def test_1x1_to_2x2(self):
+ tokens = torch.tensor([7])
+ result = _upsample_token_ids(tokens, 1, 1)
+ assert result.shape == (4,) # 2 * 2
+ assert (result == 7).all()
+
+ def test_4x4_to_8x8(self):
+ tokens = torch.arange(16, dtype=torch.long)
+ result = _upsample_token_ids(tokens, 4, 4)
+ assert result.shape == (64,)
+
+ def test_preserves_dtype(self):
+ tokens = torch.tensor([1, 2, 3, 4], dtype=torch.long)
+ result = _upsample_token_ids(tokens, 2, 2)
+ assert result.dtype == torch.long
+
+
+# =============================================================================
+# Tests for _parse_generated_tokens
+# =============================================================================
+
+
+class TestParseGeneratedTokens:
+ def test_t2i_standard(self):
+ # 1024x1024, t2i: small(256) + large(1024) + EOS
+ # Generate 256 + 1024 + 1 = 1281 tokens, last is EOS (16385)
+ large_tokens = list(range(1024))
+ small_tokens = list(range(1000, 1256))
+ eos = [16385]
+ token_ids = small_tokens + large_tokens + eos
+
+ prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=False)
+ assert h == 1024
+ assert w == 1024
+ # Prior tokens should be upsampled: 1024 tokens -> 4*1024 = 4096
+ assert prior.shape[0] == 1024 * 4
+
+ def test_i2i_standard(self):
+ # 1024x1024, i2i: large(1024) + EOS
+ large_tokens = list(range(1024))
+ eos = [16385]
+ token_ids = large_tokens + eos
+
+ prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True)
+ assert h == 1024
+ assert w == 1024
+ assert prior.shape[0] == 1024 * 4
+
+ def test_i2i_without_eos(self):
+ # i2i without EOS marker
+ large_tokens = list(range(1024))
+ prior, h, w = _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=True)
+ assert h == 1024
+ assert w == 1024
+
+ def test_i2i_too_few_tokens_raises(self):
+ with pytest.raises(ValueError, match="i2i token parse failed"):
+ _parse_generated_tokens([1, 2, 3], 1024, 1024, is_i2i=True)
+
+ def test_t2i_too_few_tokens_raises(self):
+ # Only large tokens, no small preview
+ large_tokens = list(range(1024))
+ with pytest.raises(ValueError, match="t2i token parse failed"):
+ _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=False)
+
+ def test_i2i_t2i_style_layout_fallback(self):
+ # i2i but got t2i-style (small + large) tokens
+ small_tokens = list(range(256))
+ large_tokens = list(range(1024))
+ token_ids = small_tokens + large_tokens
+
+ prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True)
+ # Should extract the large portion
+ assert h == 1024
+ assert w == 1024
+
+
+# =============================================================================
+# Tests for ar2diffusion
+# =============================================================================
+
+
+class TestAr2Diffusion:
+ def test_basic_t2i(self):
+ """Test basic text-to-image pipeline: AR -> Diffusion."""
+ # 1024x1024 t2i: small(256) + large(1024) + EOS
+ token_ids = list(range(256)) + list(range(1024)) + [16385]
+ source_outputs = [_source_output(token_ids)]
+
+ prompt = {"prompt": "a cat", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}}
+
+ result = ar2diffusion(source_outputs, prompt=[prompt])
+ assert len(result) == 1
+ assert result[0]["prompt"] == "a cat"
+ assert result[0]["height"] == 1024
+ assert result[0]["width"] == 1024
+ assert "prior_token_ids" in result[0]["extra"]
+
+ def test_i2i_with_mm_output(self):
+ """Test image-to-image with prior_token_image_ids from AR model."""
+ token_ids = list(range(1024)) + [16385]
+ mm_output = {"prior_token_image_ids": torch.tensor([1, 2, 3])}
+ source_outputs = [_source_output(token_ids, mm_output)]
+
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ prompt = {
+ "prompt": "edit this",
+ "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024},
+ "multi_modal_data": {"image": img},
+ }
+
+ result = ar2diffusion(source_outputs, prompt=[prompt])
+ assert len(result) == 1
+ assert result[0]["extra"]["prior_token_image_ids"] is not None
+
+ def test_i2i_detected_via_modalities(self):
+ """Test i2i mode detected via modalities field."""
+ token_ids = list(range(1024)) + [16385]
+ source_outputs = [_source_output(token_ids)]
+
+ prompt = {
+ "prompt": "edit this",
+ "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024},
+ "modalities": ["img2img"],
+ }
+
+ result = ar2diffusion(source_outputs, prompt=[prompt])
+ assert len(result) == 1
+
+ def test_empty_source_outputs_returns_empty_list(self):
+ assert ar2diffusion([], prompt={}) == []
+
+ def test_default_dimensions(self):
+ """When no height/width in prompt, defaults to 1024x1024."""
+ token_ids = list(range(256)) + list(range(1024)) + [16385]
+ source_outputs = [_source_output(token_ids)]
+
+ prompt = {"prompt": "test"}
+ result = ar2diffusion(source_outputs, prompt=[prompt])
+ assert result[0]["height"] == 1024
+ assert result[0]["width"] == 1024
+
+ def test_requires_multimodal_data_with_pil_image(self):
+ """Test that pil_image is included when requires_multimodal_data=True."""
+ token_ids = list(range(256)) + list(range(1024)) + [16385]
+ source_outputs = [_source_output(token_ids)]
+
+ from PIL import Image
+
+ img = Image.new("RGB", (64, 64))
+ prompt = {
+ "prompt": "test",
+ "multi_modal_data": {"image": img},
+ }
+
+ result = ar2diffusion(source_outputs, prompt=[prompt], requires_multimodal_data=True)
+ assert result[0]["pil_image"] is img
+
+ def test_extra_params_passed_through(self):
+ """Test that seed, num_inference_steps, guidance_scale, negative_prompt are passed."""
+ token_ids = list(range(256)) + list(range(1024)) + [16385]
+ source_outputs = [_source_output(token_ids)]
+
+ prompt = {
+ "prompt": "test",
+ "seed": 42,
+ "num_inference_steps": 50,
+ "guidance_scale": 7.5,
+ "negative_prompt": "blurry",
+ }
+
+ result = ar2diffusion(source_outputs, prompt=[prompt])
+ assert result[0]["seed"] == 42
+ assert result[0]["num_inference_steps"] == 50
+ assert result[0]["guidance_scale"] == 7.5
+ assert result[0]["negative_prompt"] == "blurry"
+
+ def test_batch_requests(self):
+ """Test processing multiple requests in a batch."""
+ tokens1 = list(range(256)) + list(range(1024)) + [16385]
+ tokens2 = list(range(256)) + list(range(1024)) + [16385]
+ source_outputs = [_source_output(tokens1), _source_output(tokens2)]
+
+ prompts = [
+ {"prompt": "first", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}},
+ {"prompt": "second", "mm_processor_kwargs": {"target_h": 512, "target_w": 512}},
+ ]
+
+ result = ar2diffusion(source_outputs, prompt=prompts)
+ assert len(result) == 2
+ assert result[0]["prompt"] == "first"
+ assert result[1]["prompt"] == "second"
diff --git a/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py b/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py
new file mode 100644
index 00000000000..1ea0ccfa708
--- /dev/null
+++ b/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py
@@ -0,0 +1,71 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+
+import logging
+from types import SimpleNamespace
+
+import pytest
+import torch
+
+from vllm_omni.model_executor.stage_input_processors import mimo_audio as sip
+from vllm_omni.model_executor.stage_input_processors.mimo_audio import (
+ MAX_CODE2WAV_TOKENS,
+ llm2code2wav,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+def _make_source_outputs(codec_codes: torch.Tensor, request_id: str = "req-0"):
+ """Build minimal source_outputs with one talker output carrying codec_codes."""
+ output = SimpleNamespace(multimodal_output={"code_predictor_codes": codec_codes})
+ talker_output = SimpleNamespace(outputs=[output], request_id=request_id)
+ return [talker_output]
+
+
+def test_llm2code2wav_truncates_when_flat_exceeds_max(caplog):
+ """Flat codec sequences longer than MAX_CODE2WAV_TOKENS must be truncated, not passed through."""
+ # prepend_and_flatten_colmajor produces 36 ids per (8, 4) codec frame:
+ # pad adds one row -> (9, 4) per frame, permuted and flattened.
+ # Pick enough frames to comfortably exceed the cap.
+ frames = (MAX_CODE2WAV_TOKENS // 36) + 100
+ codec_codes = torch.ones(frames, 1, 8, 4, dtype=torch.long)
+
+ source_outputs = _make_source_outputs(codec_codes, request_id="req-long")
+
+ # Attach caplog's handler directly to the module logger so the warning is
+ # captured regardless of propagation (vllm's logger configuration can
+ # interact badly with caplog.at_level's default root-handler path).
+ target_logger = logging.getLogger("vllm_omni.model_executor.stage_input_processors.mimo_audio")
+ target_logger.addHandler(caplog.handler)
+ prev_level = target_logger.level
+ target_logger.setLevel(logging.WARNING)
+ try:
+ prompts = llm2code2wav(source_outputs)
+ finally:
+ target_logger.removeHandler(caplog.handler)
+ target_logger.setLevel(prev_level)
+
+ assert len(prompts) == 1
+ assert len(prompts[0]["prompt_token_ids"]) == MAX_CODE2WAV_TOKENS
+ assert any("truncating" in rec.getMessage() for rec in caplog.records), (
+ f"Expected a 'truncating' warning; captured records: {[r.getMessage() for r in caplog.records]}"
+ )
+
+
+def test_llm2code2wav_short_sequence_unchanged():
+ """Short codec sequences are returned without truncation."""
+ codec_codes = torch.ones(4, 1, 8, 4, dtype=torch.long)
+ source_outputs = _make_source_outputs(codec_codes, request_id="req-short")
+
+ prompts = llm2code2wav(source_outputs)
+
+ assert len(prompts) == 1
+ # 4 frames + 1 pad row, flattened col-major → well below the cap
+ assert 0 < len(prompts[0]["prompt_token_ids"]) <= MAX_CODE2WAV_TOKENS
+
+
+def test_llm2code2wav_truncation_boundary_constant_matches_yaml():
+ """MAX_CODE2WAV_TOKENS must match the stage-1 max_model_len in mimo_audio.yaml and end2end.py."""
+ assert sip.MAX_CODE2WAV_TOKENS == 18192
diff --git a/tests/profile/test_omni_torch_profiler.py b/tests/profile/test_omni_torch_profiler.py
new file mode 100644
index 00000000000..3920078af4d
--- /dev/null
+++ b/tests/profile/test_omni_torch_profiler.py
@@ -0,0 +1,582 @@
+# tests/test_omni_torch_profiler.py
+from __future__ import annotations
+
+import gzip
+from dataclasses import dataclass
+from pathlib import Path
+from types import SimpleNamespace
+
+import pytest
+from openpyxl import load_workbook
+
+import vllm_omni.profiler.omni_torch_profiler as profiler_mod
+from vllm_omni.profiler.omni_torch_profiler import OmniTorchProfilerWrapper
+
+
+@pytest.fixture(autouse=True)
+def patch_worker_profiler_init(monkeypatch):
+ def fake_init(self, profiler_config):
+ self.profiler_config = profiler_config
+
+ monkeypatch.setattr(
+ profiler_mod.WorkerProfiler,
+ "__init__",
+ fake_init,
+ )
+
+
+@dataclass
+class DummyProfilerConfig:
+ torch_profiler_dir: str
+ torch_profiler_use_gzip: bool = False
+ torch_profiler_record_shapes: bool = True
+ torch_profiler_with_memory: bool = True
+ torch_profiler_with_stack: bool = True
+ torch_profiler_with_flops: bool = False
+ torch_profiler_dump_cuda_time_total: bool = False
+
+
+class FakeEvent:
+ def __init__(
+ self,
+ *,
+ name: str = "aten::mm",
+ count: int = 1,
+ input_shapes=None,
+ stack=None,
+ self_cpu_time_total: float = 10.0,
+ cpu_time_total: float = 12.0,
+ self_cuda_time_total: float = 20.0,
+ cuda_time_total: float = 25.0,
+ self_xpu_time_total: float = 0.0,
+ xpu_time_total: float = 0.0,
+ self_cpu_memory_usage: int = 128,
+ cpu_memory_usage: int = 256,
+ self_cuda_memory_usage: int = 1024,
+ cuda_memory_usage: int = 2048,
+ self_xpu_memory_usage: int = 0,
+ xpu_memory_usage: int = 0,
+ device_type: str = "CUDA",
+ node_id: int = 0,
+ overload_name: str = "",
+ is_async: bool = False,
+ is_legacy: bool = False,
+ ):
+ self.key = name
+ self.name = name
+ self.count = count
+ self.input_shapes = input_shapes if input_shapes is not None else [[2, 2], [2, 2]]
+ self.stack = stack if stack is not None else ["frame_a", "frame_b"]
+ self.self_cpu_time_total = self_cpu_time_total
+ self.cpu_time_total = cpu_time_total
+ self.self_cuda_time_total = self_cuda_time_total
+ self.cuda_time_total = cuda_time_total
+ self.self_xpu_time_total = self_xpu_time_total
+ self.xpu_time_total = xpu_time_total
+ self.self_cpu_memory_usage = self_cpu_memory_usage
+ self.cpu_memory_usage = cpu_memory_usage
+ self.self_cuda_memory_usage = self_cuda_memory_usage
+ self.cuda_memory_usage = cuda_memory_usage
+ self.self_xpu_memory_usage = self_xpu_memory_usage
+ self.xpu_memory_usage = xpu_memory_usage
+ self.device_type = device_type
+ self.node_id = node_id
+ self.overload_name = overload_name
+ self.is_async = is_async
+ self.is_legacy = is_legacy
+
+
+class FakeEventList(list):
+ def table(self, sort_by=None, row_limit=-1):
+ return f"fake_table(sort_by={sort_by}, row_limit={row_limit}, len={len(self)})"
+
+
+class FakeTorchProfiler:
+ def __init__(self, on_trace_ready=None):
+ self.started = False
+ self.stopped = False
+ self.on_trace_ready = on_trace_ready
+ self.exported_traces = []
+ self.exported_stacks = []
+
+ def start(self):
+ self.started = True
+
+ def stop(self):
+ self.stopped = True
+ if self.on_trace_ready is not None:
+ self.on_trace_ready(self)
+
+ def export_chrome_trace(self, path):
+ Path(path).write_text('{"traceEvents": []}')
+ self.exported_traces.append(path)
+
+ def export_stacks(self, path, metric):
+ Path(path).write_text(f"metric={metric}\nstack_line_1\nstack_line_2\n")
+ self.exported_stacks.append((path, metric))
+
+ def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
+ if group_by_input_shape:
+ return FakeEventList(
+ [
+ FakeEvent(
+ name="aten::bmm",
+ input_shapes=[[4, 8, 16], [4, 16, 32]],
+ )
+ ]
+ )
+ if group_by_stack_n:
+ return FakeEventList(
+ [
+ FakeEvent(
+ name="aten::all_reduce",
+ stack=["python_a", "python_b", "python_c"],
+ )
+ ]
+ )
+ return FakeEventList(
+ [
+ FakeEvent(name="aten::mm"),
+ FakeEvent(name="nccl:all_reduce"),
+ ]
+ )
+
+
+@pytest.fixture
+def fake_config(tmp_path):
+ return DummyProfilerConfig(torch_profiler_dir=str(tmp_path))
+
+
+@pytest.fixture
+def fake_profiler_factory(monkeypatch):
+ created = {}
+
+ def fake_profile(*args, **kwargs):
+ profiler = FakeTorchProfiler(on_trace_ready=kwargs.get("on_trace_ready"))
+ created["profiler"] = profiler
+ created["args"] = args
+ created["kwargs"] = kwargs
+ return profiler
+
+ monkeypatch.setattr(profiler_mod.torch.profiler, "profile", fake_profile)
+ return created
+
+
+@pytest.fixture
+def wrapper(fake_config, fake_profiler_factory):
+ return OmniTorchProfilerWrapper(
+ profiler_config=fake_config,
+ worker_name="worker0",
+ local_rank=0,
+ activities=["CPU", "CUDA"],
+ )
+
+
+def test_set_trace_filename_creates_timestamped_session_dir(wrapper, monkeypatch, tmp_path):
+ class FixedDatetime:
+ @classmethod
+ def now(cls):
+ class _Now:
+ def strftime(self, fmt):
+ return "20260403-034200"
+
+ return _Now()
+
+ monkeypatch.setattr(profiler_mod, "datetime", FixedDatetime)
+
+ wrapper.set_trace_filename("stage_0_llm_1234567890")
+
+ session_dir = Path(wrapper._session_dir)
+ assert session_dir.exists()
+ assert session_dir.parent == tmp_path
+ assert session_dir.name == "20260403-034200_stage_0_llm_1234567890"
+
+
+def test_set_trace_filename_with_full_path_creates_timestamped_leaf(wrapper, monkeypatch, tmp_path):
+ class FixedDatetime:
+ @classmethod
+ def now(cls):
+ class _Now:
+ def strftime(self, fmt):
+ return "20260403-111111"
+
+ return _Now()
+
+ monkeypatch.setattr(profiler_mod, "datetime", FixedDatetime)
+
+ target = tmp_path / "nested" / "stage_x"
+ wrapper.set_trace_filename(str(target))
+
+ session_dir = Path(wrapper._session_dir)
+ assert session_dir.exists()
+ assert session_dir.parent == target.parent
+ assert session_dir.name == "20260403-111111_stage_x"
+
+
+def test_on_trace_ready_exports_trace_json(wrapper):
+ wrapper.set_trace_filename("case_trace")
+
+ wrapper._on_trace_ready(wrapper.profiler)
+
+ trace_path = Path(wrapper._trace_path)
+ assert trace_path.exists()
+ assert trace_path.name == "trace_rank0.json"
+ assert trace_path.read_text() == '{"traceEvents": []}'
+
+
+def test_on_trace_ready_exports_gzip_trace(fake_config, fake_profiler_factory, monkeypatch):
+ fake_config.torch_profiler_use_gzip = True
+
+ wrapper = OmniTorchProfilerWrapper(
+ profiler_config=fake_config,
+ worker_name="worker0",
+ local_rank=0,
+ activities=["CPU", "CUDA"],
+ )
+ wrapper.set_trace_filename("case_gzip")
+
+ def fake_popen(cmd):
+ assert cmd[:2] == ["gzip", "-f"]
+ src = Path(cmd[2])
+ gz_path = src.with_suffix(src.suffix + ".gz")
+ gz_path.write_bytes(gzip.compress(src.read_bytes()))
+ src.unlink()
+
+ class DummyProc:
+ pass
+
+ return DummyProc()
+
+ monkeypatch.setattr(profiler_mod.subprocess, "Popen", fake_popen)
+
+ wrapper._on_trace_ready(wrapper.profiler)
+
+ assert wrapper._trace_path.endswith(".json.gz")
+ gz_path = Path(wrapper._trace_path)
+ assert gz_path.exists()
+ assert gzip.decompress(gz_path.read_bytes()) == b'{"traceEvents": []}'
+
+
+def test_start_enables_memory_history(wrapper, monkeypatch):
+ calls = []
+
+ monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
+
+ def fake_record_memory_history(*args, **kwargs):
+ calls.append((args, kwargs))
+
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_record_memory_history",
+ fake_record_memory_history,
+ )
+
+ wrapper.set_trace_filename("case_memory_start")
+ wrapper._start()
+
+ assert wrapper.profiler.started is True
+ assert wrapper._memory_history_enabled is True
+ assert len(calls) == 1
+ assert calls[0][1]["enabled"] == "all"
+ assert calls[0][1]["context"] == "all"
+ assert calls[0][1]["stacks"] == "python"
+ assert calls[0][1]["max_entries"] == 100000
+ assert calls[0][1]["clear_history"] is True
+
+
+def test_start_skips_memory_history_when_memory_disabled(fake_config, fake_profiler_factory, monkeypatch):
+ fake_config.torch_profiler_with_memory = False
+
+ wrapper = OmniTorchProfilerWrapper(
+ profiler_config=fake_config,
+ worker_name="worker0",
+ local_rank=0,
+ activities=["CPU", "CUDA"],
+ )
+
+ called = {"n": 0}
+
+ monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
+
+ def fake_record_memory_history(*args, **kwargs):
+ called["n"] += 1
+
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_record_memory_history",
+ fake_record_memory_history,
+ )
+
+ wrapper.set_trace_filename("case_skip_memory")
+ wrapper._start()
+
+ assert called["n"] == 0
+ assert wrapper._memory_history_enabled is False
+
+
+def test_try_dump_memory_snapshot_writes_pickle(wrapper, monkeypatch):
+ wrapper.set_trace_filename("case_snapshot")
+ wrapper._memory_history_enabled = True
+ wrapper._memory_history_backend = "CUDA"
+ wrapper._memory_history_module = profiler_mod.torch.cuda.memory
+
+ disable_calls = []
+
+ def fake_record_memory_history(*args, **kwargs):
+ disable_calls.append((args, kwargs))
+
+ def fake_dump_snapshot(path):
+ Path(path).write_bytes(b"fake pickle bytes")
+
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_record_memory_history",
+ fake_record_memory_history,
+ )
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_dump_snapshot",
+ fake_dump_snapshot,
+ )
+
+ wrapper._try_dump_memory_snapshot()
+
+ snapshot = Path(wrapper._artifact_paths["memory_snapshot"])
+ assert snapshot.exists()
+ assert snapshot.name == "memory_snapshot_rank0.pickle"
+ assert snapshot.read_bytes() == b"fake pickle bytes"
+ assert wrapper._memory_history_enabled is False
+
+ assert disable_calls[-1][1]["enabled"] is None
+
+
+def test_stop_always_dumps_memory_snapshot_on_success_path(wrapper, monkeypatch):
+ wrapper.set_trace_filename("case_stop")
+
+ record_calls = []
+ dump_calls = []
+
+ monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
+
+ def fake_record_memory_history(*args, **kwargs):
+ record_calls.append((args, kwargs))
+
+ def fake_dump_snapshot(path):
+ dump_calls.append(path)
+ Path(path).write_bytes(b"snapshot-bytes")
+
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_record_memory_history",
+ fake_record_memory_history,
+ )
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_dump_snapshot",
+ fake_dump_snapshot,
+ )
+
+ wrapper._start()
+ wrapper._stop()
+
+ session_dir = Path(wrapper._session_dir)
+
+ assert wrapper.profiler.started is True
+ assert wrapper.profiler.stopped is True
+ assert (session_dir / "memory_snapshot_rank0.pickle").exists()
+ assert len(dump_calls) == 1
+ assert record_calls[0][1]["enabled"] == "all"
+ assert record_calls[-1][1]["enabled"] is None
+
+
+def test_on_stop_hook_generates_stack_and_excel_artifacts(wrapper):
+ wrapper.set_trace_filename("case_artifacts")
+ wrapper._on_stop_hook()
+
+ session_dir = Path(wrapper._session_dir)
+
+ assert not (session_dir / "ops_summary_rank0.txt").exists()
+ assert not (session_dir / "ops_by_shape_rank0.txt").exists()
+ assert not (session_dir / "ops_by_stack_rank0.txt").exists()
+ assert (session_dir / "stacks_cpu_rank0.txt").exists()
+ assert (session_dir / "stacks_cuda_rank0.txt").exists()
+ assert (session_dir / "ops_rank0.xlsx").exists()
+
+
+def test_excel_contains_expected_sheets(wrapper):
+ wrapper.set_trace_filename("case_excel")
+ wrapper._on_stop_hook()
+
+ xlsx_path = Path(wrapper._session_dir) / "ops_rank0.xlsx"
+ wb = load_workbook(xlsx_path)
+
+ assert "summary" in wb.sheetnames
+ assert "by_shape" in wb.sheetnames
+ assert "by_stack" in wb.sheetnames
+
+
+def test_excel_summary_has_expected_columns(wrapper):
+ wrapper.set_trace_filename("case_excel_columns")
+ wrapper._on_stop_hook()
+
+ xlsx_path = Path(wrapper._session_dir) / "ops_rank0.xlsx"
+ wb = load_workbook(xlsx_path)
+ ws = wb["summary"]
+
+ headers = [cell.value for cell in next(ws.iter_rows(min_row=1, max_row=1))]
+ assert "name" in headers
+ assert "count" in headers
+ assert "self_cpu_time_total_us" in headers
+ assert "self_cuda_time_total_us" in headers
+ assert "self_cpu_memory_usage_bytes" in headers
+ assert "self_cuda_memory_usage_bytes" in headers
+ assert "input_shapes" in headers
+ assert "stack" in headers
+
+
+def test_get_results_returns_all_artifact_paths(wrapper, monkeypatch):
+ wrapper.set_trace_filename("case_results")
+
+ monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True)
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_record_memory_history",
+ lambda *args, **kwargs: None,
+ )
+ monkeypatch.setattr(
+ profiler_mod.torch.cuda.memory,
+ "_dump_snapshot",
+ lambda path: Path(path).write_bytes(b"snapshot"),
+ )
+
+ wrapper._start()
+ wrapper._stop()
+
+ results = wrapper.get_results()
+
+ assert "trace" in results
+ assert "table" in results
+ assert "session_dir" in results
+ assert "ops" in results
+ assert "memory_snapshot" in results
+ assert Path(results["session_dir"]).exists()
+ assert Path(results["ops"]).exists()
+ assert Path(results["table"]).exists()
+ assert Path(results["table"]).name == "ops_rank0.xlsx"
+ assert Path(results["memory_snapshot"]).exists()
+
+
+def test_start_uses_xpu_memory_history_when_available(wrapper, monkeypatch):
+ calls = []
+
+ def fake_record_memory_history(*args, **kwargs):
+ calls.append((args, kwargs))
+
+ fake_memory_module = SimpleNamespace(
+ _record_memory_history=fake_record_memory_history,
+ )
+ monkeypatch.setattr(
+ wrapper,
+ "_resolve_memory_history_backend",
+ lambda: ("XPU", fake_memory_module),
+ )
+
+ wrapper.set_trace_filename("case_xpu_memory_start")
+ wrapper._start()
+
+ assert wrapper._memory_history_enabled is True
+ assert wrapper._memory_history_backend == "XPU"
+ assert wrapper._memory_history_module is fake_memory_module
+ assert calls[0][1]["enabled"] == "all"
+
+
+def test_start_uses_npu_memory_history_when_available(wrapper, monkeypatch):
+ calls = []
+
+ def fake_record_memory_history(*args, **kwargs):
+ calls.append((args, kwargs))
+
+ fake_memory_module = SimpleNamespace(
+ _record_memory_history=fake_record_memory_history,
+ )
+ monkeypatch.setattr(
+ wrapper,
+ "_resolve_memory_history_backend",
+ lambda: ("NPU", fake_memory_module),
+ )
+
+ wrapper.set_trace_filename("case_npu_memory_start")
+ wrapper._start()
+
+ assert wrapper._memory_history_enabled is True
+ assert wrapper._memory_history_backend == "NPU"
+ assert wrapper._memory_history_module is fake_memory_module
+ assert calls[0][1]["enabled"] == "all"
+
+
+def test_start_skips_memory_history_when_backend_api_missing(wrapper, monkeypatch):
+ fake_memory_module = SimpleNamespace()
+ monkeypatch.setattr(
+ wrapper,
+ "_resolve_memory_history_backend",
+ lambda: ("XPU", fake_memory_module),
+ )
+
+ wrapper.set_trace_filename("case_missing_memory_api")
+ wrapper._start()
+
+ assert wrapper._memory_history_enabled is False
+ assert wrapper._memory_history_backend is None
+ assert wrapper._memory_history_module is None
+
+
+def test_try_dump_memory_snapshot_uses_resolved_backend_module(wrapper):
+ wrapper.set_trace_filename("case_xpu_snapshot")
+ wrapper._memory_history_enabled = True
+ wrapper._memory_history_backend = "XPU"
+
+ calls = []
+
+ def fake_record_memory_history(*args, **kwargs):
+ calls.append((args, kwargs))
+
+ def fake_dump_snapshot(path):
+ Path(path).write_bytes(b"xpu snapshot bytes")
+
+ wrapper._memory_history_module = SimpleNamespace(
+ _record_memory_history=fake_record_memory_history,
+ _dump_snapshot=fake_dump_snapshot,
+ )
+
+ wrapper._try_dump_memory_snapshot()
+
+ snapshot = Path(wrapper._artifact_paths["memory_snapshot"])
+ assert snapshot.exists()
+ assert snapshot.read_bytes() == b"xpu snapshot bytes"
+ assert calls[-1][1]["enabled"] is None
+ assert wrapper._memory_history_enabled is False
+ assert wrapper._memory_history_backend is None
+ assert wrapper._memory_history_module is None
+
+
+def test_event_list_to_rows_contains_expected_fields(wrapper):
+ rows = wrapper._event_list_to_rows(
+ [
+ FakeEvent(
+ name="aten::linear",
+ input_shapes=[[8, 16], [16, 32]],
+ stack=["f1", "f2"],
+ )
+ ]
+ )
+
+ assert len(rows) == 1
+ row = rows[0]
+ assert row["name"] == "aten::linear"
+ assert row["count"] == 1
+ assert row["self_cpu_time_total_us"] == 10.0
+ assert row["self_cuda_time_total_us"] == 20.0
+ assert row["self_cpu_memory_usage_bytes"] == 128
+ assert row["self_cuda_memory_usage_bytes"] == 1024
+ assert "[[8, 16], [16, 32]]" == row["input_shapes"]
+ assert row["stack"] == "f1\nf2"
diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py
index 4a271426ff3..778314f7a90 100644
--- a/tests/test_config_factory.py
+++ b/tests/test_config_factory.py
@@ -961,6 +961,36 @@ def test_per_stage_model_arch_flows_through_merge(self, tmp_path):
# Stage 1 uses its per-stage override
assert stages[1].yaml_engine_args["model_arch"] == "Qwen3TTSCode2Wav"
+ def test_subtalker_sampling_params_deep_merge_preserves_base_keys(self):
+ """Verify subtalker sampling params participate in stage deep-merge."""
+ from vllm_omni.config.stage_config import _deep_merge_stage
+
+ base = {
+ "stage_id": 0,
+ "subtalker_sampling_params": {
+ "do_sample": True,
+ "temperature": 0.9,
+ "top_k": 50,
+ "top_p": 1.0,
+ },
+ }
+ overlay = {
+ "stage_id": 0,
+ "subtalker_sampling_params": {
+ "temperature": 0.7,
+ "top_k": 32,
+ },
+ }
+
+ merged = _deep_merge_stage(base, overlay)
+
+ assert merged["subtalker_sampling_params"] == {
+ "do_sample": True,
+ "temperature": 0.7,
+ "top_k": 32,
+ "top_p": 1.0,
+ }
+
class TestBaseConfigInheritance:
"""Test deploy YAML base_config inheritance."""
diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py
index b2d61931558..a74c9ffc2d2 100644
--- a/tests/worker/test_omni_gpu_model_runner.py
+++ b/tests/worker/test_omni_gpu_model_runner.py
@@ -41,7 +41,17 @@ def __init__(self):
class DummyTalkerMTP(torch.nn.Module):
"""A fake talker_mtp module for deterministic CPU testing."""
- def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step):
+ def forward(
+ self,
+ req_input_ids,
+ req_embeds,
+ last_talker_hidden,
+ text_step,
+ do_sample=None,
+ temperature=None,
+ top_k=None,
+ top_p=None,
+ ):
# Deterministic behavior:
# - output embeds = input embeds + 1
# - output codes = [[0], [1], ...]
@@ -51,6 +61,36 @@ def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step):
return new_embeds, codes
+class CaptureTalkerMTP(torch.nn.Module):
+ """A fake talker_mtp module that records sampling kwargs."""
+
+ def __init__(self):
+ super().__init__()
+ self.calls = []
+
+ def forward(
+ self,
+ req_input_ids,
+ req_embeds,
+ last_talker_hidden,
+ text_step,
+ do_sample=None,
+ temperature=None,
+ top_k=None,
+ top_p=None,
+ ):
+ self.calls.append(
+ {
+ "do_sample": do_sample,
+ "temperature": temperature,
+ "top_k": top_k,
+ "top_p": top_p,
+ }
+ )
+ codes = torch.zeros((req_embeds.shape[0], 1), dtype=torch.int64)
+ return req_embeds, codes
+
+
@contextmanager
def _noop_forward_context(*args, **kwargs):
"""A no-op context manager to replace vLLM forward context in CPU tests."""
@@ -80,7 +120,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
runner.talker_mtp = DummyTalkerMTP()
runner.model = SimpleNamespace(talker_mtp_output_key="code_predictor_codes")
- runner.vllm_config = object()
+ runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace())
# Provide a minimal implementation that returns the expected 4-tuple.
def _determine_batch_execution_and_padding(**kwargs):
@@ -168,6 +208,43 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
assert torch.allclose(inputs_embeds, before)
+def test_talker_mtp_forward_passes_qwen3_tts_subtalker_sampling_params_to_talker(monkeypatch):
+ import vllm_omni.worker.gpu_model_runner as mod
+
+ monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
+
+ runner = _make_runner(req_ids=("r1",), hidden_size=4)
+ runner.talker_mtp = CaptureTalkerMTP()
+ runner.vllm_config = SimpleNamespace(
+ model_config=SimpleNamespace(
+ subtalker_sampling_params={
+ "do_sample": False,
+ "temperature": 0.2,
+ "top_k": 9,
+ "top_p": 0.55,
+ }
+ )
+ )
+
+ def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_scheduled_tokens, use_cascade_attn):
+ batch_desc = SimpleNamespace(num_tokens=int(num_tokens))
+ return (False, batch_desc, None, None, None)
+
+ monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner)))
+
+ inputs_embeds = torch.zeros((2, 4), dtype=torch.float32)
+ OmniGPUModelRunner._talker_mtp_forward(runner, ["r1"], inputs_embeds)
+
+ assert runner.talker_mtp.calls == [
+ {
+ "do_sample": False,
+ "temperature": 0.2,
+ "top_k": 9,
+ "top_p": 0.55,
+ }
+ ]
+
+
def test_update_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch):
"""Validate that _update_intermediate_buffer writes to model_intermediate_buffer
(forward path) and mirrors to additional_information_cpu setattr (backward compat)."""
diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py
index 588efabfc4f..96a34a8d79f 100644
--- a/vllm_omni/config/model.py
+++ b/vllm_omni/config/model.py
@@ -109,9 +109,11 @@ class OmniModelConfig(ModelConfig):
"extra": {},
}
)
+ subtalker_sampling_params: dict[str, Any] | None = None
omni_kv_config: dict | None = None
codec_frame_rate_hz: float | None = None
task_type: str | None = None
+ enable_sleep_mode: bool = False
@property
def registry(self):
diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py
index ab975cafc3b..8c248059280 100644
--- a/vllm_omni/config/stage_config.py
+++ b/vllm_omni/config/stage_config.py
@@ -394,6 +394,7 @@ class StageDeployConfig:
output_connectors: dict[str, str] | None = None
input_connectors: dict[str, str] | None = None
default_sampling_params: dict[str, Any] | None = None
+ subtalker_sampling_params: dict[str, Any] | None = None
engine_extras: dict[str, Any] = field(default_factory=dict)
@@ -472,7 +473,7 @@ def _parse_stage_deploy(stage_data: dict[str, Any]) -> StageDeployConfig:
return StageDeployConfig(**kwargs)
-_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "engine_extras", "engine_args"})
+_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "subtalker_sampling_params", "engine_extras", "engine_args"})
def _deep_merge_stage(base: dict, overlay: dict) -> dict:
diff --git a/vllm_omni/deploy/qwen3_tts.yaml b/vllm_omni/deploy/qwen3_tts.yaml
index 32dceebd805..5948007a97b 100644
--- a/vllm_omni/deploy/qwen3_tts.yaml
+++ b/vllm_omni/deploy/qwen3_tts.yaml
@@ -43,6 +43,11 @@ stages:
max_tokens: 4096
seed: 42
repetition_penalty: 1.05
+ subtalker_sampling_params:
+ do_sample: true
+ temperature: 0.9
+ top_k: 50
+ top_p: 1.0
- stage_id: 1
max_num_seqs: 1
diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py
index 0a19eb11974..012c0130c7c 100644
--- a/vllm_omni/diffusion/data.py
+++ b/vllm_omni/diffusion/data.py
@@ -353,6 +353,8 @@ def __getattr__(self, item: str) -> Any:
@dataclass
class OmniDiffusionConfig:
# Model and path configuration (for convenience)
+ stage_id: int = 0
+
model: str | None = None
model_class_name: str | None = None
@@ -508,6 +510,8 @@ class OmniDiffusionConfig:
# Step mode settings
step_execution: bool = False
+ # sleep mode
+ enable_sleep_mode: bool = False
# Maximum number of sequences to generate in a batch
max_num_seqs: int = 1
@@ -797,5 +801,43 @@ def __str__(self):
return self.name.lower()
+@dataclass
+class OmniACK:
+ """
+ Handshake payload from Workers to Orchestrator.
+ """
+
+ task_id: str
+ status: str
+ stage_id: int | None = None
+ rank: int | None = None
+ freed_bytes: int = 0
+ metadata: dict[str, Any] = field(default_factory=dict)
+ """
+ Additional telemetry such as:
+ - max_contiguous_block: for fragmentation analysis.
+ - cuda_graph_recalled: boolean if graphs were successfully destroyed/rebuilt.
+ - latency_ms: time taken for the D2H/H2D transfer.
+ """
+ error_msg: str | None = None
+
+
+@dataclass
+class OmniSleepTask:
+ """Structured sleep instruction."""
+
+ task_id: str
+ level: int = 2
+ metadata: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class OmniWakeTask:
+ """Structured wake-up instruction."""
+
+ task_id: str
+ tags: list[str] | None = None
+
+
# Special message broadcast via scheduler queues to signal worker shutdown.
SHUTDOWN_MESSAGE = {"type": "shutdown"}
diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py
index fe940d623e5..abaf5989598 100644
--- a/vllm_omni/diffusion/diffusion_engine.py
+++ b/vllm_omni/diffusion/diffusion_engine.py
@@ -14,6 +14,7 @@
import PIL.Image
import torch
from vllm.logger import init_logger
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.data import (
DiffusionOutput,
@@ -122,7 +123,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]:
if output.aborted:
raise DiffusionRequestAbortedError(output.abort_message or "Diffusion request aborted.")
if output.error:
- raise RuntimeError(f"{output.error}")
+ raise RuntimeError(output.error)
logger.info("Generation completed successfully.")
if output.output is None:
@@ -358,6 +359,8 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus
sched_req_id = sched_output.scheduled_req_ids[0]
try:
runner_output = self.execute_fn(sched_output)
+ except EngineDeadError:
+ raise
except Exception as exc:
logger.error("Execution failed for diffusion request %s", sched_req_id, exc_info=True)
runner_output = RunnerOutput(
diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py
index e55a464fb4a..dcb35cfde1f 100644
--- a/vllm_omni/diffusion/executor/multiproc_executor.py
+++ b/vllm_omni/diffusion/executor/multiproc_executor.py
@@ -1,14 +1,18 @@
from __future__ import annotations
import multiprocessing as mp
+import multiprocessing.connection
+import threading
import time
import weakref
+from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import zmq
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, DiffusionOutput
from vllm_omni.diffusion.executor.abstract import DiffusionExecutor
@@ -22,6 +26,8 @@
logger = init_logger(__name__)
+_DEQUEUE_TIMEOUT_S = 5.0
+
@dataclass
class BackgroundResources:
@@ -36,10 +42,14 @@ class BackgroundResources:
def __call__(self):
"""Clean up background resources."""
+ if hasattr(self, "wake_events") and self.wake_events:
+ for ev in self.wake_events:
+ ev.set()
+
if self.broadcast_mq is not None:
try:
for _ in range(self.num_workers):
- self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE)
+ self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE, timeout=1.0)
self.broadcast_mq = None
self.result_mq = None
@@ -63,13 +73,17 @@ class MultiprocDiffusionExecutor(DiffusionExecutor):
def _init_executor(self) -> None:
self._processes: list[mp.Process] = []
self._closed = False
+ self.is_failed = False
+ self._failure_callbacks: list[Callable[[], None]] = []
num_workers = self.od_config.num_gpus
+ self.wake_events = [mp.Event() for _ in range(num_workers)]
+
self._broadcast_mq = self._init_broadcast_queue(num_workers)
broadcast_handle = self._broadcast_mq.export_handle()
# Launch workers
- processes, result_handle = self._launch_workers(broadcast_handle)
+ processes, result_handle = self._launch_workers(broadcast_handle, self.wake_events)
self._result_mq = self._init_result_queue(result_handle)
self._processes = processes
@@ -81,6 +95,8 @@ def _init_executor(self) -> None:
)
self._finalizer = weakref.finalize(self, self.resources)
+ self.start_worker_monitor()
+
def _init_broadcast_queue(self, num_workers: int) -> MessageQueue:
return MessageQueue(
n_reader=num_workers,
@@ -100,7 +116,24 @@ def _ensure_open(self) -> None:
if self._result_mq is None:
raise RuntimeError("Result queue not initialized")
- def _launch_workers(self, broadcast_handle):
+ def _dequeue_one_with_failure_polling(self, deadline: float | None, method: str) -> Any:
+ """Block until one result message, polling ``is_failed`` between chunk timeouts."""
+ while True:
+ if deadline is None:
+ chunk_timeout = _DEQUEUE_TIMEOUT_S
+ else:
+ remaining = deadline - time.monotonic()
+ if remaining <= 0:
+ raise TimeoutError(f"RPC call to {method} timed out.")
+ chunk_timeout = min(_DEQUEUE_TIMEOUT_S, remaining)
+ try:
+ return self._result_mq.dequeue(timeout=chunk_timeout)
+ except (TimeoutError, zmq.error.Again):
+ if self.is_failed:
+ raise EngineDeadError()
+ continue
+
+ def _launch_workers(self, broadcast_handle, wake_events):
od_config = self.od_config
logger.info("Starting server...")
@@ -126,6 +159,7 @@ def _launch_workers(self, broadcast_handle):
od_config,
writer,
broadcast_handle,
+ wake_events[i],
worker_extension_cls,
custom_pipeline_args,
),
@@ -164,6 +198,49 @@ def _launch_workers(self, broadcast_handle):
return processes, result_handle
+ def start_worker_monitor(self) -> None:
+ # Monitors worker process liveness. If any die unexpectedly,
+ # logs an error, shuts down the executor and invokes the failure
+ # callback to inform the engine.
+ sentinels = [p.sentinel for p in self._processes]
+ if not sentinels:
+ return
+
+ def _monitor() -> None:
+ try:
+ finished = multiprocessing.connection.wait(sentinels)
+ except OSError:
+ return
+
+ if self._closed:
+ return
+
+ dead = [p.name for p in self._processes if p.sentinel in finished]
+ if dead:
+ logger.error(
+ "Diffusion worker(s) died unexpectedly: %s",
+ dead,
+ )
+ self.is_failed = True
+
+ self.shutdown()
+
+ for cb in self._failure_callbacks:
+ try:
+ cb()
+ except Exception:
+ logger.exception("failure_callback raised")
+
+ t = threading.Thread(target=_monitor, daemon=True, name="diffusion-worker-monitor")
+ t.start()
+
+ def register_failure_callback(
+ self,
+ callback: Callable[[], None],
+ ) -> None:
+ """Register a callback invoked when a worker process dies."""
+ self._failure_callbacks.append(callback)
+
def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput:
self._ensure_open()
rpc_request = {
@@ -286,27 +363,21 @@ def collective_rpc(
responses = []
for _ in range(num_responses):
- dequeue_timeout = None if deadline is None else max(0, deadline - time.monotonic())
+ response = self._dequeue_one_with_failure_polling(deadline, method)
+
try:
- response = self._result_mq.dequeue(timeout=dequeue_timeout)
-
- try:
- unpack_diffusion_output_shm(response)
- except Exception as e:
- logger.warning("SHM unpack failed (data may already be inline): %s", e)
-
- # Check if response indicates an error
- if isinstance(response, dict) and response.get("status") == "error":
- raise RuntimeError(
- f"Worker failed with error '{response.get('error')}', "
- "please check the stack trace above for the root cause"
- )
-
- responses.append(response)
- except zmq.error.Again as e:
- raise TimeoutError(f"RPC call to {method} timed out.") from e
- except TimeoutError as e:
- raise TimeoutError(f"RPC call to {method} timed out.") from e
+ unpack_diffusion_output_shm(response)
+ except Exception as e:
+ logger.warning("SHM unpack failed (data may already be inline): %s", e)
+
+ # Check if response indicates an error
+ if isinstance(response, dict) and response.get("status") == "error":
+ raise RuntimeError(
+ f"Worker failed with error '{response.get('error')}', "
+ "please check the stack trace above for the root cause"
+ )
+
+ responses.append(response)
return responses[0] if unique_reply_rank is not None else responses
except Exception as e:
@@ -314,10 +385,13 @@ def collective_rpc(
raise
def check_health(self) -> None:
- # Simple check if processes are alive
+ self._ensure_open()
+ if self.is_failed:
+ raise EngineDeadError()
for p in self._processes:
if not p.is_alive():
- raise RuntimeError(f"Worker process {p.name} is dead")
+ self.is_failed = True
+ raise EngineDeadError(f"Worker process {p.name} is dead")
def shutdown(self) -> None:
self._closed = True
diff --git a/vllm_omni/diffusion/inline_stage_diffusion_client.py b/vllm_omni/diffusion/inline_stage_diffusion_client.py
index 5cdd6e4fabc..0eec8fee996 100644
--- a/vllm_omni/diffusion/inline_stage_diffusion_client.py
+++ b/vllm_omni/diffusion/inline_stage_diffusion_client.py
@@ -34,6 +34,7 @@ class InlineStageDiffusionClient:
"""Runs DiffusionEngine in a thread executor inside the Orchestrator."""
stage_type: str = "diffusion"
+ replica_id: int = 0
def __init__(
self,
@@ -45,6 +46,7 @@ def __init__(
self.model = model
self.od_config = od_config
self.stage_id = metadata.stage_id
+ self.replica_id = getattr(metadata, "replica_id", 0)
self.final_output = metadata.final_output
self.final_output_type = metadata.final_output_type
self.default_sampling_params = metadata.default_sampling_params
@@ -62,8 +64,9 @@ def __init__(
self._shutting_down = False
logger.info(
- "[InlineStageDiffusionClient] Stage-%s initialized inline (batch_size=%d)",
+ "[InlineStageDiffusionClient] stage-%s [rep-%s] initialized inline (batch_size=%d)",
self.stage_id,
+ self.replica_id,
self.batch_size,
)
@@ -82,6 +85,12 @@ async def add_request_async(
sampling_params: OmniDiffusionSamplingParams,
kv_sender_info: dict[int, dict[str, Any]] | None = None,
) -> None:
+ logger.info(
+ "[InlineStageDiffusionClient] stage-%s [rep-%s] add request: %s",
+ self.stage_id,
+ self.replica_id,
+ request_id,
+ )
task = asyncio.create_task(
self._dispatch_request(
request_id,
@@ -135,6 +144,13 @@ async def add_batch_request_async(
sampling_params: OmniDiffusionSamplingParams,
kv_sender_info: dict[int, dict[str, Any]] | None = None,
) -> None:
+ logger.info(
+ "[InlineStageDiffusionClient] stage-%s [rep-%s] add batch request: %s (%d prompts)",
+ self.stage_id,
+ self.replica_id,
+ request_id,
+ len(prompts),
+ )
task = asyncio.create_task(
self._dispatch_batch(
request_id,
@@ -254,7 +270,7 @@ async def collective_rpc_async(
is_start = args[0] if args else True
profile_prefix = args[1] if len(args) > 1 else None
if is_start and profile_prefix is None:
- profile_prefix = f"stage_{self.stage_id}_diffusion_{int(time.time())}"
+ profile_prefix = f"stage_{self.stage_id}_rep_{self.replica_id}_diffusion_{int(time.time())}"
return await loop.run_in_executor(
self._executor,
self._engine.profile,
diff --git a/vllm_omni/diffusion/models/dmd2/__init__.py b/vllm_omni/diffusion/models/dmd2/__init__.py
new file mode 100644
index 00000000000..d0c8219d4d1
--- /dev/null
+++ b/vllm_omni/diffusion/models/dmd2/__init__.py
@@ -0,0 +1,8 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm_omni.diffusion.models.dmd2.mixin import DMD2PipelineMixin
+
+__all__ = [
+ "DMD2PipelineMixin",
+]
diff --git a/vllm_omni/diffusion/models/dmd2/mixin.py b/vllm_omni/diffusion/models/dmd2/mixin.py
new file mode 100644
index 00000000000..60c4b95baff
--- /dev/null
+++ b/vllm_omni/diffusion/models/dmd2/mixin.py
@@ -0,0 +1,88 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from __future__ import annotations
+
+import logging
+import os
+
+from vllm_omni.diffusion.data import DiffusionOutput
+from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler
+from vllm_omni.diffusion.models.utils import _load_json
+from vllm_omni.diffusion.request import OmniDiffusionRequest
+
+logger = logging.getLogger(__name__)
+
+
+class DMD2PipelineMixin:
+ """Mixin for FastGen DMD2-distilled models. Must appear before the base pipeline in MRO."""
+
+ def __init_dmd2__(self) -> None:
+ """Call after super().__init__() to apply DMD2 scheduler and read model_index."""
+ local_files_only = os.path.exists(self.od_config.model)
+ try:
+ model_index = _load_json(self.od_config.model, "model_index.json", local_files_only)
+ except Exception:
+ model_index = {}
+
+ dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624])
+ self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4)
+ shift = model_index.get("dmd2_scheduler_shift", 1.0)
+ self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0)
+
+ self.scheduler = DMD2EulerScheduler(
+ num_train_timesteps=1000,
+ shift=shift,
+ dmd2_timesteps=dmd2_timesteps,
+ )
+
+ def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None:
+ """Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts."""
+ sp = req.sampling_params
+
+ if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps:
+ logger.warning(
+ "DMD2: ignoring num_inference_steps=%d, forcing %d.",
+ sp.num_inference_steps,
+ self.num_inference_steps,
+ )
+ sp.num_inference_steps = self.num_inference_steps
+
+ if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale:
+ logger.warning(
+ "DMD2: ignoring guidance_scale=%.2f, forcing %.2f.",
+ sp.guidance_scale,
+ self.dmd2_guidance_scale,
+ )
+ sp.guidance_scale = self.dmd2_guidance_scale
+ sp.guidance_scale_provided = False
+
+ if sp.guidance_scale_2 is not None:
+ logger.warning("DMD2: ignoring guidance_scale_2.")
+ sp.guidance_scale_2 = None
+
+ if sp.true_cfg_scale is not None:
+ logger.warning("DMD2: ignoring true_cfg_scale.")
+ sp.true_cfg_scale = None
+
+ sp.do_classifier_free_guidance = False
+ sp.is_cfg_negative = False
+
+ fixed = []
+ for p in req.prompts:
+ if isinstance(p, dict) and "negative_prompt" in p:
+ logger.warning("DMD2: ignoring negative_prompt.")
+ p = {k: v for k, v in p.items() if k != "negative_prompt"}
+ fixed.append(p)
+ req.prompts = fixed
+
+ def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput:
+ self._sanitize_dmd2_request(req)
+ kwargs.pop("guidance_scale", None)
+ kwargs.pop("num_inference_steps", None)
+ return super().forward(
+ req,
+ guidance_scale=self.dmd2_guidance_scale,
+ num_inference_steps=self.num_inference_steps,
+ **kwargs,
+ )
diff --git a/vllm_omni/diffusion/models/ltx2/__init__.py b/vllm_omni/diffusion/models/ltx2/__init__.py
index 9f9d70f0106..2a78b61baeb 100644
--- a/vllm_omni/diffusion/models/ltx2/__init__.py
+++ b/vllm_omni/diffusion/models/ltx2/__init__.py
@@ -4,12 +4,14 @@
from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import (
LTX2Pipeline,
+ LTX2T2VDMD2Pipeline,
LTX2TwoStagesPipeline,
create_transformer_from_config,
get_ltx2_post_process_func,
load_transformer_config,
)
from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import (
+ LTX2I2VDMD2Pipeline,
LTX2ImageToVideoPipeline,
LTX2ImageToVideoTwoStagesPipeline,
)
@@ -17,7 +19,9 @@
__all__ = [
"LTX2Pipeline",
+ "LTX2T2VDMD2Pipeline",
"LTX2ImageToVideoPipeline",
+ "LTX2I2VDMD2Pipeline",
"LTX2LatentUpsamplePipeline",
"LTX2TwoStagesPipeline",
"LTX2ImageToVideoTwoStagesPipeline",
diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
index 95ef919c24e..0ae71aa2d71 100644
--- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
+++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
@@ -41,6 +41,7 @@
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
+from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module
from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput
from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available
@@ -1265,6 +1266,7 @@ class LTX2VideoTransformer3DModel(nn.Module):
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTX2VideoTransformerBlock"]
_layerwise_offload_blocks_attrs = ["transformer_blocks"]
+ _hsdp_shard_conditions = [is_transformer_block_module]
_sp_plan: dict[str, Any] | None = None
@staticmethod
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
index c60b192f0a5..f06ffab165a 100644
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
+++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py
@@ -33,6 +33,7 @@
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.lora.request import LoRARequest
@@ -1304,3 +1305,11 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
+
+
+class LTX2T2VDMD2Pipeline(DMD2PipelineMixin, LTX2Pipeline):
+ """LTX-2 T2V pipeline for FastGen DMD2-distilled models."""
+
+ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
+ super().__init__(od_config=od_config, prefix=prefix)
+ self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py
index 65e7454b73f..50a71a54b61 100644
--- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py
+++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py
@@ -25,6 +25,7 @@
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.lora.request import LoRARequest
@@ -889,3 +890,11 @@ def forward(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
+
+
+class LTX2I2VDMD2Pipeline(DMD2PipelineMixin, LTX2ImageToVideoPipeline):
+ """LTX-2 I2V pipeline for FastGen DMD2-distilled models."""
+
+ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
+ super().__init__(od_config=od_config, prefix=prefix)
+ self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py b/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py
index 881c72edc6d..c1abdf91f04 100644
--- a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py
+++ b/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py
@@ -48,6 +48,7 @@
)
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.models.t5_encoder.t5_gemma_encoder import T5GemmaEncoderModelTP
+from vllm_omni.diffusion.models.utils import _load_json
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import (
DiffusionPipelineProfilerMixin,
)
@@ -1640,20 +1641,6 @@ def post_process(output):
# ===========================================================================
-def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict:
- """Load a JSON config file from a local path or HuggingFace Hub repo."""
- if local_files_only:
- path = os.path.join(model_path, *filename.split("/"))
- with open(path) as f:
- return json.load(f)
- else:
- from huggingface_hub import hf_hub_download
-
- cached = hf_hub_download(repo_id=model_path, filename=filename)
- with open(cached) as f:
- return json.load(f)
-
-
def _resolve_subdir(
model_path: str,
subfolder: str,
diff --git a/vllm_omni/diffusion/models/schedulers/__init__.py b/vllm_omni/diffusion/models/schedulers/__init__.py
index 6f8df78ebf0..e683ed27203 100644
--- a/vllm_omni/diffusion/models/schedulers/__init__.py
+++ b/vllm_omni/diffusion/models/schedulers/__init__.py
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from vllm_omni.diffusion.models.schedulers.scheduling_dmd2_euler import DMD2EulerScheduler
from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import (
FlowUniPCMultistepScheduler,
)
__all__ = [
+ "DMD2EulerScheduler",
"FlowUniPCMultistepScheduler",
]
diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py
new file mode 100644
index 00000000000..01447a41d77
--- /dev/null
+++ b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py
@@ -0,0 +1,23 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from __future__ import annotations
+
+import torch
+from diffusers import FlowMatchEulerDiscreteScheduler
+
+
+class DMD2EulerScheduler(FlowMatchEulerDiscreteScheduler):
+ """Euler scheduler that always uses the fixed DMD2 training timestep schedule."""
+
+ def __init__(self, *args, dmd2_timesteps: list[int], **kwargs):
+ super().__init__(*args, **kwargs)
+ self._dmd2_timesteps = dmd2_timesteps
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int | None = None,
+ device: str | torch.device | None = None,
+ **kwargs,
+ ) -> None:
+ super().set_timesteps(timesteps=self._dmd2_timesteps, device=device)
diff --git a/vllm_omni/diffusion/models/utils.py b/vllm_omni/diffusion/models/utils.py
new file mode 100644
index 00000000000..ba0d8dda20c
--- /dev/null
+++ b/vllm_omni/diffusion/models/utils.py
@@ -0,0 +1,21 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from __future__ import annotations
+
+import json
+import os
+
+
+def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict:
+ """Load a JSON config file from a local path or HuggingFace Hub repo."""
+ if local_files_only:
+ path = os.path.join(model_path, *filename.split("/"))
+ with open(path) as f:
+ return json.load(f)
+ else:
+ from huggingface_hub import hf_hub_download
+
+ cached = hf_hub_download(repo_id=model_path, filename=filename)
+ with open(cached) as f:
+ return json.load(f)
diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py
index d418001d952..97808df29d8 100644
--- a/vllm_omni/diffusion/models/wan2_2/__init__.py
+++ b/vllm_omni/diffusion/models/wan2_2/__init__.py
@@ -3,6 +3,7 @@
from .pipeline_wan2_2 import (
Wan22Pipeline,
+ WanT2VDMD2Pipeline,
create_transformer_from_config,
get_wan22_post_process_func,
get_wan22_pre_process_func,
@@ -11,6 +12,7 @@
)
from .pipeline_wan2_2_i2v import (
Wan22I2VPipeline,
+ WanI2VDMD2Pipeline,
get_wan22_i2v_post_process_func,
get_wan22_i2v_pre_process_func,
)
@@ -28,6 +30,7 @@
from .wan2_2_vace_transformer import VaceWanTransformerBlock, WanVACETransformer3DModel
__all__ = [
+ "WanT2VDMD2Pipeline",
"Wan22Pipeline",
"get_wan22_post_process_func",
"get_wan22_pre_process_func",
@@ -35,6 +38,7 @@
"load_transformer_config",
"create_transformer_from_config",
"Wan22I2VPipeline",
+ "WanI2VDMD2Pipeline",
"get_wan22_i2v_post_process_func",
"get_wan22_i2v_pre_process_func",
"Wan22TI2VPipeline",
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
index 652425d5097..c74b72441d4 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
@@ -22,6 +22,7 @@
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler
from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler
@@ -29,16 +30,12 @@
from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.prompt_utils import (
- validate_prompt_sequence_lengths,
-)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform
logger = logging.getLogger(__name__)
DEBUG_PERF = False
WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"}
-WAN22_MAX_SEQUENCE_LENGTH = 512
def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any:
@@ -293,7 +290,6 @@ def __init__(
pass
self.boundary_ratio = od_config.boundary_ratio
- self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
# Determine which transformers to load based on boundary_ratio
# boundary_ratio=1.0: only load transformer_2 (low-noise stage only)
@@ -565,7 +561,6 @@ def forward(
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale_2=guidance_high if boundary_ratio is not None else None,
boundary_ratio=boundary_ratio,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -599,7 +594,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
+ max_sequence_length=req.sampling_params.max_sequence_length or 512,
device=device,
dtype=dtype,
)
@@ -832,20 +827,6 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
- text_inputs_untruncated = self.tokenizer(
- prompt_clean,
- padding=True,
- truncation=False,
- add_special_tokens=True,
- return_attention_mask=True,
- return_tensors="pt",
- )
- validate_prompt_sequence_lengths(
- text_inputs_untruncated.attention_mask,
- max_sequence_length=max_sequence_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- error_context="for Wan2.2 text encoding",
- )
text_inputs = self.tokenizer(
prompt_clean,
@@ -874,24 +855,8 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
- negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
- neg_text_inputs_untruncated = self.tokenizer(
- negative_prompt_clean,
- padding=True,
- truncation=False,
- add_special_tokens=True,
- return_attention_mask=True,
- return_tensors="pt",
- )
- validate_prompt_sequence_lengths(
- neg_text_inputs_untruncated.attention_mask,
- max_sequence_length=max_sequence_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- prompt_name="negative_prompt",
- error_context="for Wan2.2 text encoding",
- )
neg_text_inputs = self.tokenizer(
- negative_prompt_clean,
+ [self._prompt_clean(p) for p in negative_prompt],
padding="max_length",
max_length=max_sequence_length,
truncation=True,
@@ -963,7 +928,6 @@ def check_inputs(
negative_prompt_embeds=None,
guidance_scale_2=None,
boundary_ratio=None,
- max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -990,10 +954,18 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
- if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
- raise ValueError(
- f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
- )
-
if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")
+
+
+# ---------------------------------------------------------------------------
+# DMD2-distilled variant
+# ---------------------------------------------------------------------------
+
+
+class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline):
+ """Wan 2.x T2V pipeline for FastGen DMD2-distilled models."""
+
+ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
+ super().__init__(od_config=od_config, prefix=prefix)
+ self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
index 95d1e08bbc7..726d8853a7c 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
@@ -23,10 +23,11 @@
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero
+from vllm_omni.diffusion.models.utils import _load_json
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
- WAN22_MAX_SEQUENCE_LENGTH,
build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
@@ -37,9 +38,6 @@
from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.prompt_utils import (
- validate_prompt_sequence_lengths,
-)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform
@@ -47,29 +45,6 @@
DEBUG_PERF = False
-def _load_model_index(model: str, local_files_only: bool) -> dict:
- """Load model_index.json from local path or HF Hub."""
- if local_files_only:
- model_index_path = os.path.join(model, "model_index.json")
- if os.path.exists(model_index_path):
- import json
-
- with open(model_index_path) as f:
- return json.load(f)
- else:
- try:
- import json
-
- from huggingface_hub import hf_hub_download
-
- model_index_path = hf_hub_download(model, "model_index.json")
- with open(model_index_path) as f:
- return json.load(f)
- except Exception:
- pass
- return {}
-
-
def get_wan22_i2v_post_process_func(
od_config: OmniDiffusionConfig,
):
@@ -200,7 +175,10 @@ def __init__(
]
# Load model_index.json to detect available components
- model_index = _load_model_index(model, local_files_only)
+ try:
+ model_index = _load_json(model, "model_index.json", local_files_only)
+ except Exception:
+ model_index = {}
# Check if this is a two-stage model (MoE with transformer_2)
self.has_transformer_2 = "transformer_2" in model_index
@@ -218,7 +196,6 @@ def __init__(
# Text encoder
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
- self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
self.text_encoder = UMT5EncoderModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
@@ -474,7 +451,6 @@ def forward(
image_embeds=image_embeds,
guidance_scale_2=guidance_high if boundary_ratio is not None else None,
boundary_ratio=boundary_ratio,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
# Adjust num_frames to be compatible with VAE temporal scaling
@@ -503,7 +479,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
+ max_sequence_length=req.sampling_params.max_sequence_length or 512,
device=device,
dtype=dtype,
)
@@ -708,20 +684,6 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
- text_inputs_untruncated = self.tokenizer(
- prompt_clean,
- padding=True,
- truncation=False,
- add_special_tokens=True,
- return_attention_mask=True,
- return_tensors="pt",
- )
- validate_prompt_sequence_lengths(
- text_inputs_untruncated.attention_mask,
- max_sequence_length=max_sequence_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- error_context="for Wan2.2 text encoding",
- )
text_inputs = self.tokenizer(
prompt_clean,
@@ -750,24 +712,8 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
- negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
- neg_text_inputs_untruncated = self.tokenizer(
- negative_prompt_clean,
- padding=True,
- truncation=False,
- add_special_tokens=True,
- return_attention_mask=True,
- return_tensors="pt",
- )
- validate_prompt_sequence_lengths(
- neg_text_inputs_untruncated.attention_mask,
- max_sequence_length=max_sequence_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- prompt_name="negative_prompt",
- error_context="for Wan2.2 text encoding",
- )
neg_text_inputs = self.tokenizer(
- negative_prompt_clean,
+ [self._prompt_clean(p) for p in negative_prompt],
padding="max_length",
max_length=max_sequence_length,
truncation=True,
@@ -911,7 +857,6 @@ def check_inputs(
image_embeds=None,
guidance_scale_2=None,
boundary_ratio=None,
- max_sequence_length=None,
):
if image is None and image_embeds is None:
raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.")
@@ -933,11 +878,6 @@ def check_inputs(
if prompt is None and prompt_embeds is None:
raise ValueError("Provide either `prompt` or `prompt_embeds`.")
- if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
- raise ValueError(
- f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
- )
-
if boundary_ratio is None and guidance_scale_2 is not None:
raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.")
@@ -945,3 +885,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights using AutoWeightsLoader for vLLM integration."""
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
+
+
+# ---------------------------------------------------------------------------
+# DMD2-distilled variant
+# ---------------------------------------------------------------------------
+
+
+class WanI2VDMD2Pipeline(DMD2PipelineMixin, Wan22I2VPipeline):
+ """Wan 2.x I2V pipeline for FastGen DMD2-distilled models."""
+
+ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
+ super().__init__(od_config=od_config, prefix=prefix)
+ self.__init_dmd2__()
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
index dba76ba8af8..8170a8f5ab5 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
@@ -37,7 +37,6 @@
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin
from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
- WAN22_MAX_SEQUENCE_LENGTH,
build_wan_scheduler,
create_transformer_from_config,
load_transformer_config,
@@ -47,9 +46,6 @@
)
from vllm_omni.diffusion.postprocess import interpolate_video_tensor
from vllm_omni.diffusion.request import OmniDiffusionRequest
-from vllm_omni.diffusion.utils.prompt_utils import (
- validate_prompt_sequence_lengths,
-)
from vllm_omni.inputs.data import OmniTextPrompt
from vllm_omni.platforms import current_omni_platform
@@ -189,7 +185,6 @@ def __init__(
# Text encoder
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
- self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH
self.text_encoder = UMT5EncoderModel.from_pretrained(
model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only
).to(self.device)
@@ -377,7 +372,6 @@ def forward(
width=width,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
# Adjust num_frames to be compatible with VAE temporal scaling
@@ -401,7 +395,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_scale > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
+ max_sequence_length=req.sampling_params.max_sequence_length or 512,
device=device,
dtype=dtype,
)
@@ -551,20 +545,6 @@ def encode_prompt(
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_clean = [self._prompt_clean(p) for p in prompt]
batch_size = len(prompt_clean)
- text_inputs_untruncated = self.tokenizer(
- prompt_clean,
- padding=True,
- truncation=False,
- add_special_tokens=True,
- return_attention_mask=True,
- return_tensors="pt",
- )
- validate_prompt_sequence_lengths(
- text_inputs_untruncated.attention_mask,
- max_sequence_length=max_sequence_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- error_context="for Wan2.2 text encoding",
- )
text_inputs = self.tokenizer(
prompt_clean,
@@ -593,24 +573,8 @@ def encode_prompt(
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
- negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt]
- neg_text_inputs_untruncated = self.tokenizer(
- negative_prompt_clean,
- padding=True,
- truncation=False,
- add_special_tokens=True,
- return_attention_mask=True,
- return_tensors="pt",
- )
- validate_prompt_sequence_lengths(
- neg_text_inputs_untruncated.attention_mask,
- max_sequence_length=max_sequence_length,
- supported_max_sequence_length=self.tokenizer_max_length,
- prompt_name="negative_prompt",
- error_context="for Wan2.2 text encoding",
- )
neg_text_inputs = self.tokenizer(
- negative_prompt_clean,
+ [self._prompt_clean(p) for p in negative_prompt],
padding="max_length",
max_length=max_sequence_length,
truncation=True,
@@ -735,7 +699,6 @@ def check_inputs(
width,
prompt_embeds=None,
negative_prompt_embeds=None,
- max_sequence_length=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -751,11 +714,6 @@ def check_inputs(
if prompt is None and prompt_embeds is None:
raise ValueError("Provide either `prompt` or `prompt_embeds`.")
- if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length:
- raise ValueError(
- f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}"
- )
-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights using AutoWeightsLoader for vLLM integration."""
loader = AutoWeightsLoader(self)
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
index 11408e2d24b..0458f88597e 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
@@ -243,7 +243,6 @@ def check_inputs(
video=None,
mask=None,
reference_images=None,
- max_sequence_length=None,
):
super().check_inputs(
prompt=prompt,
@@ -252,7 +251,6 @@ def check_inputs(
width=width,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
- max_sequence_length=max_sequence_length,
)
# VACE-specific: validate video/mask/reference_images consistency
@@ -549,7 +547,6 @@ def forward(
video=source_video,
mask=source_mask,
reference_images=reference_images,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
)
device = self.device
@@ -568,7 +565,7 @@ def forward(
negative_prompt=negative_prompt,
do_classifier_free_guidance=guidance_scale > 1.0,
num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1,
- max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length,
+ max_sequence_length=req.sampling_params.max_sequence_length or 512,
device=device,
dtype=dtype,
)
diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py
index 7876b009475..d1216e2f28c 100644
--- a/vllm_omni/diffusion/offloader/layerwise_backend.py
+++ b/vllm_omni/diffusion/offloader/layerwise_backend.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from __future__ import annotations
from itertools import chain
from typing import Any
diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py
index 0bf8c04517b..4001109cc9e 100644
--- a/vllm_omni/diffusion/registry.py
+++ b/vllm_omni/diffusion/registry.py
@@ -83,6 +83,16 @@
"pipeline_ltx2_image2video",
"LTX2ImageToVideoTwoStagesPipeline",
),
+ "LTX2T2VDMD2Pipeline": (
+ "ltx2",
+ "pipeline_ltx2",
+ "LTX2T2VDMD2Pipeline",
+ ),
+ "LTX2I2VDMD2Pipeline": (
+ "ltx2",
+ "pipeline_ltx2_image2video",
+ "LTX2I2VDMD2Pipeline",
+ ),
"StableAudioPipeline": (
"stable_audio",
"pipeline_stable_audio",
@@ -93,6 +103,16 @@
"pipeline_wan2_2_i2v",
"Wan22I2VPipeline",
),
+ "WanT2VDMD2Pipeline": (
+ "wan2_2",
+ "pipeline_wan2_2",
+ "WanT2VDMD2Pipeline",
+ ),
+ "WanI2VDMD2Pipeline": (
+ "wan2_2",
+ "pipeline_wan2_2_i2v",
+ "WanI2VDMD2Pipeline",
+ ),
"LongCatImagePipeline": (
"longcat_image",
"pipeline_longcat_image",
@@ -357,8 +377,12 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -
"LTX2TwoStagesPipeline": "get_ltx2_post_process_func",
"LTX2ImageToVideoPipeline": "get_ltx2_post_process_func",
"LTX2ImageToVideoTwoStagesPipeline": "get_ltx2_post_process_func",
+ "LTX2T2VDMD2Pipeline": "get_ltx2_post_process_func",
+ "LTX2I2VDMD2Pipeline": "get_ltx2_post_process_func",
"StableAudioPipeline": "get_stable_audio_post_process_func",
"WanImageToVideoPipeline": "get_wan22_i2v_post_process_func",
+ "WanT2VDMD2Pipeline": "get_wan22_post_process_func",
+ "WanI2VDMD2Pipeline": "get_wan22_i2v_post_process_func",
"LongCatImagePipeline": "get_longcat_image_post_process_func",
"BagelPipeline": "get_bagel_post_process_func",
"LongCatImageEditPipeline": "get_longcat_image_post_process_func",
@@ -390,6 +414,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) -
"WanPipeline": "get_wan22_pre_process_func",
"WanVACEPipeline": "get_wan22_vace_pre_process_func",
"WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func",
+ "WanT2VDMD2Pipeline": "get_wan22_pre_process_func",
+ "WanI2VDMD2Pipeline": "get_wan22_i2v_pre_process_func",
"OmniGen2Pipeline": "get_omnigen2_pre_process_func",
"HeliosPipeline": "get_helios_pre_process_func",
"HeliosPyramidPipeline": "get_helios_pre_process_func",
diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py
index f952f32d9ce..dacbe387add 100644
--- a/vllm_omni/diffusion/stage_diffusion_client.py
+++ b/vllm_omni/diffusion/stage_diffusion_client.py
@@ -8,15 +8,20 @@
from __future__ import annotations
import asyncio
+import multiprocessing.connection
import time
import uuid
+import weakref
from dataclasses import fields, is_dataclass
+from threading import Thread
from typing import TYPE_CHECKING, Any
import zmq
from vllm.logger import init_logger
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.diffusion.stage_diffusion_proc import (
+ StageDiffusionProc,
complete_diffusion_handshake,
spawn_diffusion_proc,
)
@@ -62,6 +67,7 @@ class StageDiffusionClient:
"""
stage_type: str = "diffusion"
+ replica_id: int = 0
def __init__(
self,
@@ -107,12 +113,13 @@ def _initialize_client(
batch_size: int,
) -> None:
self.stage_id = metadata.stage_id
+ self.replica_id = getattr(metadata, "replica_id", 0)
self.final_output = metadata.final_output
self.final_output_type = metadata.final_output_type
self.default_sampling_params = metadata.default_sampling_params
- self.requires_multimodal_data = metadata.requires_multimodal_data
- self.custom_process_input_func = metadata.custom_process_input_func
- self.engine_input_source = metadata.engine_input_source
+ self.requires_multimodal_data = getattr(metadata, "requires_multimodal_data", False)
+ self.custom_process_input_func = getattr(metadata, "custom_process_input_func", None)
+ self.engine_input_source = getattr(metadata, "engine_input_source", [])
self._proc = proc
self._owns_process = proc is not None
@@ -130,14 +137,53 @@ def _initialize_client(
self._pending_rpcs: set[str] = set()
self._tasks: dict[str, asyncio.Task] = {}
self._shutting_down = False
+ self._engine_dead: bool = False
+
+ # Background thread to detect silent process death (SIGKILL, segfault)
+ # where the subprocess cannot send the ZMQ death sentinel.
+ # Mirrors MPClient.start_engine_core_monitor() in vLLM.
+ self._start_proc_monitor()
logger.info(
- "[StageDiffusionClient] Stage-%s initialized (owns_process=%s, batch_size=%d)",
+ "[StageDiffusionClient] stage-%s [rep-%s] initialized (owns_process=%s, batch_size=%d)",
self.stage_id,
+ self.replica_id,
self._owns_process,
batch_size,
)
+ # ------------------------------------------------------------------
+ # Process monitor (mirrors vLLM's MPClient.start_engine_core_monitor)
+ # ------------------------------------------------------------------
+
+ def _start_proc_monitor(self) -> None:
+ """Start a daemon thread that watches the subprocess sentinel.
+
+ When the subprocess dies without sending the ZMQ death sentinel
+ (e.g. SIGKILL, segfault), this thread sets ``_engine_dead`` so
+ subsequent calls raise ``EngineDeadError``.
+ """
+ proc = self._proc
+ self_ref = weakref.ref(self)
+
+ def _monitor() -> None:
+ try:
+ multiprocessing.connection.wait([proc.sentinel])
+ except Exception:
+ return
+ client = self_ref()
+ if client is None or client._shutting_down or client._engine_dead:
+ return
+ client._engine_dead = True
+ logger.error(
+ "[StageDiffusionClient] stage-%s [rep-%s] StageDiffusionProc died unexpectedly (exit code %s).",
+ client.stage_id,
+ client.replica_id,
+ proc.exitcode,
+ )
+
+ Thread(target=_monitor, daemon=True, name="DiffusionProcMonitor").start()
+
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@@ -150,6 +196,16 @@ def _drain_responses(self) -> None:
except zmq.Again:
break
+ # Check for the death sentinel (raw bytes, not msgpack-encoded).
+ if raw == StageDiffusionProc.DIFFUSION_PROC_DEAD:
+ self._engine_dead = True
+ logger.error(
+ "[StageDiffusionClient] stage-%s [rep-%s] received DIFFUSION_PROC_DEAD sentinel from subprocess.",
+ self.stage_id,
+ self.replica_id,
+ )
+ break
+
msg = self._decoder.decode(raw)
msg_type = msg.get("type")
@@ -162,8 +218,9 @@ def _drain_responses(self) -> None:
rpc_id = msg.get("rpc_id")
error_msg = msg.get("error")
logger.error(
- "[StageDiffusionClient] Stage-%s subprocess error for %s: %s",
+ "[StageDiffusionClient] stage-%s [rep-%s] subprocess error for %s: %s",
self.stage_id,
+ self.replica_id,
rpc_id or req_id,
error_msg,
)
@@ -173,13 +230,10 @@ def _drain_responses(self) -> None:
"error": True,
"reason": error_msg,
}
- elif req_id is not None:
- error_output = OmniRequestOutput.from_diffusion(
- request_id=req_id,
- images=[],
- )
- error_output.error = error_msg
- self._output_queue.put_nowait(error_output)
+ # Route request errors as error outputs so the Orchestrator
+ # sees the request complete (instead of hanging forever).
+ if req_id is not None:
+ self._output_queue.put_nowait(OmniRequestOutput.from_error(req_id, error_msg))
# Fields that are subprocess-local and cannot be serialized across
# process boundaries. They are recreated in the subprocess with
@@ -243,6 +297,14 @@ async def add_request_async(
sampling_params: OmniDiffusionSamplingParams,
kv_sender_info: dict[int, dict[str, Any]] | None = None,
) -> None:
+ if self._engine_dead:
+ raise EngineDeadError()
+ logger.info(
+ "[StageDiffusionClient] stage-%s [rep-%s] add request: %s",
+ self.stage_id,
+ self.replica_id,
+ request_id,
+ )
self._request_socket.send(
self._encoder.encode(
{
@@ -270,6 +332,15 @@ async def add_batch_request_async(
and the combined result is placed on the output queue with a single
*request_id*.
"""
+ if self._engine_dead:
+ raise EngineDeadError()
+ logger.info(
+ "[StageDiffusionClient] stage-%s [rep-%s] add batch request: %s (%d prompts)",
+ self.stage_id,
+ self.replica_id,
+ request_id,
+ len(prompts),
+ )
task = asyncio.create_task(
self._run_batch(
request_id,
@@ -302,11 +373,13 @@ async def _run_batch(
)
except Exception as e:
logger.exception(
- "[StageDiffusionClient] Stage-%s batch req=%s failed: %s",
+ "[StageDiffusionClient] stage-%s [rep-%s] batch req=%s failed: %s",
self.stage_id,
+ self.replica_id,
request_id,
e,
)
+ await self._output_queue.put(OmniRequestOutput.from_error(request_id, str(e)))
finally:
self._tasks.pop(request_id, None)
@@ -315,7 +388,10 @@ def get_diffusion_output_nowait(self) -> OmniRequestOutput | None:
try:
return self._output_queue.get_nowait()
except asyncio.QueueEmpty:
+ if self._engine_dead:
+ raise EngineDeadError()
if not self._shutting_down and self._owns_process and self._proc is not None and not self._proc.is_alive():
+ self._engine_dead = True
exitcode = self._proc.exitcode
# One final drain – the last ZMQ frame may have arrived
# between the first drain and the is_alive() check.
@@ -329,7 +405,7 @@ def get_diffusion_output_nowait(self) -> OmniRequestOutput | None:
logger.warning("StageDiffusionProc was killed by signal %d; treating as external shutdown.", sig)
self._shutting_down = True
return None
- raise RuntimeError(f"StageDiffusionProc died unexpectedly (exit code {exitcode})")
+ raise EngineDeadError(f"StageDiffusionProc died unexpectedly (exit code {exitcode})")
return None
async def abort_requests_async(self, request_ids: list[str]) -> None:
@@ -350,13 +426,16 @@ async def collective_rpc_async(
kwargs: dict[str, Any] | None = None,
) -> Any:
"""Forward control RPCs to the diffusion subprocess."""
+ if self._engine_dead:
+ raise EngineDeadError()
+
# Inject a default profile_prefix that includes stage_id when profiling.
if method == "profile":
args_list = list(args)
is_start = args_list[0] if args_list else True
profile_prefix = args_list[1] if len(args_list) > 1 else None
if is_start and profile_prefix is None:
- profile_prefix = f"stage_{self.stage_id}_diffusion_{int(time.time())}"
+ profile_prefix = f"stage_{self.stage_id}_rep_{self.replica_id}_diffusion_{int(time.time())}"
if len(args_list) > 1:
args_list[1] = profile_prefix
else:
@@ -387,8 +466,9 @@ async def collective_rpc_async(
self._drain_responses()
if rpc_id in self._rpc_results:
return self._rpc_results.pop(rpc_id)
- if self._owns_process and self._proc is not None and not self._proc.is_alive():
- raise RuntimeError(
+ if self._engine_dead or (self._owns_process and self._proc is not None and not self._proc.is_alive()):
+ self._engine_dead = True
+ raise EngineDeadError(
f"StageDiffusionProc died while waiting for "
f"collective_rpc '{method}' (exit code {self._proc.exitcode})"
)
@@ -398,6 +478,19 @@ async def collective_rpc_async(
finally:
self._pending_rpcs.discard(rpc_id)
+ def check_health(self) -> None:
+ """Raise ``EngineDeadError`` if the diffusion engine is dead.
+
+ Mirrors the ``check_health`` protocol on vLLM's ``EngineClient``.
+ """
+ if self._engine_dead:
+ raise EngineDeadError(f"Stage-{self.stage_id} diffusion subprocess is dead")
+ if self._proc is not None and not self._proc.is_alive():
+ self._engine_dead = True
+ raise EngineDeadError(
+ f"Stage-{self.stage_id} diffusion subprocess is not alive (exit code: {self._proc.exitcode})."
+ )
+
def shutdown(self) -> None:
self._shutting_down = True
try:
diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py
index eced444fd32..d36c5c644c7 100644
--- a/vllm_omni/diffusion/stage_diffusion_proc.py
+++ b/vllm_omni/diffusion/stage_diffusion_proc.py
@@ -46,6 +46,8 @@ class StageDiffusionProc:
and ZMQ-based communication with StageDiffusionClient.
"""
+ DIFFUSION_PROC_DEAD = b"DIFFUSION_PROC_DEAD"
+
def __init__(self, model: str, od_config: OmniDiffusionConfig) -> None:
self._model = model
self._od_config = od_config
@@ -450,6 +452,16 @@ async def _dispatch_batch(
elif msg_type == "shutdown":
break
+ except Exception:
+ # Send the death sentinel so the client can detect the
+ # fatal failure promptly (mirrors EngineCoreProc._send_engine_dead).
+ try:
+ response_socket.setsockopt(zmq.LINGER, 4000)
+ await response_socket.send(StageDiffusionProc.DIFFUSION_PROC_DEAD)
+ except Exception:
+ logger.warning("Failed to send DIFFUSION_PROC_DEAD sentinel to client.")
+ raise
+
finally:
for task in tasks.values():
task.cancel()
diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py
index 160309e0d8d..621f7eebd86 100644
--- a/vllm_omni/diffusion/worker/diffusion_worker.py
+++ b/vllm_omni/diffusion/worker/diffusion_worker.py
@@ -27,7 +27,10 @@
from vllm_omni.diffusion.data import (
DiffusionOutput,
+ OmniACK,
OmniDiffusionConfig,
+ OmniSleepTask,
+ OmniWakeTask,
)
from vllm_omni.diffusion.distributed.parallel_state import (
destroy_distributed_env,
@@ -77,6 +80,7 @@ def __init__(
self.model_runner: DiffusionModelRunner | None = None
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
self.lora_manager: DiffusionLoRAManager | None = None
+ self.stage_id = getattr(od_config, "stage_id", 0)
self.init_device()
# Create model runner
self.model_runner = DiffusionModelRunner(
@@ -159,8 +163,10 @@ def _create_profiler(self) -> WorkerProfiler | None:
def _get_profiler(self) -> WorkerProfiler | None:
return getattr(self, "profiler", None)
- def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None) -> None:
+ def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None, **kwargs) -> None:
"""Load the diffusion model using DiffusionModelRunner."""
+ load_format = kwargs.get("load_format", load_format)
+ custom_pipeline_name = kwargs.get("custom_pipeline_name", custom_pipeline_name)
with (
set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config),
set_current_vllm_config(self.vllm_config),
@@ -170,6 +176,8 @@ def load_model(self, load_format: str = "default", custom_pipeline_name: str | N
load_format=load_format,
custom_pipeline_name=custom_pipeline_name,
)
+ current_omni_platform.synchronize()
+ gc.collect()
process_memory = get_process_gpu_memory(self.local_rank)
if process_memory is not None:
logger.info(
@@ -284,77 +292,176 @@ def sleep(self, level: int = 1) -> bool:
"""
from vllm.device_allocator.cumem import CuMemAllocator
- process_memory_before_sleep = get_process_gpu_memory(self.local_rank)
- free_bytes_before_sleep = None
- if process_memory_before_sleep is None:
- free_bytes_before_sleep = current_omni_platform.get_free_memory()
+ allocator = CuMemAllocator.get_instance()
+
+ usage_before = allocator.get_current_usage()
- # Save the buffers before level 2 sleep
if level == 2 and self.model_runner is not None:
+ if hasattr(self.model_runner, "graph_runners"):
+ self.model_runner.graph_runners.clear()
+ logger.info(f"[Worker {self.rank}] CUDA Graphs cleared.")
model = self.model_runner.pipeline
self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()}
- allocator = CuMemAllocator.get_instance()
- allocator.sleep(offload_tags=("weights",) if level == 1 else tuple())
- process_memory_after_sleep = get_process_gpu_memory(self.local_rank)
- if process_memory_before_sleep is not None and process_memory_after_sleep is not None:
- freed_bytes = process_memory_before_sleep - process_memory_after_sleep
- used_bytes = process_memory_after_sleep
- accounting_scope = "process-scoped"
+ free_mem_before = current_omni_platform.get_free_memory()
+
+ # Level 1: Offload weights; Level 2: Total Discard
+ offload_tags = ("weights",) if level == 1 else tuple()
+ allocator.sleep(offload_tags=offload_tags)
+
+ current_omni_platform.empty_cache()
+ current_omni_platform.synchronize()
+
+ free_mem_after = current_omni_platform.get_free_memory()
+ try:
+ total_mem = current_omni_platform.get_device_total_memory()
+ except (NotImplementedError, AttributeError):
+ total_mem = torch.cuda.get_device_properties(self.device).total_memory
+
+ phys_freed_bytes = max(0, free_mem_after - free_mem_before)
+ phys_used_bytes = total_mem - free_mem_after
+
+ if usage_before > 0:
+ logger.info(
+ f"[Diffusion Worker {self.rank}] Sleep Level {level}: "
+ f"physically freed {phys_freed_bytes / GiB_bytes:.2f} GiB, "
+ f"{phys_used_bytes / GiB_bytes:.2f} GiB is still in use."
+ )
else:
- free_bytes_after_sleep = current_omni_platform.get_free_memory()
- assert free_bytes_before_sleep is not None
- device_id = self.device.index if self.device.index is not None else 0
- total = current_omni_platform.get_device_total_memory(device_id)
- freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
- used_bytes = total - free_bytes_after_sleep
- accounting_scope = "device-scoped fallback"
- assert freed_bytes >= 0, "Memory usage increased after sleeping."
- logger.info(
- "Sleep mode (%s) freed %.2f GiB memory, %.2f GiB memory is still in use.",
- accounting_scope,
- freed_bytes / GiB_bytes,
- used_bytes / GiB_bytes,
- )
- return True
+ logger.info(f"[Worker {self.rank}] Sleep Level {level} completed (GPU was already empty).")
+ logger.info(f"[Worker {self.rank}] Memory usage before sleep: {usage_before / GiB_bytes:.2f} GiB.")
+ return usage_before
def wake_up(self, tags: list[str] | None = None) -> bool:
"""
- Wake up the worker from sleep mode. See the sleep function
- method for more details.
+ Wake up the worker from sleep mode.
+
+ Re-activates the memory allocator for the specified tags and restores
+ model buffers from CPU back to GPU if they were saved during Level 2 sleep.
Args:
- tags: An optional list of tags to reallocate the worker memory
- for specific memory allocations. Values must be in
- `("weights")`. If None, all memory is reallocated.
- wake_up should be called with all tags (or None) before the
- worker is used again.
+ tags: List of memory pool tags to re-activate (e.g., ["weights"]
+ to match Level 1 sleep). If None, all pools are re-activated.
"""
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
allocator.wake_up(tags)
-
- # Restore the buffers after level 2 sleep
+ current_omni_platform.synchronize()
if len(self._sleep_saved_buffers) and self.model_runner is not None:
model = self.model_runner.pipeline
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
+ logger.info(f"[Worker {self.rank}] Buffers restored from CPU.")
+ logger.info(f"[Worker {self.rank}] Wake-up complete.")
return True
+ def handle_sleep_task(self, task: OmniSleepTask) -> OmniACK:
+ from vllm_omni.platforms import current_omni_platform
+
+ try:
+ if isinstance(task, dict):
+ task = OmniSleepTask(**task)
+ logger.info(f"[Worker {self.rank}] Handshake Received: Task {task.task_id}")
+
+ current_omni_platform.synchronize()
+ usage_before = current_omni_platform.get_current_memory_usage(self.device)
+ self.sleep(level=task.level)
+ current_omni_platform.synchronize()
+ usage_after = current_omni_platform.get_current_memory_usage(self.device)
+ real_freed = max(0, usage_before - usage_after)
+ logger.info(f"[Worker {self.rank}] Preparing ACK: freed_bytes={real_freed / GiB_bytes:.2f} GiB.")
+
+ # Ensure all ranks have completed sleep before measuring memory and sending ACK
+ if torch.distributed.is_initialized():
+ t_freed = torch.tensor([float(real_freed)], device=self.device)
+ torch.distributed.all_reduce(t_freed)
+ real_freed = int(t_freed.item())
+
+ if self.rank != 0:
+ return None
+
+ ack = OmniACK(
+ task_id=task.task_id,
+ status="SUCCESS",
+ stage_id=self.stage_id,
+ rank=self.rank,
+ freed_bytes=real_freed,
+ # return RL need metadata
+ metadata={
+ "source": f"Platform_{current_omni_platform.get_device_name()}",
+ "total_freed_gib": f"{real_freed / GiB_bytes:.2f}",
+ "rank_residual_gib": f"{usage_after / GiB_bytes:.2f}",
+ },
+ )
+ logger.info(f"[Worker {self.rank}] ACK emitted. Freed {real_freed / GiB_bytes:.2f} GiB.")
+ return ack
+ except Exception as e:
+ logger.error(f"Sleep failed: {e}", exc_info=True)
+ if torch.distributed.is_initialized():
+ try:
+ torch.distributed.barrier()
+ except Exception:
+ pass
+ return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e))
+
+ def handle_wake_task(self, task: OmniWakeTask) -> OmniACK:
+ from vllm_omni.platforms import current_omni_platform
+
+ try:
+ if isinstance(task, dict):
+ task = OmniWakeTask(**task)
+ logger.info(f"[Worker {self.rank}] Responding to Wake-up Task: {task.task_id}")
+ self.wake_up(tags=task.tags)
+
+ logger.info(f"[Worker {self.rank}] wake_up logic finished, entering barrier...")
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+
+ current_omni_platform.synchronize()
+ usage_now = current_omni_platform.get_current_memory_usage(self.device)
+ current_used_gib = usage_now / (1024**3)
+
+ if self.rank != 0:
+ return None
+ logger.info(f"[Worker {self.rank}] PASSED barrier, about to return to loop.")
+
+ return OmniACK(
+ task_id=task.task_id,
+ status="SUCCESS",
+ stage_id=self.stage_id,
+ rank=self.rank,
+ metadata={
+ "state": "WARM",
+ "source": f"Platform_{current_omni_platform.get_device_name()}",
+ "current_vram_gib": f"{current_used_gib:.2f}",
+ },
+ )
+ except Exception as e:
+ logger.error(f"Wake-up failed on Rank {self.rank}: {e}", exc_info=True)
+ if torch.distributed.is_initialized():
+ try:
+ torch.distributed.barrier()
+ except Exception:
+ pass
+ return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e))
+
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
"""Get memory pool context for sleep mode support."""
- if self.od_config.enable_sleep_mode:
+ is_sleep_enabled = getattr(self.od_config, "enable_sleep_mode", False)
+ if is_sleep_enabled:
+ current_omni_platform.synchronize()
+ gc.collect()
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process."
+ logger.info(f"[Worker {self.rank}] Activating Diffusion CuMem pool for tag: {tag}")
return allocator.use_memory_pool(tag=tag)
- else:
- return nullcontext()
+ return nullcontext()
def shutdown(self) -> None:
"""Shutdown the worker and cleanup distributed environment."""
@@ -395,11 +502,13 @@ def __init__(
od_config: OmniDiffusionConfig,
gpu_id: int,
broadcast_handle,
+ wake_event: mp.Event,
worker_extension_cls: str | None = None,
custom_pipeline_args: dict[str, Any] | None = None,
):
self.od_config = od_config
self.gpu_id = gpu_id
+ self.wake_event = wake_event
# Inter-process Communication
self.context = zmq.Context(io_threads=2)
@@ -414,7 +523,13 @@ def __init__(
if gpu_id == 0:
self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0])
self.result_mq_handle = self.result_mq.export_handle()
+ WorkerProc._shared_result_handle = self.result_mq_handle
logger.info(f"Worker {gpu_id} created result MessageQueue")
+ else:
+ handle = getattr(WorkerProc, "_shared_result_handle", None)
+ if handle:
+ self.result_mq = MessageQueue.create_from_handle(handle, gpu_id)
+ logger.info(f"Worker {gpu_id} attached to shared result MessageQueue")
assert od_config.master_port is not None
@@ -438,13 +553,17 @@ def _create_worker(
)
return wrapper
- def return_result(self, output: object):
+ def return_result(self, output: Any):
"""Reply to client, only on rank 0."""
if self.result_mq is not None:
+ if isinstance(output, OmniACK):
+ self.result_mq.enqueue(output)
+ return
try:
pack_diffusion_output_shm(output)
except Exception as e:
- logger.warning("SHM pack failed, falling back to raw enqueue: %s", e)
+ if hasattr(output, "output"):
+ logger.warning("SHM pack failed for model output: %s", e)
self.result_mq.enqueue(output)
def recv_message(self):
@@ -480,20 +599,31 @@ def worker_busy_loop(self) -> None:
while self._running:
msg = None
try:
- msg = self.recv_message()
- except Exception as e:
- logger.error(
- f"Error receiving message in worker loop: {e}",
- exc_info=True,
- )
+ msg = self.mq.dequeue(timeout=1.0)
+ except Exception:
+ if self.wake_event and self.wake_event.is_set():
+ self.wake_event.clear()
+ logger.info(f"Worker {self.gpu_id} caught OOB POKE, forcing wake-up sequence.")
+ msg = {"type": "wake_up", "task_id": "recovery-task", "tags": None}
+ else:
+ continue
+ if msg is None:
continue
if msg is None or len(msg) == 0:
logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id)
continue
+ if isinstance(msg, dict) and msg.get("type") == "sleep":
+ task = OmniSleepTask(level=msg.get("level", 2), task_id=msg.get("task_id", "local"))
+ ack = self.worker.handle_sleep_task(task)
+ self.return_result(ack)
+ elif isinstance(msg, dict) and msg.get("type") == "wake_up":
+ task = OmniWakeTask(tags=msg.get("tags"), task_id=msg.get("task_id", "local"))
+ ack = self.worker.handle_wake_task(task)
+ self.return_result(ack)
# Route message based on type
- if isinstance(msg, dict) and msg.get("type") == "rpc":
+ elif isinstance(msg, dict) and msg.get("type") == "rpc":
try:
result, should_reply = self.execute_rpc(msg)
if should_reply:
@@ -538,6 +668,7 @@ def worker_main(
od_config: OmniDiffusionConfig,
pipe_writer: mp.connection.Connection,
broadcast_handle,
+ wake_event: mp.Event,
worker_extension_cls: str | None = None,
custom_pipeline_args: dict[str, Any] | None = None,
) -> None:
@@ -549,6 +680,7 @@ def worker_main(
od_config,
gpu_id=rank,
broadcast_handle=broadcast_handle,
+ wake_event=wake_event,
worker_extension_cls=worker_extension_cls,
custom_pipeline_args=custom_pipeline_args,
)
@@ -574,6 +706,7 @@ def __init__(
gpu_id: int,
od_config: OmniDiffusionConfig,
base_worker_class: type = DiffusionWorker,
+ wake_event: mp.Event = None,
worker_extension_cls: str | None = None,
custom_pipeline_args: dict[str, Any] | None = None,
):
@@ -734,6 +867,12 @@ def wake_up(self, tags: list[str] | None = None) -> bool:
"""
return self.worker.wake_up(tags)
+ def handle_sleep_task(self, task):
+ return self.worker.handle_sleep_task(task)
+
+ def handle_wake_task(self, task):
+ return self.worker.handle_wake_task(task)
+
def shutdown(self) -> None:
"""Shutdown the worker and cleanup resources."""
return self.worker.shutdown()
diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py
index 6c92d7952de..55d119ad65c 100644
--- a/vllm_omni/engine/__init__.py
+++ b/vllm_omni/engine/__init__.py
@@ -76,6 +76,42 @@ class OmniEngineCoreRequest(EngineCoreRequest):
# Optional additional information dictionary (serialized)
additional_information: AdditionalInformationPayload | None = None
+ @classmethod
+ def from_request(
+ cls,
+ request: EngineCoreRequest,
+ *,
+ prompt_embeds: torch.Tensor | None = None,
+ additional_information: AdditionalInformationPayload | None = None,
+ ) -> "OmniEngineCoreRequest":
+ """Clone an EngineCoreRequest into an OmniEngineCoreRequest with optional payload overrides."""
+
+ if prompt_embeds is None:
+ prompt_embeds = request.prompt_embeds
+ if additional_information is None:
+ additional_information = getattr(request, "additional_information", None)
+
+ return cls(
+ request_id=request.request_id,
+ prompt_token_ids=request.prompt_token_ids,
+ mm_features=request.mm_features,
+ sampling_params=request.sampling_params,
+ pooling_params=request.pooling_params,
+ arrival_time=request.arrival_time,
+ lora_request=request.lora_request,
+ cache_salt=request.cache_salt,
+ data_parallel_rank=request.data_parallel_rank,
+ prompt_embeds=prompt_embeds,
+ client_index=request.client_index,
+ current_wave=request.current_wave,
+ priority=request.priority,
+ trace_headers=request.trace_headers,
+ resumable=request.resumable,
+ external_req_id=request.external_req_id,
+ reasoning_ended=request.reasoning_ended,
+ additional_information=additional_information,
+ )
+
class OmniEngineCoreOutput(EngineCoreOutput):
pooling_output: dict[str, torch.Tensor] | None = None
diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py
index d98ce7d419d..c93bd32c2f1 100644
--- a/vllm_omni/engine/arg_utils.py
+++ b/vllm_omni/engine/arg_utils.py
@@ -6,11 +6,12 @@
from dataclasses import dataclass, field, fields
from typing import Any
-from vllm.engine.arg_utils import EngineArgs
+from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.logger import init_logger
from vllm_omni.config import OmniModelConfig
from vllm_omni.engine.output_modality import OutputModality
+from vllm_omni.platforms import current_omni_platform
from vllm_omni.plugins import load_omni_general_plugins
logger = init_logger(__name__)
@@ -139,11 +140,30 @@ class OmniEngineArgs(EngineArgs):
hf_config_name: str | None = None
custom_process_next_stage_input_func: str | None = None
stage_connector_spec: dict[str, Any] = field(default_factory=dict)
+ subtalker_sampling_params: dict[str, Any] | None = None
async_chunk: bool = False
omni_kv_config: dict | None = None
quantization_config: Any | None = None
worker_type: str | None = None
task_type: str | None = None
+ worker_cls: str = None
+ enable_sleep_mode: bool = False
+ omni: bool = False
+
+ @classmethod
+ def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ try:
+ parser.add_argument("--omni", action="store_true", default=False, help="Enable Omni engine features.")
+ except argparse.ArgumentError:
+ pass
+ try:
+ parser.add_argument(
+ "--enable-sleep-mode", action="store_true", default=False, help="Enable GPU memory pool for sleep mode."
+ )
+ except argparse.ArgumentError:
+ pass
+ return parser
+
omni_master_address: str | None = None
omni_master_port: int | None = None
stage_configs_path: str | None = None
@@ -152,6 +172,11 @@ class OmniEngineArgs(EngineArgs):
custom_pipeline_args: dict[str, Any] | None = None
def __post_init__(self) -> None:
+ if self.worker_cls is None:
+ if self.worker_type == "ar":
+ self.worker_cls = current_omni_platform.get_omni_ar_worker_cls()
+ elif self.worker_type == "generation":
+ self.worker_cls = current_omni_platform.get_omni_generation_worker_cls()
load_omni_general_plugins()
super().__post_init__()
@@ -291,11 +316,21 @@ def create_model_config(self) -> OmniModelConfig:
hf_config_name=self.hf_config_name,
custom_process_next_stage_input_func=self.custom_process_next_stage_input_func,
stage_connector_config=stage_connector_config,
+ subtalker_sampling_params=self.subtalker_sampling_params,
omni_kv_config=self.omni_kv_config,
task_type=self.task_type,
)
return omni_config
+
+@dataclass
+class OmniAsyncEngineArgs(AsyncEngineArgs, OmniEngineArgs):
+ @classmethod
+ def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser = AsyncEngineArgs.add_cli_args(parser)
+ parser = OmniEngineArgs._add_omni_specific_args(parser)
+ return parser
+
@property
def output_modality(self) -> OutputModality:
"""Parse engine_output_type into a type-safe OutputModality flag."""
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index d02af153c45..32450bf2024 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -30,7 +30,6 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
-from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
@@ -46,7 +45,6 @@
)
from vllm_omni.engine import OmniEngineCoreRequest
from vllm_omni.engine.orchestrator import Orchestrator
-from vllm_omni.engine.output_processor import MultimodalOutputProcessor
from vllm_omni.engine.serialization import (
deserialize_additional_information,
serialize_additional_information,
@@ -63,17 +61,17 @@
register_stage_with_omni_master,
)
from vllm_omni.engine.stage_init_utils import (
- StartedLlmStage,
+ LogicalStageInitPlan,
+ ReplicaInitPlan,
_inject_inferred_kv_tp_topology,
acquire_device_locks,
build_diffusion_config,
build_engine_args_dict,
+ build_llm_stage_output_processor,
+ build_stage0_input_processor,
build_vllm_config,
- cleanup_failed_stage_initialization,
- close_started_llm_stage,
compute_replica_layout,
extract_stage_metadata,
- finalize_initialized_stages,
get_stage_connector_spec,
initialize_diffusion_stage,
inject_kv_stage_info,
@@ -89,7 +87,6 @@
inject_omni_kv_config,
load_and_resolve_stage_configs,
)
-from vllm_omni.inputs.preprocess import OmniInputPreprocessor
from vllm_omni.platforms import current_omni_platform
if TYPE_CHECKING:
@@ -97,6 +94,8 @@
logger = init_logger(__name__)
+_STARTUP_POLL_INTERVAL_S = 1.0
+
# ============================================================================
# Parent-EngineArgs field-routing contracts (consumed by
@@ -119,16 +118,6 @@
_PARENT_ARGS_NO_WARN: frozenset[str] = frozenset({"model"})
-def _patch_generation_config_if_needed(model_config: Any) -> None:
- """Ensure try_get_generation_config won't crash for models whose HF
- config.json lacks model_type (e.g. CosyVoice3). We probe it once;
- if it raises, we monkey-patch the method to return None."""
- try:
- model_config.try_get_generation_config()
- except Exception:
- model_config.try_get_generation_config = lambda: {}
-
-
def _inject_global_id(target: Any, request_id: str) -> None:
"""Inject global_request_id into a prompt dict's additional_information."""
if isinstance(target, dict):
@@ -161,24 +150,9 @@ def _upgrade_to_omni_request(
if prompt_embeds is None and additional_information is None:
return request
- return OmniEngineCoreRequest(
- request_id=request.request_id,
- prompt_token_ids=request.prompt_token_ids,
- mm_features=request.mm_features,
- sampling_params=request.sampling_params,
- pooling_params=request.pooling_params,
- arrival_time=request.arrival_time,
- lora_request=request.lora_request,
- cache_salt=request.cache_salt,
- data_parallel_rank=request.data_parallel_rank,
+ return OmniEngineCoreRequest.from_request(
+ request,
prompt_embeds=prompt_embeds,
- client_index=request.client_index,
- current_wave=request.current_wave,
- priority=request.priority,
- trace_headers=request.trace_headers,
- resumable=request.resumable,
- external_req_id=request.external_req_id,
- reasoning_ended=request.reasoning_ended,
additional_information=additional_information,
)
@@ -193,24 +167,8 @@ def _apply_omni_final_stage_metadata(
merged = deserialize_additional_information(request.additional_information)
merged["omni_final_stage_id"] = final_stage_id
payload = serialize_additional_information(merged)
- return OmniEngineCoreRequest(
- request_id=request.request_id,
- prompt_token_ids=request.prompt_token_ids,
- mm_features=request.mm_features,
- sampling_params=request.sampling_params,
- pooling_params=request.pooling_params,
- arrival_time=request.arrival_time,
- lora_request=request.lora_request,
- cache_salt=request.cache_salt,
- data_parallel_rank=request.data_parallel_rank,
- prompt_embeds=request.prompt_embeds,
- client_index=request.client_index,
- current_wave=request.current_wave,
- priority=request.priority,
- trace_headers=request.trace_headers,
- resumable=request.resumable,
- external_req_id=request.external_req_id,
- reasoning_ended=request.reasoning_ended,
+ return OmniEngineCoreRequest.from_request(
+ request,
additional_information=payload,
)
@@ -311,6 +269,7 @@ def __init__(
)
self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs)
+ self._validate_single_stage_mode_replica_constraints()
self.num_stages = len(self.stage_configs)
stage0_args = getattr(self.stage_configs[0], "engine_args", None) if self.num_stages > 0 else None
@@ -343,22 +302,7 @@ def __init__(
name="orchestrator",
)
self.orchestrator_thread.start()
-
- # Wait for stage/runtime initialization result from orchestrator thread.
- try:
- startup_future.result(timeout=startup_timeout)
- except concurrent.futures.TimeoutError as e:
- try:
- self.shutdown()
- except Exception:
- logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup timeout")
- raise TimeoutError(f"Orchestrator did not become ready within {startup_timeout}s") from e
- except Exception:
- try:
- self.shutdown()
- except Exception:
- logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure")
- raise
+ self._wait_for_orchestrator_init(startup_future, startup_timeout)
# Stage runtime fields are assigned directly on self by the bootstrap thread.
self._weak_finalizer = weakref.finalize(
@@ -372,613 +316,658 @@ def __init__(
logger.info(f"[AsyncOmniEngine] Orchestrator ready with {self.num_stages} stages")
- def _launch_llm_stage(
+ @staticmethod
+ def _cleanup_launched_llm_resources(
+ *,
+ stage_id: int,
+ proc: Any = None,
+ engine_manager: Any = None,
+ coordinator: Any = None,
+ ) -> None:
+ """Release launch-only LLM resources when client creation never completed."""
+
+ if proc is not None:
+ try:
+ terminate_alive_proc(proc)
+ except Exception as cleanup_error:
+ logger.warning(
+ "[AsyncOmniEngine] Failed to terminate process for stage %s: %s",
+ stage_id,
+ cleanup_error,
+ )
+
+ for resource, resource_name in (
+ (engine_manager, "engine manager"),
+ (coordinator, "coordinator"),
+ ):
+ if resource is None:
+ continue
+ shutdown = getattr(resource, "shutdown", None)
+ close = getattr(resource, "close", None)
+ try:
+ if callable(shutdown):
+ shutdown()
+ elif callable(close):
+ close()
+ except Exception as cleanup_error:
+ logger.warning(
+ "[AsyncOmniEngine] Failed to cleanup launched %s for stage %s: %s",
+ resource_name,
+ stage_id,
+ cleanup_error,
+ )
+
+ @staticmethod
+ def _collect_initialized_clients_for_cleanup(
+ stage_pools: Sequence[Any],
+ initialized_clients_by_stage: Mapping[int, Sequence[Any | None]],
+ ) -> list[Any]:
+ """Collect initialized clients exactly once for failure cleanup."""
+
+ collected: list[Any] = []
+ seen: set[int] = set()
+
+ def _add_client(client: Any) -> None:
+ if client is None:
+ return
+ client_id = id(client)
+ if client_id in seen:
+ return
+ seen.add(client_id)
+ collected.append(client)
+
+ for pool in stage_pools:
+ for client in getattr(pool, "clients", ()):
+ _add_client(client)
+
+ for clients in initialized_clients_by_stage.values():
+ for client in clients:
+ _add_client(client)
+
+ return collected
+
+ @staticmethod
+ def _shutdown_initialized_clients(clients: Sequence[Any]) -> None:
+ """Best-effort shutdown for attached clients after init failure."""
+
+ for client in reversed(list(clients)):
+ if client is None:
+ continue
+ try:
+ client.shutdown()
+ except Exception as cleanup_error:
+ logger.warning(
+ "[AsyncOmniEngine] Failed to shutdown initialized client after init failure: %s",
+ cleanup_error,
+ )
+
+ def _validate_single_stage_mode_replica_constraints(self) -> None:
+ """Reject replica fan-out in single-stage mode until startup is replica-aware."""
+ if not self.single_stage_mode:
+ return
+
+ unsupported: list[tuple[int, int]] = []
+ for idx, stage_cfg in enumerate(self.stage_configs):
+ runtime_cfg = getattr(stage_cfg, "runtime", {})
+ num_replicas = int(
+ runtime_cfg.get("num_replicas", 1)
+ if hasattr(runtime_cfg, "get")
+ else getattr(runtime_cfg, "num_replicas", 1)
+ )
+ if num_replicas <= 1:
+ continue
+ stage_id = int(getattr(stage_cfg, "stage_id", idx))
+ unsupported.append((stage_id, num_replicas))
+
+ if unsupported:
+ # TODO(Peiqi): single_stage_mode / headless launch still assumes one
+ # OmniMasterServer endpoint allocation per logical stage. Support
+ # per-replica startup only after remote/local startup paths are made
+ # replica-aware end-to-end.
+ raise ValueError(f"single_stage_mode does not support num_replicas > 1 yet; found {unsupported}")
+
+ def _build_logical_stage_init_plans(
+ self,
+ omni_transfer_config: Any,
+ replicas_per_stage: Sequence[int],
+ replica_devices_map: Mapping[int, Sequence[str]],
+ ) -> tuple[list[LogicalStageInitPlan], Any]:
+ """Build startup plans for every logical stage and replica."""
+
+ prompt_expand_func = None
+ stage_plans: list[LogicalStageInitPlan] = []
+
+ for stage_idx, stage_cfg in enumerate(self.stage_configs):
+ base_metadata = extract_stage_metadata(stage_cfg)
+ configured_stage_id = base_metadata.stage_id
+ if base_metadata.prompt_expand_func is not None:
+ prompt_expand_func = base_metadata.prompt_expand_func
+
+ stage_connector_spec = get_stage_connector_spec(
+ omni_transfer_config=omni_transfer_config,
+ stage_id=configured_stage_id,
+ async_chunk=self.async_chunk,
+ )
+ omni_kv_connector = resolve_omni_kv_config_for_stage(omni_transfer_config, configured_stage_id)
+ num_replicas = replicas_per_stage[stage_idx]
+ launch_mode = "local"
+ if (
+ self.single_stage_mode
+ and self._single_stage_id_filter is not None
+ and configured_stage_id != self._single_stage_id_filter
+ ):
+ launch_mode = "remote"
+
+ replicas: list[ReplicaInitPlan] = []
+ stage_vllm_config = None
+ executor_class = None
+ if base_metadata.stage_type != "diffusion":
+ engine_args_dict = build_engine_args_dict(
+ stage_cfg,
+ self.model,
+ stage_connector_spec=stage_connector_spec,
+ )
+ omni_conn_cfg, omni_from, omni_to = omni_kv_connector
+ if omni_conn_cfg:
+ omni_kv = engine_args_dict.get("omni_kv_config") or {}
+ if not isinstance(omni_kv, dict):
+ omni_kv = dict(omni_kv)
+ omni_kv["connector_config"] = omni_conn_cfg
+ omni_kv["omni_from_stage"] = omni_from
+ omni_kv["omni_to_stage"] = omni_to
+ omni_kv.setdefault("stage_id", configured_stage_id)
+ engine_args_dict["omni_kv_config"] = omni_kv
+ if self.stage_configs:
+ _inject_inferred_kv_tp_topology(
+ engine_args_dict.get("omni_kv_config"),
+ configured_stage_id,
+ self.stage_configs,
+ )
+ stage_vllm_config, executor_class = build_vllm_config(
+ stage_cfg,
+ self.model,
+ stage_connector_spec=stage_connector_spec,
+ engine_args_dict=engine_args_dict,
+ )
+
+ for replica_id in range(num_replicas):
+ replica_cfg = copy.deepcopy(stage_cfg) if replica_id > 0 else stage_cfg
+ if stage_idx in replica_devices_map:
+ replica_cfg.runtime.devices = replica_devices_map[stage_idx][replica_id]
+
+ replica_metadata = extract_stage_metadata(replica_cfg)
+ replica_metadata.replica_id = replica_id
+ if self.single_stage_mode:
+ replica_metadata.runtime_cfg = None
+
+ replicas.append(
+ ReplicaInitPlan(
+ replica_id=replica_id,
+ num_replicas=num_replicas,
+ launch_mode=launch_mode,
+ stage_cfg=replica_cfg,
+ metadata=replica_metadata,
+ stage_connector_spec=stage_connector_spec,
+ omni_kv_connector=omni_kv_connector,
+ stage_vllm_config=stage_vllm_config,
+ executor_class=executor_class,
+ )
+ )
+
+ stage_plans.append(
+ LogicalStageInitPlan(
+ stage_idx=stage_idx,
+ configured_stage_id=configured_stage_id,
+ replicas=replicas,
+ )
+ )
+
+ return stage_plans, prompt_expand_func
+
+ def _start_omni_master_server(self, stage_plans: Sequence[LogicalStageInitPlan]) -> None:
+ """Start OmniMasterServer for single-stage mode."""
+
+ if not self._omni_master_address or not self._omni_master_port:
+ raise ValueError(
+ "AsyncOmniEngine single_stage_mode requires both omni_master_address and omni_master_port to be set."
+ )
+
+ all_stage_ids: list[int] = []
+ seen_stage_ids: set[int] = set()
+ for plan in stage_plans:
+ stage_id = plan.configured_stage_id
+ if stage_id in seen_stage_ids:
+ raise ValueError(
+ f"Duplicate stage_id {stage_id!r} detected among configured stages; stage_ids must be unique."
+ )
+ seen_stage_ids.add(stage_id)
+ all_stage_ids.append(stage_id)
+
+ self._omni_master_server = OmniMasterServer(
+ master_address=self._omni_master_address,
+ master_port=self._omni_master_port,
+ stage_ids=all_stage_ids,
+ )
+ self._omni_master_server.start()
+ logger.info(
+ "[AsyncOmniEngine] OmniMasterServer started for stages %s",
+ all_stage_ids,
+ )
+
+ def _initialize_llm_replica(
self,
- stage_cfg: Any,
- metadata: Any,
- stage_connector_spec: dict[str, Any],
+ plan: ReplicaInitPlan,
stage_init_timeout: int,
llm_stage_launch_lock: threading.Lock,
- omni_kv_connector: tuple[dict[str, Any] | None, str | None, str | None] = (None, None, None),
- ) -> StartedLlmStage:
- """Launch one LLM stage to READY state in a helper thread."""
- started_stage: StartedLlmStage | None = None
+ ) -> Any:
+ """Initialize one LLM replica end-to-end."""
+
+ proc = None
+ engine_manager = None
+ coordinator = None
+ stage_client = None
lock_fds: list[int] = []
device_control_env = current_omni_platform.device_control_env_var
+ stage_cfg = plan.stage_cfg
+
try:
- proc = None
- handshake_address = None
- with ExitStack() as launch_stack:
- with llm_stage_launch_lock:
- previous_visible_devices = os.environ.get(device_control_env)
- try:
- setup_stage_devices(metadata.stage_id, metadata.runtime_cfg)
- engine_args_dict = build_engine_args_dict(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- )
- omni_conn_cfg, omni_from, omni_to = omni_kv_connector
- if omni_conn_cfg:
- omni_kv = engine_args_dict.get("omni_kv_config") or {}
- if not isinstance(omni_kv, dict):
- omni_kv = dict(omni_kv)
- omni_kv["connector_config"] = omni_conn_cfg
- omni_kv["omni_from_stage"] = omni_from
- omni_kv["omni_to_stage"] = omni_to
- omni_kv.setdefault("stage_id", metadata.stage_id)
- engine_args_dict["omni_kv_config"] = omni_kv
- if self.stage_configs:
- _inject_inferred_kv_tp_topology(
- engine_args_dict.get("omni_kv_config"),
- metadata.stage_id,
- self.stage_configs,
+ if plan.launch_mode == "remote":
+ assert self._omni_master_server is not None
+ raw_stage_cfg = self._omni_master_server.get_stage_config(
+ plan.metadata.stage_id,
+ timeout_s=stage_init_timeout,
+ )
+ if raw_stage_cfg is None:
+ raise ValueError(f"Remote stage {plan.metadata.stage_id} registered without stage config")
+ vllm_config = plan.stage_vllm_config
+ executor_class = plan.executor_class
+ assert vllm_config is not None
+ assert executor_class is not None
+ vllm_config.parallel_config.data_parallel_size_local = 0
+ launch_cm = connect_remote_engine_cores(
+ vllm_config=vllm_config,
+ omni_master_server=self._omni_master_server,
+ stage_id=plan.metadata.stage_id,
+ )
+ logger.info(
+ "[AsyncOmniEngine] Stage %s remote engine handshake started",
+ plan.metadata.stage_id,
+ )
+ with launch_cm as (engine_manager, coordinator, addresses):
+ client_addresses: dict[str, str] = {
+ "input_address": addresses.inputs[0],
+ "output_address": addresses.outputs[0],
+ }
+ if addresses.frontend_stats_publish_address is not None:
+ client_addresses["stats_update_address"] = addresses.frontend_stats_publish_address
+ stage_client = StageEngineCoreClientBase.make_async_mp_client(
+ vllm_config=vllm_config,
+ executor_class=executor_class,
+ metadata=plan.metadata,
+ client_addresses=client_addresses,
+ engine_manager=engine_manager,
+ coordinator=coordinator,
+ )
+ else:
+ handshake_address = None
+ with ExitStack() as launch_stack:
+ with llm_stage_launch_lock:
+ previous_visible_devices = os.environ.get(device_control_env)
+ try:
+ setup_stage_devices(plan.metadata.stage_id, plan.metadata.runtime_cfg)
+ vllm_config = plan.stage_vllm_config
+ executor_class = plan.executor_class
+ assert vllm_config is not None
+ assert executor_class is not None
+ engine_args_dict = build_engine_args_dict(
+ stage_cfg,
+ self.model,
+ stage_connector_spec=plan.stage_connector_spec,
+ )
+ lock_fds = acquire_device_locks(
+ plan.metadata.stage_id,
+ engine_args_dict,
+ stage_init_timeout,
)
- vllm_config, executor_class = build_vllm_config(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- engine_args_dict=engine_args_dict,
- )
- lock_fds = acquire_device_locks(
- metadata.stage_id,
- engine_args_dict,
- stage_init_timeout,
- )
- if self.single_stage_mode and self._omni_master_server is not None:
- engine_manager, coordinator, addresses = launch_stack.enter_context(
- launch_omni_core_engines(
+ if self.single_stage_mode and self._omni_master_server is not None:
+ engine_manager, coordinator, addresses = launch_stack.enter_context(
+ launch_omni_core_engines(
+ vllm_config=vllm_config,
+ executor_class=executor_class,
+ log_stats=False,
+ omni_master_server=self._omni_master_server,
+ stage_id=plan.metadata.stage_id,
+ stage_config=stage_cfg,
+ )
+ )
+ else:
+ addresses, proc, handshake_address = spawn_stage_core(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
- omni_master_server=self._omni_master_server,
- stage_id=metadata.stage_id,
- stage_config=stage_cfg,
)
+ logger.info(
+ "[AsyncOmniEngine] Stage %s engine launch started",
+ plan.metadata.stage_id,
)
- started_stage = StartedLlmStage(
- stage_id=metadata.stage_id,
- metadata=metadata,
- vllm_config=vllm_config,
- executor_class=executor_class,
- addresses=addresses,
- engine_manager=engine_manager,
- coordinator=coordinator,
- )
- else:
- addresses, proc, handshake_address = spawn_stage_core(
- vllm_config=vllm_config,
- executor_class=executor_class,
- log_stats=False,
- )
- started_stage = StartedLlmStage(
- stage_id=metadata.stage_id,
- metadata=metadata,
- vllm_config=vllm_config,
- executor_class=executor_class,
- addresses=addresses,
- proc=proc,
- )
- logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id)
- finally:
- if previous_visible_devices is None:
- current_omni_platform.unset_device_control_env_var()
- else:
- current_omni_platform.set_device_control_env_var(previous_visible_devices)
+ finally:
+ if previous_visible_devices is None:
+ current_omni_platform.unset_device_control_env_var()
+ else:
+ current_omni_platform.set_device_control_env_var(previous_visible_devices)
+
+ if self.single_stage_mode and self._omni_master_server is not None:
+ launch_stack.close()
+ else:
+ assert proc is not None
+ assert handshake_address is not None
+ complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout)
+ logger.info(
+ "[AsyncOmniEngine] Stage %s engine startup completed",
+ plan.metadata.stage_id,
+ )
+
+ client_addresses: dict[str, str] = {
+ "input_address": addresses.inputs[0],
+ "output_address": addresses.outputs[0],
+ }
+ if addresses.frontend_stats_publish_address is not None:
+ client_addresses["stats_update_address"] = addresses.frontend_stats_publish_address
+ stage_client = StageEngineCoreClientBase.make_async_mp_client(
+ vllm_config=vllm_config,
+ executor_class=executor_class,
+ metadata=plan.metadata,
+ client_addresses=client_addresses,
+ proc=proc,
+ engine_manager=engine_manager,
+ coordinator=coordinator,
+ )
- # After StageEngineCoreProc has been spawned it carries its
- # stage-specific device visibility into descendants, so the
- # slow HELLO/READY handshake can run without holding the
- # process-wide launch lock.
- if self.single_stage_mode and self._omni_master_server is not None:
- launch_stack.close()
- else:
- assert proc is not None
- assert handshake_address is not None
- complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout)
- logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id)
-
- assert started_stage is not None
- return started_stage
+ logger.info("[AsyncOmniEngine] Stage %s initialized", plan.metadata.stage_id)
+ return stage_client
except Exception:
- if started_stage is not None:
- close_started_llm_stage(started_stage)
+ if stage_client is not None:
+ try:
+ stage_client.shutdown()
+ except Exception as cleanup_error:
+ logger.warning(
+ "[AsyncOmniEngine] Failed to cleanup stage %s after attach failure: %s",
+ plan.metadata.stage_id,
+ cleanup_error,
+ )
+ else:
+ self._cleanup_launched_llm_resources(
+ stage_id=plan.metadata.stage_id,
+ proc=proc,
+ engine_manager=engine_manager,
+ coordinator=coordinator,
+ )
raise
finally:
if lock_fds:
release_device_locks(lock_fds)
- def _create_remote_llm_stage(
+ def _initialize_diffusion_replica(
self,
- stage_cfg: Any,
- metadata: Any,
- stage_connector_spec: dict[str, Any],
+ plan: ReplicaInitPlan,
stage_init_timeout: int,
- omni_master_server: OmniMasterServer,
- ) -> StartedLlmStage:
- """Attach to a remote engine core and wait for its startup handshake."""
- started_stage: StartedLlmStage | None = None
- try:
- raw_stage_cfg = omni_master_server.get_stage_config(
- metadata.stage_id,
- timeout_s=stage_init_timeout,
- )
- if raw_stage_cfg is None:
- raise ValueError(f"Remote stage {metadata.stage_id} registered without stage config")
- stage_cfg = OmegaConf.create(raw_stage_cfg)
- engine_args_dict = build_engine_args_dict(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- )
- vllm_config, executor_class = build_vllm_config(
- stage_cfg,
- self.model,
- stage_connector_spec=stage_connector_spec,
- engine_args_dict=engine_args_dict,
- )
- vllm_config.parallel_config.data_parallel_size_local = 0
- launch_cm = connect_remote_engine_cores(
- vllm_config=vllm_config,
- omni_master_server=omni_master_server,
- stage_id=metadata.stage_id,
- )
- logger.info("[AsyncOmniEngine] Stage %s remote engine handshake started", metadata.stage_id)
- with launch_cm as (engine_manager, coordinator, addresses):
- started_stage = StartedLlmStage(
- stage_id=metadata.stage_id,
- metadata=metadata,
- vllm_config=vllm_config,
- executor_class=executor_class,
- engine_manager=engine_manager,
- coordinator=coordinator,
- addresses=addresses,
- )
- logger.info("[AsyncOmniEngine] Stage %s remote engine startup completed", metadata.stage_id)
- assert started_stage is not None
- return started_stage
- except Exception:
- if started_stage is not None:
- close_started_llm_stage(started_stage)
- raise
+ stage_launch_lock: threading.Lock,
+ ) -> Any:
+ """Initialize one diffusion replica end-to-end."""
- def _launch_diffusion_stage(
- self,
- stage_cfg: Any,
- metadata: Any,
- omni_master_server: OmniMasterServer,
- ) -> StageDiffusionClient:
- """Launch a local diffusion stage on OmniMasterServer-allocated sockets."""
+ client = None
proc = None
try:
- od_config = build_diffusion_config(self.model, stage_cfg, metadata)
- handshake_address, request_address, response_address = register_stage_with_omni_master(
- omni_master_address=omni_master_server.address,
- omni_master_port=omni_master_server.port,
- omni_stage_id=metadata.stage_id,
- omni_stage_config=stage_cfg,
- return_addresses=True,
- )
- logger.info(
- "[AsyncOmniEngine] Stage %s diffusion registration completed",
- metadata.stage_id,
- )
- proc, _, _, _ = spawn_diffusion_proc(
- self.model,
- od_config,
- handshake_address=handshake_address,
- request_address=request_address,
- response_address=response_address,
- )
- complete_diffusion_handshake(proc, handshake_address)
+ if plan.launch_mode == "remote":
+ assert self._omni_master_server is not None
+ remote_stage_cfg = OmegaConf.create(
+ self._omni_master_server.get_stage_config(
+ plan.metadata.stage_id,
+ timeout_s=stage_init_timeout,
+ )
+ )
+ remote_metadata = extract_stage_metadata(remote_stage_cfg)
+ addresses = self._omni_master_server.get_zmq_addresses(plan.metadata.stage_id)
+ logger.info(
+ "[AsyncOmniEngine] Stage %s remote diffusion startup completed",
+ plan.metadata.stage_id,
+ )
+ client = StageDiffusionClient.from_addresses(
+ remote_metadata,
+ request_address=addresses.inputs[0],
+ response_address=addresses.outputs[0],
+ batch_size=self.diffusion_batch_size,
+ )
+ else:
+ device_control_env = current_omni_platform.device_control_env_var
+ with stage_launch_lock:
+ previous_visible_devices = os.environ.get(device_control_env)
+ try:
+ setup_stage_devices(plan.metadata.stage_id, plan.metadata.runtime_cfg)
+ omni_conn_cfg, omni_from, omni_to = plan.omni_kv_connector
+ if omni_conn_cfg:
+ inject_omni_kv_config(plan.stage_cfg, omni_conn_cfg, omni_from, omni_to)
+ inject_kv_stage_info(plan.stage_cfg, plan.metadata.stage_id, self.stage_configs)
+ if self.single_stage_mode:
+ assert self._omni_master_server is not None
+ od_config = build_diffusion_config(self.model, plan.stage_cfg, plan.metadata)
+ handshake_address, request_address, response_address = register_stage_with_omni_master(
+ omni_master_address=self._omni_master_server.address,
+ omni_master_port=self._omni_master_server.port,
+ omni_stage_id=plan.metadata.stage_id,
+ omni_stage_config=plan.stage_cfg,
+ return_addresses=True,
+ )
+ logger.info(
+ "[AsyncOmniEngine] Stage %s diffusion registration completed",
+ plan.metadata.stage_id,
+ )
+ proc, _, _, _ = spawn_diffusion_proc(
+ self.model,
+ od_config,
+ handshake_address=handshake_address,
+ request_address=request_address,
+ response_address=response_address,
+ )
+ complete_diffusion_handshake(proc, handshake_address)
+ logger.info(
+ "[AsyncOmniEngine] Stage %s diffusion startup completed",
+ plan.metadata.stage_id,
+ )
+ client = StageDiffusionClient.from_addresses(
+ plan.metadata,
+ request_address=request_address,
+ response_address=response_address,
+ proc=proc,
+ batch_size=self.diffusion_batch_size,
+ )
+ else:
+ client = initialize_diffusion_stage(
+ plan.metadata.stage_id,
+ self.model,
+ plan.stage_cfg,
+ plan.metadata,
+ stage_init_timeout=stage_init_timeout,
+ batch_size=self.diffusion_batch_size,
+ use_inline=self.num_stages == 1 and plan.num_replicas == 1,
+ )
+ finally:
+ if previous_visible_devices is None:
+ current_omni_platform.unset_device_control_env_var()
+ else:
+ current_omni_platform.set_device_control_env_var(previous_visible_devices)
+
logger.info(
- "[AsyncOmniEngine] Stage %s diffusion startup completed",
- metadata.stage_id,
- )
- return StageDiffusionClient.from_addresses(
- metadata,
- request_address=request_address,
- response_address=response_address,
- proc=proc,
- batch_size=self.diffusion_batch_size,
+ "[AsyncOmniEngine] Stage %s replica %s initialized (diffusion, batch_size=%d, devices=%s)",
+ plan.metadata.stage_id,
+ plan.replica_id,
+ self.diffusion_batch_size,
+ getattr(getattr(plan.stage_cfg, "runtime", None), "devices", "default"),
)
+ return client
except Exception:
if proc is not None:
terminate_alive_proc(proc)
raise
- def _create_remote_diffusion_stage(
+ def _initialize_replica(
self,
- metadata: Any,
+ plan: ReplicaInitPlan,
stage_init_timeout: int,
- omni_master_server: OmniMasterServer,
- ) -> StageDiffusionClient:
- """Attach to a remote diffusion stage registered with OmniMasterServer."""
- remote_stage_cfg = OmegaConf.create(
- omni_master_server.get_stage_config(
- metadata.stage_id,
- timeout_s=stage_init_timeout,
- )
- )
- remote_metadata = extract_stage_metadata(remote_stage_cfg)
- addresses = omni_master_server.get_zmq_addresses(metadata.stage_id)
- logger.info(
- "[AsyncOmniEngine] Stage %s remote diffusion startup completed",
- metadata.stage_id,
- )
- return StageDiffusionClient.from_addresses(
- remote_metadata,
- request_address=addresses.inputs[0],
- response_address=addresses.outputs[0],
- batch_size=self.diffusion_batch_size,
- )
+ stage_launch_lock: threading.Lock,
+ ) -> Any:
+ """Initialize one replica, regardless of backend type."""
- def _attach_llm_stage(
+ if plan.metadata.stage_type == "diffusion":
+ return self._initialize_diffusion_replica(plan, stage_init_timeout, stage_launch_lock)
+ return self._initialize_llm_replica(plan, stage_init_timeout, stage_launch_lock)
+
+ def _initialize_stage_replicas(
self,
- started: StartedLlmStage,
- ) -> tuple[Any, Any | None, Any, InputProcessor | None]:
- """Attach a READY LLM stage to the orchestrator event loop."""
+ stage_plans: Sequence[LogicalStageInitPlan],
+ stage_init_timeout: int,
+ ) -> dict[int, list[Any | None]]:
+ """Initialize all stage replicas in parallel."""
- client_addresses: dict[str, str] = {
- "input_address": started.addresses.inputs[0],
- "output_address": started.addresses.outputs[0],
+ stage_launch_lock = threading.Lock()
+ initialized_clients_by_stage: dict[int, list[Any | None]] = {
+ plan.stage_idx: [None] * len(plan.replicas) for plan in stage_plans
}
- if started.addresses.frontend_stats_publish_address is not None:
- client_addresses["stats_update_address"] = started.addresses.frontend_stats_publish_address
+ total_replicas = sum(len(plan.replicas) for plan in stage_plans)
+ future_to_replica: dict[concurrent.futures.Future[Any], tuple[int, int]] = {}
+
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=max(1, total_replicas),
+ thread_name_prefix="stage-init",
+ ) as init_executor:
+ for plan in stage_plans:
+ for replica in plan.replicas:
+ future = init_executor.submit(
+ self._initialize_replica,
+ replica,
+ stage_init_timeout,
+ stage_launch_lock,
+ )
+ future_to_replica[future] = (plan.stage_idx, replica.replica_id)
- try:
- stage_client = StageEngineCoreClientBase.make_async_mp_client(
- vllm_config=started.vllm_config,
- executor_class=started.executor_class,
- metadata=started.metadata,
- client_addresses=client_addresses,
- proc=started.proc,
- engine_manager=started.engine_manager,
- coordinator=started.coordinator,
- )
- started.proc = None
- started.engine_manager = None
- started.coordinator = None
- except Exception:
- close_started_llm_stage(started)
- raise
+ try:
+ for future in concurrent.futures.as_completed(future_to_replica):
+ stage_idx, replica_id = future_to_replica[future]
+ initialized_clients_by_stage[stage_idx][replica_id] = future.result()
+ except Exception as exc:
+ for future, (stage_idx, replica_id) in future_to_replica.items():
+ if not future.done() or future.cancelled() or future.exception() is not None:
+ continue
+ if initialized_clients_by_stage[stage_idx][replica_id] is None:
+ initialized_clients_by_stage[stage_idx][replica_id] = future.result()
+ setattr(exc, "_initialized_clients_by_stage", initialized_clients_by_stage)
+ raise
- try:
+ return initialized_clients_by_stage
+
+ def _assemble_stage_pools(
+ self,
+ stage_plans: Sequence[LogicalStageInitPlan],
+ initialized_clients_by_stage: Mapping[int, Sequence[Any | None]],
+ ) -> list[StagePool]:
+ """Assemble logical stage pools and update top-level stage metadata."""
+
+ stage_pools: list[StagePool] = []
+ default_sampling_params_list: list[Any] = []
+ stage_metadata_list: list[dict[str, Any]] = []
+
+ for plan in stage_plans:
+ replica_clients = initialized_clients_by_stage[plan.stage_idx]
+ first_client = replica_clients[0] if replica_clients else None
+ if first_client is None:
+ raise RuntimeError(f"Stage {plan.stage_idx} initialization completed with a missing client")
+
+ clients = [client for client in replica_clients if client is not None]
+ stage_vllm_config = None
output_processor = None
- if getattr(started.metadata, "replica_id", 0) == 0:
- if started.vllm_config.model_config.skip_tokenizer_init:
- tokenizer = None
- else:
- tokenizer = cached_tokenizer_from_config(
- model_config=started.vllm_config.model_config,
- )
- output_processor = MultimodalOutputProcessor(
- tokenizer=tokenizer,
- log_stats=False,
- engine_core_output_type=started.metadata.engine_output_type,
- )
- input_processor = None
- if started.stage_id == 0:
- # Some omni models (e.g. CosyVoice3) have an empty HF
- # config.json without model_type, which causes
- # try_get_generation_config -> AutoConfig.from_pretrained
- # to raise ValueError. Patch it to return None so
- # InputProcessor doesn't crash.
- _patch_generation_config_if_needed(started.vllm_config.model_config)
- input_processor = InputProcessor(vllm_config=started.vllm_config)
- # Use omni preprocessor so text-only prompts with
- # mm_processor_kwargs (e.g. GLM-Image t2i target_h/target_w)
- # still go through multimodal processor path.
- input_processor.input_preprocessor = OmniInputPreprocessor(
- vllm_config=started.vllm_config,
- renderer=input_processor.renderer,
- )
- except Exception:
- try:
- stage_client.shutdown()
- except Exception as cleanup_error:
- logger.warning(
- "[AsyncOmniEngine] Failed to cleanup stage %s after attach failure: %s",
- started.stage_id,
- cleanup_error,
+ if plan.replicas[0].metadata.stage_type != "diffusion":
+ stage_vllm_config = plan.replicas[0].stage_vllm_config
+ assert stage_vllm_config is not None
+ output_processor = build_llm_stage_output_processor(plan, stage_vllm_config)
+
+ stage_pools.append(
+ StagePool(
+ plan.stage_idx,
+ clients,
+ output_processor=output_processor,
+ stage_vllm_config=stage_vllm_config,
)
- raise
+ )
+ default_sampling_params_list.append(first_client.default_sampling_params)
+ stage_metadata_list.append(
+ {
+ "final_output": first_client.final_output,
+ "final_output_type": first_client.final_output_type,
+ "stage_type": first_client.stage_type,
+ }
+ )
- logger.info("[AsyncOmniEngine] Stage %s initialized", started.stage_id)
- return stage_client, output_processor, started.vllm_config, input_processor
+ self.default_sampling_params_list = list(default_sampling_params_list)
+ self.stage_metadata = list(stage_metadata_list)
+ return stage_pools
def _initialize_stages(self, stage_init_timeout: int) -> None:
"""Initialize stage clients/processors in orchestrator thread and assign to self.
Phases:
1. Compute replica layout (counts + device splits).
- 2. Launch all stage engine processes (parallel via ThreadPoolExecutor).
- 3. Attach launched engines (parallel) and collect clients/processors.
- 4. Build StagePool list and finalize stage metadata.
+ 2. Build per-stage/per-replica startup plans.
+ 3. Initialize all replicas in parallel via backend-specific launchers.
+ 4. Build logical StagePools and finalize runtime metadata.
TODO(stage-pool): move per-stage launch + attach logic into a
StagePool.build_from_config() classmethod so this method only
iterates stage_configs, collects pools, and finalizes metadata.
"""
- device_control_env = current_omni_platform.device_control_env_var
- num_stages = self.num_stages
-
- replicas_per_stage, replica_devices_map, total_llm_replicas = compute_replica_layout(self.stage_configs)
+ num_stages = len(self.stage_configs)
+ self.num_stages = num_stages
+ self._validate_single_stage_mode_replica_constraints()
- input_processor: InputProcessor | None = None
- llm_stage_ids: list[int] = []
- llm_launch_futures: dict[int, list[concurrent.futures.Future[StartedLlmStage]]] = {}
- started_llm_stages: dict[int, list[StartedLlmStage]] = {}
- llm_stage_launch_lock = threading.Lock()
- diffusion_clients: dict[int, Any] = {}
- prompt_expand_func = None
- async_chunk = self.async_chunk
+ replicas_per_stage, replica_devices_map = compute_replica_layout(self.stage_configs)
prepare_engine_environment()
omni_transfer_config = load_omni_transfer_config_for_model(self.model, self.config_path)
+ stage_plans, prompt_expand_func = self._build_logical_stage_init_plans(
+ omni_transfer_config,
+ replicas_per_stage,
+ replica_devices_map,
+ )
+ if self.single_stage_mode:
+ self._start_omni_master_server(stage_plans)
stage_pools: list[StagePool] = []
-
- # ------------------------------------------------------------------ #
- # Single-stage mode: start OmniMasterServer before launching stages. #
- # ------------------------------------------------------------------ #
- if self.single_stage_mode:
- if not self._omni_master_address or not self._omni_master_port:
- raise ValueError(
- "AsyncOmniEngine single_stage_mode requires both "
- "omni_master_address and omni_master_port to be set."
- )
- # Collect all configured stage IDs for pre-allocation.
- all_stage_ids: list[int] = []
- seen_stage_ids: set[int] = set()
- for i, sc in enumerate(self.stage_configs):
- stage_id = int(getattr(sc, "stage_id", i))
- if stage_id in seen_stage_ids:
- raise ValueError(
- f"Duplicate stage_id {stage_id!r} detected among configured stages; stage_ids must be unique."
- )
- seen_stage_ids.add(stage_id)
- all_stage_ids.append(stage_id)
- self._omni_master_server = OmniMasterServer(
- master_address=self._omni_master_address,
- master_port=self._omni_master_port,
- stage_ids=all_stage_ids,
- )
- self._omni_master_server.start()
- logger.info(
- "[AsyncOmniEngine] OmniMasterServer started for stages %s",
- all_stage_ids,
- )
+ input_processor: InputProcessor | None = None
+ initialized_clients_by_stage: dict[int, list[Any | None]] = {
+ plan.stage_idx: [None] * len(plan.replicas) for plan in stage_plans
+ }
try:
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=max(1, total_llm_replicas),
- thread_name_prefix="llm-stage-launch",
- ) as launch_executor:
- for stage_idx, stage_cfg in enumerate(self.stage_configs):
- metadata = extract_stage_metadata(stage_cfg)
- configured_stage_id = metadata.stage_id
- logger.info("[AsyncOmniEngine] Initializing stage %s", configured_stage_id)
- if metadata.prompt_expand_func is not None:
- prompt_expand_func = metadata.prompt_expand_func
-
- if self.single_stage_mode:
- metadata.runtime_cfg = None
-
- stage_connector_spec = get_stage_connector_spec(
- omni_transfer_config=omni_transfer_config,
- stage_id=configured_stage_id,
- async_chunk=async_chunk,
- )
-
- omni_kv_connector = resolve_omni_kv_config_for_stage(omni_transfer_config, configured_stage_id)
-
- if metadata.stage_type == "diffusion":
- is_remote_diffusion_stage = (
- self.single_stage_mode
- and self._single_stage_id_filter is not None
- and configured_stage_id != self._single_stage_id_filter
- )
- if is_remote_diffusion_stage:
- assert self._omni_master_server is not None
- diffusion_clients[stage_idx] = self._create_remote_diffusion_stage(
- metadata,
- stage_init_timeout,
- self._omni_master_server,
- )
- continue
-
- with llm_stage_launch_lock:
- previous_visible_devices = os.environ.get(device_control_env)
- try:
- setup_stage_devices(configured_stage_id, metadata.runtime_cfg)
- omni_conn_cfg, omni_from, omni_to = omni_kv_connector
- if omni_conn_cfg:
- inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to)
- inject_kv_stage_info(stage_cfg, configured_stage_id, self.stage_configs)
- if self.single_stage_mode:
- assert self._omni_master_server is not None
- diffusion_clients[stage_idx] = self._launch_diffusion_stage(
- stage_cfg,
- metadata,
- self._omni_master_server,
- )
- else:
- use_inline = True if self.num_stages == 1 else False
- diffusion_clients[stage_idx] = initialize_diffusion_stage(
- self.model,
- stage_cfg,
- metadata,
- stage_init_timeout=stage_init_timeout,
- batch_size=self.diffusion_batch_size,
- use_inline=use_inline,
- )
- logger.info(
- "[AsyncOmniEngine] Stage %s initialized (diffusion, batch_size=%d)",
- configured_stage_id,
- self.diffusion_batch_size,
- )
- finally:
- if previous_visible_devices is None:
- current_omni_platform.unset_device_control_env_var()
- else:
- current_omni_platform.set_device_control_env_var(previous_visible_devices)
- continue
-
- # Submit one launch future per replica
- llm_stage_ids.append(stage_idx)
- num_replicas = replicas_per_stage[stage_idx]
-
- # single_stage_mode pre-dates multi-replica support; when
- # a stage runs remotely (doesn't match the local filter)
- # replica fan-out is delegated to the remote process, so
- # we launch exactly one client future per such stage.
- # TODO: support remote multi-replica with stage-pool
- if (
- self.single_stage_mode
- and self._single_stage_id_filter is not None
- and configured_stage_id != self._single_stage_id_filter
- ):
- assert self._omni_master_server is not None
- if num_replicas > 1:
- logger.warning(
- "[AsyncOmniEngine] Stage %s has num_replicas=%d but runs remotely in "
- "single_stage_mode; only one remote client will be created.",
- configured_stage_id,
- num_replicas,
- )
- llm_launch_futures[stage_idx] = [
- launch_executor.submit(
- self._create_remote_llm_stage,
- stage_cfg,
- metadata,
- stage_connector_spec,
- stage_init_timeout,
- self._omni_master_server,
- )
- ]
- continue
-
- stage_futures: list[concurrent.futures.Future[StartedLlmStage]] = []
- for replica_id in range(num_replicas):
- # For replica > 0, deep-copy stage_cfg and override devices
- if replica_id > 0:
- replica_cfg = copy.deepcopy(stage_cfg)
- else:
- replica_cfg = stage_cfg
-
- if stage_idx in replica_devices_map:
- replica_cfg.runtime.devices = replica_devices_map[stage_idx][replica_id]
-
- replica_metadata = extract_stage_metadata(replica_cfg)
- replica_metadata.replica_id = replica_id
-
- logger.info(
- "[AsyncOmniEngine] Launching stage %s replica %s (devices=%s)",
- configured_stage_id,
- replica_id,
- getattr(getattr(replica_cfg, "runtime", None), "devices", "default"),
- )
-
- stage_futures.append(
- launch_executor.submit(
- self._launch_llm_stage,
- replica_cfg,
- replica_metadata,
- stage_connector_spec,
- stage_init_timeout,
- llm_stage_launch_lock,
- omni_kv_connector,
- )
- )
-
- llm_launch_futures[stage_idx] = stage_futures
-
- # Wait for all futures across all stages
- all_futures = [f for futures in llm_launch_futures.values() for f in futures]
- concurrent.futures.wait(all_futures)
-
- for stage_idx in llm_stage_ids:
- started_llm_stages[stage_idx] = [f.result() for f in llm_launch_futures[stage_idx]]
-
- # ---- Parallel attach across (stage_idx, replica_id) pairs ----
- attach_futures: dict[
- concurrent.futures.Future[tuple[Any, Any, Any, InputProcessor | None]],
- tuple[int, int],
- ] = {}
- total_replicas_to_attach = sum(len(started_llm_stages[s]) for s in llm_stage_ids)
- with concurrent.futures.ThreadPoolExecutor(
- max_workers=max(1, total_replicas_to_attach),
- thread_name_prefix="llm-stage-attach",
- ) as attach_executor:
- for stage_idx in llm_stage_ids:
- for replica_id, started in enumerate(started_llm_stages[stage_idx]):
- attach_futures[attach_executor.submit(self._attach_llm_stage, started)] = (
- stage_idx,
- replica_id,
- )
-
- stage_attach_results: dict[int, list[Any | None]] = {
- s: [None] * len(started_llm_stages[s]) for s in llm_stage_ids
- }
- stage_output_proc_results: dict[int, Any | None] = {s: None for s in llm_stage_ids}
- stage_vllm_cfg_results: dict[int, Any | None] = {s: None for s in llm_stage_ids}
-
- for future in concurrent.futures.as_completed(attach_futures):
- stage_idx, replica_id = attach_futures[future]
- stage_client, output_processor, vllm_config, stage0_input_processor = future.result()
- stage_attach_results[stage_idx][replica_id] = stage_client
- if stage_output_proc_results[stage_idx] is None and output_processor is not None:
- stage_output_proc_results[stage_idx] = output_processor
- if stage_vllm_cfg_results[stage_idx] is None:
- stage_vllm_cfg_results[stage_idx] = vllm_config
- if stage0_input_processor is not None:
- input_processor = stage0_input_processor
-
- # ---- Build StagePool list + finalize metadata ----
- # Use first replica's client per stage for finalize (default sampling params, metadata).
- logical_stage_clients_for_finalize: list[Any | None] = [None] * num_stages
- for stage_idx in llm_stage_ids:
- logical_stage_clients_for_finalize[stage_idx] = stage_attach_results[stage_idx][0]
- for stage_idx, diff_client in diffusion_clients.items():
- logical_stage_clients_for_finalize[stage_idx] = diff_client
-
- _, default_sampling_params_list, stage_metadata_list = finalize_initialized_stages(
- logical_stage_clients_for_finalize,
- input_processor,
+ initialized_clients_by_stage = self._initialize_stage_replicas(stage_plans, stage_init_timeout)
+ if stage_plans and stage_plans[0].replicas[0].metadata.stage_type != "diffusion":
+ stage0_vllm_config = stage_plans[0].replicas[0].stage_vllm_config
+ assert stage0_vllm_config is not None
+ input_processor = build_stage0_input_processor(stage0_vllm_config)
+ stage_pools = self._assemble_stage_pools(stage_plans, initialized_clients_by_stage)
+ except Exception as exc:
+ initialized_clients_by_stage = getattr(
+ exc,
+ "_initialized_clients_by_stage",
+ initialized_clients_by_stage,
+ )
+ cleanup_clients = self._collect_initialized_clients_for_cleanup(
+ stage_pools,
+ initialized_clients_by_stage,
)
-
- for stage_id in range(num_stages):
- if stage_id in diffusion_clients:
- stage_pools.append(StagePool(stage_id, diffusion_clients[stage_id]))
- else:
- stage_pools.append(
- StagePool(
- stage_id,
- stage_attach_results[stage_id],
- output_processor=stage_output_proc_results[stage_id],
- stage_vllm_config=stage_vllm_cfg_results[stage_id],
- )
- )
-
- except Exception:
- for stage_id, futures in llm_launch_futures.items():
- for f in futures:
- if not f.done() or f.cancelled() or f.exception() is not None:
- continue
- started_llm_stages.setdefault(stage_id, []).append(f.result())
- # Collect all initialized clients for cleanup
- cleanup_clients: list[Any] = list(diffusion_clients.values())
- for pool in stage_pools:
- for client in pool.clients:
- if client is not None:
- cleanup_clients.append(client)
- all_started = [s for stages in started_llm_stages.values() for s in stages]
logger.exception(
"[AsyncOmniEngine] Stage initialization failed; shutting down %s initialized client(s)",
len(cleanup_clients),
)
- cleanup_failed_stage_initialization(cleanup_clients, all_started)
+ self._shutdown_initialized_clients(cleanup_clients)
if self._omni_master_server is not None:
try:
self._omni_master_server.stop()
@@ -991,21 +980,18 @@ def _initialize_stages(self, stage_init_timeout: int) -> None:
self.prompt_expand_func = prompt_expand_func
# Derive logical-stage views for external readers (entrypoints/async_omni.py).
- self.stage_clients = [pool.stage_client for pool in stage_pools]
- self.stage_vllm_configs = [pool.stage_vllm_config for pool in stage_pools]
- self.output_processors = [pool.output_processor for pool in stage_pools]
+ self.stage_clients = [pool.stage_client for pool in self.stage_pools]
+ self.stage_vllm_configs = [pool.stage_vllm_config for pool in self.stage_pools]
+ self.output_processors = [pool.output_processor for pool in self.stage_pools]
# TODO(Peiqi): Hack here
supported_tasks: set[str] = set()
- if any(getattr(pool.stage_client, "is_comprehension", False) for pool in stage_pools):
+ if any(getattr(pool.stage_client, "is_comprehension", False) for pool in self.stage_pools):
supported_tasks.add("generate")
- if any(m.get("final_output_type") == "audio" for m in stage_metadata_list):
+ if any(m.get("final_output_type") == "audio" for m in self.stage_metadata):
supported_tasks.add("speech")
self.supported_tasks = tuple(supported_tasks) if supported_tasks else ("generate",)
- self.default_sampling_params_list = list(default_sampling_params_list)
- self.stage_metadata = list(stage_metadata_list)
-
def _initialize_janus_queues(self) -> None:
"""Initialize janus queues inside orchestrator thread loop context."""
self.request_queue = janus.Queue()
@@ -1044,13 +1030,17 @@ async def _run_orchestrator() -> None:
loop.run_until_complete(_run_orchestrator())
except Exception as e:
if not startup_future.done():
- startup_future.set_exception(RuntimeError(f"Orchestrator initialization failed: {e}"))
+ wrapped = RuntimeError(f"Orchestrator initialization failed: {e}")
+ wrapped.__cause__ = e
+ startup_future.set_exception(wrapped)
logger.exception("[AsyncOmniEngine] Orchestrator thread crashed")
+ error_text = str(e) or "Orchestrator thread crashed"
try:
+ error_msg = {"type": "error", "error": error_text, "fatal": True}
if self.output_queue is not None:
- self.output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"})
+ self.output_queue.sync_q.put_nowait(error_msg)
if self.rpc_output_queue is not None:
- self.rpc_output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"})
+ self.rpc_output_queue.sync_q.put_nowait(error_msg)
except Exception:
pass
raise
@@ -1070,6 +1060,31 @@ async def _run_orchestrator() -> None:
asyncio.set_event_loop(None)
loop.close()
+ def _wait_for_orchestrator_init(self, startup_future: concurrent.futures.Future, startup_timeout: int) -> None:
+ """
+ Wait for orchestrator startup future to return ready. Raises exception on any failures to the init process.
+ """
+ deadline = time.monotonic() + startup_timeout
+ while True:
+ remaining = deadline - time.monotonic()
+ if remaining <= 0:
+ self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup timeout")
+ raise TimeoutError(f"Orchestrator did not become ready within {startup_timeout}s")
+ try:
+ startup_future.result(
+ timeout=min(remaining, _STARTUP_POLL_INTERVAL_S),
+ )
+ break
+ except concurrent.futures.TimeoutError:
+ if not self.orchestrator_thread.is_alive():
+ self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure")
+ if startup_future.done():
+ startup_future.result() # re-raises the real exception
+ raise RuntimeError("Orchestrator thread died during startup")
+ except Exception:
+ self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure")
+ raise
+
# ---- request helpers ----
def _build_add_request_message(
@@ -1493,6 +1508,12 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
# Inject diffusion LoRA-related knobs from kwargs if not present in the stage config.
for cfg in stage_configs:
try:
+ if not hasattr(cfg, "engine_args") or cfg.engine_args is None:
+ cfg.engine_args = OmegaConf.create({})
+ global_sleep_mode = kwargs.get("enable_sleep_mode")
+ if global_sleep_mode is not None:
+ if not hasattr(cfg.engine_args, "enable_sleep_mode") or cfg.engine_args.enable_sleep_mode is None:
+ cfg.engine_args.enable_sleep_mode = global_sleep_mode
if getattr(cfg, "stage_type", None) != "diffusion":
continue
if not hasattr(cfg, "engine_args") or cfg.engine_args is None:
@@ -1814,3 +1835,9 @@ def shutdown(self) -> None:
except Exception:
logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during shutdown")
self._omni_master_server = None
+
+ def _try_shutdown(self, *args, **kwargs) -> None:
+ try:
+ self.shutdown()
+ except Exception:
+ logger.exception(*args, **kwargs)
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index 81ff960116b..67843c19c99 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -20,6 +20,7 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreOutputs
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.engine import OmniEngineCoreRequest
from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker
@@ -140,6 +141,7 @@ def __init__(
self._shutdown_event = asyncio.Event()
self._stages_shutdown = False
+ self._fatal_error: str | None = None
async def run(self) -> None:
"""Main entry point for the Orchestrator event loop."""
@@ -168,6 +170,12 @@ async def run(self) -> None:
except Exception:
pass
+ # If a fatal error caused the shutdown, drain any pending
+ # add_request messages that were never processed and broadcast
+ # fatal error responses so callers are not left hanging.
+ if self._fatal_error is not None:
+ await self._drain_pending_requests_on_fatal()
+
self._shutdown_stages()
loop = asyncio.get_running_loop()
@@ -349,8 +357,10 @@ async def _handle_collective_rpc(self, msg: dict[str, Any]) -> None:
target_pools.extend(self.stage_pools)
else:
for lid in requested_stage_ids:
- if 0 <= lid < self.num_stages:
- target_pools.append(self.stage_pools[lid])
+ if not (0 <= lid < self.num_stages):
+ logger.warning("[Orchestrator] collective_rpc: ignoring invalid stage_id %s", lid)
+ continue
+ target_pools.append(self.stage_pools[lid])
results: list[Any] = []
stage_ids: list[int] = []
@@ -404,22 +414,55 @@ async def _orchestration_loop(self) -> None:
await self._handle_processed_outputs(stage_id, replica_id, [output])
idle = False
else:
- raw_outputs = await pool.poll_llm_raw_output(replica_id, timeout_s=0.001)
- if raw_outputs is None:
- continue
-
- await self._handle_kv_ready_raw_outputs(stage_id, raw_outputs)
- for eco in raw_outputs.outputs:
- req_state = self.request_states.get(getattr(eco, "request_id", None))
- if req_state is None:
+ try:
+ raw_outputs = await pool.poll_llm_raw_output(replica_id, timeout_s=0.001)
+ if raw_outputs is None:
continue
- req_state.streaming.segment_finished = bool(getattr(eco, "is_segment_finished", False))
- req_state.streaming.new_prompt_len_snapshot = getattr(
- eco,
- "new_prompt_len_snapshot",
- None,
+
+ await self._handle_kv_ready_raw_outputs(stage_id, raw_outputs)
+ for eco in raw_outputs.outputs:
+ req_state = self.request_states.get(getattr(eco, "request_id", None))
+ if req_state is None:
+ continue
+ req_state.streaming.segment_finished = bool(getattr(eco, "is_segment_finished", False))
+ req_state.streaming.new_prompt_len_snapshot = getattr(
+ eco,
+ "new_prompt_len_snapshot",
+ None,
+ )
+ raw_output = await pool.process_llm_raw_outputs(replica_id, raw_outputs)
+ except asyncio.CancelledError:
+ raise
+ except EngineDeadError as e:
+ logger.error(
+ "[Orchestrator] Stage-%s is dead: %s",
+ stage_id,
+ e,
)
- raw_output = await pool.process_llm_raw_outputs(replica_id, raw_outputs)
+ self._fatal_error = str(e)
+ for req_id, req_state in list(self.request_states.items()):
+ if stage_id in req_state.stage_submit_ts:
+ await self.output_async_queue.put(
+ {
+ "type": "error",
+ "error": str(e),
+ "fatal": True,
+ "request_id": req_id,
+ }
+ )
+ self.request_states.pop(req_id, None)
+ self._shutdown_event.set()
+ raise
+ except Exception:
+ if self._shutdown_event.is_set():
+ return
+ logger.exception(
+ "[Orchestrator] Stage-%s replica-%s processing failed",
+ stage_id,
+ replica_id,
+ )
+ raise
+
await self._handle_processed_outputs(stage_id, replica_id, raw_output)
idle = False
@@ -713,11 +756,22 @@ async def _forward_to_next_stage(
if next_pool.stage_type == "diffusion":
if next_client.custom_process_input_func is not None:
+ _t_ar2d = _time.perf_counter()
diffusion_prompt = next_client.custom_process_input_func(
source_outputs,
req_state.prompt,
requires_multimodal_data,
)
+ _dt_ar2d = (_time.perf_counter() - _t_ar2d) * 1000
+ logger.info(
+ "[Orchestrator] ar2diffusion req=%s wall_time=%.3fms stage=%d->%d",
+ req_id,
+ _dt_ar2d,
+ src_stage_id,
+ next_logical,
+ )
+ if already_submitted and isinstance(diffusion_prompt, list) and len(diffusion_prompt) == 1:
+ diffusion_prompt = diffusion_prompt[0]
else:
diffusion_prompt = req_state.prompt
@@ -920,6 +974,53 @@ def _build_kv_sender_info(
# ---- Shutdown / lifecycle ----
+ async def _drain_pending_requests_on_fatal(self) -> None:
+ """Drain the request queue and broadcast fatal errors for any
+ pending add_request messages that were never processed.
+
+ Called from the ``run()`` finally block when a fatal error
+ (e.g. ``EngineDeadError``) caused the orchestrator to shut down
+ before the request handler could process all queued messages.
+ Also broadcasts for any already-tracked requests still in
+ ``request_states`` that were not yet notified.
+ """
+ assert self._fatal_error is not None
+
+ notified: set[str] = set()
+
+ # 1) Drain pending messages from the request queue.
+ while True:
+ try:
+ msg = self.request_async_queue.get_nowait()
+ except Exception:
+ break
+ if msg.get("type") == "add_request":
+ req_id = msg["request_id"]
+ await self.output_async_queue.put(
+ {
+ "type": "error",
+ "error": self._fatal_error,
+ "fatal": True,
+ "request_id": req_id,
+ }
+ )
+ notified.add(req_id)
+
+ # 2) Broadcast for any tracked requests not already notified
+ # (e.g. request was registered but the EngineDeadError handler
+ # missed it because it wasn't submitted to the dead stage yet).
+ for req_id in list(self.request_states):
+ if req_id not in notified:
+ await self.output_async_queue.put(
+ {
+ "type": "error",
+ "error": self._fatal_error,
+ "fatal": True,
+ "request_id": req_id,
+ }
+ )
+ self.request_states.pop(req_id, None)
+
def _shutdown_stages(self) -> None:
"""Shutdown all stage pools."""
if self._stages_shutdown:
diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py
index 67b4dd16504..5d4626f10f9 100644
--- a/vllm_omni/engine/output_processor.py
+++ b/vllm_omni/engine/output_processor.py
@@ -118,10 +118,9 @@ def _consolidate_multimodal_tensors(self) -> None:
if isinstance(v, list) and v and isinstance(v[0], torch.Tensor):
try:
if k == "audio":
- # Concatenate delta audio chunks (1-D) into the full waveform.
- # Each entry is a per-step slice; flatten to -1 so chunks with
- # inconsistent leading dims can still be joined on the sample axis.
- self.mm_accumulated[k] = torch.cat([t.reshape(-1) for t in v], dim=0)
+ # When the audio tensor shape is inconsistent, torch.cat will fail.
+ # We need to use torch.cat in -1 dimension.
+ continue
elif k == "sr":
# Sample rate is a constant scalar, keep last value.
self.mm_accumulated[k] = v[-1]
@@ -332,6 +331,24 @@ def add_request(
self.parent_requests[parent_req.request_id] = parent_req
self.external_req_ids[req_state.external_req_id].append(request_id)
+ def remove_request(self, request_id: str) -> None:
+ """Rollback one previously registered request if it was never submitted."""
+ req_state = self.request_states.pop(request_id, None)
+ if req_state is None:
+ return
+
+ external_req_id = getattr(req_state, "external_req_id", None)
+ if external_req_id is not None:
+ request_ids = self.external_req_ids.get(external_req_id)
+ if request_ids is not None:
+ self.external_req_ids[external_req_id] = [rid for rid in request_ids if rid != request_id]
+ if not self.external_req_ids[external_req_id]:
+ self.external_req_ids.pop(external_req_id, None)
+
+ parent_req = getattr(req_state, "parent_req", None)
+ if parent_req is not None:
+ self.parent_requests.pop(parent_req.request_id, None)
+
def process_outputs(
self,
engine_core_outputs: list[EngineCoreOutput],
diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py
index cf0e4241032..269b76b4948 100644
--- a/vllm_omni/engine/stage_engine_core_client.py
+++ b/vllm_omni/engine/stage_engine_core_client.py
@@ -7,13 +7,17 @@
from __future__ import annotations
import inspect
+import multiprocessing.connection
import socket
+import threading
+import weakref
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from vllm.logger import init_logger
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import AsyncMPClient, DPLBAsyncMPClient
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.distributed.omni_connectors.utils.initialization import (
KV_TRANSFER_PORT_OFFSET,
@@ -66,6 +70,8 @@ class StageEngineCoreClientBase:
``engine_manager`` / ``coordinator`` pair created elsewhere.
"""
+ replica_id: int = 0
+
@staticmethod
def make_async_mp_client(
vllm_config: Any,
@@ -124,8 +130,10 @@ def __init__(
manage the process lifecycle on shutdown.
"""
# -------- Stage metadata (public fields used at runtime) --------
+ self.replica_id = 0
if metadata is not None:
self.stage_id = metadata.stage_id
+ self.replica_id = getattr(metadata, "replica_id", 0)
self.stage_type = metadata.stage_type
self.engine_output_type = metadata.engine_output_type
self.is_comprehension = metadata.is_comprehension
@@ -147,9 +155,10 @@ def __init__(
client_name = self.__class__.__name__
logger.info(
- "[%s] Stage-%s initializing EngineCore",
+ "[%s] stage-%s [rep-%s] initializing EngineCore",
client_name,
self.stage_id,
+ self.replica_id,
)
try:
super().__init__(
@@ -166,35 +175,92 @@ def __init__(
self.resources.coordinator = coordinator
except Exception:
logger.exception(
- "[%s] Stage-%s EngineCore init failed",
+ "[%s] stage-%s [rep-%s] EngineCore init failed",
client_name,
self.stage_id,
+ self.replica_id,
)
try:
self.shutdown()
except Exception as shutdown_error:
logger.warning(
- "[%s] Stage-%s cleanup after init failure failed: %s",
+ "[%s] stage-%s [rep-%s] cleanup after init failure failed: %s",
client_name,
self.stage_id,
+ self.replica_id,
shutdown_error,
)
raise
+
self._initialize_kv_sender_endpoint()
+
+ if self._proc is not None:
+ self._start_proc_monitor()
+
logger.info(
- "[%s] Stage-%s EngineCore running",
+ "[%s] stage-%s [rep-%s] EngineCore running",
client_name,
self.stage_id,
+ self.replica_id,
+ )
+
+ def _start_proc_monitor(self) -> None:
+ """Start a daemon thread that watches the subprocess sentinel.
+
+ When the subprocess dies without sending the ZMQ ``ENGINE_CORE_DEAD``
+ sentinel (e.g. SIGKILL, segfault, OOM-killer), this thread sets
+ ``resources.engine_dead`` so subsequent calls raise
+ ``EngineDeadError``.
+ """
+ proc = self._proc
+ resources_ref = weakref.ref(self.resources)
+ stage_id = self.stage_id
+ replica_id = self.replica_id
+
+ def _monitor() -> None:
+ try:
+ multiprocessing.connection.wait([proc.sentinel])
+ except Exception:
+ return
+ resources = resources_ref()
+ if resources is None or resources.engine_dead:
+ return
+ resources.engine_dead = True
+ logger.error(
+ "[StageEngineCoreClient] stage-%s [rep-%s] subprocess died unexpectedly (exit code %s).",
+ stage_id,
+ replica_id,
+ proc.exitcode,
+ )
+
+ t = threading.Thread(
+ target=_monitor,
+ daemon=True,
+ name=f"StageCoreProcMonitor-{stage_id}",
)
+ t.start()
+
+ def check_health(self) -> None:
+ """Raise ``EngineDeadError`` if the stage subprocess is dead.
+
+ Called by ``OmniBase.check_health()`` and transitively by the
+ ``/health`` HTTP endpoint.
+ """
+ if self.resources.engine_dead:
+ raise EngineDeadError(f"Stage-{self.stage_id} engine core is dead")
+ if self._proc is not None and not self._proc.is_alive():
+ self.resources.engine_dead = True
+ raise EngineDeadError(f"Stage-{self.stage_id} subprocess is not alive (exit code {self._proc.exitcode})")
# ==================== Overrides ====================
async def add_request_async(self, request: EngineCoreRequest) -> None:
"""Add request to the stage engine core."""
logger.info(
- "[%s] Stage-%s adding request: %s",
+ "[%s] stage-%s [rep-%s] add request: %s",
self.__class__.__name__,
self.stage_id,
+ self.replica_id,
request.request_id,
)
await super().add_request_async(request)
@@ -275,9 +341,10 @@ def _initialize_kv_sender_endpoint(self) -> None:
sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage)
except (TypeError, ValueError):
logger.warning(
- "[StageEngineCoreClient] Stage-%s could not resolve sender_zmq_port "
+ "[StageEngineCoreClient] stage-%s [rep-%s] could not resolve sender_zmq_port "
"from base_port=%s and from_stage=%s",
self.stage_id,
+ self.replica_id,
base_port,
from_stage,
)
diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py
index 456f5a9244b..5cc9d12a196 100644
--- a/vllm_omni/engine/stage_init_utils.py
+++ b/vllm_omni/engine/stage_init_utils.py
@@ -19,19 +19,47 @@
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
+from vllm.tokenizers import cached_tokenizer_from_config
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.executor import Executor
+from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.engine.arg_utils import OmniEngineArgs
+from vllm_omni.engine.output_processor import MultimodalOutputProcessor
from vllm_omni.entrypoints.stage_utils import _to_dict, set_stage_devices
from vllm_omni.entrypoints.utils import filter_dataclass_kwargs, resolve_model_config_path
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams
+from vllm_omni.inputs.preprocess import OmniInputPreprocessor
from vllm_omni.platforms import current_omni_platform
logger = init_logger(__name__)
+@dataclass
+class ReplicaInitPlan:
+ """One concrete replica startup unit within a logical stage."""
+
+ replica_id: int
+ num_replicas: int
+ launch_mode: str
+ stage_cfg: Any
+ metadata: Any
+ stage_connector_spec: dict[str, Any]
+ omni_kv_connector: tuple[dict[str, Any] | None, str | None, str | None]
+ stage_vllm_config: Any | None = None
+ executor_class: type | None = None
+
+
+@dataclass
+class LogicalStageInitPlan:
+ """Startup plan for one logical stage."""
+
+ stage_idx: int
+ configured_stage_id: int
+ replicas: list[ReplicaInitPlan]
+
+
def _resolve_model_to_local_path(model: str) -> str:
"""Resolve an HF Hub model ID to a local cache path."""
if os.path.isdir(model):
@@ -83,6 +111,14 @@ def terminate_alive_proc(proc, timeout=5):
proc.kill()
+def patch_generation_config_if_needed(model_config: Any) -> None:
+ """Guard InputProcessor init for models whose config lacks model_type."""
+ try:
+ model_config.try_get_generation_config()
+ except Exception:
+ model_config.try_get_generation_config = lambda: {}
+
+
def resolve_worker_cls(engine_args: dict[str, Any]) -> None:
"""Resolve worker_cls from worker_type for non-diffusion stages."""
worker_type = engine_args.get("worker_type", None)
@@ -262,20 +298,6 @@ class StageMetadata:
replica_id: int = 0
-@dataclass
-class StartedLlmStage:
- """Resources for an LLM stage that has completed startup."""
-
- stage_id: int
- metadata: Any
- vllm_config: Any
- executor_class: type
- addresses: Any
- proc: Any = None
- engine_manager: Any = None
- coordinator: Any = None
-
-
def extract_stage_metadata(stage_config: Any) -> StageMetadata:
"""Pure data extraction from a stage_config object."""
stage_id: int = stage_config.stage_id
@@ -421,16 +443,36 @@ def get_stage_tp_size(stage_cfg: Any) -> int:
return int(getattr(engine_args, "tensor_parallel_size", 1) or 1)
+def get_stage_devices_per_replica(stage_cfg: Any) -> int:
+ """Return the number of devices consumed by one replica of *stage_cfg*."""
+ if getattr(stage_cfg, "stage_type", "llm") != "diffusion":
+ return get_stage_tp_size(stage_cfg)
+
+ parallel_config = _get_attr_or_item(getattr(stage_cfg, "engine_args", {}), "parallel_config")
+ if parallel_config is None:
+ return 1
+
+ world_size = _get_attr_or_item(parallel_config, "world_size")
+ if world_size is not None:
+ return max(1, int(world_size))
+
+ try:
+ from vllm_omni.diffusion.data import DiffusionParallelConfig
+
+ return max(1, int(DiffusionParallelConfig.from_dict(_to_dict(parallel_config)).world_size))
+ except Exception:
+ return 1
+
+
def compute_replica_layout(
stage_configs: Sequence[Any],
-) -> tuple[list[int], dict[int, list[str]], int]:
+) -> tuple[list[int], dict[int, list[str]]]:
"""Compute per-stage replica counts and device assignments.
Returns:
replicas_per_stage: num_replicas per logical stage.
replica_devices_map: stage_idx -> per-replica device strings
(only for stages with num_replicas > 1).
- total_llm_replicas: total LLM replica count across all stages.
"""
replicas_per_stage: list[int] = []
for stage_cfg in stage_configs:
@@ -451,25 +493,22 @@ def compute_replica_layout(
devices_str = (
runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "devices", None)
)
- tp_size = get_stage_tp_size(stage_cfg)
+ devices_per_replica = get_stage_devices_per_replica(stage_cfg)
replica_devices_map[stage_id] = split_devices_for_replicas(
devices_str,
num_replicas,
- tp_size,
+ devices_per_replica,
stage_id,
)
logger.info(
- "[stage_init] Stage %s: %d replicas, tp=%d, devices split: %s",
+ "[stage_init] Stage %s: %d replicas, devices_per_replica=%d, devices split: %s",
stage_id,
num_replicas,
- tp_size,
+ devices_per_replica,
replica_devices_map[stage_id],
)
- total_llm_replicas = sum(
- replicas_per_stage[i] for i, cfg in enumerate(stage_configs) if getattr(cfg, "stage_type", "llm") != "diffusion"
- )
- return replicas_per_stage, replica_devices_map, total_llm_replicas
+ return replicas_per_stage, replica_devices_map
def setup_stage_devices(stage_id: int, runtime_cfg: Any) -> None:
@@ -556,6 +595,35 @@ def build_vllm_config(
return vllm_config, executor_class
+def build_llm_stage_output_processor(plan: LogicalStageInitPlan, stage_vllm_config: Any) -> Any | None:
+ """Build one output processor per logical LLM stage."""
+
+ metadata = plan.replicas[0].metadata
+ if stage_vllm_config.model_config.skip_tokenizer_init:
+ tokenizer = None
+ else:
+ tokenizer = cached_tokenizer_from_config(
+ model_config=stage_vllm_config.model_config,
+ )
+ return MultimodalOutputProcessor(
+ tokenizer=tokenizer,
+ log_stats=False,
+ engine_core_output_type=metadata.engine_output_type,
+ )
+
+
+def build_stage0_input_processor(stage_vllm_config: Any) -> InputProcessor:
+ """Build the shared stage-0 input processor."""
+
+ patch_generation_config_if_needed(stage_vllm_config.model_config)
+ input_processor = InputProcessor(vllm_config=stage_vllm_config)
+ input_processor.input_preprocessor = OmniInputPreprocessor(
+ vllm_config=stage_vllm_config,
+ renderer=input_processor.renderer,
+ )
+ return input_processor
+
+
def acquire_device_locks(
stage_id: int,
engine_args_dict: dict[str, Any],
@@ -749,6 +817,7 @@ def build_diffusion_config(
def initialize_diffusion_stage(
+ stage_id: int,
model: str,
stage_cfg: Any,
metadata: StageMetadata,
@@ -770,99 +839,14 @@ def initialize_diffusion_stage(
"""
from vllm_omni.diffusion.stage_diffusion_client import create_diffusion_client
+ engine_args = _to_dict(stage_cfg.engine_args)
+ engine_args.pop("stage_id", None)
+ od_config = OmniDiffusionConfig.from_kwargs(
+ stage_id=stage_id,
+ model=model,
+ **engine_args,
+ )
+ if metadata.cfg_kv_collect_func is not None:
+ od_config.cfg_kv_collect_func = metadata.cfg_kv_collect_func
od_config = build_diffusion_config(model, stage_cfg, metadata)
return create_diffusion_client(model, od_config, metadata, stage_init_timeout, batch_size, use_inline)
-
-
-def _shutdown_or_close_resource(resource: Any, resource_name: str, stage_id: int) -> None:
- """vLLM CoreEngineProcManager / coordinators use ``shutdown()``, not ``close()``."""
- if resource is None:
- return
- shutdown = getattr(resource, "shutdown", None)
- if callable(shutdown):
- try:
- shutdown()
- except Exception as cleanup_error:
- logger.warning(
- "[stage_init] Failed to shutdown launched %s for stage %s: %s",
- resource_name,
- stage_id,
- cleanup_error,
- )
- return
- close = getattr(resource, "close", None)
- if callable(close):
- try:
- close()
- except Exception as cleanup_error:
- logger.warning(
- "[stage_init] Failed to close launched %s for stage %s: %s",
- resource_name,
- stage_id,
- cleanup_error,
- )
-
-
-def close_started_llm_stage(started: StartedLlmStage) -> None:
- """Release resources owned by a launched stage that never attached."""
- if started.proc is not None:
- try:
- terminate_alive_proc(started.proc)
- except Exception as cleanup_error:
- logger.warning(
- "[stage_init] Failed to terminate process for stage %s: %s",
- started.stage_id,
- cleanup_error,
- )
- _shutdown_or_close_resource(started.engine_manager, "engine manager", started.stage_id)
- _shutdown_or_close_resource(started.coordinator, "coordinator", started.stage_id)
-
-
-def finalize_initialized_stages(
- stage_clients: list[Any | None],
- input_processor: InputProcessor | None,
-) -> tuple[list[Any], list[Any], list[dict[str, Any]]]:
- """Validate successful init and build runtime metadata lists."""
- if any(stage_client is None for stage_client in stage_clients):
- raise RuntimeError("Stage initialization completed with missing stage clients")
-
- initialized_stage_clients = [stage_client for stage_client in stage_clients if stage_client is not None]
- default_sampling_params_list = [stage_client.default_sampling_params for stage_client in initialized_stage_clients]
- stage_metadata = [
- {
- "final_output": stage_client.final_output,
- "final_output_type": stage_client.final_output_type,
- "stage_type": stage_client.stage_type,
- }
- for stage_client in initialized_stage_clients
- ]
-
- if not isinstance(input_processor, InputProcessor):
- has_llm_stage = any(metadata.get("stage_type") != "diffusion" for metadata in stage_metadata)
- if has_llm_stage:
- raise RuntimeError("Failed to initialize stage-0 InputProcessor for LLM pipeline")
-
- return initialized_stage_clients, default_sampling_params_list, stage_metadata
-
-
-def cleanup_failed_stage_initialization(
- stage_clients: list[Any | None],
- started_llm_stages: list[StartedLlmStage],
-) -> None:
- """Shutdown attached stages and close any launched-but-unattached engines."""
- for cleanup_stage_id, stage_client in reversed(list(enumerate(stage_clients))):
- if stage_client is None:
- continue
- try:
- stage_client.shutdown()
- except Exception as cleanup_error:
- logger.warning(
- "[stage_init] Failed to shutdown initialized stage %s after init failure: %s",
- cleanup_stage_id,
- cleanup_error,
- )
-
- for started in reversed(started_llm_stages):
- if stage_clients[started.stage_id] is not None:
- continue
- close_started_llm_stage(started)
diff --git a/vllm_omni/engine/stage_pool.py b/vllm_omni/engine/stage_pool.py
index b5cddb388e5..00bc20e8627 100644
--- a/vllm_omni/engine/stage_pool.py
+++ b/vllm_omni/engine/stage_pool.py
@@ -199,14 +199,34 @@ async def submit_initial(
request_id,
affinity_request_id=affinity_request_id,
)
- self.output_processor.add_request(
- request=request,
- prompt=prompt_text,
- parent_req=None,
- request_index=0,
- queue=None,
- )
- await self.clients[replica_id].add_request_async(request, **submit_kwargs)
+ try:
+ self.output_processor.add_request(
+ request=request,
+ prompt=prompt_text,
+ parent_req=None,
+ request_index=0,
+ queue=None,
+ )
+ except Exception:
+ self.release_binding(request_id)
+ raise
+
+ try:
+ await self.clients[replica_id].add_request_async(request, **submit_kwargs)
+ except Exception:
+ self.release_binding(request_id)
+ rollback = getattr(self.output_processor, "remove_request", None)
+ if callable(rollback):
+ try:
+ rollback(request_id)
+ except Exception as rollback_error:
+ logger.warning(
+ "[StagePool] Failed to rollback output processor state for req=%s stage-%s: %s",
+ request_id,
+ self.stage_id,
+ rollback_error,
+ )
+ raise
return replica_id
async def submit_update(
@@ -302,6 +322,7 @@ async def abort_requests(self, request_ids: list[str]) -> None:
for request_id in request_ids:
replica_id = self.get_bound_replica_id(request_id)
if replica_id is None:
+ logger.debug("[StagePool] abort: no binding for req=%s in stage-%s", request_id, self.stage_id)
continue
request_ids_by_replica.setdefault(replica_id, []).append(request_id)
diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py
index 9606cc80d0d..67425ea2e65 100644
--- a/vllm_omni/entrypoints/async_omni.py
+++ b/vllm_omni/entrypoints/async_omni.py
@@ -9,6 +9,7 @@
import asyncio
import time
+import uuid
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
@@ -24,10 +25,12 @@
from vllm.tasks import SupportedTask
from vllm.v1.engine.exceptions import EngineDeadError
+from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask
from vllm_omni.entrypoints.client_request_state import ClientRequestState
from vllm_omni.entrypoints.omni_base import OmniBase
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
from vllm_omni.outputs import OmniRequestOutput
+from vllm_omni.platforms import current_omni_platform
if TYPE_CHECKING:
from vllm.inputs.preprocess import InputPreprocessor
@@ -40,6 +43,63 @@
_FINAL_OUTPUT_IDLE_SLEEP_S = 0.001
+class AsyncEventResolver:
+ """
+ A generic signal aggregator designed for synchronized handshakes in
+ distributed or multi-stage environments. Supports waiting for a specified
+ number (expected_count) of worker signals in both inline and multiprocess modes.
+ """
+
+ def __init__(self, orchestrator=None):
+ self._pending_tasks: dict[str, dict] = {}
+ self.orchestrator = orchestrator
+ self._lock = asyncio.Lock()
+
+ def watch_task(self, task_id: str, expected_count: int = 1) -> asyncio.Future:
+ loop = asyncio.get_running_loop()
+ fut = loop.create_future()
+ self._pending_tasks[task_id] = {
+ "future": fut,
+ "expected_count": expected_count,
+ "received": [],
+ "start_time": time.time(),
+ }
+ return fut
+
+ async def resolve(self, ack: OmniACK):
+ tid = getattr(ack, "task_id", None)
+
+ if tid is None and isinstance(ack, dict):
+ tid = ack.get("task_id")
+
+ async with self._lock:
+ task_info = self._pending_tasks.get(tid)
+ if task_info is None:
+ logger.warning(f"Received stray ACK for task_id {tid}. Task might have timed out.")
+ return
+
+ task_info["received"].append(ack)
+ current_count = len(task_info["received"])
+ expected = task_info["expected_count"]
+
+ orchestrator = self.orchestrator
+ if orchestrator and hasattr(orchestrator, "metrics") and orchestrator.metrics:
+ freed = getattr(ack, "freed_bytes", 0)
+ if freed == 0 and isinstance(ack, dict):
+ freed = ack.get("freed_bytes", 0)
+ orchestrator.metrics.record_vram_reclaimed(freed)
+
+ logger.info(f"[Resolver] Task {tid} progress: {current_count}/{expected} ACKs received.")
+
+ if current_count >= expected:
+ self._pending_tasks.pop(tid)
+ fut = task_info["future"]
+ if not fut.done():
+ elapsed = time.time() - task_info["start_time"]
+ logger.info(f"[Resolver] Task {tid} completed successfully in {elapsed:.2f}s.")
+ fut.set_result(task_info["received"])
+
+
class AsyncOmni(EngineClient, OmniBase):
"""Asynchronous unified entry point for multi-stage pipelines using AsyncOmniEngine.
@@ -76,7 +136,7 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None:
self._paused: bool = False
self._is_sleeping: bool = False
self.final_output_task: asyncio.Task | None = None
-
+ self.event_resolver = AsyncEventResolver(orchestrator=self)
self.config_path = self.engine.config_path
self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None)
self.input_processor = self.engine.input_processor
@@ -433,6 +493,9 @@ async def _process_orchestrator_results(
stage_id = result.get("stage_id", 0)
+ if result.get("type") == "error" and result.get("fatal"):
+ raise EngineDeadError(result.get("error", ""))
+
# Check for errors
if "error" in result:
logger.error(
@@ -443,6 +506,8 @@ async def _process_orchestrator_results(
)
raise RuntimeError(result)
+ self._check_engine_output_error(result, request_id, stage_id)
+
# Process the result (constructs OmniRequestOutput)
output_to_yield = self._process_single_result(
result,
@@ -488,6 +553,22 @@ async def _final_output_loop():
await asyncio.sleep(_FINAL_OUTPUT_IDLE_SLEEP_S)
continue
+ if isinstance(msg, dict) and msg.get("type") == "ack":
+ ack_data = msg.get("ack")
+ tid = getattr(ack_data, "task_id", "unknown")
+ logger.info(f"[{self._name}] Intercepted wrapped ACK for task {tid}")
+ await self.event_resolver.resolve(ack_data)
+ continue
+ if isinstance(msg, OmniACK):
+ logger.info(f"[{self._name}] Intercepted raw ACK object: {msg.task_id}")
+ await self.event_resolver.resolve(msg)
+ continue
+ if hasattr(msg, "task_id"):
+ tid = getattr(msg, "task_id")
+ logger.info(f"[{self._name}] Intercepted task-ID object: {tid}")
+ await self.event_resolver.resolve(msg)
+ continue
+
should_continue, _, stage_id, req_state = self._handle_output_message(msg)
if should_continue:
continue
@@ -499,6 +580,16 @@ async def _final_output_loop():
except asyncio.CancelledError:
raise
+ except EngineDeadError as e:
+ logger.error("[AsyncOmni] Engine dead: %s", e)
+ for req_state in list(self.request_states.values()):
+ error_msg = {
+ "type": "error",
+ "error": str(e),
+ "fatal": True,
+ "request_id": req_state.request_id,
+ }
+ await req_state.queue.put(error_msg)
except Exception as e:
logger.exception("[AsyncOmni] final_output_loop failed.")
for req_state in list(self.request_states.values()):
@@ -654,21 +745,68 @@ async def reset_prefix_cache(
logger.warning("[AsyncOmni] reset_prefix_cache not yet supported with Orchestrator process")
return True
- async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
- """Sleep all stages.
-
- Best-effort: unsupported stages will emit a TODO result.
- """
+ async def sleep(
+ self, stage_ids: list[int] | None = None, level: int = 2, mode: PauseMode = "abort"
+ ) -> list[OmniACK]:
+ self._final_output_handler()
+ if stage_ids is None:
+ stage_ids = list(range(len(self.engine.stage_clients)))
+ total_workers = 0
+ for sid in stage_ids:
+ client = self.engine.stage_clients[sid]
+ # During the Diffusion phase, regardless of the TP amount,
+ # currently only a summary ACK is reported at Rank 0.
+ if getattr(client, "stage_type", "") == "diffusion":
+ total_workers += 1
+ else:
+ config = self.engine.stage_vllm_configs[sid]
+ actual_tp = config.parallel_config.tensor_parallel_size if config else 1
+ total_workers += actual_tp
+
+ task_id = str(uuid.uuid4())
+ self.event_resolver.watch_task(task_id, expected_count=total_workers)
+ logger.info(f"[{self._name}] Sleep initiated (Task: {task_id}). Awaiting {total_workers} ACKs...")
+ task = OmniSleepTask(level=level, task_id=task_id)
+ rpc_results = await self.collective_rpc(method="handle_sleep_task", args=(task,), stage_ids=stage_ids)
+ final_acks = []
+ for stage_res in rpc_results:
+ worker_acks = stage_res if isinstance(stage_res, list) else [stage_res]
+ for ack in worker_acks:
+ if ack is not None:
+ await self.event_resolver.resolve(ack)
+ final_acks.append(ack)
self._is_sleeping = True
- await self.collective_rpc(method="sleep", args=(level,))
-
- async def wake_up(self, tags: list[str] | None = None) -> None:
- """Wake up all stages.
-
- Best-effort: unsupported stages will emit a TODO result.
- """
+ return final_acks
+
+ async def wake_up(self, stage_ids: list[int] | None = None, tags: list[str] | None = None) -> list[OmniACK]:
+ self._final_output_handler()
+ if stage_ids is None:
+ stage_ids = list(range(len(self.engine.stage_clients)))
+ total_workers = 0
+ for sid in stage_ids:
+ client = self.engine.stage_clients[sid]
+ if getattr(client, "stage_type", "") == "diffusion":
+ total_workers += 1
+ else:
+ config = self.engine.stage_vllm_configs[sid]
+ total_workers += config.parallel_config.tensor_parallel_size if config else 1
+ task_id = str(uuid.uuid4())
+ self.event_resolver.watch_task(task_id, expected_count=total_workers)
+ logger.info(f"[{self._name}] Wake-up initiated (Task: {task_id}). Awaiting {total_workers} ACKs...")
+ task = OmniWakeTask(tags=tags, task_id=task_id)
+ rpc_results = await self.collective_rpc(method="handle_wake_task", args=(task,), stage_ids=stage_ids)
+ final_acks = []
+ for stage_res in rpc_results:
+ worker_acks = stage_res if isinstance(stage_res, list) else [stage_res]
+ for ack in worker_acks:
+ if ack is not None:
+ await self.event_resolver.resolve(ack)
+ final_acks.append(ack)
+ current_omni_platform.synchronize()
+ await asyncio.sleep(0.1)
self._is_sleeping = False
- await self.collective_rpc(method="wake_up", args=(tags,))
+ logger.info(f"[{self._name}] All {len(final_acks)}/{total_workers} workers reported WARM for task {task_id}.")
+ return final_acks
async def is_sleeping(self) -> bool:
"""Return whether all stages are sleeping.
@@ -718,12 +856,25 @@ async def pin_lora(self, adapter_id: int) -> bool:
@property
def is_running(self) -> bool:
"""Check if the engine is running."""
- return self.final_output_task is not None and not self.final_output_task.done()
+ orchestrator_alive = self.engine.is_alive()
+ task_alive = self.final_output_task is not None and not self.final_output_task.done()
+ return orchestrator_alive and task_alive
@property
def errored(self) -> bool:
- """Whether orchestrator thread has stopped unexpectedly."""
- return not self.engine.is_alive()
+ """Whether the engine is in a non-recoverable error state.
+
+ Delegates to ``OmniBase.errored`` which checks the orchestrator
+ thread and all stage clients. Redeclared here to satisfy the
+ ``EngineClient`` abstract-property requirement (Python's ABC
+ mechanism does not resolve abstract methods from sibling MRO
+ entries).
+ """
+ return OmniBase.errored.fget(self) # type: ignore[union-attr]
+
+ @property
+ def _name(self) -> str:
+ return "AsyncOrchestrator"
@property
def is_stopped(self) -> bool:
diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py
index 8bccfbb5916..e355ee679d8 100644
--- a/vllm_omni/entrypoints/cli/serve.py
+++ b/vllm_omni/entrypoints/cli/serve.py
@@ -139,6 +139,17 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
action="store_true",
help="Enable vLLM-Omni mode for multi-modal and diffusion models",
)
+
+ try:
+ omni_config_group.add_argument(
+ "--enable-sleep-mode",
+ action="store_true",
+ default=False,
+ help="Enable GPU memory pool for sleep mode.",
+ )
+ except argparse.ArgumentError:
+ pass
+
omni_config_group.add_argument(
"--task-type",
type=str,
diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py
index 8ef7e2ee5b7..223c208af98 100644
--- a/vllm_omni/entrypoints/omni.py
+++ b/vllm_omni/entrypoints/omni.py
@@ -166,6 +166,8 @@ def _run_generation(
logger.warning("[Omni] Received output for unknown/finished request_id=%s", req_id)
continue
+ self._check_engine_output_error(msg, req_id, stage_id)
+
if req_state.metrics is None:
continue
output_to_yield = self._process_single_result(
diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py
index dca494efe72..3180c9c80c0 100644
--- a/vllm_omni/entrypoints/omni_base.py
+++ b/vllm_omni/entrypoints/omni_base.py
@@ -12,7 +12,7 @@
import huggingface_hub
from vllm.logger import init_logger
-from vllm.v1.engine.exceptions import EngineDeadError
+from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.entrypoints.client_request_state import ClientRequestState
@@ -133,7 +133,7 @@ def __init__(
if "log_requests" in kwargs:
raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.")
model = omni_snapshot_download(model)
- self._name = self.__class__.__name__
+ self.__dict__["_name"] = self.__class__.__name__
self.model = model
self.log_stats = log_stats
# Provisional value (mirrors the CLI/caller kwarg); the engine resolves
@@ -198,9 +198,33 @@ def stage_configs(self) -> list:
def is_running(self) -> bool:
return self.engine.is_alive()
+ @property
+ def errored(self) -> bool:
+ """Whether the engine is in a non-recoverable error state.
+
+ True when the orchestrator thread is dead **or** any stage client
+ has been marked dead (e.g. diffusion worker OOM / process death).
+
+ Checks both ``_engine_dead`` (StageDiffusionClient) and
+ ``resources.engine_dead`` (StageEngineCoreClient / AsyncMPClient)
+ since the two client types store the flag differently.
+ """
+ if not self.engine.is_alive():
+ return True
+ for stage_client in self.engine.stage_clients:
+ if getattr(stage_client, "_engine_dead", False):
+ return True
+ resources = getattr(stage_client, "resources", None)
+ if resources is not None and getattr(resources, "engine_dead", False):
+ return True
+ return False
+
def check_health(self) -> None:
if not self.engine.is_alive():
raise EngineDeadError("Orchestrator process is not alive")
+ for stage_client in self.engine.stage_clients:
+ if hasattr(stage_client, "check_health"):
+ stage_client.check_health()
def resolve_sampling_params_list(
self,
@@ -271,7 +295,10 @@ def _handle_output_message(
return True, None, None, None
if msg_type == "error":
- raise RuntimeError(msg.get("error", "Orchestrator returned an error message"))
+ error_text = msg.get("error", "Orchestrator returned an error message")
+ if msg.get("fatal"):
+ raise EngineDeadError(error_text)
+ raise RuntimeError(error_text)
if msg_type != "output":
logger.warning("[%s] got unexpected msg type: %s", self.__class__.__name__, msg_type)
@@ -300,6 +327,34 @@ def _handle_output_message(
return False, req_id, stage_id, req_state
+ def _check_engine_output_error(
+ self,
+ result: dict[str, Any],
+ request_id: str,
+ stage_id: int,
+ ) -> None:
+ """Raise if ``engine_outputs`` carries an error field.
+
+ Raises :class:`EngineDeadError` when ``self.errored`` indicates the
+ engine is unrecoverable, otherwise raises :class:`EngineGenerateError`
+ (recoverable, single-request failure).
+ """
+ engine_outputs = result.get("engine_outputs")
+ error_text = getattr(engine_outputs, "error", None)
+ if error_text is None:
+ return
+ logger.error(
+ "[%s] Stage error for req=%s stage-%s: %s",
+ self.__class__.__name__,
+ request_id,
+ stage_id,
+ error_text,
+ )
+ # NOTE: O(n_stages) check for every error.
+ if self.errored:
+ raise EngineDeadError(error_text)
+ raise EngineGenerateError(error_text)
+
def _process_single_result(
self,
result: dict[str, Any],
diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py
index 745b719d5b2..1def10e64cf 100644
--- a/vllm_omni/entrypoints/openai/api_server.py
+++ b/vllm_omni/entrypoints/openai/api_server.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
+import dataclasses
import io
import json
import multiprocessing
@@ -30,7 +31,7 @@
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template
-from vllm.entrypoints.launcher import serve_http
+from vllm.entrypoints.launcher import serve_http, terminate_if_errored
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.openai.api_server import build_app as build_openai_app
@@ -49,6 +50,7 @@
ModelCard,
ModelList,
ModelPermission,
+ RequestResponseMetadata,
)
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
@@ -73,6 +75,7 @@
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
+ create_error_response,
load_aware_call,
process_lora_modules,
with_cancellation,
@@ -82,6 +85,7 @@
from vllm.tool_parsers import ToolParserManager
from vllm.utils import random_uuid
from vllm.utils.system_utils import decorate_logs
+from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.errors import InvalidInputReferenceError
@@ -198,6 +202,50 @@ def _remove_route_from_app(app, path: str, methods: set[str] | None = None):
app.routes.remove(route)
+def _register_omni_exception_handlers(app) -> None:
+ """Override upstream vLLM exception handlers with Omni-aware versions.
+
+ The upstream ``engine_error_handler`` is designed for ``AsyncLLM`` (single
+ EngineCore process). Omni uses a multi-stage orchestrator with different
+ health semantics, so we register our own handlers that:
+
+ - Log multi-stage diagnostic info (orchestrator liveness, per-stage health)
+ when an ``EngineDeadError`` is caught.
+ - Call ``terminate_if_errored``
+ - Return an OpenAI-compatible error JSON response.
+ """
+
+ async def omni_engine_error_handler(
+ req: Request,
+ exc: EngineDeadError | EngineGenerateError,
+ ):
+ request_id = req.state.request_metadata.request_id if hasattr(req.state, "request_metadata") else None
+
+ if req.app.state.args.log_error_stack:
+ logger.exception("Engine Exception caught. Request id: %s", request_id)
+
+ engine = req.app.state.engine_client
+ if isinstance(exc, EngineDeadError):
+ # Log Omni-specific diagnostic information for dead engines.
+ orchestrator_alive = engine.engine.is_alive() if hasattr(engine, "engine") else "N/A"
+ logger.error(
+ "EngineDeadError: orchestrator_alive=%s, errored=%s, request_id=%s",
+ orchestrator_alive,
+ engine.errored,
+ request_id,
+ )
+
+ terminate_if_errored(
+ server=req.app.state.server,
+ engine=engine,
+ )
+ err = create_error_response(exc)
+ return JSONResponse(err.model_dump(), status_code=err.error.code)
+
+ app.exception_handler(EngineGenerateError)(omni_engine_error_handler)
+ app.exception_handler(EngineDeadError)(omni_engine_error_handler)
+
+
class _DiffusionServingModels:
"""Minimal OpenAIServingModels implementation for diffusion-only servers.
@@ -307,6 +355,10 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None,
_remove_route_from_app(app, "/v1/models", {"GET"}) # Remove upstream /v1/models to use omni's handler
app.include_router(router)
+ # OMNI: Override upstream exception handlers with Omni-aware versions
+ # that understand the multi-stage orchestrator lifecycle.
+ _register_omni_exception_handlers(app)
+
await omni_init_app_state(engine_client, app.state, args)
# Conditionally register profiler endpoints based on stage YAML configs
@@ -837,6 +889,7 @@ async def omni_init_app_state(
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
+ state.sleeping_stages = set()
def Omnivideo(request: Request) -> OmniOpenAIServingVideo | None:
@@ -876,6 +929,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
return base_server.create_error_response(message="The model does not support Chat Completions API")
try:
generator = await handler.create_chat_completion(request, raw_request)
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except Exception as e:
logger.exception("Chat completion failed: %s", e)
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e
@@ -971,6 +1026,8 @@ async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request
status_code=result.error.code if result.error else 400,
)
return result
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e
@@ -1005,6 +1062,8 @@ async def create_speech_batch(request: BatchSpeechRequest, raw_request: Request)
status_code=result.error.code if result.error else 400,
)
return JSONResponse(content=result.model_dump())
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except ValueError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) from e
except Exception as e:
@@ -1242,31 +1301,26 @@ async def realtime_websocket(websocket: WebSocket):
async def health(raw_request: Request) -> JSONResponse:
"""Health check endpoint that works for both LLM and diffusion modes.
- Returns 200 OK if the server is healthy.
- For LLM mode: delegates to engine_client health check
- For diffusion mode: checks if diffusion_engine is running
+ Returns 200 OK if the server is healthy, 503 if the engine is dead.
+ Mirrors vLLM upstream's /health which catches EngineDeadError -> 503.
"""
- # Check if we're in diffusion mode
- diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None)
- if diffusion_engine is not None:
- # Diffusion mode health check
- if hasattr(diffusion_engine, "is_running") and diffusion_engine.is_running:
- return JSONResponse(content={"status": "healthy"})
+ engine_client = getattr(raw_request.app.state, "engine_client", None) or getattr(
+ raw_request.app.state, "diffusion_engine", None
+ )
+ if engine_client is None:
return JSONResponse(
- content={"status": "unhealthy", "reason": "Diffusion engine is not running"},
+ content={"status": "unhealthy", "reason": "No engine initialized"},
status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
)
- # LLM mode - delegate to engine_client
- engine_client = getattr(raw_request.app.state, "engine_client", None)
- if engine_client is not None:
+ try:
await engine_client.check_health()
return JSONResponse(content={"status": "healthy"})
-
- return JSONResponse(
- content={"status": "unhealthy", "reason": "No engine initialized"},
- status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
- )
+ except EngineDeadError:
+ return JSONResponse(
+ content={"status": "unhealthy"},
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
+ )
# Remove existing models endpoint if present (from vllm imports)
@@ -1404,6 +1458,16 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
else:
size_str = "model default"
+ # Keep AR stage target grid in sync with requested output size.
+ # GLM-Image consumes target_h/target_w via mm_processor_kwargs.
+ if width is not None and height is not None:
+ prompt["mm_processor_kwargs"] = {
+ "target_h": height,
+ "target_w": width,
+ }
+ # Backward-compatible fallback for processors reading top-level fields.
+ prompt["height"] = height
+ prompt["width"] = width
app_state_args = getattr(raw_request.app.state, "args", None)
_check_max_generated_image_size(app_state_args, width, height)
@@ -1425,6 +1489,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
_update_if_not_none(gen_params, "layers", request.layers)
request_id = f"img_gen-{random_uuid()}"
+ raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id)
logger.info(f"Generating {request.n} image(s) {size_str}")
@@ -1456,8 +1521,9 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
data=image_data,
)
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except HTTPException:
- # Re-raise HTTPExceptions as-is
raise
except ValueError as e:
logger.error(f"Validation error: {e}")
@@ -1629,6 +1695,18 @@ async def edit_images(
_check_max_generated_image_size(app_state_args, width, height, resolution)
size_str = f"{width}x{height}" if width is not None and height is not None else "auto"
+
+ # Keep AR stage target grid in sync with requested output size.
+ # GLM-Image consumes target_h/target_w via mm_processor_kwargs.
+ if width is not None and height is not None:
+ prompt["mm_processor_kwargs"] = {
+ "target_h": height,
+ "target_w": width,
+ }
+ # Backward-compatible fallback for processors reading top-level fields.
+ prompt["height"] = height
+ prompt["width"] = width
+
_update_if_not_none(gen_params, "width", width)
_update_if_not_none(gen_params, "height", height)
@@ -1648,6 +1726,7 @@ async def edit_images(
# 4. Generate images using AsyncOmni (multi-stage mode)
request_id = f"img_edit-{random_uuid()}"
+ raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id)
logger.info(f"Generating {n} image(s) {size_str}")
result = await _generate_with_async_omni(
engine_client=engine_client,
@@ -1679,8 +1758,9 @@ async def edit_images(
size=size_str,
)
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except HTTPException:
- # Re-raise HTTPExceptions as-is
raise
except ValueError as e:
logger.error(f"Validation error: {e}")
@@ -2086,6 +2166,48 @@ def video_response_from_request(model_name: str, req: VideoGenerationRequest) ->
return resp
+def _status_code_for_video_failure(error: VideoError | None) -> int:
+ if error is None:
+ return HTTPStatus.INTERNAL_SERVER_ERROR.value
+
+ if isinstance(error.code, int):
+ if 400 <= error.code < 600:
+ return error.code
+ return HTTPStatus.INTERNAL_SERVER_ERROR.value
+
+ if error.code == "HTTPException":
+ status_text, _, _ = error.message.partition(":")
+ try:
+ status_code = int(status_text)
+ except ValueError:
+ return HTTPStatus.INTERNAL_SERVER_ERROR.value
+ if 400 <= status_code < 600:
+ return status_code
+ return HTTPStatus.INTERNAL_SERVER_ERROR.value
+
+ if error.code == "EngineDeadError":
+ return HTTPStatus.INTERNAL_SERVER_ERROR.value
+ if error.code == "EngineGenerateError":
+ return HTTPStatus.INTERNAL_SERVER_ERROR.value
+
+ return HTTPStatus.INTERNAL_SERVER_ERROR.value
+
+
+def _video_error_from_exception(exc: Exception) -> VideoError:
+ if isinstance(exc, HTTPException):
+ message = str(exc.detail) if exc.detail else str(exc)
+ return VideoError(code=exc.status_code, message=message)
+
+ if isinstance(exc, (EngineGenerateError, EngineDeadError)):
+ err = create_error_response(exc)
+ return VideoError(code=err.error.code, message=err.error.message)
+
+ return VideoError(
+ code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ message=str(exc),
+ )
+
+
def _cleanup_video(video_id: str, output_path: str | None):
try:
if output_path is not None:
@@ -2099,6 +2221,7 @@ async def _run_video_generation_job(
request: VideoGenerationRequest,
video_id: str,
reference_image: ReferenceImage | None = None,
+ app_state: Any | None = None,
) -> None:
job = await VIDEO_STORE.get(video_id)
if job is None:
@@ -2129,17 +2252,36 @@ async def _run_video_generation_job(
"peak_memory_mb": peak_memory_mb,
},
)
+ except (EngineGenerateError, EngineDeadError) as exc:
+ logger.exception("Video generation failed (engine error) for id=%s", video_id)
+
+ _cleanup_video(video_id, output_path)
+ await VIDEO_STORE.update_fields(
+ video_id,
+ {
+ "status": VideoGenerationStatus.FAILED,
+ "completed_at": int(time.time()),
+ "error": _video_error_from_exception(exc),
+ "inference_time_s": time.perf_counter() - started_at,
+ },
+ )
+ # Background tasks can't propagate exceptions to FastAPI handlers.
+ # Actively signal shutdown when the engine is dead.
+ if app_state is not None and isinstance(exc, EngineDeadError):
+ terminate_if_errored(
+ server=app_state.server,
+ engine=app_state.engine_client,
+ )
except Exception as exc:
logger.exception("Video generation failed for id=%s", video_id)
_cleanup_video(video_id, output_path)
- # TODO: It would be better to have a finite collection of errors to return rather than the exception name
await VIDEO_STORE.update_fields(
video_id,
{
"status": VideoGenerationStatus.FAILED,
"completed_at": int(time.time()),
- "error": VideoError(code=type(exc).__name__, message=str(exc)),
+ "error": _video_error_from_exception(exc),
"inference_time_s": time.perf_counter() - started_at,
},
)
@@ -2270,6 +2412,7 @@ async def _parse_video_form(
},
)
async def create_video(
+ raw_request: Request,
ctx: tuple[VideoGenerationRequest, OmniOpenAIServingVideo, str, ReferenceImage | None] = Depends(_parse_video_form),
) -> VideoResponse:
"""Create an asynchronous video generation job.
@@ -2280,7 +2423,9 @@ async def create_video(
request, handler, effective_model_name, reference_image = ctx
ref = video_response_from_request(effective_model_name, request)
await VIDEO_STORE.upsert(ref.id, ref)
- task = asyncio.create_task(_run_video_generation_job(handler, request, ref.id, reference_image))
+ task = asyncio.create_task(
+ _run_video_generation_job(handler, request, ref.id, reference_image, app_state=raw_request.app.state)
+ )
await VIDEO_TASKS.upsert(ref.id, task)
return ref
@@ -2295,6 +2440,7 @@ async def create_video(
},
)
async def create_video_sync(
+ raw_request: Request,
ctx: tuple[VideoGenerationRequest, OmniOpenAIServingVideo, str, ReferenceImage | None] = Depends(_parse_video_form),
) -> Response:
"""Synchronous video generation endpoint.
@@ -2308,6 +2454,7 @@ async def create_video_sync(
"""
request, handler, effective_model_name, reference_image = ctx
request_id = f"video_sync-{random_uuid()}"
+ raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id)
started_at = time.perf_counter()
try:
video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for(
@@ -2319,6 +2466,8 @@ async def create_video_sync(
status_code=HTTPStatus.GATEWAY_TIMEOUT.value,
detail=f"Video generation timed out after {VIDEO_SYNC_TIMEOUT_S}s.",
)
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except HTTPException:
raise
except Exception as exc:
@@ -2379,8 +2528,8 @@ async def list_videos(
return VideoListResponse(data=jobs, has_more=has_more, first_id=first_id, last_id=last_id)
-@router.get("/v1/videos/{video_id}")
-async def retrieve_video(video_id: str) -> VideoResponse:
+@router.get("/v1/videos/{video_id}", response_model=None)
+async def retrieve_video(video_id: str) -> VideoResponse | JSONResponse:
"""Retrieve metadata for a previously created video job.
Args:
@@ -2395,6 +2544,15 @@ async def retrieve_video(video_id: str) -> VideoResponse:
job = await VIDEO_STORE.get(video_id)
if job is None:
raise HTTPException(status_code=404, detail="Video not found")
+ if job.status == VideoGenerationStatus.FAILED:
+ status_code = _status_code_for_video_failure(job.error)
+ content = job.model_dump(mode="json")
+ if content.get("error") is not None:
+ content["error"]["code"] = status_code
+ return JSONResponse(
+ content=content,
+ status_code=status_code,
+ )
return job
@@ -2531,3 +2689,64 @@ async def stop_profile(raw_request: Request, request: ProfileRequest | None = No
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}"
)
+
+
+class OmniSleepRequest(BaseModel):
+ stage_ids: list[int]
+ level: int = 2
+
+
+class OmniWakeupRequest(BaseModel):
+ stage_ids: list[int]
+
+
+@router.post("/v1/omni/sleep")
+async def omni_sleep(request: OmniSleepRequest, raw_request: Request):
+ engine_client = raw_request.app.state.engine_client
+ sleeping_set = raw_request.app.state.sleeping_stages
+ if not hasattr(engine_client, "sleep"):
+ raise HTTPException(status_code=501, detail="Engine does not support sleep")
+ acks = await engine_client.sleep(stage_ids=request.stage_ids, level=request.level)
+ for sid in request.stage_ids:
+ sleeping_set.add(sid)
+ return {"status": "SUCCESS", "acks": [dataclasses.asdict(a) if dataclasses.is_dataclass(a) else a for a in acks]}
+
+
+@router.post("/v1/omni/wakeup")
+async def omni_wakeup(request: OmniWakeupRequest, raw_request: Request):
+ engine_client = raw_request.app.state.engine_client
+ sleeping_set = raw_request.app.state.sleeping_stages
+ if not any(sid in sleeping_set for sid in request.stage_ids):
+ return {"status": "SKIPPED", "reason": "Target stages are not sleeping."}
+ if not hasattr(engine_client, "wake_up"):
+ raise HTTPException(status_code=501, detail="Engine does not support wake_up")
+ acks = await engine_client.wake_up(stage_ids=request.stage_ids)
+ for sid in request.stage_ids:
+ if sid in sleeping_set:
+ sleeping_set.remove(sid)
+ return {"status": "SUCCESS", "acks": [dataclasses.asdict(a) if dataclasses.is_dataclass(a) else a for a in acks]}
+
+
+if __name__ == "__main__":
+ import argparse
+ import asyncio
+
+ from vllm.entrypoints.openai.cli_args import make_arg_parser
+
+ parser = argparse.ArgumentParser(description="vLLM-Omni OpenAI-Compatible REST API server")
+ parser = make_arg_parser(parser)
+ registered_flags = set()
+ for action in parser._actions:
+ registered_flags.update(action.option_strings)
+ if "--omni" not in registered_flags:
+ parser.add_argument("--omni", action="store_true", default=False, help="Enable vLLM-Omni mode.")
+ if "--enable-sleep-mode" not in registered_flags:
+ parser.add_argument(
+ "--enable-sleep-mode", action="store_true", default=False, help="Enable GPU memory pool for sleep mode."
+ )
+ args = parser.parse_args()
+ if not hasattr(args, "model_tag"):
+ setattr(args, "model_tag", args.model)
+ if hasattr(args, "model_tag") and args.model_tag is None:
+ args.model_tag = args.model
+ asyncio.run(omni_run_server(args))
diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py
index 7c2c3164d92..d46c8d43d6b 100644
--- a/vllm_omni/entrypoints/openai/protocol/videos.py
+++ b/vllm_omni/entrypoints/openai/protocol/videos.py
@@ -235,7 +235,7 @@ class VideoGenerationResponse(BaseModel):
class VideoError(BaseModel):
- code: str = Field(..., description="A machine-readable error code that was returned.")
+ code: int | str = Field(..., description="A machine-readable error code that was returned.")
message: str = Field(..., description="A human-readable description of the error that was returned.")
diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py
index 8cddac6a6c5..06b739e3bed 100644
--- a/vllm_omni/entrypoints/openai/serving_chat.py
+++ b/vllm_omni/entrypoints/openai/serving_chat.py
@@ -32,6 +32,7 @@
get_history_tool_calls_cnt,
make_tool_call_id,
)
+from vllm.entrypoints.launcher import terminate_if_errored
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionRequest,
@@ -80,6 +81,7 @@
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.utils.collection_utils import as_list
+from vllm.v1.engine.exceptions import EngineDeadError
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.image_api_utils import validate_layered_layers
@@ -304,20 +306,9 @@ async def create_chat_completion(
# effectively unconditioned and produce nonsense images.
if request.modalities and ("image" in request.modalities):
try:
- messages_as_dicts: list[dict[str, Any]] = []
- for msg in request.messages:
- if hasattr(msg, "model_dump"):
- messages_as_dicts.append(msg.model_dump())
- elif isinstance(msg, dict):
- messages_as_dicts.append(msg)
- else:
- messages_as_dicts.append(
- {
- "role": getattr(msg, "role", "user"),
- "content": getattr(msg, "content", ""),
- }
- )
- extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images(messages_as_dicts)
+ extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images_from_messages(
+ request.messages
+ )
if not extracted_prompt:
return self.create_error_response("No text prompt found in messages")
@@ -328,41 +319,33 @@ async def create_chat_completion(
extra_body = getattr(request, "extra_body", None)
if not extra_body:
extra_body = request.model_extra or {}
- height = extra_body.get("height")
- width = extra_body.get("width")
+
+ height, width = self._resolve_height_width_from_extra_body(extra_body)
+
num_inference_steps = extra_body.get("num_inference_steps")
if num_inference_steps is not None:
try:
num_inference_steps = int(num_inference_steps)
except Exception:
num_inference_steps = None
- if "size" in extra_body:
- try:
- size_str = extra_body["size"]
- if isinstance(size_str, str) and "x" in size_str.lower():
- w, h = size_str.lower().split("x")
- width, height = int(w), int(h)
- except Exception:
- pass
+
negative_prompt = extra_body.get("negative_prompt")
cfg_text_scale = extra_body.get("cfg_text_scale")
cfg_img_scale = extra_body.get("cfg_img_scale")
engine_prompt_image: dict[str, Any] | None = None
- is_img2img = False
if reference_images:
# Best-effort decode first reference image for i2i.
try:
img_bytes = base64.b64decode(reference_images[0])
img = Image.open(BytesIO(img_bytes))
engine_prompt_image = {"img2img": img}
- is_img2img = True
except Exception:
engine_prompt_image = None
# Override the prompts produced by chat-template preprocessing.
tprompt: OmniTextPrompt = {"prompt": extracted_prompt}
- if is_img2img:
+ if engine_prompt_image:
tprompt["modalities"] = ["img2img"]
else:
tprompt["modalities"] = ["image"]
@@ -378,6 +361,13 @@ async def create_chat_completion(
tprompt["mm_processor_kwargs"] = mm_processor_kwargs
if engine_prompt_image is not None:
tprompt["multi_modal_data"] = engine_prompt_image
+ # Provide multi_modal_uuids so that newer vLLM versions
+ # can validate multi_modal_data / multi_modal_uuids
+ # consistency. After the multimodal processor consumes
+ # the image data, the uuids remain as a stable reference.
+ tprompt["multi_modal_uuids"] = {
+ k: [f"{request_id}-{k}-{i}"] for i, k in enumerate(engine_prompt_image)
+ }
engine_prompts = [tprompt]
# Store height/width for applying to diffusion stage sampling params later
@@ -447,6 +437,7 @@ async def create_chat_completion(
tokenizer,
request_metadata,
reasoning_parser,
+ raw_request=raw_request,
)
try:
@@ -544,20 +535,7 @@ async def _preprocess_chat(
# containing image tokens.
req_modalities = getattr(request, "modalities", [])
if req_modalities and ("image" in req_modalities):
- messages_as_dicts: list[dict[str, Any]] = []
- for msg in messages:
- if hasattr(msg, "model_dump"):
- messages_as_dicts.append(msg.model_dump())
- elif isinstance(msg, dict):
- messages_as_dicts.append(msg)
- else:
- messages_as_dicts.append(
- {
- "role": getattr(msg, "role", "user"),
- "content": getattr(msg, "content", ""),
- }
- )
- extracted_prompt, _ = self._extract_diffusion_prompt_and_images(messages_as_dicts)
+ extracted_prompt, _ = self._extract_diffusion_prompt_and_images_from_messages(messages)
if extracted_prompt:
engine_prompt["prompt"] = extracted_prompt
@@ -717,6 +695,9 @@ def _apply_request_overrides(
Starts with YAML defaults and only overrides fields that the user
explicitly provided (non-None values) in the request.
+ For GLM-Image AR stage, if max_tokens is not in YAML and user provides
+ height/width in extra_body, computes max_tokens dynamically.
+
Args:
default_params: Default SamplingParams from stage config YAML.
request: The chat completion request containing user-provided values.
@@ -726,11 +707,56 @@ def _apply_request_overrides(
"""
params = default_params.clone()
+ # Only apply fields explicitly provided by user, not protocol defaults.
+ # Pydantic v2 uses `model_fields_set`; keep v1 fallback for compatibility.
+ explicit_fields = getattr(request, "model_fields_set", None)
+ if explicit_fields is None:
+ explicit_fields = getattr(request, "__fields_set__", set())
+
for field_name in self._OPENAI_SAMPLING_FIELDS:
+ if field_name not in explicit_fields:
+ continue
+
value = getattr(request, field_name, None)
if (value is not None and not isinstance(value, list)) or (isinstance(value, list) and len(value) > 0):
setattr(params, field_name, value)
+ # For GLM-Image: compute max_tokens from height/width with mode-aware
+ # budgeting (t2i vs i2i).
+ extra_body = getattr(request, "extra_body", {}) or {}
+ height, width = self._resolve_height_width_from_extra_body(extra_body)
+
+ # Best-effort mode detection from user messages.
+ # i2i requests include at least one reference image in message content.
+ _, reference_images = self._extract_diffusion_prompt_and_images_from_messages(request.messages)
+ ref_image_count = len(reference_images)
+ is_img2img = ref_image_count > 0
+
+ if height is not None and width is not None:
+ try:
+ from vllm_omni.model_executor.stage_input_processors.glm_image import compute_max_tokens
+
+ max_tokens = getattr(explicit_fields, "max_tokens", None)
+ if max_tokens is None:
+ max_tokens = compute_max_tokens(int(height), int(width), is_i2i=is_img2img)
+ params.max_tokens = max_tokens
+ # Keep target size in stage-0 sampling params so runner/model can
+ # build deterministic M-RoPE grids for t2i (no MM features).
+ extra_args = dict(getattr(params, "extra_args", {}) or {})
+ extra_args["target_h"] = int(height)
+ extra_args["target_w"] = int(width)
+ params.extra_args = extra_args
+ except (ImportError, ValueError, TypeError) as e:
+ logger.warning(f"Failed to compute max_tokens: {e}, using default if available")
+ else:
+ logger.info(
+ "[SamplingParams] Skip dynamic max_tokens (height=%s, width=%s, mode=%s, ref_images=%s)",
+ height,
+ width,
+ "i2i" if is_img2img else "t2i",
+ ref_image_count,
+ )
+
return params
@staticmethod
@@ -802,6 +828,7 @@ async def chat_completion_stream_generator(
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
+ raw_request: Request | None = None,
):
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
@@ -1514,6 +1541,21 @@ async def chat_completion_stream_generator(
delta=False,
)
+ except EngineDeadError as e:
+ logger.error(
+ "EngineDeadError during streaming for request %s: %s",
+ request_id,
+ e,
+ )
+ data = self.create_streaming_error_response(e)
+ yield f"data: {data}\n\n"
+ # Actively signal shutdown instead of waiting for the watchdog
+ # (5s polling interval).
+ if raw_request is not None:
+ terminate_if_errored(
+ server=raw_request.app.state.server,
+ engine=self.engine_client,
+ )
except Exception as e:
logger.exception("Error in chat completion stream generator.")
data = self.create_streaming_error_response(e)
@@ -2643,6 +2685,48 @@ def _extract_diffusion_prompt_and_images(
prompt = " ".join(prompt_parts).strip()
return prompt, images
+ def _extract_diffusion_prompt_and_images_from_messages(
+ self,
+ messages: list[Any],
+ ) -> tuple[str, list[str]]:
+ """Normalize mixed message types and extract prompt + reference images once."""
+ return self._extract_diffusion_prompt_and_images(self._messages_to_dicts(messages))
+
+ @staticmethod
+ def _messages_to_dicts(messages: list[Any]) -> list[dict[str, Any]]:
+ """Normalize request messages to plain dicts."""
+ out: list[dict[str, Any]] = []
+ for msg in messages:
+ if hasattr(msg, "model_dump"):
+ out.append(msg.model_dump())
+ elif isinstance(msg, dict):
+ out.append(msg)
+ else:
+ out.append(
+ {
+ "role": getattr(msg, "role", "user"),
+ "content": getattr(msg, "content", ""),
+ }
+ )
+ return out
+
+ @staticmethod
+ def _resolve_height_width_from_extra_body(extra_body: dict[str, Any]) -> tuple[Any, Any]:
+ """Extract generation height/width with optional size string fallback."""
+ height = extra_body.get("height")
+ width = extra_body.get("width")
+
+ if "size" in extra_body and (height is None or width is None):
+ try:
+ size_str = extra_body["size"]
+ if isinstance(size_str, str) and "x" in size_str.lower():
+ w, h = size_str.lower().split("x")
+ width, height = int(w), int(h)
+ except Exception:
+ pass
+
+ return height, width
+
def _create_error_response(
self,
message: str,
diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py
index ba8292f0c27..c275c779590 100644
--- a/vllm_omni/entrypoints/openai/serving_speech.py
+++ b/vllm_omni/entrypoints/openai/serving_speech.py
@@ -17,12 +17,17 @@
from fastapi import Request, UploadFile
from fastapi.responses import Response, StreamingResponse
from transformers.utils.hub import cached_file
-from vllm.entrypoints.openai.engine.protocol import ErrorResponse
+from vllm.entrypoints.launcher import terminate_if_errored
+from vllm.entrypoints.openai.engine.protocol import (
+ ErrorResponse,
+ RequestResponseMetadata,
+)
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.logger import init_logger
from vllm.multimodal.media import MediaConnector
from vllm.utils import random_uuid
from vllm.utils.async_utils import make_async
+from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin
from vllm_omni.entrypoints.openai.protocol.audio import (
@@ -1151,7 +1156,13 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int
)
return wav_np.tolist(), sr
- async def _generate_audio_chunks(self, generator, request_id: str, response_format: str = "pcm"):
+ async def _generate_audio_chunks(
+ self,
+ generator,
+ request_id: str,
+ response_format: str = "pcm",
+ raw_request: Request | None = None,
+ ):
"""Generate audio chunks for streaming response.
Handles two audio output modes from the engine:
@@ -1226,6 +1237,19 @@ async def _generate_audio_chunks(self, generator, request_id: str, response_form
except asyncio.CancelledError:
logger.info("Streaming request %s cancelled by client", request_id)
raise
+ except EngineDeadError as e:
+ logger.error(
+ "EngineDeadError during streaming speech for %s: %s",
+ request_id,
+ e,
+ )
+ # Actively signal shutdown rather than relying on the watchdog.
+ if raw_request is not None:
+ terminate_if_errored(
+ server=raw_request.app.state.server,
+ engine=self.engine_client,
+ )
+ raise
except Exception as e:
logger.exception("Streaming speech generation failed for %s: %s", request_id, e)
raise
@@ -1500,6 +1524,7 @@ async def _build_cosyvoice3_prompt(
async def _prepare_speech_generation(
self,
request: OpenAICreateSpeechRequest,
+ request_id: str | None = None,
) -> tuple[str, Any, dict[str, Any]]:
if self.engine_client.errored:
raise self.engine_client.dead_error
@@ -1592,7 +1617,7 @@ async def _prepare_speech_generation(
tts_params = {}
prompt = {"prompt": request.input}
- request_id = f"speech-{random_uuid()}"
+ request_id = request_id or f"speech-{random_uuid()}"
if self._is_fish_speech:
model_type = "fish_speech"
elif self._tts_model_type == "voxtral_tts":
@@ -1674,8 +1699,9 @@ async def _generate_audio_bytes(
self,
request: OpenAICreateSpeechRequest,
base64_encode: bool = False,
+ request_id: str | None = None,
) -> tuple[bytes | str, str]:
- request_id, generator, _ = await self._prepare_speech_generation(request)
+ request_id, generator, _ = await self._prepare_speech_generation(request, request_id=request_id)
final_output: OmniRequestOutput | None = None
async for res in generator:
@@ -1799,6 +1825,8 @@ async def _create_diffusion_speech(
except asyncio.CancelledError:
return self._diffusion_error_response("Client disconnected")
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except ValueError as e:
return self._diffusion_error_response(str(e))
except Exception as e:
@@ -1844,6 +1872,12 @@ async def create_speech(
logger.error("Error with model %s", error_check_ret)
return error_check_ret
+ request_id = f"speech-{random_uuid()}"
+ if raw_request:
+ raw_request.state.request_metadata = RequestResponseMetadata(
+ request_id=request_id,
+ )
+
try:
if request.stream:
# Determine response format and media type for streaming
@@ -1864,17 +1898,24 @@ async def create_speech(
)
media_type = "audio/wav" if response_format == "wav" else "audio/pcm"
- request_id, generator, _ = await self._prepare_speech_generation(request)
+ _, generator, _ = await self._prepare_speech_generation(request, request_id=request_id)
return StreamingResponse(
- self._generate_audio_chunks(generator, request_id, response_format),
+ self._generate_audio_chunks(
+ generator,
+ request_id,
+ response_format,
+ raw_request=raw_request,
+ ),
media_type=media_type,
)
- audio_bytes, media_type = await self._generate_audio_bytes(request)
+ audio_bytes, media_type = await self._generate_audio_bytes(request, request_id=request_id)
return Response(content=audio_bytes, media_type=media_type)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
+ except (EngineGenerateError, EngineDeadError):
+ raise # Propagate to the global Omni exception handler
except ValueError as e:
return self.create_error_response(e)
except Exception as e:
diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py
index c6dffd05426..cca6ce56870 100644
--- a/vllm_omni/inputs/preprocess.py
+++ b/vllm_omni/inputs/preprocess.py
@@ -29,6 +29,8 @@ def _process_text(
self,
parsed_content: OmniTextPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
+ *,
+ mm_uuids: Any | None = None,
) -> OmniTokenInputs | MultiModalInput:
"""Process text prompts with support for mm_processor_kwargs.
@@ -38,6 +40,10 @@ def _process_text(
"""
prompt_text = parsed_content["prompt"]
mm_processor_kwargs = parsed_content.get("mm_processor_kwargs") or {}
+ # When the deprecated raw-prompt path is used, process_inputs does
+ # not pass mm_uuids to preprocess(). Fall back to reading it from
+ # the prompt dict so the Renderer's _validate_mm_uuids can see it.
+ effective_mm_uuids = mm_uuids or parsed_content.get("multi_modal_uuids")
inputs: OmniTokenInputs | MultiModalInput
if multi_modal_data := parsed_content.get("multi_modal_data"):
@@ -46,6 +52,7 @@ def _process_text(
multi_modal_data,
mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
+ mm_uuids=effective_mm_uuids,
)
prompt_embeds = parsed_content.get("prompt_embeds")
if prompt_embeds is not None:
@@ -59,6 +66,7 @@ def _process_text(
{},
mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
+ mm_uuids=effective_mm_uuids,
)
else:
prompt_token_ids = self._tokenize_prompt(
@@ -142,6 +150,8 @@ def _prompt_to_llm_inputs(
self,
prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
+ *,
+ mm_uuids: Any | None = None,
) -> SingletonInput:
"""
Extract the singleton inputs from a prompt.
@@ -166,6 +176,7 @@ def _prompt_to_llm_inputs(
return self._process_text(
prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs,
+ mm_uuids=mm_uuids,
)
assert_never(prompt) # type: ignore[arg-type]
diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
index 62776cbb31f..45f7afc6931 100644
--- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
+++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py
@@ -649,6 +649,7 @@ def talker_mtp(
input_embeds: torch.Tensor,
last_talker_hidden: torch.Tensor,
text_step: torch.Tensor,
+ **kwargs: Any,
) -> tuple[torch.Tensor, torch.Tensor]:
"""GPU fast-path: run Fast AR to predict residual codebook codes.
diff --git a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py
index 31eed9b2cb9..bf21c01a645 100644
--- a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py
+++ b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py
@@ -21,6 +21,7 @@
# limitations under the License.
"""Inference-only GLM-Image model compatible with HuggingFace weights."""
+import math
import os
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal
@@ -127,6 +128,14 @@ def _get_subparsers(self):
parsers["img2img"] = self._parse_image_data
return parsers
+ def parse_mm_data(self, mm_data, **kwargs):
+ # Normalize "img2img" to "image" so the rest of the pipeline
+ # (mm_hashes, _merge_mm_kwargs) uses a single modality key.
+ normalized = {}
+ for k, v in mm_data.items():
+ normalized["image" if k == "img2img" else k] = v
+ return super().parse_mm_data(normalized, **kwargs)
+
class GlmImageProcessingInfo(BaseProcessingInfo):
"""
@@ -346,6 +355,10 @@ def _call_hf_processor(
target_h = mm_kwargs.get("target_h", 1024) if mm_kwargs else 1024
target_w = mm_kwargs.get("target_w", 1024) if mm_kwargs else 1024
+ logger.debug(
+ f"_call_hf_processor: target dimensions for generation: {target_h}x{target_w}, mm_kwargs={mm_kwargs}"
+ )
+
if not mm_data or not mm_data.get("images"):
# Text-to-image mode
if processor is not None:
@@ -566,6 +579,58 @@ def _apply_hf_processor_mm_only(
tensor_type="pt",
)
+ def _apply_hf_processor_text_only(
+ self, prompt_text: str, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object]
+ ) -> list[int]:
+ prompt_ids, _, _ = super()._apply_hf_processor_text_mm(
+ prompt_text=prompt_text,
+ mm_items=MultiModalDataItems({}),
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
+ tokenization_kwargs=tokenization_kwargs,
+ )
+ return prompt_ids
+
+ def _build_generation_grids(self, hf_processor_mm_kwargs: Mapping[str, object]) -> torch.Tensor:
+ """Build generation grids for M-RoPE decode positions.
+
+ For GLM-Image generation, decode order is:
+ 1) small preview grid
+ 2) large target grid
+ 3) EOS
+
+ We store grids as [large, small] to match HF processor behavior, and
+ decode logic consumes them in reverse order.
+ """
+
+ target_h = (
+ hf_processor_mm_kwargs.get("target_h") if isinstance(hf_processor_mm_kwargs.get("target_h"), int) else None
+ )
+ target_w = (
+ hf_processor_mm_kwargs.get("target_w") if isinstance(hf_processor_mm_kwargs.get("target_w"), int) else None
+ )
+ if target_h is None or target_w is None:
+ target_h = (
+ hf_processor_mm_kwargs.get("height") if isinstance(hf_processor_mm_kwargs.get("height"), int) else 1024
+ )
+ target_w = (
+ hf_processor_mm_kwargs.get("width") if isinstance(hf_processor_mm_kwargs.get("width"), int) else 1024
+ )
+
+ factor = 32
+ target_h = (target_h // factor) * factor
+ target_w = (target_w // factor) * factor
+ token_h = target_h // factor
+ token_w = target_w // factor
+
+ ratio = token_h / token_w if token_w > 0 else 1.0
+ small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
+ small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
+
+ return torch.tensor(
+ [[1, token_h, token_w], [1, small_token_h, small_token_w]],
+ dtype=torch.long,
+ )
+
def _apply_hf_processor_main(
self,
prompt: str | list[int],
@@ -594,126 +659,145 @@ def _apply_hf_processor_main(
logger.debug(f"_apply_hf_processor_main: mm_counts={mm_counts}, num_images={num_images}")
- if num_images == 0 or enable_hf_prompt_update:
+ if num_images == 0 and isinstance(prompt, str):
# t2i mode or normal flow - use parent implementation
- return super()._apply_hf_processor_main(
- prompt=prompt,
+ prompt_ids = self._apply_hf_processor_text_only(
+ prompt_text=prompt,
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
+ tokenization_kwargs=tokenization_kwargs,
+ )
+ mm_processed_data = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
- enable_hf_prompt_update=enable_hf_prompt_update,
)
- # i2i mode with enable_hf_prompt_update=False (cache miss scenario)
- # We need to build prompt_ids with image placeholders
- logger.debug(f"_apply_hf_processor_main: i2i mode with enable_hf_prompt_update=False, num_images={num_images}")
-
- # Get mm data from our overridden _apply_hf_processor_mm_only
- mm_processed_data = self._apply_hf_processor_mm_only(
- mm_items=mm_items,
- hf_processor_mm_kwargs=hf_processor_mm_kwargs,
- tokenization_kwargs=tokenization_kwargs,
- )
-
- # In this path we do NOT call HF apply_chat_template, so we must still
- # provide full grids (source + target) for M-RoPE to compute decode positions.
- # Keep `image_grid_thw` source-only for MM batching/validation.
- try:
- source_grid_thw = mm_processed_data.get("image_grid_thw")
- if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor):
- # Compute target grid following HF GlmImageProcessor: factor=32.
- # Prefer explicit target_h/target_w if present, otherwise fall back.
- target_h = (
- hf_processor_mm_kwargs.get("target_h")
- if isinstance(hf_processor_mm_kwargs.get("target_h"), int)
- else None
- )
- target_w = (
- hf_processor_mm_kwargs.get("target_w")
- if isinstance(hf_processor_mm_kwargs.get("target_w"), int)
- else None
+ # t2i has no source images, so mm features cannot provide image_grid_thw.
+ # Provide explicit generation grids for M-RoPE to avoid fallback token parsing
+ # (which can degrade high-resolution spatial positions, e.g. 1920x1920).
+ try:
+ mrope_grid_thw = self._build_generation_grids(hf_processor_mm_kwargs)
+ mm_processed_data["mrope_image_grid_thw"] = mrope_grid_thw
+ logger.info(
+ "_apply_hf_processor_main t2i: mrope_image_grid_thw=%s",
+ mrope_grid_thw.tolist(),
)
- if target_h is None or target_w is None:
- # Some callers pass generation size as height/width.
- target_h = (
- hf_processor_mm_kwargs.get("height")
- if isinstance(hf_processor_mm_kwargs.get("height"), int)
- else 1024
- )
- target_w = (
- hf_processor_mm_kwargs.get("width")
- if isinstance(hf_processor_mm_kwargs.get("width"), int)
- else 1024
- )
+ except Exception as e:
+ logger.warning("_apply_hf_processor_main t2i: failed to set mrope_image_grid_thw: %s", e)
- factor = 32
- target_h = (target_h // factor) * factor
- target_w = (target_w // factor) * factor
- token_h = target_h // factor
- token_w = target_w // factor
- target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype)
+ return prompt_ids, mm_processed_data, False
- mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0)
- except Exception:
- # Best-effort only; M-RoPE has additional fallbacks.
- pass
+ # i2i mode: use unified HF processor path only.
+ # This avoids drift between duplicated manual/HF i2i implementations.
+ logger.debug(
+ "_apply_hf_processor_main: i2i mode (enable_hf_prompt_update=%s), num_images=%s",
+ enable_hf_prompt_update,
+ num_images,
+ )
- # Build prompt_ids with image placeholders
- # _apply_prompt_updates will replace each [image_token_id] with expanded tokens
- tokenizer = self.info.get_tokenizer()
- image_token_id = tokenizer.convert_tokens_to_ids("<|image|>")
+ if not isinstance(prompt, str):
+ # Online OpenAI chat preprocessing can arrive here with tokenized
+ # prompts (list[int]) before serving_chat replaces engine prompt
+ # with the clean text prompt. Do not fail the whole request.
+ logger.warning(
+ "_apply_hf_processor_main i2i: got tokenized prompt type=%s; "
+ "using compatibility path for preprocessing",
+ type(prompt).__name__,
+ )
+
+ prompt_ids = list(prompt)
+ mm_processed_data = self._apply_hf_processor_mm_only(
+ mm_items=mm_items,
+ hf_processor_mm_kwargs=hf_processor_mm_kwargs,
+ tokenization_kwargs=tokenization_kwargs,
+ )
- if isinstance(prompt, str):
- # Match HF GlmImageProcessor behavior: append target grid tokens + BOS.
- # This helps M-RoPE/grid parsing and keeps i2i vs t2i behavior aligned.
+ # Preserve full grids for M-RoPE decode (source + target), while
+ # keeping image_grid_thw source-only for MM batching.
try:
- grid_bos = getattr(tokenizer, "grid_bos_token", "")
- grid_eos = getattr(tokenizer, "grid_eos_token", "")
- bos = getattr(tokenizer, "bos_token", "")
-
- # Use the same target sizes we used for mrope grids when available.
- target_h = (
- hf_processor_mm_kwargs.get("target_h")
- if isinstance(hf_processor_mm_kwargs.get("target_h"), int)
- else None
- )
- target_w = (
- hf_processor_mm_kwargs.get("target_w")
- if isinstance(hf_processor_mm_kwargs.get("target_w"), int)
- else None
- )
- if target_h is None or target_w is None:
+ source_grid_thw = mm_processed_data.get("image_grid_thw")
+ if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor):
target_h = (
- hf_processor_mm_kwargs.get("height")
- if isinstance(hf_processor_mm_kwargs.get("height"), int)
- else 1024
+ hf_processor_mm_kwargs.get("target_h")
+ if isinstance(hf_processor_mm_kwargs.get("target_h"), int)
+ else None
)
target_w = (
- hf_processor_mm_kwargs.get("width")
- if isinstance(hf_processor_mm_kwargs.get("width"), int)
- else 1024
+ hf_processor_mm_kwargs.get("target_w")
+ if isinstance(hf_processor_mm_kwargs.get("target_w"), int)
+ else None
)
+ if target_h is None or target_w is None:
+ target_h = (
+ hf_processor_mm_kwargs.get("height")
+ if isinstance(hf_processor_mm_kwargs.get("height"), int)
+ else 1024
+ )
+ target_w = (
+ hf_processor_mm_kwargs.get("width")
+ if isinstance(hf_processor_mm_kwargs.get("width"), int)
+ else 1024
+ )
- factor = 32
- target_h = (target_h // factor) * factor
- target_w = (target_w // factor) * factor
- token_h = target_h // factor
- token_w = target_w // factor
-
- expanded_prompt = f"{prompt}{grid_bos}{token_h} {token_w}{grid_eos}{bos}"
- text_ids = tokenizer.encode(expanded_prompt, add_special_tokens=False)
+ factor = 32
+ token_h = max(1, target_h // factor)
+ token_w = max(1, target_w // factor)
+ target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype)
+ mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0)
except Exception:
- text_ids = tokenizer.encode(prompt, add_special_tokens=False)
+ pass
+
+ # Prompt updates will expand image placeholders in this compatibility path.
+ return prompt_ids, mm_processed_data, False
+
+ images = mm_items.get_items("image", ImageProcessorItems)
+ image_list = [images.get(i) for i in range(images.get_count())]
+ if not image_list:
+ raise ValueError("GLM-Image i2i requires at least one source image in mm_items")
+
+ hf_inputs = self._call_hf_processor(
+ prompt=prompt,
+ mm_data={"images": image_list},
+ mm_kwargs=hf_processor_mm_kwargs,
+ tok_kwargs=tokenization_kwargs,
+ )
+
+ input_ids = hf_inputs.get("input_ids")
+ if input_ids is None:
+ raise ValueError("HF i2i processor returned no input_ids")
+
+ if isinstance(input_ids, torch.Tensor):
+ prompt_ids = input_ids[0].tolist() if input_ids.dim() > 1 else input_ids.tolist()
else:
- text_ids = list(prompt)
+ prompt_ids = (
+ input_ids[0]
+ if isinstance(input_ids, list) and input_ids and isinstance(input_ids[0], list)
+ else list(input_ids)
+ )
- # Prepend image placeholders - one per image
- prompt_ids = [image_token_id] * num_images + text_ids
+ mm_processed_data = BatchFeature(dict(), tensor_type="pt")
+ for key in ("pixel_values", "image_grid_thw", "mrope_image_grid_thw"):
+ value = hf_inputs.get(key)
+ if value is not None:
+ mm_processed_data[key] = value
- logger.debug(f"_apply_hf_processor_main: built prompt_ids with {num_images} image placeholders")
+ image_grid_thw = mm_processed_data.get("image_grid_thw")
+ mrope_grid_thw = mm_processed_data.get("mrope_image_grid_thw")
+ hf_config = self.info.get_hf_config()
+ image_token_id = getattr(hf_config, "image_token_id", 167855)
+ image_token_count = prompt_ids.count(image_token_id)
+ logger.info(
+ "_apply_hf_processor_main i2i(HF): num_images=%s, prompt_len=%s, image_token_count=%s, "
+ "source_grid_shape=%s, mrope_grid_shape=%s",
+ num_images,
+ len(prompt_ids),
+ image_token_count,
+ tuple(image_grid_thw.shape) if image_grid_thw is not None else None,
+ tuple(mrope_grid_thw.shape) if mrope_grid_thw is not None else None,
+ )
- # Return is_update_applied=False so _apply_prompt_updates will expand the placeholders
- return prompt_ids, mm_processed_data, False
+ # HF processor already expanded image placeholders in input_ids.
+ return prompt_ids, mm_processed_data, True
def _get_mm_fields_config(
self,
@@ -2667,9 +2751,23 @@ def get_mrope_input_positions(
# Input format: "textH Wh w" where =image_start_token_id=16384
# For 1024x1024: H=32, W=32 (large), h=16, w=16 (small preview)
if not image_grid_thw:
+ # Preferred path for t2i: use explicit target size propagated from
+ # serving/request sampling params. This avoids fragile grid parsing
+ # from token IDs and matches HF processor grid construction.
+ target_h = kwargs.get("target_h")
+ target_w = kwargs.get("target_w")
+ if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0:
+ factor = 32
+ token_h = target_h // factor
+ token_w = target_w // factor
+ ratio = token_h / token_w if token_w > 0 else 1.0
+ small_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
+ small_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
+ image_grid_thw = [[1, token_h, token_w], [1, small_h, small_w]]
+
# Try to parse from kwargs (passed from processor)
hf_config_arg = kwargs.get("hf_config")
- if hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"):
+ if (not image_grid_thw) and hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"):
image_grid_thw = hf_config_arg.image_grid_thw
# If still empty, try to infer from input tokens
@@ -2723,19 +2821,29 @@ def get_mrope_input_positions(
prompt_ends_with_start = len(input_tokens) > 0 and input_tokens[-1] == image_start_token_id
if prompt_ends_with_start and len(image_grid_thw) == num_source_images and num_source_images > 0:
# i2i mode: source grids exist but no target grids
- # Parse target grids from prompt tokens or use defaults
- parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config)
- if parsed_grids:
- # parsed_grids contains all grids mentioned in prompt
- # For i2i, add only the generation target grids
- if len(parsed_grids) > num_source_images:
- image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:]
+ # Prefer explicit target size propagated from request sampling params.
+ # This avoids fragile grid parsing from token IDs for non-1024 i2i.
+ target_h = kwargs.get("target_h")
+ target_w = kwargs.get("target_w")
+ if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0:
+ factor = 32
+ token_h = target_h // factor
+ token_w = target_w // factor
+ image_grid_thw = list(image_grid_thw) + [[1, token_h, token_w]]
+ else:
+ # Parse target grids from prompt tokens or use defaults
+ parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config)
+ if parsed_grids:
+ # parsed_grids contains all grids mentioned in prompt
+ # For i2i, add only the generation target grids
+ if len(parsed_grids) > num_source_images:
+ image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:]
+ else:
+ # Fallback: add default 1024x1024 generation grid (1 target for i2i)
+ image_grid_thw = list(image_grid_thw) + [[1, 32, 32]]
else:
- # Fallback: add default 1024x1024 generation grids (1 target for i2i)
+ # Fallback to default 1024x1024 grid for generation
image_grid_thw = list(image_grid_thw) + [[1, 32, 32]]
- else:
- # Fallback to default 1024x1024 grids for generation
- image_grid_thw = list(image_grid_thw) + [[1, 32, 32]]
llm_pos_ids_list: list[torch.Tensor] = []
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
index f06ecf41d22..ce336037501 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
@@ -689,6 +689,7 @@ def talker_mtp(
input_embeds: torch.Tensor,
last_talker_hidden: torch.Tensor,
text_step: torch.Tensor,
+ **kwargs: Any,
):
# TODO(Peiqi): not support intermediate_tensors now
input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1))
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
index d9cbcf7d4ef..31fd8062278 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
@@ -415,6 +415,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# In-memory LRU cache for voice extraction artifacts (Base voice clone).
self._voice_cache = VoiceEmbeddingCache()
+ raw_subtalker_sampling = getattr(vllm_config.model_config, "subtalker_sampling_params", None)
+ self._subtalker_sampling_params: dict[str, Any] = (
+ dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {}
+ )
# -------------------- vLLM required hooks --------------------
@@ -1638,6 +1642,10 @@ def talker_mtp(
input_embeds: torch.Tensor,
last_talker_hidden: torch.Tensor,
text_step: torch.Tensor,
+ do_sample: bool | None = None,
+ temperature: float | None = None,
+ top_k: int | None = None,
+ top_p: float | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""GPU fast-path used by OmniGPUModelRunner to predict residual codebooks (1..Q-1).
Returns (inputs_embeds, audio_codes) for the current step."""
@@ -1656,15 +1664,24 @@ def talker_mtp(
audio_codes = input_ids.reshape(bsz, 1)
return (last_id_hidden + text_step).reshape(bsz, -1), audio_codes
- # Predict residual codes (1..Q-1) with HF reference sampling params.
+ subtalker_params = self._subtalker_sampling_params
+ if do_sample is None:
+ do_sample = bool(subtalker_params.get("do_sample", True))
+ if temperature is None:
+ temperature = float(subtalker_params.get("temperature", 0.9))
+ if top_k is None:
+ top_k = int(subtalker_params.get("top_k", 50))
+ if top_p is None:
+ top_p = float(subtalker_params.get("top_p", 1.0))
+
audio_codes = self.code_predictor(
layer0_code=input_ids.reshape(bsz, 1),
layer0_embed=last_id_hidden,
last_talker_hidden=past_hidden,
- do_sample=True,
- temperature=0.9,
- top_k=50,
- top_p=1.0,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
) # [B, Q]
# Map invalid layer-0 ids (e.g. EOS) to PAD=0 so SpeechTokenizer sees only real codes.
diff --git a/vllm_omni/model_executor/stage_configs/glm_image.yaml b/vllm_omni/model_executor/stage_configs/glm_image.yaml
index 05ac84a7a09..f3ed6c7213d 100644
--- a/vllm_omni/model_executor/stage_configs/glm_image.yaml
+++ b/vllm_omni/model_executor/stage_configs/glm_image.yaml
@@ -33,7 +33,6 @@ stage_args:
temperature: 0.9 # From model's generation_config.json
top_p: 0.75 # From model's generation_config.json
top_k: 16512 # vision_vocab_size from generation_config.json
- max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1)
stop_token_ids: [16385] # eos_token_id from generation_config.json
seed: 42
detokenize: false
diff --git a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
index 7bd66c403fc..2a85a6dadbc 100644
--- a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
+++ b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
@@ -35,7 +35,6 @@ stage_args:
temperature: 0.9 # From model's generation_config.json
top_p: 0.75 # From model's generation_config.json
top_k: 16512 # vision_vocab_size from generation_config.json
- max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1)
stop_token_ids: [16385] # eos_token_id from generation_config.json
seed: 42
detokenize: false
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml
new file mode 100644
index 00000000000..f0797c63270
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml
@@ -0,0 +1,96 @@
+# Stage config for running Hunyuan-Image3.0 with AR→DiT KV reuse.
+# Stage 0: AR Model (vLLM implementation)
+# Stage 1: DiT Model (diffusion)
+#
+# text-to-image flow: AR (stage 0) → KV transfer → DiT (stage 1)
+# image-to-text flow: AR (stage 0) only
+#
+# Compared to hunyuan_image3_t2i.yaml, this config:
+# 1. Enables both stages [0, 1] for text-to-image (AR prefill + DiT denoising)
+# 2. Adds omni_kv_config to send/receive KV cache between stages
+
+# The following config has been verified on 8x L40S-48G GPU (4 for AR + 4 for DiT).
+stage_args:
+ - stage_id: 0
+ stage_type: llm # Use llm stage type for AR stages
+ runtime:
+ process: true # Run this stage in a separate process
+ devices: "0,1,2,3" # AR stage uses GPU 0-3
+ engine_args:
+ model_stage: AR
+ max_num_seqs: 1
+ model_arch: HunyuanImage3ForCausalMM
+ worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: true # Now we only support eager mode
+ trust_remote_code: true
+ engine_output_type: latent
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4
+ pipeline_parallel_size: 1
+ hf_overrides:
+ rope_parameters:
+ mrope_section: [0, 32, 32]
+ rope_type: default
+ omni_kv_config:
+ need_send_cache: true
+ kv_transfer_criteria:
+ type: prefill_finished # Send KV cache after AR prefill completes
+ is_comprehension: true
+ final_output: true
+ final_output_type: text
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ detokenize: True
+ repetition_penalty: 1.1
+ - stage_id: 1
+ stage_type: diffusion
+ runtime:
+ process: true
+ devices: "4,5,6,7" # DiT stage uses GPU 4-7
+ max_batch_size: 1
+ engine_args:
+ model_stage: diffusion
+ enforce_eager: true
+ distributed_executor_backend: "mp"
+ vae_use_slicing: false
+ vae_use_tiling: false
+ cache_backend: null
+ cache_config: null
+ enable_cache_dit_summary: false
+ omni_kv_config:
+ need_recv_cache: true # Receive AR KV cache from stage 0
+ parallel_config:
+ pipeline_parallel_size: 1
+ data_parallel_size: 1
+ tensor_parallel_size: 4
+ enable_expert_parallel: false
+ sequence_parallel_size: 1
+ ulysses_degree: 1
+ ring_degree: 1
+ cfg_parallel_size: 1
+ vae_patch_parallel_size: 1
+ use_hsdp: false
+ hsdp_shard_size: -1
+ hsdp_replicate_size: 1
+ engine_input_source: [0] # Receive input (including KV) from stage 0
+ final_output: true
+ final_output_type: image
+
+# Top-level runtime config: windows, edges, and connectors
+runtime:
+ enabled: true
+ defaults:
+ window_size: -1 # Trigger downstream only after full upstream completion
+ max_inflight: 1 # Process serially within each stage
+
+ edges:
+ - from: 0
+ to: 1
+ window_size: -1
diff --git a/vllm_omni/model_executor/stage_input_processors/glm_image.py b/vllm_omni/model_executor/stage_input_processors/glm_image.py
index 738a3732618..5ea05befb03 100644
--- a/vllm_omni/model_executor/stage_input_processors/glm_image.py
+++ b/vllm_omni/model_executor/stage_input_processors/glm_image.py
@@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Stage input processor for GLM-Image: AR → Diffusion transition."""
+import math
+import time
from typing import Any
import torch
@@ -13,6 +15,86 @@
logger = init_logger(__name__)
+def _has_source_image(mm_data: Any) -> bool:
+ """Return whether prompt multi_modal_data contains a source image.
+
+ Normalizes legacy/new keys used across omni pipelines:
+ - `image`: single PIL image or list
+ - `img2img`: legacy single-image key
+ - `images`: list or single image
+ """
+ if not isinstance(mm_data, dict):
+ return False
+ if mm_data.get("image") is not None:
+ return True
+ if mm_data.get("img2img") is not None:
+ return True
+ images = mm_data.get("images")
+ return bool(images)
+
+
+def _first_source_image(mm_data: Any) -> Any:
+ """Get first source image from normalized multimodal keys."""
+ if not isinstance(mm_data, dict):
+ return None
+
+ image = mm_data.get("image")
+ if image is not None:
+ if isinstance(image, list):
+ return image[0] if image else None
+ return image
+
+ image = mm_data.get("img2img")
+ if image is not None:
+ if isinstance(image, list):
+ return image[0] if image else None
+ return image
+
+ images = mm_data.get("images")
+ if isinstance(images, list):
+ return images[0] if images else None
+ return images
+
+
+def compute_max_tokens(height: int, width: int, factor: int = 32, is_i2i: bool = False) -> int:
+ """
+ Compute max_new_tokens for GLM-Image AR generation.
+
+ GLM-Image generation differs by mode:
+
+ - text-to-image (t2i): small preview + large target + EOS
+ - image-to-image (i2i): large target + EOS
+
+ Args:
+ height: Target image height in pixels
+ width: Target image width in pixels
+ factor: Downsampling factor (32 for GLM-Image AR output)
+ is_i2i: Whether the request is image-to-image mode
+
+ Returns:
+ Total number of tokens to generate for the specified mode
+ """
+ # Large image tokens (target resolution)
+ token_h = height // factor
+ token_w = width // factor
+ large_tokens = token_h * token_w
+
+ # Small preview tokens (half resolution in each dimension)
+ import math
+
+ ratio = token_h / token_w if token_w > 0 else 1.0
+ small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
+ small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
+ small_tokens = small_token_h * small_token_w
+
+ # Mode-dependent totals:
+ # - t2i: small + large + EOS
+ # - i2i: large + EOS
+ if is_i2i:
+ return large_tokens + 1
+ return small_tokens + large_tokens + 1
+
+
def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
"""Upsample token IDs by 2x using nearest neighbor interpolation.
@@ -56,39 +138,49 @@ def _parse_generated_tokens(
large_image_tokens = token_h * token_w
# Calculate small preview image dimensions (used in text-to-image)
- small_token_h = token_h // 2
- small_token_w = token_w // 2
+ ratio = token_h / token_w if token_w > 0 else 1.0
+ small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2)))
+ small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2)))
small_image_tokens = small_token_h * small_token_w
token_tensor = torch.tensor(token_ids, dtype=torch.long)
# Remove EOS token (16385) from the end if present
eos_token_id = 16385
- if len(token_ids) > 0 and token_ids[-1] == eos_token_id:
+ has_terminal_eos = len(token_ids) > 0 and token_ids[-1] == eos_token_id
+ if has_terminal_eos:
token_tensor = token_tensor[:-1]
actual_tokens = len(token_tensor)
- logger.debug(
- f"[_parse_generated_tokens] height={height}, width={width}, "
- f"token_h={token_h}, token_w={token_w}, "
- f"large_image_tokens={large_image_tokens}, small_image_tokens={small_image_tokens}, "
- f"actual_tokens={actual_tokens}"
- )
-
if is_i2i:
- # Image-to-image mode: check if AR generated small+large tokens (like t2i) or just large tokens
- # Some AR models output small+large even in i2i mode
if actual_tokens >= small_image_tokens + large_image_tokens:
- # AR generated full t2i-style output, extract large tokens after small
large_start = small_image_tokens
large_end = large_start + large_image_tokens
prior_token_ids_d32 = token_tensor[large_start:large_end]
actual_h, actual_w = token_h, token_w
- else:
- # AR generated only large tokens (pure i2i output)
+ logger.warning(
+ "[_parse_generated_tokens] i2i detected t2i-style token layout; "
+ "using small-offset extraction: large_start=%s large_end=%s",
+ large_start,
+ large_end,
+ )
+ elif actual_tokens >= large_image_tokens:
prior_token_ids_d32 = token_tensor[:large_image_tokens]
actual_h, actual_w = token_h, token_w
+ logger.info(
+ "[_parse_generated_tokens] i2i using offset-0 extraction: large_tokens=%s",
+ large_image_tokens,
+ )
+ else:
+ logger.warning(
+ "[_parse_generated_tokens] i2i token parse failed: actual_tokens=%s < expected_large_tokens=%s",
+ actual_tokens,
+ large_image_tokens,
+ )
+ raise ValueError(
+ f"i2i token parse failed: actual_tokens={actual_tokens} < expected_large_tokens={large_image_tokens}"
+ )
elif actual_tokens >= small_image_tokens + large_image_tokens:
# Text-to-image: extract large image tokens after small image tokens
large_start = small_image_tokens
@@ -96,43 +188,22 @@ def _parse_generated_tokens(
prior_token_ids_d32 = token_tensor[large_start:large_end]
actual_h, actual_w = token_h, token_w
elif actual_tokens >= large_image_tokens:
- # Image-to-image: large image tokens are at the beginning
- prior_token_ids_d32 = token_tensor[:large_image_tokens]
- actual_h, actual_w = token_h, token_w
+ logger.warning(
+ "[_parse_generated_tokens] t2i token parse failed: got only large tokens without small preview "
+ "(actual_tokens=%s, expected_small_plus_large=%s)",
+ actual_tokens,
+ small_image_tokens + large_image_tokens,
+ )
+ raise ValueError("t2i token parse failed: missing small-preview tokens; refusing low-quality fallback")
else:
- # Insufficient tokens - try to infer the actual grid size
- import math
-
- for scale in [1, 2, 4]:
- test_h = token_h // scale
- test_w = token_w // scale
- test_small_h = test_h // 2
- test_small_w = test_w // 2
- test_large = test_h * test_w
- test_small = test_small_h * test_small_w
-
- if actual_tokens >= test_small + test_large:
- prior_token_ids_d32 = token_tensor[test_small : test_small + test_large]
- actual_h, actual_w = test_h, test_w
- height = test_h * factor
- width = test_w * factor
- logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}")
- break
- elif actual_tokens >= test_large:
- prior_token_ids_d32 = token_tensor[:test_large]
- actual_h, actual_w = test_h, test_w
- height = test_h * factor
- width = test_w * factor
- logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}")
- break
- else:
- sqrt_tokens = int(math.sqrt(actual_tokens))
- actual_h = actual_w = sqrt_tokens
- usable_tokens = sqrt_tokens * sqrt_tokens
- prior_token_ids_d32 = token_tensor[:usable_tokens]
- height = sqrt_tokens * factor
- width = sqrt_tokens * factor
- logger.error(f"Grid pattern mismatch. Using {sqrt_tokens}x{sqrt_tokens}, output: {height}x{width}")
+ logger.warning(
+ "[_parse_generated_tokens] token parse failed: insufficient tokens "
+ "(actual_tokens=%s, expected=%s, mode=%s)",
+ actual_tokens,
+ large_image_tokens if is_i2i else (small_image_tokens + large_image_tokens),
+ "i2i" if is_i2i else "t2i",
+ )
+ raise ValueError(f"token parse failed: actual_tokens={actual_tokens}, mode={'i2i' if is_i2i else 't2i'}")
# Upsample from 32x to 16x
prior_token_ids = _upsample_token_ids(prior_token_ids_d32, actual_h, actual_w)
@@ -144,8 +215,16 @@ def ar2diffusion(
source_outputs: list[Any],
prompt: OmniTokensPrompt | TextPrompt | list | None = None,
requires_multimodal_data: bool = False,
+ streaming_context: Any | None = None,
) -> list[dict[str, Any]]:
- """Process AR stage outputs to create Diffusion stage inputs."""
+ """Process AR stage outputs to create Diffusion stage inputs.
+
+ This processor accepts the stage-pool transition interface:
+ ``ar2diffusion(source_outputs, prompt, requires_multimodal_data)``.
+ """
+ del streaming_context
+
+ _t_total = time.perf_counter()
ar_outputs = source_outputs
diffusion_inputs = []
@@ -154,6 +233,7 @@ def ar2diffusion(
prompt = [prompt] if prompt is not None else [{}]
for i, ar_output in enumerate(ar_outputs):
+ _t_req = time.perf_counter()
output = ar_output.outputs[0]
generated_token_ids = output.token_ids
@@ -168,23 +248,76 @@ def ar2diffusion(
else:
original_prompt = {}
- height = original_prompt.get("height", 1024)
- width = original_prompt.get("width", 1024)
+ mm_processor_kwargs = original_prompt.get("mm_processor_kwargs")
+
+ def _coerce_dim(v: Any, default: int) -> int:
+ try:
+ iv = int(v)
+ return iv if iv > 0 else default
+ except (TypeError, ValueError):
+ return default
+
+ # Prefer GLM-Image target size from mm_processor_kwargs (set by serving layer),
+ # then fall back to top-level fields for backward compatibility.
+ height = _coerce_dim(
+ mm_processor_kwargs.get("target_h") if isinstance(mm_processor_kwargs, dict) else None,
+ _coerce_dim(original_prompt.get("height"), 1024),
+ )
+ width = _coerce_dim(
+ mm_processor_kwargs.get("target_w") if isinstance(mm_processor_kwargs, dict) else None,
+ _coerce_dim(original_prompt.get("width"), 1024),
+ )
text_prompt = original_prompt.get("prompt", "")
- # Detect i2i mode first by checking if multimodal_output contains prior_token_image_ids
+ # Detect i2i mode.
+ # Prefer normalized prompt multi_modal_data source-image presence, with
+ # multimodal output as secondary signal.
+ _t_mode = time.perf_counter()
is_i2i = False
+
+ prompt_modalities = original_prompt.get("modalities")
+ if isinstance(prompt_modalities, list) and "img2img" in prompt_modalities:
+ is_i2i = True
+
+ prompt_mm_data = original_prompt.get("multi_modal_data")
+ if _has_source_image(prompt_mm_data):
+ is_i2i = True
+
if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output:
mm_output = ar_output.multimodal_output
- if isinstance(mm_output, dict) and mm_output.get("prior_token_image_ids") is not None:
- is_i2i = True
+ if isinstance(mm_output, dict):
+ if mm_output.get("prior_token_image_ids") is not None:
+ is_i2i = True
+ _dt_mode = (time.perf_counter() - _t_mode) * 1000
# Parse and upsample prior tokens
- prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens(generated_token_ids, height, width, is_i2i=is_i2i)
+ _t_parse = time.perf_counter()
+ try:
+ prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens(
+ generated_token_ids,
+ height,
+ width,
+ is_i2i=is_i2i,
+ )
+ except ValueError as e:
+ logger.warning(
+ "[ar2diffusion] Request %s: skip due to token parse failure: %s "
+ "(target=%sx%s, mode=%s, raw_tokens=%s, tail=%s)",
+ i,
+ e,
+ height,
+ width,
+ "i2i" if is_i2i else "t2i",
+ len(generated_token_ids),
+ generated_token_ids[-8:] if len(generated_token_ids) >= 8 else generated_token_ids,
+ )
+ continue
+ _dt_parse = (time.perf_counter() - _t_parse) * 1000
# Get prior_token_image_ids from AR model output (for i2i mode)
# This contains VQ-VAE tokens from input image, used for KV cache conditioning
# NOTE: multimodal_output is attached to ar_output (RequestOutput), NOT output (CompletionOutput)
+ _t_prior_img = time.perf_counter()
prior_token_image_ids = None
# Check ar_output (RequestOutput) for multimodal_output - this is the correct location
@@ -223,6 +356,7 @@ def ar2diffusion(
prior_token_image_ids = [raw_prior_image_ids]
elif isinstance(raw_prior_image_ids, list):
prior_token_image_ids = raw_prior_image_ids
+ _dt_prior_img = (time.perf_counter() - _t_prior_img) * 1000
diffusion_input = {
"prompt": text_prompt,
@@ -237,18 +371,38 @@ def ar2diffusion(
if requires_multimodal_data:
mm_data = original_prompt.get("multi_modal_data")
if mm_data:
- pil_image = mm_data.get("image")
- if pil_image is None:
- # Try "images" (plural) as fallback
- images = mm_data.get("images")
- if images:
- pil_image = images[0] if isinstance(images, list) else images
+ pil_image = _first_source_image(mm_data)
diffusion_input["pil_image"] = pil_image
for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]:
if key in original_prompt:
diffusion_input[key] = original_prompt[key]
+ _dt_req = (time.perf_counter() - _t_req) * 1000
+ logger.info(
+ "[ar2diffusion] req=%d mode=%s target=%dx%d "
+ "raw_tokens=%d prior_tokens=%d prior_image_ids=%s "
+ "timing: mode_detect=%.3fms parse+upsample=%.3fms "
+ "prior_image_ids_extract=%.3fms req_total=%.3fms",
+ i,
+ "i2i" if is_i2i else "t2i",
+ pixel_h,
+ pixel_w,
+ len(generated_token_ids),
+ len(prior_token_ids),
+ "yes" if prior_token_image_ids is not None else "no",
+ _dt_mode,
+ _dt_parse,
+ _dt_prior_img,
+ _dt_req,
+ )
diffusion_inputs.append(diffusion_input)
+ _dt_total = (time.perf_counter() - _t_total) * 1000
+ logger.info(
+ "[ar2diffusion] batch done: %d reqs, total=%.3fms",
+ len(diffusion_inputs),
+ _dt_total,
+ )
+
return diffusion_inputs
diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py
index 0ea780db42b..4216f9259ed 100644
--- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py
+++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py
@@ -9,6 +9,14 @@
logger = init_logger(__name__)
+# Maximum tokens supported by the code2wav stage. The flattened talker codec
+# sequence fed to stage-1 must not exceed this, otherwise gpu_input_batch
+# add_request will fail with a broadcast error when copying prompt_token_ids
+# into token_ids_cpu. Keep in sync with the stage-1 ``max_model_len`` in
+# ``vllm_omni/model_executor/stage_configs/mimo_audio.yaml`` and the offline
+# example ``examples/offline_inference/mimo_audio/end2end.py``.
+MAX_CODE2WAV_TOKENS = 18192
+
def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torch.Tensor:
"""
@@ -222,6 +230,20 @@ def llm2code2wav(
code_final = prepend_and_flatten_colmajor(codec_codes, pad_vec)
code_final = code_final.tolist()
+ # Guard against flattened sequences longer than code2wav's max_model_len.
+ # Without this, add_request raises ``could not broadcast input array
+ # from shape (N,) into shape (max_model_len,)`` and kills the engine
+ # core (see issue #2683). Mirrors the offline end2end.py safeguard.
+ if len(code_final) > MAX_CODE2WAV_TOKENS:
+ request_id = getattr(talker_output, "request_id", f"unknown_{i}")
+ logger.warning(
+ "Request %s: code_final len=%d > MAX_CODE2WAV_TOKENS=%d, truncating.",
+ request_id,
+ len(code_final),
+ MAX_CODE2WAV_TOKENS,
+ )
+ code_final = code_final[:MAX_CODE2WAV_TOKENS]
+
code2wav_inputs.append(
OmniTokensPrompt(
prompt_token_ids=code_final,
diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py
index c02c0c1427c..23930f358bb 100644
--- a/vllm_omni/outputs.py
+++ b/vllm_omni/outputs.py
@@ -100,9 +100,22 @@ class OmniRequestOutput:
# memory usage info
peak_memory_mb: float = 0.0
- # error handling
+ # Error information -- set when the output represents a failed request.
error: str | None = None
+ @classmethod
+ def from_error(
+ cls,
+ request_id: str,
+ error: str,
+ ) -> "OmniRequestOutput":
+ """Create an error output for a request that failed during generation."""
+ return cls(
+ request_id=request_id,
+ finished=True,
+ error=error,
+ )
+
@classmethod
def from_pipeline(
cls,
diff --git a/vllm_omni/platforms/npu/worker/npu_model_runner.py b/vllm_omni/platforms/npu/worker/npu_model_runner.py
index 8ef39adfa67..310d09311f0 100644
--- a/vllm_omni/platforms/npu/worker/npu_model_runner.py
+++ b/vllm_omni/platforms/npu/worker/npu_model_runner.py
@@ -417,10 +417,22 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te
req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded]
last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded]
text_step = self.text_step.gpu[:num_tokens_padded]
+ subtalker_params = getattr(self.vllm_config.model_config, "subtalker_sampling_params", None)
+ if not isinstance(subtalker_params, dict):
+ subtalker_params = {}
with set_ascend_forward_context(
None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
):
- req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
+ req_embeds, code_predictor_codes = self.talker_mtp(
+ req_input_ids,
+ req_embeds,
+ last_talker_hidden,
+ text_step,
+ do_sample=subtalker_params.get("do_sample"),
+ temperature=subtalker_params.get("temperature"),
+ top_k=subtalker_params.get("top_k"),
+ top_p=subtalker_params.get("top_p"),
+ )
# code_predictor_codes stays on GPU here; _update_intermediate_buffer
# keeps it device-resident when the key is in gpu_resident_buffer_keys.
# D2H is deferred to sample_tokens where hidden_states.to("cpu") already
diff --git a/vllm_omni/profiler/omni_torch_profiler.py b/vllm_omni/profiler/omni_torch_profiler.py
index 2257a212838..023ad5009b7 100644
--- a/vllm_omni/profiler/omni_torch_profiler.py
+++ b/vllm_omni/profiler/omni_torch_profiler.py
@@ -5,8 +5,8 @@
import os
import subprocess
-import time
-from typing import Literal
+from datetime import datetime
+from typing import Any, Literal
import torch
from typing_extensions import override
@@ -62,6 +62,13 @@ def __init__(
self._trace_path: str | None = None
self._table_path: str | None = None
+ self._activities = activities
+ self._session_dir: str | None = None
+ self._artifact_paths: dict[str, str | None] = {}
+ self._memory_history_enabled = False
+ self._memory_history_backend: str | None = None
+ self._memory_history_module = None
+
if local_rank in (None, 0):
logger.info_once(
"Omni torch profiling enabled. Traces will be saved to: %s",
@@ -72,6 +79,9 @@ def __init__(
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
self.profiler = self._create_profiler(profiler_config, activities)
+ def _rank(self) -> int:
+ return 0 if self.local_rank is None else self.local_rank
+
def _get_default_activities(self) -> list[TorchProfilerActivity]:
"""Get default activities for this platform.
@@ -106,19 +116,58 @@ def set_trace_filename(self, filename: str) -> None:
Can also be a full path (e.g. from diffusion engine).
"""
self._trace_filename = filename
+ self._session_dir = None
+ self._ensure_session_dir()
+
+ def _ensure_session_dir(self) -> str:
+ """Create one timestamped directory for this profiling run."""
+ if self._session_dir is not None:
+ os.makedirs(self._session_dir, exist_ok=True)
+ return self._session_dir
+
+ ts = datetime.now().strftime("%Y%m%d-%H%M%S")
+ base_name = self._trace_filename or self._worker_name
+
+ if os.path.dirname(base_name):
+ parent_dir = os.path.dirname(base_name)
+ leaf_name = os.path.basename(base_name)
+ session_name = f"{ts}_{leaf_name}"
+ self._session_dir = os.path.join(parent_dir, session_name)
+ else:
+ session_name = f"{ts}_{base_name}"
+ self._session_dir = os.path.join(self._trace_dir, session_name)
+
+ os.makedirs(self._session_dir, exist_ok=True)
+ self._artifact_paths["session_dir"] = self._session_dir
+ return self._session_dir
+
+ def _artifact_path(self, stem: str, suffix: str) -> str:
+ """Build artifact path under the session directory."""
+ return os.path.join(
+ self._ensure_session_dir(),
+ f"{stem}_rank{self._rank()}{suffix}",
+ )
+
+ def _write_text_artifact(self, name: str, content: str) -> str:
+ path = self._artifact_path(name, ".txt")
+ with open(path, "w") as f:
+ f.write(content)
+ self._artifact_paths[name] = path
+ return path
+
+ def _has_cuda_like_activity(self) -> bool:
+ return any(a in self._activities for a in ("CUDA", "MUSA"))
+
+ def _get_time_sort_key(self) -> str:
+ if self._has_cuda_like_activity():
+ return "self_cuda_time_total"
+ return "self_cpu_time_total"
def _on_trace_ready(self, prof) -> None:
"""Custom trace handler: export chrome trace with omni naming."""
- rank = self.local_rank
- filename = self._trace_filename or f"{self._worker_name}_{int(time.time())}"
- # If filename already contains a directory, use as-is (e.g. from
- # diffusion engine which builds full path). Otherwise join with trace_dir.
- if os.path.dirname(filename):
- json_file = f"{filename}_rank{rank}.json"
- else:
- json_file = os.path.join(self._trace_dir, f"{filename}_rank{rank}.json")
+ rank = self._rank()
- os.makedirs(os.path.dirname(json_file), exist_ok=True)
+ json_file = self._artifact_path("trace", ".json")
try:
prof.export_chrome_trace(json_file)
@@ -143,18 +192,211 @@ def _on_trace_ready(self, prof) -> None:
else:
self._trace_path = json_file
+ self._artifact_paths["trace"] = self._trace_path
+
except Exception as e:
logger.warning("[Rank %s] Failed to export trace: %s", rank, e)
+ def _try_enable_memory_history(self) -> None:
+ """Enable backend-specific memory history for snapshot analysis."""
+ if not self.profiler_config.torch_profiler_with_memory:
+ return
+
+ backend_name, memory_module = self._resolve_memory_history_backend()
+ if backend_name is None or memory_module is None:
+ return
+
+ record_memory_history = getattr(memory_module, "_record_memory_history", None)
+ if record_memory_history is None:
+ logger.info(
+ "[Rank %s] %s memory history is not supported on this platform",
+ self._rank(),
+ backend_name,
+ )
+ return
+
+ try:
+ record_memory_history(
+ enabled="all",
+ context="all",
+ stacks="python",
+ max_entries=100000,
+ clear_history=True,
+ )
+ self._memory_history_enabled = True
+ self._memory_history_backend = backend_name
+ self._memory_history_module = memory_module
+ logger.info("[Rank %s] %s memory history enabled", self._rank(), backend_name)
+ except Exception as e:
+ logger.warning(
+ "[Rank %s] Failed to enable %s memory history: %s",
+ self._rank(),
+ backend_name,
+ e,
+ )
+
+ def _try_dump_memory_snapshot(self) -> None:
+ """Dump a backend-specific memory snapshot into the session directory."""
+ if not self._memory_history_enabled:
+ return
+
+ try:
+ if self._memory_history_module is None or self._memory_history_backend is None:
+ return
+
+ dump_snapshot = getattr(self._memory_history_module, "_dump_snapshot", None)
+ if dump_snapshot is None:
+ logger.info(
+ "[Rank %s] %s memory snapshot is not supported on this platform",
+ self._rank(),
+ self._memory_history_backend,
+ )
+ return
+
+ snapshot_file = self._artifact_path("memory_snapshot", ".pickle")
+ dump_snapshot(snapshot_file)
+ self._artifact_paths["memory_snapshot"] = snapshot_file
+ logger.info(
+ "[Rank %s] %s memory snapshot dumped to %s",
+ self._rank(),
+ self._memory_history_backend,
+ snapshot_file,
+ )
+ except Exception as e:
+ logger.warning(
+ "[Rank %s] Failed to dump %s memory snapshot: %s",
+ self._rank(),
+ self._memory_history_backend,
+ e,
+ )
+ finally:
+ try:
+ if self._memory_history_module is not None:
+ disable_memory_history = getattr(
+ self._memory_history_module,
+ "_record_memory_history",
+ None,
+ )
+ if disable_memory_history is not None:
+ disable_memory_history(enabled=None)
+ except Exception:
+ pass
+ self._memory_history_enabled = False
+ self._memory_history_backend = None
+ self._memory_history_module = None
+
+ def _resolve_memory_history_backend(self) -> tuple[str | None, Any]:
+ """Resolve the memory backend that supports history/snapshot APIs."""
+ backend_specs = [
+ ("CUDA", self._has_cuda_like_activity(), getattr(torch, "cuda", None)),
+ ("NPU", "NPU" in self._activities, getattr(torch, "npu", None)),
+ ("XPU", "XPU" in self._activities, getattr(torch, "xpu", None)),
+ ("MUSA", "MUSA" in self._activities, getattr(torch, "musa", None)),
+ ]
+
+ for backend_name, enabled, device_module in backend_specs:
+ if not enabled or device_module is None:
+ continue
+
+ is_available = getattr(device_module, "is_available", None)
+ if callable(is_available) and not is_available():
+ continue
+
+ memory_module = getattr(device_module, "memory", None)
+ if memory_module is not None:
+ return backend_name, memory_module
+
+ return None, None
+
+ def _safe_get(self, obj, name: str, default=None):
+ return getattr(obj, name, default)
+
+ def _event_list_to_rows(self, event_list) -> list[dict]:
+ rows = []
+ for evt in event_list:
+ row = {
+ "name": self._safe_get(evt, "key", None) or self._safe_get(evt, "name", None),
+ "count": self._safe_get(evt, "count", None),
+ "device_type": self._safe_get(evt, "device_type", None),
+ "node_id": self._safe_get(evt, "node_id", None),
+ "self_cpu_time_total_us": self._safe_get(evt, "self_cpu_time_total", None),
+ "cpu_time_total_us": self._safe_get(evt, "cpu_time_total", None),
+ "self_cuda_time_total_us": self._safe_get(evt, "self_cuda_time_total", None),
+ "cuda_time_total_us": self._safe_get(evt, "cuda_time_total", None),
+ "self_xpu_time_total_us": self._safe_get(evt, "self_xpu_time_total", None),
+ "xpu_time_total_us": self._safe_get(evt, "xpu_time_total", None),
+ "self_cpu_memory_usage_bytes": self._safe_get(evt, "self_cpu_memory_usage", None),
+ "cpu_memory_usage_bytes": self._safe_get(evt, "cpu_memory_usage", None),
+ "self_cuda_memory_usage_bytes": self._safe_get(evt, "self_cuda_memory_usage", None),
+ "cuda_memory_usage_bytes": self._safe_get(evt, "cuda_memory_usage", None),
+ "self_xpu_memory_usage_bytes": self._safe_get(evt, "self_xpu_memory_usage", None),
+ "xpu_memory_usage_bytes": self._safe_get(evt, "xpu_memory_usage", None),
+ "input_shapes": str(self._safe_get(evt, "input_shapes", None)),
+ "stack": "\n".join(self._safe_get(evt, "stack", []) or []),
+ "overload_name": self._safe_get(evt, "overload_name", None),
+ "is_async": self._safe_get(evt, "is_async", None),
+ "is_legacy": self._safe_get(evt, "is_legacy", None),
+ }
+ rows.append(row)
+ return rows
+
+ def _write_excel_artifact(self, name: str, sheets: dict[str, list[dict]]) -> str:
+ path = self._artifact_path(name, ".xlsx")
+
+ try:
+ import pandas as pd
+ except Exception as e:
+ logger.warning(
+ "[Rank %s] pandas not available, skip Excel export: %s",
+ self._rank(),
+ e,
+ )
+ return path
+
+ with pd.ExcelWriter(path, engine="openpyxl") as writer:
+ for sheet_name, rows in sheets.items():
+ df = pd.DataFrame(rows)
+
+ safe_sheet_name = sheet_name if sheet_name else "Sheet1"
+
+ df.to_excel(
+ writer,
+ sheet_name=safe_sheet_name,
+ index=False,
+ freeze_panes=(1, 0),
+ )
+
+ ws = writer.sheets[safe_sheet_name]
+ ws.auto_filter.ref = ws.dimensions
+
+ for col_cells in ws.columns:
+ max_len = 0
+ col_letter = col_cells[0].column_letter
+ for cell in col_cells[:200]:
+ try:
+ val = "" if cell.value is None else str(cell.value)
+ max_len = max(max_len, len(val))
+ except Exception:
+ pass
+ ws.column_dimensions[col_letter].width = min(max(max_len + 2, 12), 80)
+
+ self._artifact_paths[name] = path
+ return path
+
@override
def _start(self) -> None:
+ self._ensure_session_dir()
+ self._try_enable_memory_history()
self.profiler.start()
@override
def _stop(self) -> None:
"""Stop profiler, export trace via on_trace_ready, and dump table."""
self.profiler.stop()
- self._on_stop_hook()
+ try:
+ self._on_stop_hook()
+ finally:
+ self._try_dump_memory_snapshot()
def _on_stop_hook(self) -> None:
"""Hook called after profiler.stop().
@@ -163,6 +405,68 @@ def _on_stop_hook(self) -> None:
Base implementation handles CUDA time total dump.
"""
rank = self.local_rank
+ sort_key = self._get_time_sort_key()
+
+ excel_sheets: dict[str, list[dict]] = {}
+
+ # 1) Summary op table
+ summary_events = self.profiler.key_averages()
+ excel_sheets["summary"] = self._event_list_to_rows(summary_events)
+
+ # 2) Shape-grouped op table
+ if self.profiler_config.torch_profiler_record_shapes:
+ try:
+ shape_events = self.profiler.key_averages(
+ group_by_input_shape=True,
+ )
+ excel_sheets["by_shape"] = self._event_list_to_rows(shape_events)
+ except Exception as e:
+ logger.warning(
+ "[Rank %s] Failed to export shape-grouped op table: %s",
+ rank,
+ e,
+ )
+
+ # 3) Stack-grouped op table
+ if self.profiler_config.torch_profiler_with_stack:
+ try:
+ stack_events = self.profiler.key_averages(
+ group_by_stack_n=8,
+ )
+ excel_sheets["by_stack"] = self._event_list_to_rows(stack_events)
+ except Exception as e:
+ logger.warning(
+ "[Rank %s] Failed to export stack-grouped op table: %s",
+ rank,
+ e,
+ )
+
+ # 4) Export stack files
+ try:
+ cpu_stack_file = self._artifact_path("stacks_cpu", ".txt")
+ self.profiler.export_stacks(
+ cpu_stack_file,
+ metric="self_cpu_time_total",
+ )
+ self._artifact_paths["stacks_cpu"] = cpu_stack_file
+ except Exception as e:
+ logger.warning("[Rank %s] export_stacks(cpu) failed: %s", rank, e)
+
+ if self._has_cuda_like_activity():
+ try:
+ cuda_stack_file = self._artifact_path("stacks_cuda", ".txt")
+ self.profiler.export_stacks(
+ cuda_stack_file,
+ metric="self_cuda_time_total",
+ )
+ self._artifact_paths["stacks_cuda"] = cuda_stack_file
+ except Exception as e:
+ logger.warning("[Rank %s] export_stacks(cuda) failed: %s", rank, e)
+
+ try:
+ self._table_path = self._write_excel_artifact("ops", excel_sheets)
+ except Exception as e:
+ logger.warning("[Rank %s] Failed to export Excel workbook: %s", rank, e)
if self.profiler_config.torch_profiler_dump_cuda_time_total:
profiler_dir = self.profiler_config.torch_profiler_dir
@@ -190,6 +494,7 @@ def get_results(self) -> dict:
return {
"trace": self._trace_path,
"table": self._table_path,
+ **self._artifact_paths,
}
diff --git a/vllm_omni/worker/base.py b/vllm_omni/worker/base.py
index f7f5dbd1d8b..8bd9efc89c4 100644
--- a/vllm_omni/worker/base.py
+++ b/vllm_omni/worker/base.py
@@ -2,15 +2,23 @@
from __future__ import annotations
+import gc
import os
import time
+from contextlib import AbstractContextManager, nullcontext
import torch
from vllm.logger import init_logger
from vllm.utils.mem_utils import format_gib, memory_profiling
from vllm.v1.worker.gpu_worker import Worker as GPUWorker
+from vllm_omni.diffusion.data import (
+ OmniACK,
+ OmniSleepTask,
+ OmniWakeTask,
+)
from vllm_omni.entrypoints.utils import detect_pid_host
+from vllm_omni.platforms import current_omni_platform
from vllm_omni.worker.gpu_memory_utils import (
get_process_gpu_memory,
is_process_scoped_memory_available,
@@ -30,6 +38,13 @@ class OmniGPUWorkerBase(GPUWorker):
for custom trace naming, background gzip, and trace path collection.
"""
+ def load_model(self, *args, **kwargs):
+ with self._maybe_get_memory_pool_context("weights"):
+ res = super().load_model(*args, **kwargs)
+ current_omni_platform.synchronize()
+ gc.collect()
+ return res
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -94,10 +109,8 @@ def determine_available_memory(self) -> int:
"""
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
self.model_runner.profile_run()
- logger.info(
- "Using explicit kv_cache_memory_bytes: %s GiB",
- format_gib(kv_cache_memory_bytes),
- )
+ if current_omni_platform.is_rocm():
+ torch.cuda.synchronize()
return kv_cache_memory_bytes
with memory_profiling(
@@ -154,3 +167,141 @@ def determine_available_memory(self) -> int:
)
return int(self.available_kv_cache_memory_bytes)
+
+ # Provide memory pool context
+ def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
+ v1_config_enabled = False
+ if hasattr(self, "vllm_config"):
+ model_cfg = getattr(self.vllm_config, "model_config", None)
+ v1_config_enabled = getattr(model_cfg, "enable_sleep_mode", False)
+
+ is_sleep_enabled = v1_config_enabled or getattr(self.cache_config, "enable_sleep_mode", False)
+ if is_sleep_enabled:
+ current_omni_platform.synchronize()
+ gc.collect()
+ from vllm.device_allocator.cumem import CuMemAllocator
+
+ allocator = CuMemAllocator.get_instance()
+ logger.info(f"[LLM Worker {self.rank}] Sleep Mode ENABLED. Activating CuMem pool for tag: {tag}")
+ return allocator.use_memory_pool(tag=tag)
+ else:
+ logger.warning(f"[LLM Worker {self.rank}] Sleep Mode DISABLED.")
+ return nullcontext()
+
+ def sleep(self, level: int = 1) -> bool:
+ """
+ Put the worker to sleep.
+ Args:
+ level: 1 (Offload weights to CPU), level: 2 (Total Discard).
+ """
+ from vllm.device_allocator.cumem import CuMemAllocator
+
+ mem_before = current_omni_platform.get_current_memory_usage(self.device)
+ offload_tags = ("weights",) if level == 1 else tuple()
+ allocator = CuMemAllocator.get_instance()
+ allocator.sleep(offload_tags=offload_tags)
+ current_omni_platform.empty_cache()
+ current_omni_platform.synchronize()
+ mem_after = current_omni_platform.get_current_memory_usage(self.device)
+ freed = max(0, mem_before - mem_after)
+ remaining_gb = mem_after / 1024**3
+ logger.info(
+ f"[LLM Worker {self.rank}] Level {level} Sleep: Freed "
+ f"{freed / 1024**3:.2f} GiB. {remaining_gb:.2f}GiB memory "
+ "is still in use."
+ )
+ return True
+
+ def wake_up(self, tags: list[str] | None = None) -> bool:
+ "Physical video memory reloading logic"
+ from vllm.device_allocator.cumem import CuMemAllocator
+
+ allocator = CuMemAllocator.get_instance()
+ allocator.wake_up(tags)
+ current_omni_platform.synchronize()
+ logger.info(f"[LLM Worker {self.rank}] Wake-up complete.")
+ return True
+
+ def handle_sleep_task(self, task: OmniSleepTask) -> OmniACK:
+ "Handle deterministic Sleep command from the main process"
+ try:
+ if isinstance(task, dict):
+ task = OmniSleepTask(**task)
+ logger.info(f"[LLM Worker {self.rank}] Handshake Received: Task {task.task_id}, Level {task.level}")
+ if task.level == 2:
+ if hasattr(self.model_runner, "graph_runners"):
+ self.model_runner.graph_runners.clear()
+ logger.info(f"[LLM Worker {self.rank}] LLM CUDA Graphs cleared.")
+ mem_before = current_omni_platform.get_current_memory_usage(self.device)
+ self.sleep(level=task.level)
+ mem_after = current_omni_platform.get_current_memory_usage(self.device)
+ rank_freed = max(0, mem_before - mem_after)
+ if torch.distributed.is_initialized():
+ t_freed = torch.tensor([float(rank_freed)], device=self.device)
+ torch.distributed.all_reduce(t_freed)
+ total_freed = int(t_freed.item())
+ torch.distributed.barrier()
+ else:
+ total_freed = rank_freed
+ if self.rank != 0:
+ return None
+ current_stage_id = getattr(self.vllm_config.model_config, "stage_id", 0)
+ ack = OmniACK(
+ task_id=task.task_id,
+ status="SUCCESS",
+ stage_id=current_stage_id,
+ rank=self.rank,
+ freed_bytes=total_freed,
+ metadata={
+ "source": "omni_platform_audit",
+ "total_freed_gib": f"{total_freed / 1024**3:.2f}",
+ "rank_residual_gib": f"{mem_after / 1024**3:.2f}",
+ },
+ )
+ if hasattr(self, "result_mq") and self.result_mq:
+ self.result_mq.put(ack)
+ logger.info(f"[LLM Worker {self.rank}] ACK emitted for Task {task.task_id}")
+ return ack
+ except Exception as e:
+ logger.error(f"[LLM Worker {self.rank}] Sleep Task Failed: {e}", exc_info=True)
+ if torch.distributed.is_initialized():
+ try:
+ torch.distributed.barrier()
+ except Exception:
+ pass
+ return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e))
+
+ def handle_wake_task(self, task: OmniWakeTask) -> OmniACK:
+ "Handle deterministic Wakeup command from the main process"
+ try:
+ if isinstance(task, dict):
+ task = OmniWakeTask(**task)
+ self.wake_up(tags=task.tags)
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier()
+ gc.collect()
+ current_omni_platform.synchronize()
+ usage_now = current_omni_platform.get_current_memory_usage(self.device)
+ if self.rank != 0:
+ return None
+ current_stage_id = getattr(self.vllm_config.model_config, "stage_id", 0)
+ ack = OmniACK(
+ task_id=task.task_id,
+ status="SUCCESS",
+ stage_id=current_stage_id,
+ rank=self.rank,
+ metadata={"state": "WARM", "current_vram_gib": f"{usage_now / 1024**3:.2f}"},
+ )
+ if hasattr(self, "result_mq") and self.result_mq:
+ self.result_mq.put(ack)
+ logger.info(f"[LLM Worker {self.rank}] Wake-up ACK emitted.")
+ return ack
+ except Exception as e:
+ logger.error(f"[LLM Worker {self.rank}] Wake-up Failed: {e}", exc_info=True)
+ if torch.distributed.is_initialized():
+ try:
+ torch.distributed.barrier()
+ except Exception:
+ pass
+ tid = task.task_id if hasattr(task, "task_id") else "unknown"
+ return OmniACK(task_id=tid, status="ERROR", error_msg=str(e))
diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py
index 4abe21964b3..2668bbe8b3a 100644
--- a/vllm_omni/worker/gpu_ar_worker.py
+++ b/vllm_omni/worker/gpu_ar_worker.py
@@ -12,6 +12,7 @@
from vllm.v1.worker.utils import request_memory
from vllm.v1.worker.workspace import init_workspace_manager
+from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask
from vllm_omni.worker.base import OmniGPUWorkerBase
from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner
from vllm_omni.worker.mixins import OmniWorkerMixin
@@ -104,3 +105,23 @@ def init_device(self):
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
+
+ def handle_sleep_task(self, task: OmniSleepTask | dict) -> OmniACK:
+ """
+ Explicitly handle sleep commands.
+ Calls the implementation in the base class OmniGPUWorkerBase.
+ """
+ logger.debug(f"[AR Worker {self.rank}] Resolving handle_sleep_task dispatch")
+ if isinstance(task, dict):
+ task = OmniSleepTask(**task)
+ return super().handle_sleep_task(task)
+
+ def handle_wake_task(self, task: OmniWakeTask | dict) -> OmniACK:
+ """
+ Explicitly handle wake-up commands.
+ Calls the implementation in the base class OmniGPUWorkerBase.
+ """
+ logger.debug(f"[AR Worker {self.rank}] Resolving handle_wake_task dispatch")
+ if isinstance(task, dict):
+ task = OmniWakeTask(**task)
+ return super().handle_wake_task(task)
diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py
index d1c15eac640..d3ccbaaf303 100644
--- a/vllm_omni/worker/gpu_model_runner.py
+++ b/vllm_omni/worker/gpu_model_runner.py
@@ -152,8 +152,10 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
if supports_mrope(self.get_model()):
# Model implements SupportsMRoPE interface
# Pass all extracted metadata; models use what they need via **kwargs
- req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions(
- req_state.prompt_token_ids,
+ sp_extra_args = getattr(req_state.sampling_params, "extra_args", {}) if req_state.sampling_params else {}
+ target_h = sp_extra_args.get("target_h") if isinstance(sp_extra_args, dict) else None
+ target_w = sp_extra_args.get("target_w") if isinstance(sp_extra_args, dict) else None
+ kwargs = dict(
mm_features=req_state.mm_features,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
@@ -162,6 +164,14 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
+ if target_h is not None:
+ kwargs["target_h"] = target_h
+ if target_w is not None:
+ kwargs["target_w"] = target_w
+ req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions(
+ req_state.prompt_token_ids,
+ **kwargs,
+ )
else:
req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
@@ -1345,10 +1355,22 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te
req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded]
last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded]
text_step = self.text_step.gpu[:num_tokens_padded]
+ subtalker_params = getattr(self.vllm_config.model_config, "subtalker_sampling_params", None)
+ if not isinstance(subtalker_params, dict):
+ subtalker_params = {}
with set_forward_context(
None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc
):
- req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step)
+ req_embeds, code_predictor_codes = self.talker_mtp(
+ req_input_ids,
+ req_embeds,
+ last_talker_hidden,
+ text_step,
+ do_sample=subtalker_params.get("do_sample"),
+ temperature=subtalker_params.get("temperature"),
+ top_k=subtalker_params.get("top_k"),
+ top_p=subtalker_params.get("top_p"),
+ )
# code_predictor_codes stays on GPU here; _update_intermediate_buffer
# keeps it device-resident when the key is in gpu_resident_buffer_keys.
# D2H is deferred to sample_tokens where hidden_states.to("cpu") already