diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index f56f23b5deb..2707dfefe83 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -64,12 +64,13 @@ while true; do done echo "--- Pulling container" -## Temporary change to use AMD Docker Hub to store the vllm-ci image +## Temporary change to use AMD Docker Hub to store the vllm-omni image # to bypass the rate limit issue with ECR Public Gallery. +# Images are now stored in a separate repository for vllm-omni, instead of vllm-ci. # TODO: @tjtanaa point back to ECR Public Gallery # once the amd agents are configured to use ECR Public Gallery. # image_name="public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${BUILDKITE_COMMIT}-rocm-omni" -image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}-rocm-omni" +image_name="rocm/vllm-omni:${BUILDKITE_COMMIT}" container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)" # TODO: @tjtanaa uncomment this once the amd agents are configured to use ECR Public Gallery. diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index e175385ff0d..0b9a3f47aba 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -117,3 +117,17 @@ steps: - export VLLM_LOGGING_LEVEL=DEBUG - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v tests/e2e/online_serving/test_image_gen_edit.py + + +- label: "Omni Sleep Mode Test" + timeout_in_minutes: 40 + agent_pool: mi325_2 + depends_on: amd-build + mirror_hardwares: [amdproduction] + grade: Blocking + commands: + - export GPU_ARCHS=gfx942 + - export VLLM_LOGGING_LEVEL=DEBUG + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - export VLLM_TEST_CLEAN_GPU_MEMORY="1" + - pytest -s -v tests/e2e/offline_inference/test_omni_sleep_mode.py -m "advanced_model and omni and MI325" --run-level "advanced_model" diff --git a/.buildkite/test-merge.yml b/.buildkite/test-merge.yml index 2a6cb6488a0..785cc58fab9 100644 --- a/.buildkite/test-merge.yml +++ b/.buildkite/test-merge.yml @@ -360,6 +360,46 @@ steps: path: /mnt/hf-cache type: DirectoryOrCreate + - label: "Omni Sleep Mode Test with H100" + timeout_in_minutes: 30 + depends_on: upload-merge-pipeline + commands: + - export VLLM_TEST_CLEAN_GPU_MEMORY="1" + - pytest -s -v tests/e2e/offline_inference/test_omni_sleep_mode.py -m "advanced_model and H100 and omni" --run-level "advanced_model" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 2 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + - label: "Voxtral-TTS E2E Test" timeout_in_minutes: 20 depends_on: upload-merge-pipeline diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml index c33d7b4d10d..fb21a2acaac 100644 --- a/.buildkite/test-nightly.yml +++ b/.buildkite/test-nightly.yml @@ -151,6 +151,44 @@ steps: type: DirectoryOrCreate + - label: ":full_moon: Omni Multi-Replica Startup Test with 4x H100" + timeout_in_minutes: 45 + commands: + - pytest -s -v tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py -m "core_model" --run-level "core_model" + agents: + queue: "mithril-h100-pool" + plugins: + - kubernetes: + podSpec: + containers: + - image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + resources: + limits: + nvidia.com/gpu: 4 + volumeMounts: + - name: devshm + mountPath: /dev/shm + - name: hf-cache + mountPath: /root/.cache/huggingface + env: + - name: HF_HOME + value: /root/.cache/huggingface + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + nodeSelector: + node.kubernetes.io/instance-type: gpu-h100-sxm + volumes: + - name: devshm + emptyDir: + medium: Memory + - name: hf-cache + hostPath: + path: /mnt/hf-cache + type: DirectoryOrCreate + - group: ":card_index_dividers: TTS Model Test" key: nightly-tts-test-group depends_on: upload-nightly-pipeline diff --git a/.buildkite/test-template-amd-omni.j2 b/.buildkite/test-template-amd-omni.j2 index f4c386a5fe3..78f47d1aec0 100644 --- a/.buildkite/test-template-amd-omni.j2 +++ b/.buildkite/test-template-amd-omni.j2 @@ -3,7 +3,7 @@ Last synced: 2025-12-15 Modifications: Removed unused CUDA/NVIDIA logic, keeping only AMD tests #} -{% set docker_image_amd = "rocm/vllm-ci:$BUILDKITE_COMMIT-rocm-omni" %} +{% set docker_image_amd = "rocm/vllm-omni:$BUILDKITE_COMMIT" %} {% set default_working_dir = "/app/vllm-omni" %} - group: "AMD Tests" diff --git a/.gitignore b/.gitignore index 35dc7571ee2..a631aa05f74 100644 --- a/.gitignore +++ b/.gitignore @@ -174,6 +174,7 @@ CLAUDE.md # Codex AGENTS.md +.codex .codex/ # cursor diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 6c209e5659a..a7b508302d2 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -138,9 +138,27 @@ Single-stage diffusion serving with torch profiler: vllm serve Wan-AI/Wan2.2-I2V-A14B-Diffusers \ --omni \ --port 8091 \ - --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile"}' + --profiler-config '{ + "profiler": "torch", + "torch_profiler_dir": "/tmp/vllm_profile_wan22_i2v", + "torch_profiler_with_stack": true, + "torch_profiler_with_flops": false, + "torch_profiler_use_gzip": true, + "torch_profiler_dump_cuda_time_total": true, + "torch_profiler_record_shapes": true, + "torch_profiler_with_memory": true, + "delay_iterations": 0, + "max_iterations": 0 + }' ``` +Useful optional fields: + +- `torch_profiler_with_stack`: export `by_stack` operator views and stack text files. +- `torch_profiler_record_shapes`: export `by_shape` operator views. +- `torch_profiler_with_memory`: dump `memory_snapshot_rank*.pickle` when the backend supports memory history. +- `torch_profiler_use_gzip`: write the trace as `trace_rank*.json.gz`. + Single-stage diffusion serving with Nsight Systems: ```bash @@ -177,8 +195,11 @@ For mixed-stage pipelines, use explicit `stages` and pass the same stage list to Torch profiler output: -- Chrome/Perfetto traces under `torch_profiler_dir` -- Optional aggregated CUDA-time tables under the same directory +- Chrome/Perfetto trace: `trace_rank*.json` or `trace_rank*.json.gz` +- Excel workbook: `ops_rank*.xlsx` with `summary`, and optional `by_shape` / `by_stack` sheets +- Stack exports: `stacks_cpu_rank*.txt` and `stacks_cuda_rank*.txt` when stack capture is enabled +- Memory snapshot: `memory_snapshot_rank*.pickle` when memory capture is enabled and supported by the backend +- Optional aggregated CUDA-time tables under the same session directory CUDA profiler / Nsight Systems output: diff --git a/docs/features/sleep_mode.md b/docs/features/sleep_mode.md index 41aa48c1735..b4eec162d31 100644 --- a/docs/features/sleep_mode.md +++ b/docs/features/sleep_mode.md @@ -1,39 +1,243 @@ -# Sleep Mode +# Sleep Mode & ACK Protocol -vLLM-Omni’s **Sleep Mode** allows you to temporarily release most GPU memory used by a model—such as model weights and key-value (KV) caches (for autoregressive models)—**without stopping the server or unloading the Docker container**. +vLLM-Omni’s **Sleep Mode** allows you to temporarily release most GPU memory used by a model—such as model weights and key-value (KV) caches—**without stopping the server or unloading the Docker container**. -This feature is inherited from [vLLM’s Sleep Mode](https://blog.vllm.ai/2025/10/26/sleep-mode.html), which provides zero-reload model switching for multi-model serving. - -It is especially useful in **RLHF**, **training**, or **cost-saving scenarios**, where GPU resources must be freed between inference workloads. +This feature is inherited from [vLLM’s Sleep Mode](https://blog.vllm.ai/2025/10/26/sleep-mode.html) and extended with the **Omni ACK Protocol** to support multi-stage pipelines and heterogeneous hardware backends (NVIDIA, AMD, Intel, Huawei). It is especially useful in **RLHF**, **dynamic model switching**, or **cost-saving scenarios**. --- -## Omni Model +## 1. Feature Documentation -Omni model inherit the feature from vLLM' Sleep Mode +### Overview +Omni Sleep Mode provides a mechanism to "sleep" specific model stages. When a stage enters sleep, its physical VRAM is reclaimed by the system, while the process state is preserved for rapid "wake-up" without full re-initialization. -This means: +### Sleep Levels +We support two levels of hibernation to balance recovery speed and memory efficiency: -- Support both Level 1 and Level 2 sleep, allow to release and reset both model weights and KV Cache +| Level | Name | Mechanism | Recovery Speed | Memory Freed | +| :--- | :--- | :--- | :--- | :--- | +| **Level 1** | **Weight Offloading** | Offloads weights to Host CPU RAM. | **Fast** (DMA) | Substantial | +| **Level 2** | **Full De-mapping** | Physically releases memory pages via VRAM scavenging. | **Moderate** | **Maximum** (up to 95%+) | -## Diffusion Model Extension +### Supported Platforms -We added Sleep Mode support for **diffusion models**, which previously lacked this functionality. -In diffusion pipelines, this currently only offloads **model weight memory**, as these models typically do not use KV caches. +Omni Sleep Mode is optimized for high-performance computing backends: -This means: +* **NVIDIA**: Supported via Virtual Memory Management (VMM). +* **AMD (ROCm)**: Fully supported with physical page de-mapping. +* **Intel XPU**: Supported via Level Zero memory management. +* **Huawei NPU**: Supported via Ascend memory scavenging. -- Diffusion models can now enter Level 1 sleep. -- Pipeline states (e.g., noise schedulers, buffers) remain intact after waking. -- Useful for releasing VRAM between image generation or training cycles. +### Hardware Requirements +* **Memory Considerations**: System RAM must be sufficient to hold offloaded weights during sleep. +* **TP Support**: Tensor Parallel groups synchronize sleep/wake transitions across all workers. --- -## Enable sleep mode -To enable sleep mode, set the `enable_sleep_mode` in `engine_args` to `True` +## 2. Usage Examples + +### Python API Example +You can programmatically control the lifecycle of stages using the `AsyncOmni` engine. -Example: ```python -omni = Omni(model=...,enable_sleep_mode=True) + +import asyncio +from vllm_omni.entrypoints.async_omni import AsyncOmni + +async def run_sleep_demo(): + # 1. initialization + engine = AsyncOmni( + model="ByteDance-Seed/BAGEL-7B-MoT", + enable_sleep_mode=True + ) + + # 2. sleep mode level2 + acks = await engine.sleep(stage_ids=[0], level=2) + print(f"Freed {acks[0].freed_bytes / 1024**3:.2f} GiB on Stage 0") + + # 3. wake up + await engine.wake_up(stage_ids=[0]) + +if __name__ == "__main__": + asyncio.run(run_sleep_demo()) + +``` + +### server command Example +Start the server with sleep mode enabled: + +The first method + +``` + +vllm serve ByteDance-Seed/BAGEL-7B-MoT \ +--omni \ +--enable-sleep-mode \ +--trust-remote-code \ +--gpu-memory-utilization 0.7 + +``` + +The second method + +``` + +python3 -m vllm_omni.entrypoints.openai.api_server \ + --model ByteDance-Seed/BAGEL-7B-MoT \ + --omni \ + --enable-sleep-mode \ + --trust-remote-code \ +--gpu-memory-utilization 0.7 + +``` + + + + +### Test Scenarios & Commands + +#### Scenario 1: LLM Engine Sleep + +Objective: Verify VRAM reclamation for Stage 0 (Thinker). + +Trigger sleep (Level 1 or Level 2) via client: + ``` + +curl -X POST http://localhost:8000/v1/omni/sleep \ + -H "Content-Type: application/json" \ + -d '{"stage_ids": [0], "level": 2}' + +``` + +Tip: Open a new terminal and run rocm-smi or nvidia-smi or to observe the immediate drop in VRAM usage. + + + +#### Scenario 2: Diffusion Sleep +Objective: Verify VRAM reclamation for Stage 1 (Diffusion). + +Trigger sleep (Level 1 or Level 2) via client: + +``` + +curl -X POST http://localhost:8000/v1/omni/sleep \ + -H "Content-Type: application/json" \ + -d '{"stage_ids": [1], "level": 2}' + +``` + + + +#### Scenario 3: Multi-Stage Coordinated Stress Test +Objective: Test concurrent sleep and rapid wake-up across multiple stages. + +Concurrent Sleep (Stage 0 & 1): + +``` + +curl -X POST http://localhost:8000/v1/omni/sleep \ + -H "Content-Type: application/json" \ + -d '{"stage_ids": [0, 1], "level": 2}' + +``` + + +Rapid Wake-up: + +``` + +curl -X POST http://localhost:8000/v1/omni/wakeup \ + -H "Content-Type: application/json" \ + -d '{"stage_ids": [0, 1]}' + +``` + + +#### Scenario 4: Full Lifecycle Memory Audit & Functional Integrity +Objective: Audit the complete flow from Sleep to Wake-up followed by an Inference validation. + +Check Initial State: Observe baseline VRAM usage. + +Trigger Deep Sleep (Level 2): + +``` + +curl -X POST http://localhost:8000/v1/omni/sleep \ + -H "Content-Type: application/json" \ + -d '{"stage_ids": [0], "level": 2}' + +``` + +Wake-up Model: + +``` + +curl -X POST http://localhost:8000/v1/omni/wakeup \ + -H "Content-Type: application/json" \ + -d '{"stage_ids": [0]}' + +``` + +Verify Functional Integrity (Inference): +Ensure the model still generates valid output after reloading weights. + +``` + +curl -X POST http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A huge swimming pool, with many people swimming.", + "model": "ByteDance-Seed/BAGEL-7B-MoT", + "response_format": "b64_json", + "extra_body": {"sampling_params": {"num_inference_steps": 4, "seed": 42}} + }' > post.json + +``` + + + + +## 3. API Reference + + +### Methods + +| Method | Arguments | Return Type | Description | +| :--- | :--- | :--- | :--- | +| **sleep** | `stage_ids: List[int], level: int` | `List[OmniACK]` | Triggers hibernation for specified stages. | +| **wake_up** | `stage_ids: List[int]` | `List[OmniACK]` | Reloads weights and re-maps memory. | + + + +### OmniACK Dataclass Fields + +| Field | Type | Description | +| :--- | :--- | :--- | +| **task_id** | `str` | Unique identifier for the operation. | +| **status** | `str` | `SUCCESS` or `ERROR`. | +| **stage_id** | `int` | The ID of the stage that responded. | +| **rank** | `int` | The rank ID within the Tensor Parallel group. | +| **freed_bytes** | `int` | Actual amount of physical VRAM reclaimed. | +| **metadata** | `dict` | Additional platform-specific metrics. | + +Metadata Field Analysis +The metadata field is a dynamic dictionary containing hardware-specific telemetry and audit data, primarily used for verifying memory reclamation on various backends (e.g., AMD ROCm, NVIDIA CUDA). + +``` +"metadata": { + "source": "Platform_AMD_Instinct_MI300X", + "total_freed_gib": "78.57", + "rank_residual_gib": "2.07" +} +``` + +#### Core Utility: +**VRAM Reclamation Audit (total_freed_gib)**: Converts raw freed_bytes into human-readable GiB. It serves as the primary metric to verify that Level 2 sleep has successfully purged model weights from VRAM. + +**Residual & Fragmentation Monitoring (rank_residual_gib)**: Reports the remaining VRAM footprint after memory de-mapping. A low residual value (e.g., 2.07 GiB) confirms a successful "clean" state, ensuring the device is ready for high-memory co-located tasks like training or diffusion pipelines. + +**Backend Traceability (source)**: Identifies the underlying hardware driver or audit source. This is critical for debugging synchronization issues in multi-stage, distributed environments. + +**Performance Analytics (Roadmap)**: Future updates will include latency_ms (context-switch overhead) and cuda_graph_recalled (graph engine status) to optimize performance in high-frequency sleep/wake scenarios. diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 7bdeede446a..be5b2d10ca3 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -140,7 +140,7 @@ The following tables show which models support each feature: |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| | **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ | | **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ | -| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | +| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | | **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | diff --git a/docs/user_guide/examples/online_serving/text_to_video.md b/docs/user_guide/examples/online_serving/text_to_video.md index 00a9c167239..b918aac19d0 100644 --- a/docs/user_guide/examples/online_serving/text_to_video.md +++ b/docs/user_guide/examples/online_serving/text_to_video.md @@ -288,6 +288,14 @@ vllm serve Lightricks/LTX-2 --omni --port 8098 \ --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0 ``` +For multi-GPU memory reduction, you can enable HSDP: + +```bash +vllm serve Lightricks/LTX-2 --omni --port 8098 \ + --enforce-eager --flow-shift 1.0 --boundary-ratio 1.0 \ + --use-hsdp --hsdp-shard-size 2 +``` + #### Start with Optimization Presets Use the LTX-2 startup script with built-in optimization presets: diff --git a/examples/offline_inference/glm_image/end2end.py b/examples/offline_inference/glm_image/end2end.py index 13bcd23f55a..7ae9478ba75 100644 --- a/examples/offline_inference/glm_image/end2end.py +++ b/examples/offline_inference/glm_image/end2end.py @@ -57,22 +57,22 @@ GLM_IMAGE_VISION_VOCAB_SIZE = 16512 # top_k should be vision_vocab_size -def compute_max_tokens(height: int, width: int, factor: int = 32) -> int: +def compute_max_tokens(height: int, width: int, factor: int = 32, is_i2i: bool = False) -> int: """ Compute max_new_tokens for GLM-Image AR generation. - GLM-Image generates tokens in this order for text-to-image: - 1. Small preview image (half resolution in each dimension) - 2. Large target image (full resolution) - 3. EOS token + GLM-Image generation differs by mode: + - text-to-image (t2i): small preview + large target + EOS + - image-to-image (i2i): large target + EOS Args: height: Target image height in pixels width: Target image width in pixels factor: Downsampling factor (32 for GLM-Image AR output) + is_i2i: Whether the request is image-to-image mode Returns: - Total number of tokens to generate (small + large + EOS) + Total number of tokens to generate for the specified mode """ # Large image tokens (target resolution) token_h = height // factor @@ -80,11 +80,15 @@ def compute_max_tokens(height: int, width: int, factor: int = 32) -> int: large_tokens = token_h * token_w # Small preview tokens (half resolution in each dimension) - small_h = token_h // 2 - small_w = token_w // 2 - small_tokens = small_h * small_w + import math - # Total: small + large + EOS + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + small_tokens = small_token_h * small_token_w + + if is_i2i: + return large_tokens + 1 return small_tokens + large_tokens + 1 @@ -282,14 +286,18 @@ def main(args: argparse.Namespace) -> None: # Compute max_tokens dynamically based on target image size target_height = prompt_dict.get("height", 1024) target_width = prompt_dict.get("width", 1024) - calculated_max_tokens = compute_max_tokens(target_height, target_width) + is_i2i = source_image is not None + calculated_max_tokens = compute_max_tokens(target_height, target_width, is_i2i=is_i2i) # Use calculated value unless user explicitly specified a different value # Default args.max_tokens is 16384 (very large), so prefer calculated value effective_max_tokens = calculated_max_tokens if args.max_tokens == 16384 else args.max_tokens if args.verbose: - print(f"AR max_tokens: {effective_max_tokens} (calculated: {calculated_max_tokens}, arg: {args.max_tokens})") + print( + f"AR max_tokens: {effective_max_tokens} " + f"(calculated: {calculated_max_tokens}, arg: {args.max_tokens}, mode: {'i2i' if is_i2i else 't2i'})" + ) # IMPORTANT: GLM-Image AR model requires these exact sampling parameters # from generation_config.json for proper image token generation. @@ -303,6 +311,12 @@ def main(args: argparse.Namespace) -> None: stop_token_ids=[GLM_IMAGE_EOS_TOKEN_ID], # 16385, CRITICAL for stopping seed=args.seed, detokenize=False, + # Keep target size available in runner/model for deterministic M-RoPE + # decode grids in t2i (no mm_features available in this path). + extra_args={ + "target_h": int(target_height), + "target_w": int(target_width), + }, ) # For diffusion stage, sampling_params contains diffusion-specific parameters diff --git a/examples/offline_inference/hunyuan_image3/README.md b/examples/offline_inference/hunyuan_image3/README.md index da28a44d9e6..3cd8fa01b2e 100644 --- a/examples/offline_inference/hunyuan_image3/README.md +++ b/examples/offline_inference/hunyuan_image3/README.md @@ -1,25 +1,161 @@ -# HunyuanImage-3.0 Image-to-Text Inference +# HunyuanImage-3.0-Instruct -This example demonstrates how to run HunyuanImage-3.0 Image-to-Text with the vLLM-Omni. +## Set up -## Local CLI Usage +Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup. -Download the example image: +## Run examples + +**Note**: These examples work with the default configuration on **8x NVIDIA L40S (48GB)**. For different GPU setups, modify the stage configuration to adjust device allocation and memory utilization. + +Get into the hunyuan_image3 folder: + +```bash +cd examples/offline_inference/hunyuan_image3 +``` + +### Modality Control + +HunyuanImage-3.0-Instruct supports multiple modality modes. You can control the mode using the `--modality` argument: + +#### Text to Image (text2img) + +- **Pipeline**: Text → AR (CoT + latent tokens) → DiT (denoise) → VAE Decode → Image +- **Stages Used**: Stage 0 (AR) + Stage 1 (DiT) +- **KV Transfer**: AR sends KV cache to DiT for conditioned generation +- **Default Config**: `hunyuan_image3_t2i.yaml` + +```bash +python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ + --modality text2img \ + --prompts "A cute cat sitting on a windowsill watching the sunset" +``` + +#### Image to Image (img2img) + +- **Pipeline**: Image + Text → AR (CoT + recaption + latent) → DiT → Edited Image +- **Stages Used**: Stage 0 (AR) + Stage 1 (DiT) +- **KV Transfer**: AR sends KV cache to DiT +- **Default Config**: `hunyuan_image3_it2i.yaml` + +```bash +python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ + --modality img2img \ + --image-path /path/to/image.png \ + --prompts "Make the petals neon pink" +``` + +#### Image to Text (img2text) + +- **Pipeline**: Image + Question → AR → Text description +- **Stages Used**: Stage 0 (AR) only +- **Default Config**: `hunyuan_image3_i2t.yaml` ```bash -wget https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg +python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ + --modality img2text \ + --image-path /path/to/image.jpg \ + --prompts "Describe the content of the picture." ``` -Run example: +#### Text to Text (text2text) + +- **Pipeline**: Text → AR → Text +- **Stages Used**: Stage 0 (AR) only +- **Default Config**: `hunyuan_image3_t2t.yaml` ```bash -python image_to_text.py \ - --image cherry_blossom.jpg \ - --prompt "<|startoftext|>You are an assistant that understands images and outputs text.Describe the content of the picture." +python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ + --modality text2text \ + --prompts "What is the capital of France?" ``` -Key arguments: +### Inference Steps & Guidance + +Control generation quality for image modalities: + +```bash +python end2end.py --modality text2img \ + --steps 50 \ + --guidance-scale 5.0 \ + --height 1024 --width 1024 \ + --prompts "A photo-realistic sunset over the ocean" +``` + +### Key Arguments + +#### 📌 Command Line Arguments (end2end.py) + +| Argument | Type | Default | Description | +| :--------------------- | :----- | :----------------------------------- | :----------------------------------------------------------- | +| `--model` | string | `tencent/HunyuanImage-3.0-Instruct` | Model path or name | +| `--modality` | choice | `text2img` | Modality: `text2img`, `img2img`, `img2text`, `text2text` | +| `--prompts` | list | `None` | Input text prompts | +| `--image-path` | string | `None` | Input image path (for `img2img`/`img2text`) | +| `--output` | string | `.` | Output directory for saved images | +| `--steps` | int | `50` | Number of inference steps | +| `--guidance-scale` | float | `5.0` | Classifier-free guidance scale | +| `--seed` | int | `42` | Random seed | +| `--height` | int | `1024` | Output image height | +| `--width` | int | `1024` | Output image width | +| `--bot-task` | string | auto | Override prompt task (e.g. `it2i_think`, `t2i_recaption`) | +| `--sys-type` | string | auto | Override system prompt type (e.g. `en_unified`, `en_vanilla`) | +| `--stage-configs-path` | string | auto | Custom stage config YAML path | +| `--enforce-eager` | flag | `False` | Disable torch.compile | +| `--init-timeout` | int | `300` | Initialization timeout (seconds) | + +------ + +#### ⚙️ Stage Configurations + +| Config YAML | Modality | Stages | GPUs | Description | +| :---------------------------------- | :-------- | :----- | :----- | :------------------------------------ | +| `hunyuan_image3_t2i.yaml` | text2img | 2 | 8 | T2I with AR→DiT, 4 GPU each | +| `hunyuan_image3_it2i.yaml` | img2img | 2 | 8 | IT2I with AR→DiT, 4 GPU each | +| `hunyuan_image3_i2t.yaml` | img2text | 1 | 4 | I2T (AR only) | +| `hunyuan_image3_t2t.yaml` | text2text | 1 | 4 | T2T (AR only) | +| `hunyuan_image3_t2i_2gpu.yaml` | text2img | 2 | 2 | T2I for 2-GPU setups | +| `hunyuan_image3_moe.yaml` | text2img | 2 | 8 | T2I with MoE AR→DiT KV reuse | +| `hunyuan_image3_moe_dit_2gpu_fp8.yaml` | text2img | 2 | 2 | T2I with FP8 quantization | + +------ + +## Using MoE Config + +The `hunyuan_image3_moe.yaml` config enables AR→DiT KV cache reuse with 8 GPUs (4 for AR + 4 for DiT). + +```bash +python end2end.py --model tencent/HunyuanImage-3.0-Instruct \ + --modality text2img \ + --stage-configs-path hunyuan_image3_moe.yaml \ + --prompts "A cute cat" +``` + +------ + +## Prompt Format + +HunyuanImage-3.0 uses a pretrain template format: + +``` +<|startoftext|>{system_prompt}{}{trigger_tag}{user_prompt} +``` + +- ``: Placeholder for each input image (auto-inserted by `prompt_utils.py`) +- Trigger tags: `` (CoT), `` (recaptioning) +- System prompt: Auto-selected based on task + +The `prompt_utils.build_prompt()` handles this formatting automatically. + +------ + +## FAQ + +- **OOM errors**: Decrease `gpu_memory_utilization` in the YAML stage config, or use a smaller `max_num_batched_tokens`. +- **Custom image sizes**: Use `--height` and `--width` flags (multiples of 16 recommended). -- `--model`: Model used. Default is: tencent/HunyuanImage-3.0-Instruct (Optional). -- `--image`: Path to input image (required). -- `--prompt`: Text description used to guide image understanding (required). +| Stage | VRAM (approx) | +| :---------------- | :------------------- | +| Stage 0 (AR) | ~15 GiB + KV Cache | +| Stage 1 (DiT) | ~30 GiB | +| Total (8-GPU) | ~45 GiB + KV Cache | diff --git a/examples/offline_inference/hunyuan_image3/end2end.py b/examples/offline_inference/hunyuan_image3/end2end.py new file mode 100644 index 00000000000..3c1ae386678 --- /dev/null +++ b/examples/offline_inference/hunyuan_image3/end2end.py @@ -0,0 +1,262 @@ +""" +HunyuanImage-3.0-Instruct unified end-to-end inference script. + +Supports all modalities through a single entry point: + - text2img: Text → AR → DiT → Image + - img2img: Text+Image → AR → DiT → Edited Image (IT2I) + - img2text: Image+Text → AR → Text description (I2T) + - text2text: Text → AR → Text (comprehension, no image) + +Usage: + python end2end.py --modality text2img --prompts "A cute cat" + python end2end.py --modality img2img --image-path input.png --prompts "Make it snowy" + python end2end.py --modality img2text --image-path input.png --prompts "Describe this image" +""" + +import argparse +import os + +from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import ( + get_system_prompt, +) +from vllm_omni.entrypoints.omni import Omni +from vllm_omni.inputs.data import OmniPromptType + +# task → (sys_type, bot_task, trigger_tag) +_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { + "t2t": ("en_unified", None, None), + "i2t": ("en_unified", None, None), + "it2i_think": ("en_unified", "think", ""), + "it2i_recaption": ("en_unified", "recaption", ""), + "t2i_think": ("en_unified", "think", ""), + "t2i_recaption": ("en_unified", "recaption", ""), + "t2i_vanilla": ("en_vanilla", "image", None), +} + +# Modality → prompt_utils task mapping +_MODALITY_TASK_MAP = { + "text2img": "t2i_think", + "img2img": "it2i_think", + "img2text": "i2t", + "text2text": "t2t", +} + + +def build_prompt( + user_prompt: str, + task: str = "it2i_think", + sys_type: str | None = None, + custom_system_prompt: str | None = None, +) -> str: + """Build a HunyuanImage-3.0 prompt using pretrain template format.""" + if task not in _TASK_PRESETS: + raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") + + preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] + effective_sys_type = sys_type or preset_sys_type + + system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) + sys_text = system_prompt.strip() if system_prompt else "" + + has_image_input = task.startswith("i2t") or task.startswith("it2i") + + parts = ["<|startoftext|>"] + if sys_text: + parts.append(sys_text) + if has_image_input: + parts.append("") + if trigger_tag: + parts.append(trigger_tag) + parts.append(user_prompt) + + return "".join(parts) + + +# Modality → default stage config +_MODALITY_DEFAULT_CONFIG = { + "text2img": "hunyuan_image3_t2i.yaml", + "img2img": "hunyuan_image3_it2i.yaml", + "img2text": "hunyuan_image3_i2t.yaml", + "text2text": "hunyuan_image3_t2t.yaml", +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="HunyuanImage-3.0-Instruct end-to-end inference.") + parser.add_argument( + "--model", + default="tencent/HunyuanImage-3.0-Instruct", + help="Model name or local path.", + ) + parser.add_argument( + "--modality", + default="text2img", + choices=["text2img", "img2img", "img2text", "text2text"], + help="Modality mode to control stage execution.", + ) + parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.") + parser.add_argument( + "--image-path", + type=str, + default=None, + help="Path to input image (for img2img/img2text).", + ) + parser.add_argument( + "--output", + type=str, + default=".", + help="Output directory to save results.", + ) + + # Generation parameters + parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.") + parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument("--height", type=int, default=1024, help="Output image height.") + parser.add_argument("--width", type=int, default=1024, help="Output image width.") + + # Prompt configuration + parser.add_argument( + "--bot-task", + type=str, + default=None, + help="Override prompt task (e.g. it2i_think, t2i_recaption). Default: auto from modality.", + ) + parser.add_argument( + "--sys-type", + type=str, + default=None, + help="Override system prompt type (e.g. en_unified, en_vanilla).", + ) + + # Omni init args + parser.add_argument("--stage-configs-path", type=str, default=None, help="Custom stage config YAML path.") + parser.add_argument("--log-stats", action="store_true", default=False) + parser.add_argument("--init-timeout", type=int, default=300, help="Initialization timeout in seconds.") + parser.add_argument("--enforce-eager", action="store_true", help="Disable torch.compile.") + + return parser.parse_args() + + +def main(): + args = parse_args() + os.makedirs(args.output, exist_ok=True) + + # Determine task for prompt formatting + task = args.bot_task or _MODALITY_TASK_MAP[args.modality] + + # Determine stage config + stage_configs_path = args.stage_configs_path or _MODALITY_DEFAULT_CONFIG[args.modality] + + # Build Omni + omni_kwargs = { + "model": args.model, + "stage_configs_path": stage_configs_path, + "log_stats": args.log_stats, + "init_timeout": args.init_timeout, + "enforce_eager": args.enforce_eager, + } + if args.modality in ("text2img", "img2img"): + omni_kwargs["mode"] = "text-to-image" + + omni = Omni(**omni_kwargs) + + # Prepare prompts + prompts = args.prompts or ["A cute cat"] + if not prompts: + print("[Info] No prompts provided, using default.") + prompts = ["A cute cat"] + + # Load image if needed + input_image = None + if args.modality in ("img2img", "img2text"): + if not args.image_path or not os.path.exists(args.image_path): + raise ValueError(f"--image-path required for {args.modality}, got: {args.image_path}") + from PIL import Image + + input_image = Image.open(args.image_path).convert("RGB") + + # Format prompts + formatted_prompts: list[OmniPromptType] = [] + for p in prompts: + formatted_text = build_prompt(p, task=task, sys_type=args.sys_type) + + prompt_dict: dict = {"prompt": formatted_text} + + if args.modality == "text2img": + prompt_dict["modalities"] = ["image"] + elif args.modality == "img2img": + prompt_dict["modalities"] = ["image"] + prompt_dict["multi_modal_data"] = {"image": input_image} + prompt_dict["height"] = input_image.height + prompt_dict["width"] = input_image.width + elif args.modality == "img2text": + prompt_dict["modalities"] = ["text"] + prompt_dict["multi_modal_data"] = {"image": input_image} + elif args.modality == "text2text": + prompt_dict["modalities"] = ["text"] + + formatted_prompts.append(prompt_dict) + + # Build sampling params from defaults + params_list = list(omni.default_sampling_params_list) + + # Override diffusion params if applicable + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + + for i, sp in enumerate(params_list): + if isinstance(sp, OmniDiffusionSamplingParams): + sp.num_inference_steps = args.steps + sp.guidance_scale = args.guidance_scale + if args.seed is not None: + sp.seed = args.seed + if args.modality in ("text2img",): + sp.height = args.height + sp.width = args.width + + # Print configuration + print(f"\n{'=' * 60}") + print("HunyuanImage-3.0 Generation Configuration:") + print(f" Model: {args.model}") + print(f" Modality: {args.modality}") + print(f" Stage config: {stage_configs_path}") + print(f" Num stages: {omni.num_stages}") + if args.modality in ("text2img", "img2img"): + print(f" Inference steps: {args.steps}") + print(f" Guidance scale: {args.guidance_scale}") + print(f" Seed: {args.seed}") + if args.modality == "text2img": + print(f" Output size: {args.width}x{args.height}") + if args.image_path: + print(f" Input image: {args.image_path}") + print(f" Prompts: {prompts}") + print(f"{'=' * 60}\n") + + # Generate + omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list)) + + # Process outputs + img_idx = 0 + for req_output in omni_outputs: + # Text output (AR stage or text-only) + ro = getattr(req_output, "request_output", None) + if ro and getattr(ro, "outputs", None): + txt = "".join(getattr(o, "text", "") or "" for o in ro.outputs) + if txt: + print(f"[Output] Text:\n{txt}") + + # Image output (DiT stage) + images = getattr(req_output, "images", None) + if not images and ro and hasattr(ro, "images"): + images = ro.images + + if images: + for j, img in enumerate(images): + save_path = os.path.join(args.output, f"output_{img_idx}_{j}.png") + img.save(save_path) + print(f"[Output] Saved image to {save_path}") + img_idx += 1 + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/hunyuan_image3/image_to_text.py b/examples/offline_inference/hunyuan_image3/image_to_text.py deleted file mode 100644 index d40134ac0a0..00000000000 --- a/examples/offline_inference/hunyuan_image3/image_to_text.py +++ /dev/null @@ -1,92 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import os - -from PIL import Image - -from vllm_omni.entrypoints.omni import Omni - -""" -The tencent/HunyuanImage-3.0-Instruct base model uses the tencent/Hunyuan-A13B-Instruct backbone. It utilizes two tokenizer delimiter templates: - -1) Pretrained template (default for gen_text mode), which concatenates system, image - tokens, and user question WITHOUT role delimiters: -"<|startoftext|>{system_prompt}{image_tokens}{user_question}" - - Example (before image token expansion): -"<|startoftext|>You are an assistant that understands images and outputs text.Describe the content of the picture." - -2) Instruct template, which uses explicit role prefixes and separators. -""" - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Generate text from image using HunyuanImage-3.0-Instruct.") - parser.add_argument( - "--model", - default="tencent/HunyuanImage-3.0-Instruct", - help="Model name or local path.", - ) - parser.add_argument( - "--image", - type=str, - required=True, - help="Path to input image file (PNG, JPG, etc.).", - ) - parser.add_argument( - "--prompt", - type=str, - required=True, - help="Pretrain template prompt: <|startoftext|>{system}{question}", - ) - parser.add_argument( - "--enable-diffusion-pipeline-profiler", - action="store_true", - help="Enable diffusion pipeline profiler to display stage durations.", - ) - return parser.parse_args() - - -def load_image(image_path: str) -> Image.Image: - """Load an image from file path.""" - if not os.path.exists(image_path): - raise FileNotFoundError(f"Image file not found: {image_path}") - return Image.open(image_path).convert("RGB") - - -def main(args: argparse.Namespace) -> None: - omni = Omni( - model=args.model, - enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, - mode="image-to-text", - ) - - prompt = "<|startoftext|>You are an assistant that understands images and outputs text." + args.prompt - - prompt_dict = { - "prompt": prompt, - "modalities": ["text"], - } - - # Add image input if provided - if args.image: - if not os.path.exists(args.image): - raise FileNotFoundError(f"Input image not found: {args.image}") - - input_image = load_image(args.image) - prompt_dict["multi_modal_data"] = {"image": input_image} - - prompts = [prompt_dict] - omni_outputs = omni.generate(prompts=prompts) - - prompt_text = omni_outputs[0].request_output.prompt - generated_text = omni_outputs[0].request_output.outputs[0].text - print(f"Prompt: {prompt_text}") - print(f"Text: {generated_text}") - - -if __name__ == "__main__": - args = parse_args() - main(args) diff --git a/examples/offline_inference/hunyuan_image3/prompt_utils.py b/examples/offline_inference/hunyuan_image3/prompt_utils.py deleted file mode 100644 index a5ef8e15369..00000000000 --- a/examples/offline_inference/hunyuan_image3/prompt_utils.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Prompt construction utilities for HunyuanImage-3.0-Instruct examples. - -Wraps system_prompt.get_system_prompt() with task-aware presets so that -examples and tests don't need to manually concatenate system prompts, -, , and tags. - -Usage: - from prompt_utils import build_prompt - - # IT2I (image editing, think+recaption mode) - prompt = build_prompt("Make the petals neon pink", task="it2i_think") - - # I2T (image understanding) - prompt = build_prompt("Describe the content of the picture.", task="i2t") -""" - -from __future__ import annotations - -from vllm_omni.diffusion.models.hunyuan_image3.system_prompt import ( - get_system_prompt, -) - -# task → (sys_type, bot_task, trigger_tag) -# trigger_tag: "", "", or None -_TASK_PRESETS: dict[str, tuple[str, str | None, str | None]] = { - # Pure text generation (text → text, no image) - "t2t": ("en_unified", None, None), - # Image understanding (image → text) - "i2t": ("en_unified", None, None), - # Image editing (image+text → image), think+recaption mode - "it2i_think": ("en_unified", "think", ""), - # Image editing, recaption-only mode - "it2i_recaption": ("en_unified", "recaption", ""), - # Text-to-image, think mode - "t2i_think": ("en_unified", "think", ""), - # Text-to-image, recaption mode - "t2i_recaption": ("en_unified", "recaption", ""), - # Text-to-image, vanilla (no CoT) - "t2i_vanilla": ("en_vanilla", "image", None), -} - - -def build_prompt( - user_prompt: str, - task: str = "it2i_think", - sys_type: str | None = None, - custom_system_prompt: str | None = None, -) -> str: - """Build a complete HunyuanImage-3.0 prompt with auto-selected system - prompt and mode trigger tags. - - Args: - user_prompt: The user's raw instruction or question. - task: One of the preset task keys (see _TASK_PRESETS). - sys_type: Override the preset's sys_type for get_system_prompt(). - custom_system_prompt: Custom system prompt text (used when - sys_type="custom"). - - Returns: - Fully formatted prompt string ready for Omni.generate(). - """ - if task not in _TASK_PRESETS: - raise ValueError(f"Unknown task {task!r}. Choose from: {sorted(_TASK_PRESETS)}") - - preset_sys_type, preset_bot_task, trigger_tag = _TASK_PRESETS[task] - effective_sys_type = sys_type or preset_sys_type - - system_prompt = get_system_prompt(effective_sys_type, preset_bot_task, custom_system_prompt) - sys_text = system_prompt.strip() if system_prompt else "" - - has_image_input = task.startswith("i2t") or task.startswith("it2i") - - parts = ["<|startoftext|>"] - if sys_text: - parts.append(sys_text) - # Instruct conversation template: \n\nUser: ... \n\nAssistant: - parts.append("\n\nUser: ") - if has_image_input: - parts.append("") - parts.append(user_prompt) - parts.append("\n\nAssistant: ") - if trigger_tag: - parts.append(trigger_tag) - - return "".join(parts) diff --git a/pyproject.toml b/pyproject.toml index 9b034a7c8e9..012bcd47c41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,3 +237,4 @@ ue = "ue" semantics = "semantics" fullset = "fullset" Vai = "Vai" +CANN = "CANN" diff --git a/recipes/README.md b/recipes/README.md index 5b3dfb5430b..01ecc41f185 100644 --- a/recipes/README.md +++ b/recipes/README.md @@ -27,6 +27,8 @@ recipes/ - [`Qwen/Qwen3-Omni.md`](./Qwen/Qwen3-Omni.md): online serving recipe for multimodal chat on `1x A100 80GB` +- [`Wan-AI/Wan2.2-I2V.md`](./Wan-AI/Wan2.2-I2V.md): image-to-video serving + recipe for Wan2.2 14B on `8x Ascend NPU (A2/A3)` Within a single recipe file, include different hardware support sections such as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like diff --git a/recipes/Wan-AI/Wan2.2-I2V.md b/recipes/Wan-AI/Wan2.2-I2V.md new file mode 100644 index 00000000000..99ceac3cebe --- /dev/null +++ b/recipes/Wan-AI/Wan2.2-I2V.md @@ -0,0 +1,136 @@ +# Wan2.2 Image To Video + +## Summary + +- Vendor: Wan-AI +- Model: `Wan-AI/Wan2.2-I2V-A14B-Diffusers` +- Task: Image-to-video generation +- Mode: Online serving with the OpenAI-compatible API +- Maintainer: Community + +## When to use this recipe + +Use this recipe when you want to deploy the Wan2.2 14B image-to-video model +with vLLM-Omni using multi-card parallelism. Two configurations are provided: + +1. **Distilled model (no negative-prompt / CFG computation)** — higher + throughput, recommended when using a distilled checkpoint that does not + require classifier-free guidance. +2. **Official open-source model (with CFG)** — uses `--cfg 2` to run negative + and positive samples in parallel for the original released weights. + +## References + +- Upstream model card: + +## Hardware Support + +## NPU + +### 8x Ascend A2 / A3 + +#### Environment + +- OS: Linux +- Python: 3.10+ +- Driver / runtime: Ascend NPU driver with CANN toolkit +- Recommended operator library: **mindie-sd** (Ascend high-performance fused + operators — enables `adalayernorm` and other fused kernels automatically upon + installation) +- vLLM version: Match the repository requirements for your checkout +- vLLM-Omni version or commit: Use the commit you are deploying from + +#### Prerequisites + +Install the **mindie-sd** operator library to enable Ascend-optimized fused +operators (`adalayernorm`, etc.): + +```bash +git clone https://gitcode.com/Ascend/MindIE-SD.git && cd MindIE-SD + +# Comment out the tik_ops build step (not needed for this use case) +sed -i 's|^\(\s*\)source ${current_script_dir}/build_tik_ops.sh|\1# source ${current_script_dir}/build_tik_ops.sh|' build/build_ops.sh + +python setup.py bdist_wheel +cd dist +pip install mindiesd-*.whl +``` + +After installation, enable the Laser Attention kernel for significant +long-sequence speedups (up to ~40% at 720p in tested workloads): + +```bash +export MINDIE_SD_FA_TYPE=ascend_laser_attention +``` + +When using HSDP with FSDP2, set the following environment variable to work +around a PyTorch NPU multi-stream memory reuse issue +([pytorch/pytorch#147168](https://github.com/pytorch/pytorch/issues/147168)). +This issue has been fixed on CUDA but still applies to NPU: + +```bash +export MULTI_STREAM_MEMORY_REUSE=2 +``` + +#### Command + +**Distilled model (no CFG, recommended for distilled checkpoints):** + +```bash +export MINDIE_SD_FA_TYPE=ascend_laser_attention +export MULTI_STREAM_MEMORY_REUSE=2 + +vllm serve \ + --omni Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --use-hsdp \ + --usp 8 \ + --vae-patch-parallel-size 8 \ + --vae-use-tiling +``` + +**Official open-source model (with CFG):** + +```bash +export MINDIE_SD_FA_TYPE=ascend_laser_attention +export MULTI_STREAM_MEMORY_REUSE=2 + +vllm serve \ + --omni Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --use-hsdp \ + --usp 4 \ + --cfg 2 \ + --vae-patch-parallel-size 8 \ + --vae-use-tiling +``` + +> **Why the difference?** With `--cfg 2`, two copies of the input (positive and +> negative prompts) are processed in parallel, effectively doubling the compute +> for the DiT backbone. USP is therefore halved from 8 to 4 so that the total +> parallelism across the 8 cards remains balanced (`usp * cfg = 8`). + +#### Verification + +After the server is ready, see +[`examples/online_serving/image_to_video/README.md`](../../examples/online_serving/image_to_video/README.md) +for complete client examples and request formats. + +#### Notes + +- **Key flags:** + - `--omni` — enables vLLM-Omni diffusion serving. + - `--use-hsdp` — enables Hybrid Sharded Data Parallelism for the DiT model + weights. + - `--usp ` — Unified Sequence Parallelism degree. + - `--cfg ` — Classifier-Free Guidance parallelism; set to 2 for models + that require negative-prompt computation, omit for distilled models. + - `--vae-patch-parallel-size 8` — parallelizes VAE decoding across all 8 + cards. + - `--vae-use-tiling` — enables tiled VAE decoding to reduce peak memory. +- **Performance tips:** + - Installing mindie-sd and enabling Laser Attention + (`MINDIE_SD_FA_TYPE=ascend_laser_attention`) provides up to ~40% + performance improvement at 720p resolution due to long-sequence attention + optimization. +- **Known limitations:** + - `MULTI_STREAM_MEMORY_REUSE=2` is required on NPU when using HSDP/FSDP2 + due to a multi-stream memory reuse bug. This is not needed on CUDA. diff --git a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json index 5ec7f1cc2b6..cdd0cac2c03 100644 --- a/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json +++ b/tests/dfx/perf/tests/test_qwen_image_vllm_omni.json @@ -105,22 +105,6 @@ } }, "benchmark_params": [ - { - "name": "512x512_steps20", - "dataset": "random", - "task": "t2i", - "width": 512, - "height": 512, - "num-inference-steps": 20, - "num-prompts": 10, - "max-concurrency": 1, - "enable-negative-prompt": true, - "baseline": { - "throughput_qps": 0.1, - "latency_mean": 2.7, - "peak_memory_mb_mean": 61000 - } - }, { "name": "1536x1536_steps35", "dataset": "random", diff --git a/tests/diffusion/models/dmd2/__init__.py b/tests/diffusion/models/dmd2/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py new file mode 100644 index 00000000000..e270390bd99 --- /dev/null +++ b/tests/diffusion/models/dmd2/test_dmd2_request_sanitization.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests). +_DMD2_BASE = { + WanT2VDMD2Pipeline: Wan22Pipeline, + WanI2VDMD2Pipeline: Wan22I2VPipeline, + LTX2T2VDMD2Pipeline: LTX2Pipeline, + LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline, +} + + +def _make_pipeline(cls): + """Run the DMD2 __init__ with the base pipeline mocked out (no model weights loaded).""" + + base = _DMD2_BASE[cls] + od_config = MagicMock() + od_config.model = "/nonexistent" + + def _mock_base_init(self, *a, **kw): + self.od_config = od_config + + with patch.object(base, "__init__", _mock_base_init): + pipeline = object.__new__(cls) + torch.nn.Module.__init__(pipeline) + cls.__init__(pipeline, od_config=od_config) + return pipeline + + +def _make_request(prompts=None, **sp_kwargs) -> OmniDiffusionRequest: + sp = OmniDiffusionSamplingParams(**sp_kwargs) + return OmniDiffusionRequest( + prompts=prompts or [{"prompt": "a cat dancing"}], + sampling_params=sp, + ) + + +@pytest.fixture( + params=list(_DMD2_BASE.keys()), + ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"], +) +def pipeline(request): + return _make_pipeline(request.param) + + +# --------------------------------------------------------------------------- +# num_inference_steps +# --------------------------------------------------------------------------- + + +def test_num_inference_steps_forced_to_dmd2_value(pipeline): + req = _make_request(num_inference_steps=40) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps + + +def test_num_inference_steps_already_correct(pipeline): + req = _make_request(num_inference_steps=pipeline.num_inference_steps) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.num_inference_steps == pipeline.num_inference_steps + + +# --------------------------------------------------------------------------- +# guidance_scale +# --------------------------------------------------------------------------- + + +def test_guidance_scale_forced_to_one(pipeline): + req = _make_request(guidance_scale=5.0, guidance_scale_provided=True) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale + assert req.sampling_params.guidance_scale_provided is False + + +def test_guidance_scale_already_correct(pipeline): + req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=False) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale + + +def test_guidance_scale_provided_flag_cleared(pipeline): + """guidance_scale_provided=True must be cleared even if scale is already dmd2_guidance_scale.""" + req = _make_request(guidance_scale=pipeline.dmd2_guidance_scale, guidance_scale_provided=True) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale_provided is False + + +def test_guidance_scale_2_cleared(pipeline): + req = _make_request(guidance_scale_2=3.0) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale_2 is None + + +def test_guidance_scale_2_unset_unchanged(pipeline): + req = _make_request() + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale_2 is None + + +def test_true_cfg_scale_cleared(pipeline): + req = _make_request(true_cfg_scale=2.0) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.true_cfg_scale is None + + +def test_do_classifier_free_guidance_forced_false(pipeline): + req = _make_request(do_classifier_free_guidance=True) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.do_classifier_free_guidance is False + + +def test_is_cfg_negative_forced_false(pipeline): + req = _make_request(is_cfg_negative=True) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.is_cfg_negative is False + + +def test_negative_prompt_stripped_from_prompt_dict(pipeline): + req = _make_request(prompts=[{"prompt": "a cat", "negative_prompt": "blurry"}]) + pipeline._sanitize_dmd2_request(req) + assert "negative_prompt" not in req.prompts[0] + assert req.prompts[0]["prompt"] == "a cat" + + +def test_no_negative_prompt_unchanged(pipeline): + req = _make_request(prompts=[{"prompt": "a cat"}]) + pipeline._sanitize_dmd2_request(req) + assert req.prompts[0] == {"prompt": "a cat"} + + +def test_string_prompt_not_mutated(pipeline): + """String prompts (not dicts) must pass through unchanged.""" + req = _make_request(prompts=["a cat dancing"]) + pipeline._sanitize_dmd2_request(req) + assert req.prompts == ["a cat dancing"] + + +def test_multiple_prompts_all_sanitized(pipeline): + req = _make_request( + prompts=[ + {"prompt": "a cat", "negative_prompt": "blurry"}, + {"prompt": "a dog", "negative_prompt": "ugly"}, + ] + ) + pipeline._sanitize_dmd2_request(req) + for p in req.prompts: + assert "negative_prompt" not in p + + +# --------------------------------------------------------------------------- +# Clean request — nothing changes +# --------------------------------------------------------------------------- + + +def test_clean_request_no_changes(pipeline): + req = _make_request( + guidance_scale=pipeline.dmd2_guidance_scale, + guidance_scale_provided=False, + do_classifier_free_guidance=False, + is_cfg_negative=False, + ) + pipeline._sanitize_dmd2_request(req) + assert req.sampling_params.guidance_scale == pipeline.dmd2_guidance_scale + assert req.sampling_params.guidance_scale_provided is False + assert req.sampling_params.guidance_scale_2 is None + assert req.sampling_params.true_cfg_scale is None + assert req.sampling_params.do_classifier_free_guidance is False + assert req.sampling_params.is_cfg_negative is False diff --git a/tests/diffusion/models/dmd2/test_dmd2_scheduler.py b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py new file mode 100644 index 00000000000..32d00dbf18e --- /dev/null +++ b/tests/diffusion/models/dmd2/test_dmd2_scheduler.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import LTX2Pipeline, LTX2T2VDMD2Pipeline +from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline, WanT2VDMD2Pipeline +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import Wan22I2VPipeline, WanI2VDMD2Pipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest, OmniDiffusionSamplingParams + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + +_DMD2_TIMESTEPS = [999, 937, 833, 624] + +# DMD2 subclass → immediate base pipeline whose __init__ loads model weights (mocked in tests). +_DMD2_BASE = { + WanT2VDMD2Pipeline: Wan22Pipeline, + WanI2VDMD2Pipeline: Wan22I2VPipeline, + LTX2T2VDMD2Pipeline: LTX2Pipeline, + LTX2I2VDMD2Pipeline: LTX2ImageToVideoPipeline, +} + + +def _make_pipeline(cls): + """Run the DMD2 __init__ (including __init_dmd2__) with the base pipeline mocked.""" + + base = _DMD2_BASE[cls] + od_config = MagicMock() + od_config.model = "/nonexistent" + + def _mock_base_init(self, *a, **kw): + self.od_config = od_config # __init_dmd2__ needs this + + with patch.object(base, "__init__", _mock_base_init): + pipeline = object.__new__(cls) + torch.nn.Module.__init__(pipeline) + cls.__init__(pipeline, od_config=od_config) + return pipeline + + +def _make_request(**sp_kwargs) -> OmniDiffusionRequest: + sp = OmniDiffusionSamplingParams(**sp_kwargs) + return OmniDiffusionRequest(prompts=[{"prompt": "a cat"}], sampling_params=sp) + + +@pytest.fixture( + params=list(_DMD2_BASE.keys()), + ids=["wan_t2v", "wan_i2v", "ltx2_t2v", "ltx2_i2v"], +) +def pipeline(request): + return _make_pipeline(request.param) + + +# --------------------------------------------------------------------------- +# forward() timestep injection +# --------------------------------------------------------------------------- + + +def _fake_parent_forward(self, req, *args, num_inference_steps=40, **kwargs): + """Stub that calls set_timesteps as the real parent does.""" + self.scheduler.set_timesteps(num_inference_steps, device="cpu") + return MagicMock() + + +def test_forward_timesteps_match_dmd2_schedule(pipeline): + """After forward() runs, scheduler.timesteps must equal the DMD2 training schedule.""" + parent = _DMD2_BASE[type(pipeline)] + + # Baseline: calling set_timesteps(40) without the DMD2 override gives a different schedule + pipeline.scheduler.set_timesteps(40, device="cpu") + default_timesteps = pipeline.scheduler.timesteps.long().tolist() + assert default_timesteps == _DMD2_TIMESTEPS, ( + "DMD2EulerScheduler should always return DMD2 timesteps regardless of num_steps" + ) + + with patch.object(parent, "forward", _fake_parent_forward): + pipeline.forward(_make_request()) + + assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS + + +def test_forward_timesteps_idempotent_across_calls(pipeline): + """Successive forward() calls must not cause scheduler state to drift.""" + parent = _DMD2_BASE[type(pipeline)] + + with patch.object(parent, "forward", _fake_parent_forward): + pipeline.forward(_make_request()) + pipeline.forward(_make_request()) + + assert pipeline.scheduler.timesteps.long().tolist() == _DMD2_TIMESTEPS diff --git a/tests/diffusion/models/ltx2/test_ltx2_hsdp.py b/tests/diffusion/models/ltx2/test_ltx2_hsdp.py new file mode 100644 index 00000000000..4dd07e1bf82 --- /dev/null +++ b/tests/diffusion/models/ltx2/test_ltx2_hsdp.py @@ -0,0 +1,25 @@ +import pytest +import torch.nn as nn + +from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def test_ltx2_exposes_hsdp_shard_conditions_for_transformer_blocks(): + model = object.__new__(LTX2VideoTransformer3DModel) + nn.Module.__init__(model) + model.transformer_blocks = nn.ModuleList([nn.Linear(4, 4) for _ in range(2)]) + model.norm_out = nn.LayerNorm(4) + + conditions = getattr(model, "_hsdp_shard_conditions", None) + + assert conditions is not None + assert len(conditions) == 1 + + matched = [] + for name, module in model.named_modules(): + if any(cond(name, module) for cond in conditions): + matched.append(name) + + assert matched == ["transformer_blocks.0", "transformer_blocks.1"] diff --git a/tests/diffusion/models/wan2_2/__init__.py b/tests/diffusion/models/wan2_2/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py b/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py deleted file mode 100644 index 64c2b271c9c..00000000000 --- a/tests/diffusion/models/wan2_2/test_wan22_max_sequence_length.py +++ /dev/null @@ -1,150 +0,0 @@ -from types import SimpleNamespace - -import PIL.Image -import pytest -import torch -from torch import nn - -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, - Wan22Pipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import ( - Wan22I2VPipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import ( - Wan22TI2VPipeline, -) -from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import ( - Wan22VACEPipeline, -) - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - - -class _RejectingTextEncoder: - dtype = torch.float32 - - def __call__(self, *args, **kwargs): - raise AssertionError("text encoder should not run for prompts that exceed max_sequence_length") - - -class _FakeTokenBatch: - def __init__(self, total_sequence_length: int): - attention_mask = torch.ones((1, total_sequence_length), dtype=torch.long) - self.input_ids = attention_mask.clone() - self.attention_mask = attention_mask - - -class _FakeTokenizer: - def __init__(self, total_sequence_length: int): - self.total_sequence_length = total_sequence_length - - def __call__(self, *args, **kwargs): - return _FakeTokenBatch(self.total_sequence_length) - - -PIPELINE_CASES = [ - pytest.param(Wan22Pipeline, id="wan22-t2v"), - pytest.param(Wan22I2VPipeline, id="wan22-i2v"), - pytest.param(Wan22TI2VPipeline, id="wan22-ti2v"), - pytest.param(Wan22VACEPipeline, id="wan22-vace"), -] - - -def _make_pipeline(pipeline_class: type, *, total_sequence_length: int): - pipeline = object.__new__(pipeline_class) - nn.Module.__init__(pipeline) - pipeline.device = torch.device("cpu") - pipeline.text_encoder = _RejectingTextEncoder() - pipeline.tokenizer = _FakeTokenizer(total_sequence_length) - pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH - return pipeline - - -@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) -def test_encode_prompt_rejects_prompt_longer_than_default_max_sequence_length(pipeline_class: type): - pipeline = _make_pipeline(pipeline_class, total_sequence_length=WAN22_MAX_SEQUENCE_LENGTH + 1) - - with pytest.raises(ValueError, match=r"got 513 tokens, but `max_sequence_length` is 512"): - pipeline.encode_prompt(prompt="prompt") - - -@pytest.mark.parametrize("pipeline_class", PIPELINE_CASES) -def test_encode_prompt_rejects_prompt_longer_than_explicit_max_sequence_length(pipeline_class: type): - pipeline = _make_pipeline(pipeline_class, total_sequence_length=17) - - with pytest.raises(ValueError, match=r"got 17 tokens, but `max_sequence_length` is 16"): - pipeline.encode_prompt(prompt="prompt", max_sequence_length=16) - - -def _sampling_params(**overrides): - defaults = dict( - height=None, - width=None, - num_frames=None, - num_inference_steps=None, - generator=None, - guidance_scale_provided=False, - guidance_scale_2=None, - boundary_ratio=None, - num_outputs_per_prompt=0, - max_sequence_length=None, - seed=None, - extra_args={}, - prompt_embeds=None, - negative_prompt_embeds=None, - ) - defaults.update(overrides) - return SimpleNamespace(**defaults) - - -@pytest.mark.parametrize( - ("pipeline_class", "prompt_value", "forward_kwargs"), - [ - pytest.param(Wan22Pipeline, "prompt", {}, id="wan22-t2v"), - pytest.param( - Wan22I2VPipeline, - {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, - {"image": PIL.Image.new("RGB", (64, 64))}, - id="wan22-i2v", - ), - pytest.param( - Wan22TI2VPipeline, - {"prompt": "prompt", "multi_modal_data": {"image": PIL.Image.new("RGB", (64, 64))}}, - {"image": PIL.Image.new("RGB", (64, 64))}, - id="wan22-ti2v", - ), - pytest.param(Wan22VACEPipeline, "prompt", {}, id="wan22-vace"), - ], -) -def test_forward_defaults_to_wan22_tokenizer_max_length( - pipeline_class: type, - prompt_value, - forward_kwargs, -): - pipeline = object.__new__(pipeline_class) - nn.Module.__init__(pipeline) - pipeline.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH - pipeline.boundary_ratio = None - pipeline.vae_scale_factor_temporal = 4 - pipeline.vae_scale_factor_spatial = 8 - pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2)) - - captured = {} - - def _fake_check_inputs(*args, **kwargs): - captured["max_sequence_length"] = kwargs["max_sequence_length"] - raise RuntimeError("stop after capture") - - pipeline.check_inputs = _fake_check_inputs - - req = SimpleNamespace( - prompts=[prompt_value], - sampling_params=_sampling_params(), - ) - - with pytest.raises(RuntimeError, match="stop after capture"): - pipeline.forward(req, **forward_kwargs) - - assert captured["max_sequence_length"] == WAN22_MAX_SEQUENCE_LENGTH diff --git a/tests/diffusion/test_diffusion_worker.py b/tests/diffusion/test_diffusion_worker.py index e2bd7ef8a32..fc08c5f7f03 100644 --- a/tests/diffusion/test_diffusion_worker.py +++ b/tests/diffusion/test_diffusion_worker.py @@ -16,7 +16,7 @@ from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker -pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu] @pytest.fixture @@ -81,17 +81,31 @@ def test_load_weights_empty_iterable(self, mocker: MockerFixture, mock_gpu_worke class TestDiffusionWorkerSleep: """Test DiffusionWorker.sleep method.""" + @pytest.fixture(autouse=True) + def setup_allocator(self, mocker: MockerFixture): + """ + Unified interception of Allocators, and provision of default security values. + """ + self.mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator") + self.mock_allocator = mocker.Mock() + self.mock_allocator_class.get_instance.return_value = self.mock_allocator + self.mock_allocator.get_current_usage.return_value = 4 * 1024**3 + self.mock_allocator.sleep = mocker.Mock() + def test_sleep_level_1(self, mocker: MockerFixture, mock_gpu_worker): """Test sleep mode level 1 (offload weights only).""" mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator") - mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + mock_platform.get_free_memory.side_effect = [10 * 1024**3, 12 * 1024**3] + mock_platform.get_device_total_memory.return_value = 80 * 1024**3 mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory") # Setup process-scoped memory mocks # Before sleep: 3GB used # After sleep: 1GB used (freed 2GB) + initial_usage = 3 * 1024**3 mock_get_process_memory.side_effect = [ - 3 * 1024**3, + initial_usage, 1 * 1024**3, ] @@ -99,25 +113,29 @@ def test_sleep_level_1(self, mocker: MockerFixture, mock_gpu_worker): mock_allocator = mocker.Mock() mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator) mock_allocator.sleep = mocker.Mock() + mock_allocator.get_current_usage.return_value = initial_usage # Call sleep with level 1 result = mock_gpu_worker.sleep(level=1) # Verify sleep was called with correct tags mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",)) - assert result is True + assert bool(result) is True # Verify buffers were NOT saved (level 1 doesn't save buffers) assert len(mock_gpu_worker._sleep_saved_buffers) == 0 def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker): """Test sleep mode level 2 (offload all, save buffers).""" mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator") - mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + mock_platform.get_free_memory.side_effect = [5 * 1024**3, 10 * 1024**3] + mock_platform.get_device_total_memory.return_value = 80 * 1024**3 mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory") # Setup process-scoped memory mocks + initial_usage = 5 * 1024**3 mock_get_process_memory.side_effect = [ - 5 * 1024**3, # Before sleep + initial_usage, # Before sleep 1 * 1024**3, # After sleep (freed 4GB) ] @@ -125,6 +143,7 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker): mock_allocator = mocker.Mock() mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator) mock_allocator.sleep = mocker.Mock() + mock_allocator.get_current_usage.return_value = initial_usage # Mock pipeline buffers mock_buffer1 = torch.randn(10, 10) @@ -141,7 +160,7 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker): # Verify sleep was called with empty tags (offload all) mock_allocator.sleep.assert_called_once_with(offload_tags=tuple()) - assert result is True + assert bool(result) is True # Verify buffers were saved assert len(mock_gpu_worker._sleep_saved_buffers) == 2 @@ -151,22 +170,26 @@ def test_sleep_level_2(self, mocker: MockerFixture, mock_gpu_worker): def test_sleep_memory_freed_validation(self, mocker: MockerFixture, mock_gpu_worker): """Test that sleep validates memory was actually freed.""" mock_allocator_class = mocker.patch("vllm.device_allocator.cumem.CuMemAllocator") - mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + mock_platform = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.current_omni_platform") + mock_platform.get_free_memory.return_value = 10 * 1024**3 + mock_platform.get_device_total_memory.return_value = 80 * 1024**3 mock_get_process_memory = mocker.patch("vllm_omni.diffusion.worker.diffusion_worker.get_process_gpu_memory") # Simulate process memory increase (should trigger assertion error) + initial_usage = 1 * 1024**3 mock_get_process_memory.side_effect = [ - 1 * 1024**3, # Before sleep: 1GB used + initial_usage, # Before sleep: 1GB used 3 * 1024**3, # After sleep: 3GB used (negative freed) ] mock_allocator = mocker.Mock() mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator) mock_allocator.sleep = mocker.Mock() + mock_allocator.get_current_usage.return_value = initial_usage # This should raise an assertion error - with pytest.raises(AssertionError, match="Memory usage increased after sleeping"): - mock_gpu_worker.sleep(level=1) + result = mock_gpu_worker.sleep(level=1) + assert result == initial_usage def test_sleep_falls_back_to_device_memory_when_nvml_unavailable(self, mocker: MockerFixture, mock_gpu_worker): """Test sleep uses device-scoped fallback when NVML is unavailable.""" @@ -184,11 +207,12 @@ def test_sleep_falls_back_to_device_memory_when_nvml_unavailable(self, mocker: M mock_allocator = mocker.Mock() mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator) mock_allocator.sleep = mocker.Mock() + mock_allocator.get_current_usage.return_value = 2 * 1024**3 result = mock_gpu_worker.sleep(level=1) mock_allocator.sleep.assert_called_once_with(offload_tags=("weights",)) - assert result is True + assert bool(result) is True class TestDiffusionWorkerWakeUp: @@ -202,6 +226,7 @@ def test_wake_up_without_buffers(self, mocker: MockerFixture, mock_gpu_worker): mock_allocator = mocker.Mock() mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator) mock_allocator.wake_up = mocker.Mock() + mock_allocator.get_current_usage.return_value = 10 * 1024**3 # Ensure no saved buffers mock_gpu_worker._sleep_saved_buffers = {} @@ -211,7 +236,7 @@ def test_wake_up_without_buffers(self, mocker: MockerFixture, mock_gpu_worker): # Verify allocator.wake_up was called mock_allocator.wake_up.assert_called_once_with(["weights"]) - assert result is True + assert bool(result) is True def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker): """Test wake_up with saved buffers (level 2 sleep).""" @@ -221,6 +246,7 @@ def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker): mock_allocator = mocker.Mock() mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator) mock_allocator.wake_up = mocker.Mock() + mock_allocator.get_current_usage.return_value = 10 * 1024**3 # Create saved buffers saved_buffer1 = torch.randn(10, 10) @@ -255,7 +281,7 @@ def test_wake_up_with_buffers(self, mocker: MockerFixture, mock_gpu_worker): # Verify saved buffers were cleared assert len(mock_gpu_worker._sleep_saved_buffers) == 0 - assert result is True + assert bool(result) is True def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_worker): """Test wake_up only restores buffers that were saved.""" @@ -265,6 +291,7 @@ def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_wo mock_allocator = mocker.Mock() mock_allocator_class.get_instance = mocker.Mock(return_value=mock_allocator) mock_allocator.wake_up = mocker.Mock() + mock_allocator.get_current_usage.return_value = 10 * 1024**3 # Only save buffer1, not buffer2 saved_buffer1 = torch.randn(10, 10) @@ -293,4 +320,4 @@ def test_wake_up_partial_buffer_restore(self, mocker: MockerFixture, mock_gpu_wo # buffer2 should NOT be restored since it wasn't saved mock_buffer2.data.copy_.assert_not_called() - assert result is True + assert bool(result) is True diff --git a/tests/diffusion/test_multiproc_engine_concurrency.py b/tests/diffusion/test_multiproc_engine_concurrency.py index 4bc3e05fe91..9ec06e8107d 100644 --- a/tests/diffusion/test_multiproc_engine_concurrency.py +++ b/tests/diffusion/test_multiproc_engine_concurrency.py @@ -1,17 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import multiprocessing as mp import queue import threading +import time from types import SimpleNamespace +from unittest.mock import MagicMock, Mock import pytest import torch +import zmq +from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.diffusion.data import DiffusionOutput from vllm_omni.diffusion.diffusion_engine import DiffusionEngine from vllm_omni.diffusion.executor.multiproc_executor import MultiprocDiffusionExecutor from vllm_omni.diffusion.sched import RequestScheduler +from vllm_omni.diffusion.stage_diffusion_proc import StageDiffusionProc +from vllm_omni.outputs import OmniRequestOutput pytestmark = [pytest.mark.diffusion, pytest.mark.core_model, pytest.mark.cpu] @@ -51,6 +59,8 @@ def _make_executor(num_gpus: int = 1): executor._result_mq = mock_rmq executor._closed = False executor._processes = [] + executor.is_failed = False + executor._failure_callbacks = [] return executor, req_q, res_q @@ -334,8 +344,9 @@ def test_collective_rpc_closed_executor_raises(self): class TestCollectiveRpcTimeoutWhileLockHeld: """``collective_rpc(timeout=...)`` must honour its timeout even when - another thread holds ``engine._rpc_lock`` indefinitely (e.g. a stalled - ``add_req`` waiting on an unresponsive worker). + another thread holds ``engine._rpc_lock`` indefinitely (e.g. request + execution stalled on ``add_req_and_wait_for_response`` → ``execute_fn`` + → ``collective_rpc`` while blocked on an unresponsive worker). """ def test_rpc_times_out_when_lock_held_directly(self): @@ -359,10 +370,10 @@ def _hold_lock(): with pytest.raises(TimeoutError): engine.collective_rpc("health", timeout=0.5) - def test_rpc_times_out_when_add_req_stalled_on_worker(self): + def test_rpc_times_out_when_request_execution_stalled_on_worker(self): """Real-world scenario the bot flagged: - ``add_req`` holds ``_rpc_lock`` while blocked on + The scheduler/execute path holds ``_rpc_lock`` while blocked on ``executor._result_mq.dequeue()`` because the worker never replies. A concurrent ``collective_rpc(timeout=...)`` must still time out instead of hanging forever waiting for the lock. @@ -428,3 +439,353 @@ def _hold_and_release(): t.join(5) assert result.error == "ok" + + +# ───────── error handling: EngineDeadError propagation through layers ───── + + +class TestMultiprocExecutorRaisesEngineDeadError: + """``collective_rpc`` raises ``EngineDeadError`` when the engine is failed.""" + + def test_collective_rpc_raises_when_is_failed(self): + executor = object.__new__(MultiprocDiffusionExecutor) + executor._closed = False + executor._broadcast_mq = MagicMock() + executor._result_mq = MagicMock() + executor._result_mq.dequeue = MagicMock(side_effect=TimeoutError) + executor.is_failed = True + + with pytest.raises(EngineDeadError): + executor.collective_rpc( + "generate", + args=(MagicMock(),), + unique_reply_rank=0, + exec_all_ranks=True, + ) + + def test_collective_rpc_raises_mid_dequeue_when_is_failed(self): + """Worker dies while we are polling the dequeue loop.""" + executor, _, res_q = _make_executor() + + call_count = 0 + orig_dequeue = executor._result_mq.dequeue + + def _dying_dequeue(timeout=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + executor.is_failed = True + raise TimeoutError + return orig_dequeue(timeout=timeout) + + executor._result_mq.dequeue = _dying_dequeue + + with pytest.raises(EngineDeadError): + executor.collective_rpc( + "generate", + args=(MagicMock(),), + unique_reply_rank=0, + exec_all_ranks=True, + ) + + +class TestDiffusionEngineDeadErrorPassthrough: + """``DiffusionEngine.add_req_and_wait_for_response`` re-raises + ``EngineDeadError`` from executor and wraps other errors.""" + + def test_engine_dead_error_propagates(self): + engine, executor, _, _ = _make_engine() + engine.execute_fn = Mock(side_effect=EngineDeadError()) + + with pytest.raises(EngineDeadError): + engine.add_req_and_wait_for_response(_mock_request("dead")) + + def test_runtime_error_wrapped_in_output(self): + engine, executor, _, _ = _make_engine() + engine.execute_fn = Mock(side_effect=RuntimeError("gpu fault")) + + out = engine.add_req_and_wait_for_response(_mock_request("fault")) + assert isinstance(out, DiffusionOutput) + assert "gpu fault" in out.error + + +class TestStageDiffusionClientErrorPropagation: + """Error surface behaviour of ``StageDiffusionClient``. + + Uses ``object.__new__`` to construct a client without spawning a real + subprocess, then manually sets the fields needed for each test. + """ + + def _make_client(self, *, engine_dead=False, proc_alive=True): + from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient + + client = object.__new__(StageDiffusionClient) + client.stage_id = 0 + client.final_output = True + client.final_output_type = "image" + client.default_sampling_params = None + client.custom_process_input_func = None + client.engine_input_source = None + + client._output_queue = asyncio.Queue() + client._rpc_results = {} + client._pending_rpcs = set() + client._tasks = {} + client._shutting_down = False + client._engine_dead = engine_dead + client._owns_process = True + client._proc = MagicMock( + is_alive=MagicMock(return_value=proc_alive), + exitcode=1, + ) + client._request_socket = MagicMock() + client._response_socket = MagicMock() + client._encoder = MagicMock() + client._decoder = MagicMock() + + return client + + @pytest.mark.asyncio + async def test_add_request_raises_when_dead(self): + client = self._make_client(engine_dead=True) + + with pytest.raises(EngineDeadError): + await client.add_request_async("req-3", "test prompt", None) + + def test_check_health_raises_when_dead(self): + client = self._make_client(engine_dead=True) + + with pytest.raises(EngineDeadError): + client.check_health() + + def test_check_health_ok_when_alive(self): + client = self._make_client() + client.check_health() + + def test_get_output_raises_engine_dead_when_dead(self): + """When ``_engine_dead`` is True and the output queue is empty, + ``get_diffusion_output_nowait`` must raise ``EngineDeadError``.""" + client = self._make_client(engine_dead=True) + # Simulate _drain_responses as a no-op (no ZMQ socket) + client._response_socket.recv.side_effect = zmq.Again + + with pytest.raises(EngineDeadError): + client.get_diffusion_output_nowait() + + def test_get_output_returns_none_when_alive_and_empty(self): + """When the engine is alive and the queue is empty, return None.""" + client = self._make_client() + client._response_socket.recv.side_effect = zmq.Again + + assert client.get_diffusion_output_nowait() is None + + def test_check_health_raises_when_proc_dead(self): + """``check_health`` detects a dead subprocess via ``_proc.is_alive()`` + and raises ``EngineDeadError``, setting ``_engine_dead`` as a + side effect.""" + client = self._make_client(proc_alive=False) + + with pytest.raises(EngineDeadError, match="not alive"): + client.check_health() + + assert client._engine_dead is True + + def test_get_output_raises_when_proc_dead(self): + """When the subprocess has died (non-signal exit) and the output + queue is empty, ``get_diffusion_output_nowait`` must raise + ``EngineDeadError`` with the exit code.""" + client = self._make_client(proc_alive=False) + client._response_socket.recv.side_effect = zmq.Again + + with pytest.raises(EngineDeadError, match="exit code"): + client.get_diffusion_output_nowait() + + assert client._engine_dead is True + + def test_get_output_returns_none_on_signal_death(self): + """When the subprocess was killed by a signal (exit code > 128), + ``get_diffusion_output_nowait`` returns ``None`` and sets + ``_shutting_down`` instead of raising.""" + client = self._make_client(proc_alive=False) + client._proc.exitcode = 137 # SIGKILL (128 + 9) + client._response_socket.recv.side_effect = zmq.Again + + result = client.get_diffusion_output_nowait() + + assert result is None + assert client._shutting_down is True + assert client._engine_dead is True + + +# ───────── monitor thread & death sentinel integration tests ───────── + + +def _poll_flag(get_flag, *, timeout=5.0, interval=0.05) -> bool: + """Poll until ``get_flag()`` returns True or *timeout* elapses.""" + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if get_flag(): + return True + time.sleep(interval) + return False + + +def _make_short_lived_process() -> mp.Process: + """Spawn a real subprocess that exits immediately. + + The process must be started with ``"fork"`` (or the platform default) + so that it can use a plain ``lambda`` as its target — ``"spawn"`` would + fail to pickle it. + """ + ctx = mp.get_context("fork") + p = ctx.Process(target=lambda: None, name="ShortLivedWorker-0") + p.start() + return p + + +class TestMultiprocExecutorWorkerMonitor: + """Integration tests for ``start_worker_monitor``. + + Uses real short-lived subprocesses so that OS-level sentinel fd + readiness is exercised end-to-end. + """ + + def test_worker_monitor_sets_is_failed_and_calls_callbacks_on_death(self): + """When a worker process dies, the monitor thread must: + 1. Set ``is_failed = True`` + 2. Call ``shutdown()`` (which sets ``_closed = True``) + 3. Invoke all registered failure callbacks + """ + executor = object.__new__(MultiprocDiffusionExecutor) + executor._closed = False + executor.is_failed = False + executor._failure_callbacks = [] + executor._broadcast_mq = None + executor._result_mq = None + executor.resources = None + # Use a no-op so shutdown() doesn't crash on None resources. + executor._finalizer = lambda: None + + proc = _make_short_lived_process() + executor._processes = [proc] + + callback_called = threading.Event() + executor.register_failure_callback(callback_called.set) + + executor.start_worker_monitor() + + # Wait for the process to exit and the monitor to react. + proc.join(5) + assert _poll_flag(lambda: executor.is_failed), "is_failed was not set" + assert executor._closed, "shutdown() was not called" + assert callback_called.wait(timeout=2), "failure callback was not invoked" + + def test_worker_monitor_noop_when_already_closed(self): + """If ``_closed`` is already True when the process dies (orderly + shutdown), the monitor must *not* set ``is_failed``.""" + executor = object.__new__(MultiprocDiffusionExecutor) + executor._closed = True # already shut down + executor.is_failed = False + executor._failure_callbacks = [] + executor._broadcast_mq = None + executor._result_mq = None + executor.resources = None + executor._finalizer = lambda: None + + proc = _make_short_lived_process() + executor._processes = [proc] + + executor.start_worker_monitor() + proc.join(5) + + # Give the monitor thread a chance to run (it should early-return). + time.sleep(0.3) + assert not executor.is_failed, "is_failed should remain False on orderly shutdown" + + +class TestStageDiffusionClientProcMonitor: + """Integration test for ``StageDiffusionClient._start_proc_monitor``. + + Uses a real short-lived subprocess to verify the sentinel-based + detection pipeline. + """ + + def test_proc_monitor_sets_engine_dead_on_process_death(self): + """When the subprocess dies, the monitor thread must set + ``_engine_dead = True``.""" + from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient + + client = object.__new__(StageDiffusionClient) + client.stage_id = 0 + client._shutting_down = False + client._engine_dead = False + + proc = _make_short_lived_process() + client._proc = proc + + client._start_proc_monitor() + proc.join(5) + + assert _poll_flag(lambda: client._engine_dead), "_engine_dead was not set" + + +class TestDrainResponsesDeathSentinel: + """Tests for death sentinel and error routing in + ``StageDiffusionClient._drain_responses()``. + """ + + def _make_client(self): + from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient + + client = object.__new__(StageDiffusionClient) + client.stage_id = 0 + client._engine_dead = False + client._shutting_down = False + client._output_queue = asyncio.Queue() + client._rpc_results = {} + client._pending_rpcs = set() + client._response_socket = MagicMock() + client._decoder = MagicMock() + return client + + def test_drain_responses_sets_engine_dead_on_death_sentinel(self): + """When ``_drain_responses`` receives the ``DIFFUSION_PROC_DEAD`` + sentinel, it must set ``_engine_dead = True`` and stop draining + (decoder is never called).""" + client = self._make_client() + + # First recv returns the death sentinel, second would be a normal + # message but should never be reached. + client._response_socket.recv.side_effect = [ + StageDiffusionProc.DIFFUSION_PROC_DEAD, + b"should-not-be-reached", + ] + + client._drain_responses() + + assert client._engine_dead is True + client._decoder.decode.assert_not_called() + + def test_drain_responses_routes_error_as_omni_request_output(self): + """When ``_drain_responses`` receives a ``{"type": "error"}`` message + with a ``request_id``, it must place an ``OmniRequestOutput`` with + the error on ``_output_queue``.""" + client = self._make_client() + + error_msg = { + "type": "error", + "request_id": "req-fail", + "error": "gpu fault", + } + # First recv returns the encoded error, second raises zmq.Again. + client._response_socket.recv.side_effect = [b"encoded-error", zmq.Again] + client._decoder.decode.return_value = error_msg + + client._drain_responses() + + assert not client._output_queue.empty() + output = client._output_queue.get_nowait() + assert isinstance(output, OmniRequestOutput) + assert output.request_id == "req-fail" + assert output.error == "gpu fault" + assert output.finished is True diff --git a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py index 24daa8ccf54..3aa5da85c24 100644 --- a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py +++ b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py @@ -567,7 +567,6 @@ def test_wan22_i2v_diffusers_offline_generates_video( @pytest.mark.benchmark @pytest.mark.diffusion @hardware_test(res={"cuda": "H100"}, num_cards=2) -@pytest.mark.skip(reason="issue: #2874") @pytest.mark.parametrize("omni_server", SERVER_CASES, indirect=True) def test_wan22_i2v_online_serving_generates_video( omni_server, diff --git a/tests/e2e/offline_inference/test_omni_sleep_mode.py b/tests/e2e/offline_inference/test_omni_sleep_mode.py new file mode 100644 index 00000000000..2ad0b53b010 --- /dev/null +++ b/tests/e2e/offline_inference/test_omni_sleep_mode.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +End-to-end tests for Omni Sleep Mode across various model architectures. +""" + +import pytest +import torch +from vllm import SamplingParams + +from tests.helpers.mark import hardware_test +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +MODEL = "ByteDance-Seed/BAGEL-7B-MoT" +MODEL_DIFF = "riverclouds/qwen_image_random" + + +def get_ack_info(ack, key, default=None): + if hasattr(ack, key): + return getattr(ack, key) + if isinstance(ack, dict): + return ack.get(key, default) + return default + + +def get_dynamic_devices(stage_idx, num_stages, tp_size): + total_gpus = torch.cuda.device_count() + gpus_per_stage = tp_size + start_idx = stage_idx * gpus_per_stage + if start_idx + gpus_per_stage > total_gpus: + start_idx = start_idx % total_gpus + device_ids = [str(start_idx + i) for i in range(gpus_per_stage)] + return ",".join(device_ids) + + +# Test 1: Diffusion Model (2-Stage BAGEL) +@pytest.mark.advanced_model +@pytest.mark.omni +@pytest.mark.parametrize("tp_size", [1, 2]) +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +@pytest.mark.asyncio +async def test_diffusion_model_sleep_tp(tp_size): + num_gpus = torch.cuda.device_count() + if num_gpus < tp_size: + pytest.skip(f"Skipping TP={tp_size}") + + engine_args = { + "model": MODEL, + "enable_sleep_mode": True, + "tensor_parallel_size": tp_size, + "enforce_eager": True, + "trust_remote_code": True, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.5, + } + + engine = AsyncOmni(**engine_args, stage_init_timeout=1200) + try: + # BAGEL requires 2 params + diff_sp = OmniDiffusionSamplingParams(num_inference_steps=2, height=256, width=256) + llm_sp = SamplingParams() + + # Warmup + async for _ in engine.generate("test", sampling_params=[llm_sp, diff_sp]): + pass + + # Sleep all + acks = await engine.sleep(level=2) + statuses = [get_ack_info(ack, "status") for ack in acks] + assert all(s == "SUCCESS" for s in statuses), f"Sleep failed. Statuses: {statuses}" + + # Wakeup & Verify + await engine.wake_up() + async for _ in engine.generate("verify", sampling_params=[llm_sp, diff_sp]): + pass + + print(f"Diffusion TP={tp_size} Lifecycle OK") + finally: + engine.shutdown() + + +# Test 2: Multi-stage Manual Config +@pytest.mark.advanced_model +@pytest.mark.omni +@pytest.mark.parametrize("tp_size", [1, 2]) +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +@pytest.mark.asyncio +async def test_multistage_sleep_h100(tp_size): + num_gpus = torch.cuda.device_count() + if num_gpus < tp_size * 2: + pytest.skip("Not enough GPUs") + + stages = [] + for i in range(2): + devs = get_dynamic_devices(i, 2, tp_size) + stages.append( + { + "stage_id": i, + "stage_type": "llm" if i == 0 else "diffusion", + "runtime": {"process": True, "devices": devs}, + "engine_args": { + "model": MODEL, + "model_stage": "thinker" if i == 0 else "base", + "tensor_parallel_size": tp_size, + "gpu_memory_utilization": 0.4, + "dtype": "bfloat16", + "enable_sleep_mode": True, + "trust_remote_code": True, + }, + } + ) + + connectors = [{"src_stage_id": 0, "dst_stage_id": 1, "connector_type": "queue"}] + + engine = AsyncOmni( + model=MODEL, stages=stages, connectors=connectors, enable_sleep_mode=True, stage_init_timeout=1200 + ) + try: + sp = OmniDiffusionSamplingParams(num_inference_steps=2) + async for _ in engine.generate("warmup", sampling_params=[SamplingParams(), sp]): + pass + + acks = await engine.sleep(stage_ids=[0, 1], level=2) + assert len(acks) == 2 * tp_size + + await engine.wake_up(stage_ids=[0, 1]) + async for _ in engine.generate("verify", sampling_params=[SamplingParams(), sp]): + pass + finally: + engine.shutdown() + + +# Test 3: Pure Diffusion Single-Stage +@pytest.mark.advanced_model +@pytest.mark.omni +@pytest.mark.parametrize("tp_size", [1, 2]) +@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2) +@pytest.mark.asyncio +async def test_pure_diffusion_scenario(tp_size): + engine_args = { + "model": MODEL_DIFF, + "enable_sleep_mode": True, + "tensor_parallel_size": tp_size, + "enforce_eager": True, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.5, + } + + engine = AsyncOmni(**engine_args, stage_init_timeout=1200) + try: + await engine.sleep(level=1) + await engine.wake_up() + async for _ in engine.generate("test", sampling_params=[SamplingParams()]): + pass + print("Pure Diffusion OK") + finally: + engine.shutdown() diff --git a/tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py b/tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py new file mode 100644 index 00000000000..5b51615feaf --- /dev/null +++ b/tests/e2e/online_serving/test_qwen3_omni_multi_replicas.py @@ -0,0 +1,137 @@ +""" +Core-model CI guard for Qwen3-Omni multi-replica stage-pool routing on 4 GPUs. +""" + +import os + +import pytest + +from tests.helpers.mark import hardware_test +from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video +from tests.helpers.runtime import OmniResponse, OmniServerParams, dummy_messages_from_mix_data +from tests.helpers.stage_config import get_deploy_config_path + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0" + +MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct" +MULTI_REPLICA_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe_multi_replicas_4gpu.yaml") +ROUTE_STRESS_REQUESTS = 6 +MIXED_MODAL_REQUESTS = 4 + +test_params = [ + OmniServerParams( + model=MODEL, + stage_config_path=MULTI_REPLICA_DEPLOY, + server_args=["--disable-log-stats"], + ) +] + + +def _system_prompt() -> dict[str, object]: + return { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + "You are Qwen, a virtual human developed by the Qwen Team, " + "Alibaba Group, capable of perceiving auditory and visual inputs, " + "as well as generating text and speech." + ), + } + ], + } + + +def _text_messages() -> list[dict[str, object]]: + return dummy_messages_from_mix_data( + system_prompt=_system_prompt(), + content_text="What is the capital of China? Answer in one short sentence.", + ) + + +def _mixed_messages() -> list[dict[str, object]]: + video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}" + image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}" + audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(5, 1)['base64']}" + return dummy_messages_from_mix_data( + system_prompt=_system_prompt(), + video_data_url=video_data_url, + image_data_url=image_data_url, + audio_data_url=audio_data_url, + content_text="What is recited in the audio? What is in this image? Describe the video briefly.", + ) + + +def _assert_batch_size(responses: list[OmniResponse], expected: int) -> None: + assert len(responses) == expected, f"Expected {expected} responses, got {len(responses)}" + assert all(resp.success for resp in responses), "At least one request failed" + assert all(resp.e2e_latency is not None and resp.e2e_latency > 0 for resp in responses), "Missing request latency" + + +def _assert_text_outputs(responses: list[OmniResponse]) -> None: + assert all(resp.text_content is not None for resp in responses), "Missing text output" + assert all(resp.audio_bytes is None for resp in responses), "Text-only request unexpectedly produced audio" + + +def _assert_audio_outputs(responses: list[OmniResponse], *, expect_text: bool) -> None: + assert all(resp.audio_bytes is not None and len(resp.audio_bytes) > 128 for resp in responses), ( + "Missing audio output" + ) + if expect_text: + assert all(resp.text_content is not None for resp in responses), "Missing text output" + else: + assert all(not (resp.text_content or "").strip() for resp in responses), ( + "Audio-only request unexpectedly produced text" + ) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_text_only_batch_uses_multi_replica_talker(omni_server, openai_client) -> None: + request_config = { + "model": omni_server.model, + "messages": _text_messages(), + "stream": False, + "modalities": ["text"], + } + + responses = openai_client.send_omni_request(request_config, request_num=ROUTE_STRESS_REQUESTS) + _assert_batch_size(responses, ROUTE_STRESS_REQUESTS) + _assert_text_outputs(responses) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_text_to_audio_stream_batch_uses_multi_replica_vocoder(omni_server, openai_client) -> None: + request_config = { + "model": omni_server.model, + "messages": _text_messages(), + "stream": True, + "modalities": ["audio"], + } + + responses = openai_client.send_omni_request(request_config, request_num=ROUTE_STRESS_REQUESTS) + _assert_batch_size(responses, ROUTE_STRESS_REQUESTS) + _assert_audio_outputs(responses, expect_text=False) + + +@pytest.mark.core_model +@pytest.mark.omni +@hardware_test(res={"cuda": "H100"}, num_cards=4) +@pytest.mark.parametrize("omni_server", test_params, indirect=True) +def test_mixed_modal_stream_batch_generates_text_and_audio(omni_server, openai_client) -> None: + request_config = { + "model": omni_server.model, + "messages": _mixed_messages(), + "stream": True, + } + + responses = openai_client.send_omni_request(request_config, request_num=MIXED_MODAL_REQUESTS) + _assert_batch_size(responses, MIXED_MODAL_REQUESTS) + _assert_audio_outputs(responses, expect_text=True) diff --git a/tests/engine/test_async_omni_engine_input.py b/tests/engine/test_async_omni_engine_input.py index 3700e426d42..d86483f1561 100644 --- a/tests/engine/test_async_omni_engine_input.py +++ b/tests/engine/test_async_omni_engine_input.py @@ -59,7 +59,7 @@ def test_build_add_request_message_preserves_additional_information(mocker: Mock assert request.additional_information is not None assert request.additional_information.entries["text"].list_data == ["hello world"] assert request.additional_information.entries["speaker"].list_data == ["vivian"] - output_processor.add_request.assert_called_once() + output_processor.add_request.assert_not_called() def test_build_add_request_message_with_resumable_streaming(mocker: MockerFixture): diff --git a/tests/engine/test_async_omni_engine_outputs.py b/tests/engine/test_async_omni_engine_outputs.py index ef3cfab3bf8..47b3d5e9f14 100644 --- a/tests/engine/test_async_omni_engine_outputs.py +++ b/tests/engine/test_async_omni_engine_outputs.py @@ -63,3 +63,39 @@ async def test_try_get_output_async_raises_after_orchestrator_dies(mocker: Mocke with pytest.raises(RuntimeError, match="Orchestrator died unexpectedly"): await engine.try_get_output_async() + + +def test_fatal_error_message_surfaces_through_try_get_output(mocker: MockerFixture): + """When the orchestrator thread crashes, it enqueues a fatal error message. + + ``try_get_output`` must return this message so the caller + (``OmniBase._handle_output_message``) can detect the fatal flag. + """ + fatal_msg = {"type": "error", "error": "Orchestrator thread crashed", "fatal": True} + + mock_queue = mocker.MagicMock() + mock_queue.sync_q.get.return_value = fatal_msg + + engine = _make_engine(mock_queue, mocker, thread_alive=False) + + msg = engine.try_get_output() + assert msg is not None + assert msg["type"] == "error" + assert msg["fatal"] is True + assert "crashed" in msg["error"] + + +@pytest.mark.asyncio +async def test_fatal_error_message_surfaces_through_try_get_output_async(mocker: MockerFixture): + """Async variant of the fatal error message test.""" + fatal_msg = {"type": "error", "error": "Orchestrator thread crashed", "fatal": True} + + mock_queue = mocker.MagicMock() + mock_queue.sync_q.get_nowait.return_value = fatal_msg + + engine = _make_engine(mock_queue, mocker, thread_alive=False) + + msg = await engine.try_get_output_async() + assert msg is not None + assert msg["type"] == "error" + assert msg["fatal"] is True diff --git a/tests/engine/test_async_omni_engine_stage_init.py b/tests/engine/test_async_omni_engine_stage_init.py index e0999ca8796..ed1d15b3d1e 100644 --- a/tests/engine/test_async_omni_engine_stage_init.py +++ b/tests/engine/test_async_omni_engine_stage_init.py @@ -1,15 +1,138 @@ import importlib import os import threading +import time import types import pytest from vllm_omni.engine.async_omni_engine import AsyncOmniEngine +from vllm_omni.engine.stage_init_utils import ( + LogicalStageInitPlan, + ReplicaInitPlan, + build_stage0_input_processor, + compute_replica_layout, +) pytestmark = [pytest.mark.core_model, pytest.mark.cpu] +def _make_llm_metadata( + stage_id: int, + *, + replica_id: int = 0, + final_output: bool = False, + final_output_type: str | None = None, + is_comprehension: bool = False, +): + return types.SimpleNamespace( + stage_id=stage_id, + stage_type="llm", + runtime_cfg={}, + prompt_expand_func=None, + final_output=final_output, + final_output_type=final_output_type, + default_sampling_params=types.SimpleNamespace(name=f"sp-{stage_id}-{replica_id}"), + custom_process_input_func=None, + engine_input_source=[] if stage_id == 0 else [stage_id - 1], + engine_output_type="token_ids", + replica_id=replica_id, + is_comprehension=is_comprehension, + ) + + +def _make_diffusion_metadata(stage_id: int, *, replica_id: int = 0, final_output_type: str = "image"): + return types.SimpleNamespace( + stage_id=stage_id, + stage_type="diffusion", + runtime_cfg={"devices": str(replica_id)}, + prompt_expand_func=None, + final_output=True, + final_output_type=final_output_type, + default_sampling_params=types.SimpleNamespace(name=f"dsp-{stage_id}-{replica_id}"), + custom_process_input_func=None, + engine_input_source=[], + cfg_kv_collect_func=None, + replica_id=replica_id, + ) + + +def _make_llm_plan( + stage_idx: int, + *, + configured_stage_id: int, + vllm_config: object, + num_replicas: int = 1, + final_output: bool = False, + final_output_type: str | None = None, + is_comprehension: bool = False, +): + replicas: list[ReplicaInitPlan] = [] + for replica_id in range(num_replicas): + stage_cfg = types.SimpleNamespace( + stage_id=configured_stage_id, + stage_type="llm", + runtime=types.SimpleNamespace(devices=str(replica_id)), + engine_args={}, + ) + replicas.append( + ReplicaInitPlan( + replica_id=replica_id, + num_replicas=num_replicas, + launch_mode="local", + stage_cfg=stage_cfg, + metadata=_make_llm_metadata( + configured_stage_id, + replica_id=replica_id, + final_output=final_output, + final_output_type=final_output_type, + is_comprehension=is_comprehension and replica_id == 0, + ), + stage_connector_spec={}, + omni_kv_connector=(None, None, None), + stage_vllm_config=vllm_config, + executor_class=object, + ) + ) + return LogicalStageInitPlan( + stage_idx=stage_idx, + configured_stage_id=configured_stage_id, + replicas=replicas, + ) + + +def _make_diffusion_plan( + stage_idx: int, + *, + configured_stage_id: int, + num_replicas: int = 1, +): + replicas: list[ReplicaInitPlan] = [] + for replica_id in range(num_replicas): + stage_cfg = types.SimpleNamespace( + stage_id=configured_stage_id, + stage_type="diffusion", + runtime=types.SimpleNamespace(devices=str(replica_id)), + engine_args={}, + ) + replicas.append( + ReplicaInitPlan( + replica_id=replica_id, + num_replicas=num_replicas, + launch_mode="local", + stage_cfg=stage_cfg, + metadata=_make_diffusion_metadata(configured_stage_id, replica_id=replica_id), + stage_connector_spec={}, + omni_kv_connector=(None, None, None), + ) + ) + return LogicalStageInitPlan( + stage_idx=stage_idx, + configured_stage_id=configured_stage_id, + replicas=replicas, + ) + + def test_stage_engine_core_client_module_reload_keeps_forward_refs_deferred(): """Regression test for forward references in make_async_mp_client.""" import vllm_omni.engine.stage_engine_core_client as client_mod @@ -21,64 +144,68 @@ def test_stage_engine_core_client_module_reload_keeps_forward_refs_deferred(): ) -def test_initialize_stages_restores_device_visibility_after_diffusion_init(monkeypatch): - """Regression test for stage device env leakage across stage init. +def test_compute_replica_layout_splits_diffusion_devices_by_world_size(): + stage_cfg = types.SimpleNamespace( + stage_id=0, + stage_type="diffusion", + engine_args={"parallel_config": {"tensor_parallel_size": 2}}, + runtime={"devices": "0,1,2,3", "num_replicas": 2}, + ) + + replicas_per_stage, replica_devices_map = compute_replica_layout([stage_cfg]) + + assert replicas_per_stage == [2] + assert replica_devices_map == {0: ["0,1", "2,3"]} + + +def test_collect_initialized_clients_for_cleanup_deduplicates_clients(): + shared = types.SimpleNamespace(name="shared") + extra = types.SimpleNamespace(name="extra") + + cleanup_clients = AsyncOmniEngine._collect_initialized_clients_for_cleanup( + stage_pools=[types.SimpleNamespace(clients=[shared, None])], + initialized_clients_by_stage={0: [shared], 1: [extra]}, + ) + + assert cleanup_clients == [shared, extra] + + +def test_initialize_stages_rejects_replicas_in_single_stage_mode(): + engine = object.__new__(AsyncOmniEngine) + engine.single_stage_mode = True + engine.stage_configs = [types.SimpleNamespace(stage_id=0, runtime={"num_replicas": 2})] + + with pytest.raises(ValueError, match="single_stage_mode does not support num_replicas > 1 yet"): + engine._validate_single_stage_mode_replica_constraints() + - Diffusion init mutates process-level CUDA visibility. Ensure AsyncOmniEngine - restores the previous value after diffusion stage setup. - """ +def test_initialize_diffusion_replica_restores_device_visibility_after_local_init(monkeypatch): import vllm_omni.engine.async_omni_engine as engine_mod from vllm_omni.platforms import current_omni_platform engine = object.__new__(AsyncOmniEngine) engine.model = "dummy-model" - engine.config_path = "dummy-config" engine.num_stages = 1 - engine.async_chunk = False engine.diffusion_batch_size = 1 engine.single_stage_mode = False - engine._single_stage_id_filter = None engine._omni_master_server = None - engine.stage_configs = [types.SimpleNamespace(stage_id=0, stage_type="diffusion")] + engine.stage_configs = [] + + plan = _make_diffusion_plan(0, configured_stage_id=0).replicas[0] env_var = current_omni_platform.device_control_env_var old_env = os.environ.get(env_var) os.environ[env_var] = "0,1" - diffusion_client = types.SimpleNamespace(is_comprehension=False) - - metadata = types.SimpleNamespace( - stage_id=0, - stage_type="diffusion", - runtime_cfg={"devices": "1"}, - prompt_expand_func=None, - ) - - monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None) - monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None) - monkeypatch.setattr(engine_mod, "extract_stage_metadata", lambda _cfg: metadata) - monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {}) - monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None)) - def _fake_setup_stage_devices(_stage_id, _runtime_cfg): - # Simulate diffusion setup mutating process-global visibility. current_omni_platform.set_device_control_env_var("1") monkeypatch.setattr(engine_mod, "setup_stage_devices", _fake_setup_stage_devices) monkeypatch.setattr(engine_mod, "inject_kv_stage_info", lambda *_: None) - monkeypatch.setattr(engine_mod, "initialize_diffusion_stage", lambda *_, **__: diffusion_client) - monkeypatch.setattr( - engine_mod, - "finalize_initialized_stages", - lambda stage_clients, _input_processor: ( - stage_clients, - [types.SimpleNamespace()], - [{"final_output_type": "image"}], - ), - ) + monkeypatch.setattr(engine_mod, "initialize_diffusion_stage", lambda *_, **__: types.SimpleNamespace()) try: - engine._initialize_stages(stage_init_timeout=1) + engine._initialize_diffusion_replica(plan, stage_init_timeout=1, stage_launch_lock=threading.Lock()) assert os.environ.get(env_var) == "0,1" finally: if old_env is None: @@ -87,290 +214,273 @@ def _fake_setup_stage_devices(_stage_id, _runtime_cfg): os.environ[env_var] = old_env -def test_initialize_stages_passes_stage_init_timeout_to_diffusion_handshake(monkeypatch): - """Regression test for stage_init_timeout passing to complete_diffusion_handshake - in the diffusion stage path. - """ - import vllm_omni.diffusion.data as diffusion_data_mod - import vllm_omni.diffusion.stage_diffusion_client as client_mod +def test_initialize_diffusion_replica_passes_stage_init_timeout_and_inline_flag(monkeypatch): + import vllm_omni.engine.async_omni_engine as engine_mod + + engine = object.__new__(AsyncOmniEngine) + engine.model = "dummy-model" + engine.num_stages = 1 + engine.diffusion_batch_size = 4 + engine.single_stage_mode = False + engine._omni_master_server = None + engine.stage_configs = [] + + plan = _make_diffusion_plan(0, configured_stage_id=0).replicas[0] + + captured: dict[str, object] = {} + + monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None) + monkeypatch.setattr(engine_mod, "inject_kv_stage_info", lambda *_: None) + + def _capture_initialize_diffusion_stage( + stage_id, _model, _stage_cfg, _metadata, *, stage_init_timeout, batch_size, use_inline + ): + captured["stage_id"] = stage_id + captured["stage_init_timeout"] = stage_init_timeout + captured["batch_size"] = batch_size + captured["use_inline"] = use_inline + return types.SimpleNamespace() + + monkeypatch.setattr(engine_mod, "initialize_diffusion_stage", _capture_initialize_diffusion_stage) + + engine._initialize_diffusion_replica(plan, stage_init_timeout=302, stage_launch_lock=threading.Lock()) + + assert captured == { + "stage_id": 0, + "stage_init_timeout": 302, + "batch_size": 4, + "use_inline": True, + } + + +def test_initialize_stages_exposes_logical_stage_views_and_builds_top_level_input_processor(monkeypatch): import vllm_omni.engine.async_omni_engine as engine_mod - from vllm_omni.platforms import current_omni_platform engine = object.__new__(AsyncOmniEngine) - engine.log_stats = False engine.model = "dummy-model" engine.config_path = "dummy-config" engine.num_stages = 2 engine.async_chunk = False engine.diffusion_batch_size = 1 engine.single_stage_mode = False + engine._single_stage_id_filter = None engine._omni_master_server = None - engine.stage_configs = [types.SimpleNamespace(stage_id=0, stage_type="diffusion", engine_args={})] + engine.stage_configs = [types.SimpleNamespace(), types.SimpleNamespace()] - metadata = types.SimpleNamespace( - stage_id=0, - stage_type="diffusion", - runtime_cfg={"devices": "0"}, - prompt_expand_func=None, + cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) + cfg1 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) + stage_plans = [ + _make_llm_plan(0, configured_stage_id=0, vllm_config=cfg0, num_replicas=2, is_comprehension=True), + _make_llm_plan(1, configured_stage_id=1, vllm_config=cfg1, final_output=True), + ] + + stage0_client_r0 = types.SimpleNamespace( + stage_type="llm", + is_comprehension=True, + final_output=False, + final_output_type=None, + default_sampling_params=types.SimpleNamespace(name="sp0"), + ) + stage0_client_r1 = types.SimpleNamespace( + stage_type="llm", + is_comprehension=False, + final_output=False, + final_output_type=None, + default_sampling_params=types.SimpleNamespace(name="sp0r1"), + ) + stage1_client_r0 = types.SimpleNamespace( + stage_type="llm", + is_comprehension=False, final_output=True, - final_output_type="image", - default_sampling_params=None, - custom_process_input_func=None, - engine_input_source=None, - cfg_kv_collect_func=None, + final_output_type=None, + default_sampling_params=types.SimpleNamespace(name="sp1"), ) + initialized_clients = { + 0: [stage0_client_r0, stage0_client_r1], + 1: [stage1_client_r0], + } - captured_timeout = None - device_env_var = current_omni_platform.device_control_env_var - prev_device_env = os.environ.get(device_env_var) - os.environ[device_env_var] = "0" + stage0_output_processor = object() + stage1_output_processor = object() + top_level_input_processor = object() monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None) monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None) - monkeypatch.setattr(engine_mod, "extract_stage_metadata", lambda _cfg: metadata) - monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None) + monkeypatch.setattr(engine_mod, "compute_replica_layout", lambda _cfgs: ([2, 1], {})) + monkeypatch.setattr(engine, "_build_logical_stage_init_plans", lambda *_: (stage_plans, None)) + monkeypatch.setattr(engine, "_initialize_stage_replicas", lambda *_: initialized_clients) monkeypatch.setattr( engine_mod, - "finalize_initialized_stages", - lambda stage_clients, _input_processor: ( - stage_clients, - [types.SimpleNamespace()], - [{"final_output_type": "image"}], - ), + "build_llm_stage_output_processor", + lambda plan, _cfg: stage0_output_processor if plan.stage_idx == 0 else stage1_output_processor, ) - monkeypatch.setattr( - diffusion_data_mod.OmniDiffusionConfig, - "from_kwargs", - classmethod(lambda cls, **kwargs: types.SimpleNamespace(parallel_config=types.SimpleNamespace(world_size=1))), - ) - monkeypatch.setattr( - client_mod, - "spawn_diffusion_proc", - lambda model, od_cfg: (object(), "ipc://handshake", "ipc://request", "ipc://response"), - ) - - def _capture_handshake_timeout(proc, handshake_address, handshake_timeout): - nonlocal captured_timeout - captured_timeout = handshake_timeout + monkeypatch.setattr(engine_mod, "build_stage0_input_processor", lambda _cfg: top_level_input_processor) - monkeypatch.setattr(client_mod, "complete_diffusion_handshake", _capture_handshake_timeout) - monkeypatch.setattr( - client_mod.zmq, - "Context", - lambda: types.SimpleNamespace(socket=lambda _: types.SimpleNamespace(connect=lambda _: None)), - ) - - try: - engine._initialize_stages(stage_init_timeout=302) - finally: - if prev_device_env is None: - os.environ.pop(device_env_var, None) - else: - os.environ[device_env_var] = prev_device_env + engine._initialize_stages(stage_init_timeout=1) - assert captured_timeout == 302 + assert len(engine.stage_pools) == 2 + assert engine.input_processor is top_level_input_processor + assert engine.stage_clients == [stage0_client_r0, stage1_client_r0] + assert engine.stage_vllm_configs == [cfg0, cfg1] + assert engine.output_processors == [stage0_output_processor, stage1_output_processor] + assert engine.default_sampling_params_list == [ + stage0_client_r0.default_sampling_params, + stage1_client_r0.default_sampling_params, + ] + assert engine.stage_metadata == [ + {"final_output": False, "final_output_type": None, "stage_type": "llm"}, + {"final_output": True, "final_output_type": None, "stage_type": "llm"}, + ] -def test_initialize_stages_exposes_logical_stage_views_and_shares_stage_output_processor(monkeypatch): +def test_build_logical_stage_init_plans_applies_replica_device_splits(monkeypatch): import vllm_omni.engine.async_omni_engine as engine_mod engine = object.__new__(AsyncOmniEngine) engine.model = "dummy-model" - engine.config_path = "dummy-config" - engine.num_stages = 2 engine.async_chunk = False - engine.diffusion_batch_size = 1 engine.single_stage_mode = False engine._single_stage_id_filter = None - engine._omni_master_server = None engine.stage_configs = [ - types.SimpleNamespace(stage_id=0, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace()), - types.SimpleNamespace(stage_id=1, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace()), + types.SimpleNamespace(stage_id=0, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace(devices="0")), + types.SimpleNamespace( + stage_id=1, stage_type="llm", engine_args={}, runtime=types.SimpleNamespace(devices="1,2,3") + ), ] - stage0_client_r0 = types.SimpleNamespace(is_comprehension=False, stage_type="llm") - stage0_client_r1 = types.SimpleNamespace(is_comprehension=False, stage_type="llm") - stage1_client_r0 = types.SimpleNamespace(is_comprehension=False, stage_type="llm") - - stage0_proc_r0 = object() - stage1_proc_r0 = object() - - cfg0_r0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) - cfg0_r1 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) - cfg1_r0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) - - metadata_map = { - 0: types.SimpleNamespace( - stage_id=0, - stage_type="llm", - runtime_cfg={}, - prompt_expand_func=None, - final_output=False, - final_output_type=None, - default_sampling_params=types.SimpleNamespace(), - custom_process_input_func=None, - engine_input_source=[], - engine_output_type="token_ids", - ), - 1: types.SimpleNamespace( - stage_id=1, - stage_type="llm", - runtime_cfg={}, - prompt_expand_func=None, - final_output=True, - final_output_type=None, - default_sampling_params=types.SimpleNamespace(), - custom_process_input_func=None, - engine_input_source=[0], - engine_output_type="token_ids", - ), + metadata_by_stage = { + 0: _make_llm_metadata(0), + 1: _make_llm_metadata(1), } - monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None) - monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None) - monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {}) - monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None)) - monkeypatch.setattr(engine_mod, "compute_replica_layout", lambda _cfgs: ([2, 1], {}, 3)) monkeypatch.setattr( engine_mod, "extract_stage_metadata", - lambda cfg: types.SimpleNamespace(**metadata_map[cfg.stage_id].__dict__), + lambda cfg: types.SimpleNamespace(**metadata_by_stage[cfg.stage_id].__dict__), ) + monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {}) + monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None)) + monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {}) monkeypatch.setattr( engine_mod, - "finalize_initialized_stages", - lambda stage_clients, _input_processor: ( - stage_clients, - [types.SimpleNamespace(), types.SimpleNamespace()], - [{"final_output_type": None}, {"final_output_type": None}], - ), + "build_vllm_config", + lambda stage_cfg, *_args, **_kwargs: (types.SimpleNamespace(tag=f"cfg-{stage_cfg.stage_id}"), object), ) - started_by_stage = { - 0: [ - types.SimpleNamespace(stage_id=0, metadata=types.SimpleNamespace(replica_id=0), vllm_config=cfg0_r0), - types.SimpleNamespace(stage_id=0, metadata=types.SimpleNamespace(replica_id=1), vllm_config=cfg0_r1), - ], - 1: [ - types.SimpleNamespace(stage_id=1, metadata=types.SimpleNamespace(replica_id=0), vllm_config=cfg1_r0), - ], - } + stage_plans, prompt_expand_func = engine._build_logical_stage_init_plans( + omni_transfer_config=None, + replicas_per_stage=[1, 3], + replica_devices_map={1: ["1", "2", "3"]}, + ) - attach_outputs = { - (0, 0): (stage0_client_r0, stage0_proc_r0, cfg0_r0, object()), - (0, 1): (stage0_client_r1, None, cfg0_r1, None), - (1, 0): (stage1_client_r0, stage1_proc_r0, cfg1_r0, None), - } + assert prompt_expand_func is None + assert [plan.configured_stage_id for plan in stage_plans] == [0, 1] + assert [replica.stage_cfg.runtime.devices for replica in stage_plans[1].replicas] == ["1", "2", "3"] + assert [replica.replica_id for replica in stage_plans[1].replicas] == [0, 1, 2] + assert all(replica.num_replicas == 3 for replica in stage_plans[1].replicas) - launch_counters = {0: 0, 1: 0} - def _fake_launch_llm_stage(stage_cfg, metadata, *_args, **_kwargs): - idx = metadata.stage_id - launch_idx = launch_counters[idx] - launch_counters[idx] += 1 - return started_by_stage[idx][launch_idx] +def test_initialize_stage_replicas_collects_results_by_stage_and_replica_id(monkeypatch): + engine = object.__new__(AsyncOmniEngine) - def _fake_attach_llm_stage(started): - return attach_outputs[(started.stage_id, started.metadata.replica_id)] + cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) + cfg1 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) + stage_plans = [ + _make_llm_plan(0, configured_stage_id=0, vllm_config=cfg0, num_replicas=2), + _make_llm_plan(1, configured_stage_id=1, vllm_config=cfg1, num_replicas=2), + ] - monkeypatch.setattr(engine, "_launch_llm_stage", _fake_launch_llm_stage) - monkeypatch.setattr(engine, "_attach_llm_stage", _fake_attach_llm_stage) + clients = { + (0, 0): types.SimpleNamespace(name="stage0-replica0"), + (0, 1): types.SimpleNamespace(name="stage0-replica1"), + (1, 0): types.SimpleNamespace(name="stage1-replica0"), + (1, 1): types.SimpleNamespace(name="stage1-replica1"), + } - engine._initialize_stages(stage_init_timeout=1) + def _initialize_replica(plan, _stage_init_timeout, _stage_launch_lock): + time.sleep(0.02 * (3 - plan.metadata.stage_id - plan.replica_id)) + return clients[(plan.metadata.stage_id, plan.replica_id)] - assert len(engine.stage_pools) == 2 - assert len(engine.stage_clients) == 2 - assert len(engine.stage_vllm_configs) == 2 - assert len(engine.output_processors) == 2 + monkeypatch.setattr(engine, "_initialize_replica", _initialize_replica) - assert engine.stage_clients[0] is stage0_client_r0 - assert engine.stage_clients[1] is stage1_client_r0 - assert engine.stage_vllm_configs[0] is cfg0_r0 - assert engine.stage_vllm_configs[1] is cfg1_r0 + initialized_clients = engine._initialize_stage_replicas(stage_plans, stage_init_timeout=123) - stage0_pool = engine.stage_pools[0] - assert engine.output_processors[0] is stage0_pool.output_processor + assert initialized_clients == { + 0: [clients[(0, 0)], clients[(0, 1)]], + 1: [clients[(1, 0)], clients[(1, 1)]], + } -def test_launch_llm_stage_passes_stage_init_timeout_to_complete_stage_handshake(monkeypatch): - """Regression test for stage_init_timeout reaching complete_stage_handshake - in the LLM stage path. - """ +def test_initialize_stages_cleans_up_successful_replicas_after_partial_multi_replica_failure(monkeypatch): import vllm_omni.engine.async_omni_engine as engine_mod - from vllm_omni.platforms import current_omni_platform engine = object.__new__(AsyncOmniEngine) - engine.log_stats = False engine.model = "dummy-model" + engine.config_path = "dummy-config" + engine.num_stages = 1 + engine.async_chunk = False + engine.diffusion_batch_size = 1 engine.single_stage_mode = False + engine._single_stage_id_filter = None engine._omni_master_server = None - engine.stage_configs = [] + engine.stage_configs = [types.SimpleNamespace()] - metadata = types.SimpleNamespace(stage_id=0, runtime_cfg={"devices": "0"}) - fake_vllm_config = types.SimpleNamespace() - fake_addresses = types.SimpleNamespace() - fake_proc = types.SimpleNamespace() + cfg0 = types.SimpleNamespace(model_config=types.SimpleNamespace(max_model_len=64)) + stage_plans = [_make_llm_plan(0, configured_stage_id=0, vllm_config=cfg0, num_replicas=2)] + initialized_client = types.SimpleNamespace(shutdown=lambda: None) - captured_timeout = None + monkeypatch.setattr(engine_mod, "prepare_engine_environment", lambda: None) + monkeypatch.setattr(engine_mod, "load_omni_transfer_config_for_model", lambda *_: None) + monkeypatch.setattr(engine_mod, "compute_replica_layout", lambda _cfgs: ([2], {})) + monkeypatch.setattr(engine, "_build_logical_stage_init_plans", lambda *_: (stage_plans, None)) - device_env_var = current_omni_platform.device_control_env_var - prev_device_env = os.environ.get(device_env_var) - os.environ[device_env_var] = "0" + def _initialize_replica(plan, _stage_init_timeout, _stage_launch_lock): + if plan.replica_id == 0: + return initialized_client + time.sleep(0.05) + raise RuntimeError("replica launch failed") - monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None) - monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {}) - monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object)) - monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: []) - monkeypatch.setattr( - engine_mod, - "spawn_stage_core", - lambda **_: (fake_addresses, fake_proc, "ipc://handshake"), - ) + monkeypatch.setattr(engine, "_initialize_replica", _initialize_replica) - def _capture_stage_timeout(_proc, _handshake_addr, _addresses, _vllm_cfg, handshake_timeout): - nonlocal captured_timeout - captured_timeout = handshake_timeout + captured_cleanup: list[list[object]] = [] - monkeypatch.setattr(engine_mod, "complete_stage_handshake", _capture_stage_timeout) + def _capture_shutdown(clients): + captured_cleanup.append(list(clients)) - try: - engine._launch_llm_stage( - stage_cfg=types.SimpleNamespace(engine_args={}), - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=302, - llm_stage_launch_lock=threading.Lock(), - ) - finally: - if prev_device_env is None: - os.environ.pop(device_env_var, None) - else: - os.environ[device_env_var] = prev_device_env + monkeypatch.setattr(engine, "_shutdown_initialized_clients", _capture_shutdown) - assert captured_timeout == 302 + with pytest.raises(RuntimeError, match="replica launch failed"): + engine._initialize_stages(stage_init_timeout=1) + + assert captured_cleanup == [[initialized_client]] -def test_launch_llm_stage_releases_launch_lock_before_complete_stage_handshake(monkeypatch): - """Regression test for parallel LLM stage startup during handshake wait.""" +def test_initialize_llm_replica_passes_stage_init_timeout_to_complete_stage_handshake(monkeypatch): import vllm_omni.engine.async_omni_engine as engine_mod from vllm_omni.platforms import current_omni_platform engine = object.__new__(AsyncOmniEngine) - engine.log_stats = False engine.model = "dummy-model" engine.single_stage_mode = False engine._omni_master_server = None engine.stage_configs = [] fake_vllm_config = types.SimpleNamespace() - fake_addresses = types.SimpleNamespace() - shared_launch_lock = threading.Lock() - counter_lock = threading.Lock() - first_handshake_started = threading.Event() - second_stage_spawned = threading.Event() - allow_first_handshake_to_finish = threading.Event() - launch_errors: list[BaseException] = [] - spawn_count = 0 + fake_addresses = types.SimpleNamespace(inputs=["in"], outputs=["out"], frontend_stats_publish_address=None) + fake_proc = types.SimpleNamespace() + captured_timeout: int | None = None + + plan = ReplicaInitPlan( + replica_id=0, + num_replicas=1, + launch_mode="local", + stage_cfg=types.SimpleNamespace(engine_args={}, runtime=types.SimpleNamespace(devices="0")), + metadata=types.SimpleNamespace(stage_id=0, runtime_cfg={"devices": "0"}), + stage_connector_spec={}, + omni_kv_connector=(None, None, None), + stage_vllm_config=fake_vllm_config, + executor_class=object, + ) device_env_var = current_omni_platform.device_control_env_var prev_device_env = os.environ.get(device_env_var) @@ -378,127 +488,52 @@ def test_launch_llm_stage_releases_launch_lock_before_complete_stage_handshake(m monkeypatch.setattr(engine_mod, "setup_stage_devices", lambda *_: None) monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {}) - monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (fake_vllm_config, object)) monkeypatch.setattr(engine_mod, "acquire_device_locks", lambda *_: []) + monkeypatch.setattr(engine_mod, "spawn_stage_core", lambda **_: (fake_addresses, fake_proc, "ipc://handshake")) - def _spawn_stage_core(**_): - nonlocal spawn_count - with counter_lock: - spawn_count += 1 - call_idx = spawn_count - if call_idx == 2: - second_stage_spawned.set() - return fake_addresses, types.SimpleNamespace(), f"ipc://handshake-{call_idx}" - - def _complete_stage_handshake(_proc, handshake_address, _addresses, _vllm_cfg, _timeout): - if handshake_address == "ipc://handshake-1": - first_handshake_started.set() - assert second_stage_spawned.wait(timeout=1), ( - "second stage did not reach spawn_stage_core while first stage waited in handshake" - ) - assert allow_first_handshake_to_finish.wait(timeout=1), ( - "second stage did not enter handshake while first stage was still waiting" - ) - else: - allow_first_handshake_to_finish.set() - - monkeypatch.setattr(engine_mod, "spawn_stage_core", _spawn_stage_core) - monkeypatch.setattr(engine_mod, "complete_stage_handshake", _complete_stage_handshake) + def _capture_stage_timeout(_proc, _handshake_addr, _addresses, _vllm_cfg, handshake_timeout): + nonlocal captured_timeout + captured_timeout = handshake_timeout - def _launch_stage(stage_id: int) -> None: - metadata = types.SimpleNamespace(stage_id=stage_id, runtime_cfg={"devices": str(stage_id)}) - try: - engine._launch_llm_stage( - stage_cfg=types.SimpleNamespace(engine_args={}), - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=302, - llm_stage_launch_lock=shared_launch_lock, - ) - except BaseException as exc: # pragma: no cover - surfaced through assertion below - launch_errors.append(exc) + monkeypatch.setattr(engine_mod, "complete_stage_handshake", _capture_stage_timeout) + monkeypatch.setattr( + engine_mod.StageEngineCoreClientBase, + "make_async_mp_client", + staticmethod(lambda **_: types.SimpleNamespace(shutdown=lambda: None)), + ) try: - first_thread = threading.Thread(target=_launch_stage, args=(0,)) - first_thread.start() - assert first_handshake_started.wait(timeout=1), "first stage never entered handshake" - - second_thread = threading.Thread(target=_launch_stage, args=(1,)) - second_thread.start() - - first_thread.join(timeout=3) - second_thread.join(timeout=3) + engine._initialize_llm_replica(plan, 302, threading.Lock()) finally: if prev_device_env is None: os.environ.pop(device_env_var, None) else: os.environ[device_env_var] = prev_device_env - assert not first_thread.is_alive() - assert not second_thread.is_alive() - assert second_stage_spawned.is_set() - assert not launch_errors - - -def test_attach_llm_stage_uses_omni_input_preprocessor(monkeypatch): - """Regression test for GLM-Image t2i preprocessing path. - - Stage-0 InputProcessor must use OmniInputPreprocessor so text prompts with - mm_processor_kwargs go through multimodal preprocessing. - """ - import vllm_omni.engine.async_omni_engine as engine_mod - - class DummyStageEngineCoreClient: - def __init__(self, **kwargs): - self.kwargs = kwargs + assert captured_timeout == 302 - def shutdown(self): - return None - class DummyOutputProcessor: - def __init__(self, **kwargs): - self.kwargs = kwargs +def test_build_stage0_input_processor_uses_omni_input_preprocessor(monkeypatch): + import vllm_omni.engine.stage_init_utils as init_mod class DummyInputProcessor: def __init__(self, vllm_config): self.vllm_config = vllm_config self.renderer = object() - self.input_preprocessor = object() + self.input_preprocessor = None class DummyOmniInputPreprocessor: def __init__(self, vllm_config, renderer=None): self.vllm_config = vllm_config self.renderer = renderer - monkeypatch.setattr( - engine_mod.StageEngineCoreClientBase, - "make_async_mp_client", - staticmethod(lambda **kwargs: DummyStageEngineCoreClient(**kwargs)), - ) - monkeypatch.setattr(engine_mod, "MultimodalOutputProcessor", DummyOutputProcessor) - monkeypatch.setattr(engine_mod, "InputProcessor", DummyInputProcessor) - monkeypatch.setattr(engine_mod, "OmniInputPreprocessor", DummyOmniInputPreprocessor) + monkeypatch.setattr(init_mod, "InputProcessor", DummyInputProcessor) + monkeypatch.setattr(init_mod, "OmniInputPreprocessor", DummyOmniInputPreprocessor) - started = types.SimpleNamespace( - stage_id=0, - metadata=types.SimpleNamespace(stage_id=0, engine_output_type="token_ids"), - vllm_config=types.SimpleNamespace(model_config=types.SimpleNamespace(skip_tokenizer_init=True)), - executor_class=object, - engine_manager=object(), - coordinator=object(), - proc=None, - addresses=types.SimpleNamespace( - inputs=["inproc://input"], - outputs=["inproc://output"], - frontend_stats_publish_address=None, - ), + input_processor = build_stage0_input_processor( + types.SimpleNamespace(model_config=types.SimpleNamespace(try_get_generation_config=lambda: {})) ) - engine = object.__new__(AsyncOmniEngine) - - _stage_client, _out_proc, _vllm_cfg, input_processor = engine._attach_llm_stage(started) - - assert input_processor is not None assert isinstance(input_processor.input_preprocessor, DummyOmniInputPreprocessor) assert input_processor.input_preprocessor.renderer is input_processor.renderer diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py index b012728cce5..88251ff7410 100644 --- a/tests/engine/test_orchestrator.py +++ b/tests/engine/test_orchestrator.py @@ -2,6 +2,7 @@ import asyncio import concurrent.futures +import logging import queue import threading import time @@ -14,7 +15,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -from vllm_omni.engine.orchestrator import Orchestrator +from vllm_omni.engine.orchestrator import Orchestrator, OrchestratorRequestState from vllm_omni.engine.stage_pool import StagePool from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput @@ -88,6 +89,25 @@ def push_diffusion_output(self, output) -> None: self._diffusion_outputs.put_nowait(output) +class FakeCollectiveRpcStageClient(FakeStageClient): + def __init__(self, *args, rpc_result: Any = None, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.rpc_result = rpc_result + self.collective_rpc_calls: list[tuple[str, float | None, tuple[Any, ...], dict[str, Any]]] = [] + + async def collective_rpc_async( + self, + *, + method: str, + timeout: float | None = None, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + normalized_kwargs = dict(kwargs or {}) + self.collective_rpc_calls.append((method, timeout, args, normalized_kwargs)) + return self.rpc_result + + class FakeOutputProcessor: def __init__(self, *, request_outputs: list[object] | None = None) -> None: self.request_outputs = list(request_outputs or []) @@ -281,6 +301,18 @@ async def _get_output_message(orchestrator_fixture: OrchestratorFixture, *, time return msg +async def _get_rpc_message(orchestrator_fixture: OrchestratorFixture, *, timeout: float = 2.0) -> dict: + deadline = time.monotonic() + timeout + rpc_sync_q = orchestrator_fixture.queues[2].sync_q + while True: + if time.monotonic() >= deadline: + raise AssertionError("Timed out waiting for orchestrator rpc output") + try: + return rpc_sync_q.get_nowait() + except queue.Empty: + await asyncio.sleep(0.01) + + async def _enqueue_add_request( orchestrator_fixture: OrchestratorFixture, *, @@ -691,3 +723,207 @@ async def test_multi_replica_shutdown_all_replicas(orchestrator_factory) -> None assert not orchestrator_fixture.thread.is_alive() for client in [stage0_r0, stage0_r1, stage1]: assert client.shutdown_calls == 1 + + +@pytest.mark.asyncio +async def test_stage_pool_submit_update_reuses_existing_binding() -> None: + """A request admitted to one replica must keep using that replica on updates.""" + stage0_r0 = FakeStageClient(stage_type="llm", final_output=False) + stage0_r1 = FakeStageClient(stage_type="llm", final_output=False) + pool = StagePool( + 0, + [stage0_r0, stage0_r1], + output_processor=FakeOutputProcessor(), + stage_vllm_config=SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + ) + + req0_state = OrchestratorRequestState( + request_id="req-0", + sampling_params_list=[_sampling_params()], + final_stage_id=0, + ) + req1_state = OrchestratorRequestState( + request_id="req-1", + sampling_params_list=[_sampling_params()], + final_stage_id=0, + ) + + await pool.submit_initial("req-0", req0_state, SimpleNamespace(request_id="req-0", prompt_token_ids=[1, 2])) + await pool.submit_update("req-0", req0_state, SimpleNamespace(request_id="req-0", prompt_token_ids=[3])) + await pool.submit_initial("req-1", req1_state, SimpleNamespace(request_id="req-1", prompt_token_ids=[4, 5])) + await pool.submit_update("req-1", req1_state, SimpleNamespace(request_id="req-1", prompt_token_ids=[6])) + + assert pool.get_bound_replica_id("req-0") == 0 + assert pool.get_bound_replica_id("req-1") == 1 + assert len(stage0_r0.add_request_calls) == 2 + assert len(stage0_r1.add_request_calls) == 2 + assert stage0_r0.add_request_calls[0][0].request_id == "req-0" + assert stage0_r0.add_request_calls[1][0].request_id == "req-0" + assert stage0_r1.add_request_calls[0][0].request_id == "req-1" + assert stage0_r1.add_request_calls[1][0].request_id == "req-1" + + +@pytest.mark.asyncio +async def test_stage_pool_submit_initial_rolls_back_output_processor_when_client_submit_fails() -> None: + class FailingStageClient(FakeStageClient): + async def add_request_async(self, *args, **_kwargs) -> None: + raise RuntimeError("submit failed") + + class TrackingOutputProcessor(FakeOutputProcessor): + def __init__(self) -> None: + super().__init__() + self.added_request_ids: list[str] = [] + self.removed_request_ids: list[str] = [] + + def add_request(self, request, *_args, **_kwargs) -> None: + self.added_request_ids.append(request.request_id) + + def remove_request(self, request_id: str) -> None: + self.removed_request_ids.append(request_id) + + client = FailingStageClient(stage_type="llm", final_output=False) + output_processor = TrackingOutputProcessor() + pool = StagePool( + 0, + [client], + output_processor=output_processor, + stage_vllm_config=SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + ) + req_state = OrchestratorRequestState( + request_id="req-0", + sampling_params_list=[_sampling_params()], + final_stage_id=0, + ) + + with pytest.raises(RuntimeError, match="submit failed"): + await pool.submit_initial("req-0", req_state, SimpleNamespace(request_id="req-0", prompt_token_ids=[1, 2])) + + assert output_processor.added_request_ids == ["req-0"] + assert output_processor.removed_request_ids == ["req-0"] + assert pool.get_bound_replica_id("req-0") is None + + +@pytest.mark.asyncio +async def test_stage_pool_abort_requests_logs_when_binding_is_missing(caplog) -> None: + stage0 = FakeStageClient(stage_type="llm", final_output=False) + pool = StagePool( + 0, + [stage0], + output_processor=FakeOutputProcessor(), + stage_vllm_config=SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + ) + + target_logger = logging.getLogger("vllm_omni.engine.stage_pool") + target_logger.addHandler(caplog.handler) + prev_level = target_logger.level + target_logger.setLevel(logging.DEBUG) + try: + await pool.abort_requests(["missing-req"]) + finally: + target_logger.removeHandler(caplog.handler) + target_logger.setLevel(prev_level) + + assert not stage0.abort_calls + assert "abort: no binding for req=missing-req in stage-0" in caplog.text + + +@pytest.mark.asyncio +async def test_collective_rpc_ignores_invalid_stage_ids(orchestrator_factory, caplog) -> None: + stage0 = FakeCollectiveRpcStageClient(stage_type="llm", final_output=True, rpc_result={"stage": 0}) + stage1 = FakeCollectiveRpcStageClient(stage_type="llm", final_output=True, rpc_result={"stage": 1}) + stage_pools = _build_stage_pools( + [[stage0], [stage1]], + output_processors=[FakeOutputProcessor(), FakeOutputProcessor()], + stage_vllm_configs=[ + SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)), + ], + ) + orchestrator_fixture = orchestrator_factory([], stage_pools=stage_pools) + + try: + target_logger = logging.getLogger("vllm_omni.engine.orchestrator") + target_logger.addHandler(caplog.handler) + prev_level = target_logger.level + target_logger.setLevel(logging.WARNING) + try: + orchestrator_fixture.request_sync_q.put_nowait( + { + "type": "collective_rpc", + "rpc_id": "rpc-1", + "method": "list_loras", + "stage_ids": [99, 1], + } + ) + + msg = await _get_rpc_message(orchestrator_fixture) + finally: + target_logger.removeHandler(caplog.handler) + target_logger.setLevel(prev_level) + + assert msg["type"] == "collective_rpc_result" + assert msg["rpc_id"] == "rpc-1" + assert msg["stage_ids"] == [1] + assert msg["results"] == [{"stage": 1}] + assert not stage0.collective_rpc_calls + assert len(stage1.collective_rpc_calls) == 1 + assert "collective_rpc: ignoring invalid stage_id 99" in caplog.text + finally: + await _shutdown_orchestrator(orchestrator_fixture) + + +@pytest.mark.asyncio +async def test_multi_replica_cfg_companion_inherits_parent_affinity(orchestrator_factory) -> None: + """CFG companions should be routed to the same stage-0 replica as their parent.""" + stage0_r0 = FakeStageClient(stage_type="llm", final_output=False) + stage0_r1 = FakeStageClient(stage_type="llm", final_output=False) + default_vllm_cfg = SimpleNamespace(model_config=SimpleNamespace(max_model_len=64)) + stage_pools = _build_stage_pools( + [[stage0_r0, stage0_r1]], + output_processors=[FakeOutputProcessor()], + stage_vllm_configs=[default_vllm_cfg], + ) + orchestrator_fixture = orchestrator_factory([], stage_pools=stage_pools) + + try: + # Consume replica-0 first so the parent request binds to replica-1. + await _enqueue_add_request( + orchestrator_fixture, + request_id="warmup", + prompt=SimpleNamespace(request_id="warmup", prompt_token_ids=[0]), + original_prompt={"prompt": "warmup"}, + sampling_params_list=[_sampling_params()], + final_stage_id=0, + ) + await _wait_for(lambda: len(stage0_r0.add_request_calls) == 1) + + await _enqueue_add_request( + orchestrator_fixture, + request_id="parent", + prompt=SimpleNamespace(request_id="parent", prompt_token_ids=[1, 2]), + original_prompt={"prompt": "parent"}, + sampling_params_list=[_sampling_params()], + final_stage_id=0, + ) + await _wait_for(lambda: len(stage0_r1.add_request_calls) == 1) + + orchestrator_fixture.request_sync_q.put_nowait( + { + "type": "add_companion_request", + "companion_id": "parent-neg", + "parent_id": "parent", + "role": "negative", + "prompt": SimpleNamespace(request_id="parent-neg", prompt_token_ids=[9]), + "companion_prompt_text": {"prompt": "negative"}, + "sampling_params_list": [_sampling_params()], + } + ) + await _wait_for(lambda: len(stage0_r1.add_request_calls) == 2) + + assert stage_pools[0].get_bound_replica_id("parent") == 1 + assert stage_pools[0].get_bound_replica_id("parent-neg") == 1 + assert len(stage0_r0.add_request_calls) == 1 + assert stage0_r1.add_request_calls[0][0].request_id == "parent" + assert stage0_r1.add_request_calls[1][0].request_id == "parent-neg" + finally: + await _shutdown_orchestrator(orchestrator_fixture) diff --git a/tests/engine/test_orchestrator_error_handling.py b/tests/engine/test_orchestrator_error_handling.py new file mode 100644 index 00000000000..18099c01640 --- /dev/null +++ b/tests/engine/test_orchestrator_error_handling.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for error propagation paths within the Orchestrator. + +Covers: +- EngineDeadError from an LLM stage poll → fatal error broadcast + shutdown +- Diffusion stage error output (OmniRequestOutput.from_error) → routed correctly +""" + +from __future__ import annotations + +import asyncio +import queue +import time +from types import SimpleNamespace + +import pytest +from vllm.v1.engine.exceptions import EngineDeadError + +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.outputs import OmniRequestOutput + +from .test_orchestrator import ( + FakeStageClient, + OrchestratorFixture, + _build_harness, + _enqueue_add_request, + _wait_for, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _sampling_params(max_tokens: int = 4): + from vllm.sampling_params import SamplingParams + + return SamplingParams(max_tokens=max_tokens) + + +async def _get_any_output_message(fixture: OrchestratorFixture, *, timeout: float = 2.0) -> dict: + """Like _get_output_message but returns any message type (including errors).""" + deadline = time.monotonic() + timeout + while True: + if time.monotonic() >= deadline: + raise AssertionError("Timed out waiting for orchestrator output") + try: + return fixture.output_sync_q.get_nowait() + except queue.Empty: + await asyncio.sleep(0.01) + + +@pytest.fixture +def orchestrator_factory(): + fixtures: list[OrchestratorFixture] = [] + + def _factory(*args, **kwargs) -> OrchestratorFixture: + fixture = _build_harness(*args, **kwargs) + fixtures.append(fixture) + return fixture + + yield _factory + + for fixture in fixtures: + if fixture.thread.is_alive(): + fixture.request_sync_q.put_nowait({"type": "shutdown"}) + fixture.thread.join(timeout=5) + for q in fixture.queues: + q.close() + + +# ───────── EngineDeadError from LLM stage poll ───────── + + +class FakeDeadLLMStageClient(FakeStageClient): + """LLM stage client that raises EngineDeadError on get_output_async.""" + + async def get_output_async(self): + raise EngineDeadError("Stage-0 engine core is dead") + + +@pytest.mark.asyncio +async def test_engine_dead_error_broadcasts_fatal_and_shuts_down(orchestrator_factory) -> None: + """When a stage raises EngineDeadError during poll, the orchestrator must: + 1. Enqueue a fatal error message for each affected request + 2. Shut itself down (thread exits) + """ + stage0 = FakeDeadLLMStageClient(stage_type="llm", final_output=True) + orchestrator_fixture = orchestrator_factory([stage0]) + request = SimpleNamespace(request_id="req-dead", prompt_token_ids=[1, 2]) + + try: + await _enqueue_add_request( + orchestrator_fixture, + request_id="req-dead", + prompt=request, + original_prompt={"prompt": "hello"}, + sampling_params_list=[_sampling_params()], + final_stage_id=0, + ) + + # Collect the fatal error message. + msg = await _get_any_output_message(orchestrator_fixture) + + assert msg["type"] == "error" + assert msg["fatal"] is True + assert msg["request_id"] == "req-dead" + assert "Stage-0 engine core is dead" in msg["error"] + + # The orchestrator thread should exit after the fatal error. + orchestrator_fixture.thread.join(timeout=5) + assert not orchestrator_fixture.thread.is_alive() + + # Request state should be cleaned up. + assert "req-dead" not in orchestrator_fixture.orchestrator.request_states + finally: + if orchestrator_fixture.thread.is_alive(): + orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"}) + orchestrator_fixture.thread.join(timeout=5) + + +# ───────── Diffusion stage error output routing ───────── + + +@pytest.mark.asyncio +async def test_diffusion_error_output_routed_as_finished(orchestrator_factory) -> None: + """When a diffusion stage returns an OmniRequestOutput with a non-None + error, the orchestrator must route it as an error message and clean up + the request state. + """ + stage0 = FakeStageClient(stage_type="diffusion", final_output=True, final_output_type="image") + orchestrator_fixture = orchestrator_factory([stage0]) + params = OmniDiffusionSamplingParams() + + try: + await _enqueue_add_request( + orchestrator_fixture, + request_id="req-err", + prompt={"prompt": "draw a cat"}, + original_prompt={"prompt": "draw a cat"}, + sampling_params_list=[params], + final_stage_id=0, + ) + + await _wait_for(lambda: len(stage0.add_request_calls) == 1) + + # Push an error output from the diffusion stage. + stage0.push_diffusion_output(OmniRequestOutput.from_error("req-err", "gpu fault")) + + msg = await _get_any_output_message(orchestrator_fixture) + + assert msg["type"] == "error" + assert msg["request_id"] == "req-err" + assert msg["stage_id"] == 0 + assert msg["error"] == "gpu fault" + + # Request state should be cleaned up. + await _wait_for(lambda: "req-err" not in orchestrator_fixture.orchestrator.request_states) + finally: + orchestrator_fixture.request_sync_q.put_nowait({"type": "shutdown"}) + orchestrator_fixture.thread.join(timeout=5) diff --git a/tests/engine/test_orchestrator_kv_sender_info.py b/tests/engine/test_orchestrator_kv_sender_info.py index a030ab3478b..2fa9878fb84 100644 --- a/tests/engine/test_orchestrator_kv_sender_info.py +++ b/tests/engine/test_orchestrator_kv_sender_info.py @@ -153,7 +153,7 @@ def test_forward_to_diffusion_attaches_kv_sender_info(): ) output = SimpleNamespace(request_id="req-1", finished=True) - asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-1", sender_pool, output, req_state)) + asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-1", sender_pool.stage_id, output, req_state)) assert diffusion_stage.calls[0]["request_id"] == "req-1" assert diffusion_stage.calls[0]["kv_sender_info"] == { @@ -182,7 +182,7 @@ def test_forward_to_diffusion_uses_engine_input_source_for_kv_sender_info(): ) output = SimpleNamespace(request_id="req-3", finished=True) - asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-3", previous_pool, output, req_state)) + asyncio.run(Orchestrator._forward_to_next_stage(orchestrator, "req-3", previous_pool.stage_id, output, req_state)) assert diffusion_stage.calls[0]["kv_sender_info"] == { 0: {"host": "10.0.0.2", "zmq_port": 50151}, diff --git a/tests/engine/test_single_stage_mode.py b/tests/engine/test_single_stage_mode.py index 28ccccaa2b5..f183b769b43 100644 --- a/tests/engine/test_single_stage_mode.py +++ b/tests/engine/test_single_stage_mode.py @@ -1,20 +1,8 @@ -"""Unit tests for AsyncOmniEngine single-stage mode and OmniMasterServer. - -These tests cover: -- OmniMasterServer address pre-allocation & ZMQ registration handshake -- AsyncOmniEngine single_stage_mode detection / _single_stage_id_filter setup -- _initialize_stages stage routing (local launch vs. remote-wait) in - single_stage_mode -- _create_remote_llm_stage delegation to connect_remote_engine_cores -- _launch_llm_stage delegation to launch_omni_core_engines in - single_stage_mode - -All tests run without real hardware by mocking ZMQ, vllm_config, and the -heavy initialization helpers. -""" +"""Unit tests for AsyncOmniEngine single-stage mode and OmniMasterServer.""" from __future__ import annotations +import os import threading from contextlib import contextmanager from types import SimpleNamespace @@ -25,6 +13,7 @@ from vllm.v1.engine.utils import EngineZmqAddresses from vllm_omni.engine.async_omni_engine import AsyncOmniEngine +from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClientBase from vllm_omni.engine.stage_engine_startup import ( OmniMasterServer, StageAllocation, @@ -32,21 +21,17 @@ connect_remote_engine_cores, launch_omni_core_engines, ) -from vllm_omni.engine.stage_init_utils import StartedLlmStage +from vllm_omni.engine.stage_init_utils import LogicalStageInitPlan, ReplicaInitPlan pytestmark = [pytest.mark.core_model, pytest.mark.cpu] -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - def _make_stage_cfg(stage_id: int, stage_type: str = "llm"): """Return a lightweight stage config mock.""" return SimpleNamespace( stage_id=stage_id, stage_type=stage_type, + runtime=SimpleNamespace(devices="0"), engine_args=SimpleNamespace( async_chunk=False, model_stage=None, @@ -55,67 +40,108 @@ def _make_stage_cfg(stage_id: int, stage_type: str = "llm"): ) -def _make_started_llm_stage(stage_id: int) -> StartedLlmStage: - """Return a minimal StartedLlmStage for mocking.""" - addresses = SimpleNamespace( - inputs=["tcp://127.0.0.1:5000"], - outputs=["tcp://127.0.0.1:5001"], - frontend_stats_publish_address=None, +def _make_llm_plan( + stage_idx: int, + *, + configured_stage_id: int, + launch_mode: str, + vllm_config: Any | None = None, +) -> LogicalStageInitPlan: + stage_cfg = _make_stage_cfg(configured_stage_id) + metadata = SimpleNamespace( + stage_id=configured_stage_id, + stage_type="llm", + runtime_cfg={"devices": "0"}, + prompt_expand_func=None, + final_output=False, + final_output_type=None, + default_sampling_params=SimpleNamespace(), + custom_process_input_func=None, + engine_input_source=[], + engine_output_type="token_ids", + replica_id=0, ) - return StartedLlmStage( - stage_id=stage_id, - metadata=SimpleNamespace(stage_id=stage_id), - vllm_config=SimpleNamespace(), - executor_class=SimpleNamespace(), - engine_manager=SimpleNamespace(), - coordinator=SimpleNamespace(), - addresses=addresses, + return LogicalStageInitPlan( + stage_idx=stage_idx, + configured_stage_id=configured_stage_id, + replicas=[ + ReplicaInitPlan( + replica_id=0, + num_replicas=1, + launch_mode=launch_mode, + stage_cfg=stage_cfg, + metadata=metadata, + stage_connector_spec={}, + omni_kv_connector=(None, None, None), + stage_vllm_config=vllm_config + or SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1)), + executor_class=object, + ) + ], + ) + + +def _make_diffusion_plan( + stage_idx: int, + *, + configured_stage_id: int, + launch_mode: str, +) -> LogicalStageInitPlan: + stage_cfg = _make_stage_cfg(configured_stage_id, stage_type="diffusion") + metadata = SimpleNamespace( + stage_id=configured_stage_id, + stage_type="diffusion", + runtime_cfg={"devices": "0"}, + prompt_expand_func=None, + final_output=True, + final_output_type="image", + default_sampling_params=SimpleNamespace(), + custom_process_input_func=None, + engine_input_source=[], + cfg_kv_collect_func=None, + replica_id=0, + ) + return LogicalStageInitPlan( + stage_idx=stage_idx, + configured_stage_id=configured_stage_id, + replicas=[ + ReplicaInitPlan( + replica_id=0, + num_replicas=1, + launch_mode=launch_mode, + stage_cfg=stage_cfg, + metadata=metadata, + stage_connector_spec={}, + omni_kv_connector=(None, None, None), + ) + ], ) # --------------------------------------------------------------------------- -# OmniMasterServer – address pre-allocation +# OmniMasterServer address pre-allocation # --------------------------------------------------------------------------- class TestOmniMasterServerAllocation: - """Test address pre-allocation in OmniMasterServer.__init__.""" - def test_public_address_and_port_properties_expose_registration_endpoint(self): - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=15000, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=15000, stage_ids=[0]) assert server.address == "127.0.0.1" assert server.port == 15000 def test_allocations_created_for_each_stage_id(self): - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=15000, - stage_ids=[0, 1, 2], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=15000, stage_ids=[0, 1, 2]) assert set(server._allocations.keys()) == {0, 1, 2} def test_each_allocation_is_stage_allocation(self): - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=15000, - stage_ids=[0, 1], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=15000, stage_ids=[0, 1]) for sid in (0, 1): - alloc = server._allocations[sid] - assert isinstance(alloc, StageAllocation) + assert isinstance(server._allocations[sid], StageAllocation) def test_allocation_addresses_reference_master_address(self): - server = OmniMasterServer( - master_address="192.168.1.10", - master_port=20000, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="192.168.1.10", master_port=20000, stage_ids=[0]) alloc = server._allocations[0] - for addr in ( + for address in ( alloc.handshake_bind_address, alloc.handshake_connect_address, alloc.input_bind_address, @@ -123,73 +149,48 @@ def test_allocation_addresses_reference_master_address(self): alloc.output_bind_address, alloc.output_connect_address, ): - assert "192.168.1.10" in addr, f"Expected master address in {addr}" + assert "192.168.1.10" in address def test_port_uniqueness_within_single_allocation(self): - """Each allocation uses three distinct ports.""" - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=15001, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=15001, stage_ids=[0]) alloc = server._allocations[0] - hs_port = int(alloc.handshake_bind_address.split(":")[-1]) - inp_port = int(alloc.input_bind_address.split(":")[-1]) - out_port = int(alloc.output_bind_address.split(":")[-1]) - assert len({hs_port, inp_port, out_port}) == 3, "Expected three distinct ports per stage allocation" + handshake_port = int(alloc.handshake_bind_address.split(":")[-1]) + input_port = int(alloc.input_bind_address.split(":")[-1]) + output_port = int(alloc.output_bind_address.split(":")[-1]) + assert len({handshake_port, input_port, output_port}) == 3 def test_get_zmq_addresses_returns_bind_addresses(self): - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=15002, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=15002, stage_ids=[0]) alloc = server._allocations[0] zmq_addrs = server.get_zmq_addresses(0) assert zmq_addrs.inputs == [alloc.input_bind_address] assert zmq_addrs.outputs == [alloc.output_bind_address] def test_get_engine_zmq_addresses_returns_connect_addresses(self): - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=15003, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=15003, stage_ids=[0]) alloc = server._allocations[0] - engine_addrs = server.get_engine_zmq_addresses(0) - assert engine_addrs.inputs == [alloc.input_connect_address] - assert engine_addrs.outputs == [alloc.output_connect_address] + zmq_addrs = server.get_engine_zmq_addresses(0) + assert zmq_addrs.inputs == [alloc.input_connect_address] + assert zmq_addrs.outputs == [alloc.output_connect_address] def test_get_allocation_returns_correct_object(self): - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=15004, - stage_ids=[3], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=15004, stage_ids=[3]) assert server.get_allocation(3) is server._allocations[3] # --------------------------------------------------------------------------- -# OmniMasterServer – ZMQ registration flow +# OmniMasterServer registration flow # --------------------------------------------------------------------------- class TestOmniMasterServerRegistration: - """Test that the server correctly handles a stage registration.""" - def test_registration_reply_contains_handshake_address(self): - """A DEALER client that sends a registration msg gets the handshake - address back from the ROUTER registration socket.""" import msgspec import zmq from vllm.utils.network_utils import get_open_port master_port = get_open_port() - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=master_port, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0]) server.start() expected_hs = server._allocations[0].handshake_connect_address @@ -198,8 +199,7 @@ def test_registration_reply_contains_handshake_address(self): sock = ctx.socket(zmq.DEALER) sock.connect(f"tcp://127.0.0.1:{master_port}") sock.send(msgspec.msgpack.encode({"stage_id": 0})) - if not sock.poll(timeout=5_000): - pytest.fail("No reply received from OmniMasterServer within 5 s") + assert sock.poll(timeout=5_000) reply = msgspec.msgpack.decode(sock.recv()) assert reply["handshake_address"] == expected_hs finally: @@ -208,92 +208,65 @@ def test_registration_reply_contains_handshake_address(self): server.stop() def test_server_handles_unknown_stage_id_gracefully(self): - """A registration for an unrecognised stage_id must not crash the server.""" import msgspec import zmq from vllm.utils.network_utils import get_open_port master_port = get_open_port() - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=master_port, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0]) server.start() ctx = zmq.Context() + bad_sock = None + good_sock = None try: bad_sock = ctx.socket(zmq.DEALER) bad_sock.connect(f"tcp://127.0.0.1:{master_port}") - # Send unknown stage_id=99 bad_sock.send(msgspec.msgpack.encode({"stage_id": 99})) - # Server should NOT reply for an unknown id; wait briefly - has_reply = bad_sock.poll(timeout=500) - assert not has_reply, "Server should not reply to unknown stage_id" - # Then register the valid stage so the server thread can exit + assert not bad_sock.poll(timeout=500) + good_sock = ctx.socket(zmq.DEALER) good_sock.connect(f"tcp://127.0.0.1:{master_port}") good_sock.send(msgspec.msgpack.encode({"stage_id": 0})) - good_sock.poll(timeout=2_000) + assert good_sock.poll(timeout=2_000) + good_sock.recv() finally: - for s in (bad_sock, good_sock): - try: - s.close(linger=0) - except Exception: - pass + for sock in (bad_sock, good_sock): + if sock is not None: + sock.close(linger=0) ctx.term() server.stop() def test_registration_stores_stage_config(self): - """Stage registration should persist the sender's stage config.""" import msgspec import zmq from vllm.utils.network_utils import get_open_port master_port = get_open_port() - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=master_port, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0]) server.start() - payload = { - "stage_id": 0, - "stage_config": { - "stage_id": 0, - "stage_type": "llm", - "engine_args": {"model": "fake-model"}, - }, - } - + stage_config = {"stage_id": 0, "stage_type": "llm"} ctx = zmq.Context() try: sock = ctx.socket(zmq.DEALER) sock.connect(f"tcp://127.0.0.1:{master_port}") - sock.send(msgspec.msgpack.encode(payload)) + sock.send(msgspec.msgpack.encode({"stage_id": 0, "stage_config": stage_config})) assert sock.poll(timeout=5_000) sock.recv() - - stored = server.get_stage_config(0, timeout_s=0.1) - assert stored == payload["stage_config"] + assert server.get_stage_config(0) == stage_config finally: sock.close(linger=0) ctx.term() server.stop() def test_registration_stores_coordinator_addresses(self): - """Stage registration should persist optional coordinator addresses.""" import msgspec import zmq from vllm.utils.network_utils import get_open_port master_port = get_open_port() - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=master_port, - stage_ids=[0], - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[0]) server.start() payload = { @@ -303,7 +276,6 @@ def test_registration_stores_coordinator_addresses(self): "coordinator_output": "tcp://127.0.0.1:31002", "frontend_stats_publish_address": "tcp://127.0.0.1:31003", } - ctx = zmq.Context() try: sock = ctx.socket(zmq.DEALER) @@ -311,9 +283,7 @@ def test_registration_stores_coordinator_addresses(self): sock.send(msgspec.msgpack.encode(payload)) assert sock.poll(timeout=5_000) sock.recv() - - stored = server.get_stage_coordinator_addresses(0, timeout_s=0.1) - assert stored == StageCoordinatorAddresses( + assert server.get_stage_coordinator_addresses(0) == StageCoordinatorAddresses( coordinator_input=payload["coordinator_input"], coordinator_output=payload["coordinator_output"], frontend_stats_publish_address=payload["frontend_stats_publish_address"], @@ -327,57 +297,47 @@ def test_stop_joins_server_thread(self): from vllm.utils.network_utils import get_open_port master_port = get_open_port() - server = OmniMasterServer( - master_address="127.0.0.1", - master_port=master_port, - stage_ids=[], # no stages → thread exits immediately - ) + server = OmniMasterServer(master_address="127.0.0.1", master_port=master_port, stage_ids=[]) server.start() + assert server._thread is not None server.stop() - # Thread should have exited (joined with timeout=10 inside stop()) assert not server._thread.is_alive() # --------------------------------------------------------------------------- -# AsyncOmniEngine – single_stage_mode detection in __init__ +# AsyncOmniEngine single_stage_mode detection in __init__ # --------------------------------------------------------------------------- class TestSingleStageModeDetection: - """Test __init__ single_stage_mode / _single_stage_id_filter setup. - - We bypass the real __init__ by patching _resolve_stage_configs and - the orchestrator thread, so no actual engines are started. - """ - - def _make_engine_no_thread(self, mocker: MockerFixture, **kwargs: Any) -> AsyncOmniEngine: - """Create an AsyncOmniEngine without starting the orchestrator thread.""" - stage_cfg = _make_stage_cfg(0) - mock_stage_configs = [stage_cfg] + def _make_engine_no_thread( + self, + mocker: MockerFixture, + *, + stage_cfgs: list[Any] | None = None, + **kwargs: Any, + ) -> AsyncOmniEngine: + mock_stage_configs = stage_cfgs or [_make_stage_cfg(0)] mocker.patch.object( AsyncOmniEngine, "_resolve_stage_configs", return_value=("/fake/path", mock_stage_configs), ) - mocker.patch.object( - AsyncOmniEngine, - "_bootstrap_orchestrator", - ) + mocker.patch.object(AsyncOmniEngine, "_bootstrap_orchestrator") mock_thread_cls = mocker.patch("threading.Thread") mock_future_cls = mocker.patch("concurrent.futures.Future") mock_future = mocker.Mock() - mock_future.result.return_value = mocker.Mock() # simulates a loop + mock_future.result.return_value = mocker.Mock() mock_future_cls.return_value = mock_future mock_thread = mocker.Mock() mock_thread.is_alive.return_value = False mock_thread_cls.return_value = mock_thread - engine = AsyncOmniEngine(model="fake-model", **kwargs) - return engine + return AsyncOmniEngine(model="fake-model", **kwargs) def test_explicit_single_stage_mode_true(self, mocker: MockerFixture): engine = self._make_engine_no_thread( @@ -407,9 +367,7 @@ def test_stage_id_kwarg_sets_filter(self, mocker: MockerFixture): assert engine._single_stage_id_filter == 1 def test_no_stage_id_no_single_stage_mode(self, mocker: MockerFixture): - engine = self._make_engine_no_thread( - mocker, - ) + engine = self._make_engine_no_thread(mocker) assert engine.single_stage_mode is False assert engine._single_stage_id_filter is None @@ -420,6 +378,7 @@ def test_single_stage_mode_without_stage_id_has_no_filter(self, mocker: MockerFi omni_master_address="127.0.0.1", omni_master_port=20003, ) + assert engine.single_stage_mode is True assert engine._single_stage_id_filter is None def test_master_address_and_port_stored(self, mocker: MockerFixture): @@ -433,899 +392,489 @@ def test_master_address_and_port_stored(self, mocker: MockerFixture): assert engine._omni_master_port == 12345 def test_omni_master_server_starts_as_none(self, mocker: MockerFixture): - engine = self._make_engine_no_thread( - mocker, - ) + engine = self._make_engine_no_thread(mocker) assert engine._omni_master_server is None # --------------------------------------------------------------------------- -# AsyncOmniEngine – _initialize_stages stage routing +# AsyncOmniEngine single-stage initialization paths # --------------------------------------------------------------------------- -class TestInitializeStagesRouting: - """Verify that _initialize_stages routes each stage to the correct launch - function depending on single_stage_mode and _single_stage_id_filter.""" - - _COMMON_PATCHES = [ - "vllm_omni.engine.async_omni_engine.prepare_engine_environment", - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - ] - - def _build_engine_skeleton( - self, - stage_cfgs: list[Any], - single_stage_mode: bool, - stage_id_filter: int | None, - omni_master_address: str = "127.0.0.1", - omni_master_port: int = 25000, +class TestSingleStageInitialization: + def _build_engine( + self, stage_cfgs: list[Any], *, single_stage_mode: bool, stage_id_filter: int | None ) -> AsyncOmniEngine: - """Build a bare AsyncOmniEngine without launching any threads.""" engine = object.__new__(AsyncOmniEngine) engine.model = "fake-model" - engine.config_path = "/fake" - engine.stage_configs = stage_cfgs + engine.config_path = "/fake/stages.yaml" engine.num_stages = len(stage_cfgs) - engine.async_chunk = False + engine.stage_configs = stage_cfgs engine.single_stage_mode = single_stage_mode engine._single_stage_id_filter = stage_id_filter - engine._omni_master_address = omni_master_address - engine._omni_master_port = omni_master_port + engine._omni_master_address = "127.0.0.1" + engine._omni_master_port = 26000 engine._omni_master_server = None - engine._llm_stage_launch_lock = __import__("threading").Lock() - engine.diffusion_batch_size = 1 - engine.stage_clients = [] - engine.stage_vllm_configs = [] - engine.output_processors = [] - engine.input_processor = None - engine.supported_tasks = ("generate",) - engine.default_sampling_params_list = [] - engine.stage_metadata = [] - engine.prompt_expand_func = None + engine.async_chunk = False + engine.diffusion_batch_size = 2 return engine - def _fake_metadata(self, mocker: MockerFixture, stage_id: int, stage_type: str = "llm") -> Any: - meta = mocker.Mock() - meta.stage_id = stage_id - meta.stage_type = stage_type - meta.runtime_cfg = {} - meta.prompt_expand_func = None - meta.engine_output_type = None - meta.is_comprehension = False - meta.final_output = True if stage_id == 0 else False - meta.final_output_type = None - return meta - - def _run_initialize_stages_mocked( - self, - mocker: MockerFixture, - engine: AsyncOmniEngine, - stage_cfgs: list[Any], - *, - launch_side_effect: Any = None, - remote_side_effect: Any = None, - attach_result: Any = None, - ) -> tuple[Any, Any]: - """Execute _initialize_stages with all heavy helpers mocked. - - Returns (mock_launch_llm_stage, mock_create_remote_llm_stage). - """ - started_by_stage: dict[int, StartedLlmStage] = { - cfg.stage_id: _make_started_llm_stage(cfg.stage_id) - for cfg in stage_cfgs - if getattr(cfg, "stage_type", "llm") != "diffusion" - } - - default_attach = (mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()) + def test_build_logical_stage_init_plans_marks_non_matching_stage_remote(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod - mock_launch = mocker.Mock( - side_effect=launch_side_effect - or (lambda cfg, meta, spec, timeout, llm_stage_launch_lock, kv: started_by_stage[meta.stage_id]) - ) - mock_remote = mocker.Mock( - side_effect=remote_side_effect or (lambda cfg, meta, spec, timeout, srv: started_by_stage[meta.stage_id]) + stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)] + engine = self._build_engine(stage_cfgs, single_stage_mode=True, stage_id_filter=7) + + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr( + engine_mod, + "extract_stage_metadata", + lambda cfg: SimpleNamespace( + stage_id=cfg.stage_id, + stage_type=getattr(cfg, "stage_type", "llm"), + prompt_expand_func=None, + runtime_cfg={}, + ), ) - mock_attach = mocker.Mock(return_value=attach_result or default_attach) + monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {}) + monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None)) + monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {}) + monkeypatch.setattr(engine_mod, "build_vllm_config", lambda *_, **__: (SimpleNamespace(), object)) + try: + stage_plans, _ = engine._build_logical_stage_init_plans(None, [1, 1], {}) + finally: + monkeypatch.undo() + + assert [plan.replicas[0].launch_mode for plan in stage_plans] == ["local", "remote"] + + def test_start_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod + engine = self._build_engine([], single_stage_mode=True, stage_id_filter=7) mock_oms = mocker.Mock(spec=OmniMasterServer) - mock_oms.get_zmq_addresses.side_effect = lambda sid: mocker.Mock() + mocker.patch.object(engine_mod, "OmniMasterServer", return_value=mock_oms) - finalized = ( - [mocker.Mock() for _ in stage_cfgs], - [mocker.Mock() for _ in stage_cfgs], - [{"final_output": True, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs], - ) + stage_plans = [ + _make_llm_plan(0, configured_stage_id=7, launch_mode="local"), + _make_diffusion_plan(1, configured_stage_id=11, launch_mode="remote"), + ] - mocker.patch.object(engine, "_launch_llm_stage", mock_launch) - mocker.patch.object(engine, "_create_remote_llm_stage", mock_remote) - mocker.patch.object(engine, "_attach_llm_stage", mock_attach) - mocker.patch( - "vllm_omni.engine.async_omni_engine.OmniMasterServer", - return_value=mock_oms, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.prepare_engine_environment", - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata( - mocker, - cfg.stage_id, - getattr(cfg, "stage_type", "llm"), - ), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, + engine._start_omni_master_server(stage_plans) + + engine_mod.OmniMasterServer.assert_called_once_with( + master_address="127.0.0.1", + master_port=26000, + stage_ids=[7, 11], ) + mock_oms.start.assert_called_once() - engine._initialize_stages(stage_init_timeout=60) + def test_start_omni_master_server_duplicate_stage_ids_raise(self): + engine = self._build_engine([], single_stage_mode=True, stage_id_filter=7) + stage_plans = [ + _make_llm_plan(0, configured_stage_id=7, launch_mode="local"), + _make_llm_plan(1, configured_stage_id=7, launch_mode="remote"), + ] - return mock_launch, mock_remote + with pytest.raises(ValueError, match="Duplicate stage_id"): + engine._start_omni_master_server(stage_plans) - # -- single-stage mode: stage matches filter → local launch --------------- + def test_start_omni_master_server_missing_address_raises(self): + engine = self._build_engine([], single_stage_mode=True, stage_id_filter=7) + engine._omni_master_address = None - def test_matching_stage_uses_launch_llm_stage(self, mocker: MockerFixture): - """stage_id == _single_stage_id_filter → _launch_llm_stage is called.""" - stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0) - mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs) + with pytest.raises(ValueError, match="requires both"): + engine._start_omni_master_server([_make_llm_plan(0, configured_stage_id=7, launch_mode="local")]) - launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list] - assert 0 in launched_ids, "_launch_llm_stage should be called for stage 0" + def test_build_logical_stage_init_plans_clears_runtime_cfg_in_single_stage_mode(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod - def test_non_matching_stage_uses_create_remote_llm_stage(self, mocker: MockerFixture): - """stage_id != _single_stage_id_filter → _create_remote_llm_stage is called.""" - stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0) - mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs) + engine = self._build_engine([_make_stage_cfg(7)], single_stage_mode=True, stage_id_filter=7) - remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list] - assert 1 in remote_ids, "_create_remote_llm_stage should be called for stage 1" + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr( + engine_mod, + "extract_stage_metadata", + lambda cfg: SimpleNamespace( + stage_id=cfg.stage_id, + stage_type="llm", + prompt_expand_func=None, + runtime_cfg={"devices": "0"}, + ), + ) + monkeypatch.setattr(engine_mod, "get_stage_connector_spec", lambda **_: {}) + monkeypatch.setattr(engine_mod, "resolve_omni_kv_config_for_stage", lambda *_: (None, None, None)) + monkeypatch.setattr(engine_mod, "build_engine_args_dict", lambda *_, **__: {}) + monkeypatch.setattr( + engine_mod, + "build_vllm_config", + lambda *_, **__: (SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1)), object), + ) + try: + stage_plans, _ = engine._build_logical_stage_init_plans(None, [1], {}) + finally: + monkeypatch.undo() - def test_filter_1_routes_correctly(self, mocker: MockerFixture): - """With filter=1, stage 0 is remote and stage 1 is local.""" - stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=1) - mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs) + assert stage_plans[0].replicas[0].metadata.runtime_cfg is None - launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list] - remote_ids = [c.args[1].stage_id for c in mock_remote.call_args_list] - assert 1 in launched_ids, "stage 1 should be launched locally with filter=1" - assert 0 in remote_ids, "stage 0 should use remote path with filter=1" + def test_initialize_stages_calls_master_server_only_in_single_stage_mode(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod - def test_no_filter_all_stages_use_launch_path(self, mocker: MockerFixture): - """single_stage_mode=True but no filter → all stages use _launch_llm_stage.""" - stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=None) - mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs) + stage_cfgs = [_make_stage_cfg(0)] + engine = self._build_engine(stage_cfgs, single_stage_mode=True, stage_id_filter=0) + stage_plan = _make_llm_plan(0, configured_stage_id=0, launch_mode="local") + client = SimpleNamespace( + stage_type="llm", + is_comprehension=False, + final_output=True, + final_output_type=None, + default_sampling_params=SimpleNamespace(), + ) + + mocker.patch.object(engine_mod, "prepare_engine_environment") + mocker.patch.object(engine_mod, "load_omni_transfer_config_for_model", return_value=None) + mocker.patch.object(engine_mod, "compute_replica_layout", return_value=([1], {})) + mocker.patch.object(engine, "_build_logical_stage_init_plans", return_value=([stage_plan], None)) + mock_start = mocker.patch.object(engine, "_start_omni_master_server") + mocker.patch.object(engine, "_initialize_stage_replicas", return_value={0: [client]}) + mocker.patch.object(engine_mod, "build_stage0_input_processor", return_value=object()) + mocker.patch.object(engine_mod, "build_llm_stage_output_processor", return_value=object()) - assert mock_remote.call_count == 0, "No remote launches without a filter" - launched_ids = [c.args[1].stage_id for c in mock_launch.call_args_list] - assert set(launched_ids) == {0, 1} + engine._initialize_stages(stage_init_timeout=60) + mock_start.assert_called_once() + + engine = self._build_engine(stage_cfgs, single_stage_mode=False, stage_id_filter=None) + mocker.patch.object(engine_mod, "prepare_engine_environment") + mocker.patch.object(engine_mod, "load_omni_transfer_config_for_model", return_value=None) + mocker.patch.object(engine_mod, "compute_replica_layout", return_value=([1], {})) + mocker.patch.object(engine, "_build_logical_stage_init_plans", return_value=([stage_plan], None)) + mock_start = mocker.patch.object(engine, "_start_omni_master_server") + mocker.patch.object(engine, "_initialize_stage_replicas", return_value={0: [client]}) + mocker.patch.object(engine_mod, "build_stage0_input_processor", return_value=object()) + mocker.patch.object(engine_mod, "build_llm_stage_output_processor", return_value=object()) - def test_non_single_stage_mode_never_calls_create_remote(self, mocker: MockerFixture): - """Outside single_stage_mode, _create_remote_llm_stage must not be called.""" - stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None) - mock_launch, mock_remote = self._run_initialize_stages_mocked(mocker, engine, stage_cfgs) + engine._initialize_stages(stage_init_timeout=60) + mock_start.assert_not_called() - assert mock_remote.call_count == 0 + def test_initialize_stages_stops_master_server_and_shuts_down_initialized_clients_on_failure( + self, + mocker: MockerFixture, + ): + import vllm_omni.engine.async_omni_engine as engine_mod - def test_omni_master_server_started_in_single_stage_mode(self, mocker: MockerFixture): - """OmniMasterServer.start() must be called when single_stage_mode=True.""" stage_cfgs = [_make_stage_cfg(0)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0) - mock_oms = mocker.Mock(spec=OmniMasterServer) - mock_oms.get_zmq_addresses.return_value = mocker.Mock() - finalized = ( - [mocker.Mock()], - [mocker.Mock()], - [{"final_output": True, "final_output_type": None, "stage_type": "llm"}], - ) + engine = self._build_engine(stage_cfgs, single_stage_mode=True, stage_id_filter=0) + stage_plan = _make_llm_plan(0, configured_stage_id=0, launch_mode="local") + initialized_client = mocker.Mock() + mock_master = mocker.Mock(spec=OmniMasterServer) + + mocker.patch.object(engine_mod, "prepare_engine_environment") + mocker.patch.object(engine_mod, "load_omni_transfer_config_for_model", return_value=None) + mocker.patch.object(engine_mod, "compute_replica_layout", return_value=([1], {})) + mocker.patch.object(engine, "_build_logical_stage_init_plans", return_value=([stage_plan], None)) + + def _start_master(_plans): + engine._omni_master_server = mock_master + + mocker.patch.object(engine, "_start_omni_master_server", side_effect=_start_master) + mocker.patch.object(engine, "_initialize_stage_replicas", return_value={0: [initialized_client]}) + mocker.patch.object(engine_mod, "build_stage0_input_processor", return_value=object()) + mocker.patch.object(engine, "_assemble_stage_pools", side_effect=RuntimeError("assemble failed")) + mock_shutdown = mocker.patch.object(engine, "_shutdown_initialized_clients") + + with pytest.raises(RuntimeError, match="assemble failed"): + engine._initialize_stages(stage_init_timeout=60) - mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)) - mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0)) - mocker.patch.object( - engine, - "_attach_llm_stage", - return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.OmniMasterServer", - return_value=mock_oms, - ) - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, - ) + mock_shutdown.assert_called_once_with([initialized_client]) + mock_master.stop.assert_called_once() - engine._initialize_stages(stage_init_timeout=60) - mock_oms.start.assert_called_once() +class TestSingleStageReplicaInitialization: + def test_initialize_llm_replica_remote_uses_connect_remote_engine_cores(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod - def test_omni_master_server_uses_configured_stage_ids(self, mocker: MockerFixture): - """Configured stage IDs, not list indexes, should drive pre-allocation.""" - stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7) - mock_oms = mocker.Mock(spec=OmniMasterServer) - mock_oms.get_zmq_addresses.return_value = mocker.Mock() - finalized = ( - [mocker.Mock(), mocker.Mock()], - [mocker.Mock(), mocker.Mock()], - [{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs], - ) + engine = object.__new__(AsyncOmniEngine) + engine.model = "fake-model" + engine.single_stage_mode = True + engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) + engine._omni_master_server.get_stage_config.return_value = {"stage_id": 7, "stage_type": "llm"} - mocker.patch.object( - engine, - "_launch_llm_stage", - side_effect=[_make_started_llm_stage(7), _make_started_llm_stage(11)], - ) - mocker.patch.object( - engine, - "_create_remote_llm_stage", - return_value=_make_started_llm_stage(11), - ) - mocker.patch.object( - engine, - "_attach_llm_stage", - return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()), - ) - mock_oms_cls = mocker.patch( - "vllm_omni.engine.async_omni_engine.OmniMasterServer", - return_value=mock_oms, - ) - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, + fake_vllm_config = SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1)) + fake_addresses = SimpleNamespace( + inputs=["tcp://in"], outputs=["tcp://out"], frontend_stats_publish_address=None ) + fake_manager = mocker.Mock() + fake_coordinator = mocker.Mock() - engine._initialize_stages(stage_init_timeout=60) + @contextmanager + def _fake_connect(**kwargs): + yield fake_manager, fake_coordinator, fake_addresses - mock_oms_cls.assert_called_once_with( - master_address=engine._omni_master_address, - master_port=engine._omni_master_port, - stage_ids=[7, 11], - ) + plan = _make_llm_plan(0, configured_stage_id=7, launch_mode="remote", vllm_config=fake_vllm_config).replicas[0] + sentinel_client = SimpleNamespace() - def test_single_stage_filter_uses_configured_stage_ids(self, mocker: MockerFixture): - """Local/remote dispatch should compare against configured stage IDs.""" - stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7) - mock_oms = mocker.Mock(spec=OmniMasterServer) - finalized = ( - [mocker.Mock(), mocker.Mock()], - [mocker.Mock(), mocker.Mock()], - [{"final_output": False, "final_output_type": None, "stage_type": "llm"} for _ in stage_cfgs], - ) - - mock_launch = mocker.patch.object( - engine, - "_launch_llm_stage", - side_effect=[_make_started_llm_stage(7)], - ) - mock_remote = mocker.patch.object( - engine, - "_create_remote_llm_stage", - return_value=_make_started_llm_stage(11), - ) + mock_connect = mocker.patch.object(engine_mod, "connect_remote_engine_cores", side_effect=_fake_connect) mocker.patch.object( - engine, - "_attach_llm_stage", - return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.OmniMasterServer", - return_value=mock_oms, - ) - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, + StageEngineCoreClientBase, + "make_async_mp_client", + side_effect=lambda **_: sentinel_client, ) - engine._initialize_stages(stage_init_timeout=60) + result = engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock()) - assert [call.args[1].stage_id for call in mock_launch.call_args_list] == [7] - assert [call.args[1].stage_id for call in mock_remote.call_args_list] == [11] + assert result is sentinel_client + engine._omni_master_server.get_stage_config.assert_called_once_with(7, timeout_s=60) + assert fake_vllm_config.parallel_config.data_parallel_size_local == 0 + assert mock_connect.call_args.kwargs["stage_id"] == 7 - def test_omni_master_server_preallocates_diffusion_stage_ids(self, mocker: MockerFixture): - """Diffusion stages should also receive OmniMasterServer allocations.""" - stage_cfgs = [_make_stage_cfg(7), _make_stage_cfg(11, stage_type="diffusion")] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=7) - mock_oms = mocker.Mock(spec=OmniMasterServer) - finalized = ( - [mocker.Mock(), mocker.Mock()], - [mocker.Mock(), mocker.Mock()], - [ - {"final_output": False, "final_output_type": None, "stage_type": "llm"}, - {"final_output": False, "final_output_type": None, "stage_type": "diffusion"}, - ], - ) + def test_initialize_llm_replica_remote_missing_registered_stage_config_raises(self, mocker: MockerFixture): + engine = object.__new__(AsyncOmniEngine) + engine.single_stage_mode = True + engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) + engine._omni_master_server.get_stage_config.return_value = None - mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(7)) - mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(7)) - mocker.patch.object(engine, "_launch_diffusion_stage", return_value=mocker.Mock()) - mocker.patch.object( - engine, - "_create_remote_diffusion_stage", - return_value=mocker.Mock(), - ) - mocker.patch.object( - engine, - "_attach_llm_stage", - return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()), - ) - mock_oms_cls = mocker.patch( - "vllm_omni.engine.async_omni_engine.OmniMasterServer", - return_value=mock_oms, - ) - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata( - mocker, - cfg.stage_id, - getattr(cfg, "stage_type", "llm"), - ), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, - ) + plan = _make_llm_plan(0, configured_stage_id=7, launch_mode="remote").replicas[0] - engine._initialize_stages(stage_init_timeout=60) + with pytest.raises(ValueError, match="registered without stage config"): + engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock()) - mock_oms_cls.assert_called_once_with( - master_address=engine._omni_master_address, - master_port=engine._omni_master_port, - stage_ids=[7, 11], - ) + def test_initialize_llm_replica_remote_attach_failure_cleans_up_started_resources(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod - def test_duplicate_llm_stage_ids_raise(self, mocker: MockerFixture): - """Duplicate configured LLM stage IDs should fail fast.""" - stage_cfgs = [_make_stage_cfg(3), _make_stage_cfg(3)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=3) + engine = object.__new__(AsyncOmniEngine) + engine.single_stage_mode = True + engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) + engine._omni_master_server.get_stage_config.return_value = {"stage_id": 7, "stage_type": "llm"} - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, + fake_vllm_config = SimpleNamespace(parallel_config=SimpleNamespace(data_parallel_size_local=1)) + fake_addresses = SimpleNamespace( + inputs=["tcp://in"], outputs=["tcp://out"], frontend_stats_publish_address=None ) - with pytest.raises(ValueError, match="Duplicate stage_id"): - engine._initialize_stages(stage_init_timeout=60) + fake_manager = mocker.Mock() + fake_coordinator = mocker.Mock() - def test_omni_master_server_not_started_in_normal_mode(self, mocker: MockerFixture): - """OmniMasterServer must NOT be instantiated outside single_stage_mode.""" - stage_cfgs = [_make_stage_cfg(0)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=False, stage_id_filter=None) - finalized = ( - [mocker.Mock()], - [mocker.Mock()], - [{"final_output": True, "final_output_type": None, "stage_type": "llm"}], - ) + @contextmanager + def _fake_connect(**kwargs): + yield fake_manager, fake_coordinator, fake_addresses - mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)) + plan = _make_llm_plan(0, configured_stage_id=7, launch_mode="remote", vllm_config=fake_vllm_config).replicas[0] + mocker.patch.object(engine_mod, "connect_remote_engine_cores", side_effect=_fake_connect) mocker.patch.object( - engine, - "_attach_llm_stage", - return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()), - ) - mock_oms_cls = mocker.patch("vllm_omni.engine.async_omni_engine.OmniMasterServer") - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata(mocker, cfg.stage_id), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, + StageEngineCoreClientBase, + "make_async_mp_client", + side_effect=RuntimeError("attach failed"), ) - engine._initialize_stages(stage_init_timeout=60) - - mock_oms_cls.assert_not_called() + with pytest.raises(RuntimeError, match="attach failed"): + engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock()) - def test_single_stage_mode_missing_master_address_raises(self, mocker: MockerFixture): - """single_stage_mode without master address/port raises ValueError.""" - stage_cfgs = [_make_stage_cfg(0)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0) - engine._omni_master_address = None # missing - engine._omni_master_port = None + fake_manager.shutdown.assert_called_once() + fake_coordinator.shutdown.assert_called_once() - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - with pytest.raises(ValueError, match="omni_master_address"): - engine._initialize_stages(stage_init_timeout=60) + def test_initialize_llm_replica_single_stage_local_uses_launch_omni_core_engines(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod + from vllm_omni.platforms import current_omni_platform - def test_matching_diffusion_stage_uses_local_registered_launch(self, mocker: MockerFixture): - """A local diffusion stage should use the registered single-stage launch path.""" - stage_cfgs = [_make_stage_cfg(0, stage_type="diffusion"), _make_stage_cfg(1)] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0) - mock_oms = mocker.Mock(spec=OmniMasterServer) - diffusion_client = mocker.Mock(stage_type="diffusion") - finalized = ( - [diffusion_client, mocker.Mock()], - [mocker.Mock(), mocker.Mock()], - [ - {"final_output": False, "final_output_type": None, "stage_type": "diffusion"}, - {"final_output": False, "final_output_type": None, "stage_type": "llm"}, - ], - ) + engine = object.__new__(AsyncOmniEngine) + engine.model = "fake-model" + engine.single_stage_mode = True + engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) + engine.stage_configs = [] - mock_local_diff = mocker.patch.object( - engine, - "_launch_diffusion_stage", - return_value=diffusion_client, - ) - mock_remote_diff = mocker.patch.object(engine, "_create_remote_diffusion_stage") - mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(1)) - mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(1)) - mocker.patch.object( - engine, - "_attach_llm_stage", - return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.OmniMasterServer", - return_value=mock_oms, - ) - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata( - mocker, - cfg.stage_id, - getattr(cfg, "stage_type", "llm"), - ), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, + fake_vllm_config = SimpleNamespace(parallel_config=SimpleNamespace()) + fake_addresses = SimpleNamespace( + inputs=["tcp://in"], outputs=["tcp://out"], frontend_stats_publish_address=None ) - engine._initialize_stages(stage_init_timeout=60) + @contextmanager + def _fake_launch(**kwargs): + yield mocker.Mock(), None, fake_addresses - assert mock_local_diff.call_count == 1 - assert mock_local_diff.call_args.args[1].stage_id == 0 - mock_remote_diff.assert_not_called() + plan = _make_llm_plan(0, configured_stage_id=3, launch_mode="local", vllm_config=fake_vllm_config).replicas[0] + sentinel_client = SimpleNamespace() - def test_non_matching_diffusion_stage_uses_remote_diffusion_client(self, mocker: MockerFixture): - """A non-local diffusion stage should attach via the remote diffusion path.""" - stage_cfgs = [_make_stage_cfg(0), _make_stage_cfg(1, stage_type="diffusion")] - engine = self._build_engine_skeleton(stage_cfgs, single_stage_mode=True, stage_id_filter=0) - mock_oms = mocker.Mock(spec=OmniMasterServer) - remote_diffusion_client = mocker.Mock(stage_type="diffusion") - finalized = ( - [mocker.Mock(), remote_diffusion_client], - [mocker.Mock(), mocker.Mock()], - [ - {"final_output": False, "final_output_type": None, "stage_type": "llm"}, - {"final_output": False, "final_output_type": None, "stage_type": "diffusion"}, - ], - ) + device_env_var = current_omni_platform.device_control_env_var + prev_device_env = os.environ.get(device_env_var) + os.environ[device_env_var] = "0" - mock_local_diff = mocker.patch.object(engine, "_launch_diffusion_stage") - mock_remote_diff = mocker.patch.object( - engine, - "_create_remote_diffusion_stage", - return_value=remote_diffusion_client, - ) - mocker.patch.object(engine, "_launch_llm_stage", return_value=_make_started_llm_stage(0)) - mocker.patch.object(engine, "_create_remote_llm_stage", return_value=_make_started_llm_stage(0)) + mocker.patch.object(engine_mod, "setup_stage_devices") + mocker.patch.object(engine_mod, "build_engine_args_dict", return_value={}) + mocker.patch.object(engine_mod, "acquire_device_locks", return_value=[]) + mocker.patch.object(engine_mod, "release_device_locks") + mock_launch = mocker.patch.object(engine_mod, "launch_omni_core_engines", side_effect=_fake_launch) mocker.patch.object( - engine, - "_attach_llm_stage", - return_value=(mocker.Mock(), mocker.Mock(), mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.OmniMasterServer", - return_value=mock_oms, - ) - mocker.patch("vllm_omni.engine.async_omni_engine.prepare_engine_environment") - mocker.patch( - "vllm_omni.engine.async_omni_engine.load_omni_transfer_config_for_model", - return_value=None, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.get_stage_connector_spec", - return_value={}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.resolve_omni_kv_config_for_stage", - return_value=(None, None, None), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.extract_stage_metadata", - side_effect=lambda cfg: self._fake_metadata( - mocker, - cfg.stage_id, - getattr(cfg, "stage_type", "llm"), - ), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.finalize_initialized_stages", - return_value=finalized, + StageEngineCoreClientBase, + "make_async_mp_client", + side_effect=lambda **_: sentinel_client, ) + try: + result = engine._initialize_llm_replica(plan, stage_init_timeout=60, llm_stage_launch_lock=threading.Lock()) + finally: + if prev_device_env is None: + os.environ.pop(device_env_var, None) + else: + os.environ[device_env_var] = prev_device_env - engine._initialize_stages(stage_init_timeout=60) + assert result is sentinel_client + assert mock_launch.call_args.kwargs["stage_id"] == 3 + assert mock_launch.call_args.kwargs["stage_config"] is plan.stage_cfg + + def test_initialize_diffusion_replica_remote_uses_from_addresses(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod + + engine = object.__new__(AsyncOmniEngine) + engine.single_stage_mode = True + engine.diffusion_batch_size = 4 + engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) + engine._omni_master_server.get_stage_config.return_value = {"stage_id": 11, "stage_type": "diffusion"} + engine._omni_master_server.get_zmq_addresses.return_value = SimpleNamespace( + inputs=["tcp://in"], + outputs=["tcp://out"], + ) - mock_local_diff.assert_not_called() - assert mock_remote_diff.call_count == 1 - assert mock_remote_diff.call_args.args[0].stage_id == 1 + remote_metadata = _make_diffusion_plan(1, configured_stage_id=11, launch_mode="remote").replicas[0].metadata + plan = _make_diffusion_plan(1, configured_stage_id=11, launch_mode="remote").replicas[0] + sentinel_client = SimpleNamespace() + mocker.patch.object(engine_mod, "extract_stage_metadata", return_value=remote_metadata) + mock_from_addresses = mocker.patch.object( + engine_mod.StageDiffusionClient, "from_addresses", return_value=sentinel_client + ) -# --------------------------------------------------------------------------- -# AsyncOmniEngine – _launch_diffusion_stage -# --------------------------------------------------------------------------- + result = engine._initialize_diffusion_replica(plan, stage_init_timeout=60, stage_launch_lock=threading.Lock()) + assert result is sentinel_client + engine._omni_master_server.get_stage_config.assert_called_once_with(11, timeout_s=60) + mock_from_addresses.assert_called_once() -class TestLaunchDiffusionStage: - """Test local diffusion stage launch wiring.""" + def test_initialize_diffusion_replica_single_stage_local_registers_with_master(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod + from vllm_omni.platforms import current_omni_platform - def test_registers_stage_with_public_master_properties(self, mocker: MockerFixture): engine = object.__new__(AsyncOmniEngine) engine.model = "fake-model" + engine.single_stage_mode = True engine.diffusion_batch_size = 4 + engine.stage_configs = [] + engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) + engine._omni_master_server.address = "127.0.0.1" + engine._omni_master_server.port = 25000 - stage_cfg = _make_stage_cfg(5, stage_type="diffusion") - metadata = mocker.Mock(stage_id=5) - omni_master_server = mocker.Mock(spec=OmniMasterServer) - omni_master_server.address = "127.0.0.1" - omni_master_server.port = 25000 - + plan = _make_diffusion_plan(0, configured_stage_id=5, launch_mode="local").replicas[0] + sentinel_client = SimpleNamespace() proc = mocker.Mock() - diffusion_client = mocker.Mock() - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_diffusion_config", - return_value="diffusion-config", - ) - mock_register = mocker.patch( - "vllm_omni.engine.async_omni_engine.register_stage_with_omni_master", - return_value=( - "tcp://127.0.0.1:25001", - "tcp://127.0.0.1:25002", - "tcp://127.0.0.1:25003", - ), - ) - mock_spawn = mocker.patch( - "vllm_omni.engine.async_omni_engine.spawn_diffusion_proc", + device_env_var = current_omni_platform.device_control_env_var + prev_device_env = os.environ.get(device_env_var) + os.environ[device_env_var] = "0" + + mocker.patch.object(engine_mod, "setup_stage_devices") + mocker.patch.object(engine_mod, "inject_kv_stage_info") + mocker.patch.object(engine_mod, "build_diffusion_config", return_value="diffusion-config") + mock_register = mocker.patch.object( + engine_mod, + "register_stage_with_omni_master", + return_value=("tcp://hs", "tcp://req", "tcp://resp"), + ) + mock_spawn = mocker.patch.object( + engine_mod, + "spawn_diffusion_proc", return_value=(proc, None, None, None), ) - mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_diffusion_handshake") - mock_from_addresses = mocker.patch( - "vllm_omni.engine.async_omni_engine.StageDiffusionClient.from_addresses", - return_value=diffusion_client, + mock_handshake = mocker.patch.object(engine_mod, "complete_diffusion_handshake") + mock_from_addresses = mocker.patch.object( + engine_mod.StageDiffusionClient, + "from_addresses", + return_value=sentinel_client, ) - result = engine._launch_diffusion_stage( - stage_cfg=stage_cfg, - metadata=metadata, - omni_master_server=omni_master_server, - ) + try: + result = engine._initialize_diffusion_replica( + plan, stage_init_timeout=60, stage_launch_lock=threading.Lock() + ) + finally: + if prev_device_env is None: + os.environ.pop(device_env_var, None) + else: + os.environ[device_env_var] = prev_device_env + assert result is sentinel_client mock_register.assert_called_once_with( omni_master_address="127.0.0.1", omni_master_port=25000, omni_stage_id=5, - omni_stage_config=stage_cfg, + omni_stage_config=plan.stage_cfg, return_addresses=True, ) mock_spawn.assert_called_once_with( "fake-model", "diffusion-config", - handshake_address="tcp://127.0.0.1:25001", - request_address="tcp://127.0.0.1:25002", - response_address="tcp://127.0.0.1:25003", + handshake_address="tcp://hs", + request_address="tcp://req", + response_address="tcp://resp", ) - mock_handshake.assert_called_once_with(proc, "tcp://127.0.0.1:25001") + mock_handshake.assert_called_once_with(proc, "tcp://hs") mock_from_addresses.assert_called_once_with( - metadata, - request_address="tcp://127.0.0.1:25002", - response_address="tcp://127.0.0.1:25003", + plan.metadata, + request_address="tcp://req", + response_address="tcp://resp", proc=proc, batch_size=4, ) - assert result is diffusion_client - -# --------------------------------------------------------------------------- -# AsyncOmniEngine – _create_remote_llm_stage -# --------------------------------------------------------------------------- - - -class TestCreateRemoteLlmStage: - """Test _create_remote_llm_stage delegates correctly.""" + def test_initialize_diffusion_replica_local_failure_terminates_proc(self, mocker: MockerFixture): + import vllm_omni.engine.async_omni_engine as engine_mod + from vllm_omni.platforms import current_omni_platform - def _engine(self, mocker: MockerFixture) -> AsyncOmniEngine: engine = object.__new__(AsyncOmniEngine) engine.model = "fake-model" engine.single_stage_mode = True - engine._single_stage_id_filter = 0 + engine.diffusion_batch_size = 4 + engine.stage_configs = [] engine._omni_master_server = mocker.Mock(spec=OmniMasterServer) - engine._omni_master_server.get_zmq_addresses.return_value = mocker.Mock() - engine._omni_master_server.get_allocation.return_value = mocker.Mock() - engine._omni_master_server.get_stage_config.return_value = { - "stage_id": 0, - "stage_type": "llm", - "engine_args": {}, - } - return engine - - def _mock_build_and_connect(self, mocker: MockerFixture, stage_id: int): - fake_vllm_config = mocker.Mock() - fake_executor_cls = mocker.Mock() - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None - - eng_mgr = mocker.Mock() - coordinator = mocker.Mock() + engine._omni_master_server.address = "127.0.0.1" + engine._omni_master_server.port = 25000 - @contextmanager - def fake_connect_cm(*args, **kwargs): - yield eng_mgr, coordinator, fake_addresses - - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": stage_id}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(fake_vllm_config, fake_executor_cls), - ) - mock_connect = mocker.patch( - "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores", - return_value=fake_connect_cm(), - ) - - return mock_connect, fake_vllm_config, fake_executor_cls, fake_addresses - - def test_returns_started_llm_stage_with_correct_stage_id(self, mocker: MockerFixture): - engine = self._engine(mocker) - stage_cfg = _make_stage_cfg(1) - metadata = mocker.Mock(stage_id=1) - omni_ms = engine._omni_master_server - omni_ms.get_stage_config.return_value = { - "stage_id": 1, - "stage_type": "llm", - "engine_args": {}, - } - - self._mock_build_and_connect(mocker, 1) - result = engine._create_remote_llm_stage( - stage_cfg=stage_cfg, - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - omni_master_server=omni_ms, - ) - assert isinstance(result, StartedLlmStage) - assert result.stage_id == 1 - - def test_connect_remote_engine_cores_called_with_stage_id(self, mocker: MockerFixture): - engine = self._engine(mocker) - stage_cfg = _make_stage_cfg(2) - metadata = mocker.Mock(stage_id=2) - omni_ms = engine._omni_master_server - omni_ms.get_zmq_addresses.return_value = mocker.Mock(inputs=["x"], outputs=["y"]) - omni_ms.get_stage_config.return_value = { - "stage_id": 2, - "stage_type": "llm", - "engine_args": {}, - } - - fake_vllm_config = mocker.Mock() - fake_executor_cls = mocker.Mock() - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None + plan = _make_diffusion_plan(0, configured_stage_id=5, launch_mode="local").replicas[0] + proc = mocker.Mock() - @contextmanager - def fake_connect_cm(*args, **kwargs): - yield mocker.Mock(), mocker.Mock(), fake_addresses + device_env_var = current_omni_platform.device_control_env_var + prev_device_env = os.environ.get(device_env_var) + os.environ[device_env_var] = "0" - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": 2}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(fake_vllm_config, fake_executor_cls), - ) - mock_connect = mocker.patch( - "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores", - return_value=fake_connect_cm(), + mocker.patch.object(engine_mod, "setup_stage_devices") + mocker.patch.object(engine_mod, "inject_kv_stage_info") + mocker.patch.object(engine_mod, "build_diffusion_config", return_value="diffusion-config") + mocker.patch.object( + engine_mod, + "register_stage_with_omni_master", + return_value=("tcp://hs", "tcp://req", "tcp://resp"), ) - - engine._create_remote_llm_stage( - stage_cfg=stage_cfg, - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - omni_master_server=omni_ms, + mocker.patch.object( + engine_mod, + "spawn_diffusion_proc", + return_value=(proc, None, None, None), ) + mocker.patch.object(engine_mod, "complete_diffusion_handshake", side_effect=RuntimeError("handshake failed")) + mock_terminate = mocker.patch.object(engine_mod, "terminate_alive_proc") - mock_connect.assert_called_once() - _, kwargs = mock_connect.call_args - assert kwargs.get("stage_id") == 2 or mock_connect.call_args.args[-1] == 2 - omni_ms.get_stage_config.assert_called_once_with(2, timeout_s=60) - - def test_missing_registered_stage_config_raises_value_error(self, mocker: MockerFixture): - engine = self._engine(mocker) - stage_cfg = _make_stage_cfg(3) - metadata = mocker.Mock(stage_id=3) - omni_ms = engine._omni_master_server - omni_ms.get_stage_config.return_value = None - - mock_build_args = mocker.patch("vllm_omni.engine.async_omni_engine.build_engine_args_dict") - with pytest.raises( - ValueError, - match="Remote stage 3 registered without stage config", - ): - engine._create_remote_llm_stage( - stage_cfg=stage_cfg, - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - omni_master_server=omni_ms, - ) + try: + with pytest.raises(RuntimeError, match="handshake failed"): + engine._initialize_diffusion_replica(plan, stage_init_timeout=60, stage_launch_lock=threading.Lock()) + finally: + if prev_device_env is None: + os.environ.pop(device_env_var, None) + else: + os.environ[device_env_var] = prev_device_env - mock_build_args.assert_not_called() - - def test_exception_during_connect_closes_started_stage(self, mocker: MockerFixture): - """If an error occurs after StartedLlmStage creation, close_started_llm_stage is called.""" - engine = self._engine(mocker) - stage_cfg = _make_stage_cfg(1) - metadata = mocker.Mock(stage_id=1) - omni_ms = engine._omni_master_server - omni_ms.get_stage_config.return_value = { - "stage_id": 1, - "stage_type": "llm", - "engine_args": {}, - } + mock_terminate.assert_called_once_with(proc) - @contextmanager - def boom(*args, **kwargs): - yield mocker.Mock(), mocker.Mock(), mocker.Mock() - raise RuntimeError("handshake failed") - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": 1}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.connect_remote_engine_cores", - return_value=boom(), - ) - mock_close = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage") - with pytest.raises(RuntimeError, match="handshake failed"): - engine._create_remote_llm_stage( - stage_cfg=stage_cfg, - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - omni_master_server=omni_ms, - ) - mock_close.assert_called_once() +# --------------------------------------------------------------------------- +# Stage engine startup helpers +# --------------------------------------------------------------------------- class TestConnectRemoteEngineCoresCoordinator: - """Test coordinator launch parity with launch_core_engines.""" - @staticmethod def _build_vllm_config( mocker: MockerFixture, *, dp_rank: int = 0, offline_mode: bool = False, needs_dp_coordinator: bool = True @@ -1347,7 +896,8 @@ def test_uses_registered_coordinator_addresses(self, mocker: MockerFixture): omni_master_server = mocker.Mock(spec=OmniMasterServer) omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses( - inputs=["tcp://client-in"], outputs=["tcp://client-out"] + inputs=["tcp://client-in"], + outputs=["tcp://client-out"], ) omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001") omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses( @@ -1360,10 +910,7 @@ def test_uses_registered_coordinator_addresses(self, mocker: MockerFixture): def fake_socket_ctx(*args, **kwargs): yield mocker.Mock() - mocker.patch( - "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", - return_value=fake_socket_ctx(), - ) + mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()) mock_wait = mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup") with connect_remote_engine_cores( vllm_config=vllm_config, @@ -1379,16 +926,12 @@ def fake_socket_ctx(*args, **kwargs): mock_wait.assert_called_once() def test_defaults_to_no_coordinator_addresses_when_none_registered(self, mocker: MockerFixture): - vllm_config = self._build_vllm_config( - mocker, - dp_rank=0, - offline_mode=False, - needs_dp_coordinator=True, - ) + vllm_config = self._build_vllm_config(mocker, dp_rank=0, offline_mode=False, needs_dp_coordinator=True) omni_master_server = mocker.Mock(spec=OmniMasterServer) omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses( - inputs=["tcp://client-in"], outputs=["tcp://client-out"] + inputs=["tcp://client-in"], + outputs=["tcp://client-out"], ) omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001") omni_master_server.get_stage_coordinator_addresses.return_value = StageCoordinatorAddresses() @@ -1397,10 +940,7 @@ def test_defaults_to_no_coordinator_addresses_when_none_registered(self, mocker: def fake_socket_ctx(*args, **kwargs): yield mocker.Mock() - mocker.patch( - "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", - return_value=fake_socket_ctx(), - ) + mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()) mocker.patch("vllm_omni.engine.stage_engine_startup._wait_for_omni_engine_startup") with connect_remote_engine_cores( vllm_config=vllm_config, @@ -1414,8 +954,6 @@ def fake_socket_ctx(*args, **kwargs): class TestLaunchOmniCoreEngines: - """Tests for local omni engine launch wiring.""" - def test_registers_stage_once_and_reuses_handshake_for_all_local_engines(self, mocker: MockerFixture): parallel_config = mocker.Mock( data_parallel_size_local=2, @@ -1440,10 +978,7 @@ def fake_socket_ctx(*args, **kwargs): "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", return_value="tcp://127.0.0.1:26001", ) - mocker.patch( - "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", - return_value=fake_socket_ctx(), - ) + mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()) mock_manager_cls = mocker.patch( "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager", return_value=local_engine_manager, @@ -1459,6 +994,7 @@ def fake_socket_ctx(*args, **kwargs): ) as (yielded_manager, yielded_coordinator, yielded_addresses): assert yielded_manager is local_engine_manager assert yielded_coordinator is None + assert yielded_addresses is not None mock_register.assert_called_once_with( omni_master_address="127.0.0.1", @@ -1467,15 +1003,11 @@ def fake_socket_ctx(*args, **kwargs): omni_stage_config=stage_config, coordinator=None, ) - mock_manager_cls.assert_called_once() manager_kwargs = mock_manager_cls.call_args.kwargs assert manager_kwargs["local_engine_count"] == 2 assert manager_kwargs["start_index"] == 3 assert manager_kwargs["local_start_index"] == 0 - assert manager_kwargs["vllm_config"] is vllm_config - assert manager_kwargs["local_client"] is True assert manager_kwargs["handshake_address"] == "tcp://127.0.0.1:26001" - assert manager_kwargs["executor_class"] is not None def test_registers_stage_with_coordinator_when_started(self, mocker: MockerFixture): parallel_config = mocker.Mock( @@ -1483,15 +1015,19 @@ def test_registers_stage_with_coordinator_when_started(self, mocker: MockerFixtu data_parallel_size=2, data_parallel_rank=0, ) - vllm_config = mocker.Mock(parallel_config=parallel_config) - vllm_config.needs_dp_coordinator = True - vllm_config.model_config = mocker.Mock(is_moe=False) + vllm_config = mocker.Mock( + parallel_config=parallel_config, + needs_dp_coordinator=True, + model_config=mocker.Mock(is_moe=False), + cache_config=mocker.Mock(), + ) omni_master_server = mocker.Mock(spec=OmniMasterServer) omni_master_server.address = "127.0.0.1" omni_master_server.port = 26000 omni_master_server.get_zmq_addresses.return_value = EngineZmqAddresses( - inputs=["tcp://client-in"], outputs=["tcp://client-out"] + inputs=["tcp://client-in"], + outputs=["tcp://client-out"], ) omni_master_server.get_allocation.return_value = mocker.Mock(handshake_bind_address="tcp://127.0.0.1:26001") @@ -1509,11 +1045,8 @@ def fake_socket_ctx(*args, **kwargs): "vllm_omni.engine.stage_engine_startup.register_stage_with_omni_master", return_value="tcp://127.0.0.1:26001", ) + mocker.patch("vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", return_value=fake_socket_ctx()) mocker.patch( - "vllm_omni.engine.stage_engine_startup.zmq_socket_ctx", - return_value=fake_socket_ctx(), - ) - mock_manager_cls = mocker.patch( "vllm_omni.engine.stage_engine_startup.CoreEngineProcManager", return_value=mocker.Mock(), ) @@ -1525,8 +1058,11 @@ def fake_socket_ctx(*args, **kwargs): omni_master_server=omni_master_server, stage_id=7, stage_config={"stage_id": 7}, - ): - pass + ) as (_, yielded_coordinator, yielded_addresses): + assert yielded_coordinator is coordinator + assert yielded_addresses.coordinator_input == "tcp://coord-in" + assert yielded_addresses.coordinator_output == "tcp://coord-out" + assert yielded_addresses.frontend_stats_publish_address == "tcp://stats" mock_register.assert_called_once_with( omni_master_address="127.0.0.1", @@ -1535,312 +1071,4 @@ def fake_socket_ctx(*args, **kwargs): omni_stage_config={"stage_id": 7}, coordinator=coordinator, ) - manager_kwargs = mock_manager_cls.call_args.kwargs - assert manager_kwargs["log_stats"] is False mock_wait.assert_called_once() - - -# --------------------------------------------------------------------------- -# AsyncOmniEngine – _launch_llm_stage single_stage_mode codepath -# --------------------------------------------------------------------------- - - -class TestLaunchLlmStageSingleStageMode: - """Test that _launch_llm_stage selects launch_omni_core_engines when - single_stage_mode=True and _omni_master_server is set.""" - - def _build_engine_with_oms(self, mocker: MockerFixture) -> AsyncOmniEngine: - engine = object.__new__(AsyncOmniEngine) - engine.model = "fake-model" - engine.single_stage_mode = True - engine._single_stage_id_filter = 0 - engine._llm_stage_launch_lock = threading.Lock() - engine.stage_configs = [] - mock_oms = mocker.Mock(spec=OmniMasterServer) - mock_oms.address = "127.0.0.1" - mock_oms.port = 25000 - alloc = mocker.Mock() - alloc.handshake_bind_address = "tcp://127.0.0.1:25001" - mock_oms.get_allocation.return_value = alloc - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None - mock_oms.get_zmq_addresses.return_value = fake_addresses - engine._omni_master_server = mock_oms - return engine - - def _mock_launch_omni(self, mocker: MockerFixture, stage_id: int): - fake_vllm_config = mocker.Mock() - fake_executor_cls = mocker.Mock() - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None - - eng_mgr = mocker.Mock() - - @contextmanager - def fake_launch_omni(*args, **kwargs): - yield eng_mgr, None, fake_addresses - - mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices") - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": stage_id}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(fake_vllm_config, fake_executor_cls), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.acquire_device_locks", - return_value=[], - ) - mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks") - return mocker.patch( - "vllm_omni.engine.async_omni_engine.launch_omni_core_engines", - return_value=fake_launch_omni(), - ) - - def test_launch_omni_core_engines_used_in_single_stage_mode(self, mocker: MockerFixture): - """single_stage_mode + _omni_master_server → launch_omni_core_engines.""" - engine = self._build_engine_with_oms(mocker) - metadata = mocker.Mock(stage_id=0, runtime_cfg={}) - stage_cfg = _make_stage_cfg(0) - - mock_launch_omni = self._mock_launch_omni(mocker, 0) - result = engine._launch_llm_stage( - stage_cfg=stage_cfg, - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - llm_stage_launch_lock=threading.Lock(), - ) - - mock_launch_omni.assert_called_once() - assert mock_launch_omni.call_args.kwargs["stage_config"] is stage_cfg - assert isinstance(result, StartedLlmStage) - assert result.stage_id == 0 - - def test_spawn_stage_core_used_in_normal_mode(self, mocker: MockerFixture): - """~single_stage_mode → spawn_stage_core + complete_stage_handshake.""" - engine = object.__new__(AsyncOmniEngine) - engine.model = "fake-model" - engine.single_stage_mode = False - engine._omni_master_server = None - engine._llm_stage_launch_lock = threading.Lock() - engine.stage_configs = [] - - fake_vllm_config = mocker.Mock() - fake_executor_cls = mocker.Mock() - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None - - fake_proc = mocker.Mock() - fake_handshake_address = "ipc:///tmp/fake-handshake" - stage_init_timeout = 60 - - mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices") - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": 0}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(fake_vllm_config, fake_executor_cls), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.acquire_device_locks", - return_value=[], - ) - mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks") - mock_spawn = mocker.patch( - "vllm_omni.engine.async_omni_engine.spawn_stage_core", - return_value=(fake_addresses, fake_proc, fake_handshake_address), - ) - mock_handshake = mocker.patch("vllm_omni.engine.async_omni_engine.complete_stage_handshake") - mock_omni = mocker.patch("vllm_omni.engine.async_omni_engine.launch_omni_core_engines") - metadata = mocker.Mock(stage_id=0, runtime_cfg={}) - result = engine._launch_llm_stage( - stage_cfg=_make_stage_cfg(0), - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=stage_init_timeout, - llm_stage_launch_lock=threading.Lock(), - ) - - mock_spawn.assert_called_once_with( - vllm_config=fake_vllm_config, - executor_class=fake_executor_cls, - log_stats=False, - ) - mock_handshake.assert_called_once_with( - fake_proc, - fake_handshake_address, - fake_addresses, - fake_vllm_config, - stage_init_timeout, - ) - mock_omni.assert_not_called() - assert isinstance(result, StartedLlmStage) - assert result.proc is fake_proc - - def test_launch_omni_passes_stage_id_and_master_server(self, mocker: MockerFixture): - """launch_omni_core_engines receives the correct stage_id and omni_master_server.""" - engine = self._build_engine_with_oms(mocker) - metadata = mocker.Mock(stage_id=0, runtime_cfg={}) - - captured_kwargs: dict[str, Any] = {} - - @contextmanager - def capturing_launch(*args, **kwargs): - captured_kwargs.update(kwargs) - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None - yield mocker.Mock(), None, fake_addresses - - mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices") - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": 0}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.acquire_device_locks", - return_value=[], - ) - mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks") - mocker.patch( - "vllm_omni.engine.async_omni_engine.launch_omni_core_engines", - side_effect=capturing_launch, - ) - - engine._launch_llm_stage( - stage_cfg=_make_stage_cfg(0), - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - llm_stage_launch_lock=threading.Lock(), - ) - - assert captured_kwargs.get("stage_id") == 0 - assert captured_kwargs.get("omni_master_server") is engine._omni_master_server - - def test_launch_omni_context_exits_before_stage_cleanup_on_error(self, mocker: MockerFixture): - """Errors after entering the omni launch context still unwind it first.""" - engine = self._build_engine_with_oms(mocker) - metadata = mocker.Mock(stage_id=0, runtime_cfg={}) - - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None - - events: list[str] = [] - - @contextmanager - def fake_launch_omni(*args, **kwargs): - try: - yield mocker.Mock(), None, fake_addresses - finally: - events.append("launch_exit") - - mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices") - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": 0}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.acquire_device_locks", - return_value=[], - ) - mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks") - mocker.patch( - "vllm_omni.engine.async_omni_engine.launch_omni_core_engines", - return_value=fake_launch_omni(), - ) - mocker.patch("vllm_omni.engine.async_omni_engine.logger.info", side_effect=RuntimeError("boom")) - mock_close_stage = mocker.patch( - "vllm_omni.engine.async_omni_engine.close_started_llm_stage", - side_effect=lambda _started: events.append("stage_close"), - ) - with pytest.raises(RuntimeError, match="boom"): - engine._launch_llm_stage( - stage_cfg=_make_stage_cfg(0), - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - llm_stage_launch_lock=threading.Lock(), - ) - - mock_close_stage.assert_called_once() - assert events == ["launch_exit", "stage_close"] - - def test_base_exception_propagates_without_started_stage_cleanup(self, mocker: MockerFixture): - """BaseException subclasses should bypass the Exception cleanup path.""" - engine = self._build_engine_with_oms(mocker) - metadata = mocker.Mock(stage_id=0, runtime_cfg={}) - - fake_addresses = mocker.Mock() - fake_addresses.inputs = ["tcp://127.0.0.1:5000"] - fake_addresses.outputs = ["tcp://127.0.0.1:5001"] - fake_addresses.frontend_stats_publish_address = None - - events: list[str] = [] - - class FatalLaunchInterrupt(BaseException): - pass - - @contextmanager - def fake_launch_omni(*args, **kwargs): - try: - yield mocker.Mock(), None, fake_addresses - finally: - events.append("launch_exit") - - mocker.patch("vllm_omni.engine.async_omni_engine.setup_stage_devices") - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_engine_args_dict", - return_value={"model": "fake", "stage_id": 0}, - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.build_vllm_config", - return_value=(mocker.Mock(), mocker.Mock()), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.acquire_device_locks", - return_value=[], - ) - mocker.patch("vllm_omni.engine.async_omni_engine.release_device_locks") - mocker.patch( - "vllm_omni.engine.async_omni_engine.launch_omni_core_engines", - return_value=fake_launch_omni(), - ) - mocker.patch( - "vllm_omni.engine.async_omni_engine.logger.info", - side_effect=FatalLaunchInterrupt("stop"), - ) - mock_close_stage = mocker.patch("vllm_omni.engine.async_omni_engine.close_started_llm_stage") - with pytest.raises(FatalLaunchInterrupt, match="stop"): - engine._launch_llm_stage( - stage_cfg=_make_stage_cfg(0), - metadata=metadata, - stage_connector_spec={}, - stage_init_timeout=60, - llm_stage_launch_lock=threading.Lock(), - ) - - mock_close_stage.assert_not_called() - assert events == ["launch_exit"] diff --git a/tests/engine/test_stage_engine_core_client.py b/tests/engine/test_stage_engine_core_client.py new file mode 100644 index 00000000000..dde0927af2d --- /dev/null +++ b/tests/engine/test_stage_engine_core_client.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for StageEngineCoreClient.check_health(). + +Uses object.__new__ to construct a minimal client — check_health only +touches self.resources, self.stage_id, and self._proc. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from vllm.v1.engine.exceptions import EngineDeadError + +from vllm_omni.engine.stage_engine_core_client import StageEngineCoreClient + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _make_client(*, engine_dead=False, proc_alive=True): + client = object.__new__(StageEngineCoreClient) + client.stage_id = 0 + client.resources = SimpleNamespace(engine_dead=engine_dead) + client._proc = MagicMock(is_alive=MagicMock(return_value=proc_alive), exitcode=1) + return client + + +def test_check_health_passes_when_alive(): + client = _make_client(engine_dead=False, proc_alive=True) + client.check_health() # no exception + + +def test_check_health_raises_when_resources_engine_dead(): + client = _make_client(engine_dead=True, proc_alive=True) + with pytest.raises(EngineDeadError, match="engine core is dead"): + client.check_health() + + +def test_check_health_raises_when_proc_not_alive(): + client = _make_client(engine_dead=False, proc_alive=False) + with pytest.raises(EngineDeadError, match="not alive"): + client.check_health() + # Verify it set resources.engine_dead as a side effect + assert client.resources.engine_dead is True diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py index 607b3eaa813..8a373f74d27 100644 --- a/tests/entrypoints/openai_api/test_image_server.py +++ b/tests/entrypoints/openai_api/test_image_server.py @@ -407,6 +407,29 @@ def test_health_endpoint_no_engine(): assert data["status"] == "unhealthy" +def test_health_endpoint_dead_engine(): + """Health returns 503 when the engine raises EngineDeadError.""" + from unittest.mock import AsyncMock + + from fastapi import FastAPI + from vllm.v1.engine.exceptions import EngineDeadError + + from vllm_omni.entrypoints.openai.api_server import router + + app = FastAPI() + app.include_router(router) + + dead_engine = AsyncMock() + dead_engine.check_health = AsyncMock(side_effect=EngineDeadError()) + app.state.engine_client = dead_engine + + client = TestClient(app) + response = client.get("/health") + assert response.status_code == 503 + data = response.json() + assert data["status"] == "unhealthy" + + def test_models_endpoint(test_client): """Test /v1/models endpoint for diffusion mode""" response = test_client.get("/v1/models") diff --git a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py index 4190b1fbb13..0a9d45a8ccb 100644 --- a/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py +++ b/tests/entrypoints/openai_api/test_serving_chat_sampling_params.py @@ -100,6 +100,10 @@ def mock_request(mocker: MockerFixture): request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + # Must be real Python objects (not MagicMock) so the code's explicit-field + # and extra_body checks work correctly. + request.model_fields_set = set() + request.extra_body = {} return request @@ -150,6 +154,7 @@ def test_preserves_yaml_defaults_when_no_request_params(serving_chat, mock_reque def test_request_temperature_overrides_yaml_default(serving_chat, mock_request): """Test that request temperature overrides YAML default.""" mock_request.temperature = 0.8 + mock_request.model_fields_set = {"temperature"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -162,6 +167,7 @@ def test_request_temperature_overrides_yaml_default(serving_chat, mock_request): def test_request_top_p_overrides_yaml_default(serving_chat, mock_request): """Test that request top_p overrides YAML default.""" mock_request.top_p = 0.95 + mock_request.model_fields_set = {"top_p"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -173,6 +179,7 @@ def test_request_top_p_overrides_yaml_default(serving_chat, mock_request): def test_request_max_tokens_overrides_yaml_default(serving_chat, mock_request): """Test that request max_tokens overrides YAML default.""" mock_request.max_tokens = 100 + mock_request.model_fields_set = {"max_tokens"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -189,6 +196,7 @@ def test_max_tokens_uses_yaml_default_when_not_specified(serving_chat, mock_requ def test_request_seed_overrides_yaml_default(serving_chat, mock_request): """Test that request seed overrides YAML default.""" mock_request.seed = 123 + mock_request.model_fields_set = {"seed"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -200,6 +208,7 @@ def test_request_seed_overrides_yaml_default(serving_chat, mock_request): def test_request_frequency_penalty_overrides(serving_chat, mock_request): """Test that request frequency_penalty is applied.""" mock_request.frequency_penalty = 0.5 + mock_request.model_fields_set = {"frequency_penalty"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -209,6 +218,7 @@ def test_request_frequency_penalty_overrides(serving_chat, mock_request): def test_request_presence_penalty_overrides(serving_chat, mock_request): """Test that request presence_penalty is applied.""" mock_request.presence_penalty = 0.3 + mock_request.model_fields_set = {"presence_penalty"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -235,6 +245,7 @@ def test_multiple_params_override_together(serving_chat, mock_request): mock_request.temperature = 0.7 mock_request.top_p = 0.85 mock_request.seed = 999 + mock_request.model_fields_set = {"max_tokens", "temperature", "top_p", "seed"} result = serving_chat._build_sampling_params_list_from_request(mock_request) @@ -275,6 +286,7 @@ def test_apply_request_overrides_applies_values(serving_chat, mock_request, defa """Test that _apply_request_overrides applies non-None request values.""" mock_request.temperature = 0.8 mock_request.seed = 123 + mock_request.model_fields_set = {"temperature", "seed"} result = serving_chat._apply_request_overrides(default_comprehension_params, mock_request) @@ -304,6 +316,8 @@ def test_apply_overrides_empty_stop_list_preserves_default(serving_chat, mocker) request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -325,6 +339,8 @@ def test_apply_overrides_nonempty_stop_list_overrides_default(serving_chat, mock request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -367,6 +383,8 @@ def test_apply_overrides_nonempty_stop_token_ids_overrides_default(serving_chat, request.stop_token_ids = [100] # non-empty list — should override request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop_token_ids"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -392,6 +410,8 @@ def test_apply_overrides_mixed_empty_and_nonempty_lists(serving_chat, mocker): request.stop_token_ids = [100, 200] # non-empty — SHOULD override request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"temperature", "stop", "stop_token_ids"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -415,6 +435,8 @@ def test_apply_overrides_none_scalar_still_preserves_default(serving_chat, mocke request.stop_token_ids = None request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = set() + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -442,6 +464,8 @@ def test_apply_overrides_both_lists_empty_preserves_defaults(serving_chat, mocke request.stop_token_ids = [] request.frequency_penalty = None request.presence_penalty = None + request.model_fields_set = {"stop", "stop_token_ids"} + request.extra_body = {} result = serving_chat._apply_request_overrides(default_params, request) @@ -511,3 +535,165 @@ def test_get_comprehension_stage_index_raises_when_not_found(mocker: MockerFixtu with pytest.raises(ValueError, match="No comprehension stage"): instance._get_comprehension_stage_index() + + +# ============================================================================= +# Tests for _resolve_height_width_from_extra_body +# ============================================================================= + + +class TestResolveHeightWidth: + def test_explicit_height_width(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"height": 512, "width": 768}) + assert h == 512 + assert w == 768 + + def test_size_string(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768x512"}) + assert w == 768 + assert h == 512 + + def test_size_string_uppercase(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "768X512"}) + assert w == 768 + assert h == 512 + + def test_size_fallback_when_height_missing(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "512x512", "width": 1024}) + # height is None -> size fallback fires and sets BOTH width and height + assert h == 512 + assert w == 512 + + def test_empty_extra_body(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({}) + assert h is None + assert w is None + + def test_invalid_size_format_ignored(self): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + h, w = OmniOpenAIServingChat._resolve_height_width_from_extra_body({"size": "invalid"}) + assert h is None + assert w is None + + +# ============================================================================= +# Tests for _apply_request_overrides with GLM-Image (max_tokens computation) +# ============================================================================= + + +class TestApplyRequestOverridesGLMImage: + """Test dynamic max_tokens computation for GLM-Image AR stage.""" + + @pytest.fixture + def glm_serving_chat(self, mock_engine_client, mocker: MockerFixture): + from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat + + instance = object.__new__(OmniOpenAIServingChat) + instance.engine_client = mock_engine_client + # Mock the image extraction to return no reference images (t2i by default) + instance._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock(return_value=("a cat", [])) + return instance + + @pytest.fixture + def glm_request(self, mocker: MockerFixture): + req = mocker.MagicMock() + req.temperature = None + req.top_p = None + req.top_k = None + req.max_tokens = None + req.min_tokens = None + req.seed = None + req.ignore_eos = None + req.stop = None + req.stop_token_ids = None + req.frequency_penalty = None + req.presence_penalty = None + req.extra_body = {"height": 1024, "width": 1024} + req.model_fields_set = set() + return req + + def test_t2i_computes_max_tokens(self, glm_serving_chat, glm_request, default_comprehension_params): + """t2i mode: max_tokens computed from height/width, no reference images.""" + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request) + # t2i 1024x1024 = 256 + 1024 + 1 = 1281 + assert result.max_tokens == 1281 + assert result.extra_args["target_h"] == 1024 + assert result.extra_args["target_w"] == 1024 + + def test_i2i_computes_fewer_tokens( + self, glm_serving_chat, glm_request, default_comprehension_params, mocker: MockerFixture + ): + """i2i mode: max_tokens should be smaller than t2i for same dimensions.""" + # Make it detect reference images + glm_serving_chat._extract_diffusion_prompt_and_images_from_messages = mocker.MagicMock( + return_value=("edit this", ["fake_image"]) + ) + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request) + # i2i 1024x1024 = 1024 + 1 = 1025 + assert result.max_tokens == 1025 + + def test_dynamic_max_tokens_overrides_user_value(self, glm_serving_chat, glm_request, default_comprehension_params): + """When height/width are provided, dynamic computation overrides user max_tokens.""" + glm_request.max_tokens = 500 + glm_request.model_fields_set = {"max_tokens"} + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, glm_request) + # Dynamic computation from height/width always wins when present + assert result.max_tokens == 1281 + + def test_no_height_width_preserves_default( + self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params + ): + """When no height/width in extra_body, keep YAML default max_tokens.""" + req = mocker.MagicMock() + req.temperature = None + req.top_p = None + req.top_k = None + req.max_tokens = None + req.min_tokens = None + req.seed = None + req.ignore_eos = None + req.stop = None + req.stop_token_ids = None + req.frequency_penalty = None + req.presence_penalty = None + req.extra_body = {} + req.model_fields_set = set() + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req) + assert result.max_tokens == 2048 # YAML default + + def test_size_string_parsed_for_glm_image( + self, glm_serving_chat, mocker: MockerFixture, default_comprehension_params + ): + """'size' in extra_body is parsed as fallback for height/width.""" + req = mocker.MagicMock() + req.temperature = None + req.top_p = None + req.top_k = None + req.max_tokens = None + req.min_tokens = None + req.seed = None + req.ignore_eos = None + req.stop = None + req.stop_token_ids = None + req.frequency_penalty = None + req.presence_penalty = None + req.extra_body = {"size": "512x512"} + req.model_fields_set = set() + + result = glm_serving_chat._apply_request_overrides(default_comprehension_params, req) + # 512x512 t2i = 256 + 256 + 1 = 513 + assert result.max_tokens == 513 diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py index 6157d82313e..1cc5616657d 100644 --- a/tests/entrypoints/openai_api/test_video_server.py +++ b/tests/entrypoints/openai_api/test_video_server.py @@ -133,7 +133,6 @@ def _wait_for_status(client: TestClient, video_id: str, status: str, timeout_s: last_payload = None while time.time() < deadline: response = client.get(f"/v1/videos/{video_id}") - assert response.status_code == 200 last_payload = response.json() if last_payload["status"] == status: return last_payload @@ -644,7 +643,7 @@ def test_invalid_lora_returns_400(test_client): assert response.status_code == 200 video_id = response.json()["id"] failed = _wait_for_status(test_client, video_id, VideoGenerationStatus.FAILED.value) - assert failed["error"]["code"] == "HTTPException" + assert failed["error"]["code"] == 400 assert "lora object" in failed["error"]["message"].lower() diff --git a/tests/entrypoints/test_omni_entrypoints.py b/tests/entrypoints/test_omni_entrypoints.py index 3cffcd37df4..289cf2673fb 100644 --- a/tests/entrypoints/test_omni_entrypoints.py +++ b/tests/entrypoints/test_omni_entrypoints.py @@ -4,11 +4,13 @@ from collections.abc import Callable from types import SimpleNamespace from typing import Any +from unittest.mock import MagicMock import pytest from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.omni import Omni @@ -163,6 +165,15 @@ def _patch_engine(monkeypatch: pytest.MonkeyPatch, engine: FakeAsyncOmniEngine) monkeypatch.setattr("vllm_omni.entrypoints.omni_base.omni_snapshot_download", lambda model: model) +def _make_base(): + from vllm_omni.entrypoints.omni_base import OmniBase + + obj = object.__new__(OmniBase) + obj.engine = MagicMock() + obj.request_states = {} + return obj + + def _stage_spec( stage_id: int, *, @@ -687,3 +698,255 @@ def test_omni_forces_final_only_on_llm_stages(monkeypatch: pytest.MonkeyPatch): assert submitted_params[1].output_kind == RequestOutputKind.FINAL_ONLY assert submitted_params[2].output_kind == original_diffusion_output_kind assert len(outputs) == 2 + + +def test_fatal_error_raises_engine_dead(): + base = _make_base() + msg = {"type": "error", "error": "orchestrator crashed", "fatal": True} + + with pytest.raises(EngineDeadError, match="orchestrator crashed"): + base._handle_output_message(msg) + + +def test_non_fatal_error_raises_runtime(): + base = _make_base() + msg = {"type": "error", "error": "something wrong"} + + with pytest.raises(RuntimeError, match="something wrong"): + base._handle_output_message(msg) + + +def test_async_omni_errored_property_alive(): + omni = object.__new__(AsyncOmni) + omni.engine = SimpleNamespace( + is_alive=lambda: True, + stage_clients=[SimpleNamespace(is_comprehension=False)], + ) + + assert omni.errored is False + + +def test_async_omni_errored_property_dead_engine(): + omni = object.__new__(AsyncOmni) + omni.engine = SimpleNamespace( + is_alive=lambda: False, + stage_clients=[SimpleNamespace(is_comprehension=False)], + ) + + assert omni.errored is True + + +def test_async_omni_errored_property_dead_stage(): + omni = object.__new__(AsyncOmni) + dead_stage = SimpleNamespace(is_comprehension=False, _engine_dead=True) + omni.engine = SimpleNamespace( + is_alive=lambda: True, + stage_clients=[dead_stage], + ) + + assert omni.errored is True + + +def _enqueue_stage_error( + engine: FakeAsyncOmniEngine, + msg, + *, + error_text: str, + kill_engine: bool = False, +): + """Enqueue a stage error output, optionally killing the engine.""" + if kill_engine: + engine._alive = False + engine.output_q.put_nowait( + { + "type": "output", + "request_id": msg["request_id"], + "stage_id": 0, + "engine_outputs": SimpleNamespace( + payload="", + finished=True, + images=[], + stage_durations={}, + error=error_text, + ), + "finished": False, + } + ) + + +@pytest.mark.asyncio +async def test_async_omni_propagates_engine_dead_error(monkeypatch: pytest.MonkeyPatch): + """When the engine is dead and an error output arrives, ``generate()`` + must raise ``EngineDeadError`` (not plain ``RuntimeError``).""" + + engine = FakeAsyncOmniEngine( + stage_metadata=THREE_STAGE_META, + on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="worker OOM", kill_engine=True), + ) + _patch_engine(monkeypatch, engine) + + app = AsyncOmni("dummy-model") + try: + with pytest.raises(EngineDeadError, match="worker OOM"): + async for _ in app.generate(prompt="hello", request_id="req-dead"): + pass + finally: + app.shutdown() + + +@pytest.mark.asyncio +async def test_async_omni_propagates_engine_generate_error(monkeypatch: pytest.MonkeyPatch): + """When the engine is alive but a stage error occurs, ``generate()`` + must raise ``EngineGenerateError`` (recoverable, not ``EngineDeadError``).""" + + engine = FakeAsyncOmniEngine( + stage_metadata=THREE_STAGE_META, + on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="diffusion step failed"), + ) + _patch_engine(monkeypatch, engine) + + app = AsyncOmni("dummy-model") + try: + with pytest.raises(EngineGenerateError): + async for _ in app.generate(prompt="hello", request_id="req-recover"): + pass + finally: + app.shutdown() + + +# ───────── OmniBase.check_health() aggregation ───────── + + +def test_check_health_passes_when_all_healthy(): + base = _make_base() + healthy_stage = MagicMock() + healthy_stage.check_health = MagicMock() + base.engine.is_alive.return_value = True + base.engine.stage_clients = [healthy_stage] + base.check_health() # should not raise + + +def test_check_health_raises_when_stage_dead(): + base = _make_base() + dead_stage = MagicMock() + dead_stage.check_health = MagicMock(side_effect=EngineDeadError("Stage-1 dead")) + base.engine.is_alive.return_value = True + base.engine.stage_clients = [dead_stage] + with pytest.raises(EngineDeadError, match="Stage-1 dead"): + base.check_health() + + +def test_check_health_raises_when_orchestrator_dead(): + base = _make_base() + base.engine.is_alive.return_value = False + base.engine.stage_clients = [] + with pytest.raises(EngineDeadError, match="not alive"): + base.check_health() + + +# ───────── OmniBase.errored property ───────── + + +def test_omni_base_errored_false_when_alive(): + base = _make_base() + base.engine.is_alive.return_value = True + base.engine.stage_clients = [SimpleNamespace()] + assert base.errored is False + + +def test_omni_base_errored_true_when_orchestrator_dead(): + base = _make_base() + base.engine.is_alive.return_value = False + base.engine.stage_clients = [] + assert base.errored is True + + +def test_omni_base_errored_true_when_stage_engine_dead(): + base = _make_base() + base.engine.is_alive.return_value = True + dead_stage = SimpleNamespace(_engine_dead=True) + base.engine.stage_clients = [dead_stage] + assert base.errored is True + + +def test_omni_base_errored_true_when_stage_resources_engine_dead(): + base = _make_base() + base.engine.is_alive.return_value = True + dead_stage = SimpleNamespace(resources=SimpleNamespace(engine_dead=True)) + base.engine.stage_clients = [dead_stage] + assert base.errored is True + + +# ───────── Omni (sync) EngineDeadError / EngineGenerateError ───────── + + +def test_omni_propagates_engine_dead_error(monkeypatch: pytest.MonkeyPatch): + """When the engine is dead and a stage error output arrives, + ``Omni.generate()`` must raise ``EngineDeadError``.""" + engine = FakeAsyncOmniEngine( + stage_metadata=THREE_STAGE_META, + on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="worker OOM", kill_engine=True), + ) + _patch_engine(monkeypatch, engine) + + app = Omni("dummy-model") + try: + with pytest.raises(EngineDeadError, match="worker OOM"): + list(app.generate(["hello"], py_generator=False, use_tqdm=False)) + finally: + app.shutdown() + + +def test_omni_propagates_engine_generate_error(monkeypatch: pytest.MonkeyPatch): + """When the engine is alive but a stage error occurs, + ``Omni.generate()`` must raise ``EngineGenerateError`` (recoverable).""" + engine = FakeAsyncOmniEngine( + stage_metadata=THREE_STAGE_META, + on_add_request=lambda eng, msg: _enqueue_stage_error(eng, msg, error_text="diffusion step failed"), + ) + _patch_engine(monkeypatch, engine) + + app = Omni("dummy-model") + try: + with pytest.raises(EngineGenerateError): + list(app.generate(["hello"], py_generator=False, use_tqdm=False)) + finally: + app.shutdown() + + +def test_omni_errored_property_alive(monkeypatch: pytest.MonkeyPatch): + """Omni.errored (inherited from OmniBase) returns False when healthy.""" + engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META) + _patch_engine(monkeypatch, engine) + + app = Omni("dummy-model") + try: + assert app.errored is False + finally: + app.shutdown() + + +def test_omni_errored_property_dead_engine(monkeypatch: pytest.MonkeyPatch): + """Omni.errored returns True when the orchestrator is dead.""" + engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META) + _patch_engine(monkeypatch, engine) + + app = Omni("dummy-model") + try: + engine._alive = False + assert app.errored is True + finally: + app.shutdown() + + +def test_omni_errored_property_dead_stage(monkeypatch: pytest.MonkeyPatch): + """Omni.errored returns True when a stage client is marked dead.""" + engine = FakeAsyncOmniEngine(stage_metadata=THREE_STAGE_META) + _patch_engine(monkeypatch, engine) + + app = Omni("dummy-model") + try: + engine.stage_clients[0]._engine_dead = True + assert app.errored is True + finally: + app.shutdown() diff --git a/tests/entrypoints/test_omni_sleep_mode.py b/tests/entrypoints/test_omni_sleep_mode.py new file mode 100644 index 00000000000..aa7be1ba0f7 --- /dev/null +++ b/tests/entrypoints/test_omni_sleep_mode.py @@ -0,0 +1,336 @@ +import asyncio +import logging +import os + +import pytest +import torch +from vllm import SamplingParams + +from tests.helpers.mark import hardware_test +from vllm_omni.entrypoints.async_omni import AsyncOmni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("OmniTest") +pytestmark = [pytest.mark.advanced_model] + + +def clean_gpu_envs(): + """clean up GPU environment variables to ensure tests run on all available devices.""" + device_visibility_vars = [ + "CUDA_VISIBLE_DEVICES", # NVIDIA + "HIP_VISIBLE_DEVICES", # AMD ROCm + "ZE_AFFINITY_MASK", # Intel XPU + "ONEAPI_DEVICE_SELECTOR", # Intel OneAPI + "ASCEND_RT_VISIBLE_DEVICES", # Huawei NPU (CAN) + ] + for key in device_visibility_vars: + os.environ.pop(key, None) + + +def get_vram_info(device_id: int) -> dict: + """Obtain a snapshot of the specified GPU's memory (GiB).""" + try: + if current_omni_platform.is_rocm(): + num_gpus = torch.cuda.device_count() + safe_id = device_id if device_id < num_gpus else 0 + torch.cuda.synchronize(safe_id) + return { + "reserved": torch.cuda.memory_reserved(safe_id) / 1024**3, + "allocated": torch.cuda.memory_allocated(safe_id) / 1024**3, + } + else: + with torch.cuda.device(device_id): + torch.cuda.synchronize() + return { + "reserved": torch.cuda.memory_reserved() / 1024**3, + "allocated": torch.cuda.memory_allocated() / 1024**3, + } + except Exception as e: + logger.warning(f"memory skip ({device_id}): {e}") + return {"reserved": 0.0, "allocated": 0.0} + + +def get_ack_info(ack, key, default=None): + """ + Since ACKs in a distributed environment can be either objects or dictionaries, + this tool ensures compatibility. + """ + if hasattr(ack, key): + return getattr(ack, key) + if isinstance(ack, dict): + return ack.get(key, default) + return default + + +@pytest.fixture(scope="function") +async def llm_engine(): + if current_omni_platform.is_rocm(): + clean_gpu_envs() + model_name = "ByteDance-Seed/BAGEL-7B-MoT" + common_args = { + "worker_type": "ar", + "enable_sleep_mode": True, + "dtype": "bfloat16", + "trust_remote_code": True, + "max_model_len": 2048, + "max_num_batched_tokens": 8192, + "enforce_eager": True, + } + stages = [ + { + "stage_id": 0, + "stage_type": "llm", + "runtime": {"process": True, "devices": "0", "max_batch_size": 1}, + "engine_args": {**common_args, "model_stage": "thinker", "gpu_memory_utilization": 0.1}, + }, + { + "stage_id": 1, + "stage_type": "llm", + "engine_input_source": [0], + "runtime": {"process": True, "devices": "1", "max_batch_size": 1, "connector_type": "queue"}, + "engine_args": {**common_args, "model_stage": "talker", "gpu_memory_utilization": 0.1}, + }, + ] + connectors = [{"src_stage_id": 0, "dst_stage_id": 1, "connector_type": "queue"}] + engine = AsyncOmni(model=model_name, stages=stages, connectors=connectors, init_timeout=600, enable_sleep_mode=True) + yield engine + engine.shutdown() + + +@pytest.fixture(scope="function") +async def diffusion_engine(): + if current_omni_platform.is_rocm(): + clean_gpu_envs() + model_name = "ByteDance-Seed/BAGEL-7B-MoT" + stages = [ + { + "stage_id": 0, + "stage_type": "diffusion", + "runtime": {"process": True, "devices": "0,1", "max_batch_size": 1}, + "engine_args": { + "model_stage": "base", + "gpu_memory_utilization": 0.1, + "model_class_name": "BagelPipeline", + "enable_sleep_mode": True, + "enforce_eager": True, + "max_num_batched_tokens": 8192, + "parallel_config": { + "tensor_parallel_size": 2, + }, + }, + "final_output": True, + "final_output_type": "image", + } + ] + engine = AsyncOmni(model=model_name, stages=stages, init_timeout=600, enable_sleep_mode=True) + yield engine + engine.shutdown() + + +class TestOmniSleepMode: + @pytest.mark.asyncio + @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=1) + async def test_llm_sleep_ack(self, llm_engine: AsyncOmni): + """LLM Thinker (GPU0) Signal and Physical Recycling Audit""" + try: + acks = await llm_engine.sleep(stage_ids=[0], level=2) + # Verification signal successful + assert all(get_ack_info(ack, "status") == "SUCCESS" for ack in acks) + # Verify physical recycling volume + total_freed_bytes = sum(get_ack_info(ack, "freed_bytes", 0) for ack in acks) + freed_gib = total_freed_bytes / 1024**3 + logger.info(f"Thinker VRAM physically reclaimed: {freed_gib:.2f} GiB") + assert freed_gib > 5.0 + finally: + llm_engine.shutdown() + + @pytest.mark.asyncio + @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2) + async def test_diffusion_sleep_handshake(self, diffusion_engine: AsyncOmni): + """Diffusion Worker stage signal loop""" + try: + logger.info("Starting Diffusion Worker Handshake Test") + acks = await diffusion_engine.sleep(stage_ids=[0], level=2) + + def _get_status(ack): + return ack.status if hasattr(ack, "status") else ack.get("status") + + assert len(acks) >= 1, "Expected at least 1 ACK from Diffusion Workers" + assert all(_get_status(ack) == "SUCCESS" for ack in acks) + logger.info(f"Success: Received {len(acks)} Diffusion Worker ACKs") + logger.info("Testing auto-wakeup before test end...") + await diffusion_engine.wake_up(stage_ids=[0]) + logger.info("Test logic finished, triggering manual shutdown...") + finally: + diffusion_engine.shutdown() + logger.info("Manual shutdown executed. Test should exit now.") + + @pytest.mark.asyncio + @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2) + async def test_cross_device_cleanup(self, diffusion_engine: AsyncOmni): + """Physical recycling audit: leveraging deterministic data returned by Workers""" + try: + acks = await diffusion_engine.sleep(stage_ids=[0], level=1) + # Sum up the release amounts reported by all Workers. + total_freed_bytes = sum(get_ack_info(ack, "freed_bytes", 0) for ack in acks) + freed_gb = total_freed_bytes / 1024**3 + logger.info("Physical reclamation summary from workers:") + logger.info(f"- Total Workers: {len(acks)}") + logger.info(f"- Total Freed: {freed_gb:.2f} GiB") + assert freed_gb > 14.0 + logger.info("SUCCESS: 100% weights offloaded.") + finally: + diffusion_engine.shutdown() + + @pytest.mark.asyncio + @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2) + async def test_diffusion_integrity_bit_level(self, diffusion_engine: AsyncOmni): + """Bit-level consistency after Diffusion wake-up (prevent image corruption)""" + try: + prompt = "A huge swimming pool, with many people swimming." + sp = OmniDiffusionSamplingParams(num_inference_steps=4, height=512, width=512, seed=42) + llm_sp = SamplingParams() + + # Baseline Generation + logger.info("Running Baseline Generation...") + base_output = None + async for output in diffusion_engine.generate(prompt, request_id="base", sampling_params_list=[llm_sp, sp]): + base_output = output + assert base_output is not None and len(base_output.images) > 0 + logger.info("Baseline Generation successful.") + # Sleep Level 2 + logger.info("Entering Deep Sleep (VRAM Scavenging)...") + await diffusion_engine.sleep(stage_ids=[0], level=2) + # Wake-up + logger.info("Waking up (Reloading Weights)...") + await diffusion_engine.wake_up(stage_ids=[0]) + + await asyncio.sleep(2.0) + import gc + + gc.collect() + + logger.info("Running Post-Wakeup Generation...") + post_output = None + async for output in diffusion_engine.generate(prompt, request_id="post", sampling_params_list=[llm_sp, sp]): + post_output = output + # Assert result consistency + assert post_output is not None + assert len(base_output.images) == len(post_output.images) + assert post_output.images[0] is not None + logger.info("SUCCESS: Diffusion integrity verified after Sleep/Wake cycle.") + except Exception as e: + logger.error(f"Integrity test failed: {e}") + raise e + finally: + logger.info("Triggering mandatory cleanup...") + diffusion_engine.shutdown() + logger.info("Cleanup complete, test exiting.") + + @pytest.mark.asyncio + @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2) + async def test_coordinated_cross_device(self, llm_engine: AsyncOmni, diffusion_engine: AsyncOmni): + """Heterogeneous Coordinated Cleanup Test (Talker and Diffusion on GPU 1)""" + device_id = 1 + try: + logger.info(f"Waking up both engines on GPU {device_id}...") + await llm_engine.wake_up(stage_ids=[1]) + await diffusion_engine.wake_up(stage_ids=[0]) + + get_vram_info(device_id) + torch.cuda.empty_cache() + await asyncio.sleep(2) + + initial_vram = get_vram_info(device_id)["reserved"] + logger.info(f"GPU {device_id} Peak Pressure: {initial_vram:.2f} GiB") + + # coordinated sleep + logger.info("Issuing concurrent SLEEP commands...") + await llm_engine.sleep(stage_ids=[1], level=2) + await asyncio.sleep(1.0) + await diffusion_engine.sleep(stage_ids=[0], level=2) + + await asyncio.sleep(3.0) + torch.cuda.empty_cache() + + final_vram = get_vram_info(device_id)["reserved"] + logger.info(f"GPU {device_id} Final VRAM after coordinated sleep: {final_vram:.2f} GiB") + + assert initial_vram - final_vram > 15.0 or final_vram < 8.0 + logger.info(f"SUCCESS: Heterogeneous VRAM drop verified on GPU {device_id}.") + except Exception as e: + logger.error(f"Coordinated test failed: {e}") + raise e + finally: + logger.info("Triggering mandatory cleanup for both engines...") + llm_engine.shutdown() + diffusion_engine.shutdown() + logger.info("All engines scavenged. Ready for next test.") + + @pytest.mark.asyncio + @hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=2) + async def test_diffusion_vram_lifecycle_audit(self, diffusion_engine: AsyncOmni): + """Diffusion memory loop: Active -> Deep Sleep -> Active -> inference sanity check""" + device_id = 1 + try: + get_vram_info(device_id) + torch.cuda.empty_cache() + vram_initial = get_vram_info(device_id)["reserved"] + logger.info(f"Diffusion Initial VRAM: {vram_initial:.2f} GiB") + + # Sleep + logger.info("Triggering Level 2 Deep Sleep (Full Weight Offloading)...") + acks = await diffusion_engine.sleep(stage_ids=[0], level=2) + + reported_freed_bytes = sum(getattr(ack, "freed_bytes", 0) for ack in acks) + reported_freed_gib = reported_freed_bytes / 1024**3 + logger.info(f"Worker internally reported freed: {reported_freed_gib:.2f} GiB") + + await asyncio.sleep(2) + get_vram_info(device_id) + torch.cuda.empty_cache() + + vram_sleeping = get_vram_info(device_id)["reserved"] + logger.info(f"External VRAM measurement during Sleep: {vram_sleeping:.2f} GiB") + + assert reported_freed_gib > 14.0 or vram_sleeping < 5.0, ( + f"Reclamation failed. Reported: {reported_freed_gib:.2f}G, Measured: {vram_sleeping:.2f}G" + ) + + # wake-up + logger.info("Triggering Wake-up (Reloading weights to GPU)...") + await diffusion_engine.wake_up(stage_ids=[0]) + + await asyncio.sleep(2) + get_vram_info(device_id) + torch.cuda.empty_cache() + vram_restored = get_vram_info(device_id)["reserved"] + logger.info(f"VRAM after Wake-up: {vram_restored:.2f} GiB") + + assert abs(vram_restored - vram_initial) < 3.0, "VRAM failed to restore to initial levels" + + # inference sanity check + logger.info("Running post-lifecycle inference smoke test...") + prompt = "A futuristic lab with glowing lights, high quality." + sp = OmniDiffusionSamplingParams(num_inference_steps=2, height=512, width=512, seed=42) + llm_sp = SamplingParams() + + base_img_found = False + async for output in diffusion_engine.generate( + prompt, request_id="lifecycle-check", sampling_params_list=[llm_sp, sp] + ): + if output.images and output.images[0] is not None: + base_img_found = True + + assert base_img_found, "Inference failed after Wake-up cycle!" + logger.info("SUCCESS: Full Diffusion Lifecycle (Active -> Sleep -> Active -> Generate) audited.") + + except Exception as e: + logger.error(f"Lifecycle audit failed: {e}") + raise e + finally: + logger.info("Cleaning up engine and scavenging processes...") + diffusion_engine.shutdown() + await asyncio.sleep(1) diff --git a/tests/entrypoints/test_utils.py b/tests/entrypoints/test_utils.py index 248629d51df..4252d5e837e 100644 --- a/tests/entrypoints/test_utils.py +++ b/tests/entrypoints/test_utils.py @@ -358,6 +358,48 @@ def test_load_and_resolve_with_kwargs(self): assert len(stage_configs) == 1 assert "dtype" in stage_configs[0]["engine_args"] + def test_stage_configs_path_promotes_new_deploy_yaml_without_expanding_replicas( + self, tmp_path, mocker: MockerFixture + ): + deploy_path = tmp_path / "qwen3_multi.yaml" + deploy_path.write_text( + 'stages:\n - stage_id: 0\n devices: "0"\n - stage_id: 1\n devices: "1,2,3"\n num_replicas: 3\n', + encoding="utf-8", + ) + + returned_stage_configs = [ + create_config({"stage_id": 0, "runtime": {"devices": "0"}, "engine_args": {"model": "dummy"}}), + create_config( + { + "stage_id": 1, + "runtime": {"devices": "1,2,3", "num_replicas": 3}, + "engine_args": {"model": "dummy"}, + } + ), + ] + load_stage_configs = mocker.patch( + "vllm_omni.entrypoints.utils.load_stage_configs_from_model", + return_value=returned_stage_configs, + ) + + config_path, stage_configs = load_and_resolve_stage_configs( + model="dummy-model", + stage_configs_path=str(deploy_path), + kwargs={}, + ) + + load_stage_configs.assert_called_once_with( + "dummy-model", + base_engine_args={}, + deploy_config_path=str(deploy_path), + stage_overrides=None, + cli_explicit_keys=None, + ) + assert config_path == str(deploy_path) + assert len(stage_configs) == 2 + assert stage_configs[1].runtime.num_replicas == 3 + assert stage_configs[1].runtime.devices == "1,2,3" + class TestLoadStageConfigsFromYaml: """Regression tests for stage-config loading and merging.""" diff --git a/tests/helpers/media.py b/tests/helpers/media.py index 3c45c2a9d95..4463acbbbb6 100644 --- a/tests/helpers/media.py +++ b/tests/helpers/media.py @@ -544,7 +544,12 @@ def cosine_similarity_text(text1, text2, n: int = 3): norm2 = sum(b * b for b in vec2) ** 0.5 if norm1 == 0 or norm2 == 0: return 0.0 - return dot_product / (norm1 * norm2) + cosine = dot_product / (norm1 * norm2) + # Down-weight when lengths differ: repeated/hallucinated transcripts stay + # high in bag-of-ngrams cosine (e.g. ABCABCABC vs ABC) but should score low. + len1, len2 = len(text1), len(text2) + length_harmony = (2.0 * min(len1, len2)) / (len1 + len2) + return cosine * length_harmony def _merge_base64_audio_to_segment(base64_list: list[str]): diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py index 29a80372ecf..8310cedceff 100644 --- a/tests/helpers/stage_config.py +++ b/tests/helpers/stage_config.py @@ -411,6 +411,42 @@ def delete_by_path(config_dict: dict, path: str) -> None: }, }, }, + "qwen3_omni_moe_multi_replicas_4gpu": { + "base_config": "qwen3_omni_moe.yaml", + "async_chunk": True, + "stages": [ + { + "stage_id": 0, + "devices": "0", + "gpu_memory_utilization": 0.85, + "max_num_seqs": 6, + "max_model_len": 32768, + "mm_processor_cache_gb": 0, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 150, "ignore_eos": False}, + }, + { + "stage_id": 1, + "devices": "1,2,3", + "num_replicas": 3, + "gpu_memory_utilization": 0.6, + "max_num_seqs": 2, + "max_model_len": 32768, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 1000}, + }, + { + "stage_id": 2, + "devices": "1,2,3", + "num_replicas": 3, + "gpu_memory_utilization": 0.1, + "max_num_seqs": 2, + "max_num_batched_tokens": 65536, + "load_format": "dummy", + "default_sampling_params": {"max_tokens": 2000}, + }, + ], + }, # Single-stage thinker-only topology for the abort test. "qwen2_5_omni_thinker_only": { "async_chunk": False, diff --git a/tests/model_executor/models/glm_image/test_glm_image_ar.py b/tests/model_executor/models/glm_image/test_glm_image_ar.py new file mode 100644 index 00000000000..32a016b2a67 --- /dev/null +++ b/tests/model_executor/models/glm_image/test_glm_image_ar.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for GLM-Image AR model: DataParser, processor, and M-RoPE.""" + +import importlib.util +import os +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest +import torch + +# --------------------------------------------------------------------------- +# Load target classes via importlib to avoid requiring transformers.models.glm_image +# (which may not exist in CI). This follows the same pattern as +# tests/model_executor/models/qwen3_tts/test_code_predictor_dtype.py. +# --------------------------------------------------------------------------- + +_BASE = os.path.join( + os.path.dirname(__file__), + os.pardir, + os.pardir, + os.pardir, + os.pardir, + "vllm_omni", + "model_executor", + "models", + "glm_image", +) + + +def _load_module(name: str, filename: str): + path = os.path.abspath(os.path.join(_BASE, filename)) + spec = importlib.util.spec_from_file_location(name, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _build_mock_modules() -> dict[str, object]: + """Build the dict of modules to inject into sys.modules.""" + # Stub transformers.models.glm_image submodules + glm_image_mod = types.ModuleType("transformers.models.glm_image") + glm_config_mod = types.ModuleType("transformers.models.glm_image.configuration_glm_image") + glm_config_mod.GlmImageConfig = type("GlmImageConfig", (), {}) + glm_config_mod.GlmImageTextConfig = type("GlmImageTextConfig", (), {}) + glm_config_mod.GlmImageVisionConfig = type("GlmImageVisionConfig", (), {}) + glm_config_mod.GlmImageVQVAEConfig = type("GlmImageVQVAEConfig", (), {}) + glm_proc_mod = types.ModuleType("transformers.models.glm_image.processing_glm_image") + glm_proc_mod.GlmImageProcessor = type("GlmImageProcessor", (), {}) + + # vllm_omni submodules needed by the import chain + vllm_omni_mod = MagicMock() + vllm_omni_models = types.ModuleType("vllm_omni.model_executor.models") + vllm_omni_glm_image_pkg = types.ModuleType("vllm_omni.model_executor.models.glm_image") + vllm_omni_glm_image_pkg.__path__ = [os.path.abspath(_BASE)] + vllm_omni_output = MagicMock() + + return { + "transformers.models.glm_image": glm_image_mod, + "transformers.models.glm_image.configuration_glm_image": glm_config_mod, + "transformers.models.glm_image.processing_glm_image": glm_proc_mod, + "vllm_omni": vllm_omni_mod, + "vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"), + "vllm_omni.model_executor.models": vllm_omni_models, + "vllm_omni.model_executor.models.glm_image": vllm_omni_glm_image_pkg, + "vllm_omni.model_executor.models.output_templates": vllm_omni_output, + } + + +def _load_target_classes(): + """Load the glm_image_ar module with mocked dependencies.""" + mocks = _build_mock_modules() + with patch.dict(sys.modules, mocks): + mod = _load_module( + "vllm_omni.model_executor.models.glm_image.glm_image_ar", + "glm_image_ar.py", + ) + sys.modules["vllm_omni.model_executor.models.glm_image.glm_image_ar"] = mod + return mod + + +_ar_mod = _load_target_classes() + +GlmImageDataParser = _ar_mod.GlmImageDataParser +GlmImageMultiModalProcessor = _ar_mod.GlmImageMultiModalProcessor +GlmImageForConditionalGeneration = _ar_mod.GlmImageForConditionalGeneration +GlmImageRotaryEmbedding = _ar_mod.GlmImageRotaryEmbedding + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# ============================================================================= +# Helper: Minimal config for testing +# ============================================================================= + + +def _make_hf_config(**overrides): + """Create a minimal GlmImageConfig-like object for testing.""" + defaults = { + "image_token_id": 167855, + "image_start_token_id": 16384, + "image_end_token_id": 16385, + "grid_bos_token_id": None, + "grid_eos_token_id": None, + } + defaults.update(overrides) + from types import SimpleNamespace + + return SimpleNamespace(**defaults) + + +# ============================================================================= +# Tests for GlmImageDataParser +# ============================================================================= + + +class TestGlmImageDataParser: + """Test that img2img key is normalized to image in the data parser.""" + + def test_img2img_normalized_to_image(self): + parser = GlmImageDataParser.__new__(GlmImageDataParser) + parser._expected_hidden_size = 4096 + # The _get_subparsers should include img2img + subparsers = parser._get_subparsers() + assert "img2img" in subparsers + assert subparsers["img2img"] == parser._parse_image_data + + def test_parse_mm_data_normalizes_img2img(self): + parser = GlmImageDataParser.__new__(GlmImageDataParser) + parser._expected_hidden_size = 4096 + # Create a mock for the parent parse_mm_data + original_parse = type(parser).parse_mm_data + + calls = [] + + def mock_parse(mm_data, **kwargs): + calls.append(mm_data) + return MagicMock() + + # Monkey-patch temporarily + type(parser).parse_mm_data = mock_parse + try: + parser.parse_mm_data({"img2img": "fake_image"}) + except Exception: + pass # parse might fail on mock, we just check the normalization + finally: + type(parser).parse_mm_data = original_parse + + # Verify that "img2img" was normalized to "image" + if calls: + assert "image" in calls[0] + assert "img2img" not in calls[0] + + +# ============================================================================= +# Tests for _build_generation_grids +# ============================================================================= + + +class TestBuildGenerationGrids: + """Test M-RoPE grid construction for t2i mode.""" + + @pytest.fixture + def processor(self): + """Create a minimal processor instance with mocked info.""" + proc = object.__new__(GlmImageMultiModalProcessor) + proc.info = MagicMock() + return proc + + def test_1024x1024(self, processor): + kwargs = {"target_h": 1024, "target_w": 1024} + grids = processor._build_generation_grids(kwargs) + # token_h = 32, token_w = 32 + # ratio = 1.0, small_h = 16, small_w = 16 + assert grids.shape == (2, 3) + assert grids[0].tolist() == [1, 32, 32] # large + assert grids[1].tolist() == [1, 16, 16] # small + + def test_512x512(self, processor): + kwargs = {"target_h": 512, "target_w": 512} + grids = processor._build_generation_grids(kwargs) + assert grids.shape == (2, 3) + assert grids[0].tolist() == [1, 16, 16] + # small: ratio=1.0, small_h=int(sqrt(1)*16)=16, small_w=16 + assert grids[1].tolist() == [1, 16, 16] + + def test_non_square(self, processor): + kwargs = {"target_h": 1024, "target_w": 512} + grids = processor._build_generation_grids(kwargs) + # token_h = 32, token_w = 16, ratio = 2.0 + # small_h = int(sqrt(2)*16) = 22, small_w = int(sqrt(0.5)*16) = 11 + assert grids[0].tolist() == [1, 32, 16] + assert grids[1].tolist() == [1, 22, 11] + + def test_defaults_to_1024_when_no_target(self, processor): + kwargs = {} + grids = processor._build_generation_grids(kwargs) + assert grids[0].tolist() == [1, 32, 32] + + def test_height_width_fallback(self, processor): + kwargs = {"height": 512, "width": 512} + grids = processor._build_generation_grids(kwargs) + assert grids[0].tolist() == [1, 16, 16] + + def test_aligned_to_factor(self, processor): + # 1000 not aligned to 32, should be rounded down to 992 + kwargs = {"target_h": 1000, "target_w": 1000} + grids = processor._build_generation_grids(kwargs) + # 1000 // 32 = 31 + assert grids[0].tolist() == [1, 31, 31] + + +# ============================================================================= +# Tests for get_mrope_input_positions +# ============================================================================= + + +class TestGetMropeInputPositions: + """Test M-RoPE position ID computation.""" + + @pytest.fixture + def model(self): + """Create a minimal model instance for M-RoPE testing.""" + model = object.__new__(GlmImageForConditionalGeneration) + model.config = _make_hf_config() + return model + + def test_pure_text(self, model): + """Pure text tokens: all 3 dimensions get same sequential positions.""" + input_tokens = [100, 101, 102, 103] + positions, delta = model.get_mrope_input_positions(input_tokens) + assert positions.shape == (3, 4) + # All three dims should be [0, 1, 2, 3] + for dim in range(3): + assert positions[dim].tolist() == [0, 1, 2, 3] + assert delta == 0 # max(3) + 1 - seq_len(4) = 0 + + def test_t2i_with_target_size(self, model): + """t2i with explicit target_h/target_w: grids built from them.""" + input_tokens = [100, 101, 102, 16384] # text + + kwargs = {"target_h": 256, "target_w": 256} + + positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs) + # 256/32=8 -> grids = [[1,8,8], [1,16,16]] (small uses factor//2=16 base) + # Decode order (reversed): grid[-1]=[1,16,16]=256, grid[-2]=[1,8,8]=64, EOS=1 + total_decode = 256 + 64 + 1 # 321 + assert positions.shape == (3, 4 + total_decode) + # delta = max_position + 1 - seq_len + # Positions advance by max(h,w) per grid: max(16,16)=16, max(8,8)=8 + # max_pos = seq_len(4) + 16 + 8 = 28, then EOS at 28 + # delta = 28 + 1 - 4 = 25 + assert delta == 25 + + def test_t2i_1024_default_grids(self, model): + """t2i with default 1024x1024 grids when no explicit target size.""" + # Prompt ending with image_start_token_id but no image_end_token_id + input_tokens = [100, 101, 16384] + # No target_h/target_w, no mrope_image_grid_thw + # Falls back to token parsing then to default [[1,32,32], [1,16,16]] + positions, delta = model.get_mrope_input_positions(input_tokens) + assert positions.shape[0] == 3 + + def test_i2i_with_mrope_grid(self, model): + """i2i: mrope_image_grid_thw contains source + target grids.""" + # Source image tokens: [16384, 167855*4, 16385] + text + 16384(bos) + source_grid = [1, 2, 2] # 2x2 = 4 image tokens + target_grid = [1, 32, 32] # 32x32 = 1024 tokens + mrope_grid = torch.tensor([source_grid, target_grid], dtype=torch.long) + + # input_tokens: text + + 4*image_token + + + input_tokens = [100, 101, 16384] + [167855] * 4 + [16385, 16384] + + positions, delta = model.get_mrope_input_positions(input_tokens, mrope_image_grid_thw=mrope_grid) + + # 1 source image (num_complete_images=1), 1 target grid (num_decode_grids=1) + # Prefill covers all input tokens + # Decode covers: 32*32 + 1(EOS) = 1025 tokens + assert positions.shape[0] == 3 + + def test_position_delta_non_negative(self, model): + """mrope_position_delta should be non-negative for valid inputs.""" + input_tokens = [100, 16384] + kwargs = {"target_h": 64, "target_w": 64} + positions, delta = model.get_mrope_input_positions(input_tokens, **kwargs) + assert delta >= 0 + + +# ============================================================================= +# Tests for GlmImageRotaryEmbedding._apply_mrope +# ============================================================================= + + +class TestGlmImageRotaryEmbedding: + """Test M-RoPE section interleaving in the rotary embedding.""" + + @pytest.fixture + def rotary_emb(self): + # mrope_section=[8,12,12] sums to 32, so rotary_dim//2 must be >= 32 + # -> head_dim=64 gives rotary_dim=64, rotary_dim//2=32 + return GlmImageRotaryEmbedding(head_dim=64, mrope_section=[8, 12, 12]) + + def test_apply_mrope_shape(self, rotary_emb): + """Output shape matches [num_tokens, rotary_dim // 2].""" + freqs = torch.randn(3, 5, 32) # 3 dims, 5 tokens, rotary_dim//2=32 + result = rotary_emb._apply_mrope(freqs) + assert result.shape == (5, 32) + + def test_apply_mrope_interleaving(self, rotary_emb): + """Verify that M-RoPE correctly interleaves T/H/W sections.""" + # mrope_section = [8, 12, 12] splits dim 32 into 3 chunks: [8, 12, 12] + # chunk 0 (size 8): dim 0 % 3 = 0 (temporal) + # chunk 1 (size 12): dim 1 % 3 = 1 (height) + # chunk 2 (size 12): dim 2 % 3 = 2 (width) + freqs = torch.ones(3, 1, 32) + freqs[0, :, :] = 1.0 # temporal + freqs[1, :, :] = 2.0 # height + freqs[2, :, :] = 3.0 # width + + result = rotary_emb._apply_mrope(freqs) + assert result.shape == (1, 32) + assert (result[0, :8] == 1.0).all() # chunk 0: temporal + assert (result[0, 8:20] == 2.0).all() # chunk 1: height + assert (result[0, 20:32] == 3.0).all() # chunk 2: width + + def test_forward_1d_positions(self, rotary_emb): + """Forward with 1D positions (text-only) produces correct shapes.""" + positions = torch.arange(10) # [10] + q = torch.randn(10, 64) + k = torch.randn(10, 64) + q_out, k_out = rotary_emb(positions, q, k) + assert q_out.shape == (10, 64) + assert k_out.shape == (10, 64) + + def test_forward_3d_positions(self, rotary_emb): + """Forward with 3D M-RoPE positions produces correct shapes.""" + positions = torch.arange(30).reshape(3, 10) # [3, 10] + q = torch.randn(10, 64) + k = torch.randn(10, 64) + q_out, k_out = rotary_emb(positions, q, k) + assert q_out.shape == (10, 64) + assert k_out.shape == (10, 64) + + def test_forward_preserves_dtype(self, rotary_emb): + """Output dtype matches input dtype.""" + positions = torch.arange(5) + q = torch.randn(5, 64, dtype=torch.float32) + k = torch.randn(5, 64, dtype=torch.float32) + q_out, k_out = rotary_emb(positions, q, k) + assert q_out.dtype == torch.float32 + assert k_out.dtype == torch.float32 diff --git a/tests/model_executor/stage_input_processors/test_glm_image.py b/tests/model_executor/stage_input_processors/test_glm_image.py new file mode 100644 index 00000000000..88352cac248 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_glm_image.py @@ -0,0 +1,389 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for GLM-Image stage input processor.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors.glm_image import ( + _first_source_image, + _has_source_image, + _parse_generated_tokens, + _upsample_token_ids, + ar2diffusion, + compute_max_tokens, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _source_output(token_ids: list[int], mm_output: dict | None = None): + """Create a minimal AR output mock.""" + return SimpleNamespace( + outputs=[SimpleNamespace(token_ids=token_ids)], + multimodal_output=mm_output, + ) + + +# ============================================================================= +# Tests for _has_source_image +# ============================================================================= + + +class TestHasSourceImage: + def test_none_input(self): + assert _has_source_image(None) is False + + def test_non_dict_input(self): + assert _has_source_image("not_a_dict") is False + + def test_empty_dict(self): + assert _has_source_image({}) is False + + def test_image_key_present(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _has_source_image({"image": img}) is True + + def test_image_key_none(self): + assert _has_source_image({"image": None}) is False + + def test_img2img_key_present(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _has_source_image({"img2img": img}) is True + + def test_images_key_list(self): + from PIL import Image + + imgs = [Image.new("RGB", (64, 64))] + assert _has_source_image({"images": imgs}) is True + + def test_images_key_empty_list(self): + assert _has_source_image({"images": []}) is False + + def test_images_key_single(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _has_source_image({"images": img}) is True + + +# ============================================================================= +# Tests for _first_source_image +# ============================================================================= + + +class TestFirstSourceImage: + def test_none_input(self): + assert _first_source_image(None) is None + + def test_non_dict_input(self): + assert _first_source_image("not_a_dict") is None + + def test_image_key_single(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"image": img}) is img + + def test_image_key_list(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"image": [img]}) is img + + def test_image_key_empty_list(self): + assert _first_source_image({"image": []}) is None + + def test_img2img_key_single(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"img2img": img}) is img + + def test_images_key_list(self): + from PIL import Image + + imgs = [Image.new("RGB", (64, 64))] + assert _first_source_image({"images": imgs}) is imgs[0] + + def test_images_key_empty_list(self): + assert _first_source_image({"images": []}) is None + + def test_images_key_single_not_list(self): + from PIL import Image + + img = Image.new("RGB", (64, 64)) + assert _first_source_image({"images": img}) is img + + +# ============================================================================= +# Tests for compute_max_tokens +# ============================================================================= + + +class TestComputeMaxTokens: + def test_t2i_1024x1024(self): + # t2i: small_tokens + large_tokens + 1 (EOS) + # token_h = 1024/32 = 32, token_w = 1024/32 = 32 + # large = 32*32 = 1024 + # ratio = 1.0, small_h = sqrt(1)*16 = 16, small_w = sqrt(1)*16 = 16, small = 256 + # total = 256 + 1024 + 1 = 1281 + result = compute_max_tokens(1024, 1024, is_i2i=False) + assert result == 1281 + + def test_i2i_1024x1024(self): + # i2i: large_tokens + 1 (EOS) + # large = 32*32 = 1024, total = 1025 + result = compute_max_tokens(1024, 1024, is_i2i=True) + assert result == 1025 + + def test_t2i_512x512(self): + # token_h = 16, token_w = 16, large = 256 + # ratio = 1.0, small_h = 16, small_w = 16, small = 256 + # total = 256 + 256 + 1 = 513 + result = compute_max_tokens(512, 512, is_i2i=False) + assert result == 513 + + def test_i2i_512x512(self): + # large = 256, total = 257 + result = compute_max_tokens(512, 512, is_i2i=True) + assert result == 257 + + def test_non_square_t2i(self): + # 1024x512: token_h=32, token_w=16, large=512 + # ratio = 32/16 = 2.0 + # small_h = max(1, int(sqrt(2)*16)) = 22, small_w = max(1, int(sqrt(0.5)*16)) = 11 + # small = 22*11 = 242 + # total = 242 + 512 + 1 = 755 + result = compute_max_tokens(1024, 512, is_i2i=False) + assert result == 242 + 512 + 1 + + def test_custom_factor(self): + # factor=16, 512x512: token_h=32, token_w=32, large=1024 + # ratio=1.0, small_h=8, small_w=8, small=64 + # total = 64 + 1024 + 1 = 1089 + result = compute_max_tokens(512, 512, factor=16, is_i2i=False) + assert result == 1089 + + def test_i2i_smaller_than_t2i(self): + t2i = compute_max_tokens(1024, 1024, is_i2i=False) + i2i = compute_max_tokens(1024, 1024, is_i2i=True) + assert i2i < t2i + + +# ============================================================================= +# Tests for _upsample_token_ids +# ============================================================================= + + +class TestUpsampleTokenIds: + def test_2x2_to_4x4(self): + tokens = torch.tensor([1, 2, 3, 4]) + result = _upsample_token_ids(tokens, 2, 2) + assert result.shape == (16,) # 4 * 4 = 16 (2x each dim) + + def test_1x1_to_2x2(self): + tokens = torch.tensor([7]) + result = _upsample_token_ids(tokens, 1, 1) + assert result.shape == (4,) # 2 * 2 + assert (result == 7).all() + + def test_4x4_to_8x8(self): + tokens = torch.arange(16, dtype=torch.long) + result = _upsample_token_ids(tokens, 4, 4) + assert result.shape == (64,) + + def test_preserves_dtype(self): + tokens = torch.tensor([1, 2, 3, 4], dtype=torch.long) + result = _upsample_token_ids(tokens, 2, 2) + assert result.dtype == torch.long + + +# ============================================================================= +# Tests for _parse_generated_tokens +# ============================================================================= + + +class TestParseGeneratedTokens: + def test_t2i_standard(self): + # 1024x1024, t2i: small(256) + large(1024) + EOS + # Generate 256 + 1024 + 1 = 1281 tokens, last is EOS (16385) + large_tokens = list(range(1024)) + small_tokens = list(range(1000, 1256)) + eos = [16385] + token_ids = small_tokens + large_tokens + eos + + prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=False) + assert h == 1024 + assert w == 1024 + # Prior tokens should be upsampled: 1024 tokens -> 4*1024 = 4096 + assert prior.shape[0] == 1024 * 4 + + def test_i2i_standard(self): + # 1024x1024, i2i: large(1024) + EOS + large_tokens = list(range(1024)) + eos = [16385] + token_ids = large_tokens + eos + + prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True) + assert h == 1024 + assert w == 1024 + assert prior.shape[0] == 1024 * 4 + + def test_i2i_without_eos(self): + # i2i without EOS marker + large_tokens = list(range(1024)) + prior, h, w = _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=True) + assert h == 1024 + assert w == 1024 + + def test_i2i_too_few_tokens_raises(self): + with pytest.raises(ValueError, match="i2i token parse failed"): + _parse_generated_tokens([1, 2, 3], 1024, 1024, is_i2i=True) + + def test_t2i_too_few_tokens_raises(self): + # Only large tokens, no small preview + large_tokens = list(range(1024)) + with pytest.raises(ValueError, match="t2i token parse failed"): + _parse_generated_tokens(large_tokens, 1024, 1024, is_i2i=False) + + def test_i2i_t2i_style_layout_fallback(self): + # i2i but got t2i-style (small + large) tokens + small_tokens = list(range(256)) + large_tokens = list(range(1024)) + token_ids = small_tokens + large_tokens + + prior, h, w = _parse_generated_tokens(token_ids, 1024, 1024, is_i2i=True) + # Should extract the large portion + assert h == 1024 + assert w == 1024 + + +# ============================================================================= +# Tests for ar2diffusion +# ============================================================================= + + +class TestAr2Diffusion: + def test_basic_t2i(self): + """Test basic text-to-image pipeline: AR -> Diffusion.""" + # 1024x1024 t2i: small(256) + large(1024) + EOS + token_ids = list(range(256)) + list(range(1024)) + [16385] + source_outputs = [_source_output(token_ids)] + + prompt = {"prompt": "a cat", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}} + + result = ar2diffusion(source_outputs, prompt=[prompt]) + assert len(result) == 1 + assert result[0]["prompt"] == "a cat" + assert result[0]["height"] == 1024 + assert result[0]["width"] == 1024 + assert "prior_token_ids" in result[0]["extra"] + + def test_i2i_with_mm_output(self): + """Test image-to-image with prior_token_image_ids from AR model.""" + token_ids = list(range(1024)) + [16385] + mm_output = {"prior_token_image_ids": torch.tensor([1, 2, 3])} + source_outputs = [_source_output(token_ids, mm_output)] + + from PIL import Image + + img = Image.new("RGB", (64, 64)) + prompt = { + "prompt": "edit this", + "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}, + "multi_modal_data": {"image": img}, + } + + result = ar2diffusion(source_outputs, prompt=[prompt]) + assert len(result) == 1 + assert result[0]["extra"]["prior_token_image_ids"] is not None + + def test_i2i_detected_via_modalities(self): + """Test i2i mode detected via modalities field.""" + token_ids = list(range(1024)) + [16385] + source_outputs = [_source_output(token_ids)] + + prompt = { + "prompt": "edit this", + "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}, + "modalities": ["img2img"], + } + + result = ar2diffusion(source_outputs, prompt=[prompt]) + assert len(result) == 1 + + def test_empty_source_outputs_returns_empty_list(self): + assert ar2diffusion([], prompt={}) == [] + + def test_default_dimensions(self): + """When no height/width in prompt, defaults to 1024x1024.""" + token_ids = list(range(256)) + list(range(1024)) + [16385] + source_outputs = [_source_output(token_ids)] + + prompt = {"prompt": "test"} + result = ar2diffusion(source_outputs, prompt=[prompt]) + assert result[0]["height"] == 1024 + assert result[0]["width"] == 1024 + + def test_requires_multimodal_data_with_pil_image(self): + """Test that pil_image is included when requires_multimodal_data=True.""" + token_ids = list(range(256)) + list(range(1024)) + [16385] + source_outputs = [_source_output(token_ids)] + + from PIL import Image + + img = Image.new("RGB", (64, 64)) + prompt = { + "prompt": "test", + "multi_modal_data": {"image": img}, + } + + result = ar2diffusion(source_outputs, prompt=[prompt], requires_multimodal_data=True) + assert result[0]["pil_image"] is img + + def test_extra_params_passed_through(self): + """Test that seed, num_inference_steps, guidance_scale, negative_prompt are passed.""" + token_ids = list(range(256)) + list(range(1024)) + [16385] + source_outputs = [_source_output(token_ids)] + + prompt = { + "prompt": "test", + "seed": 42, + "num_inference_steps": 50, + "guidance_scale": 7.5, + "negative_prompt": "blurry", + } + + result = ar2diffusion(source_outputs, prompt=[prompt]) + assert result[0]["seed"] == 42 + assert result[0]["num_inference_steps"] == 50 + assert result[0]["guidance_scale"] == 7.5 + assert result[0]["negative_prompt"] == "blurry" + + def test_batch_requests(self): + """Test processing multiple requests in a batch.""" + tokens1 = list(range(256)) + list(range(1024)) + [16385] + tokens2 = list(range(256)) + list(range(1024)) + [16385] + source_outputs = [_source_output(tokens1), _source_output(tokens2)] + + prompts = [ + {"prompt": "first", "mm_processor_kwargs": {"target_h": 1024, "target_w": 1024}}, + {"prompt": "second", "mm_processor_kwargs": {"target_h": 512, "target_w": 512}}, + ] + + result = ar2diffusion(source_outputs, prompt=prompts) + assert len(result) == 2 + assert result[0]["prompt"] == "first" + assert result[1]["prompt"] == "second" diff --git a/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py b/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py new file mode 100644 index 00000000000..1ea0ccfa708 --- /dev/null +++ b/tests/model_executor/stage_input_processors/test_mimo_audio_llm2code2wav.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import logging +from types import SimpleNamespace + +import pytest +import torch + +from vllm_omni.model_executor.stage_input_processors import mimo_audio as sip +from vllm_omni.model_executor.stage_input_processors.mimo_audio import ( + MAX_CODE2WAV_TOKENS, + llm2code2wav, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu] + + +def _make_source_outputs(codec_codes: torch.Tensor, request_id: str = "req-0"): + """Build minimal source_outputs with one talker output carrying codec_codes.""" + output = SimpleNamespace(multimodal_output={"code_predictor_codes": codec_codes}) + talker_output = SimpleNamespace(outputs=[output], request_id=request_id) + return [talker_output] + + +def test_llm2code2wav_truncates_when_flat_exceeds_max(caplog): + """Flat codec sequences longer than MAX_CODE2WAV_TOKENS must be truncated, not passed through.""" + # prepend_and_flatten_colmajor produces 36 ids per (8, 4) codec frame: + # pad adds one row -> (9, 4) per frame, permuted and flattened. + # Pick enough frames to comfortably exceed the cap. + frames = (MAX_CODE2WAV_TOKENS // 36) + 100 + codec_codes = torch.ones(frames, 1, 8, 4, dtype=torch.long) + + source_outputs = _make_source_outputs(codec_codes, request_id="req-long") + + # Attach caplog's handler directly to the module logger so the warning is + # captured regardless of propagation (vllm's logger configuration can + # interact badly with caplog.at_level's default root-handler path). + target_logger = logging.getLogger("vllm_omni.model_executor.stage_input_processors.mimo_audio") + target_logger.addHandler(caplog.handler) + prev_level = target_logger.level + target_logger.setLevel(logging.WARNING) + try: + prompts = llm2code2wav(source_outputs) + finally: + target_logger.removeHandler(caplog.handler) + target_logger.setLevel(prev_level) + + assert len(prompts) == 1 + assert len(prompts[0]["prompt_token_ids"]) == MAX_CODE2WAV_TOKENS + assert any("truncating" in rec.getMessage() for rec in caplog.records), ( + f"Expected a 'truncating' warning; captured records: {[r.getMessage() for r in caplog.records]}" + ) + + +def test_llm2code2wav_short_sequence_unchanged(): + """Short codec sequences are returned without truncation.""" + codec_codes = torch.ones(4, 1, 8, 4, dtype=torch.long) + source_outputs = _make_source_outputs(codec_codes, request_id="req-short") + + prompts = llm2code2wav(source_outputs) + + assert len(prompts) == 1 + # 4 frames + 1 pad row, flattened col-major → well below the cap + assert 0 < len(prompts[0]["prompt_token_ids"]) <= MAX_CODE2WAV_TOKENS + + +def test_llm2code2wav_truncation_boundary_constant_matches_yaml(): + """MAX_CODE2WAV_TOKENS must match the stage-1 max_model_len in mimo_audio.yaml and end2end.py.""" + assert sip.MAX_CODE2WAV_TOKENS == 18192 diff --git a/tests/profile/test_omni_torch_profiler.py b/tests/profile/test_omni_torch_profiler.py new file mode 100644 index 00000000000..3920078af4d --- /dev/null +++ b/tests/profile/test_omni_torch_profiler.py @@ -0,0 +1,582 @@ +# tests/test_omni_torch_profiler.py +from __future__ import annotations + +import gzip +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace + +import pytest +from openpyxl import load_workbook + +import vllm_omni.profiler.omni_torch_profiler as profiler_mod +from vllm_omni.profiler.omni_torch_profiler import OmniTorchProfilerWrapper + + +@pytest.fixture(autouse=True) +def patch_worker_profiler_init(monkeypatch): + def fake_init(self, profiler_config): + self.profiler_config = profiler_config + + monkeypatch.setattr( + profiler_mod.WorkerProfiler, + "__init__", + fake_init, + ) + + +@dataclass +class DummyProfilerConfig: + torch_profiler_dir: str + torch_profiler_use_gzip: bool = False + torch_profiler_record_shapes: bool = True + torch_profiler_with_memory: bool = True + torch_profiler_with_stack: bool = True + torch_profiler_with_flops: bool = False + torch_profiler_dump_cuda_time_total: bool = False + + +class FakeEvent: + def __init__( + self, + *, + name: str = "aten::mm", + count: int = 1, + input_shapes=None, + stack=None, + self_cpu_time_total: float = 10.0, + cpu_time_total: float = 12.0, + self_cuda_time_total: float = 20.0, + cuda_time_total: float = 25.0, + self_xpu_time_total: float = 0.0, + xpu_time_total: float = 0.0, + self_cpu_memory_usage: int = 128, + cpu_memory_usage: int = 256, + self_cuda_memory_usage: int = 1024, + cuda_memory_usage: int = 2048, + self_xpu_memory_usage: int = 0, + xpu_memory_usage: int = 0, + device_type: str = "CUDA", + node_id: int = 0, + overload_name: str = "", + is_async: bool = False, + is_legacy: bool = False, + ): + self.key = name + self.name = name + self.count = count + self.input_shapes = input_shapes if input_shapes is not None else [[2, 2], [2, 2]] + self.stack = stack if stack is not None else ["frame_a", "frame_b"] + self.self_cpu_time_total = self_cpu_time_total + self.cpu_time_total = cpu_time_total + self.self_cuda_time_total = self_cuda_time_total + self.cuda_time_total = cuda_time_total + self.self_xpu_time_total = self_xpu_time_total + self.xpu_time_total = xpu_time_total + self.self_cpu_memory_usage = self_cpu_memory_usage + self.cpu_memory_usage = cpu_memory_usage + self.self_cuda_memory_usage = self_cuda_memory_usage + self.cuda_memory_usage = cuda_memory_usage + self.self_xpu_memory_usage = self_xpu_memory_usage + self.xpu_memory_usage = xpu_memory_usage + self.device_type = device_type + self.node_id = node_id + self.overload_name = overload_name + self.is_async = is_async + self.is_legacy = is_legacy + + +class FakeEventList(list): + def table(self, sort_by=None, row_limit=-1): + return f"fake_table(sort_by={sort_by}, row_limit={row_limit}, len={len(self)})" + + +class FakeTorchProfiler: + def __init__(self, on_trace_ready=None): + self.started = False + self.stopped = False + self.on_trace_ready = on_trace_ready + self.exported_traces = [] + self.exported_stacks = [] + + def start(self): + self.started = True + + def stop(self): + self.stopped = True + if self.on_trace_ready is not None: + self.on_trace_ready(self) + + def export_chrome_trace(self, path): + Path(path).write_text('{"traceEvents": []}') + self.exported_traces.append(path) + + def export_stacks(self, path, metric): + Path(path).write_text(f"metric={metric}\nstack_line_1\nstack_line_2\n") + self.exported_stacks.append((path, metric)) + + def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): + if group_by_input_shape: + return FakeEventList( + [ + FakeEvent( + name="aten::bmm", + input_shapes=[[4, 8, 16], [4, 16, 32]], + ) + ] + ) + if group_by_stack_n: + return FakeEventList( + [ + FakeEvent( + name="aten::all_reduce", + stack=["python_a", "python_b", "python_c"], + ) + ] + ) + return FakeEventList( + [ + FakeEvent(name="aten::mm"), + FakeEvent(name="nccl:all_reduce"), + ] + ) + + +@pytest.fixture +def fake_config(tmp_path): + return DummyProfilerConfig(torch_profiler_dir=str(tmp_path)) + + +@pytest.fixture +def fake_profiler_factory(monkeypatch): + created = {} + + def fake_profile(*args, **kwargs): + profiler = FakeTorchProfiler(on_trace_ready=kwargs.get("on_trace_ready")) + created["profiler"] = profiler + created["args"] = args + created["kwargs"] = kwargs + return profiler + + monkeypatch.setattr(profiler_mod.torch.profiler, "profile", fake_profile) + return created + + +@pytest.fixture +def wrapper(fake_config, fake_profiler_factory): + return OmniTorchProfilerWrapper( + profiler_config=fake_config, + worker_name="worker0", + local_rank=0, + activities=["CPU", "CUDA"], + ) + + +def test_set_trace_filename_creates_timestamped_session_dir(wrapper, monkeypatch, tmp_path): + class FixedDatetime: + @classmethod + def now(cls): + class _Now: + def strftime(self, fmt): + return "20260403-034200" + + return _Now() + + monkeypatch.setattr(profiler_mod, "datetime", FixedDatetime) + + wrapper.set_trace_filename("stage_0_llm_1234567890") + + session_dir = Path(wrapper._session_dir) + assert session_dir.exists() + assert session_dir.parent == tmp_path + assert session_dir.name == "20260403-034200_stage_0_llm_1234567890" + + +def test_set_trace_filename_with_full_path_creates_timestamped_leaf(wrapper, monkeypatch, tmp_path): + class FixedDatetime: + @classmethod + def now(cls): + class _Now: + def strftime(self, fmt): + return "20260403-111111" + + return _Now() + + monkeypatch.setattr(profiler_mod, "datetime", FixedDatetime) + + target = tmp_path / "nested" / "stage_x" + wrapper.set_trace_filename(str(target)) + + session_dir = Path(wrapper._session_dir) + assert session_dir.exists() + assert session_dir.parent == target.parent + assert session_dir.name == "20260403-111111_stage_x" + + +def test_on_trace_ready_exports_trace_json(wrapper): + wrapper.set_trace_filename("case_trace") + + wrapper._on_trace_ready(wrapper.profiler) + + trace_path = Path(wrapper._trace_path) + assert trace_path.exists() + assert trace_path.name == "trace_rank0.json" + assert trace_path.read_text() == '{"traceEvents": []}' + + +def test_on_trace_ready_exports_gzip_trace(fake_config, fake_profiler_factory, monkeypatch): + fake_config.torch_profiler_use_gzip = True + + wrapper = OmniTorchProfilerWrapper( + profiler_config=fake_config, + worker_name="worker0", + local_rank=0, + activities=["CPU", "CUDA"], + ) + wrapper.set_trace_filename("case_gzip") + + def fake_popen(cmd): + assert cmd[:2] == ["gzip", "-f"] + src = Path(cmd[2]) + gz_path = src.with_suffix(src.suffix + ".gz") + gz_path.write_bytes(gzip.compress(src.read_bytes())) + src.unlink() + + class DummyProc: + pass + + return DummyProc() + + monkeypatch.setattr(profiler_mod.subprocess, "Popen", fake_popen) + + wrapper._on_trace_ready(wrapper.profiler) + + assert wrapper._trace_path.endswith(".json.gz") + gz_path = Path(wrapper._trace_path) + assert gz_path.exists() + assert gzip.decompress(gz_path.read_bytes()) == b'{"traceEvents": []}' + + +def test_start_enables_memory_history(wrapper, monkeypatch): + calls = [] + + monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True) + + def fake_record_memory_history(*args, **kwargs): + calls.append((args, kwargs)) + + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_record_memory_history", + fake_record_memory_history, + ) + + wrapper.set_trace_filename("case_memory_start") + wrapper._start() + + assert wrapper.profiler.started is True + assert wrapper._memory_history_enabled is True + assert len(calls) == 1 + assert calls[0][1]["enabled"] == "all" + assert calls[0][1]["context"] == "all" + assert calls[0][1]["stacks"] == "python" + assert calls[0][1]["max_entries"] == 100000 + assert calls[0][1]["clear_history"] is True + + +def test_start_skips_memory_history_when_memory_disabled(fake_config, fake_profiler_factory, monkeypatch): + fake_config.torch_profiler_with_memory = False + + wrapper = OmniTorchProfilerWrapper( + profiler_config=fake_config, + worker_name="worker0", + local_rank=0, + activities=["CPU", "CUDA"], + ) + + called = {"n": 0} + + monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True) + + def fake_record_memory_history(*args, **kwargs): + called["n"] += 1 + + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_record_memory_history", + fake_record_memory_history, + ) + + wrapper.set_trace_filename("case_skip_memory") + wrapper._start() + + assert called["n"] == 0 + assert wrapper._memory_history_enabled is False + + +def test_try_dump_memory_snapshot_writes_pickle(wrapper, monkeypatch): + wrapper.set_trace_filename("case_snapshot") + wrapper._memory_history_enabled = True + wrapper._memory_history_backend = "CUDA" + wrapper._memory_history_module = profiler_mod.torch.cuda.memory + + disable_calls = [] + + def fake_record_memory_history(*args, **kwargs): + disable_calls.append((args, kwargs)) + + def fake_dump_snapshot(path): + Path(path).write_bytes(b"fake pickle bytes") + + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_record_memory_history", + fake_record_memory_history, + ) + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_dump_snapshot", + fake_dump_snapshot, + ) + + wrapper._try_dump_memory_snapshot() + + snapshot = Path(wrapper._artifact_paths["memory_snapshot"]) + assert snapshot.exists() + assert snapshot.name == "memory_snapshot_rank0.pickle" + assert snapshot.read_bytes() == b"fake pickle bytes" + assert wrapper._memory_history_enabled is False + + assert disable_calls[-1][1]["enabled"] is None + + +def test_stop_always_dumps_memory_snapshot_on_success_path(wrapper, monkeypatch): + wrapper.set_trace_filename("case_stop") + + record_calls = [] + dump_calls = [] + + monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True) + + def fake_record_memory_history(*args, **kwargs): + record_calls.append((args, kwargs)) + + def fake_dump_snapshot(path): + dump_calls.append(path) + Path(path).write_bytes(b"snapshot-bytes") + + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_record_memory_history", + fake_record_memory_history, + ) + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_dump_snapshot", + fake_dump_snapshot, + ) + + wrapper._start() + wrapper._stop() + + session_dir = Path(wrapper._session_dir) + + assert wrapper.profiler.started is True + assert wrapper.profiler.stopped is True + assert (session_dir / "memory_snapshot_rank0.pickle").exists() + assert len(dump_calls) == 1 + assert record_calls[0][1]["enabled"] == "all" + assert record_calls[-1][1]["enabled"] is None + + +def test_on_stop_hook_generates_stack_and_excel_artifacts(wrapper): + wrapper.set_trace_filename("case_artifacts") + wrapper._on_stop_hook() + + session_dir = Path(wrapper._session_dir) + + assert not (session_dir / "ops_summary_rank0.txt").exists() + assert not (session_dir / "ops_by_shape_rank0.txt").exists() + assert not (session_dir / "ops_by_stack_rank0.txt").exists() + assert (session_dir / "stacks_cpu_rank0.txt").exists() + assert (session_dir / "stacks_cuda_rank0.txt").exists() + assert (session_dir / "ops_rank0.xlsx").exists() + + +def test_excel_contains_expected_sheets(wrapper): + wrapper.set_trace_filename("case_excel") + wrapper._on_stop_hook() + + xlsx_path = Path(wrapper._session_dir) / "ops_rank0.xlsx" + wb = load_workbook(xlsx_path) + + assert "summary" in wb.sheetnames + assert "by_shape" in wb.sheetnames + assert "by_stack" in wb.sheetnames + + +def test_excel_summary_has_expected_columns(wrapper): + wrapper.set_trace_filename("case_excel_columns") + wrapper._on_stop_hook() + + xlsx_path = Path(wrapper._session_dir) / "ops_rank0.xlsx" + wb = load_workbook(xlsx_path) + ws = wb["summary"] + + headers = [cell.value for cell in next(ws.iter_rows(min_row=1, max_row=1))] + assert "name" in headers + assert "count" in headers + assert "self_cpu_time_total_us" in headers + assert "self_cuda_time_total_us" in headers + assert "self_cpu_memory_usage_bytes" in headers + assert "self_cuda_memory_usage_bytes" in headers + assert "input_shapes" in headers + assert "stack" in headers + + +def test_get_results_returns_all_artifact_paths(wrapper, monkeypatch): + wrapper.set_trace_filename("case_results") + + monkeypatch.setattr(profiler_mod.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_record_memory_history", + lambda *args, **kwargs: None, + ) + monkeypatch.setattr( + profiler_mod.torch.cuda.memory, + "_dump_snapshot", + lambda path: Path(path).write_bytes(b"snapshot"), + ) + + wrapper._start() + wrapper._stop() + + results = wrapper.get_results() + + assert "trace" in results + assert "table" in results + assert "session_dir" in results + assert "ops" in results + assert "memory_snapshot" in results + assert Path(results["session_dir"]).exists() + assert Path(results["ops"]).exists() + assert Path(results["table"]).exists() + assert Path(results["table"]).name == "ops_rank0.xlsx" + assert Path(results["memory_snapshot"]).exists() + + +def test_start_uses_xpu_memory_history_when_available(wrapper, monkeypatch): + calls = [] + + def fake_record_memory_history(*args, **kwargs): + calls.append((args, kwargs)) + + fake_memory_module = SimpleNamespace( + _record_memory_history=fake_record_memory_history, + ) + monkeypatch.setattr( + wrapper, + "_resolve_memory_history_backend", + lambda: ("XPU", fake_memory_module), + ) + + wrapper.set_trace_filename("case_xpu_memory_start") + wrapper._start() + + assert wrapper._memory_history_enabled is True + assert wrapper._memory_history_backend == "XPU" + assert wrapper._memory_history_module is fake_memory_module + assert calls[0][1]["enabled"] == "all" + + +def test_start_uses_npu_memory_history_when_available(wrapper, monkeypatch): + calls = [] + + def fake_record_memory_history(*args, **kwargs): + calls.append((args, kwargs)) + + fake_memory_module = SimpleNamespace( + _record_memory_history=fake_record_memory_history, + ) + monkeypatch.setattr( + wrapper, + "_resolve_memory_history_backend", + lambda: ("NPU", fake_memory_module), + ) + + wrapper.set_trace_filename("case_npu_memory_start") + wrapper._start() + + assert wrapper._memory_history_enabled is True + assert wrapper._memory_history_backend == "NPU" + assert wrapper._memory_history_module is fake_memory_module + assert calls[0][1]["enabled"] == "all" + + +def test_start_skips_memory_history_when_backend_api_missing(wrapper, monkeypatch): + fake_memory_module = SimpleNamespace() + monkeypatch.setattr( + wrapper, + "_resolve_memory_history_backend", + lambda: ("XPU", fake_memory_module), + ) + + wrapper.set_trace_filename("case_missing_memory_api") + wrapper._start() + + assert wrapper._memory_history_enabled is False + assert wrapper._memory_history_backend is None + assert wrapper._memory_history_module is None + + +def test_try_dump_memory_snapshot_uses_resolved_backend_module(wrapper): + wrapper.set_trace_filename("case_xpu_snapshot") + wrapper._memory_history_enabled = True + wrapper._memory_history_backend = "XPU" + + calls = [] + + def fake_record_memory_history(*args, **kwargs): + calls.append((args, kwargs)) + + def fake_dump_snapshot(path): + Path(path).write_bytes(b"xpu snapshot bytes") + + wrapper._memory_history_module = SimpleNamespace( + _record_memory_history=fake_record_memory_history, + _dump_snapshot=fake_dump_snapshot, + ) + + wrapper._try_dump_memory_snapshot() + + snapshot = Path(wrapper._artifact_paths["memory_snapshot"]) + assert snapshot.exists() + assert snapshot.read_bytes() == b"xpu snapshot bytes" + assert calls[-1][1]["enabled"] is None + assert wrapper._memory_history_enabled is False + assert wrapper._memory_history_backend is None + assert wrapper._memory_history_module is None + + +def test_event_list_to_rows_contains_expected_fields(wrapper): + rows = wrapper._event_list_to_rows( + [ + FakeEvent( + name="aten::linear", + input_shapes=[[8, 16], [16, 32]], + stack=["f1", "f2"], + ) + ] + ) + + assert len(rows) == 1 + row = rows[0] + assert row["name"] == "aten::linear" + assert row["count"] == 1 + assert row["self_cpu_time_total_us"] == 10.0 + assert row["self_cuda_time_total_us"] == 20.0 + assert row["self_cpu_memory_usage_bytes"] == 128 + assert row["self_cuda_memory_usage_bytes"] == 1024 + assert "[[8, 16], [16, 32]]" == row["input_shapes"] + assert row["stack"] == "f1\nf2" diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py index 4a271426ff3..778314f7a90 100644 --- a/tests/test_config_factory.py +++ b/tests/test_config_factory.py @@ -961,6 +961,36 @@ def test_per_stage_model_arch_flows_through_merge(self, tmp_path): # Stage 1 uses its per-stage override assert stages[1].yaml_engine_args["model_arch"] == "Qwen3TTSCode2Wav" + def test_subtalker_sampling_params_deep_merge_preserves_base_keys(self): + """Verify subtalker sampling params participate in stage deep-merge.""" + from vllm_omni.config.stage_config import _deep_merge_stage + + base = { + "stage_id": 0, + "subtalker_sampling_params": { + "do_sample": True, + "temperature": 0.9, + "top_k": 50, + "top_p": 1.0, + }, + } + overlay = { + "stage_id": 0, + "subtalker_sampling_params": { + "temperature": 0.7, + "top_k": 32, + }, + } + + merged = _deep_merge_stage(base, overlay) + + assert merged["subtalker_sampling_params"] == { + "do_sample": True, + "temperature": 0.7, + "top_k": 32, + "top_p": 1.0, + } + class TestBaseConfigInheritance: """Test deploy YAML base_config inheritance.""" diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py index b2d61931558..a74c9ffc2d2 100644 --- a/tests/worker/test_omni_gpu_model_runner.py +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -41,7 +41,17 @@ def __init__(self): class DummyTalkerMTP(torch.nn.Module): """A fake talker_mtp module for deterministic CPU testing.""" - def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step): + def forward( + self, + req_input_ids, + req_embeds, + last_talker_hidden, + text_step, + do_sample=None, + temperature=None, + top_k=None, + top_p=None, + ): # Deterministic behavior: # - output embeds = input embeds + 1 # - output codes = [[0], [1], ...] @@ -51,6 +61,36 @@ def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step): return new_embeds, codes +class CaptureTalkerMTP(torch.nn.Module): + """A fake talker_mtp module that records sampling kwargs.""" + + def __init__(self): + super().__init__() + self.calls = [] + + def forward( + self, + req_input_ids, + req_embeds, + last_talker_hidden, + text_step, + do_sample=None, + temperature=None, + top_k=None, + top_p=None, + ): + self.calls.append( + { + "do_sample": do_sample, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + } + ) + codes = torch.zeros((req_embeds.shape[0], 1), dtype=torch.int64) + return req_embeds, codes + + @contextmanager def _noop_forward_context(*args, **kwargs): """A no-op context manager to replace vLLM forward context in CPU tests.""" @@ -80,7 +120,7 @@ def _make_runner(req_ids=("r1", "r2"), hidden_size=4): runner.talker_mtp = DummyTalkerMTP() runner.model = SimpleNamespace(talker_mtp_output_key="code_predictor_codes") - runner.vllm_config = object() + runner.vllm_config = SimpleNamespace(model_config=SimpleNamespace()) # Provide a minimal implementation that returns the expected 4-tuple. def _determine_batch_execution_and_padding(**kwargs): @@ -168,6 +208,43 @@ def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch): assert torch.allclose(inputs_embeds, before) +def test_talker_mtp_forward_passes_qwen3_tts_subtalker_sampling_params_to_talker(monkeypatch): + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1",), hidden_size=4) + runner.talker_mtp = CaptureTalkerMTP() + runner.vllm_config = SimpleNamespace( + model_config=SimpleNamespace( + subtalker_sampling_params={ + "do_sample": False, + "temperature": 0.2, + "top_k": 9, + "top_p": 0.55, + } + ) + ) + + def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_scheduled_tokens, use_cascade_attn): + batch_desc = SimpleNamespace(num_tokens=int(num_tokens)) + return (False, batch_desc, None, None, None) + + monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner))) + + inputs_embeds = torch.zeros((2, 4), dtype=torch.float32) + OmniGPUModelRunner._talker_mtp_forward(runner, ["r1"], inputs_embeds) + + assert runner.talker_mtp.calls == [ + { + "do_sample": False, + "temperature": 0.2, + "top_k": 9, + "top_p": 0.55, + } + ] + + def test_update_intermediate_buffer_writes_to_buffer_and_setattr(monkeypatch): """Validate that _update_intermediate_buffer writes to model_intermediate_buffer (forward path) and mirrors to additional_information_cpu setattr (backward compat).""" diff --git a/vllm_omni/config/model.py b/vllm_omni/config/model.py index 588efabfc4f..96a34a8d79f 100644 --- a/vllm_omni/config/model.py +++ b/vllm_omni/config/model.py @@ -109,9 +109,11 @@ class OmniModelConfig(ModelConfig): "extra": {}, } ) + subtalker_sampling_params: dict[str, Any] | None = None omni_kv_config: dict | None = None codec_frame_rate_hz: float | None = None task_type: str | None = None + enable_sleep_mode: bool = False @property def registry(self): diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py index ab975cafc3b..8c248059280 100644 --- a/vllm_omni/config/stage_config.py +++ b/vllm_omni/config/stage_config.py @@ -394,6 +394,7 @@ class StageDeployConfig: output_connectors: dict[str, str] | None = None input_connectors: dict[str, str] | None = None default_sampling_params: dict[str, Any] | None = None + subtalker_sampling_params: dict[str, Any] | None = None engine_extras: dict[str, Any] = field(default_factory=dict) @@ -472,7 +473,7 @@ def _parse_stage_deploy(stage_data: dict[str, Any]) -> StageDeployConfig: return StageDeployConfig(**kwargs) -_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "engine_extras", "engine_args"}) +_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "subtalker_sampling_params", "engine_extras", "engine_args"}) def _deep_merge_stage(base: dict, overlay: dict) -> dict: diff --git a/vllm_omni/deploy/qwen3_tts.yaml b/vllm_omni/deploy/qwen3_tts.yaml index 32dceebd805..5948007a97b 100644 --- a/vllm_omni/deploy/qwen3_tts.yaml +++ b/vllm_omni/deploy/qwen3_tts.yaml @@ -43,6 +43,11 @@ stages: max_tokens: 4096 seed: 42 repetition_penalty: 1.05 + subtalker_sampling_params: + do_sample: true + temperature: 0.9 + top_k: 50 + top_p: 1.0 - stage_id: 1 max_num_seqs: 1 diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 0a19eb11974..012c0130c7c 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -353,6 +353,8 @@ def __getattr__(self, item: str) -> Any: @dataclass class OmniDiffusionConfig: # Model and path configuration (for convenience) + stage_id: int = 0 + model: str | None = None model_class_name: str | None = None @@ -508,6 +510,8 @@ class OmniDiffusionConfig: # Step mode settings step_execution: bool = False + # sleep mode + enable_sleep_mode: bool = False # Maximum number of sequences to generate in a batch max_num_seqs: int = 1 @@ -797,5 +801,43 @@ def __str__(self): return self.name.lower() +@dataclass +class OmniACK: + """ + Handshake payload from Workers to Orchestrator. + """ + + task_id: str + status: str + stage_id: int | None = None + rank: int | None = None + freed_bytes: int = 0 + metadata: dict[str, Any] = field(default_factory=dict) + """ + Additional telemetry such as: + - max_contiguous_block: for fragmentation analysis. + - cuda_graph_recalled: boolean if graphs were successfully destroyed/rebuilt. + - latency_ms: time taken for the D2H/H2D transfer. + """ + error_msg: str | None = None + + +@dataclass +class OmniSleepTask: + """Structured sleep instruction.""" + + task_id: str + level: int = 2 + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class OmniWakeTask: + """Structured wake-up instruction.""" + + task_id: str + tags: list[str] | None = None + + # Special message broadcast via scheduler queues to signal worker shutdown. SHUTDOWN_MESSAGE = {"type": "shutdown"} diff --git a/vllm_omni/diffusion/diffusion_engine.py b/vllm_omni/diffusion/diffusion_engine.py index fe940d623e5..abaf5989598 100644 --- a/vllm_omni/diffusion/diffusion_engine.py +++ b/vllm_omni/diffusion/diffusion_engine.py @@ -14,6 +14,7 @@ import PIL.Image import torch from vllm.logger import init_logger +from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.diffusion.data import ( DiffusionOutput, @@ -122,7 +123,7 @@ def step(self, request: OmniDiffusionRequest) -> list[OmniRequestOutput]: if output.aborted: raise DiffusionRequestAbortedError(output.abort_message or "Diffusion request aborted.") if output.error: - raise RuntimeError(f"{output.error}") + raise RuntimeError(output.error) logger.info("Generation completed successfully.") if output.output is None: @@ -358,6 +359,8 @@ def add_req_and_wait_for_response(self, request: OmniDiffusionRequest) -> Diffus sched_req_id = sched_output.scheduled_req_ids[0] try: runner_output = self.execute_fn(sched_output) + except EngineDeadError: + raise except Exception as exc: logger.error("Execution failed for diffusion request %s", sched_req_id, exc_info=True) runner_output = RunnerOutput( diff --git a/vllm_omni/diffusion/executor/multiproc_executor.py b/vllm_omni/diffusion/executor/multiproc_executor.py index e55a464fb4a..dcb35cfde1f 100644 --- a/vllm_omni/diffusion/executor/multiproc_executor.py +++ b/vllm_omni/diffusion/executor/multiproc_executor.py @@ -1,14 +1,18 @@ from __future__ import annotations import multiprocessing as mp +import multiprocessing.connection +import threading import time import weakref +from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any import zmq from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger +from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.diffusion.data import SHUTDOWN_MESSAGE, DiffusionOutput from vllm_omni.diffusion.executor.abstract import DiffusionExecutor @@ -22,6 +26,8 @@ logger = init_logger(__name__) +_DEQUEUE_TIMEOUT_S = 5.0 + @dataclass class BackgroundResources: @@ -36,10 +42,14 @@ class BackgroundResources: def __call__(self): """Clean up background resources.""" + if hasattr(self, "wake_events") and self.wake_events: + for ev in self.wake_events: + ev.set() + if self.broadcast_mq is not None: try: for _ in range(self.num_workers): - self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE) + self.broadcast_mq.enqueue(SHUTDOWN_MESSAGE, timeout=1.0) self.broadcast_mq = None self.result_mq = None @@ -63,13 +73,17 @@ class MultiprocDiffusionExecutor(DiffusionExecutor): def _init_executor(self) -> None: self._processes: list[mp.Process] = [] self._closed = False + self.is_failed = False + self._failure_callbacks: list[Callable[[], None]] = [] num_workers = self.od_config.num_gpus + self.wake_events = [mp.Event() for _ in range(num_workers)] + self._broadcast_mq = self._init_broadcast_queue(num_workers) broadcast_handle = self._broadcast_mq.export_handle() # Launch workers - processes, result_handle = self._launch_workers(broadcast_handle) + processes, result_handle = self._launch_workers(broadcast_handle, self.wake_events) self._result_mq = self._init_result_queue(result_handle) self._processes = processes @@ -81,6 +95,8 @@ def _init_executor(self) -> None: ) self._finalizer = weakref.finalize(self, self.resources) + self.start_worker_monitor() + def _init_broadcast_queue(self, num_workers: int) -> MessageQueue: return MessageQueue( n_reader=num_workers, @@ -100,7 +116,24 @@ def _ensure_open(self) -> None: if self._result_mq is None: raise RuntimeError("Result queue not initialized") - def _launch_workers(self, broadcast_handle): + def _dequeue_one_with_failure_polling(self, deadline: float | None, method: str) -> Any: + """Block until one result message, polling ``is_failed`` between chunk timeouts.""" + while True: + if deadline is None: + chunk_timeout = _DEQUEUE_TIMEOUT_S + else: + remaining = deadline - time.monotonic() + if remaining <= 0: + raise TimeoutError(f"RPC call to {method} timed out.") + chunk_timeout = min(_DEQUEUE_TIMEOUT_S, remaining) + try: + return self._result_mq.dequeue(timeout=chunk_timeout) + except (TimeoutError, zmq.error.Again): + if self.is_failed: + raise EngineDeadError() + continue + + def _launch_workers(self, broadcast_handle, wake_events): od_config = self.od_config logger.info("Starting server...") @@ -126,6 +159,7 @@ def _launch_workers(self, broadcast_handle): od_config, writer, broadcast_handle, + wake_events[i], worker_extension_cls, custom_pipeline_args, ), @@ -164,6 +198,49 @@ def _launch_workers(self, broadcast_handle): return processes, result_handle + def start_worker_monitor(self) -> None: + # Monitors worker process liveness. If any die unexpectedly, + # logs an error, shuts down the executor and invokes the failure + # callback to inform the engine. + sentinels = [p.sentinel for p in self._processes] + if not sentinels: + return + + def _monitor() -> None: + try: + finished = multiprocessing.connection.wait(sentinels) + except OSError: + return + + if self._closed: + return + + dead = [p.name for p in self._processes if p.sentinel in finished] + if dead: + logger.error( + "Diffusion worker(s) died unexpectedly: %s", + dead, + ) + self.is_failed = True + + self.shutdown() + + for cb in self._failure_callbacks: + try: + cb() + except Exception: + logger.exception("failure_callback raised") + + t = threading.Thread(target=_monitor, daemon=True, name="diffusion-worker-monitor") + t.start() + + def register_failure_callback( + self, + callback: Callable[[], None], + ) -> None: + """Register a callback invoked when a worker process dies.""" + self._failure_callbacks.append(callback) + def add_req(self, request: OmniDiffusionRequest) -> DiffusionOutput: self._ensure_open() rpc_request = { @@ -286,27 +363,21 @@ def collective_rpc( responses = [] for _ in range(num_responses): - dequeue_timeout = None if deadline is None else max(0, deadline - time.monotonic()) + response = self._dequeue_one_with_failure_polling(deadline, method) + try: - response = self._result_mq.dequeue(timeout=dequeue_timeout) - - try: - unpack_diffusion_output_shm(response) - except Exception as e: - logger.warning("SHM unpack failed (data may already be inline): %s", e) - - # Check if response indicates an error - if isinstance(response, dict) and response.get("status") == "error": - raise RuntimeError( - f"Worker failed with error '{response.get('error')}', " - "please check the stack trace above for the root cause" - ) - - responses.append(response) - except zmq.error.Again as e: - raise TimeoutError(f"RPC call to {method} timed out.") from e - except TimeoutError as e: - raise TimeoutError(f"RPC call to {method} timed out.") from e + unpack_diffusion_output_shm(response) + except Exception as e: + logger.warning("SHM unpack failed (data may already be inline): %s", e) + + # Check if response indicates an error + if isinstance(response, dict) and response.get("status") == "error": + raise RuntimeError( + f"Worker failed with error '{response.get('error')}', " + "please check the stack trace above for the root cause" + ) + + responses.append(response) return responses[0] if unique_reply_rank is not None else responses except Exception as e: @@ -314,10 +385,13 @@ def collective_rpc( raise def check_health(self) -> None: - # Simple check if processes are alive + self._ensure_open() + if self.is_failed: + raise EngineDeadError() for p in self._processes: if not p.is_alive(): - raise RuntimeError(f"Worker process {p.name} is dead") + self.is_failed = True + raise EngineDeadError(f"Worker process {p.name} is dead") def shutdown(self) -> None: self._closed = True diff --git a/vllm_omni/diffusion/inline_stage_diffusion_client.py b/vllm_omni/diffusion/inline_stage_diffusion_client.py index 5cdd6e4fabc..0eec8fee996 100644 --- a/vllm_omni/diffusion/inline_stage_diffusion_client.py +++ b/vllm_omni/diffusion/inline_stage_diffusion_client.py @@ -34,6 +34,7 @@ class InlineStageDiffusionClient: """Runs DiffusionEngine in a thread executor inside the Orchestrator.""" stage_type: str = "diffusion" + replica_id: int = 0 def __init__( self, @@ -45,6 +46,7 @@ def __init__( self.model = model self.od_config = od_config self.stage_id = metadata.stage_id + self.replica_id = getattr(metadata, "replica_id", 0) self.final_output = metadata.final_output self.final_output_type = metadata.final_output_type self.default_sampling_params = metadata.default_sampling_params @@ -62,8 +64,9 @@ def __init__( self._shutting_down = False logger.info( - "[InlineStageDiffusionClient] Stage-%s initialized inline (batch_size=%d)", + "[InlineStageDiffusionClient] stage-%s [rep-%s] initialized inline (batch_size=%d)", self.stage_id, + self.replica_id, self.batch_size, ) @@ -82,6 +85,12 @@ async def add_request_async( sampling_params: OmniDiffusionSamplingParams, kv_sender_info: dict[int, dict[str, Any]] | None = None, ) -> None: + logger.info( + "[InlineStageDiffusionClient] stage-%s [rep-%s] add request: %s", + self.stage_id, + self.replica_id, + request_id, + ) task = asyncio.create_task( self._dispatch_request( request_id, @@ -135,6 +144,13 @@ async def add_batch_request_async( sampling_params: OmniDiffusionSamplingParams, kv_sender_info: dict[int, dict[str, Any]] | None = None, ) -> None: + logger.info( + "[InlineStageDiffusionClient] stage-%s [rep-%s] add batch request: %s (%d prompts)", + self.stage_id, + self.replica_id, + request_id, + len(prompts), + ) task = asyncio.create_task( self._dispatch_batch( request_id, @@ -254,7 +270,7 @@ async def collective_rpc_async( is_start = args[0] if args else True profile_prefix = args[1] if len(args) > 1 else None if is_start and profile_prefix is None: - profile_prefix = f"stage_{self.stage_id}_diffusion_{int(time.time())}" + profile_prefix = f"stage_{self.stage_id}_rep_{self.replica_id}_diffusion_{int(time.time())}" return await loop.run_in_executor( self._executor, self._engine.profile, diff --git a/vllm_omni/diffusion/models/dmd2/__init__.py b/vllm_omni/diffusion/models/dmd2/__init__.py new file mode 100644 index 00000000000..d0c8219d4d1 --- /dev/null +++ b/vllm_omni/diffusion/models/dmd2/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.dmd2.mixin import DMD2PipelineMixin + +__all__ = [ + "DMD2PipelineMixin", +] diff --git a/vllm_omni/diffusion/models/dmd2/mixin.py b/vllm_omni/diffusion/models/dmd2/mixin.py new file mode 100644 index 00000000000..60c4b95baff --- /dev/null +++ b/vllm_omni/diffusion/models/dmd2/mixin.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import logging +import os + +from vllm_omni.diffusion.data import DiffusionOutput +from vllm_omni.diffusion.models.schedulers import DMD2EulerScheduler +from vllm_omni.diffusion.models.utils import _load_json +from vllm_omni.diffusion.request import OmniDiffusionRequest + +logger = logging.getLogger(__name__) + + +class DMD2PipelineMixin: + """Mixin for FastGen DMD2-distilled models. Must appear before the base pipeline in MRO.""" + + def __init_dmd2__(self) -> None: + """Call after super().__init__() to apply DMD2 scheduler and read model_index.""" + local_files_only = os.path.exists(self.od_config.model) + try: + model_index = _load_json(self.od_config.model, "model_index.json", local_files_only) + except Exception: + model_index = {} + + dmd2_timesteps = model_index.get("dmd2_denoising_timesteps", [999, 937, 833, 624]) + self.num_inference_steps = model_index.get("dmd2_num_inference_steps", 4) + shift = model_index.get("dmd2_scheduler_shift", 1.0) + self.dmd2_guidance_scale = model_index.get("dmd2_guidance_scale", 1.0) + + self.scheduler = DMD2EulerScheduler( + num_train_timesteps=1000, + shift=shift, + dmd2_timesteps=dmd2_timesteps, + ) + + def _sanitize_dmd2_request(self, req: OmniDiffusionRequest) -> None: + """Sanitize CFG-related fields in-place. Mutates req.sampling_params and req.prompts.""" + sp = req.sampling_params + + if sp.num_inference_steps and sp.num_inference_steps != self.num_inference_steps: + logger.warning( + "DMD2: ignoring num_inference_steps=%d, forcing %d.", + sp.num_inference_steps, + self.num_inference_steps, + ) + sp.num_inference_steps = self.num_inference_steps + + if sp.guidance_scale_provided and sp.guidance_scale != self.dmd2_guidance_scale: + logger.warning( + "DMD2: ignoring guidance_scale=%.2f, forcing %.2f.", + sp.guidance_scale, + self.dmd2_guidance_scale, + ) + sp.guidance_scale = self.dmd2_guidance_scale + sp.guidance_scale_provided = False + + if sp.guidance_scale_2 is not None: + logger.warning("DMD2: ignoring guidance_scale_2.") + sp.guidance_scale_2 = None + + if sp.true_cfg_scale is not None: + logger.warning("DMD2: ignoring true_cfg_scale.") + sp.true_cfg_scale = None + + sp.do_classifier_free_guidance = False + sp.is_cfg_negative = False + + fixed = [] + for p in req.prompts: + if isinstance(p, dict) and "negative_prompt" in p: + logger.warning("DMD2: ignoring negative_prompt.") + p = {k: v for k, v in p.items() if k != "negative_prompt"} + fixed.append(p) + req.prompts = fixed + + def forward(self, req: OmniDiffusionRequest, **kwargs) -> DiffusionOutput: + self._sanitize_dmd2_request(req) + kwargs.pop("guidance_scale", None) + kwargs.pop("num_inference_steps", None) + return super().forward( + req, + guidance_scale=self.dmd2_guidance_scale, + num_inference_steps=self.num_inference_steps, + **kwargs, + ) diff --git a/vllm_omni/diffusion/models/ltx2/__init__.py b/vllm_omni/diffusion/models/ltx2/__init__.py index 9f9d70f0106..2a78b61baeb 100644 --- a/vllm_omni/diffusion/models/ltx2/__init__.py +++ b/vllm_omni/diffusion/models/ltx2/__init__.py @@ -4,12 +4,14 @@ from vllm_omni.diffusion.models.ltx2.ltx2_transformer import LTX2VideoTransformer3DModel from vllm_omni.diffusion.models.ltx2.pipeline_ltx2 import ( LTX2Pipeline, + LTX2T2VDMD2Pipeline, LTX2TwoStagesPipeline, create_transformer_from_config, get_ltx2_post_process_func, load_transformer_config, ) from vllm_omni.diffusion.models.ltx2.pipeline_ltx2_image2video import ( + LTX2I2VDMD2Pipeline, LTX2ImageToVideoPipeline, LTX2ImageToVideoTwoStagesPipeline, ) @@ -17,7 +19,9 @@ __all__ = [ "LTX2Pipeline", + "LTX2T2VDMD2Pipeline", "LTX2ImageToVideoPipeline", + "LTX2I2VDMD2Pipeline", "LTX2LatentUpsamplePipeline", "LTX2TwoStagesPipeline", "LTX2ImageToVideoTwoStagesPipeline", diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py index 95ef919c24e..0ae71aa2d71 100644 --- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py +++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py @@ -41,6 +41,7 @@ from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention +from vllm_omni.diffusion.distributed.hsdp_utils import is_transformer_block_module from vllm_omni.diffusion.distributed.sp_plan import SequenceParallelInput, SequenceParallelOutput from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available @@ -1265,6 +1266,7 @@ class LTX2VideoTransformer3DModel(nn.Module): _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["LTX2VideoTransformerBlock"] _layerwise_offload_blocks_attrs = ["transformer_blocks"] + _hsdp_shard_conditions = [is_transformer_block_module] _sp_plan: dict[str, Any] | None = None @staticmethod diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py index c60b192f0a5..f06ffab165a 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2.py @@ -33,6 +33,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.lora.request import LoRARequest @@ -1304,3 +1305,11 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class LTX2T2VDMD2Pipeline(DMD2PipelineMixin, LTX2Pipeline): + """LTX-2 T2V pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py index 65e7454b73f..50a71a54b61 100644 --- a/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py +++ b/vllm_omni/diffusion/models/ltx2/pipeline_ltx2_image2video.py @@ -25,6 +25,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.lora.request import LoRARequest @@ -889,3 +890,11 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class LTX2I2VDMD2Pipeline(DMD2PipelineMixin, LTX2ImageToVideoPipeline): + """LTX-2 I2V pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py b/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py index 881c72edc6d..c1abdf91f04 100644 --- a/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py +++ b/vllm_omni/diffusion/models/magi_human/pipeline_magi_human.py @@ -48,6 +48,7 @@ ) from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.t5_encoder.t5_gemma_encoder import T5GemmaEncoderModelTP +from vllm_omni.diffusion.models.utils import _load_json from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import ( DiffusionPipelineProfilerMixin, ) @@ -1640,20 +1641,6 @@ def post_process(output): # =========================================================================== -def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict: - """Load a JSON config file from a local path or HuggingFace Hub repo.""" - if local_files_only: - path = os.path.join(model_path, *filename.split("/")) - with open(path) as f: - return json.load(f) - else: - from huggingface_hub import hf_hub_download - - cached = hf_hub_download(repo_id=model_path, filename=filename) - with open(cached) as f: - return json.load(f) - - def _resolve_subdir( model_path: str, subfolder: str, diff --git a/vllm_omni/diffusion/models/schedulers/__init__.py b/vllm_omni/diffusion/models/schedulers/__init__.py index 6f8df78ebf0..e683ed27203 100644 --- a/vllm_omni/diffusion/models/schedulers/__init__.py +++ b/vllm_omni/diffusion/models/schedulers/__init__.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm_omni.diffusion.models.schedulers.scheduling_dmd2_euler import DMD2EulerScheduler from vllm_omni.diffusion.models.schedulers.scheduling_flow_unipc_multistep import ( FlowUniPCMultistepScheduler, ) __all__ = [ + "DMD2EulerScheduler", "FlowUniPCMultistepScheduler", ] diff --git a/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py new file mode 100644 index 00000000000..01447a41d77 --- /dev/null +++ b/vllm_omni/diffusion/models/schedulers/scheduling_dmd2_euler.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import torch +from diffusers import FlowMatchEulerDiscreteScheduler + + +class DMD2EulerScheduler(FlowMatchEulerDiscreteScheduler): + """Euler scheduler that always uses the fixed DMD2 training timestep schedule.""" + + def __init__(self, *args, dmd2_timesteps: list[int], **kwargs): + super().__init__(*args, **kwargs) + self._dmd2_timesteps = dmd2_timesteps + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + **kwargs, + ) -> None: + super().set_timesteps(timesteps=self._dmd2_timesteps, device=device) diff --git a/vllm_omni/diffusion/models/utils.py b/vllm_omni/diffusion/models/utils.py new file mode 100644 index 00000000000..ba0d8dda20c --- /dev/null +++ b/vllm_omni/diffusion/models/utils.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import json +import os + + +def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict: + """Load a JSON config file from a local path or HuggingFace Hub repo.""" + if local_files_only: + path = os.path.join(model_path, *filename.split("/")) + with open(path) as f: + return json.load(f) + else: + from huggingface_hub import hf_hub_download + + cached = hf_hub_download(repo_id=model_path, filename=filename) + with open(cached) as f: + return json.load(f) diff --git a/vllm_omni/diffusion/models/wan2_2/__init__.py b/vllm_omni/diffusion/models/wan2_2/__init__.py index d418001d952..97808df29d8 100644 --- a/vllm_omni/diffusion/models/wan2_2/__init__.py +++ b/vllm_omni/diffusion/models/wan2_2/__init__.py @@ -3,6 +3,7 @@ from .pipeline_wan2_2 import ( Wan22Pipeline, + WanT2VDMD2Pipeline, create_transformer_from_config, get_wan22_post_process_func, get_wan22_pre_process_func, @@ -11,6 +12,7 @@ ) from .pipeline_wan2_2_i2v import ( Wan22I2VPipeline, + WanI2VDMD2Pipeline, get_wan22_i2v_post_process_func, get_wan22_i2v_pre_process_func, ) @@ -28,6 +30,7 @@ from .wan2_2_vace_transformer import VaceWanTransformerBlock, WanVACETransformer3DModel __all__ = [ + "WanT2VDMD2Pipeline", "Wan22Pipeline", "get_wan22_post_process_func", "get_wan22_pre_process_func", @@ -35,6 +38,7 @@ "load_transformer_config", "create_transformer_from_config", "Wan22I2VPipeline", + "WanI2VDMD2Pipeline", "get_wan22_i2v_post_process_func", "get_wan22_i2v_pre_process_func", "Wan22TI2VPipeline", diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 652425d5097..c74b72441d4 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -22,6 +22,7 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler from vllm_omni.diffusion.models.wan2_2.scheduling_wan_euler import WanEulerScheduler @@ -29,16 +30,12 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform logger = logging.getLogger(__name__) DEBUG_PERF = False WAN_SAMPLE_SOLVER_CHOICES = {"unipc", "euler"} -WAN22_MAX_SEQUENCE_LENGTH = 512 def build_wan_scheduler(sample_solver: str, flow_shift: float) -> Any: @@ -293,7 +290,6 @@ def __init__( pass self.boundary_ratio = od_config.boundary_ratio - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH # Determine which transformers to load based on boundary_ratio # boundary_ratio=1.0: only load transformer_2 (low-noise stage only) @@ -565,7 +561,6 @@ def forward( negative_prompt_embeds=negative_prompt_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -599,7 +594,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -832,20 +827,6 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) text_inputs = self.tokenizer( prompt_clean, @@ -874,24 +855,8 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", max_length=max_sequence_length, truncation=True, @@ -963,7 +928,6 @@ def check_inputs( negative_prompt_embeds=None, guidance_scale_2=None, boundary_ratio=None, - max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -990,10 +954,18 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") + + +# --------------------------------------------------------------------------- +# DMD2-distilled variant +# --------------------------------------------------------------------------- + + +class WanT2VDMD2Pipeline(DMD2PipelineMixin, Wan22Pipeline): + """Wan 2.x T2V pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 95d1e08bbc7..726d8853a7c 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -23,10 +23,11 @@ from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader +from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin, _is_rank_zero +from vllm_omni.diffusion.models.utils import _load_json from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, build_wan_scheduler, create_transformer_from_config, load_transformer_config, @@ -37,9 +38,6 @@ from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.profiler.diffusion_pipeline_profiler import DiffusionPipelineProfilerMixin from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -47,29 +45,6 @@ DEBUG_PERF = False -def _load_model_index(model: str, local_files_only: bool) -> dict: - """Load model_index.json from local path or HF Hub.""" - if local_files_only: - model_index_path = os.path.join(model, "model_index.json") - if os.path.exists(model_index_path): - import json - - with open(model_index_path) as f: - return json.load(f) - else: - try: - import json - - from huggingface_hub import hf_hub_download - - model_index_path = hf_hub_download(model, "model_index.json") - with open(model_index_path) as f: - return json.load(f) - except Exception: - pass - return {} - - def get_wan22_i2v_post_process_func( od_config: OmniDiffusionConfig, ): @@ -200,7 +175,10 @@ def __init__( ] # Load model_index.json to detect available components - model_index = _load_model_index(model, local_files_only) + try: + model_index = _load_json(model, "model_index.json", local_files_only) + except Exception: + model_index = {} # Check if this is a two-stage model (MoE with transformer_2) self.has_transformer_2 = "transformer_2" in model_index @@ -218,7 +196,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -474,7 +451,6 @@ def forward( image_embeds=image_embeds, guidance_scale_2=guidance_high if boundary_ratio is not None else None, boundary_ratio=boundary_ratio, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -503,7 +479,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_low > 1.0 or guidance_high > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -708,20 +684,6 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) text_inputs = self.tokenizer( prompt_clean, @@ -750,24 +712,8 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", max_length=max_sequence_length, truncation=True, @@ -911,7 +857,6 @@ def check_inputs( image_embeds=None, guidance_scale_2=None, boundary_ratio=None, - max_sequence_length=None, ): if image is None and image_embeds is None: raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") @@ -933,11 +878,6 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - if boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when `boundary_ratio` is set.") @@ -945,3 +885,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +# --------------------------------------------------------------------------- +# DMD2-distilled variant +# --------------------------------------------------------------------------- + + +class WanI2VDMD2Pipeline(DMD2PipelineMixin, Wan22I2VPipeline): + """Wan 2.x I2V pipeline for FastGen DMD2-distilled models.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + self.__init_dmd2__() diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index dba76ba8af8..8170a8f5ab5 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -37,7 +37,6 @@ from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.progress_bar import ProgressBarMixin from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( - WAN22_MAX_SEQUENCE_LENGTH, build_wan_scheduler, create_transformer_from_config, load_transformer_config, @@ -47,9 +46,6 @@ ) from vllm_omni.diffusion.postprocess import interpolate_video_tensor from vllm_omni.diffusion.request import OmniDiffusionRequest -from vllm_omni.diffusion.utils.prompt_utils import ( - validate_prompt_sequence_lengths, -) from vllm_omni.inputs.data import OmniTextPrompt from vllm_omni.platforms import current_omni_platform @@ -189,7 +185,6 @@ def __init__( # Text encoder self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) - self.tokenizer_max_length = WAN22_MAX_SEQUENCE_LENGTH self.text_encoder = UMT5EncoderModel.from_pretrained( model, subfolder="text_encoder", torch_dtype=dtype, local_files_only=local_files_only ).to(self.device) @@ -377,7 +372,6 @@ def forward( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) # Adjust num_frames to be compatible with VAE temporal scaling @@ -401,7 +395,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) @@ -551,20 +545,6 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt prompt_clean = [self._prompt_clean(p) for p in prompt] batch_size = len(prompt_clean) - text_inputs_untruncated = self.tokenizer( - prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - error_context="for Wan2.2 text encoding", - ) text_inputs = self.tokenizer( prompt_clean, @@ -593,24 +573,8 @@ def encode_prompt( if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_clean = [self._prompt_clean(p) for p in negative_prompt] - neg_text_inputs_untruncated = self.tokenizer( - negative_prompt_clean, - padding=True, - truncation=False, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - validate_prompt_sequence_lengths( - neg_text_inputs_untruncated.attention_mask, - max_sequence_length=max_sequence_length, - supported_max_sequence_length=self.tokenizer_max_length, - prompt_name="negative_prompt", - error_context="for Wan2.2 text encoding", - ) neg_text_inputs = self.tokenizer( - negative_prompt_clean, + [self._prompt_clean(p) for p in negative_prompt], padding="max_length", max_length=max_sequence_length, truncation=True, @@ -735,7 +699,6 @@ def check_inputs( width, prompt_embeds=None, negative_prompt_embeds=None, - max_sequence_length=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -751,11 +714,6 @@ def check_inputs( if prompt is None and prompt_embeds is None: raise ValueError("Provide either `prompt` or `prompt_embeds`.") - if max_sequence_length is not None and max_sequence_length > self.tokenizer_max_length: - raise ValueError( - f"`max_sequence_length` cannot be greater than {self.tokenizer_max_length} but is {max_sequence_length}" - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py index 11408e2d24b..0458f88597e 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py @@ -243,7 +243,6 @@ def check_inputs( video=None, mask=None, reference_images=None, - max_sequence_length=None, ): super().check_inputs( prompt=prompt, @@ -252,7 +251,6 @@ def check_inputs( width=width, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, ) # VACE-specific: validate video/mask/reference_images consistency @@ -549,7 +547,6 @@ def forward( video=source_video, mask=source_mask, reference_images=reference_images, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, ) device = self.device @@ -568,7 +565,7 @@ def forward( negative_prompt=negative_prompt, do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=req.sampling_params.num_outputs_per_prompt or 1, - max_sequence_length=req.sampling_params.max_sequence_length or self.tokenizer_max_length, + max_sequence_length=req.sampling_params.max_sequence_length or 512, device=device, dtype=dtype, ) diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index 7876b009475..d1216e2f28c 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations from itertools import chain from typing import Any diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 0bf8c04517b..4001109cc9e 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -83,6 +83,16 @@ "pipeline_ltx2_image2video", "LTX2ImageToVideoTwoStagesPipeline", ), + "LTX2T2VDMD2Pipeline": ( + "ltx2", + "pipeline_ltx2", + "LTX2T2VDMD2Pipeline", + ), + "LTX2I2VDMD2Pipeline": ( + "ltx2", + "pipeline_ltx2_image2video", + "LTX2I2VDMD2Pipeline", + ), "StableAudioPipeline": ( "stable_audio", "pipeline_stable_audio", @@ -93,6 +103,16 @@ "pipeline_wan2_2_i2v", "Wan22I2VPipeline", ), + "WanT2VDMD2Pipeline": ( + "wan2_2", + "pipeline_wan2_2", + "WanT2VDMD2Pipeline", + ), + "WanI2VDMD2Pipeline": ( + "wan2_2", + "pipeline_wan2_2_i2v", + "WanI2VDMD2Pipeline", + ), "LongCatImagePipeline": ( "longcat_image", "pipeline_longcat_image", @@ -357,8 +377,12 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "LTX2TwoStagesPipeline": "get_ltx2_post_process_func", "LTX2ImageToVideoPipeline": "get_ltx2_post_process_func", "LTX2ImageToVideoTwoStagesPipeline": "get_ltx2_post_process_func", + "LTX2T2VDMD2Pipeline": "get_ltx2_post_process_func", + "LTX2I2VDMD2Pipeline": "get_ltx2_post_process_func", "StableAudioPipeline": "get_stable_audio_post_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_post_process_func", + "WanT2VDMD2Pipeline": "get_wan22_post_process_func", + "WanI2VDMD2Pipeline": "get_wan22_i2v_post_process_func", "LongCatImagePipeline": "get_longcat_image_post_process_func", "BagelPipeline": "get_bagel_post_process_func", "LongCatImageEditPipeline": "get_longcat_image_post_process_func", @@ -390,6 +414,8 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "WanPipeline": "get_wan22_pre_process_func", "WanVACEPipeline": "get_wan22_vace_pre_process_func", "WanImageToVideoPipeline": "get_wan22_i2v_pre_process_func", + "WanT2VDMD2Pipeline": "get_wan22_pre_process_func", + "WanI2VDMD2Pipeline": "get_wan22_i2v_pre_process_func", "OmniGen2Pipeline": "get_omnigen2_pre_process_func", "HeliosPipeline": "get_helios_pre_process_func", "HeliosPyramidPipeline": "get_helios_pre_process_func", diff --git a/vllm_omni/diffusion/stage_diffusion_client.py b/vllm_omni/diffusion/stage_diffusion_client.py index f952f32d9ce..dacbe387add 100644 --- a/vllm_omni/diffusion/stage_diffusion_client.py +++ b/vllm_omni/diffusion/stage_diffusion_client.py @@ -8,15 +8,20 @@ from __future__ import annotations import asyncio +import multiprocessing.connection import time import uuid +import weakref from dataclasses import fields, is_dataclass +from threading import Thread from typing import TYPE_CHECKING, Any import zmq from vllm.logger import init_logger +from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.diffusion.stage_diffusion_proc import ( + StageDiffusionProc, complete_diffusion_handshake, spawn_diffusion_proc, ) @@ -62,6 +67,7 @@ class StageDiffusionClient: """ stage_type: str = "diffusion" + replica_id: int = 0 def __init__( self, @@ -107,12 +113,13 @@ def _initialize_client( batch_size: int, ) -> None: self.stage_id = metadata.stage_id + self.replica_id = getattr(metadata, "replica_id", 0) self.final_output = metadata.final_output self.final_output_type = metadata.final_output_type self.default_sampling_params = metadata.default_sampling_params - self.requires_multimodal_data = metadata.requires_multimodal_data - self.custom_process_input_func = metadata.custom_process_input_func - self.engine_input_source = metadata.engine_input_source + self.requires_multimodal_data = getattr(metadata, "requires_multimodal_data", False) + self.custom_process_input_func = getattr(metadata, "custom_process_input_func", None) + self.engine_input_source = getattr(metadata, "engine_input_source", []) self._proc = proc self._owns_process = proc is not None @@ -130,14 +137,53 @@ def _initialize_client( self._pending_rpcs: set[str] = set() self._tasks: dict[str, asyncio.Task] = {} self._shutting_down = False + self._engine_dead: bool = False + + # Background thread to detect silent process death (SIGKILL, segfault) + # where the subprocess cannot send the ZMQ death sentinel. + # Mirrors MPClient.start_engine_core_monitor() in vLLM. + self._start_proc_monitor() logger.info( - "[StageDiffusionClient] Stage-%s initialized (owns_process=%s, batch_size=%d)", + "[StageDiffusionClient] stage-%s [rep-%s] initialized (owns_process=%s, batch_size=%d)", self.stage_id, + self.replica_id, self._owns_process, batch_size, ) + # ------------------------------------------------------------------ + # Process monitor (mirrors vLLM's MPClient.start_engine_core_monitor) + # ------------------------------------------------------------------ + + def _start_proc_monitor(self) -> None: + """Start a daemon thread that watches the subprocess sentinel. + + When the subprocess dies without sending the ZMQ death sentinel + (e.g. SIGKILL, segfault), this thread sets ``_engine_dead`` so + subsequent calls raise ``EngineDeadError``. + """ + proc = self._proc + self_ref = weakref.ref(self) + + def _monitor() -> None: + try: + multiprocessing.connection.wait([proc.sentinel]) + except Exception: + return + client = self_ref() + if client is None or client._shutting_down or client._engine_dead: + return + client._engine_dead = True + logger.error( + "[StageDiffusionClient] stage-%s [rep-%s] StageDiffusionProc died unexpectedly (exit code %s).", + client.stage_id, + client.replica_id, + proc.exitcode, + ) + + Thread(target=_monitor, daemon=True, name="DiffusionProcMonitor").start() + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @@ -150,6 +196,16 @@ def _drain_responses(self) -> None: except zmq.Again: break + # Check for the death sentinel (raw bytes, not msgpack-encoded). + if raw == StageDiffusionProc.DIFFUSION_PROC_DEAD: + self._engine_dead = True + logger.error( + "[StageDiffusionClient] stage-%s [rep-%s] received DIFFUSION_PROC_DEAD sentinel from subprocess.", + self.stage_id, + self.replica_id, + ) + break + msg = self._decoder.decode(raw) msg_type = msg.get("type") @@ -162,8 +218,9 @@ def _drain_responses(self) -> None: rpc_id = msg.get("rpc_id") error_msg = msg.get("error") logger.error( - "[StageDiffusionClient] Stage-%s subprocess error for %s: %s", + "[StageDiffusionClient] stage-%s [rep-%s] subprocess error for %s: %s", self.stage_id, + self.replica_id, rpc_id or req_id, error_msg, ) @@ -173,13 +230,10 @@ def _drain_responses(self) -> None: "error": True, "reason": error_msg, } - elif req_id is not None: - error_output = OmniRequestOutput.from_diffusion( - request_id=req_id, - images=[], - ) - error_output.error = error_msg - self._output_queue.put_nowait(error_output) + # Route request errors as error outputs so the Orchestrator + # sees the request complete (instead of hanging forever). + if req_id is not None: + self._output_queue.put_nowait(OmniRequestOutput.from_error(req_id, error_msg)) # Fields that are subprocess-local and cannot be serialized across # process boundaries. They are recreated in the subprocess with @@ -243,6 +297,14 @@ async def add_request_async( sampling_params: OmniDiffusionSamplingParams, kv_sender_info: dict[int, dict[str, Any]] | None = None, ) -> None: + if self._engine_dead: + raise EngineDeadError() + logger.info( + "[StageDiffusionClient] stage-%s [rep-%s] add request: %s", + self.stage_id, + self.replica_id, + request_id, + ) self._request_socket.send( self._encoder.encode( { @@ -270,6 +332,15 @@ async def add_batch_request_async( and the combined result is placed on the output queue with a single *request_id*. """ + if self._engine_dead: + raise EngineDeadError() + logger.info( + "[StageDiffusionClient] stage-%s [rep-%s] add batch request: %s (%d prompts)", + self.stage_id, + self.replica_id, + request_id, + len(prompts), + ) task = asyncio.create_task( self._run_batch( request_id, @@ -302,11 +373,13 @@ async def _run_batch( ) except Exception as e: logger.exception( - "[StageDiffusionClient] Stage-%s batch req=%s failed: %s", + "[StageDiffusionClient] stage-%s [rep-%s] batch req=%s failed: %s", self.stage_id, + self.replica_id, request_id, e, ) + await self._output_queue.put(OmniRequestOutput.from_error(request_id, str(e))) finally: self._tasks.pop(request_id, None) @@ -315,7 +388,10 @@ def get_diffusion_output_nowait(self) -> OmniRequestOutput | None: try: return self._output_queue.get_nowait() except asyncio.QueueEmpty: + if self._engine_dead: + raise EngineDeadError() if not self._shutting_down and self._owns_process and self._proc is not None and not self._proc.is_alive(): + self._engine_dead = True exitcode = self._proc.exitcode # One final drain – the last ZMQ frame may have arrived # between the first drain and the is_alive() check. @@ -329,7 +405,7 @@ def get_diffusion_output_nowait(self) -> OmniRequestOutput | None: logger.warning("StageDiffusionProc was killed by signal %d; treating as external shutdown.", sig) self._shutting_down = True return None - raise RuntimeError(f"StageDiffusionProc died unexpectedly (exit code {exitcode})") + raise EngineDeadError(f"StageDiffusionProc died unexpectedly (exit code {exitcode})") return None async def abort_requests_async(self, request_ids: list[str]) -> None: @@ -350,13 +426,16 @@ async def collective_rpc_async( kwargs: dict[str, Any] | None = None, ) -> Any: """Forward control RPCs to the diffusion subprocess.""" + if self._engine_dead: + raise EngineDeadError() + # Inject a default profile_prefix that includes stage_id when profiling. if method == "profile": args_list = list(args) is_start = args_list[0] if args_list else True profile_prefix = args_list[1] if len(args_list) > 1 else None if is_start and profile_prefix is None: - profile_prefix = f"stage_{self.stage_id}_diffusion_{int(time.time())}" + profile_prefix = f"stage_{self.stage_id}_rep_{self.replica_id}_diffusion_{int(time.time())}" if len(args_list) > 1: args_list[1] = profile_prefix else: @@ -387,8 +466,9 @@ async def collective_rpc_async( self._drain_responses() if rpc_id in self._rpc_results: return self._rpc_results.pop(rpc_id) - if self._owns_process and self._proc is not None and not self._proc.is_alive(): - raise RuntimeError( + if self._engine_dead or (self._owns_process and self._proc is not None and not self._proc.is_alive()): + self._engine_dead = True + raise EngineDeadError( f"StageDiffusionProc died while waiting for " f"collective_rpc '{method}' (exit code {self._proc.exitcode})" ) @@ -398,6 +478,19 @@ async def collective_rpc_async( finally: self._pending_rpcs.discard(rpc_id) + def check_health(self) -> None: + """Raise ``EngineDeadError`` if the diffusion engine is dead. + + Mirrors the ``check_health`` protocol on vLLM's ``EngineClient``. + """ + if self._engine_dead: + raise EngineDeadError(f"Stage-{self.stage_id} diffusion subprocess is dead") + if self._proc is not None and not self._proc.is_alive(): + self._engine_dead = True + raise EngineDeadError( + f"Stage-{self.stage_id} diffusion subprocess is not alive (exit code: {self._proc.exitcode})." + ) + def shutdown(self) -> None: self._shutting_down = True try: diff --git a/vllm_omni/diffusion/stage_diffusion_proc.py b/vllm_omni/diffusion/stage_diffusion_proc.py index eced444fd32..d36c5c644c7 100644 --- a/vllm_omni/diffusion/stage_diffusion_proc.py +++ b/vllm_omni/diffusion/stage_diffusion_proc.py @@ -46,6 +46,8 @@ class StageDiffusionProc: and ZMQ-based communication with StageDiffusionClient. """ + DIFFUSION_PROC_DEAD = b"DIFFUSION_PROC_DEAD" + def __init__(self, model: str, od_config: OmniDiffusionConfig) -> None: self._model = model self._od_config = od_config @@ -450,6 +452,16 @@ async def _dispatch_batch( elif msg_type == "shutdown": break + except Exception: + # Send the death sentinel so the client can detect the + # fatal failure promptly (mirrors EngineCoreProc._send_engine_dead). + try: + response_socket.setsockopt(zmq.LINGER, 4000) + await response_socket.send(StageDiffusionProc.DIFFUSION_PROC_DEAD) + except Exception: + logger.warning("Failed to send DIFFUSION_PROC_DEAD sentinel to client.") + raise + finally: for task in tasks.values(): task.cancel() diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 160309e0d8d..621f7eebd86 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -27,7 +27,10 @@ from vllm_omni.diffusion.data import ( DiffusionOutput, + OmniACK, OmniDiffusionConfig, + OmniSleepTask, + OmniWakeTask, ) from vllm_omni.diffusion.distributed.parallel_state import ( destroy_distributed_env, @@ -77,6 +80,7 @@ def __init__( self.model_runner: DiffusionModelRunner | None = None self._sleep_saved_buffers: dict[str, torch.Tensor] = {} self.lora_manager: DiffusionLoRAManager | None = None + self.stage_id = getattr(od_config, "stage_id", 0) self.init_device() # Create model runner self.model_runner = DiffusionModelRunner( @@ -159,8 +163,10 @@ def _create_profiler(self) -> WorkerProfiler | None: def _get_profiler(self) -> WorkerProfiler | None: return getattr(self, "profiler", None) - def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None) -> None: + def load_model(self, load_format: str = "default", custom_pipeline_name: str | None = None, **kwargs) -> None: """Load the diffusion model using DiffusionModelRunner.""" + load_format = kwargs.get("load_format", load_format) + custom_pipeline_name = kwargs.get("custom_pipeline_name", custom_pipeline_name) with ( set_forward_context(vllm_config=self.vllm_config, omni_diffusion_config=self.od_config), set_current_vllm_config(self.vllm_config), @@ -170,6 +176,8 @@ def load_model(self, load_format: str = "default", custom_pipeline_name: str | N load_format=load_format, custom_pipeline_name=custom_pipeline_name, ) + current_omni_platform.synchronize() + gc.collect() process_memory = get_process_gpu_memory(self.local_rank) if process_memory is not None: logger.info( @@ -284,77 +292,176 @@ def sleep(self, level: int = 1) -> bool: """ from vllm.device_allocator.cumem import CuMemAllocator - process_memory_before_sleep = get_process_gpu_memory(self.local_rank) - free_bytes_before_sleep = None - if process_memory_before_sleep is None: - free_bytes_before_sleep = current_omni_platform.get_free_memory() + allocator = CuMemAllocator.get_instance() + + usage_before = allocator.get_current_usage() - # Save the buffers before level 2 sleep if level == 2 and self.model_runner is not None: + if hasattr(self.model_runner, "graph_runners"): + self.model_runner.graph_runners.clear() + logger.info(f"[Worker {self.rank}] CUDA Graphs cleared.") model = self.model_runner.pipeline self._sleep_saved_buffers = {name: buffer.cpu().clone() for name, buffer in model.named_buffers()} - allocator = CuMemAllocator.get_instance() - allocator.sleep(offload_tags=("weights",) if level == 1 else tuple()) - process_memory_after_sleep = get_process_gpu_memory(self.local_rank) - if process_memory_before_sleep is not None and process_memory_after_sleep is not None: - freed_bytes = process_memory_before_sleep - process_memory_after_sleep - used_bytes = process_memory_after_sleep - accounting_scope = "process-scoped" + free_mem_before = current_omni_platform.get_free_memory() + + # Level 1: Offload weights; Level 2: Total Discard + offload_tags = ("weights",) if level == 1 else tuple() + allocator.sleep(offload_tags=offload_tags) + + current_omni_platform.empty_cache() + current_omni_platform.synchronize() + + free_mem_after = current_omni_platform.get_free_memory() + try: + total_mem = current_omni_platform.get_device_total_memory() + except (NotImplementedError, AttributeError): + total_mem = torch.cuda.get_device_properties(self.device).total_memory + + phys_freed_bytes = max(0, free_mem_after - free_mem_before) + phys_used_bytes = total_mem - free_mem_after + + if usage_before > 0: + logger.info( + f"[Diffusion Worker {self.rank}] Sleep Level {level}: " + f"physically freed {phys_freed_bytes / GiB_bytes:.2f} GiB, " + f"{phys_used_bytes / GiB_bytes:.2f} GiB is still in use." + ) else: - free_bytes_after_sleep = current_omni_platform.get_free_memory() - assert free_bytes_before_sleep is not None - device_id = self.device.index if self.device.index is not None else 0 - total = current_omni_platform.get_device_total_memory(device_id) - freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep - used_bytes = total - free_bytes_after_sleep - accounting_scope = "device-scoped fallback" - assert freed_bytes >= 0, "Memory usage increased after sleeping." - logger.info( - "Sleep mode (%s) freed %.2f GiB memory, %.2f GiB memory is still in use.", - accounting_scope, - freed_bytes / GiB_bytes, - used_bytes / GiB_bytes, - ) - return True + logger.info(f"[Worker {self.rank}] Sleep Level {level} completed (GPU was already empty).") + logger.info(f"[Worker {self.rank}] Memory usage before sleep: {usage_before / GiB_bytes:.2f} GiB.") + return usage_before def wake_up(self, tags: list[str] | None = None) -> bool: """ - Wake up the worker from sleep mode. See the sleep function - method for more details. + Wake up the worker from sleep mode. + + Re-activates the memory allocator for the specified tags and restores + model buffers from CPU back to GPU if they were saved during Level 2 sleep. Args: - tags: An optional list of tags to reallocate the worker memory - for specific memory allocations. Values must be in - `("weights")`. If None, all memory is reallocated. - wake_up should be called with all tags (or None) before the - worker is used again. + tags: List of memory pool tags to re-activate (e.g., ["weights"] + to match Level 1 sleep). If None, all pools are re-activated. """ from vllm.device_allocator.cumem import CuMemAllocator allocator = CuMemAllocator.get_instance() allocator.wake_up(tags) - - # Restore the buffers after level 2 sleep + current_omni_platform.synchronize() if len(self._sleep_saved_buffers) and self.model_runner is not None: model = self.model_runner.pipeline for name, buffer in model.named_buffers(): if name in self._sleep_saved_buffers: buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} + logger.info(f"[Worker {self.rank}] Buffers restored from CPU.") + logger.info(f"[Worker {self.rank}] Wake-up complete.") return True + def handle_sleep_task(self, task: OmniSleepTask) -> OmniACK: + from vllm_omni.platforms import current_omni_platform + + try: + if isinstance(task, dict): + task = OmniSleepTask(**task) + logger.info(f"[Worker {self.rank}] Handshake Received: Task {task.task_id}") + + current_omni_platform.synchronize() + usage_before = current_omni_platform.get_current_memory_usage(self.device) + self.sleep(level=task.level) + current_omni_platform.synchronize() + usage_after = current_omni_platform.get_current_memory_usage(self.device) + real_freed = max(0, usage_before - usage_after) + logger.info(f"[Worker {self.rank}] Preparing ACK: freed_bytes={real_freed / GiB_bytes:.2f} GiB.") + + # Ensure all ranks have completed sleep before measuring memory and sending ACK + if torch.distributed.is_initialized(): + t_freed = torch.tensor([float(real_freed)], device=self.device) + torch.distributed.all_reduce(t_freed) + real_freed = int(t_freed.item()) + + if self.rank != 0: + return None + + ack = OmniACK( + task_id=task.task_id, + status="SUCCESS", + stage_id=self.stage_id, + rank=self.rank, + freed_bytes=real_freed, + # return RL need metadata + metadata={ + "source": f"Platform_{current_omni_platform.get_device_name()}", + "total_freed_gib": f"{real_freed / GiB_bytes:.2f}", + "rank_residual_gib": f"{usage_after / GiB_bytes:.2f}", + }, + ) + logger.info(f"[Worker {self.rank}] ACK emitted. Freed {real_freed / GiB_bytes:.2f} GiB.") + return ack + except Exception as e: + logger.error(f"Sleep failed: {e}", exc_info=True) + if torch.distributed.is_initialized(): + try: + torch.distributed.barrier() + except Exception: + pass + return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e)) + + def handle_wake_task(self, task: OmniWakeTask) -> OmniACK: + from vllm_omni.platforms import current_omni_platform + + try: + if isinstance(task, dict): + task = OmniWakeTask(**task) + logger.info(f"[Worker {self.rank}] Responding to Wake-up Task: {task.task_id}") + self.wake_up(tags=task.tags) + + logger.info(f"[Worker {self.rank}] wake_up logic finished, entering barrier...") + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + current_omni_platform.synchronize() + usage_now = current_omni_platform.get_current_memory_usage(self.device) + current_used_gib = usage_now / (1024**3) + + if self.rank != 0: + return None + logger.info(f"[Worker {self.rank}] PASSED barrier, about to return to loop.") + + return OmniACK( + task_id=task.task_id, + status="SUCCESS", + stage_id=self.stage_id, + rank=self.rank, + metadata={ + "state": "WARM", + "source": f"Platform_{current_omni_platform.get_device_name()}", + "current_vram_gib": f"{current_used_gib:.2f}", + }, + ) + except Exception as e: + logger.error(f"Wake-up failed on Rank {self.rank}: {e}", exc_info=True) + if torch.distributed.is_initialized(): + try: + torch.distributed.barrier() + except Exception: + pass + return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e)) + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: """Get memory pool context for sleep mode support.""" - if self.od_config.enable_sleep_mode: + is_sleep_enabled = getattr(self.od_config, "enable_sleep_mode", False) + if is_sleep_enabled: + current_omni_platform.synchronize() + gc.collect() from vllm.device_allocator.cumem import CuMemAllocator allocator = CuMemAllocator.get_instance() if tag == "weights": assert allocator.get_current_usage() == 0, "Sleep mode can only be used for one instance per process." + logger.info(f"[Worker {self.rank}] Activating Diffusion CuMem pool for tag: {tag}") return allocator.use_memory_pool(tag=tag) - else: - return nullcontext() + return nullcontext() def shutdown(self) -> None: """Shutdown the worker and cleanup distributed environment.""" @@ -395,11 +502,13 @@ def __init__( od_config: OmniDiffusionConfig, gpu_id: int, broadcast_handle, + wake_event: mp.Event, worker_extension_cls: str | None = None, custom_pipeline_args: dict[str, Any] | None = None, ): self.od_config = od_config self.gpu_id = gpu_id + self.wake_event = wake_event # Inter-process Communication self.context = zmq.Context(io_threads=2) @@ -414,7 +523,13 @@ def __init__( if gpu_id == 0: self.result_mq = MessageQueue(n_reader=1, n_local_reader=1, local_reader_ranks=[0]) self.result_mq_handle = self.result_mq.export_handle() + WorkerProc._shared_result_handle = self.result_mq_handle logger.info(f"Worker {gpu_id} created result MessageQueue") + else: + handle = getattr(WorkerProc, "_shared_result_handle", None) + if handle: + self.result_mq = MessageQueue.create_from_handle(handle, gpu_id) + logger.info(f"Worker {gpu_id} attached to shared result MessageQueue") assert od_config.master_port is not None @@ -438,13 +553,17 @@ def _create_worker( ) return wrapper - def return_result(self, output: object): + def return_result(self, output: Any): """Reply to client, only on rank 0.""" if self.result_mq is not None: + if isinstance(output, OmniACK): + self.result_mq.enqueue(output) + return try: pack_diffusion_output_shm(output) except Exception as e: - logger.warning("SHM pack failed, falling back to raw enqueue: %s", e) + if hasattr(output, "output"): + logger.warning("SHM pack failed for model output: %s", e) self.result_mq.enqueue(output) def recv_message(self): @@ -480,20 +599,31 @@ def worker_busy_loop(self) -> None: while self._running: msg = None try: - msg = self.recv_message() - except Exception as e: - logger.error( - f"Error receiving message in worker loop: {e}", - exc_info=True, - ) + msg = self.mq.dequeue(timeout=1.0) + except Exception: + if self.wake_event and self.wake_event.is_set(): + self.wake_event.clear() + logger.info(f"Worker {self.gpu_id} caught OOB POKE, forcing wake-up sequence.") + msg = {"type": "wake_up", "task_id": "recovery-task", "tags": None} + else: + continue + if msg is None: continue if msg is None or len(msg) == 0: logger.warning("Worker %s: Received empty payload, ignoring", self.gpu_id) continue + if isinstance(msg, dict) and msg.get("type") == "sleep": + task = OmniSleepTask(level=msg.get("level", 2), task_id=msg.get("task_id", "local")) + ack = self.worker.handle_sleep_task(task) + self.return_result(ack) + elif isinstance(msg, dict) and msg.get("type") == "wake_up": + task = OmniWakeTask(tags=msg.get("tags"), task_id=msg.get("task_id", "local")) + ack = self.worker.handle_wake_task(task) + self.return_result(ack) # Route message based on type - if isinstance(msg, dict) and msg.get("type") == "rpc": + elif isinstance(msg, dict) and msg.get("type") == "rpc": try: result, should_reply = self.execute_rpc(msg) if should_reply: @@ -538,6 +668,7 @@ def worker_main( od_config: OmniDiffusionConfig, pipe_writer: mp.connection.Connection, broadcast_handle, + wake_event: mp.Event, worker_extension_cls: str | None = None, custom_pipeline_args: dict[str, Any] | None = None, ) -> None: @@ -549,6 +680,7 @@ def worker_main( od_config, gpu_id=rank, broadcast_handle=broadcast_handle, + wake_event=wake_event, worker_extension_cls=worker_extension_cls, custom_pipeline_args=custom_pipeline_args, ) @@ -574,6 +706,7 @@ def __init__( gpu_id: int, od_config: OmniDiffusionConfig, base_worker_class: type = DiffusionWorker, + wake_event: mp.Event = None, worker_extension_cls: str | None = None, custom_pipeline_args: dict[str, Any] | None = None, ): @@ -734,6 +867,12 @@ def wake_up(self, tags: list[str] | None = None) -> bool: """ return self.worker.wake_up(tags) + def handle_sleep_task(self, task): + return self.worker.handle_sleep_task(task) + + def handle_wake_task(self, task): + return self.worker.handle_wake_task(task) + def shutdown(self) -> None: """Shutdown the worker and cleanup resources.""" return self.worker.shutdown() diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py index 6c92d7952de..55d119ad65c 100644 --- a/vllm_omni/engine/__init__.py +++ b/vllm_omni/engine/__init__.py @@ -76,6 +76,42 @@ class OmniEngineCoreRequest(EngineCoreRequest): # Optional additional information dictionary (serialized) additional_information: AdditionalInformationPayload | None = None + @classmethod + def from_request( + cls, + request: EngineCoreRequest, + *, + prompt_embeds: torch.Tensor | None = None, + additional_information: AdditionalInformationPayload | None = None, + ) -> "OmniEngineCoreRequest": + """Clone an EngineCoreRequest into an OmniEngineCoreRequest with optional payload overrides.""" + + if prompt_embeds is None: + prompt_embeds = request.prompt_embeds + if additional_information is None: + additional_information = getattr(request, "additional_information", None) + + return cls( + request_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + mm_features=request.mm_features, + sampling_params=request.sampling_params, + pooling_params=request.pooling_params, + arrival_time=request.arrival_time, + lora_request=request.lora_request, + cache_salt=request.cache_salt, + data_parallel_rank=request.data_parallel_rank, + prompt_embeds=prompt_embeds, + client_index=request.client_index, + current_wave=request.current_wave, + priority=request.priority, + trace_headers=request.trace_headers, + resumable=request.resumable, + external_req_id=request.external_req_id, + reasoning_ended=request.reasoning_ended, + additional_information=additional_information, + ) + class OmniEngineCoreOutput(EngineCoreOutput): pooling_output: dict[str, torch.Tensor] | None = None diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index d98ce7d419d..c93bd32c2f1 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -6,11 +6,12 @@ from dataclasses import dataclass, field, fields from typing import Any -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.logger import init_logger from vllm_omni.config import OmniModelConfig from vllm_omni.engine.output_modality import OutputModality +from vllm_omni.platforms import current_omni_platform from vllm_omni.plugins import load_omni_general_plugins logger = init_logger(__name__) @@ -139,11 +140,30 @@ class OmniEngineArgs(EngineArgs): hf_config_name: str | None = None custom_process_next_stage_input_func: str | None = None stage_connector_spec: dict[str, Any] = field(default_factory=dict) + subtalker_sampling_params: dict[str, Any] | None = None async_chunk: bool = False omni_kv_config: dict | None = None quantization_config: Any | None = None worker_type: str | None = None task_type: str | None = None + worker_cls: str = None + enable_sleep_mode: bool = False + omni: bool = False + + @classmethod + def _add_omni_specific_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + try: + parser.add_argument("--omni", action="store_true", default=False, help="Enable Omni engine features.") + except argparse.ArgumentError: + pass + try: + parser.add_argument( + "--enable-sleep-mode", action="store_true", default=False, help="Enable GPU memory pool for sleep mode." + ) + except argparse.ArgumentError: + pass + return parser + omni_master_address: str | None = None omni_master_port: int | None = None stage_configs_path: str | None = None @@ -152,6 +172,11 @@ class OmniEngineArgs(EngineArgs): custom_pipeline_args: dict[str, Any] | None = None def __post_init__(self) -> None: + if self.worker_cls is None: + if self.worker_type == "ar": + self.worker_cls = current_omni_platform.get_omni_ar_worker_cls() + elif self.worker_type == "generation": + self.worker_cls = current_omni_platform.get_omni_generation_worker_cls() load_omni_general_plugins() super().__post_init__() @@ -291,11 +316,21 @@ def create_model_config(self) -> OmniModelConfig: hf_config_name=self.hf_config_name, custom_process_next_stage_input_func=self.custom_process_next_stage_input_func, stage_connector_config=stage_connector_config, + subtalker_sampling_params=self.subtalker_sampling_params, omni_kv_config=self.omni_kv_config, task_type=self.task_type, ) return omni_config + +@dataclass +class OmniAsyncEngineArgs(AsyncEngineArgs, OmniEngineArgs): + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = AsyncEngineArgs.add_cli_args(parser) + parser = OmniEngineArgs._add_omni_specific_args(parser) + return parser + @property def output_modality(self) -> OutputModality: """Parse engine_output_type into a type-safe OutputModality flag.""" diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index d02af153c45..32450bf2024 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -30,7 +30,6 @@ from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType from vllm.logger import init_logger -from vllm.tokenizers import cached_tokenizer_from_config from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.input_processor import InputProcessor @@ -46,7 +45,6 @@ ) from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.engine.orchestrator import Orchestrator -from vllm_omni.engine.output_processor import MultimodalOutputProcessor from vllm_omni.engine.serialization import ( deserialize_additional_information, serialize_additional_information, @@ -63,17 +61,17 @@ register_stage_with_omni_master, ) from vllm_omni.engine.stage_init_utils import ( - StartedLlmStage, + LogicalStageInitPlan, + ReplicaInitPlan, _inject_inferred_kv_tp_topology, acquire_device_locks, build_diffusion_config, build_engine_args_dict, + build_llm_stage_output_processor, + build_stage0_input_processor, build_vllm_config, - cleanup_failed_stage_initialization, - close_started_llm_stage, compute_replica_layout, extract_stage_metadata, - finalize_initialized_stages, get_stage_connector_spec, initialize_diffusion_stage, inject_kv_stage_info, @@ -89,7 +87,6 @@ inject_omni_kv_config, load_and_resolve_stage_configs, ) -from vllm_omni.inputs.preprocess import OmniInputPreprocessor from vllm_omni.platforms import current_omni_platform if TYPE_CHECKING: @@ -97,6 +94,8 @@ logger = init_logger(__name__) +_STARTUP_POLL_INTERVAL_S = 1.0 + # ============================================================================ # Parent-EngineArgs field-routing contracts (consumed by @@ -119,16 +118,6 @@ _PARENT_ARGS_NO_WARN: frozenset[str] = frozenset({"model"}) -def _patch_generation_config_if_needed(model_config: Any) -> None: - """Ensure try_get_generation_config won't crash for models whose HF - config.json lacks model_type (e.g. CosyVoice3). We probe it once; - if it raises, we monkey-patch the method to return None.""" - try: - model_config.try_get_generation_config() - except Exception: - model_config.try_get_generation_config = lambda: {} - - def _inject_global_id(target: Any, request_id: str) -> None: """Inject global_request_id into a prompt dict's additional_information.""" if isinstance(target, dict): @@ -161,24 +150,9 @@ def _upgrade_to_omni_request( if prompt_embeds is None and additional_information is None: return request - return OmniEngineCoreRequest( - request_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - mm_features=request.mm_features, - sampling_params=request.sampling_params, - pooling_params=request.pooling_params, - arrival_time=request.arrival_time, - lora_request=request.lora_request, - cache_salt=request.cache_salt, - data_parallel_rank=request.data_parallel_rank, + return OmniEngineCoreRequest.from_request( + request, prompt_embeds=prompt_embeds, - client_index=request.client_index, - current_wave=request.current_wave, - priority=request.priority, - trace_headers=request.trace_headers, - resumable=request.resumable, - external_req_id=request.external_req_id, - reasoning_ended=request.reasoning_ended, additional_information=additional_information, ) @@ -193,24 +167,8 @@ def _apply_omni_final_stage_metadata( merged = deserialize_additional_information(request.additional_information) merged["omni_final_stage_id"] = final_stage_id payload = serialize_additional_information(merged) - return OmniEngineCoreRequest( - request_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - mm_features=request.mm_features, - sampling_params=request.sampling_params, - pooling_params=request.pooling_params, - arrival_time=request.arrival_time, - lora_request=request.lora_request, - cache_salt=request.cache_salt, - data_parallel_rank=request.data_parallel_rank, - prompt_embeds=request.prompt_embeds, - client_index=request.client_index, - current_wave=request.current_wave, - priority=request.priority, - trace_headers=request.trace_headers, - resumable=request.resumable, - external_req_id=request.external_req_id, - reasoning_ended=request.reasoning_ended, + return OmniEngineCoreRequest.from_request( + request, additional_information=payload, ) @@ -311,6 +269,7 @@ def __init__( ) self.config_path, self.stage_configs = self._resolve_stage_configs(model, kwargs) + self._validate_single_stage_mode_replica_constraints() self.num_stages = len(self.stage_configs) stage0_args = getattr(self.stage_configs[0], "engine_args", None) if self.num_stages > 0 else None @@ -343,22 +302,7 @@ def __init__( name="orchestrator", ) self.orchestrator_thread.start() - - # Wait for stage/runtime initialization result from orchestrator thread. - try: - startup_future.result(timeout=startup_timeout) - except concurrent.futures.TimeoutError as e: - try: - self.shutdown() - except Exception: - logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup timeout") - raise TimeoutError(f"Orchestrator did not become ready within {startup_timeout}s") from e - except Exception: - try: - self.shutdown() - except Exception: - logger.exception("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure") - raise + self._wait_for_orchestrator_init(startup_future, startup_timeout) # Stage runtime fields are assigned directly on self by the bootstrap thread. self._weak_finalizer = weakref.finalize( @@ -372,613 +316,658 @@ def __init__( logger.info(f"[AsyncOmniEngine] Orchestrator ready with {self.num_stages} stages") - def _launch_llm_stage( + @staticmethod + def _cleanup_launched_llm_resources( + *, + stage_id: int, + proc: Any = None, + engine_manager: Any = None, + coordinator: Any = None, + ) -> None: + """Release launch-only LLM resources when client creation never completed.""" + + if proc is not None: + try: + terminate_alive_proc(proc) + except Exception as cleanup_error: + logger.warning( + "[AsyncOmniEngine] Failed to terminate process for stage %s: %s", + stage_id, + cleanup_error, + ) + + for resource, resource_name in ( + (engine_manager, "engine manager"), + (coordinator, "coordinator"), + ): + if resource is None: + continue + shutdown = getattr(resource, "shutdown", None) + close = getattr(resource, "close", None) + try: + if callable(shutdown): + shutdown() + elif callable(close): + close() + except Exception as cleanup_error: + logger.warning( + "[AsyncOmniEngine] Failed to cleanup launched %s for stage %s: %s", + resource_name, + stage_id, + cleanup_error, + ) + + @staticmethod + def _collect_initialized_clients_for_cleanup( + stage_pools: Sequence[Any], + initialized_clients_by_stage: Mapping[int, Sequence[Any | None]], + ) -> list[Any]: + """Collect initialized clients exactly once for failure cleanup.""" + + collected: list[Any] = [] + seen: set[int] = set() + + def _add_client(client: Any) -> None: + if client is None: + return + client_id = id(client) + if client_id in seen: + return + seen.add(client_id) + collected.append(client) + + for pool in stage_pools: + for client in getattr(pool, "clients", ()): + _add_client(client) + + for clients in initialized_clients_by_stage.values(): + for client in clients: + _add_client(client) + + return collected + + @staticmethod + def _shutdown_initialized_clients(clients: Sequence[Any]) -> None: + """Best-effort shutdown for attached clients after init failure.""" + + for client in reversed(list(clients)): + if client is None: + continue + try: + client.shutdown() + except Exception as cleanup_error: + logger.warning( + "[AsyncOmniEngine] Failed to shutdown initialized client after init failure: %s", + cleanup_error, + ) + + def _validate_single_stage_mode_replica_constraints(self) -> None: + """Reject replica fan-out in single-stage mode until startup is replica-aware.""" + if not self.single_stage_mode: + return + + unsupported: list[tuple[int, int]] = [] + for idx, stage_cfg in enumerate(self.stage_configs): + runtime_cfg = getattr(stage_cfg, "runtime", {}) + num_replicas = int( + runtime_cfg.get("num_replicas", 1) + if hasattr(runtime_cfg, "get") + else getattr(runtime_cfg, "num_replicas", 1) + ) + if num_replicas <= 1: + continue + stage_id = int(getattr(stage_cfg, "stage_id", idx)) + unsupported.append((stage_id, num_replicas)) + + if unsupported: + # TODO(Peiqi): single_stage_mode / headless launch still assumes one + # OmniMasterServer endpoint allocation per logical stage. Support + # per-replica startup only after remote/local startup paths are made + # replica-aware end-to-end. + raise ValueError(f"single_stage_mode does not support num_replicas > 1 yet; found {unsupported}") + + def _build_logical_stage_init_plans( + self, + omni_transfer_config: Any, + replicas_per_stage: Sequence[int], + replica_devices_map: Mapping[int, Sequence[str]], + ) -> tuple[list[LogicalStageInitPlan], Any]: + """Build startup plans for every logical stage and replica.""" + + prompt_expand_func = None + stage_plans: list[LogicalStageInitPlan] = [] + + for stage_idx, stage_cfg in enumerate(self.stage_configs): + base_metadata = extract_stage_metadata(stage_cfg) + configured_stage_id = base_metadata.stage_id + if base_metadata.prompt_expand_func is not None: + prompt_expand_func = base_metadata.prompt_expand_func + + stage_connector_spec = get_stage_connector_spec( + omni_transfer_config=omni_transfer_config, + stage_id=configured_stage_id, + async_chunk=self.async_chunk, + ) + omni_kv_connector = resolve_omni_kv_config_for_stage(omni_transfer_config, configured_stage_id) + num_replicas = replicas_per_stage[stage_idx] + launch_mode = "local" + if ( + self.single_stage_mode + and self._single_stage_id_filter is not None + and configured_stage_id != self._single_stage_id_filter + ): + launch_mode = "remote" + + replicas: list[ReplicaInitPlan] = [] + stage_vllm_config = None + executor_class = None + if base_metadata.stage_type != "diffusion": + engine_args_dict = build_engine_args_dict( + stage_cfg, + self.model, + stage_connector_spec=stage_connector_spec, + ) + omni_conn_cfg, omni_from, omni_to = omni_kv_connector + if omni_conn_cfg: + omni_kv = engine_args_dict.get("omni_kv_config") or {} + if not isinstance(omni_kv, dict): + omni_kv = dict(omni_kv) + omni_kv["connector_config"] = omni_conn_cfg + omni_kv["omni_from_stage"] = omni_from + omni_kv["omni_to_stage"] = omni_to + omni_kv.setdefault("stage_id", configured_stage_id) + engine_args_dict["omni_kv_config"] = omni_kv + if self.stage_configs: + _inject_inferred_kv_tp_topology( + engine_args_dict.get("omni_kv_config"), + configured_stage_id, + self.stage_configs, + ) + stage_vllm_config, executor_class = build_vllm_config( + stage_cfg, + self.model, + stage_connector_spec=stage_connector_spec, + engine_args_dict=engine_args_dict, + ) + + for replica_id in range(num_replicas): + replica_cfg = copy.deepcopy(stage_cfg) if replica_id > 0 else stage_cfg + if stage_idx in replica_devices_map: + replica_cfg.runtime.devices = replica_devices_map[stage_idx][replica_id] + + replica_metadata = extract_stage_metadata(replica_cfg) + replica_metadata.replica_id = replica_id + if self.single_stage_mode: + replica_metadata.runtime_cfg = None + + replicas.append( + ReplicaInitPlan( + replica_id=replica_id, + num_replicas=num_replicas, + launch_mode=launch_mode, + stage_cfg=replica_cfg, + metadata=replica_metadata, + stage_connector_spec=stage_connector_spec, + omni_kv_connector=omni_kv_connector, + stage_vllm_config=stage_vllm_config, + executor_class=executor_class, + ) + ) + + stage_plans.append( + LogicalStageInitPlan( + stage_idx=stage_idx, + configured_stage_id=configured_stage_id, + replicas=replicas, + ) + ) + + return stage_plans, prompt_expand_func + + def _start_omni_master_server(self, stage_plans: Sequence[LogicalStageInitPlan]) -> None: + """Start OmniMasterServer for single-stage mode.""" + + if not self._omni_master_address or not self._omni_master_port: + raise ValueError( + "AsyncOmniEngine single_stage_mode requires both omni_master_address and omni_master_port to be set." + ) + + all_stage_ids: list[int] = [] + seen_stage_ids: set[int] = set() + for plan in stage_plans: + stage_id = plan.configured_stage_id + if stage_id in seen_stage_ids: + raise ValueError( + f"Duplicate stage_id {stage_id!r} detected among configured stages; stage_ids must be unique." + ) + seen_stage_ids.add(stage_id) + all_stage_ids.append(stage_id) + + self._omni_master_server = OmniMasterServer( + master_address=self._omni_master_address, + master_port=self._omni_master_port, + stage_ids=all_stage_ids, + ) + self._omni_master_server.start() + logger.info( + "[AsyncOmniEngine] OmniMasterServer started for stages %s", + all_stage_ids, + ) + + def _initialize_llm_replica( self, - stage_cfg: Any, - metadata: Any, - stage_connector_spec: dict[str, Any], + plan: ReplicaInitPlan, stage_init_timeout: int, llm_stage_launch_lock: threading.Lock, - omni_kv_connector: tuple[dict[str, Any] | None, str | None, str | None] = (None, None, None), - ) -> StartedLlmStage: - """Launch one LLM stage to READY state in a helper thread.""" - started_stage: StartedLlmStage | None = None + ) -> Any: + """Initialize one LLM replica end-to-end.""" + + proc = None + engine_manager = None + coordinator = None + stage_client = None lock_fds: list[int] = [] device_control_env = current_omni_platform.device_control_env_var + stage_cfg = plan.stage_cfg + try: - proc = None - handshake_address = None - with ExitStack() as launch_stack: - with llm_stage_launch_lock: - previous_visible_devices = os.environ.get(device_control_env) - try: - setup_stage_devices(metadata.stage_id, metadata.runtime_cfg) - engine_args_dict = build_engine_args_dict( - stage_cfg, - self.model, - stage_connector_spec=stage_connector_spec, - ) - omni_conn_cfg, omni_from, omni_to = omni_kv_connector - if omni_conn_cfg: - omni_kv = engine_args_dict.get("omni_kv_config") or {} - if not isinstance(omni_kv, dict): - omni_kv = dict(omni_kv) - omni_kv["connector_config"] = omni_conn_cfg - omni_kv["omni_from_stage"] = omni_from - omni_kv["omni_to_stage"] = omni_to - omni_kv.setdefault("stage_id", metadata.stage_id) - engine_args_dict["omni_kv_config"] = omni_kv - if self.stage_configs: - _inject_inferred_kv_tp_topology( - engine_args_dict.get("omni_kv_config"), - metadata.stage_id, - self.stage_configs, + if plan.launch_mode == "remote": + assert self._omni_master_server is not None + raw_stage_cfg = self._omni_master_server.get_stage_config( + plan.metadata.stage_id, + timeout_s=stage_init_timeout, + ) + if raw_stage_cfg is None: + raise ValueError(f"Remote stage {plan.metadata.stage_id} registered without stage config") + vllm_config = plan.stage_vllm_config + executor_class = plan.executor_class + assert vllm_config is not None + assert executor_class is not None + vllm_config.parallel_config.data_parallel_size_local = 0 + launch_cm = connect_remote_engine_cores( + vllm_config=vllm_config, + omni_master_server=self._omni_master_server, + stage_id=plan.metadata.stage_id, + ) + logger.info( + "[AsyncOmniEngine] Stage %s remote engine handshake started", + plan.metadata.stage_id, + ) + with launch_cm as (engine_manager, coordinator, addresses): + client_addresses: dict[str, str] = { + "input_address": addresses.inputs[0], + "output_address": addresses.outputs[0], + } + if addresses.frontend_stats_publish_address is not None: + client_addresses["stats_update_address"] = addresses.frontend_stats_publish_address + stage_client = StageEngineCoreClientBase.make_async_mp_client( + vllm_config=vllm_config, + executor_class=executor_class, + metadata=plan.metadata, + client_addresses=client_addresses, + engine_manager=engine_manager, + coordinator=coordinator, + ) + else: + handshake_address = None + with ExitStack() as launch_stack: + with llm_stage_launch_lock: + previous_visible_devices = os.environ.get(device_control_env) + try: + setup_stage_devices(plan.metadata.stage_id, plan.metadata.runtime_cfg) + vllm_config = plan.stage_vllm_config + executor_class = plan.executor_class + assert vllm_config is not None + assert executor_class is not None + engine_args_dict = build_engine_args_dict( + stage_cfg, + self.model, + stage_connector_spec=plan.stage_connector_spec, + ) + lock_fds = acquire_device_locks( + plan.metadata.stage_id, + engine_args_dict, + stage_init_timeout, ) - vllm_config, executor_class = build_vllm_config( - stage_cfg, - self.model, - stage_connector_spec=stage_connector_spec, - engine_args_dict=engine_args_dict, - ) - lock_fds = acquire_device_locks( - metadata.stage_id, - engine_args_dict, - stage_init_timeout, - ) - if self.single_stage_mode and self._omni_master_server is not None: - engine_manager, coordinator, addresses = launch_stack.enter_context( - launch_omni_core_engines( + if self.single_stage_mode and self._omni_master_server is not None: + engine_manager, coordinator, addresses = launch_stack.enter_context( + launch_omni_core_engines( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + omni_master_server=self._omni_master_server, + stage_id=plan.metadata.stage_id, + stage_config=stage_cfg, + ) + ) + else: + addresses, proc, handshake_address = spawn_stage_core( vllm_config=vllm_config, executor_class=executor_class, log_stats=False, - omni_master_server=self._omni_master_server, - stage_id=metadata.stage_id, - stage_config=stage_cfg, ) + logger.info( + "[AsyncOmniEngine] Stage %s engine launch started", + plan.metadata.stage_id, ) - started_stage = StartedLlmStage( - stage_id=metadata.stage_id, - metadata=metadata, - vllm_config=vllm_config, - executor_class=executor_class, - addresses=addresses, - engine_manager=engine_manager, - coordinator=coordinator, - ) - else: - addresses, proc, handshake_address = spawn_stage_core( - vllm_config=vllm_config, - executor_class=executor_class, - log_stats=False, - ) - started_stage = StartedLlmStage( - stage_id=metadata.stage_id, - metadata=metadata, - vllm_config=vllm_config, - executor_class=executor_class, - addresses=addresses, - proc=proc, - ) - logger.info("[AsyncOmniEngine] Stage %s engine launch started", metadata.stage_id) - finally: - if previous_visible_devices is None: - current_omni_platform.unset_device_control_env_var() - else: - current_omni_platform.set_device_control_env_var(previous_visible_devices) + finally: + if previous_visible_devices is None: + current_omni_platform.unset_device_control_env_var() + else: + current_omni_platform.set_device_control_env_var(previous_visible_devices) + + if self.single_stage_mode and self._omni_master_server is not None: + launch_stack.close() + else: + assert proc is not None + assert handshake_address is not None + complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout) + logger.info( + "[AsyncOmniEngine] Stage %s engine startup completed", + plan.metadata.stage_id, + ) + + client_addresses: dict[str, str] = { + "input_address": addresses.inputs[0], + "output_address": addresses.outputs[0], + } + if addresses.frontend_stats_publish_address is not None: + client_addresses["stats_update_address"] = addresses.frontend_stats_publish_address + stage_client = StageEngineCoreClientBase.make_async_mp_client( + vllm_config=vllm_config, + executor_class=executor_class, + metadata=plan.metadata, + client_addresses=client_addresses, + proc=proc, + engine_manager=engine_manager, + coordinator=coordinator, + ) - # After StageEngineCoreProc has been spawned it carries its - # stage-specific device visibility into descendants, so the - # slow HELLO/READY handshake can run without holding the - # process-wide launch lock. - if self.single_stage_mode and self._omni_master_server is not None: - launch_stack.close() - else: - assert proc is not None - assert handshake_address is not None - complete_stage_handshake(proc, handshake_address, addresses, vllm_config, stage_init_timeout) - logger.info("[AsyncOmniEngine] Stage %s engine startup completed", metadata.stage_id) - - assert started_stage is not None - return started_stage + logger.info("[AsyncOmniEngine] Stage %s initialized", plan.metadata.stage_id) + return stage_client except Exception: - if started_stage is not None: - close_started_llm_stage(started_stage) + if stage_client is not None: + try: + stage_client.shutdown() + except Exception as cleanup_error: + logger.warning( + "[AsyncOmniEngine] Failed to cleanup stage %s after attach failure: %s", + plan.metadata.stage_id, + cleanup_error, + ) + else: + self._cleanup_launched_llm_resources( + stage_id=plan.metadata.stage_id, + proc=proc, + engine_manager=engine_manager, + coordinator=coordinator, + ) raise finally: if lock_fds: release_device_locks(lock_fds) - def _create_remote_llm_stage( + def _initialize_diffusion_replica( self, - stage_cfg: Any, - metadata: Any, - stage_connector_spec: dict[str, Any], + plan: ReplicaInitPlan, stage_init_timeout: int, - omni_master_server: OmniMasterServer, - ) -> StartedLlmStage: - """Attach to a remote engine core and wait for its startup handshake.""" - started_stage: StartedLlmStage | None = None - try: - raw_stage_cfg = omni_master_server.get_stage_config( - metadata.stage_id, - timeout_s=stage_init_timeout, - ) - if raw_stage_cfg is None: - raise ValueError(f"Remote stage {metadata.stage_id} registered without stage config") - stage_cfg = OmegaConf.create(raw_stage_cfg) - engine_args_dict = build_engine_args_dict( - stage_cfg, - self.model, - stage_connector_spec=stage_connector_spec, - ) - vllm_config, executor_class = build_vllm_config( - stage_cfg, - self.model, - stage_connector_spec=stage_connector_spec, - engine_args_dict=engine_args_dict, - ) - vllm_config.parallel_config.data_parallel_size_local = 0 - launch_cm = connect_remote_engine_cores( - vllm_config=vllm_config, - omni_master_server=omni_master_server, - stage_id=metadata.stage_id, - ) - logger.info("[AsyncOmniEngine] Stage %s remote engine handshake started", metadata.stage_id) - with launch_cm as (engine_manager, coordinator, addresses): - started_stage = StartedLlmStage( - stage_id=metadata.stage_id, - metadata=metadata, - vllm_config=vllm_config, - executor_class=executor_class, - engine_manager=engine_manager, - coordinator=coordinator, - addresses=addresses, - ) - logger.info("[AsyncOmniEngine] Stage %s remote engine startup completed", metadata.stage_id) - assert started_stage is not None - return started_stage - except Exception: - if started_stage is not None: - close_started_llm_stage(started_stage) - raise + stage_launch_lock: threading.Lock, + ) -> Any: + """Initialize one diffusion replica end-to-end.""" - def _launch_diffusion_stage( - self, - stage_cfg: Any, - metadata: Any, - omni_master_server: OmniMasterServer, - ) -> StageDiffusionClient: - """Launch a local diffusion stage on OmniMasterServer-allocated sockets.""" + client = None proc = None try: - od_config = build_diffusion_config(self.model, stage_cfg, metadata) - handshake_address, request_address, response_address = register_stage_with_omni_master( - omni_master_address=omni_master_server.address, - omni_master_port=omni_master_server.port, - omni_stage_id=metadata.stage_id, - omni_stage_config=stage_cfg, - return_addresses=True, - ) - logger.info( - "[AsyncOmniEngine] Stage %s diffusion registration completed", - metadata.stage_id, - ) - proc, _, _, _ = spawn_diffusion_proc( - self.model, - od_config, - handshake_address=handshake_address, - request_address=request_address, - response_address=response_address, - ) - complete_diffusion_handshake(proc, handshake_address) + if plan.launch_mode == "remote": + assert self._omni_master_server is not None + remote_stage_cfg = OmegaConf.create( + self._omni_master_server.get_stage_config( + plan.metadata.stage_id, + timeout_s=stage_init_timeout, + ) + ) + remote_metadata = extract_stage_metadata(remote_stage_cfg) + addresses = self._omni_master_server.get_zmq_addresses(plan.metadata.stage_id) + logger.info( + "[AsyncOmniEngine] Stage %s remote diffusion startup completed", + plan.metadata.stage_id, + ) + client = StageDiffusionClient.from_addresses( + remote_metadata, + request_address=addresses.inputs[0], + response_address=addresses.outputs[0], + batch_size=self.diffusion_batch_size, + ) + else: + device_control_env = current_omni_platform.device_control_env_var + with stage_launch_lock: + previous_visible_devices = os.environ.get(device_control_env) + try: + setup_stage_devices(plan.metadata.stage_id, plan.metadata.runtime_cfg) + omni_conn_cfg, omni_from, omni_to = plan.omni_kv_connector + if omni_conn_cfg: + inject_omni_kv_config(plan.stage_cfg, omni_conn_cfg, omni_from, omni_to) + inject_kv_stage_info(plan.stage_cfg, plan.metadata.stage_id, self.stage_configs) + if self.single_stage_mode: + assert self._omni_master_server is not None + od_config = build_diffusion_config(self.model, plan.stage_cfg, plan.metadata) + handshake_address, request_address, response_address = register_stage_with_omni_master( + omni_master_address=self._omni_master_server.address, + omni_master_port=self._omni_master_server.port, + omni_stage_id=plan.metadata.stage_id, + omni_stage_config=plan.stage_cfg, + return_addresses=True, + ) + logger.info( + "[AsyncOmniEngine] Stage %s diffusion registration completed", + plan.metadata.stage_id, + ) + proc, _, _, _ = spawn_diffusion_proc( + self.model, + od_config, + handshake_address=handshake_address, + request_address=request_address, + response_address=response_address, + ) + complete_diffusion_handshake(proc, handshake_address) + logger.info( + "[AsyncOmniEngine] Stage %s diffusion startup completed", + plan.metadata.stage_id, + ) + client = StageDiffusionClient.from_addresses( + plan.metadata, + request_address=request_address, + response_address=response_address, + proc=proc, + batch_size=self.diffusion_batch_size, + ) + else: + client = initialize_diffusion_stage( + plan.metadata.stage_id, + self.model, + plan.stage_cfg, + plan.metadata, + stage_init_timeout=stage_init_timeout, + batch_size=self.diffusion_batch_size, + use_inline=self.num_stages == 1 and plan.num_replicas == 1, + ) + finally: + if previous_visible_devices is None: + current_omni_platform.unset_device_control_env_var() + else: + current_omni_platform.set_device_control_env_var(previous_visible_devices) + logger.info( - "[AsyncOmniEngine] Stage %s diffusion startup completed", - metadata.stage_id, - ) - return StageDiffusionClient.from_addresses( - metadata, - request_address=request_address, - response_address=response_address, - proc=proc, - batch_size=self.diffusion_batch_size, + "[AsyncOmniEngine] Stage %s replica %s initialized (diffusion, batch_size=%d, devices=%s)", + plan.metadata.stage_id, + plan.replica_id, + self.diffusion_batch_size, + getattr(getattr(plan.stage_cfg, "runtime", None), "devices", "default"), ) + return client except Exception: if proc is not None: terminate_alive_proc(proc) raise - def _create_remote_diffusion_stage( + def _initialize_replica( self, - metadata: Any, + plan: ReplicaInitPlan, stage_init_timeout: int, - omni_master_server: OmniMasterServer, - ) -> StageDiffusionClient: - """Attach to a remote diffusion stage registered with OmniMasterServer.""" - remote_stage_cfg = OmegaConf.create( - omni_master_server.get_stage_config( - metadata.stage_id, - timeout_s=stage_init_timeout, - ) - ) - remote_metadata = extract_stage_metadata(remote_stage_cfg) - addresses = omni_master_server.get_zmq_addresses(metadata.stage_id) - logger.info( - "[AsyncOmniEngine] Stage %s remote diffusion startup completed", - metadata.stage_id, - ) - return StageDiffusionClient.from_addresses( - remote_metadata, - request_address=addresses.inputs[0], - response_address=addresses.outputs[0], - batch_size=self.diffusion_batch_size, - ) + stage_launch_lock: threading.Lock, + ) -> Any: + """Initialize one replica, regardless of backend type.""" - def _attach_llm_stage( + if plan.metadata.stage_type == "diffusion": + return self._initialize_diffusion_replica(plan, stage_init_timeout, stage_launch_lock) + return self._initialize_llm_replica(plan, stage_init_timeout, stage_launch_lock) + + def _initialize_stage_replicas( self, - started: StartedLlmStage, - ) -> tuple[Any, Any | None, Any, InputProcessor | None]: - """Attach a READY LLM stage to the orchestrator event loop.""" + stage_plans: Sequence[LogicalStageInitPlan], + stage_init_timeout: int, + ) -> dict[int, list[Any | None]]: + """Initialize all stage replicas in parallel.""" - client_addresses: dict[str, str] = { - "input_address": started.addresses.inputs[0], - "output_address": started.addresses.outputs[0], + stage_launch_lock = threading.Lock() + initialized_clients_by_stage: dict[int, list[Any | None]] = { + plan.stage_idx: [None] * len(plan.replicas) for plan in stage_plans } - if started.addresses.frontend_stats_publish_address is not None: - client_addresses["stats_update_address"] = started.addresses.frontend_stats_publish_address + total_replicas = sum(len(plan.replicas) for plan in stage_plans) + future_to_replica: dict[concurrent.futures.Future[Any], tuple[int, int]] = {} + + with concurrent.futures.ThreadPoolExecutor( + max_workers=max(1, total_replicas), + thread_name_prefix="stage-init", + ) as init_executor: + for plan in stage_plans: + for replica in plan.replicas: + future = init_executor.submit( + self._initialize_replica, + replica, + stage_init_timeout, + stage_launch_lock, + ) + future_to_replica[future] = (plan.stage_idx, replica.replica_id) - try: - stage_client = StageEngineCoreClientBase.make_async_mp_client( - vllm_config=started.vllm_config, - executor_class=started.executor_class, - metadata=started.metadata, - client_addresses=client_addresses, - proc=started.proc, - engine_manager=started.engine_manager, - coordinator=started.coordinator, - ) - started.proc = None - started.engine_manager = None - started.coordinator = None - except Exception: - close_started_llm_stage(started) - raise + try: + for future in concurrent.futures.as_completed(future_to_replica): + stage_idx, replica_id = future_to_replica[future] + initialized_clients_by_stage[stage_idx][replica_id] = future.result() + except Exception as exc: + for future, (stage_idx, replica_id) in future_to_replica.items(): + if not future.done() or future.cancelled() or future.exception() is not None: + continue + if initialized_clients_by_stage[stage_idx][replica_id] is None: + initialized_clients_by_stage[stage_idx][replica_id] = future.result() + setattr(exc, "_initialized_clients_by_stage", initialized_clients_by_stage) + raise - try: + return initialized_clients_by_stage + + def _assemble_stage_pools( + self, + stage_plans: Sequence[LogicalStageInitPlan], + initialized_clients_by_stage: Mapping[int, Sequence[Any | None]], + ) -> list[StagePool]: + """Assemble logical stage pools and update top-level stage metadata.""" + + stage_pools: list[StagePool] = [] + default_sampling_params_list: list[Any] = [] + stage_metadata_list: list[dict[str, Any]] = [] + + for plan in stage_plans: + replica_clients = initialized_clients_by_stage[plan.stage_idx] + first_client = replica_clients[0] if replica_clients else None + if first_client is None: + raise RuntimeError(f"Stage {plan.stage_idx} initialization completed with a missing client") + + clients = [client for client in replica_clients if client is not None] + stage_vllm_config = None output_processor = None - if getattr(started.metadata, "replica_id", 0) == 0: - if started.vllm_config.model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = cached_tokenizer_from_config( - model_config=started.vllm_config.model_config, - ) - output_processor = MultimodalOutputProcessor( - tokenizer=tokenizer, - log_stats=False, - engine_core_output_type=started.metadata.engine_output_type, - ) - input_processor = None - if started.stage_id == 0: - # Some omni models (e.g. CosyVoice3) have an empty HF - # config.json without model_type, which causes - # try_get_generation_config -> AutoConfig.from_pretrained - # to raise ValueError. Patch it to return None so - # InputProcessor doesn't crash. - _patch_generation_config_if_needed(started.vllm_config.model_config) - input_processor = InputProcessor(vllm_config=started.vllm_config) - # Use omni preprocessor so text-only prompts with - # mm_processor_kwargs (e.g. GLM-Image t2i target_h/target_w) - # still go through multimodal processor path. - input_processor.input_preprocessor = OmniInputPreprocessor( - vllm_config=started.vllm_config, - renderer=input_processor.renderer, - ) - except Exception: - try: - stage_client.shutdown() - except Exception as cleanup_error: - logger.warning( - "[AsyncOmniEngine] Failed to cleanup stage %s after attach failure: %s", - started.stage_id, - cleanup_error, + if plan.replicas[0].metadata.stage_type != "diffusion": + stage_vllm_config = plan.replicas[0].stage_vllm_config + assert stage_vllm_config is not None + output_processor = build_llm_stage_output_processor(plan, stage_vllm_config) + + stage_pools.append( + StagePool( + plan.stage_idx, + clients, + output_processor=output_processor, + stage_vllm_config=stage_vllm_config, ) - raise + ) + default_sampling_params_list.append(first_client.default_sampling_params) + stage_metadata_list.append( + { + "final_output": first_client.final_output, + "final_output_type": first_client.final_output_type, + "stage_type": first_client.stage_type, + } + ) - logger.info("[AsyncOmniEngine] Stage %s initialized", started.stage_id) - return stage_client, output_processor, started.vllm_config, input_processor + self.default_sampling_params_list = list(default_sampling_params_list) + self.stage_metadata = list(stage_metadata_list) + return stage_pools def _initialize_stages(self, stage_init_timeout: int) -> None: """Initialize stage clients/processors in orchestrator thread and assign to self. Phases: 1. Compute replica layout (counts + device splits). - 2. Launch all stage engine processes (parallel via ThreadPoolExecutor). - 3. Attach launched engines (parallel) and collect clients/processors. - 4. Build StagePool list and finalize stage metadata. + 2. Build per-stage/per-replica startup plans. + 3. Initialize all replicas in parallel via backend-specific launchers. + 4. Build logical StagePools and finalize runtime metadata. TODO(stage-pool): move per-stage launch + attach logic into a StagePool.build_from_config() classmethod so this method only iterates stage_configs, collects pools, and finalizes metadata. """ - device_control_env = current_omni_platform.device_control_env_var - num_stages = self.num_stages - - replicas_per_stage, replica_devices_map, total_llm_replicas = compute_replica_layout(self.stage_configs) + num_stages = len(self.stage_configs) + self.num_stages = num_stages + self._validate_single_stage_mode_replica_constraints() - input_processor: InputProcessor | None = None - llm_stage_ids: list[int] = [] - llm_launch_futures: dict[int, list[concurrent.futures.Future[StartedLlmStage]]] = {} - started_llm_stages: dict[int, list[StartedLlmStage]] = {} - llm_stage_launch_lock = threading.Lock() - diffusion_clients: dict[int, Any] = {} - prompt_expand_func = None - async_chunk = self.async_chunk + replicas_per_stage, replica_devices_map = compute_replica_layout(self.stage_configs) prepare_engine_environment() omni_transfer_config = load_omni_transfer_config_for_model(self.model, self.config_path) + stage_plans, prompt_expand_func = self._build_logical_stage_init_plans( + omni_transfer_config, + replicas_per_stage, + replica_devices_map, + ) + if self.single_stage_mode: + self._start_omni_master_server(stage_plans) stage_pools: list[StagePool] = [] - - # ------------------------------------------------------------------ # - # Single-stage mode: start OmniMasterServer before launching stages. # - # ------------------------------------------------------------------ # - if self.single_stage_mode: - if not self._omni_master_address or not self._omni_master_port: - raise ValueError( - "AsyncOmniEngine single_stage_mode requires both " - "omni_master_address and omni_master_port to be set." - ) - # Collect all configured stage IDs for pre-allocation. - all_stage_ids: list[int] = [] - seen_stage_ids: set[int] = set() - for i, sc in enumerate(self.stage_configs): - stage_id = int(getattr(sc, "stage_id", i)) - if stage_id in seen_stage_ids: - raise ValueError( - f"Duplicate stage_id {stage_id!r} detected among configured stages; stage_ids must be unique." - ) - seen_stage_ids.add(stage_id) - all_stage_ids.append(stage_id) - self._omni_master_server = OmniMasterServer( - master_address=self._omni_master_address, - master_port=self._omni_master_port, - stage_ids=all_stage_ids, - ) - self._omni_master_server.start() - logger.info( - "[AsyncOmniEngine] OmniMasterServer started for stages %s", - all_stage_ids, - ) + input_processor: InputProcessor | None = None + initialized_clients_by_stage: dict[int, list[Any | None]] = { + plan.stage_idx: [None] * len(plan.replicas) for plan in stage_plans + } try: - with concurrent.futures.ThreadPoolExecutor( - max_workers=max(1, total_llm_replicas), - thread_name_prefix="llm-stage-launch", - ) as launch_executor: - for stage_idx, stage_cfg in enumerate(self.stage_configs): - metadata = extract_stage_metadata(stage_cfg) - configured_stage_id = metadata.stage_id - logger.info("[AsyncOmniEngine] Initializing stage %s", configured_stage_id) - if metadata.prompt_expand_func is not None: - prompt_expand_func = metadata.prompt_expand_func - - if self.single_stage_mode: - metadata.runtime_cfg = None - - stage_connector_spec = get_stage_connector_spec( - omni_transfer_config=omni_transfer_config, - stage_id=configured_stage_id, - async_chunk=async_chunk, - ) - - omni_kv_connector = resolve_omni_kv_config_for_stage(omni_transfer_config, configured_stage_id) - - if metadata.stage_type == "diffusion": - is_remote_diffusion_stage = ( - self.single_stage_mode - and self._single_stage_id_filter is not None - and configured_stage_id != self._single_stage_id_filter - ) - if is_remote_diffusion_stage: - assert self._omni_master_server is not None - diffusion_clients[stage_idx] = self._create_remote_diffusion_stage( - metadata, - stage_init_timeout, - self._omni_master_server, - ) - continue - - with llm_stage_launch_lock: - previous_visible_devices = os.environ.get(device_control_env) - try: - setup_stage_devices(configured_stage_id, metadata.runtime_cfg) - omni_conn_cfg, omni_from, omni_to = omni_kv_connector - if omni_conn_cfg: - inject_omni_kv_config(stage_cfg, omni_conn_cfg, omni_from, omni_to) - inject_kv_stage_info(stage_cfg, configured_stage_id, self.stage_configs) - if self.single_stage_mode: - assert self._omni_master_server is not None - diffusion_clients[stage_idx] = self._launch_diffusion_stage( - stage_cfg, - metadata, - self._omni_master_server, - ) - else: - use_inline = True if self.num_stages == 1 else False - diffusion_clients[stage_idx] = initialize_diffusion_stage( - self.model, - stage_cfg, - metadata, - stage_init_timeout=stage_init_timeout, - batch_size=self.diffusion_batch_size, - use_inline=use_inline, - ) - logger.info( - "[AsyncOmniEngine] Stage %s initialized (diffusion, batch_size=%d)", - configured_stage_id, - self.diffusion_batch_size, - ) - finally: - if previous_visible_devices is None: - current_omni_platform.unset_device_control_env_var() - else: - current_omni_platform.set_device_control_env_var(previous_visible_devices) - continue - - # Submit one launch future per replica - llm_stage_ids.append(stage_idx) - num_replicas = replicas_per_stage[stage_idx] - - # single_stage_mode pre-dates multi-replica support; when - # a stage runs remotely (doesn't match the local filter) - # replica fan-out is delegated to the remote process, so - # we launch exactly one client future per such stage. - # TODO: support remote multi-replica with stage-pool - if ( - self.single_stage_mode - and self._single_stage_id_filter is not None - and configured_stage_id != self._single_stage_id_filter - ): - assert self._omni_master_server is not None - if num_replicas > 1: - logger.warning( - "[AsyncOmniEngine] Stage %s has num_replicas=%d but runs remotely in " - "single_stage_mode; only one remote client will be created.", - configured_stage_id, - num_replicas, - ) - llm_launch_futures[stage_idx] = [ - launch_executor.submit( - self._create_remote_llm_stage, - stage_cfg, - metadata, - stage_connector_spec, - stage_init_timeout, - self._omni_master_server, - ) - ] - continue - - stage_futures: list[concurrent.futures.Future[StartedLlmStage]] = [] - for replica_id in range(num_replicas): - # For replica > 0, deep-copy stage_cfg and override devices - if replica_id > 0: - replica_cfg = copy.deepcopy(stage_cfg) - else: - replica_cfg = stage_cfg - - if stage_idx in replica_devices_map: - replica_cfg.runtime.devices = replica_devices_map[stage_idx][replica_id] - - replica_metadata = extract_stage_metadata(replica_cfg) - replica_metadata.replica_id = replica_id - - logger.info( - "[AsyncOmniEngine] Launching stage %s replica %s (devices=%s)", - configured_stage_id, - replica_id, - getattr(getattr(replica_cfg, "runtime", None), "devices", "default"), - ) - - stage_futures.append( - launch_executor.submit( - self._launch_llm_stage, - replica_cfg, - replica_metadata, - stage_connector_spec, - stage_init_timeout, - llm_stage_launch_lock, - omni_kv_connector, - ) - ) - - llm_launch_futures[stage_idx] = stage_futures - - # Wait for all futures across all stages - all_futures = [f for futures in llm_launch_futures.values() for f in futures] - concurrent.futures.wait(all_futures) - - for stage_idx in llm_stage_ids: - started_llm_stages[stage_idx] = [f.result() for f in llm_launch_futures[stage_idx]] - - # ---- Parallel attach across (stage_idx, replica_id) pairs ---- - attach_futures: dict[ - concurrent.futures.Future[tuple[Any, Any, Any, InputProcessor | None]], - tuple[int, int], - ] = {} - total_replicas_to_attach = sum(len(started_llm_stages[s]) for s in llm_stage_ids) - with concurrent.futures.ThreadPoolExecutor( - max_workers=max(1, total_replicas_to_attach), - thread_name_prefix="llm-stage-attach", - ) as attach_executor: - for stage_idx in llm_stage_ids: - for replica_id, started in enumerate(started_llm_stages[stage_idx]): - attach_futures[attach_executor.submit(self._attach_llm_stage, started)] = ( - stage_idx, - replica_id, - ) - - stage_attach_results: dict[int, list[Any | None]] = { - s: [None] * len(started_llm_stages[s]) for s in llm_stage_ids - } - stage_output_proc_results: dict[int, Any | None] = {s: None for s in llm_stage_ids} - stage_vllm_cfg_results: dict[int, Any | None] = {s: None for s in llm_stage_ids} - - for future in concurrent.futures.as_completed(attach_futures): - stage_idx, replica_id = attach_futures[future] - stage_client, output_processor, vllm_config, stage0_input_processor = future.result() - stage_attach_results[stage_idx][replica_id] = stage_client - if stage_output_proc_results[stage_idx] is None and output_processor is not None: - stage_output_proc_results[stage_idx] = output_processor - if stage_vllm_cfg_results[stage_idx] is None: - stage_vllm_cfg_results[stage_idx] = vllm_config - if stage0_input_processor is not None: - input_processor = stage0_input_processor - - # ---- Build StagePool list + finalize metadata ---- - # Use first replica's client per stage for finalize (default sampling params, metadata). - logical_stage_clients_for_finalize: list[Any | None] = [None] * num_stages - for stage_idx in llm_stage_ids: - logical_stage_clients_for_finalize[stage_idx] = stage_attach_results[stage_idx][0] - for stage_idx, diff_client in diffusion_clients.items(): - logical_stage_clients_for_finalize[stage_idx] = diff_client - - _, default_sampling_params_list, stage_metadata_list = finalize_initialized_stages( - logical_stage_clients_for_finalize, - input_processor, + initialized_clients_by_stage = self._initialize_stage_replicas(stage_plans, stage_init_timeout) + if stage_plans and stage_plans[0].replicas[0].metadata.stage_type != "diffusion": + stage0_vllm_config = stage_plans[0].replicas[0].stage_vllm_config + assert stage0_vllm_config is not None + input_processor = build_stage0_input_processor(stage0_vllm_config) + stage_pools = self._assemble_stage_pools(stage_plans, initialized_clients_by_stage) + except Exception as exc: + initialized_clients_by_stage = getattr( + exc, + "_initialized_clients_by_stage", + initialized_clients_by_stage, + ) + cleanup_clients = self._collect_initialized_clients_for_cleanup( + stage_pools, + initialized_clients_by_stage, ) - - for stage_id in range(num_stages): - if stage_id in diffusion_clients: - stage_pools.append(StagePool(stage_id, diffusion_clients[stage_id])) - else: - stage_pools.append( - StagePool( - stage_id, - stage_attach_results[stage_id], - output_processor=stage_output_proc_results[stage_id], - stage_vllm_config=stage_vllm_cfg_results[stage_id], - ) - ) - - except Exception: - for stage_id, futures in llm_launch_futures.items(): - for f in futures: - if not f.done() or f.cancelled() or f.exception() is not None: - continue - started_llm_stages.setdefault(stage_id, []).append(f.result()) - # Collect all initialized clients for cleanup - cleanup_clients: list[Any] = list(diffusion_clients.values()) - for pool in stage_pools: - for client in pool.clients: - if client is not None: - cleanup_clients.append(client) - all_started = [s for stages in started_llm_stages.values() for s in stages] logger.exception( "[AsyncOmniEngine] Stage initialization failed; shutting down %s initialized client(s)", len(cleanup_clients), ) - cleanup_failed_stage_initialization(cleanup_clients, all_started) + self._shutdown_initialized_clients(cleanup_clients) if self._omni_master_server is not None: try: self._omni_master_server.stop() @@ -991,21 +980,18 @@ def _initialize_stages(self, stage_init_timeout: int) -> None: self.prompt_expand_func = prompt_expand_func # Derive logical-stage views for external readers (entrypoints/async_omni.py). - self.stage_clients = [pool.stage_client for pool in stage_pools] - self.stage_vllm_configs = [pool.stage_vllm_config for pool in stage_pools] - self.output_processors = [pool.output_processor for pool in stage_pools] + self.stage_clients = [pool.stage_client for pool in self.stage_pools] + self.stage_vllm_configs = [pool.stage_vllm_config for pool in self.stage_pools] + self.output_processors = [pool.output_processor for pool in self.stage_pools] # TODO(Peiqi): Hack here supported_tasks: set[str] = set() - if any(getattr(pool.stage_client, "is_comprehension", False) for pool in stage_pools): + if any(getattr(pool.stage_client, "is_comprehension", False) for pool in self.stage_pools): supported_tasks.add("generate") - if any(m.get("final_output_type") == "audio" for m in stage_metadata_list): + if any(m.get("final_output_type") == "audio" for m in self.stage_metadata): supported_tasks.add("speech") self.supported_tasks = tuple(supported_tasks) if supported_tasks else ("generate",) - self.default_sampling_params_list = list(default_sampling_params_list) - self.stage_metadata = list(stage_metadata_list) - def _initialize_janus_queues(self) -> None: """Initialize janus queues inside orchestrator thread loop context.""" self.request_queue = janus.Queue() @@ -1044,13 +1030,17 @@ async def _run_orchestrator() -> None: loop.run_until_complete(_run_orchestrator()) except Exception as e: if not startup_future.done(): - startup_future.set_exception(RuntimeError(f"Orchestrator initialization failed: {e}")) + wrapped = RuntimeError(f"Orchestrator initialization failed: {e}") + wrapped.__cause__ = e + startup_future.set_exception(wrapped) logger.exception("[AsyncOmniEngine] Orchestrator thread crashed") + error_text = str(e) or "Orchestrator thread crashed" try: + error_msg = {"type": "error", "error": error_text, "fatal": True} if self.output_queue is not None: - self.output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"}) + self.output_queue.sync_q.put_nowait(error_msg) if self.rpc_output_queue is not None: - self.rpc_output_queue.sync_q.put_nowait({"type": "error", "error": "Orchestrator thread crashed"}) + self.rpc_output_queue.sync_q.put_nowait(error_msg) except Exception: pass raise @@ -1070,6 +1060,31 @@ async def _run_orchestrator() -> None: asyncio.set_event_loop(None) loop.close() + def _wait_for_orchestrator_init(self, startup_future: concurrent.futures.Future, startup_timeout: int) -> None: + """ + Wait for orchestrator startup future to return ready. Raises exception on any failures to the init process. + """ + deadline = time.monotonic() + startup_timeout + while True: + remaining = deadline - time.monotonic() + if remaining <= 0: + self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup timeout") + raise TimeoutError(f"Orchestrator did not become ready within {startup_timeout}s") + try: + startup_future.result( + timeout=min(remaining, _STARTUP_POLL_INTERVAL_S), + ) + break + except concurrent.futures.TimeoutError: + if not self.orchestrator_thread.is_alive(): + self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure") + if startup_future.done(): + startup_future.result() # re-raises the real exception + raise RuntimeError("Orchestrator thread died during startup") + except Exception: + self._try_shutdown("[AsyncOmniEngine] Failed to cleanup after orchestrator startup failure") + raise + # ---- request helpers ---- def _build_add_request_message( @@ -1493,6 +1508,12 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st # Inject diffusion LoRA-related knobs from kwargs if not present in the stage config. for cfg in stage_configs: try: + if not hasattr(cfg, "engine_args") or cfg.engine_args is None: + cfg.engine_args = OmegaConf.create({}) + global_sleep_mode = kwargs.get("enable_sleep_mode") + if global_sleep_mode is not None: + if not hasattr(cfg.engine_args, "enable_sleep_mode") or cfg.engine_args.enable_sleep_mode is None: + cfg.engine_args.enable_sleep_mode = global_sleep_mode if getattr(cfg, "stage_type", None) != "diffusion": continue if not hasattr(cfg, "engine_args") or cfg.engine_args is None: @@ -1814,3 +1835,9 @@ def shutdown(self) -> None: except Exception: logger.exception("[AsyncOmniEngine] Failed to stop OmniMasterServer during shutdown") self._omni_master_server = None + + def _try_shutdown(self, *args, **kwargs) -> None: + try: + self.shutdown() + except Exception: + logger.exception(*args, **kwargs) diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py index 81ff960116b..67843c19c99 100644 --- a/vllm_omni/engine/orchestrator.py +++ b/vllm_omni/engine/orchestrator.py @@ -20,6 +20,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.engine import EngineCoreOutputs +from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.engine import OmniEngineCoreRequest from vllm_omni.engine.cfg_companion_tracker import CfgCompanionTracker @@ -140,6 +141,7 @@ def __init__( self._shutdown_event = asyncio.Event() self._stages_shutdown = False + self._fatal_error: str | None = None async def run(self) -> None: """Main entry point for the Orchestrator event loop.""" @@ -168,6 +170,12 @@ async def run(self) -> None: except Exception: pass + # If a fatal error caused the shutdown, drain any pending + # add_request messages that were never processed and broadcast + # fatal error responses so callers are not left hanging. + if self._fatal_error is not None: + await self._drain_pending_requests_on_fatal() + self._shutdown_stages() loop = asyncio.get_running_loop() @@ -349,8 +357,10 @@ async def _handle_collective_rpc(self, msg: dict[str, Any]) -> None: target_pools.extend(self.stage_pools) else: for lid in requested_stage_ids: - if 0 <= lid < self.num_stages: - target_pools.append(self.stage_pools[lid]) + if not (0 <= lid < self.num_stages): + logger.warning("[Orchestrator] collective_rpc: ignoring invalid stage_id %s", lid) + continue + target_pools.append(self.stage_pools[lid]) results: list[Any] = [] stage_ids: list[int] = [] @@ -404,22 +414,55 @@ async def _orchestration_loop(self) -> None: await self._handle_processed_outputs(stage_id, replica_id, [output]) idle = False else: - raw_outputs = await pool.poll_llm_raw_output(replica_id, timeout_s=0.001) - if raw_outputs is None: - continue - - await self._handle_kv_ready_raw_outputs(stage_id, raw_outputs) - for eco in raw_outputs.outputs: - req_state = self.request_states.get(getattr(eco, "request_id", None)) - if req_state is None: + try: + raw_outputs = await pool.poll_llm_raw_output(replica_id, timeout_s=0.001) + if raw_outputs is None: continue - req_state.streaming.segment_finished = bool(getattr(eco, "is_segment_finished", False)) - req_state.streaming.new_prompt_len_snapshot = getattr( - eco, - "new_prompt_len_snapshot", - None, + + await self._handle_kv_ready_raw_outputs(stage_id, raw_outputs) + for eco in raw_outputs.outputs: + req_state = self.request_states.get(getattr(eco, "request_id", None)) + if req_state is None: + continue + req_state.streaming.segment_finished = bool(getattr(eco, "is_segment_finished", False)) + req_state.streaming.new_prompt_len_snapshot = getattr( + eco, + "new_prompt_len_snapshot", + None, + ) + raw_output = await pool.process_llm_raw_outputs(replica_id, raw_outputs) + except asyncio.CancelledError: + raise + except EngineDeadError as e: + logger.error( + "[Orchestrator] Stage-%s is dead: %s", + stage_id, + e, ) - raw_output = await pool.process_llm_raw_outputs(replica_id, raw_outputs) + self._fatal_error = str(e) + for req_id, req_state in list(self.request_states.items()): + if stage_id in req_state.stage_submit_ts: + await self.output_async_queue.put( + { + "type": "error", + "error": str(e), + "fatal": True, + "request_id": req_id, + } + ) + self.request_states.pop(req_id, None) + self._shutdown_event.set() + raise + except Exception: + if self._shutdown_event.is_set(): + return + logger.exception( + "[Orchestrator] Stage-%s replica-%s processing failed", + stage_id, + replica_id, + ) + raise + await self._handle_processed_outputs(stage_id, replica_id, raw_output) idle = False @@ -713,11 +756,22 @@ async def _forward_to_next_stage( if next_pool.stage_type == "diffusion": if next_client.custom_process_input_func is not None: + _t_ar2d = _time.perf_counter() diffusion_prompt = next_client.custom_process_input_func( source_outputs, req_state.prompt, requires_multimodal_data, ) + _dt_ar2d = (_time.perf_counter() - _t_ar2d) * 1000 + logger.info( + "[Orchestrator] ar2diffusion req=%s wall_time=%.3fms stage=%d->%d", + req_id, + _dt_ar2d, + src_stage_id, + next_logical, + ) + if already_submitted and isinstance(diffusion_prompt, list) and len(diffusion_prompt) == 1: + diffusion_prompt = diffusion_prompt[0] else: diffusion_prompt = req_state.prompt @@ -920,6 +974,53 @@ def _build_kv_sender_info( # ---- Shutdown / lifecycle ---- + async def _drain_pending_requests_on_fatal(self) -> None: + """Drain the request queue and broadcast fatal errors for any + pending add_request messages that were never processed. + + Called from the ``run()`` finally block when a fatal error + (e.g. ``EngineDeadError``) caused the orchestrator to shut down + before the request handler could process all queued messages. + Also broadcasts for any already-tracked requests still in + ``request_states`` that were not yet notified. + """ + assert self._fatal_error is not None + + notified: set[str] = set() + + # 1) Drain pending messages from the request queue. + while True: + try: + msg = self.request_async_queue.get_nowait() + except Exception: + break + if msg.get("type") == "add_request": + req_id = msg["request_id"] + await self.output_async_queue.put( + { + "type": "error", + "error": self._fatal_error, + "fatal": True, + "request_id": req_id, + } + ) + notified.add(req_id) + + # 2) Broadcast for any tracked requests not already notified + # (e.g. request was registered but the EngineDeadError handler + # missed it because it wasn't submitted to the dead stage yet). + for req_id in list(self.request_states): + if req_id not in notified: + await self.output_async_queue.put( + { + "type": "error", + "error": self._fatal_error, + "fatal": True, + "request_id": req_id, + } + ) + self.request_states.pop(req_id, None) + def _shutdown_stages(self) -> None: """Shutdown all stage pools.""" if self._stages_shutdown: diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py index 67b4dd16504..5d4626f10f9 100644 --- a/vllm_omni/engine/output_processor.py +++ b/vllm_omni/engine/output_processor.py @@ -118,10 +118,9 @@ def _consolidate_multimodal_tensors(self) -> None: if isinstance(v, list) and v and isinstance(v[0], torch.Tensor): try: if k == "audio": - # Concatenate delta audio chunks (1-D) into the full waveform. - # Each entry is a per-step slice; flatten to -1 so chunks with - # inconsistent leading dims can still be joined on the sample axis. - self.mm_accumulated[k] = torch.cat([t.reshape(-1) for t in v], dim=0) + # When the audio tensor shape is inconsistent, torch.cat will fail. + # We need to use torch.cat in -1 dimension. + continue elif k == "sr": # Sample rate is a constant scalar, keep last value. self.mm_accumulated[k] = v[-1] @@ -332,6 +331,24 @@ def add_request( self.parent_requests[parent_req.request_id] = parent_req self.external_req_ids[req_state.external_req_id].append(request_id) + def remove_request(self, request_id: str) -> None: + """Rollback one previously registered request if it was never submitted.""" + req_state = self.request_states.pop(request_id, None) + if req_state is None: + return + + external_req_id = getattr(req_state, "external_req_id", None) + if external_req_id is not None: + request_ids = self.external_req_ids.get(external_req_id) + if request_ids is not None: + self.external_req_ids[external_req_id] = [rid for rid in request_ids if rid != request_id] + if not self.external_req_ids[external_req_id]: + self.external_req_ids.pop(external_req_id, None) + + parent_req = getattr(req_state, "parent_req", None) + if parent_req is not None: + self.parent_requests.pop(parent_req.request_id, None) + def process_outputs( self, engine_core_outputs: list[EngineCoreOutput], diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py index cf0e4241032..269b76b4948 100644 --- a/vllm_omni/engine/stage_engine_core_client.py +++ b/vllm_omni/engine/stage_engine_core_client.py @@ -7,13 +7,17 @@ from __future__ import annotations import inspect +import multiprocessing.connection import socket +import threading +import weakref from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from vllm.logger import init_logger from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import AsyncMPClient, DPLBAsyncMPClient +from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.distributed.omni_connectors.utils.initialization import ( KV_TRANSFER_PORT_OFFSET, @@ -66,6 +70,8 @@ class StageEngineCoreClientBase: ``engine_manager`` / ``coordinator`` pair created elsewhere. """ + replica_id: int = 0 + @staticmethod def make_async_mp_client( vllm_config: Any, @@ -124,8 +130,10 @@ def __init__( manage the process lifecycle on shutdown. """ # -------- Stage metadata (public fields used at runtime) -------- + self.replica_id = 0 if metadata is not None: self.stage_id = metadata.stage_id + self.replica_id = getattr(metadata, "replica_id", 0) self.stage_type = metadata.stage_type self.engine_output_type = metadata.engine_output_type self.is_comprehension = metadata.is_comprehension @@ -147,9 +155,10 @@ def __init__( client_name = self.__class__.__name__ logger.info( - "[%s] Stage-%s initializing EngineCore", + "[%s] stage-%s [rep-%s] initializing EngineCore", client_name, self.stage_id, + self.replica_id, ) try: super().__init__( @@ -166,35 +175,92 @@ def __init__( self.resources.coordinator = coordinator except Exception: logger.exception( - "[%s] Stage-%s EngineCore init failed", + "[%s] stage-%s [rep-%s] EngineCore init failed", client_name, self.stage_id, + self.replica_id, ) try: self.shutdown() except Exception as shutdown_error: logger.warning( - "[%s] Stage-%s cleanup after init failure failed: %s", + "[%s] stage-%s [rep-%s] cleanup after init failure failed: %s", client_name, self.stage_id, + self.replica_id, shutdown_error, ) raise + self._initialize_kv_sender_endpoint() + + if self._proc is not None: + self._start_proc_monitor() + logger.info( - "[%s] Stage-%s EngineCore running", + "[%s] stage-%s [rep-%s] EngineCore running", client_name, self.stage_id, + self.replica_id, + ) + + def _start_proc_monitor(self) -> None: + """Start a daemon thread that watches the subprocess sentinel. + + When the subprocess dies without sending the ZMQ ``ENGINE_CORE_DEAD`` + sentinel (e.g. SIGKILL, segfault, OOM-killer), this thread sets + ``resources.engine_dead`` so subsequent calls raise + ``EngineDeadError``. + """ + proc = self._proc + resources_ref = weakref.ref(self.resources) + stage_id = self.stage_id + replica_id = self.replica_id + + def _monitor() -> None: + try: + multiprocessing.connection.wait([proc.sentinel]) + except Exception: + return + resources = resources_ref() + if resources is None or resources.engine_dead: + return + resources.engine_dead = True + logger.error( + "[StageEngineCoreClient] stage-%s [rep-%s] subprocess died unexpectedly (exit code %s).", + stage_id, + replica_id, + proc.exitcode, + ) + + t = threading.Thread( + target=_monitor, + daemon=True, + name=f"StageCoreProcMonitor-{stage_id}", ) + t.start() + + def check_health(self) -> None: + """Raise ``EngineDeadError`` if the stage subprocess is dead. + + Called by ``OmniBase.check_health()`` and transitively by the + ``/health`` HTTP endpoint. + """ + if self.resources.engine_dead: + raise EngineDeadError(f"Stage-{self.stage_id} engine core is dead") + if self._proc is not None and not self._proc.is_alive(): + self.resources.engine_dead = True + raise EngineDeadError(f"Stage-{self.stage_id} subprocess is not alive (exit code {self._proc.exitcode})") # ==================== Overrides ==================== async def add_request_async(self, request: EngineCoreRequest) -> None: """Add request to the stage engine core.""" logger.info( - "[%s] Stage-%s adding request: %s", + "[%s] stage-%s [rep-%s] add request: %s", self.__class__.__name__, self.stage_id, + self.replica_id, request.request_id, ) await super().add_request_async(request) @@ -275,9 +341,10 @@ def _initialize_kv_sender_endpoint(self) -> None: sender_port = int(base_port) + KV_TRANSFER_PORT_OFFSET + int(from_stage) except (TypeError, ValueError): logger.warning( - "[StageEngineCoreClient] Stage-%s could not resolve sender_zmq_port " + "[StageEngineCoreClient] stage-%s [rep-%s] could not resolve sender_zmq_port " "from base_port=%s and from_stage=%s", self.stage_id, + self.replica_id, base_port, from_stage, ) diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py index 456f5a9244b..5cc9d12a196 100644 --- a/vllm_omni/engine/stage_init_utils.py +++ b/vllm_omni/engine/stage_init_utils.py @@ -19,19 +19,47 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams +from vllm.tokenizers import cached_tokenizer_from_config from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.executor import Executor +from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.engine.arg_utils import OmniEngineArgs +from vllm_omni.engine.output_processor import MultimodalOutputProcessor from vllm_omni.entrypoints.stage_utils import _to_dict, set_stage_devices from vllm_omni.entrypoints.utils import filter_dataclass_kwargs, resolve_model_config_path from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniSamplingParams +from vllm_omni.inputs.preprocess import OmniInputPreprocessor from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) +@dataclass +class ReplicaInitPlan: + """One concrete replica startup unit within a logical stage.""" + + replica_id: int + num_replicas: int + launch_mode: str + stage_cfg: Any + metadata: Any + stage_connector_spec: dict[str, Any] + omni_kv_connector: tuple[dict[str, Any] | None, str | None, str | None] + stage_vllm_config: Any | None = None + executor_class: type | None = None + + +@dataclass +class LogicalStageInitPlan: + """Startup plan for one logical stage.""" + + stage_idx: int + configured_stage_id: int + replicas: list[ReplicaInitPlan] + + def _resolve_model_to_local_path(model: str) -> str: """Resolve an HF Hub model ID to a local cache path.""" if os.path.isdir(model): @@ -83,6 +111,14 @@ def terminate_alive_proc(proc, timeout=5): proc.kill() +def patch_generation_config_if_needed(model_config: Any) -> None: + """Guard InputProcessor init for models whose config lacks model_type.""" + try: + model_config.try_get_generation_config() + except Exception: + model_config.try_get_generation_config = lambda: {} + + def resolve_worker_cls(engine_args: dict[str, Any]) -> None: """Resolve worker_cls from worker_type for non-diffusion stages.""" worker_type = engine_args.get("worker_type", None) @@ -262,20 +298,6 @@ class StageMetadata: replica_id: int = 0 -@dataclass -class StartedLlmStage: - """Resources for an LLM stage that has completed startup.""" - - stage_id: int - metadata: Any - vllm_config: Any - executor_class: type - addresses: Any - proc: Any = None - engine_manager: Any = None - coordinator: Any = None - - def extract_stage_metadata(stage_config: Any) -> StageMetadata: """Pure data extraction from a stage_config object.""" stage_id: int = stage_config.stage_id @@ -421,16 +443,36 @@ def get_stage_tp_size(stage_cfg: Any) -> int: return int(getattr(engine_args, "tensor_parallel_size", 1) or 1) +def get_stage_devices_per_replica(stage_cfg: Any) -> int: + """Return the number of devices consumed by one replica of *stage_cfg*.""" + if getattr(stage_cfg, "stage_type", "llm") != "diffusion": + return get_stage_tp_size(stage_cfg) + + parallel_config = _get_attr_or_item(getattr(stage_cfg, "engine_args", {}), "parallel_config") + if parallel_config is None: + return 1 + + world_size = _get_attr_or_item(parallel_config, "world_size") + if world_size is not None: + return max(1, int(world_size)) + + try: + from vllm_omni.diffusion.data import DiffusionParallelConfig + + return max(1, int(DiffusionParallelConfig.from_dict(_to_dict(parallel_config)).world_size)) + except Exception: + return 1 + + def compute_replica_layout( stage_configs: Sequence[Any], -) -> tuple[list[int], dict[int, list[str]], int]: +) -> tuple[list[int], dict[int, list[str]]]: """Compute per-stage replica counts and device assignments. Returns: replicas_per_stage: num_replicas per logical stage. replica_devices_map: stage_idx -> per-replica device strings (only for stages with num_replicas > 1). - total_llm_replicas: total LLM replica count across all stages. """ replicas_per_stage: list[int] = [] for stage_cfg in stage_configs: @@ -451,25 +493,22 @@ def compute_replica_layout( devices_str = ( runtime_cfg.get("devices") if hasattr(runtime_cfg, "get") else getattr(runtime_cfg, "devices", None) ) - tp_size = get_stage_tp_size(stage_cfg) + devices_per_replica = get_stage_devices_per_replica(stage_cfg) replica_devices_map[stage_id] = split_devices_for_replicas( devices_str, num_replicas, - tp_size, + devices_per_replica, stage_id, ) logger.info( - "[stage_init] Stage %s: %d replicas, tp=%d, devices split: %s", + "[stage_init] Stage %s: %d replicas, devices_per_replica=%d, devices split: %s", stage_id, num_replicas, - tp_size, + devices_per_replica, replica_devices_map[stage_id], ) - total_llm_replicas = sum( - replicas_per_stage[i] for i, cfg in enumerate(stage_configs) if getattr(cfg, "stage_type", "llm") != "diffusion" - ) - return replicas_per_stage, replica_devices_map, total_llm_replicas + return replicas_per_stage, replica_devices_map def setup_stage_devices(stage_id: int, runtime_cfg: Any) -> None: @@ -556,6 +595,35 @@ def build_vllm_config( return vllm_config, executor_class +def build_llm_stage_output_processor(plan: LogicalStageInitPlan, stage_vllm_config: Any) -> Any | None: + """Build one output processor per logical LLM stage.""" + + metadata = plan.replicas[0].metadata + if stage_vllm_config.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = cached_tokenizer_from_config( + model_config=stage_vllm_config.model_config, + ) + return MultimodalOutputProcessor( + tokenizer=tokenizer, + log_stats=False, + engine_core_output_type=metadata.engine_output_type, + ) + + +def build_stage0_input_processor(stage_vllm_config: Any) -> InputProcessor: + """Build the shared stage-0 input processor.""" + + patch_generation_config_if_needed(stage_vllm_config.model_config) + input_processor = InputProcessor(vllm_config=stage_vllm_config) + input_processor.input_preprocessor = OmniInputPreprocessor( + vllm_config=stage_vllm_config, + renderer=input_processor.renderer, + ) + return input_processor + + def acquire_device_locks( stage_id: int, engine_args_dict: dict[str, Any], @@ -749,6 +817,7 @@ def build_diffusion_config( def initialize_diffusion_stage( + stage_id: int, model: str, stage_cfg: Any, metadata: StageMetadata, @@ -770,99 +839,14 @@ def initialize_diffusion_stage( """ from vllm_omni.diffusion.stage_diffusion_client import create_diffusion_client + engine_args = _to_dict(stage_cfg.engine_args) + engine_args.pop("stage_id", None) + od_config = OmniDiffusionConfig.from_kwargs( + stage_id=stage_id, + model=model, + **engine_args, + ) + if metadata.cfg_kv_collect_func is not None: + od_config.cfg_kv_collect_func = metadata.cfg_kv_collect_func od_config = build_diffusion_config(model, stage_cfg, metadata) return create_diffusion_client(model, od_config, metadata, stage_init_timeout, batch_size, use_inline) - - -def _shutdown_or_close_resource(resource: Any, resource_name: str, stage_id: int) -> None: - """vLLM CoreEngineProcManager / coordinators use ``shutdown()``, not ``close()``.""" - if resource is None: - return - shutdown = getattr(resource, "shutdown", None) - if callable(shutdown): - try: - shutdown() - except Exception as cleanup_error: - logger.warning( - "[stage_init] Failed to shutdown launched %s for stage %s: %s", - resource_name, - stage_id, - cleanup_error, - ) - return - close = getattr(resource, "close", None) - if callable(close): - try: - close() - except Exception as cleanup_error: - logger.warning( - "[stage_init] Failed to close launched %s for stage %s: %s", - resource_name, - stage_id, - cleanup_error, - ) - - -def close_started_llm_stage(started: StartedLlmStage) -> None: - """Release resources owned by a launched stage that never attached.""" - if started.proc is not None: - try: - terminate_alive_proc(started.proc) - except Exception as cleanup_error: - logger.warning( - "[stage_init] Failed to terminate process for stage %s: %s", - started.stage_id, - cleanup_error, - ) - _shutdown_or_close_resource(started.engine_manager, "engine manager", started.stage_id) - _shutdown_or_close_resource(started.coordinator, "coordinator", started.stage_id) - - -def finalize_initialized_stages( - stage_clients: list[Any | None], - input_processor: InputProcessor | None, -) -> tuple[list[Any], list[Any], list[dict[str, Any]]]: - """Validate successful init and build runtime metadata lists.""" - if any(stage_client is None for stage_client in stage_clients): - raise RuntimeError("Stage initialization completed with missing stage clients") - - initialized_stage_clients = [stage_client for stage_client in stage_clients if stage_client is not None] - default_sampling_params_list = [stage_client.default_sampling_params for stage_client in initialized_stage_clients] - stage_metadata = [ - { - "final_output": stage_client.final_output, - "final_output_type": stage_client.final_output_type, - "stage_type": stage_client.stage_type, - } - for stage_client in initialized_stage_clients - ] - - if not isinstance(input_processor, InputProcessor): - has_llm_stage = any(metadata.get("stage_type") != "diffusion" for metadata in stage_metadata) - if has_llm_stage: - raise RuntimeError("Failed to initialize stage-0 InputProcessor for LLM pipeline") - - return initialized_stage_clients, default_sampling_params_list, stage_metadata - - -def cleanup_failed_stage_initialization( - stage_clients: list[Any | None], - started_llm_stages: list[StartedLlmStage], -) -> None: - """Shutdown attached stages and close any launched-but-unattached engines.""" - for cleanup_stage_id, stage_client in reversed(list(enumerate(stage_clients))): - if stage_client is None: - continue - try: - stage_client.shutdown() - except Exception as cleanup_error: - logger.warning( - "[stage_init] Failed to shutdown initialized stage %s after init failure: %s", - cleanup_stage_id, - cleanup_error, - ) - - for started in reversed(started_llm_stages): - if stage_clients[started.stage_id] is not None: - continue - close_started_llm_stage(started) diff --git a/vllm_omni/engine/stage_pool.py b/vllm_omni/engine/stage_pool.py index b5cddb388e5..00bc20e8627 100644 --- a/vllm_omni/engine/stage_pool.py +++ b/vllm_omni/engine/stage_pool.py @@ -199,14 +199,34 @@ async def submit_initial( request_id, affinity_request_id=affinity_request_id, ) - self.output_processor.add_request( - request=request, - prompt=prompt_text, - parent_req=None, - request_index=0, - queue=None, - ) - await self.clients[replica_id].add_request_async(request, **submit_kwargs) + try: + self.output_processor.add_request( + request=request, + prompt=prompt_text, + parent_req=None, + request_index=0, + queue=None, + ) + except Exception: + self.release_binding(request_id) + raise + + try: + await self.clients[replica_id].add_request_async(request, **submit_kwargs) + except Exception: + self.release_binding(request_id) + rollback = getattr(self.output_processor, "remove_request", None) + if callable(rollback): + try: + rollback(request_id) + except Exception as rollback_error: + logger.warning( + "[StagePool] Failed to rollback output processor state for req=%s stage-%s: %s", + request_id, + self.stage_id, + rollback_error, + ) + raise return replica_id async def submit_update( @@ -302,6 +322,7 @@ async def abort_requests(self, request_ids: list[str]) -> None: for request_id in request_ids: replica_id = self.get_bound_replica_id(request_id) if replica_id is None: + logger.debug("[StagePool] abort: no binding for req=%s in stage-%s", request_id, self.stage_id) continue request_ids_by_replica.setdefault(replica_id, []).append(request_id) diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 9606cc80d0d..67425ea2e65 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -9,6 +9,7 @@ import asyncio import time +import uuid from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any @@ -24,10 +25,12 @@ from vllm.tasks import SupportedTask from vllm.v1.engine.exceptions import EngineDeadError +from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask from vllm_omni.entrypoints.client_request_state import ClientRequestState from vllm_omni.entrypoints.omni_base import OmniBase from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics from vllm_omni.outputs import OmniRequestOutput +from vllm_omni.platforms import current_omni_platform if TYPE_CHECKING: from vllm.inputs.preprocess import InputPreprocessor @@ -40,6 +43,63 @@ _FINAL_OUTPUT_IDLE_SLEEP_S = 0.001 +class AsyncEventResolver: + """ + A generic signal aggregator designed for synchronized handshakes in + distributed or multi-stage environments. Supports waiting for a specified + number (expected_count) of worker signals in both inline and multiprocess modes. + """ + + def __init__(self, orchestrator=None): + self._pending_tasks: dict[str, dict] = {} + self.orchestrator = orchestrator + self._lock = asyncio.Lock() + + def watch_task(self, task_id: str, expected_count: int = 1) -> asyncio.Future: + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._pending_tasks[task_id] = { + "future": fut, + "expected_count": expected_count, + "received": [], + "start_time": time.time(), + } + return fut + + async def resolve(self, ack: OmniACK): + tid = getattr(ack, "task_id", None) + + if tid is None and isinstance(ack, dict): + tid = ack.get("task_id") + + async with self._lock: + task_info = self._pending_tasks.get(tid) + if task_info is None: + logger.warning(f"Received stray ACK for task_id {tid}. Task might have timed out.") + return + + task_info["received"].append(ack) + current_count = len(task_info["received"]) + expected = task_info["expected_count"] + + orchestrator = self.orchestrator + if orchestrator and hasattr(orchestrator, "metrics") and orchestrator.metrics: + freed = getattr(ack, "freed_bytes", 0) + if freed == 0 and isinstance(ack, dict): + freed = ack.get("freed_bytes", 0) + orchestrator.metrics.record_vram_reclaimed(freed) + + logger.info(f"[Resolver] Task {tid} progress: {current_count}/{expected} ACKs received.") + + if current_count >= expected: + self._pending_tasks.pop(tid) + fut = task_info["future"] + if not fut.done(): + elapsed = time.time() - task_info["start_time"] + logger.info(f"[Resolver] Task {tid} completed successfully in {elapsed:.2f}s.") + fut.set_result(task_info["received"]) + + class AsyncOmni(EngineClient, OmniBase): """Asynchronous unified entry point for multi-stage pipelines using AsyncOmniEngine. @@ -76,7 +136,7 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None: self._paused: bool = False self._is_sleeping: bool = False self.final_output_task: asyncio.Task | None = None - + self.event_resolver = AsyncEventResolver(orchestrator=self) self.config_path = self.engine.config_path self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None) self.input_processor = self.engine.input_processor @@ -433,6 +493,9 @@ async def _process_orchestrator_results( stage_id = result.get("stage_id", 0) + if result.get("type") == "error" and result.get("fatal"): + raise EngineDeadError(result.get("error", "")) + # Check for errors if "error" in result: logger.error( @@ -443,6 +506,8 @@ async def _process_orchestrator_results( ) raise RuntimeError(result) + self._check_engine_output_error(result, request_id, stage_id) + # Process the result (constructs OmniRequestOutput) output_to_yield = self._process_single_result( result, @@ -488,6 +553,22 @@ async def _final_output_loop(): await asyncio.sleep(_FINAL_OUTPUT_IDLE_SLEEP_S) continue + if isinstance(msg, dict) and msg.get("type") == "ack": + ack_data = msg.get("ack") + tid = getattr(ack_data, "task_id", "unknown") + logger.info(f"[{self._name}] Intercepted wrapped ACK for task {tid}") + await self.event_resolver.resolve(ack_data) + continue + if isinstance(msg, OmniACK): + logger.info(f"[{self._name}] Intercepted raw ACK object: {msg.task_id}") + await self.event_resolver.resolve(msg) + continue + if hasattr(msg, "task_id"): + tid = getattr(msg, "task_id") + logger.info(f"[{self._name}] Intercepted task-ID object: {tid}") + await self.event_resolver.resolve(msg) + continue + should_continue, _, stage_id, req_state = self._handle_output_message(msg) if should_continue: continue @@ -499,6 +580,16 @@ async def _final_output_loop(): except asyncio.CancelledError: raise + except EngineDeadError as e: + logger.error("[AsyncOmni] Engine dead: %s", e) + for req_state in list(self.request_states.values()): + error_msg = { + "type": "error", + "error": str(e), + "fatal": True, + "request_id": req_state.request_id, + } + await req_state.queue.put(error_msg) except Exception as e: logger.exception("[AsyncOmni] final_output_loop failed.") for req_state in list(self.request_states.values()): @@ -654,21 +745,68 @@ async def reset_prefix_cache( logger.warning("[AsyncOmni] reset_prefix_cache not yet supported with Orchestrator process") return True - async def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None: - """Sleep all stages. - - Best-effort: unsupported stages will emit a TODO result. - """ + async def sleep( + self, stage_ids: list[int] | None = None, level: int = 2, mode: PauseMode = "abort" + ) -> list[OmniACK]: + self._final_output_handler() + if stage_ids is None: + stage_ids = list(range(len(self.engine.stage_clients))) + total_workers = 0 + for sid in stage_ids: + client = self.engine.stage_clients[sid] + # During the Diffusion phase, regardless of the TP amount, + # currently only a summary ACK is reported at Rank 0. + if getattr(client, "stage_type", "") == "diffusion": + total_workers += 1 + else: + config = self.engine.stage_vllm_configs[sid] + actual_tp = config.parallel_config.tensor_parallel_size if config else 1 + total_workers += actual_tp + + task_id = str(uuid.uuid4()) + self.event_resolver.watch_task(task_id, expected_count=total_workers) + logger.info(f"[{self._name}] Sleep initiated (Task: {task_id}). Awaiting {total_workers} ACKs...") + task = OmniSleepTask(level=level, task_id=task_id) + rpc_results = await self.collective_rpc(method="handle_sleep_task", args=(task,), stage_ids=stage_ids) + final_acks = [] + for stage_res in rpc_results: + worker_acks = stage_res if isinstance(stage_res, list) else [stage_res] + for ack in worker_acks: + if ack is not None: + await self.event_resolver.resolve(ack) + final_acks.append(ack) self._is_sleeping = True - await self.collective_rpc(method="sleep", args=(level,)) - - async def wake_up(self, tags: list[str] | None = None) -> None: - """Wake up all stages. - - Best-effort: unsupported stages will emit a TODO result. - """ + return final_acks + + async def wake_up(self, stage_ids: list[int] | None = None, tags: list[str] | None = None) -> list[OmniACK]: + self._final_output_handler() + if stage_ids is None: + stage_ids = list(range(len(self.engine.stage_clients))) + total_workers = 0 + for sid in stage_ids: + client = self.engine.stage_clients[sid] + if getattr(client, "stage_type", "") == "diffusion": + total_workers += 1 + else: + config = self.engine.stage_vllm_configs[sid] + total_workers += config.parallel_config.tensor_parallel_size if config else 1 + task_id = str(uuid.uuid4()) + self.event_resolver.watch_task(task_id, expected_count=total_workers) + logger.info(f"[{self._name}] Wake-up initiated (Task: {task_id}). Awaiting {total_workers} ACKs...") + task = OmniWakeTask(tags=tags, task_id=task_id) + rpc_results = await self.collective_rpc(method="handle_wake_task", args=(task,), stage_ids=stage_ids) + final_acks = [] + for stage_res in rpc_results: + worker_acks = stage_res if isinstance(stage_res, list) else [stage_res] + for ack in worker_acks: + if ack is not None: + await self.event_resolver.resolve(ack) + final_acks.append(ack) + current_omni_platform.synchronize() + await asyncio.sleep(0.1) self._is_sleeping = False - await self.collective_rpc(method="wake_up", args=(tags,)) + logger.info(f"[{self._name}] All {len(final_acks)}/{total_workers} workers reported WARM for task {task_id}.") + return final_acks async def is_sleeping(self) -> bool: """Return whether all stages are sleeping. @@ -718,12 +856,25 @@ async def pin_lora(self, adapter_id: int) -> bool: @property def is_running(self) -> bool: """Check if the engine is running.""" - return self.final_output_task is not None and not self.final_output_task.done() + orchestrator_alive = self.engine.is_alive() + task_alive = self.final_output_task is not None and not self.final_output_task.done() + return orchestrator_alive and task_alive @property def errored(self) -> bool: - """Whether orchestrator thread has stopped unexpectedly.""" - return not self.engine.is_alive() + """Whether the engine is in a non-recoverable error state. + + Delegates to ``OmniBase.errored`` which checks the orchestrator + thread and all stage clients. Redeclared here to satisfy the + ``EngineClient`` abstract-property requirement (Python's ABC + mechanism does not resolve abstract methods from sibling MRO + entries). + """ + return OmniBase.errored.fget(self) # type: ignore[union-attr] + + @property + def _name(self) -> str: + return "AsyncOrchestrator" @property def is_stopped(self) -> bool: diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 8bccfbb5916..e355ee679d8 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -139,6 +139,17 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu action="store_true", help="Enable vLLM-Omni mode for multi-modal and diffusion models", ) + + try: + omni_config_group.add_argument( + "--enable-sleep-mode", + action="store_true", + default=False, + help="Enable GPU memory pool for sleep mode.", + ) + except argparse.ArgumentError: + pass + omni_config_group.add_argument( "--task-type", type=str, diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 8ef7e2ee5b7..223c208af98 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -166,6 +166,8 @@ def _run_generation( logger.warning("[Omni] Received output for unknown/finished request_id=%s", req_id) continue + self._check_engine_output_error(msg, req_id, stage_id) + if req_state.metrics is None: continue output_to_yield = self._process_single_result( diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py index dca494efe72..3180c9c80c0 100644 --- a/vllm_omni/entrypoints/omni_base.py +++ b/vllm_omni/entrypoints/omni_base.py @@ -12,7 +12,7 @@ import huggingface_hub from vllm.logger import init_logger -from vllm.v1.engine.exceptions import EngineDeadError +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm_omni.engine.async_omni_engine import AsyncOmniEngine from vllm_omni.entrypoints.client_request_state import ClientRequestState @@ -133,7 +133,7 @@ def __init__( if "log_requests" in kwargs: raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.") model = omni_snapshot_download(model) - self._name = self.__class__.__name__ + self.__dict__["_name"] = self.__class__.__name__ self.model = model self.log_stats = log_stats # Provisional value (mirrors the CLI/caller kwarg); the engine resolves @@ -198,9 +198,33 @@ def stage_configs(self) -> list: def is_running(self) -> bool: return self.engine.is_alive() + @property + def errored(self) -> bool: + """Whether the engine is in a non-recoverable error state. + + True when the orchestrator thread is dead **or** any stage client + has been marked dead (e.g. diffusion worker OOM / process death). + + Checks both ``_engine_dead`` (StageDiffusionClient) and + ``resources.engine_dead`` (StageEngineCoreClient / AsyncMPClient) + since the two client types store the flag differently. + """ + if not self.engine.is_alive(): + return True + for stage_client in self.engine.stage_clients: + if getattr(stage_client, "_engine_dead", False): + return True + resources = getattr(stage_client, "resources", None) + if resources is not None and getattr(resources, "engine_dead", False): + return True + return False + def check_health(self) -> None: if not self.engine.is_alive(): raise EngineDeadError("Orchestrator process is not alive") + for stage_client in self.engine.stage_clients: + if hasattr(stage_client, "check_health"): + stage_client.check_health() def resolve_sampling_params_list( self, @@ -271,7 +295,10 @@ def _handle_output_message( return True, None, None, None if msg_type == "error": - raise RuntimeError(msg.get("error", "Orchestrator returned an error message")) + error_text = msg.get("error", "Orchestrator returned an error message") + if msg.get("fatal"): + raise EngineDeadError(error_text) + raise RuntimeError(error_text) if msg_type != "output": logger.warning("[%s] got unexpected msg type: %s", self.__class__.__name__, msg_type) @@ -300,6 +327,34 @@ def _handle_output_message( return False, req_id, stage_id, req_state + def _check_engine_output_error( + self, + result: dict[str, Any], + request_id: str, + stage_id: int, + ) -> None: + """Raise if ``engine_outputs`` carries an error field. + + Raises :class:`EngineDeadError` when ``self.errored`` indicates the + engine is unrecoverable, otherwise raises :class:`EngineGenerateError` + (recoverable, single-request failure). + """ + engine_outputs = result.get("engine_outputs") + error_text = getattr(engine_outputs, "error", None) + if error_text is None: + return + logger.error( + "[%s] Stage error for req=%s stage-%s: %s", + self.__class__.__name__, + request_id, + stage_id, + error_text, + ) + # NOTE: O(n_stages) check for every error. + if self.errored: + raise EngineDeadError(error_text) + raise EngineGenerateError(error_text) + def _process_single_result( self, result: dict[str, Any], diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 745b719d5b2..1def10e64cf 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import base64 +import dataclasses import io import json import multiprocessing @@ -30,7 +31,7 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.serving import AnthropicServingMessages from vllm.entrypoints.chat_utils import load_chat_template -from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.launcher import serve_http, terminate_if_errored from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.openai.api_server import build_app as build_openai_app @@ -49,6 +50,7 @@ ModelCard, ModelList, ModelPermission, + RequestResponseMetadata, ) from vllm.entrypoints.openai.models.protocol import BaseModelPath from vllm.entrypoints.openai.models.serving import OpenAIServingModels @@ -73,6 +75,7 @@ from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.utils import ( + create_error_response, load_aware_call, process_lora_modules, with_cancellation, @@ -82,6 +85,7 @@ from vllm.tool_parsers import ToolParserManager from vllm.utils import random_uuid from vllm.utils.system_utils import decorate_logs +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.openai.errors import InvalidInputReferenceError @@ -198,6 +202,50 @@ def _remove_route_from_app(app, path: str, methods: set[str] | None = None): app.routes.remove(route) +def _register_omni_exception_handlers(app) -> None: + """Override upstream vLLM exception handlers with Omni-aware versions. + + The upstream ``engine_error_handler`` is designed for ``AsyncLLM`` (single + EngineCore process). Omni uses a multi-stage orchestrator with different + health semantics, so we register our own handlers that: + + - Log multi-stage diagnostic info (orchestrator liveness, per-stage health) + when an ``EngineDeadError`` is caught. + - Call ``terminate_if_errored`` + - Return an OpenAI-compatible error JSON response. + """ + + async def omni_engine_error_handler( + req: Request, + exc: EngineDeadError | EngineGenerateError, + ): + request_id = req.state.request_metadata.request_id if hasattr(req.state, "request_metadata") else None + + if req.app.state.args.log_error_stack: + logger.exception("Engine Exception caught. Request id: %s", request_id) + + engine = req.app.state.engine_client + if isinstance(exc, EngineDeadError): + # Log Omni-specific diagnostic information for dead engines. + orchestrator_alive = engine.engine.is_alive() if hasattr(engine, "engine") else "N/A" + logger.error( + "EngineDeadError: orchestrator_alive=%s, errored=%s, request_id=%s", + orchestrator_alive, + engine.errored, + request_id, + ) + + terminate_if_errored( + server=req.app.state.server, + engine=engine, + ) + err = create_error_response(exc) + return JSONResponse(err.model_dump(), status_code=err.error.code) + + app.exception_handler(EngineGenerateError)(omni_engine_error_handler) + app.exception_handler(EngineDeadError)(omni_engine_error_handler) + + class _DiffusionServingModels: """Minimal OpenAIServingModels implementation for diffusion-only servers. @@ -307,6 +355,10 @@ async def omni_run_server_worker(listen_address, sock, args, client_config=None, _remove_route_from_app(app, "/v1/models", {"GET"}) # Remove upstream /v1/models to use omni's handler app.include_router(router) + # OMNI: Override upstream exception handlers with Omni-aware versions + # that understand the multi-stage orchestrator lifecycle. + _register_omni_exception_handlers(app) + await omni_init_app_state(engine_client, app.state, args) # Conditionally register profiler endpoints based on stage YAML configs @@ -837,6 +889,7 @@ async def omni_init_app_state( state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 + state.sleeping_stages = set() def Omnivideo(request: Request) -> OmniOpenAIServingVideo | None: @@ -876,6 +929,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re return base_server.create_error_response(message="The model does not support Chat Completions API") try: generator = await handler.create_chat_completion(request, raw_request) + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except Exception as e: logger.exception("Chat completion failed: %s", e) raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e @@ -971,6 +1026,8 @@ async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request status_code=result.error.code if result.error else 400, ) return result + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except Exception as e: raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e @@ -1005,6 +1062,8 @@ async def create_speech_batch(request: BatchSpeechRequest, raw_request: Request) status_code=result.error.code if result.error else 400, ) return JSONResponse(content=result.model_dump()) + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except ValueError as e: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)) from e except Exception as e: @@ -1242,31 +1301,26 @@ async def realtime_websocket(websocket: WebSocket): async def health(raw_request: Request) -> JSONResponse: """Health check endpoint that works for both LLM and diffusion modes. - Returns 200 OK if the server is healthy. - For LLM mode: delegates to engine_client health check - For diffusion mode: checks if diffusion_engine is running + Returns 200 OK if the server is healthy, 503 if the engine is dead. + Mirrors vLLM upstream's /health which catches EngineDeadError -> 503. """ - # Check if we're in diffusion mode - diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) - if diffusion_engine is not None: - # Diffusion mode health check - if hasattr(diffusion_engine, "is_running") and diffusion_engine.is_running: - return JSONResponse(content={"status": "healthy"}) + engine_client = getattr(raw_request.app.state, "engine_client", None) or getattr( + raw_request.app.state, "diffusion_engine", None + ) + if engine_client is None: return JSONResponse( - content={"status": "unhealthy", "reason": "Diffusion engine is not running"}, + content={"status": "unhealthy", "reason": "No engine initialized"}, status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, ) - # LLM mode - delegate to engine_client - engine_client = getattr(raw_request.app.state, "engine_client", None) - if engine_client is not None: + try: await engine_client.check_health() return JSONResponse(content={"status": "healthy"}) - - return JSONResponse( - content={"status": "unhealthy", "reason": "No engine initialized"}, - status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, - ) + except EngineDeadError: + return JSONResponse( + content={"status": "unhealthy"}, + status_code=HTTPStatus.SERVICE_UNAVAILABLE.value, + ) # Remove existing models endpoint if present (from vllm imports) @@ -1404,6 +1458,16 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) else: size_str = "model default" + # Keep AR stage target grid in sync with requested output size. + # GLM-Image consumes target_h/target_w via mm_processor_kwargs. + if width is not None and height is not None: + prompt["mm_processor_kwargs"] = { + "target_h": height, + "target_w": width, + } + # Backward-compatible fallback for processors reading top-level fields. + prompt["height"] = height + prompt["width"] = width app_state_args = getattr(raw_request.app.state, "args", None) _check_max_generated_image_size(app_state_args, width, height) @@ -1425,6 +1489,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) _update_if_not_none(gen_params, "layers", request.layers) request_id = f"img_gen-{random_uuid()}" + raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id) logger.info(f"Generating {request.n} image(s) {size_str}") @@ -1456,8 +1521,9 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request) data=image_data, ) + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except HTTPException: - # Re-raise HTTPExceptions as-is raise except ValueError as e: logger.error(f"Validation error: {e}") @@ -1629,6 +1695,18 @@ async def edit_images( _check_max_generated_image_size(app_state_args, width, height, resolution) size_str = f"{width}x{height}" if width is not None and height is not None else "auto" + + # Keep AR stage target grid in sync with requested output size. + # GLM-Image consumes target_h/target_w via mm_processor_kwargs. + if width is not None and height is not None: + prompt["mm_processor_kwargs"] = { + "target_h": height, + "target_w": width, + } + # Backward-compatible fallback for processors reading top-level fields. + prompt["height"] = height + prompt["width"] = width + _update_if_not_none(gen_params, "width", width) _update_if_not_none(gen_params, "height", height) @@ -1648,6 +1726,7 @@ async def edit_images( # 4. Generate images using AsyncOmni (multi-stage mode) request_id = f"img_edit-{random_uuid()}" + raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id) logger.info(f"Generating {n} image(s) {size_str}") result = await _generate_with_async_omni( engine_client=engine_client, @@ -1679,8 +1758,9 @@ async def edit_images( size=size_str, ) + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except HTTPException: - # Re-raise HTTPExceptions as-is raise except ValueError as e: logger.error(f"Validation error: {e}") @@ -2086,6 +2166,48 @@ def video_response_from_request(model_name: str, req: VideoGenerationRequest) -> return resp +def _status_code_for_video_failure(error: VideoError | None) -> int: + if error is None: + return HTTPStatus.INTERNAL_SERVER_ERROR.value + + if isinstance(error.code, int): + if 400 <= error.code < 600: + return error.code + return HTTPStatus.INTERNAL_SERVER_ERROR.value + + if error.code == "HTTPException": + status_text, _, _ = error.message.partition(":") + try: + status_code = int(status_text) + except ValueError: + return HTTPStatus.INTERNAL_SERVER_ERROR.value + if 400 <= status_code < 600: + return status_code + return HTTPStatus.INTERNAL_SERVER_ERROR.value + + if error.code == "EngineDeadError": + return HTTPStatus.INTERNAL_SERVER_ERROR.value + if error.code == "EngineGenerateError": + return HTTPStatus.INTERNAL_SERVER_ERROR.value + + return HTTPStatus.INTERNAL_SERVER_ERROR.value + + +def _video_error_from_exception(exc: Exception) -> VideoError: + if isinstance(exc, HTTPException): + message = str(exc.detail) if exc.detail else str(exc) + return VideoError(code=exc.status_code, message=message) + + if isinstance(exc, (EngineGenerateError, EngineDeadError)): + err = create_error_response(exc) + return VideoError(code=err.error.code, message=err.error.message) + + return VideoError( + code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + message=str(exc), + ) + + def _cleanup_video(video_id: str, output_path: str | None): try: if output_path is not None: @@ -2099,6 +2221,7 @@ async def _run_video_generation_job( request: VideoGenerationRequest, video_id: str, reference_image: ReferenceImage | None = None, + app_state: Any | None = None, ) -> None: job = await VIDEO_STORE.get(video_id) if job is None: @@ -2129,17 +2252,36 @@ async def _run_video_generation_job( "peak_memory_mb": peak_memory_mb, }, ) + except (EngineGenerateError, EngineDeadError) as exc: + logger.exception("Video generation failed (engine error) for id=%s", video_id) + + _cleanup_video(video_id, output_path) + await VIDEO_STORE.update_fields( + video_id, + { + "status": VideoGenerationStatus.FAILED, + "completed_at": int(time.time()), + "error": _video_error_from_exception(exc), + "inference_time_s": time.perf_counter() - started_at, + }, + ) + # Background tasks can't propagate exceptions to FastAPI handlers. + # Actively signal shutdown when the engine is dead. + if app_state is not None and isinstance(exc, EngineDeadError): + terminate_if_errored( + server=app_state.server, + engine=app_state.engine_client, + ) except Exception as exc: logger.exception("Video generation failed for id=%s", video_id) _cleanup_video(video_id, output_path) - # TODO: It would be better to have a finite collection of errors to return rather than the exception name await VIDEO_STORE.update_fields( video_id, { "status": VideoGenerationStatus.FAILED, "completed_at": int(time.time()), - "error": VideoError(code=type(exc).__name__, message=str(exc)), + "error": _video_error_from_exception(exc), "inference_time_s": time.perf_counter() - started_at, }, ) @@ -2270,6 +2412,7 @@ async def _parse_video_form( }, ) async def create_video( + raw_request: Request, ctx: tuple[VideoGenerationRequest, OmniOpenAIServingVideo, str, ReferenceImage | None] = Depends(_parse_video_form), ) -> VideoResponse: """Create an asynchronous video generation job. @@ -2280,7 +2423,9 @@ async def create_video( request, handler, effective_model_name, reference_image = ctx ref = video_response_from_request(effective_model_name, request) await VIDEO_STORE.upsert(ref.id, ref) - task = asyncio.create_task(_run_video_generation_job(handler, request, ref.id, reference_image)) + task = asyncio.create_task( + _run_video_generation_job(handler, request, ref.id, reference_image, app_state=raw_request.app.state) + ) await VIDEO_TASKS.upsert(ref.id, task) return ref @@ -2295,6 +2440,7 @@ async def create_video( }, ) async def create_video_sync( + raw_request: Request, ctx: tuple[VideoGenerationRequest, OmniOpenAIServingVideo, str, ReferenceImage | None] = Depends(_parse_video_form), ) -> Response: """Synchronous video generation endpoint. @@ -2308,6 +2454,7 @@ async def create_video_sync( """ request, handler, effective_model_name, reference_image = ctx request_id = f"video_sync-{random_uuid()}" + raw_request.state.request_metadata = RequestResponseMetadata(request_id=request_id) started_at = time.perf_counter() try: video_bytes, stage_durations, peak_memory_mb = await asyncio.wait_for( @@ -2319,6 +2466,8 @@ async def create_video_sync( status_code=HTTPStatus.GATEWAY_TIMEOUT.value, detail=f"Video generation timed out after {VIDEO_SYNC_TIMEOUT_S}s.", ) + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except HTTPException: raise except Exception as exc: @@ -2379,8 +2528,8 @@ async def list_videos( return VideoListResponse(data=jobs, has_more=has_more, first_id=first_id, last_id=last_id) -@router.get("/v1/videos/{video_id}") -async def retrieve_video(video_id: str) -> VideoResponse: +@router.get("/v1/videos/{video_id}", response_model=None) +async def retrieve_video(video_id: str) -> VideoResponse | JSONResponse: """Retrieve metadata for a previously created video job. Args: @@ -2395,6 +2544,15 @@ async def retrieve_video(video_id: str) -> VideoResponse: job = await VIDEO_STORE.get(video_id) if job is None: raise HTTPException(status_code=404, detail="Video not found") + if job.status == VideoGenerationStatus.FAILED: + status_code = _status_code_for_video_failure(job.error) + content = job.model_dump(mode="json") + if content.get("error") is not None: + content["error"]["code"] = status_code + return JSONResponse( + content=content, + status_code=status_code, + ) return job @@ -2531,3 +2689,64 @@ async def stop_profile(raw_request: Request, request: ProfileRequest | None = No raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=f"Failed to stop profiler: {str(e)}" ) + + +class OmniSleepRequest(BaseModel): + stage_ids: list[int] + level: int = 2 + + +class OmniWakeupRequest(BaseModel): + stage_ids: list[int] + + +@router.post("/v1/omni/sleep") +async def omni_sleep(request: OmniSleepRequest, raw_request: Request): + engine_client = raw_request.app.state.engine_client + sleeping_set = raw_request.app.state.sleeping_stages + if not hasattr(engine_client, "sleep"): + raise HTTPException(status_code=501, detail="Engine does not support sleep") + acks = await engine_client.sleep(stage_ids=request.stage_ids, level=request.level) + for sid in request.stage_ids: + sleeping_set.add(sid) + return {"status": "SUCCESS", "acks": [dataclasses.asdict(a) if dataclasses.is_dataclass(a) else a for a in acks]} + + +@router.post("/v1/omni/wakeup") +async def omni_wakeup(request: OmniWakeupRequest, raw_request: Request): + engine_client = raw_request.app.state.engine_client + sleeping_set = raw_request.app.state.sleeping_stages + if not any(sid in sleeping_set for sid in request.stage_ids): + return {"status": "SKIPPED", "reason": "Target stages are not sleeping."} + if not hasattr(engine_client, "wake_up"): + raise HTTPException(status_code=501, detail="Engine does not support wake_up") + acks = await engine_client.wake_up(stage_ids=request.stage_ids) + for sid in request.stage_ids: + if sid in sleeping_set: + sleeping_set.remove(sid) + return {"status": "SUCCESS", "acks": [dataclasses.asdict(a) if dataclasses.is_dataclass(a) else a for a in acks]} + + +if __name__ == "__main__": + import argparse + import asyncio + + from vllm.entrypoints.openai.cli_args import make_arg_parser + + parser = argparse.ArgumentParser(description="vLLM-Omni OpenAI-Compatible REST API server") + parser = make_arg_parser(parser) + registered_flags = set() + for action in parser._actions: + registered_flags.update(action.option_strings) + if "--omni" not in registered_flags: + parser.add_argument("--omni", action="store_true", default=False, help="Enable vLLM-Omni mode.") + if "--enable-sleep-mode" not in registered_flags: + parser.add_argument( + "--enable-sleep-mode", action="store_true", default=False, help="Enable GPU memory pool for sleep mode." + ) + args = parser.parse_args() + if not hasattr(args, "model_tag"): + setattr(args, "model_tag", args.model) + if hasattr(args, "model_tag") and args.model_tag is None: + args.model_tag = args.model + asyncio.run(omni_run_server(args)) diff --git a/vllm_omni/entrypoints/openai/protocol/videos.py b/vllm_omni/entrypoints/openai/protocol/videos.py index 7c2c3164d92..d46c8d43d6b 100644 --- a/vllm_omni/entrypoints/openai/protocol/videos.py +++ b/vllm_omni/entrypoints/openai/protocol/videos.py @@ -235,7 +235,7 @@ class VideoGenerationResponse(BaseModel): class VideoError(BaseModel): - code: str = Field(..., description="A machine-readable error code that was returned.") + code: int | str = Field(..., description="A machine-readable error code that was returned.") message: str = Field(..., description="A human-readable description of the error that was returned.") diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py index 8cddac6a6c5..06b739e3bed 100644 --- a/vllm_omni/entrypoints/openai/serving_chat.py +++ b/vllm_omni/entrypoints/openai/serving_chat.py @@ -32,6 +32,7 @@ get_history_tool_calls_cnt, make_tool_call_id, ) +from vllm.entrypoints.launcher import terminate_if_errored from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, @@ -80,6 +81,7 @@ from vllm.tool_parsers import ToolParser from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.utils.collection_utils import as_list +from vllm.v1.engine.exceptions import EngineDeadError from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.image_api_utils import validate_layered_layers @@ -304,20 +306,9 @@ async def create_chat_completion( # effectively unconditioned and produce nonsense images. if request.modalities and ("image" in request.modalities): try: - messages_as_dicts: list[dict[str, Any]] = [] - for msg in request.messages: - if hasattr(msg, "model_dump"): - messages_as_dicts.append(msg.model_dump()) - elif isinstance(msg, dict): - messages_as_dicts.append(msg) - else: - messages_as_dicts.append( - { - "role": getattr(msg, "role", "user"), - "content": getattr(msg, "content", ""), - } - ) - extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images(messages_as_dicts) + extracted_prompt, reference_images = self._extract_diffusion_prompt_and_images_from_messages( + request.messages + ) if not extracted_prompt: return self.create_error_response("No text prompt found in messages") @@ -328,41 +319,33 @@ async def create_chat_completion( extra_body = getattr(request, "extra_body", None) if not extra_body: extra_body = request.model_extra or {} - height = extra_body.get("height") - width = extra_body.get("width") + + height, width = self._resolve_height_width_from_extra_body(extra_body) + num_inference_steps = extra_body.get("num_inference_steps") if num_inference_steps is not None: try: num_inference_steps = int(num_inference_steps) except Exception: num_inference_steps = None - if "size" in extra_body: - try: - size_str = extra_body["size"] - if isinstance(size_str, str) and "x" in size_str.lower(): - w, h = size_str.lower().split("x") - width, height = int(w), int(h) - except Exception: - pass + negative_prompt = extra_body.get("negative_prompt") cfg_text_scale = extra_body.get("cfg_text_scale") cfg_img_scale = extra_body.get("cfg_img_scale") engine_prompt_image: dict[str, Any] | None = None - is_img2img = False if reference_images: # Best-effort decode first reference image for i2i. try: img_bytes = base64.b64decode(reference_images[0]) img = Image.open(BytesIO(img_bytes)) engine_prompt_image = {"img2img": img} - is_img2img = True except Exception: engine_prompt_image = None # Override the prompts produced by chat-template preprocessing. tprompt: OmniTextPrompt = {"prompt": extracted_prompt} - if is_img2img: + if engine_prompt_image: tprompt["modalities"] = ["img2img"] else: tprompt["modalities"] = ["image"] @@ -378,6 +361,13 @@ async def create_chat_completion( tprompt["mm_processor_kwargs"] = mm_processor_kwargs if engine_prompt_image is not None: tprompt["multi_modal_data"] = engine_prompt_image + # Provide multi_modal_uuids so that newer vLLM versions + # can validate multi_modal_data / multi_modal_uuids + # consistency. After the multimodal processor consumes + # the image data, the uuids remain as a stable reference. + tprompt["multi_modal_uuids"] = { + k: [f"{request_id}-{k}-{i}"] for i, k in enumerate(engine_prompt_image) + } engine_prompts = [tprompt] # Store height/width for applying to diffusion stage sampling params later @@ -447,6 +437,7 @@ async def create_chat_completion( tokenizer, request_metadata, reasoning_parser, + raw_request=raw_request, ) try: @@ -544,20 +535,7 @@ async def _preprocess_chat( # containing image tokens. req_modalities = getattr(request, "modalities", []) if req_modalities and ("image" in req_modalities): - messages_as_dicts: list[dict[str, Any]] = [] - for msg in messages: - if hasattr(msg, "model_dump"): - messages_as_dicts.append(msg.model_dump()) - elif isinstance(msg, dict): - messages_as_dicts.append(msg) - else: - messages_as_dicts.append( - { - "role": getattr(msg, "role", "user"), - "content": getattr(msg, "content", ""), - } - ) - extracted_prompt, _ = self._extract_diffusion_prompt_and_images(messages_as_dicts) + extracted_prompt, _ = self._extract_diffusion_prompt_and_images_from_messages(messages) if extracted_prompt: engine_prompt["prompt"] = extracted_prompt @@ -717,6 +695,9 @@ def _apply_request_overrides( Starts with YAML defaults and only overrides fields that the user explicitly provided (non-None values) in the request. + For GLM-Image AR stage, if max_tokens is not in YAML and user provides + height/width in extra_body, computes max_tokens dynamically. + Args: default_params: Default SamplingParams from stage config YAML. request: The chat completion request containing user-provided values. @@ -726,11 +707,56 @@ def _apply_request_overrides( """ params = default_params.clone() + # Only apply fields explicitly provided by user, not protocol defaults. + # Pydantic v2 uses `model_fields_set`; keep v1 fallback for compatibility. + explicit_fields = getattr(request, "model_fields_set", None) + if explicit_fields is None: + explicit_fields = getattr(request, "__fields_set__", set()) + for field_name in self._OPENAI_SAMPLING_FIELDS: + if field_name not in explicit_fields: + continue + value = getattr(request, field_name, None) if (value is not None and not isinstance(value, list)) or (isinstance(value, list) and len(value) > 0): setattr(params, field_name, value) + # For GLM-Image: compute max_tokens from height/width with mode-aware + # budgeting (t2i vs i2i). + extra_body = getattr(request, "extra_body", {}) or {} + height, width = self._resolve_height_width_from_extra_body(extra_body) + + # Best-effort mode detection from user messages. + # i2i requests include at least one reference image in message content. + _, reference_images = self._extract_diffusion_prompt_and_images_from_messages(request.messages) + ref_image_count = len(reference_images) + is_img2img = ref_image_count > 0 + + if height is not None and width is not None: + try: + from vllm_omni.model_executor.stage_input_processors.glm_image import compute_max_tokens + + max_tokens = getattr(explicit_fields, "max_tokens", None) + if max_tokens is None: + max_tokens = compute_max_tokens(int(height), int(width), is_i2i=is_img2img) + params.max_tokens = max_tokens + # Keep target size in stage-0 sampling params so runner/model can + # build deterministic M-RoPE grids for t2i (no MM features). + extra_args = dict(getattr(params, "extra_args", {}) or {}) + extra_args["target_h"] = int(height) + extra_args["target_w"] = int(width) + params.extra_args = extra_args + except (ImportError, ValueError, TypeError) as e: + logger.warning(f"Failed to compute max_tokens: {e}, using default if available") + else: + logger.info( + "[SamplingParams] Skip dynamic max_tokens (height=%s, width=%s, mode=%s, ref_images=%s)", + height, + width, + "i2i" if is_img2img else "t2i", + ref_image_count, + ) + return params @staticmethod @@ -802,6 +828,7 @@ async def chat_completion_stream_generator( tokenizer: AnyTokenizer, request_metadata: RequestResponseMetadata, reasoning_parser: ReasoningParser | None = None, + raw_request: Request | None = None, ): created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" @@ -1514,6 +1541,21 @@ async def chat_completion_stream_generator( delta=False, ) + except EngineDeadError as e: + logger.error( + "EngineDeadError during streaming for request %s: %s", + request_id, + e, + ) + data = self.create_streaming_error_response(e) + yield f"data: {data}\n\n" + # Actively signal shutdown instead of waiting for the watchdog + # (5s polling interval). + if raw_request is not None: + terminate_if_errored( + server=raw_request.app.state.server, + engine=self.engine_client, + ) except Exception as e: logger.exception("Error in chat completion stream generator.") data = self.create_streaming_error_response(e) @@ -2643,6 +2685,48 @@ def _extract_diffusion_prompt_and_images( prompt = " ".join(prompt_parts).strip() return prompt, images + def _extract_diffusion_prompt_and_images_from_messages( + self, + messages: list[Any], + ) -> tuple[str, list[str]]: + """Normalize mixed message types and extract prompt + reference images once.""" + return self._extract_diffusion_prompt_and_images(self._messages_to_dicts(messages)) + + @staticmethod + def _messages_to_dicts(messages: list[Any]) -> list[dict[str, Any]]: + """Normalize request messages to plain dicts.""" + out: list[dict[str, Any]] = [] + for msg in messages: + if hasattr(msg, "model_dump"): + out.append(msg.model_dump()) + elif isinstance(msg, dict): + out.append(msg) + else: + out.append( + { + "role": getattr(msg, "role", "user"), + "content": getattr(msg, "content", ""), + } + ) + return out + + @staticmethod + def _resolve_height_width_from_extra_body(extra_body: dict[str, Any]) -> tuple[Any, Any]: + """Extract generation height/width with optional size string fallback.""" + height = extra_body.get("height") + width = extra_body.get("width") + + if "size" in extra_body and (height is None or width is None): + try: + size_str = extra_body["size"] + if isinstance(size_str, str) and "x" in size_str.lower(): + w, h = size_str.lower().split("x") + width, height = int(w), int(h) + except Exception: + pass + + return height, width + def _create_error_response( self, message: str, diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index ba8292f0c27..c275c779590 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -17,12 +17,17 @@ from fastapi import Request, UploadFile from fastapi.responses import Response, StreamingResponse from transformers.utils.hub import cached_file -from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.launcher import terminate_if_errored +from vllm.entrypoints.openai.engine.protocol import ( + ErrorResponse, + RequestResponseMetadata, +) from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.logger import init_logger from vllm.multimodal.media import MediaConnector from vllm.utils import random_uuid from vllm.utils.async_utils import make_async +from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin from vllm_omni.entrypoints.openai.protocol.audio import ( @@ -1151,7 +1156,13 @@ async def _resolve_ref_audio(self, ref_audio_str: str) -> tuple[list[float], int ) return wav_np.tolist(), sr - async def _generate_audio_chunks(self, generator, request_id: str, response_format: str = "pcm"): + async def _generate_audio_chunks( + self, + generator, + request_id: str, + response_format: str = "pcm", + raw_request: Request | None = None, + ): """Generate audio chunks for streaming response. Handles two audio output modes from the engine: @@ -1226,6 +1237,19 @@ async def _generate_audio_chunks(self, generator, request_id: str, response_form except asyncio.CancelledError: logger.info("Streaming request %s cancelled by client", request_id) raise + except EngineDeadError as e: + logger.error( + "EngineDeadError during streaming speech for %s: %s", + request_id, + e, + ) + # Actively signal shutdown rather than relying on the watchdog. + if raw_request is not None: + terminate_if_errored( + server=raw_request.app.state.server, + engine=self.engine_client, + ) + raise except Exception as e: logger.exception("Streaming speech generation failed for %s: %s", request_id, e) raise @@ -1500,6 +1524,7 @@ async def _build_cosyvoice3_prompt( async def _prepare_speech_generation( self, request: OpenAICreateSpeechRequest, + request_id: str | None = None, ) -> tuple[str, Any, dict[str, Any]]: if self.engine_client.errored: raise self.engine_client.dead_error @@ -1592,7 +1617,7 @@ async def _prepare_speech_generation( tts_params = {} prompt = {"prompt": request.input} - request_id = f"speech-{random_uuid()}" + request_id = request_id or f"speech-{random_uuid()}" if self._is_fish_speech: model_type = "fish_speech" elif self._tts_model_type == "voxtral_tts": @@ -1674,8 +1699,9 @@ async def _generate_audio_bytes( self, request: OpenAICreateSpeechRequest, base64_encode: bool = False, + request_id: str | None = None, ) -> tuple[bytes | str, str]: - request_id, generator, _ = await self._prepare_speech_generation(request) + request_id, generator, _ = await self._prepare_speech_generation(request, request_id=request_id) final_output: OmniRequestOutput | None = None async for res in generator: @@ -1799,6 +1825,8 @@ async def _create_diffusion_speech( except asyncio.CancelledError: return self._diffusion_error_response("Client disconnected") + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except ValueError as e: return self._diffusion_error_response(str(e)) except Exception as e: @@ -1844,6 +1872,12 @@ async def create_speech( logger.error("Error with model %s", error_check_ret) return error_check_ret + request_id = f"speech-{random_uuid()}" + if raw_request: + raw_request.state.request_metadata = RequestResponseMetadata( + request_id=request_id, + ) + try: if request.stream: # Determine response format and media type for streaming @@ -1864,17 +1898,24 @@ async def create_speech( ) media_type = "audio/wav" if response_format == "wav" else "audio/pcm" - request_id, generator, _ = await self._prepare_speech_generation(request) + _, generator, _ = await self._prepare_speech_generation(request, request_id=request_id) return StreamingResponse( - self._generate_audio_chunks(generator, request_id, response_format), + self._generate_audio_chunks( + generator, + request_id, + response_format, + raw_request=raw_request, + ), media_type=media_type, ) - audio_bytes, media_type = await self._generate_audio_bytes(request) + audio_bytes, media_type = await self._generate_audio_bytes(request, request_id=request_id) return Response(content=audio_bytes, media_type=media_type) except asyncio.CancelledError: return self.create_error_response("Client disconnected") + except (EngineGenerateError, EngineDeadError): + raise # Propagate to the global Omni exception handler except ValueError as e: return self.create_error_response(e) except Exception as e: diff --git a/vllm_omni/inputs/preprocess.py b/vllm_omni/inputs/preprocess.py index c6dffd05426..cca6ce56870 100644 --- a/vllm_omni/inputs/preprocess.py +++ b/vllm_omni/inputs/preprocess.py @@ -29,6 +29,8 @@ def _process_text( self, parsed_content: OmniTextPrompt, tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: Any | None = None, ) -> OmniTokenInputs | MultiModalInput: """Process text prompts with support for mm_processor_kwargs. @@ -38,6 +40,10 @@ def _process_text( """ prompt_text = parsed_content["prompt"] mm_processor_kwargs = parsed_content.get("mm_processor_kwargs") or {} + # When the deprecated raw-prompt path is used, process_inputs does + # not pass mm_uuids to preprocess(). Fall back to reading it from + # the prompt dict so the Renderer's _validate_mm_uuids can see it. + effective_mm_uuids = mm_uuids or parsed_content.get("multi_modal_uuids") inputs: OmniTokenInputs | MultiModalInput if multi_modal_data := parsed_content.get("multi_modal_data"): @@ -46,6 +52,7 @@ def _process_text( multi_modal_data, mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=effective_mm_uuids, ) prompt_embeds = parsed_content.get("prompt_embeds") if prompt_embeds is not None: @@ -59,6 +66,7 @@ def _process_text( {}, mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, + mm_uuids=effective_mm_uuids, ) else: prompt_token_ids = self._tokenize_prompt( @@ -142,6 +150,8 @@ def _prompt_to_llm_inputs( self, prompt: SingletonDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: Any | None = None, ) -> SingletonInput: """ Extract the singleton inputs from a prompt. @@ -166,6 +176,7 @@ def _prompt_to_llm_inputs( return self._process_text( prompt, # type: ignore[arg-type] tokenization_kwargs=tokenization_kwargs, + mm_uuids=mm_uuids, ) assert_never(prompt) # type: ignore[arg-type] diff --git a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py index 62776cbb31f..45f7afc6931 100644 --- a/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py +++ b/vllm_omni/model_executor/models/fish_speech/fish_speech_slow_ar.py @@ -649,6 +649,7 @@ def talker_mtp( input_embeds: torch.Tensor, last_talker_hidden: torch.Tensor, text_step: torch.Tensor, + **kwargs: Any, ) -> tuple[torch.Tensor, torch.Tensor]: """GPU fast-path: run Fast AR to predict residual codebook codes. diff --git a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py index 31eed9b2cb9..bf21c01a645 100644 --- a/vllm_omni/model_executor/models/glm_image/glm_image_ar.py +++ b/vllm_omni/model_executor/models/glm_image/glm_image_ar.py @@ -21,6 +21,7 @@ # limitations under the License. """Inference-only GLM-Image model compatible with HuggingFace weights.""" +import math import os from collections.abc import Iterable, Mapping, Sequence from typing import Annotated, Literal @@ -127,6 +128,14 @@ def _get_subparsers(self): parsers["img2img"] = self._parse_image_data return parsers + def parse_mm_data(self, mm_data, **kwargs): + # Normalize "img2img" to "image" so the rest of the pipeline + # (mm_hashes, _merge_mm_kwargs) uses a single modality key. + normalized = {} + for k, v in mm_data.items(): + normalized["image" if k == "img2img" else k] = v + return super().parse_mm_data(normalized, **kwargs) + class GlmImageProcessingInfo(BaseProcessingInfo): """ @@ -346,6 +355,10 @@ def _call_hf_processor( target_h = mm_kwargs.get("target_h", 1024) if mm_kwargs else 1024 target_w = mm_kwargs.get("target_w", 1024) if mm_kwargs else 1024 + logger.debug( + f"_call_hf_processor: target dimensions for generation: {target_h}x{target_w}, mm_kwargs={mm_kwargs}" + ) + if not mm_data or not mm_data.get("images"): # Text-to-image mode if processor is not None: @@ -566,6 +579,58 @@ def _apply_hf_processor_mm_only( tensor_type="pt", ) + def _apply_hf_processor_text_only( + self, prompt_text: str, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] + ) -> list[int]: + prompt_ids, _, _ = super()._apply_hf_processor_text_mm( + prompt_text=prompt_text, + mm_items=MultiModalDataItems({}), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + return prompt_ids + + def _build_generation_grids(self, hf_processor_mm_kwargs: Mapping[str, object]) -> torch.Tensor: + """Build generation grids for M-RoPE decode positions. + + For GLM-Image generation, decode order is: + 1) small preview grid + 2) large target grid + 3) EOS + + We store grids as [large, small] to match HF processor behavior, and + decode logic consumes them in reverse order. + """ + + target_h = ( + hf_processor_mm_kwargs.get("target_h") if isinstance(hf_processor_mm_kwargs.get("target_h"), int) else None + ) + target_w = ( + hf_processor_mm_kwargs.get("target_w") if isinstance(hf_processor_mm_kwargs.get("target_w"), int) else None + ) + if target_h is None or target_w is None: + target_h = ( + hf_processor_mm_kwargs.get("height") if isinstance(hf_processor_mm_kwargs.get("height"), int) else 1024 + ) + target_w = ( + hf_processor_mm_kwargs.get("width") if isinstance(hf_processor_mm_kwargs.get("width"), int) else 1024 + ) + + factor = 32 + target_h = (target_h // factor) * factor + target_w = (target_w // factor) * factor + token_h = target_h // factor + token_w = target_w // factor + + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + + return torch.tensor( + [[1, token_h, token_w], [1, small_token_h, small_token_w]], + dtype=torch.long, + ) + def _apply_hf_processor_main( self, prompt: str | list[int], @@ -594,126 +659,145 @@ def _apply_hf_processor_main( logger.debug(f"_apply_hf_processor_main: mm_counts={mm_counts}, num_images={num_images}") - if num_images == 0 or enable_hf_prompt_update: + if num_images == 0 and isinstance(prompt, str): # t2i mode or normal flow - use parent implementation - return super()._apply_hf_processor_main( - prompt=prompt, + prompt_ids = self._apply_hf_processor_text_only( + prompt_text=prompt, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + mm_processed_data = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - enable_hf_prompt_update=enable_hf_prompt_update, ) - # i2i mode with enable_hf_prompt_update=False (cache miss scenario) - # We need to build prompt_ids with image placeholders - logger.debug(f"_apply_hf_processor_main: i2i mode with enable_hf_prompt_update=False, num_images={num_images}") - - # Get mm data from our overridden _apply_hf_processor_mm_only - mm_processed_data = self._apply_hf_processor_mm_only( - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, - ) - - # In this path we do NOT call HF apply_chat_template, so we must still - # provide full grids (source + target) for M-RoPE to compute decode positions. - # Keep `image_grid_thw` source-only for MM batching/validation. - try: - source_grid_thw = mm_processed_data.get("image_grid_thw") - if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor): - # Compute target grid following HF GlmImageProcessor: factor=32. - # Prefer explicit target_h/target_w if present, otherwise fall back. - target_h = ( - hf_processor_mm_kwargs.get("target_h") - if isinstance(hf_processor_mm_kwargs.get("target_h"), int) - else None - ) - target_w = ( - hf_processor_mm_kwargs.get("target_w") - if isinstance(hf_processor_mm_kwargs.get("target_w"), int) - else None + # t2i has no source images, so mm features cannot provide image_grid_thw. + # Provide explicit generation grids for M-RoPE to avoid fallback token parsing + # (which can degrade high-resolution spatial positions, e.g. 1920x1920). + try: + mrope_grid_thw = self._build_generation_grids(hf_processor_mm_kwargs) + mm_processed_data["mrope_image_grid_thw"] = mrope_grid_thw + logger.info( + "_apply_hf_processor_main t2i: mrope_image_grid_thw=%s", + mrope_grid_thw.tolist(), ) - if target_h is None or target_w is None: - # Some callers pass generation size as height/width. - target_h = ( - hf_processor_mm_kwargs.get("height") - if isinstance(hf_processor_mm_kwargs.get("height"), int) - else 1024 - ) - target_w = ( - hf_processor_mm_kwargs.get("width") - if isinstance(hf_processor_mm_kwargs.get("width"), int) - else 1024 - ) + except Exception as e: + logger.warning("_apply_hf_processor_main t2i: failed to set mrope_image_grid_thw: %s", e) - factor = 32 - target_h = (target_h // factor) * factor - target_w = (target_w // factor) * factor - token_h = target_h // factor - token_w = target_w // factor - target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype) + return prompt_ids, mm_processed_data, False - mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0) - except Exception: - # Best-effort only; M-RoPE has additional fallbacks. - pass + # i2i mode: use unified HF processor path only. + # This avoids drift between duplicated manual/HF i2i implementations. + logger.debug( + "_apply_hf_processor_main: i2i mode (enable_hf_prompt_update=%s), num_images=%s", + enable_hf_prompt_update, + num_images, + ) - # Build prompt_ids with image placeholders - # _apply_prompt_updates will replace each [image_token_id] with expanded tokens - tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.convert_tokens_to_ids("<|image|>") + if not isinstance(prompt, str): + # Online OpenAI chat preprocessing can arrive here with tokenized + # prompts (list[int]) before serving_chat replaces engine prompt + # with the clean text prompt. Do not fail the whole request. + logger.warning( + "_apply_hf_processor_main i2i: got tokenized prompt type=%s; " + "using compatibility path for preprocessing", + type(prompt).__name__, + ) + + prompt_ids = list(prompt) + mm_processed_data = self._apply_hf_processor_mm_only( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) - if isinstance(prompt, str): - # Match HF GlmImageProcessor behavior: append target grid tokens + BOS. - # This helps M-RoPE/grid parsing and keeps i2i vs t2i behavior aligned. + # Preserve full grids for M-RoPE decode (source + target), while + # keeping image_grid_thw source-only for MM batching. try: - grid_bos = getattr(tokenizer, "grid_bos_token", "") - grid_eos = getattr(tokenizer, "grid_eos_token", "") - bos = getattr(tokenizer, "bos_token", "") - - # Use the same target sizes we used for mrope grids when available. - target_h = ( - hf_processor_mm_kwargs.get("target_h") - if isinstance(hf_processor_mm_kwargs.get("target_h"), int) - else None - ) - target_w = ( - hf_processor_mm_kwargs.get("target_w") - if isinstance(hf_processor_mm_kwargs.get("target_w"), int) - else None - ) - if target_h is None or target_w is None: + source_grid_thw = mm_processed_data.get("image_grid_thw") + if source_grid_thw is not None and isinstance(source_grid_thw, torch.Tensor): target_h = ( - hf_processor_mm_kwargs.get("height") - if isinstance(hf_processor_mm_kwargs.get("height"), int) - else 1024 + hf_processor_mm_kwargs.get("target_h") + if isinstance(hf_processor_mm_kwargs.get("target_h"), int) + else None ) target_w = ( - hf_processor_mm_kwargs.get("width") - if isinstance(hf_processor_mm_kwargs.get("width"), int) - else 1024 + hf_processor_mm_kwargs.get("target_w") + if isinstance(hf_processor_mm_kwargs.get("target_w"), int) + else None ) + if target_h is None or target_w is None: + target_h = ( + hf_processor_mm_kwargs.get("height") + if isinstance(hf_processor_mm_kwargs.get("height"), int) + else 1024 + ) + target_w = ( + hf_processor_mm_kwargs.get("width") + if isinstance(hf_processor_mm_kwargs.get("width"), int) + else 1024 + ) - factor = 32 - target_h = (target_h // factor) * factor - target_w = (target_w // factor) * factor - token_h = target_h // factor - token_w = target_w // factor - - expanded_prompt = f"{prompt}{grid_bos}{token_h} {token_w}{grid_eos}{bos}" - text_ids = tokenizer.encode(expanded_prompt, add_special_tokens=False) + factor = 32 + token_h = max(1, target_h // factor) + token_w = max(1, target_w // factor) + target_grid = torch.tensor([[1, token_h, token_w]], dtype=source_grid_thw.dtype) + mm_processed_data["mrope_image_grid_thw"] = torch.cat([source_grid_thw, target_grid], dim=0) except Exception: - text_ids = tokenizer.encode(prompt, add_special_tokens=False) + pass + + # Prompt updates will expand image placeholders in this compatibility path. + return prompt_ids, mm_processed_data, False + + images = mm_items.get_items("image", ImageProcessorItems) + image_list = [images.get(i) for i in range(images.get_count())] + if not image_list: + raise ValueError("GLM-Image i2i requires at least one source image in mm_items") + + hf_inputs = self._call_hf_processor( + prompt=prompt, + mm_data={"images": image_list}, + mm_kwargs=hf_processor_mm_kwargs, + tok_kwargs=tokenization_kwargs, + ) + + input_ids = hf_inputs.get("input_ids") + if input_ids is None: + raise ValueError("HF i2i processor returned no input_ids") + + if isinstance(input_ids, torch.Tensor): + prompt_ids = input_ids[0].tolist() if input_ids.dim() > 1 else input_ids.tolist() else: - text_ids = list(prompt) + prompt_ids = ( + input_ids[0] + if isinstance(input_ids, list) and input_ids and isinstance(input_ids[0], list) + else list(input_ids) + ) - # Prepend image placeholders - one per image - prompt_ids = [image_token_id] * num_images + text_ids + mm_processed_data = BatchFeature(dict(), tensor_type="pt") + for key in ("pixel_values", "image_grid_thw", "mrope_image_grid_thw"): + value = hf_inputs.get(key) + if value is not None: + mm_processed_data[key] = value - logger.debug(f"_apply_hf_processor_main: built prompt_ids with {num_images} image placeholders") + image_grid_thw = mm_processed_data.get("image_grid_thw") + mrope_grid_thw = mm_processed_data.get("mrope_image_grid_thw") + hf_config = self.info.get_hf_config() + image_token_id = getattr(hf_config, "image_token_id", 167855) + image_token_count = prompt_ids.count(image_token_id) + logger.info( + "_apply_hf_processor_main i2i(HF): num_images=%s, prompt_len=%s, image_token_count=%s, " + "source_grid_shape=%s, mrope_grid_shape=%s", + num_images, + len(prompt_ids), + image_token_count, + tuple(image_grid_thw.shape) if image_grid_thw is not None else None, + tuple(mrope_grid_thw.shape) if mrope_grid_thw is not None else None, + ) - # Return is_update_applied=False so _apply_prompt_updates will expand the placeholders - return prompt_ids, mm_processed_data, False + # HF processor already expanded image placeholders in input_ids. + return prompt_ids, mm_processed_data, True def _get_mm_fields_config( self, @@ -2667,9 +2751,23 @@ def get_mrope_input_positions( # Input format: "textH Wh w" where =image_start_token_id=16384 # For 1024x1024: H=32, W=32 (large), h=16, w=16 (small preview) if not image_grid_thw: + # Preferred path for t2i: use explicit target size propagated from + # serving/request sampling params. This avoids fragile grid parsing + # from token IDs and matches HF processor grid construction. + target_h = kwargs.get("target_h") + target_w = kwargs.get("target_w") + if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0: + factor = 32 + token_h = target_h // factor + token_w = target_w // factor + ratio = token_h / token_w if token_w > 0 else 1.0 + small_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + image_grid_thw = [[1, token_h, token_w], [1, small_h, small_w]] + # Try to parse from kwargs (passed from processor) hf_config_arg = kwargs.get("hf_config") - if hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"): + if (not image_grid_thw) and hf_config_arg is not None and hasattr(hf_config_arg, "image_grid_thw"): image_grid_thw = hf_config_arg.image_grid_thw # If still empty, try to infer from input tokens @@ -2723,19 +2821,29 @@ def get_mrope_input_positions( prompt_ends_with_start = len(input_tokens) > 0 and input_tokens[-1] == image_start_token_id if prompt_ends_with_start and len(image_grid_thw) == num_source_images and num_source_images > 0: # i2i mode: source grids exist but no target grids - # Parse target grids from prompt tokens or use defaults - parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config) - if parsed_grids: - # parsed_grids contains all grids mentioned in prompt - # For i2i, add only the generation target grids - if len(parsed_grids) > num_source_images: - image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:] + # Prefer explicit target size propagated from request sampling params. + # This avoids fragile grid parsing from token IDs for non-1024 i2i. + target_h = kwargs.get("target_h") + target_w = kwargs.get("target_w") + if isinstance(target_h, int) and isinstance(target_w, int) and target_h > 0 and target_w > 0: + factor = 32 + token_h = target_h // factor + token_w = target_w // factor + image_grid_thw = list(image_grid_thw) + [[1, token_h, token_w]] + else: + # Parse target grids from prompt tokens or use defaults + parsed_grids = self._parse_grid_from_tokens(input_tokens, hf_config) + if parsed_grids: + # parsed_grids contains all grids mentioned in prompt + # For i2i, add only the generation target grids + if len(parsed_grids) > num_source_images: + image_grid_thw = list(image_grid_thw) + parsed_grids[num_source_images:] + else: + # Fallback: add default 1024x1024 generation grid (1 target for i2i) + image_grid_thw = list(image_grid_thw) + [[1, 32, 32]] else: - # Fallback: add default 1024x1024 generation grids (1 target for i2i) + # Fallback to default 1024x1024 grid for generation image_grid_thw = list(image_grid_thw) + [[1, 32, 32]] - else: - # Fallback to default 1024x1024 grids for generation - image_grid_thw = list(image_grid_thw) + [[1, 32, 32]] llm_pos_ids_list: list[torch.Tensor] = [] diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index f06ecf41d22..ce336037501 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -689,6 +689,7 @@ def talker_mtp( input_embeds: torch.Tensor, last_talker_hidden: torch.Tensor, text_step: torch.Tensor, + **kwargs: Any, ): # TODO(Peiqi): not support intermediate_tensors now input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1)) diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py index d9cbcf7d4ef..31fd8062278 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py +++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py @@ -415,6 +415,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # In-memory LRU cache for voice extraction artifacts (Base voice clone). self._voice_cache = VoiceEmbeddingCache() + raw_subtalker_sampling = getattr(vllm_config.model_config, "subtalker_sampling_params", None) + self._subtalker_sampling_params: dict[str, Any] = ( + dict(raw_subtalker_sampling) if isinstance(raw_subtalker_sampling, Mapping) else {} + ) # -------------------- vLLM required hooks -------------------- @@ -1638,6 +1642,10 @@ def talker_mtp( input_embeds: torch.Tensor, last_talker_hidden: torch.Tensor, text_step: torch.Tensor, + do_sample: bool | None = None, + temperature: float | None = None, + top_k: int | None = None, + top_p: float | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """GPU fast-path used by OmniGPUModelRunner to predict residual codebooks (1..Q-1). Returns (inputs_embeds, audio_codes) for the current step.""" @@ -1656,15 +1664,24 @@ def talker_mtp( audio_codes = input_ids.reshape(bsz, 1) return (last_id_hidden + text_step).reshape(bsz, -1), audio_codes - # Predict residual codes (1..Q-1) with HF reference sampling params. + subtalker_params = self._subtalker_sampling_params + if do_sample is None: + do_sample = bool(subtalker_params.get("do_sample", True)) + if temperature is None: + temperature = float(subtalker_params.get("temperature", 0.9)) + if top_k is None: + top_k = int(subtalker_params.get("top_k", 50)) + if top_p is None: + top_p = float(subtalker_params.get("top_p", 1.0)) + audio_codes = self.code_predictor( layer0_code=input_ids.reshape(bsz, 1), layer0_embed=last_id_hidden, last_talker_hidden=past_hidden, - do_sample=True, - temperature=0.9, - top_k=50, - top_p=1.0, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, ) # [B, Q] # Map invalid layer-0 ids (e.g. EOS) to PAD=0 so SpeechTokenizer sees only real codes. diff --git a/vllm_omni/model_executor/stage_configs/glm_image.yaml b/vllm_omni/model_executor/stage_configs/glm_image.yaml index 05ac84a7a09..f3ed6c7213d 100644 --- a/vllm_omni/model_executor/stage_configs/glm_image.yaml +++ b/vllm_omni/model_executor/stage_configs/glm_image.yaml @@ -33,7 +33,6 @@ stage_args: temperature: 0.9 # From model's generation_config.json top_p: 0.75 # From model's generation_config.json top_k: 16512 # vision_vocab_size from generation_config.json - max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1) stop_token_ids: [16385] # eos_token_id from generation_config.json seed: 42 detokenize: false diff --git a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml index 7bd66c403fc..2a85a6dadbc 100644 --- a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml +++ b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml @@ -35,7 +35,6 @@ stage_args: temperature: 0.9 # From model's generation_config.json top_p: 0.75 # From model's generation_config.json top_k: 16512 # vision_vocab_size from generation_config.json - max_tokens: 1281 # For 1024x1024: small(16x16=256) + large(32x32=1024) + EOS(1) stop_token_ids: [16385] # eos_token_id from generation_config.json seed: 42 detokenize: false diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml new file mode 100644 index 00000000000..f0797c63270 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe.yaml @@ -0,0 +1,96 @@ +# Stage config for running Hunyuan-Image3.0 with AR→DiT KV reuse. +# Stage 0: AR Model (vLLM implementation) +# Stage 1: DiT Model (diffusion) +# +# text-to-image flow: AR (stage 0) → KV transfer → DiT (stage 1) +# image-to-text flow: AR (stage 0) only +# +# Compared to hunyuan_image3_t2i.yaml, this config: +# 1. Enables both stages [0, 1] for text-to-image (AR prefill + DiT denoising) +# 2. Adds omni_kv_config to send/receive KV cache between stages + +# The following config has been verified on 8x L40S-48G GPU (4 for AR + 4 for DiT). +stage_args: + - stage_id: 0 + stage_type: llm # Use llm stage type for AR stages + runtime: + process: true # Run this stage in a separate process + devices: "0,1,2,3" # AR stage uses GPU 0-3 + engine_args: + model_stage: AR + max_num_seqs: 1 + model_arch: HunyuanImage3ForCausalMM + worker_cls: vllm_omni.worker.gpu_ar_worker.GPUARWorker + scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler + gpu_memory_utilization: 0.9 + enforce_eager: true # Now we only support eager mode + trust_remote_code: true + engine_output_type: latent + enable_prefix_caching: false + max_num_batched_tokens: 32768 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + hf_overrides: + rope_parameters: + mrope_section: [0, 32, 32] + rope_type: default + omni_kv_config: + need_send_cache: true + kv_transfer_criteria: + type: prefill_finished # Send KV cache after AR prefill completes + is_comprehension: true + final_output: true + final_output_type: text + default_sampling_params: + temperature: 0.0 + top_p: 1.0 + top_k: -1 + max_tokens: 2048 + seed: 42 + detokenize: True + repetition_penalty: 1.1 + - stage_id: 1 + stage_type: diffusion + runtime: + process: true + devices: "4,5,6,7" # DiT stage uses GPU 4-7 + max_batch_size: 1 + engine_args: + model_stage: diffusion + enforce_eager: true + distributed_executor_backend: "mp" + vae_use_slicing: false + vae_use_tiling: false + cache_backend: null + cache_config: null + enable_cache_dit_summary: false + omni_kv_config: + need_recv_cache: true # Receive AR KV cache from stage 0 + parallel_config: + pipeline_parallel_size: 1 + data_parallel_size: 1 + tensor_parallel_size: 4 + enable_expert_parallel: false + sequence_parallel_size: 1 + ulysses_degree: 1 + ring_degree: 1 + cfg_parallel_size: 1 + vae_patch_parallel_size: 1 + use_hsdp: false + hsdp_shard_size: -1 + hsdp_replicate_size: 1 + engine_input_source: [0] # Receive input (including KV) from stage 0 + final_output: true + final_output_type: image + +# Top-level runtime config: windows, edges, and connectors +runtime: + enabled: true + defaults: + window_size: -1 # Trigger downstream only after full upstream completion + max_inflight: 1 # Process serially within each stage + + edges: + - from: 0 + to: 1 + window_size: -1 diff --git a/vllm_omni/model_executor/stage_input_processors/glm_image.py b/vllm_omni/model_executor/stage_input_processors/glm_image.py index 738a3732618..5ea05befb03 100644 --- a/vllm_omni/model_executor/stage_input_processors/glm_image.py +++ b/vllm_omni/model_executor/stage_input_processors/glm_image.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Stage input processor for GLM-Image: AR → Diffusion transition.""" +import math +import time from typing import Any import torch @@ -13,6 +15,86 @@ logger = init_logger(__name__) +def _has_source_image(mm_data: Any) -> bool: + """Return whether prompt multi_modal_data contains a source image. + + Normalizes legacy/new keys used across omni pipelines: + - `image`: single PIL image or list + - `img2img`: legacy single-image key + - `images`: list or single image + """ + if not isinstance(mm_data, dict): + return False + if mm_data.get("image") is not None: + return True + if mm_data.get("img2img") is not None: + return True + images = mm_data.get("images") + return bool(images) + + +def _first_source_image(mm_data: Any) -> Any: + """Get first source image from normalized multimodal keys.""" + if not isinstance(mm_data, dict): + return None + + image = mm_data.get("image") + if image is not None: + if isinstance(image, list): + return image[0] if image else None + return image + + image = mm_data.get("img2img") + if image is not None: + if isinstance(image, list): + return image[0] if image else None + return image + + images = mm_data.get("images") + if isinstance(images, list): + return images[0] if images else None + return images + + +def compute_max_tokens(height: int, width: int, factor: int = 32, is_i2i: bool = False) -> int: + """ + Compute max_new_tokens for GLM-Image AR generation. + + GLM-Image generation differs by mode: + + - text-to-image (t2i): small preview + large target + EOS + - image-to-image (i2i): large target + EOS + + Args: + height: Target image height in pixels + width: Target image width in pixels + factor: Downsampling factor (32 for GLM-Image AR output) + is_i2i: Whether the request is image-to-image mode + + Returns: + Total number of tokens to generate for the specified mode + """ + # Large image tokens (target resolution) + token_h = height // factor + token_w = width // factor + large_tokens = token_h * token_w + + # Small preview tokens (half resolution in each dimension) + import math + + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) + small_tokens = small_token_h * small_token_w + + # Mode-dependent totals: + # - t2i: small + large + EOS + # - i2i: large + EOS + if is_i2i: + return large_tokens + 1 + return small_tokens + large_tokens + 1 + + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: """Upsample token IDs by 2x using nearest neighbor interpolation. @@ -56,39 +138,49 @@ def _parse_generated_tokens( large_image_tokens = token_h * token_w # Calculate small preview image dimensions (used in text-to-image) - small_token_h = token_h // 2 - small_token_w = token_w // 2 + ratio = token_h / token_w if token_w > 0 else 1.0 + small_token_h = max(1, int(math.sqrt(ratio) * (factor // 2))) + small_token_w = max(1, int(math.sqrt(1 / ratio) * (factor // 2))) small_image_tokens = small_token_h * small_token_w token_tensor = torch.tensor(token_ids, dtype=torch.long) # Remove EOS token (16385) from the end if present eos_token_id = 16385 - if len(token_ids) > 0 and token_ids[-1] == eos_token_id: + has_terminal_eos = len(token_ids) > 0 and token_ids[-1] == eos_token_id + if has_terminal_eos: token_tensor = token_tensor[:-1] actual_tokens = len(token_tensor) - logger.debug( - f"[_parse_generated_tokens] height={height}, width={width}, " - f"token_h={token_h}, token_w={token_w}, " - f"large_image_tokens={large_image_tokens}, small_image_tokens={small_image_tokens}, " - f"actual_tokens={actual_tokens}" - ) - if is_i2i: - # Image-to-image mode: check if AR generated small+large tokens (like t2i) or just large tokens - # Some AR models output small+large even in i2i mode if actual_tokens >= small_image_tokens + large_image_tokens: - # AR generated full t2i-style output, extract large tokens after small large_start = small_image_tokens large_end = large_start + large_image_tokens prior_token_ids_d32 = token_tensor[large_start:large_end] actual_h, actual_w = token_h, token_w - else: - # AR generated only large tokens (pure i2i output) + logger.warning( + "[_parse_generated_tokens] i2i detected t2i-style token layout; " + "using small-offset extraction: large_start=%s large_end=%s", + large_start, + large_end, + ) + elif actual_tokens >= large_image_tokens: prior_token_ids_d32 = token_tensor[:large_image_tokens] actual_h, actual_w = token_h, token_w + logger.info( + "[_parse_generated_tokens] i2i using offset-0 extraction: large_tokens=%s", + large_image_tokens, + ) + else: + logger.warning( + "[_parse_generated_tokens] i2i token parse failed: actual_tokens=%s < expected_large_tokens=%s", + actual_tokens, + large_image_tokens, + ) + raise ValueError( + f"i2i token parse failed: actual_tokens={actual_tokens} < expected_large_tokens={large_image_tokens}" + ) elif actual_tokens >= small_image_tokens + large_image_tokens: # Text-to-image: extract large image tokens after small image tokens large_start = small_image_tokens @@ -96,43 +188,22 @@ def _parse_generated_tokens( prior_token_ids_d32 = token_tensor[large_start:large_end] actual_h, actual_w = token_h, token_w elif actual_tokens >= large_image_tokens: - # Image-to-image: large image tokens are at the beginning - prior_token_ids_d32 = token_tensor[:large_image_tokens] - actual_h, actual_w = token_h, token_w + logger.warning( + "[_parse_generated_tokens] t2i token parse failed: got only large tokens without small preview " + "(actual_tokens=%s, expected_small_plus_large=%s)", + actual_tokens, + small_image_tokens + large_image_tokens, + ) + raise ValueError("t2i token parse failed: missing small-preview tokens; refusing low-quality fallback") else: - # Insufficient tokens - try to infer the actual grid size - import math - - for scale in [1, 2, 4]: - test_h = token_h // scale - test_w = token_w // scale - test_small_h = test_h // 2 - test_small_w = test_w // 2 - test_large = test_h * test_w - test_small = test_small_h * test_small_w - - if actual_tokens >= test_small + test_large: - prior_token_ids_d32 = token_tensor[test_small : test_small + test_large] - actual_h, actual_w = test_h, test_w - height = test_h * factor - width = test_w * factor - logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}") - break - elif actual_tokens >= test_large: - prior_token_ids_d32 = token_tensor[:test_large] - actual_h, actual_w = test_h, test_w - height = test_h * factor - width = test_w * factor - logger.warning(f"Adjusted grid to {test_h}x{test_w}, output will be {height}x{width}") - break - else: - sqrt_tokens = int(math.sqrt(actual_tokens)) - actual_h = actual_w = sqrt_tokens - usable_tokens = sqrt_tokens * sqrt_tokens - prior_token_ids_d32 = token_tensor[:usable_tokens] - height = sqrt_tokens * factor - width = sqrt_tokens * factor - logger.error(f"Grid pattern mismatch. Using {sqrt_tokens}x{sqrt_tokens}, output: {height}x{width}") + logger.warning( + "[_parse_generated_tokens] token parse failed: insufficient tokens " + "(actual_tokens=%s, expected=%s, mode=%s)", + actual_tokens, + large_image_tokens if is_i2i else (small_image_tokens + large_image_tokens), + "i2i" if is_i2i else "t2i", + ) + raise ValueError(f"token parse failed: actual_tokens={actual_tokens}, mode={'i2i' if is_i2i else 't2i'}") # Upsample from 32x to 16x prior_token_ids = _upsample_token_ids(prior_token_ids_d32, actual_h, actual_w) @@ -144,8 +215,16 @@ def ar2diffusion( source_outputs: list[Any], prompt: OmniTokensPrompt | TextPrompt | list | None = None, requires_multimodal_data: bool = False, + streaming_context: Any | None = None, ) -> list[dict[str, Any]]: - """Process AR stage outputs to create Diffusion stage inputs.""" + """Process AR stage outputs to create Diffusion stage inputs. + + This processor accepts the stage-pool transition interface: + ``ar2diffusion(source_outputs, prompt, requires_multimodal_data)``. + """ + del streaming_context + + _t_total = time.perf_counter() ar_outputs = source_outputs diffusion_inputs = [] @@ -154,6 +233,7 @@ def ar2diffusion( prompt = [prompt] if prompt is not None else [{}] for i, ar_output in enumerate(ar_outputs): + _t_req = time.perf_counter() output = ar_output.outputs[0] generated_token_ids = output.token_ids @@ -168,23 +248,76 @@ def ar2diffusion( else: original_prompt = {} - height = original_prompt.get("height", 1024) - width = original_prompt.get("width", 1024) + mm_processor_kwargs = original_prompt.get("mm_processor_kwargs") + + def _coerce_dim(v: Any, default: int) -> int: + try: + iv = int(v) + return iv if iv > 0 else default + except (TypeError, ValueError): + return default + + # Prefer GLM-Image target size from mm_processor_kwargs (set by serving layer), + # then fall back to top-level fields for backward compatibility. + height = _coerce_dim( + mm_processor_kwargs.get("target_h") if isinstance(mm_processor_kwargs, dict) else None, + _coerce_dim(original_prompt.get("height"), 1024), + ) + width = _coerce_dim( + mm_processor_kwargs.get("target_w") if isinstance(mm_processor_kwargs, dict) else None, + _coerce_dim(original_prompt.get("width"), 1024), + ) text_prompt = original_prompt.get("prompt", "") - # Detect i2i mode first by checking if multimodal_output contains prior_token_image_ids + # Detect i2i mode. + # Prefer normalized prompt multi_modal_data source-image presence, with + # multimodal output as secondary signal. + _t_mode = time.perf_counter() is_i2i = False + + prompt_modalities = original_prompt.get("modalities") + if isinstance(prompt_modalities, list) and "img2img" in prompt_modalities: + is_i2i = True + + prompt_mm_data = original_prompt.get("multi_modal_data") + if _has_source_image(prompt_mm_data): + is_i2i = True + if hasattr(ar_output, "multimodal_output") and ar_output.multimodal_output: mm_output = ar_output.multimodal_output - if isinstance(mm_output, dict) and mm_output.get("prior_token_image_ids") is not None: - is_i2i = True + if isinstance(mm_output, dict): + if mm_output.get("prior_token_image_ids") is not None: + is_i2i = True + _dt_mode = (time.perf_counter() - _t_mode) * 1000 # Parse and upsample prior tokens - prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens(generated_token_ids, height, width, is_i2i=is_i2i) + _t_parse = time.perf_counter() + try: + prior_token_ids, pixel_h, pixel_w = _parse_generated_tokens( + generated_token_ids, + height, + width, + is_i2i=is_i2i, + ) + except ValueError as e: + logger.warning( + "[ar2diffusion] Request %s: skip due to token parse failure: %s " + "(target=%sx%s, mode=%s, raw_tokens=%s, tail=%s)", + i, + e, + height, + width, + "i2i" if is_i2i else "t2i", + len(generated_token_ids), + generated_token_ids[-8:] if len(generated_token_ids) >= 8 else generated_token_ids, + ) + continue + _dt_parse = (time.perf_counter() - _t_parse) * 1000 # Get prior_token_image_ids from AR model output (for i2i mode) # This contains VQ-VAE tokens from input image, used for KV cache conditioning # NOTE: multimodal_output is attached to ar_output (RequestOutput), NOT output (CompletionOutput) + _t_prior_img = time.perf_counter() prior_token_image_ids = None # Check ar_output (RequestOutput) for multimodal_output - this is the correct location @@ -223,6 +356,7 @@ def ar2diffusion( prior_token_image_ids = [raw_prior_image_ids] elif isinstance(raw_prior_image_ids, list): prior_token_image_ids = raw_prior_image_ids + _dt_prior_img = (time.perf_counter() - _t_prior_img) * 1000 diffusion_input = { "prompt": text_prompt, @@ -237,18 +371,38 @@ def ar2diffusion( if requires_multimodal_data: mm_data = original_prompt.get("multi_modal_data") if mm_data: - pil_image = mm_data.get("image") - if pil_image is None: - # Try "images" (plural) as fallback - images = mm_data.get("images") - if images: - pil_image = images[0] if isinstance(images, list) else images + pil_image = _first_source_image(mm_data) diffusion_input["pil_image"] = pil_image for key in ["seed", "num_inference_steps", "guidance_scale", "negative_prompt"]: if key in original_prompt: diffusion_input[key] = original_prompt[key] + _dt_req = (time.perf_counter() - _t_req) * 1000 + logger.info( + "[ar2diffusion] req=%d mode=%s target=%dx%d " + "raw_tokens=%d prior_tokens=%d prior_image_ids=%s " + "timing: mode_detect=%.3fms parse+upsample=%.3fms " + "prior_image_ids_extract=%.3fms req_total=%.3fms", + i, + "i2i" if is_i2i else "t2i", + pixel_h, + pixel_w, + len(generated_token_ids), + len(prior_token_ids), + "yes" if prior_token_image_ids is not None else "no", + _dt_mode, + _dt_parse, + _dt_prior_img, + _dt_req, + ) diffusion_inputs.append(diffusion_input) + _dt_total = (time.perf_counter() - _t_total) * 1000 + logger.info( + "[ar2diffusion] batch done: %d reqs, total=%.3fms", + len(diffusion_inputs), + _dt_total, + ) + return diffusion_inputs diff --git a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py index 0ea780db42b..4216f9259ed 100644 --- a/vllm_omni/model_executor/stage_input_processors/mimo_audio.py +++ b/vllm_omni/model_executor/stage_input_processors/mimo_audio.py @@ -9,6 +9,14 @@ logger = init_logger(__name__) +# Maximum tokens supported by the code2wav stage. The flattened talker codec +# sequence fed to stage-1 must not exceed this, otherwise gpu_input_batch +# add_request will fail with a broadcast error when copying prompt_token_ids +# into token_ids_cpu. Keep in sync with the stage-1 ``max_model_len`` in +# ``vllm_omni/model_executor/stage_configs/mimo_audio.yaml`` and the offline +# example ``examples/offline_inference/mimo_audio/end2end.py``. +MAX_CODE2WAV_TOKENS = 18192 + def prepend_and_flatten_colmajor(x: torch.Tensor, pad_vec: torch.Tensor) -> torch.Tensor: """ @@ -222,6 +230,20 @@ def llm2code2wav( code_final = prepend_and_flatten_colmajor(codec_codes, pad_vec) code_final = code_final.tolist() + # Guard against flattened sequences longer than code2wav's max_model_len. + # Without this, add_request raises ``could not broadcast input array + # from shape (N,) into shape (max_model_len,)`` and kills the engine + # core (see issue #2683). Mirrors the offline end2end.py safeguard. + if len(code_final) > MAX_CODE2WAV_TOKENS: + request_id = getattr(talker_output, "request_id", f"unknown_{i}") + logger.warning( + "Request %s: code_final len=%d > MAX_CODE2WAV_TOKENS=%d, truncating.", + request_id, + len(code_final), + MAX_CODE2WAV_TOKENS, + ) + code_final = code_final[:MAX_CODE2WAV_TOKENS] + code2wav_inputs.append( OmniTokensPrompt( prompt_token_ids=code_final, diff --git a/vllm_omni/outputs.py b/vllm_omni/outputs.py index c02c0c1427c..23930f358bb 100644 --- a/vllm_omni/outputs.py +++ b/vllm_omni/outputs.py @@ -100,9 +100,22 @@ class OmniRequestOutput: # memory usage info peak_memory_mb: float = 0.0 - # error handling + # Error information -- set when the output represents a failed request. error: str | None = None + @classmethod + def from_error( + cls, + request_id: str, + error: str, + ) -> "OmniRequestOutput": + """Create an error output for a request that failed during generation.""" + return cls( + request_id=request_id, + finished=True, + error=error, + ) + @classmethod def from_pipeline( cls, diff --git a/vllm_omni/platforms/npu/worker/npu_model_runner.py b/vllm_omni/platforms/npu/worker/npu_model_runner.py index 8ef39adfa67..310d09311f0 100644 --- a/vllm_omni/platforms/npu/worker/npu_model_runner.py +++ b/vllm_omni/platforms/npu/worker/npu_model_runner.py @@ -417,10 +417,22 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded] last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded] text_step = self.text_step.gpu[:num_tokens_padded] + subtalker_params = getattr(self.vllm_config.model_config, "subtalker_sampling_params", None) + if not isinstance(subtalker_params, dict): + subtalker_params = {} with set_ascend_forward_context( None, self.vllm_config, aclgraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc ): - req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + req_embeds, code_predictor_codes = self.talker_mtp( + req_input_ids, + req_embeds, + last_talker_hidden, + text_step, + do_sample=subtalker_params.get("do_sample"), + temperature=subtalker_params.get("temperature"), + top_k=subtalker_params.get("top_k"), + top_p=subtalker_params.get("top_p"), + ) # code_predictor_codes stays on GPU here; _update_intermediate_buffer # keeps it device-resident when the key is in gpu_resident_buffer_keys. # D2H is deferred to sample_tokens where hidden_states.to("cpu") already diff --git a/vllm_omni/profiler/omni_torch_profiler.py b/vllm_omni/profiler/omni_torch_profiler.py index 2257a212838..023ad5009b7 100644 --- a/vllm_omni/profiler/omni_torch_profiler.py +++ b/vllm_omni/profiler/omni_torch_profiler.py @@ -5,8 +5,8 @@ import os import subprocess -import time -from typing import Literal +from datetime import datetime +from typing import Any, Literal import torch from typing_extensions import override @@ -62,6 +62,13 @@ def __init__( self._trace_path: str | None = None self._table_path: str | None = None + self._activities = activities + self._session_dir: str | None = None + self._artifact_paths: dict[str, str | None] = {} + self._memory_history_enabled = False + self._memory_history_backend: str | None = None + self._memory_history_module = None + if local_rank in (None, 0): logger.info_once( "Omni torch profiling enabled. Traces will be saved to: %s", @@ -72,6 +79,9 @@ def __init__( self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1 self.profiler = self._create_profiler(profiler_config, activities) + def _rank(self) -> int: + return 0 if self.local_rank is None else self.local_rank + def _get_default_activities(self) -> list[TorchProfilerActivity]: """Get default activities for this platform. @@ -106,19 +116,58 @@ def set_trace_filename(self, filename: str) -> None: Can also be a full path (e.g. from diffusion engine). """ self._trace_filename = filename + self._session_dir = None + self._ensure_session_dir() + + def _ensure_session_dir(self) -> str: + """Create one timestamped directory for this profiling run.""" + if self._session_dir is not None: + os.makedirs(self._session_dir, exist_ok=True) + return self._session_dir + + ts = datetime.now().strftime("%Y%m%d-%H%M%S") + base_name = self._trace_filename or self._worker_name + + if os.path.dirname(base_name): + parent_dir = os.path.dirname(base_name) + leaf_name = os.path.basename(base_name) + session_name = f"{ts}_{leaf_name}" + self._session_dir = os.path.join(parent_dir, session_name) + else: + session_name = f"{ts}_{base_name}" + self._session_dir = os.path.join(self._trace_dir, session_name) + + os.makedirs(self._session_dir, exist_ok=True) + self._artifact_paths["session_dir"] = self._session_dir + return self._session_dir + + def _artifact_path(self, stem: str, suffix: str) -> str: + """Build artifact path under the session directory.""" + return os.path.join( + self._ensure_session_dir(), + f"{stem}_rank{self._rank()}{suffix}", + ) + + def _write_text_artifact(self, name: str, content: str) -> str: + path = self._artifact_path(name, ".txt") + with open(path, "w") as f: + f.write(content) + self._artifact_paths[name] = path + return path + + def _has_cuda_like_activity(self) -> bool: + return any(a in self._activities for a in ("CUDA", "MUSA")) + + def _get_time_sort_key(self) -> str: + if self._has_cuda_like_activity(): + return "self_cuda_time_total" + return "self_cpu_time_total" def _on_trace_ready(self, prof) -> None: """Custom trace handler: export chrome trace with omni naming.""" - rank = self.local_rank - filename = self._trace_filename or f"{self._worker_name}_{int(time.time())}" - # If filename already contains a directory, use as-is (e.g. from - # diffusion engine which builds full path). Otherwise join with trace_dir. - if os.path.dirname(filename): - json_file = f"{filename}_rank{rank}.json" - else: - json_file = os.path.join(self._trace_dir, f"{filename}_rank{rank}.json") + rank = self._rank() - os.makedirs(os.path.dirname(json_file), exist_ok=True) + json_file = self._artifact_path("trace", ".json") try: prof.export_chrome_trace(json_file) @@ -143,18 +192,211 @@ def _on_trace_ready(self, prof) -> None: else: self._trace_path = json_file + self._artifact_paths["trace"] = self._trace_path + except Exception as e: logger.warning("[Rank %s] Failed to export trace: %s", rank, e) + def _try_enable_memory_history(self) -> None: + """Enable backend-specific memory history for snapshot analysis.""" + if not self.profiler_config.torch_profiler_with_memory: + return + + backend_name, memory_module = self._resolve_memory_history_backend() + if backend_name is None or memory_module is None: + return + + record_memory_history = getattr(memory_module, "_record_memory_history", None) + if record_memory_history is None: + logger.info( + "[Rank %s] %s memory history is not supported on this platform", + self._rank(), + backend_name, + ) + return + + try: + record_memory_history( + enabled="all", + context="all", + stacks="python", + max_entries=100000, + clear_history=True, + ) + self._memory_history_enabled = True + self._memory_history_backend = backend_name + self._memory_history_module = memory_module + logger.info("[Rank %s] %s memory history enabled", self._rank(), backend_name) + except Exception as e: + logger.warning( + "[Rank %s] Failed to enable %s memory history: %s", + self._rank(), + backend_name, + e, + ) + + def _try_dump_memory_snapshot(self) -> None: + """Dump a backend-specific memory snapshot into the session directory.""" + if not self._memory_history_enabled: + return + + try: + if self._memory_history_module is None or self._memory_history_backend is None: + return + + dump_snapshot = getattr(self._memory_history_module, "_dump_snapshot", None) + if dump_snapshot is None: + logger.info( + "[Rank %s] %s memory snapshot is not supported on this platform", + self._rank(), + self._memory_history_backend, + ) + return + + snapshot_file = self._artifact_path("memory_snapshot", ".pickle") + dump_snapshot(snapshot_file) + self._artifact_paths["memory_snapshot"] = snapshot_file + logger.info( + "[Rank %s] %s memory snapshot dumped to %s", + self._rank(), + self._memory_history_backend, + snapshot_file, + ) + except Exception as e: + logger.warning( + "[Rank %s] Failed to dump %s memory snapshot: %s", + self._rank(), + self._memory_history_backend, + e, + ) + finally: + try: + if self._memory_history_module is not None: + disable_memory_history = getattr( + self._memory_history_module, + "_record_memory_history", + None, + ) + if disable_memory_history is not None: + disable_memory_history(enabled=None) + except Exception: + pass + self._memory_history_enabled = False + self._memory_history_backend = None + self._memory_history_module = None + + def _resolve_memory_history_backend(self) -> tuple[str | None, Any]: + """Resolve the memory backend that supports history/snapshot APIs.""" + backend_specs = [ + ("CUDA", self._has_cuda_like_activity(), getattr(torch, "cuda", None)), + ("NPU", "NPU" in self._activities, getattr(torch, "npu", None)), + ("XPU", "XPU" in self._activities, getattr(torch, "xpu", None)), + ("MUSA", "MUSA" in self._activities, getattr(torch, "musa", None)), + ] + + for backend_name, enabled, device_module in backend_specs: + if not enabled or device_module is None: + continue + + is_available = getattr(device_module, "is_available", None) + if callable(is_available) and not is_available(): + continue + + memory_module = getattr(device_module, "memory", None) + if memory_module is not None: + return backend_name, memory_module + + return None, None + + def _safe_get(self, obj, name: str, default=None): + return getattr(obj, name, default) + + def _event_list_to_rows(self, event_list) -> list[dict]: + rows = [] + for evt in event_list: + row = { + "name": self._safe_get(evt, "key", None) or self._safe_get(evt, "name", None), + "count": self._safe_get(evt, "count", None), + "device_type": self._safe_get(evt, "device_type", None), + "node_id": self._safe_get(evt, "node_id", None), + "self_cpu_time_total_us": self._safe_get(evt, "self_cpu_time_total", None), + "cpu_time_total_us": self._safe_get(evt, "cpu_time_total", None), + "self_cuda_time_total_us": self._safe_get(evt, "self_cuda_time_total", None), + "cuda_time_total_us": self._safe_get(evt, "cuda_time_total", None), + "self_xpu_time_total_us": self._safe_get(evt, "self_xpu_time_total", None), + "xpu_time_total_us": self._safe_get(evt, "xpu_time_total", None), + "self_cpu_memory_usage_bytes": self._safe_get(evt, "self_cpu_memory_usage", None), + "cpu_memory_usage_bytes": self._safe_get(evt, "cpu_memory_usage", None), + "self_cuda_memory_usage_bytes": self._safe_get(evt, "self_cuda_memory_usage", None), + "cuda_memory_usage_bytes": self._safe_get(evt, "cuda_memory_usage", None), + "self_xpu_memory_usage_bytes": self._safe_get(evt, "self_xpu_memory_usage", None), + "xpu_memory_usage_bytes": self._safe_get(evt, "xpu_memory_usage", None), + "input_shapes": str(self._safe_get(evt, "input_shapes", None)), + "stack": "\n".join(self._safe_get(evt, "stack", []) or []), + "overload_name": self._safe_get(evt, "overload_name", None), + "is_async": self._safe_get(evt, "is_async", None), + "is_legacy": self._safe_get(evt, "is_legacy", None), + } + rows.append(row) + return rows + + def _write_excel_artifact(self, name: str, sheets: dict[str, list[dict]]) -> str: + path = self._artifact_path(name, ".xlsx") + + try: + import pandas as pd + except Exception as e: + logger.warning( + "[Rank %s] pandas not available, skip Excel export: %s", + self._rank(), + e, + ) + return path + + with pd.ExcelWriter(path, engine="openpyxl") as writer: + for sheet_name, rows in sheets.items(): + df = pd.DataFrame(rows) + + safe_sheet_name = sheet_name if sheet_name else "Sheet1" + + df.to_excel( + writer, + sheet_name=safe_sheet_name, + index=False, + freeze_panes=(1, 0), + ) + + ws = writer.sheets[safe_sheet_name] + ws.auto_filter.ref = ws.dimensions + + for col_cells in ws.columns: + max_len = 0 + col_letter = col_cells[0].column_letter + for cell in col_cells[:200]: + try: + val = "" if cell.value is None else str(cell.value) + max_len = max(max_len, len(val)) + except Exception: + pass + ws.column_dimensions[col_letter].width = min(max(max_len + 2, 12), 80) + + self._artifact_paths[name] = path + return path + @override def _start(self) -> None: + self._ensure_session_dir() + self._try_enable_memory_history() self.profiler.start() @override def _stop(self) -> None: """Stop profiler, export trace via on_trace_ready, and dump table.""" self.profiler.stop() - self._on_stop_hook() + try: + self._on_stop_hook() + finally: + self._try_dump_memory_snapshot() def _on_stop_hook(self) -> None: """Hook called after profiler.stop(). @@ -163,6 +405,68 @@ def _on_stop_hook(self) -> None: Base implementation handles CUDA time total dump. """ rank = self.local_rank + sort_key = self._get_time_sort_key() + + excel_sheets: dict[str, list[dict]] = {} + + # 1) Summary op table + summary_events = self.profiler.key_averages() + excel_sheets["summary"] = self._event_list_to_rows(summary_events) + + # 2) Shape-grouped op table + if self.profiler_config.torch_profiler_record_shapes: + try: + shape_events = self.profiler.key_averages( + group_by_input_shape=True, + ) + excel_sheets["by_shape"] = self._event_list_to_rows(shape_events) + except Exception as e: + logger.warning( + "[Rank %s] Failed to export shape-grouped op table: %s", + rank, + e, + ) + + # 3) Stack-grouped op table + if self.profiler_config.torch_profiler_with_stack: + try: + stack_events = self.profiler.key_averages( + group_by_stack_n=8, + ) + excel_sheets["by_stack"] = self._event_list_to_rows(stack_events) + except Exception as e: + logger.warning( + "[Rank %s] Failed to export stack-grouped op table: %s", + rank, + e, + ) + + # 4) Export stack files + try: + cpu_stack_file = self._artifact_path("stacks_cpu", ".txt") + self.profiler.export_stacks( + cpu_stack_file, + metric="self_cpu_time_total", + ) + self._artifact_paths["stacks_cpu"] = cpu_stack_file + except Exception as e: + logger.warning("[Rank %s] export_stacks(cpu) failed: %s", rank, e) + + if self._has_cuda_like_activity(): + try: + cuda_stack_file = self._artifact_path("stacks_cuda", ".txt") + self.profiler.export_stacks( + cuda_stack_file, + metric="self_cuda_time_total", + ) + self._artifact_paths["stacks_cuda"] = cuda_stack_file + except Exception as e: + logger.warning("[Rank %s] export_stacks(cuda) failed: %s", rank, e) + + try: + self._table_path = self._write_excel_artifact("ops", excel_sheets) + except Exception as e: + logger.warning("[Rank %s] Failed to export Excel workbook: %s", rank, e) if self.profiler_config.torch_profiler_dump_cuda_time_total: profiler_dir = self.profiler_config.torch_profiler_dir @@ -190,6 +494,7 @@ def get_results(self) -> dict: return { "trace": self._trace_path, "table": self._table_path, + **self._artifact_paths, } diff --git a/vllm_omni/worker/base.py b/vllm_omni/worker/base.py index f7f5dbd1d8b..8bd9efc89c4 100644 --- a/vllm_omni/worker/base.py +++ b/vllm_omni/worker/base.py @@ -2,15 +2,23 @@ from __future__ import annotations +import gc import os import time +from contextlib import AbstractContextManager, nullcontext import torch from vllm.logger import init_logger from vllm.utils.mem_utils import format_gib, memory_profiling from vllm.v1.worker.gpu_worker import Worker as GPUWorker +from vllm_omni.diffusion.data import ( + OmniACK, + OmniSleepTask, + OmniWakeTask, +) from vllm_omni.entrypoints.utils import detect_pid_host +from vllm_omni.platforms import current_omni_platform from vllm_omni.worker.gpu_memory_utils import ( get_process_gpu_memory, is_process_scoped_memory_available, @@ -30,6 +38,13 @@ class OmniGPUWorkerBase(GPUWorker): for custom trace naming, background gzip, and trace path collection. """ + def load_model(self, *args, **kwargs): + with self._maybe_get_memory_pool_context("weights"): + res = super().load_model(*args, **kwargs) + current_omni_platform.synchronize() + gc.collect() + return res + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -94,10 +109,8 @@ def determine_available_memory(self) -> int: """ if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: self.model_runner.profile_run() - logger.info( - "Using explicit kv_cache_memory_bytes: %s GiB", - format_gib(kv_cache_memory_bytes), - ) + if current_omni_platform.is_rocm(): + torch.cuda.synchronize() return kv_cache_memory_bytes with memory_profiling( @@ -154,3 +167,141 @@ def determine_available_memory(self) -> int: ) return int(self.available_kv_cache_memory_bytes) + + # Provide memory pool context + def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: + v1_config_enabled = False + if hasattr(self, "vllm_config"): + model_cfg = getattr(self.vllm_config, "model_config", None) + v1_config_enabled = getattr(model_cfg, "enable_sleep_mode", False) + + is_sleep_enabled = v1_config_enabled or getattr(self.cache_config, "enable_sleep_mode", False) + if is_sleep_enabled: + current_omni_platform.synchronize() + gc.collect() + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + logger.info(f"[LLM Worker {self.rank}] Sleep Mode ENABLED. Activating CuMem pool for tag: {tag}") + return allocator.use_memory_pool(tag=tag) + else: + logger.warning(f"[LLM Worker {self.rank}] Sleep Mode DISABLED.") + return nullcontext() + + def sleep(self, level: int = 1) -> bool: + """ + Put the worker to sleep. + Args: + level: 1 (Offload weights to CPU), level: 2 (Total Discard). + """ + from vllm.device_allocator.cumem import CuMemAllocator + + mem_before = current_omni_platform.get_current_memory_usage(self.device) + offload_tags = ("weights",) if level == 1 else tuple() + allocator = CuMemAllocator.get_instance() + allocator.sleep(offload_tags=offload_tags) + current_omni_platform.empty_cache() + current_omni_platform.synchronize() + mem_after = current_omni_platform.get_current_memory_usage(self.device) + freed = max(0, mem_before - mem_after) + remaining_gb = mem_after / 1024**3 + logger.info( + f"[LLM Worker {self.rank}] Level {level} Sleep: Freed " + f"{freed / 1024**3:.2f} GiB. {remaining_gb:.2f}GiB memory " + "is still in use." + ) + return True + + def wake_up(self, tags: list[str] | None = None) -> bool: + "Physical video memory reloading logic" + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + allocator.wake_up(tags) + current_omni_platform.synchronize() + logger.info(f"[LLM Worker {self.rank}] Wake-up complete.") + return True + + def handle_sleep_task(self, task: OmniSleepTask) -> OmniACK: + "Handle deterministic Sleep command from the main process" + try: + if isinstance(task, dict): + task = OmniSleepTask(**task) + logger.info(f"[LLM Worker {self.rank}] Handshake Received: Task {task.task_id}, Level {task.level}") + if task.level == 2: + if hasattr(self.model_runner, "graph_runners"): + self.model_runner.graph_runners.clear() + logger.info(f"[LLM Worker {self.rank}] LLM CUDA Graphs cleared.") + mem_before = current_omni_platform.get_current_memory_usage(self.device) + self.sleep(level=task.level) + mem_after = current_omni_platform.get_current_memory_usage(self.device) + rank_freed = max(0, mem_before - mem_after) + if torch.distributed.is_initialized(): + t_freed = torch.tensor([float(rank_freed)], device=self.device) + torch.distributed.all_reduce(t_freed) + total_freed = int(t_freed.item()) + torch.distributed.barrier() + else: + total_freed = rank_freed + if self.rank != 0: + return None + current_stage_id = getattr(self.vllm_config.model_config, "stage_id", 0) + ack = OmniACK( + task_id=task.task_id, + status="SUCCESS", + stage_id=current_stage_id, + rank=self.rank, + freed_bytes=total_freed, + metadata={ + "source": "omni_platform_audit", + "total_freed_gib": f"{total_freed / 1024**3:.2f}", + "rank_residual_gib": f"{mem_after / 1024**3:.2f}", + }, + ) + if hasattr(self, "result_mq") and self.result_mq: + self.result_mq.put(ack) + logger.info(f"[LLM Worker {self.rank}] ACK emitted for Task {task.task_id}") + return ack + except Exception as e: + logger.error(f"[LLM Worker {self.rank}] Sleep Task Failed: {e}", exc_info=True) + if torch.distributed.is_initialized(): + try: + torch.distributed.barrier() + except Exception: + pass + return OmniACK(task_id=task.task_id, status="ERROR", error_msg=str(e)) + + def handle_wake_task(self, task: OmniWakeTask) -> OmniACK: + "Handle deterministic Wakeup command from the main process" + try: + if isinstance(task, dict): + task = OmniWakeTask(**task) + self.wake_up(tags=task.tags) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + gc.collect() + current_omni_platform.synchronize() + usage_now = current_omni_platform.get_current_memory_usage(self.device) + if self.rank != 0: + return None + current_stage_id = getattr(self.vllm_config.model_config, "stage_id", 0) + ack = OmniACK( + task_id=task.task_id, + status="SUCCESS", + stage_id=current_stage_id, + rank=self.rank, + metadata={"state": "WARM", "current_vram_gib": f"{usage_now / 1024**3:.2f}"}, + ) + if hasattr(self, "result_mq") and self.result_mq: + self.result_mq.put(ack) + logger.info(f"[LLM Worker {self.rank}] Wake-up ACK emitted.") + return ack + except Exception as e: + logger.error(f"[LLM Worker {self.rank}] Wake-up Failed: {e}", exc_info=True) + if torch.distributed.is_initialized(): + try: + torch.distributed.barrier() + except Exception: + pass + tid = task.task_id if hasattr(task, "task_id") else "unknown" + return OmniACK(task_id=tid, status="ERROR", error_msg=str(e)) diff --git a/vllm_omni/worker/gpu_ar_worker.py b/vllm_omni/worker/gpu_ar_worker.py index 4abe21964b3..2668bbe8b3a 100644 --- a/vllm_omni/worker/gpu_ar_worker.py +++ b/vllm_omni/worker/gpu_ar_worker.py @@ -12,6 +12,7 @@ from vllm.v1.worker.utils import request_memory from vllm.v1.worker.workspace import init_workspace_manager +from vllm_omni.diffusion.data import OmniACK, OmniSleepTask, OmniWakeTask from vllm_omni.worker.base import OmniGPUWorkerBase from vllm_omni.worker.gpu_ar_model_runner import GPUARModelRunner from vllm_omni.worker.mixins import OmniWorkerMixin @@ -104,3 +105,23 @@ def init_device(self): if self.rank == 0: # If usage stat is enabled, collect relevant info. report_usage_stats(self.vllm_config) + + def handle_sleep_task(self, task: OmniSleepTask | dict) -> OmniACK: + """ + Explicitly handle sleep commands. + Calls the implementation in the base class OmniGPUWorkerBase. + """ + logger.debug(f"[AR Worker {self.rank}] Resolving handle_sleep_task dispatch") + if isinstance(task, dict): + task = OmniSleepTask(**task) + return super().handle_sleep_task(task) + + def handle_wake_task(self, task: OmniWakeTask | dict) -> OmniACK: + """ + Explicitly handle wake-up commands. + Calls the implementation in the base class OmniGPUWorkerBase. + """ + logger.debug(f"[AR Worker {self.rank}] Resolving handle_wake_task dispatch") + if isinstance(task, dict): + task = OmniWakeTask(**task) + return super().handle_wake_task(task) diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index d1c15eac640..d3ccbaaf303 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -152,8 +152,10 @@ def _init_mrope_positions(self, req_state: CachedRequestState): if supports_mrope(self.get_model()): # Model implements SupportsMRoPE interface # Pass all extracted metadata; models use what they need via **kwargs - req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( - req_state.prompt_token_ids, + sp_extra_args = getattr(req_state.sampling_params, "extra_args", {}) if req_state.sampling_params else {} + target_h = sp_extra_args.get("target_h") if isinstance(sp_extra_args, dict) else None + target_w = sp_extra_args.get("target_w") if isinstance(sp_extra_args, dict) else None + kwargs = dict( mm_features=req_state.mm_features, hf_config=self.model_config.hf_config, image_grid_thw=image_grid_thw, @@ -162,6 +164,14 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + if target_h is not None: + kwargs["target_h"] = target_h + if target_w is not None: + kwargs["target_w"] = target_w + req_state.mrope_positions, req_state.mrope_position_delta = self.model.get_mrope_input_positions( + req_state.prompt_token_ids, + **kwargs, + ) else: req_state.mrope_positions, req_state.mrope_position_delta = MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids, @@ -1345,10 +1355,22 @@ def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Te req_embeds = self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded] last_talker_hidden = self.last_talker_hidden.gpu[:num_tokens_padded] text_step = self.text_step.gpu[:num_tokens_padded] + subtalker_params = getattr(self.vllm_config.model_config, "subtalker_sampling_params", None) + if not isinstance(subtalker_params, dict): + subtalker_params = {} with set_forward_context( None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc ): - req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + req_embeds, code_predictor_codes = self.talker_mtp( + req_input_ids, + req_embeds, + last_talker_hidden, + text_step, + do_sample=subtalker_params.get("do_sample"), + temperature=subtalker_params.get("temperature"), + top_k=subtalker_params.get("top_k"), + top_p=subtalker_params.get("top_p"), + ) # code_predictor_codes stays on GPU here; _update_intermediate_buffer # keeps it device-resident when the key is in gpu_resident_buffer_keys. # D2H is deferred to sample_tokens where hidden_states.to("cpu") already