diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml
index 02d8cced403..c33d7b4d10d 100644
--- a/.buildkite/test-nightly.yml
+++ b/.buildkite/test-nightly.yml
@@ -552,7 +552,7 @@ steps:
- label: ":full_moon: Diffusion X2V · Accuracy Test"
timeout_in_minutes: 180
commands:
- - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model
+ - pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py -m advanced_model --run-level advanced_model
agents:
queue: "mithril-h100-pool"
plugins:
diff --git a/.buildkite/test-ready.yml b/.buildkite/test-ready.yml
index 68f8e615286..3ca1747fe64 100644
--- a/.buildkite/test-ready.yml
+++ b/.buildkite/test-ready.yml
@@ -367,6 +367,33 @@ steps:
volumes:
- "/fsx/hf_cache:/fsx/hf_cache"
+ - label: "Qwen3-TTS Base E2E Test (ModelRunner V2)"
+ depends_on: upload-ready-pipeline
+ soft_fail:
+ - exit_status: 1
+ commands:
+ - |
+ timeout 20m bash -c '
+ export VLLM_LOGGING_LEVEL=DEBUG
+ export VLLM_WORKER_MULTIPROC_METHOD=spawn
+ export VLLM_ALLOW_LONG_MAX_MODEL_LEN="1"
+ export VLLM_OMNI_USE_V2_RUNNER="1"
+ pytest -s -v tests/e2e/online_serving/test_qwen3_tts_base.py -m "core_model" --run-level "core_model"
+ '
+ agents:
+ queue: "gpu_1_queue"
+ plugins:
+ - docker#v5.2.0:
+ image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
+ always-pull: true
+ propagate-environment: true
+ shm-size: "8gb"
+ environment:
+ - "HF_HOME=/fsx/hf_cache"
+ - "HF_TOKEN"
+ volumes:
+ - "/fsx/hf_cache:/fsx/hf_cache"
+
- label: "Voxtral-TTS E2E Test"
timeout_in_minutes: 20
depends_on: upload-ready-pipeline
diff --git a/.claude/skills/vllm-omni-npu-upgrade/SKILL.md b/.claude/skills/vllm-omni-npu-upgrade/SKILL.md
new file mode 100644
index 00000000000..1ef7ab39301
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/SKILL.md
@@ -0,0 +1,300 @@
+---
+name: vllm-omni-npu-model-runner-upgrade
+description: "Upgrade vllm-omni NPU model runners (OmniNPUModelRunner, NPUARModelRunner, NPUGenerationModelRunner) to align with the latest vllm-ascend NPUModelRunner while preserving omni-specific logic."
+---
+
+# vLLM-Omni NPU Model Runner Upgrade Skill
+
+## Overview
+
+This skill guides the process of upgrading vllm-omni's NPU model runners to align with the latest vllm-ascend codebase while preserving omni-specific enhancements. The NPU runners are designed to run omni multimodal models (like Qwen3-Omni, Bagel, MiMoAudio) on Ascend NPUs.
+
+## File Structure
+
+### NPU Model Runner Files
+```
+vllm-omni/vllm_omni/platforms/npu/worker/
+├── __init__.py
+├── npu_model_runner.py # OmniNPUModelRunner (base class)
+├── npu_ar_model_runner.py # NPUARModelRunner (autoregressive)
+├── npu_ar_worker.py # AR worker
+├── npu_generation_model_runner.py # NPUGenerationModelRunner (diffusion/non-AR)
+└── npu_generation_worker.py # Generation worker
+```
+
+### GPU Reference Files (for omni-specific logic sync)
+```
+vllm-omni/vllm_omni/worker/
+├── __init__.py
+├── gpu_model_runner.py # OmniGPUModelRunner
+├── gpu_ar_model_runner.py # GPUARModelRunner
+├── gpu_ar_worker.py
+├── gpu_generation_model_runner.py
+├── gpu_generation_worker.py
+├── mixins.py
+├── base.py
+└── gpu_memory_utils.py
+```
+
+### vllm-ascend Reference Files
+```
+vllm-ascend/vllm_ascend/worker/
+├── model_runner_v1.py # NPUModelRunner (base class to copy from)
+├── npu_input_batch.py
+├── block_table.py
+├── pcp_utils.py
+└── worker.py
+```
+
+## Inheritance Hierarchy
+
+```
+ GPUModelRunner (vllm)
+ |
+ +----------------+----------------+
+ | |
+ OmniGPUModelRunner NPUModelRunner (vllm-ascend)
+ (vllm_omni/worker) (vllm_ascend/worker)
+ | |
+ +----------- OmniNPUModelRunner --+
+ (multiple inheritance)
+ |
+ +---------------+---------------+
+ | |
+ NPUARModelRunner NPUGenerationModelRunner
+ (autoregressive) (non-autoregressive/diffusion)
+```
+
+## Omni-Specific Comment Markers
+
+Omni-specific logic is marked with comment blocks:
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# ... omni-specific code ...
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+Or simpler variations:
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# ------------------------------------------------------------------------------------------------
+```
+
+**Important**:
+- Always preserve and add these markers when modifying code.
+- **The reference documents (`references/omni-specific-blocks.md`) may not be up-to-date.** Always grep for `Omni-new` in the GPU implementations to find the authoritative list of omni-specific blocks.
+- When you discover new omni-specific code that is not documented in the references, please update the reference files.
+
+## Key Methods Requiring Attention
+
+### OmniNPUModelRunner (npu_model_runner.py)
+
+| Method | Description | Omni-Specific Logic |
+|--------|-------------|---------------------|
+| `load_model` | Load model and initialize talker_mtp | Uses `ACLGraphWrapper` instead of `CUDAGraphWrapper`, initializes talker buffers |
+| `_dummy_run` | Warmup/profiling run | talker_mtp dummy forward, `extract_multimodal_outputs` |
+| `_model_forward` | Forward pass wrapper | Injects `model_kwargs_extra`, wraps with `OmniOutput`, NPU-specific graph updates |
+| `_talker_mtp_forward` | Talker MTP forward for Qwen3-Omni | Uses `set_ascend_forward_context` |
+
+### NPUARModelRunner (npu_ar_model_runner.py)
+
+| Method | Description | Omni-Specific Logic |
+|--------|-------------|---------------------|
+| `__init__` | Initialize with KV transfer manager | `OmniKVTransferManager` setup |
+| `execute_model` | Main inference entry | KV transfer handling, `_update_states` override, `extract_multimodal_outputs` |
+| `sample_tokens` | Token sampling | Hidden states extraction, multimodal outputs processing, `OmniModelRunnerOutput` |
+| `_resolve_global_request_id` | Request ID resolution | For disaggregated inference |
+
+### NPUGenerationModelRunner (npu_generation_model_runner.py)
+
+| Method | Description | Omni-Specific Logic |
+|--------|-------------|---------------------|
+| `_update_request_states` | Update request states for async chunk | async_chunk handling |
+| `execute_model` | Generation forward | async_chunk, `seq_token_counts`, `_run_generation_model` |
+| `sample_tokens` | Output processing | multimodal output packaging to `OmniModelRunnerOutput` |
+| `_dummy_run` | Dummy run override | model_kwargs initialization, multimodal extraction |
+| `_run_generation_model` | Run generation model | Calls `_model_forward` with sampler |
+
+## Upgrade Workflow
+
+### Step 1: Preparation
+
+1. **Identify target versions**(Use gh cli to check):
+ - We're using vllm-omni main branch
+ - Check the last release of vllm-omni
+ - Target vllm-ascend version(Just directly use the local latest vllm-ascend code)
+
+2. **Check GPU-side changes** (since last release):
+ ```bash
+ cd /root/vllm-workspace/vllm-omni
+ git log --oneline --since="" -- vllm_omni/worker/
+ ```
+
+3. **Read latest vllm-ascend code**:
+ - We don't track vllm-ascend changes - just directly use the latest code from `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py`
+ - Copy the relevant methods and re-insert omni-specific blocks
+
+### Step 2: Analyze Omni-Specific Logic
+
+For each NPU model runner file:
+
+1. **Extract existing omni-specific blocks**:
+ ```bash
+ grep -n "Omni-new" vllm_omni/platforms/npu/worker/npu_model_runner.py
+ ```
+
+2. **Document each omni block**:
+ - Which method it belongs to
+ - What functionality it provides
+ - Dependencies on other omni code
+
+### Step 3: Update Base Class (OmniNPUModelRunner)
+
+**Note**: Always check the GPU implementation `gpu_model_runner.py` for any new omni logic not yet documented in references.
+
+1. **Read the latest vllm-ascend `NPUModelRunner.load_model`**
+2. **Copy the method, keeping the structure**
+3. **Re-insert omni-specific logic** (check GPU `gpu_model_runner.py` for authoritative list):
+ - Replace `CUDAGraphWrapper` with `ACLGraphWrapper`
+ - Keep talker_mtp initialization
+ - Preserve buffer allocations for talker
+ - Check for any new omni blocks added since last sync
+
+4. **Update `_dummy_run`**:
+ - Copy from vllm-ascend
+ - Compare with GPU `_dummy_run` for omni-specific blocks
+ - Re-insert all `Omni-new` marked code from GPU version
+
+5. **Update `_model_forward`**:
+ - Keep the omni wrapper logic
+ - Update NPU-specific parts (graph params, SP all-gather)
+ - Check GPU version for any new omni logic
+
+### Step 4: Update AR Model Runner
+
+1. **Compare with GPU `gpu_ar_model_runner.py`** for any new omni features
+2. **Copy `execute_model` from vllm-ascend**
+3. **Re-insert omni blocks** (reference `references/omni-specific-blocks.md`, but note it may be incomplete):
+ - **IMPORTANT**: Always check the GPU implementation `gpu_ar_model_runner.py` for all `Omni-new` marked code blocks
+ - The reference doc may not include newly added omni logic - treat it as a starting point, not exhaustive
+ - When discovering new omni code blocks, please update `references/omni-specific-blocks.md`
+ - Common omni blocks include but are not limited to: KV transfer, multimodal outputs, sampling_metadata handling, etc.
+
+4. **Update `sample_tokens`** (also compare with GPU implementation):
+ - Compare with `gpu_ar_model_runner.py`'s `sample_tokens` method
+ - Identify all `Omni-new` marked code blocks
+ - Ensure NPU version includes all omni-specific logic
+
+### Step 5: Update Generation Model Runner
+
+**Note**: Generation model runner may have unique omni logic for diffusion/non-AR models.
+
+1. **Compare with GPU `gpu_generation_model_runner.py`** - grep for all `Omni-new` blocks
+2. **Update `execute_model`**:
+ - Check GPU version for all omni-specific blocks
+ - Keep async_chunk handling
+ - Keep `seq_token_counts` injection
+ - Update forward/context setup from vllm-ascend
+ - Look for any new omni logic not documented in references
+
+3. **Update `_dummy_run`**:
+ - Copy from vllm-ascend base
+ - Compare with GPU `_dummy_run` if exists
+ - Re-insert all omni-specific logic
+
+### Step 6: Update Imports
+
+Check and update imports at the top of each file:
+
+```python
+# Common vllm-ascend imports
+from vllm_ascend.ascend_forward_context import get_forward_context, set_ascend_forward_context
+from vllm_ascend.attention.attention_v1 import AscendAttentionState
+from vllm_ascend.attention.utils import using_paged_attention
+from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
+from vllm_ascend.ops.rotary_embedding import update_cos_sin
+from vllm_ascend.utils import enable_sp, lmhead_tp_enable
+from vllm_ascend.worker.model_runner_v1 import SEQ_LEN_WITH_MAX_PA_WORKSPACE, NPUModelRunner
+
+# Omni-specific imports
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
+from vllm_omni.outputs import OmniModelRunnerOutput
+from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager
+```
+
+### Step 7: Sync GPU-Side Omni Changes
+
+1. **Check recent GPU worker changes**:
+ ```bash
+ git diff .. -- vllm_omni/worker/gpu_model_runner.py
+ git diff .. -- vllm_omni/worker/gpu_ar_model_runner.py
+ ```
+
+2. **Identify new omni features** that need to be ported to NPU
+
+3. **Apply corresponding changes** to NPU runners
+
+### Step 8: Validation
+
+1. **Run type checking**:
+ ```bash
+ cd /root/vllm-workspace/vllm-omni
+ python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py
+ python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py
+ python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py
+ ```
+
+2. **Run import test**:
+ ```bash
+ python -c "from vllm_omni.platforms.npu.worker import *"
+ ```
+
+3. **Run model serving test** (if hardware available):
+ ```bash
+ vllm serve --trust-remote-code
+ ```
+
+## Common Pitfalls
+
+### 1. Forward Context Differences
+- GPU uses `set_forward_context`
+- NPU uses `set_ascend_forward_context`
+- Parameters may differ slightly
+
+### 2. Graph Wrapper Differences
+- GPU: `CUDAGraphWrapper`
+- NPU: `ACLGraphWrapper`
+- Constructor parameters may differ
+
+### 3. Buffer Creation
+- GPU: `_make_buffer` returns different structure
+- NPU: May need numpy=True/False parameter
+
+### 4. Attention Metadata
+- GPU: Uses vllm attention metadata builders
+- NPU: Uses `AscendCommonAttentionMetadata`
+
+### 5. Sampling
+- GPU: Uses vllm sampler
+- NPU: Uses `AscendSampler`
+
+## Checklist Before Commit
+
+- [ ] All omni-specific comment markers preserved
+- [ ] New omni logic from GPU side synced
+- [ ] Imports updated to latest vllm-ascend
+- [ ] No `CUDAGraphWrapper` references in NPU code
+- [ ] `set_ascend_forward_context` used instead of `set_forward_context`
+- [ ] `ACLGraphWrapper` used for talker_mtp wrapping
+- [ ] Type hints match vllm-ascend signatures
+- [ ] No duplicate code blocks
+- [ ] Python syntax valid (py_compile passes)
+
+## Reference Files for Comparison
+
+When upgrading, keep these files open for reference:
+
+1. **vllm-ascend NPUModelRunner**: `/root/vllm-workspace/vllm-ascend/vllm_ascend/worker/model_runner_v1.py`
+2. **vllm GPUModelRunner**: `/root/vllm-workspace/vllm/vllm/v1/worker/gpu_model_runner.py`
+3. **vllm-omni OmniGPUModelRunner**: `/root/vllm-workspace/vllm-omni/vllm_omni/worker/gpu_model_runner.py`
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md b/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md
new file mode 100644
index 00000000000..89067d37b2d
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/references/gpu-to-npu-translation.md
@@ -0,0 +1,335 @@
+# GPU to NPU Translation Patterns
+
+This document provides a quick reference for translating GPU code patterns to NPU equivalents when porting omni-specific logic.
+
+## Import Translations
+
+### Forward Context
+```python
+# GPU
+from vllm.forward_context import set_forward_context
+
+# NPU
+from vllm_ascend.ascend_forward_context import set_ascend_forward_context
+```
+
+### Graph Wrapper
+```python
+# GPU
+from vllm.compilation.cuda_graph import CUDAGraphWrapper
+
+# NPU
+from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
+```
+
+### Attention State
+```python
+# GPU (no equivalent - uses FlashAttention states directly)
+
+# NPU
+from vllm_ascend.attention.attention_v1 import AscendAttentionState
+```
+
+### Utilities
+```python
+# GPU
+# (directly use torch.cuda functions)
+
+# NPU
+from vllm_ascend.utils import enable_sp, lmhead_tp_enable
+from vllm_ascend.ops.rotary_embedding import update_cos_sin
+```
+
+## Context Manager Translations
+
+### Forward Context Setup
+```python
+# GPU
+with set_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_tokens_padded,
+ num_tokens_across_dp=num_tokens_across_dp,
+ cudagraph_runtime_mode=cudagraph_mode,
+ batch_descriptor=batch_desc,
+):
+ # forward pass
+
+# NPU
+with set_ascend_forward_context(
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_tokens_padded,
+ num_tokens_across_dp=num_tokens_across_dp,
+ aclgraph_runtime_mode=cudagraph_mode, # Note: 'aclgraph' not 'cudagraph'
+ batch_descriptor=batch_desc,
+ num_actual_tokens=scheduler_output.total_num_scheduled_tokens,
+ model_instance=self.model,
+):
+ # forward pass
+```
+
+### Graph Capture Context
+```python
+# GPU
+from vllm.compilation.cuda_graph import graph_capture as cuda_graph_capture
+with cuda_graph_capture(self.device):
+ # capture
+
+# NPU
+from vllm_ascend.worker.model_runner_v1 import graph_capture
+with graph_capture(self.device):
+ # capture
+```
+
+## Graph Wrapper Usage
+
+### Creating Graph Wrapper
+```python
+# GPU
+if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ self.talker_mtp = CUDAGraphWrapper(
+ talker_mtp,
+ self.vllm_config,
+ runtime_mode=CUDAGraphMode.FULL
+ )
+
+# NPU
+if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ self.talker_mtp = ACLGraphWrapper(
+ talker_mtp,
+ self.vllm_config,
+ runtime_mode=CUDAGraphMode.FULL
+ )
+```
+
+### Checking Graph Wrapper Type
+```python
+# GPU
+if not isinstance(self.talker_mtp, CUDAGraphWrapper):
+ _cudagraph_mode = CUDAGraphMode.NONE
+
+# NPU
+if not isinstance(self.talker_mtp, ACLGraphWrapper):
+ _cudagraph_mode = CUDAGraphMode.NONE
+```
+
+## Device Operations
+
+### Synchronization
+```python
+# GPU
+torch.cuda.synchronize()
+
+# NPU
+torch.npu.synchronize()
+```
+
+### Stream Operations
+```python
+# GPU
+stream = torch.cuda.Stream(device=device)
+torch.cuda.current_stream()
+
+# NPU
+stream = torch.npu.Stream(device=device)
+torch.npu.current_stream()
+```
+
+## Attention Metadata
+
+### State Setting (NPU-specific)
+```python
+# GPU - handled internally by attention backends
+
+# NPU - explicit state setting required
+self.attn_state = AscendAttentionState.DecodeOnly
+if self.speculative_config and self.speculative_config.method == "mtp":
+ if self.vllm_config.model_config.use_mla:
+ self.attn_state = AscendAttentionState.SpecDecoding
+ else:
+ self.attn_state = AscendAttentionState.ChunkedPrefill
+```
+
+### Building Attention Metadata
+```python
+# GPU - uses vllm attention builders
+
+# NPU - may need additional parameters
+(attn_metadata, spec_decode_common_attn_metadata) = self._build_attention_metadata(
+ num_tokens=num_tokens_unpadded,
+ num_tokens_padded=num_tokens_padded,
+ num_reqs=num_reqs,
+ num_reqs_padded=num_reqs_padded,
+ max_query_len=max_num_scheduled_tokens,
+ ubatch_slices=ubatch_slices_attn,
+ logits_indices=logits_indices,
+ use_spec_decode=use_spec_decode,
+ num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
+ num_scheduled_tokens_np=num_scheduled_tokens_np,
+ cascade_attn_prefix_lens=cascade_attn_prefix_lens,
+)
+```
+
+## Rotary Embedding
+
+### Update Cos/Sin Cache
+```python
+# GPU - typically handled inside attention
+
+# NPU - explicit update required before forward
+from vllm_ascend.ops.rotary_embedding import update_cos_sin
+update_cos_sin(positions)
+```
+
+## Sequence Parallelism
+
+### Enable SP Check
+```python
+# GPU - use vllm distributed utilities
+
+# NPU - use vllm-ascend wrapper
+from vllm_ascend.utils import enable_sp
+
+if enable_sp():
+ # sequence parallelism enabled
+```
+
+## Sampler
+
+### Sampler Type
+```python
+# GPU - uses vllm sampler
+self.sampler = Sampler()
+
+# NPU - uses AscendSampler
+from vllm_ascend.sample.sampler import AscendSampler
+self.sampler = AscendSampler()
+```
+
+## Input Batch
+
+### Batch Class
+```python
+# GPU
+from vllm.v1.worker.gpu_input_batch import InputBatch
+
+# NPU
+from vllm_ascend.worker.npu_input_batch import NPUInputBatch
+```
+
+## Graph Parameter Updates
+
+### Full Graph Params Update (NPU-specific)
+```python
+# GPU - not needed
+
+# NPU - required for FULL graph mode
+from vllm_ascend.compilation.acl_graph import update_full_graph_params
+
+forward_context = get_forward_context()
+if (
+ forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL
+ and not forward_context.capturing
+ and not self.use_sparse
+):
+ update_full_graph_params(
+ self.attn_backend,
+ self.update_stream,
+ forward_context,
+ num_tokens_padded,
+ self.vllm_config,
+ self.speculative_config,
+ positions.shape[0],
+ )
+```
+
+## Paged Attention Check
+
+```python
+# GPU - not typically needed
+
+# NPU
+from vllm_ascend.attention.utils import using_paged_attention
+
+if is_graph_capturing and using_paged_attention(num_tokens, self.vllm_config):
+ seq_lens = SEQ_LEN_WITH_MAX_PA_WORKSPACE
+```
+
+## Common Method Signature Differences
+
+### _dummy_run Parameters
+```python
+# GPU (v0.17.0)
+def _dummy_run(
+ self,
+ num_tokens: int,
+ cudagraph_runtime_mode: CUDAGraphMode | None = None,
+ force_attention: bool = False,
+ uniform_decode: bool = False,
+ allow_microbatching: bool = True,
+ skip_eplb: bool = False,
+ is_profile: bool = False,
+ create_mixed_batch: bool = False,
+ remove_lora: bool = True,
+ is_graph_capturing: bool = False,
+ num_active_loras: int = 0,
+) -> tuple[torch.Tensor, torch.Tensor]:
+
+# NPU (v0.17.0) - adds with_prefill, activate_lora
+def _dummy_run(
+ self,
+ num_tokens: int,
+ with_prefill: bool = False,
+ cudagraph_runtime_mode: CUDAGraphMode | None = None,
+ force_attention: bool = False,
+ uniform_decode: bool = False,
+ is_profile: bool = False,
+ create_mixed_batch: bool = False,
+ allow_microbatching: bool = True,
+ skip_eplb: bool = False,
+ remove_lora: bool = True,
+ activate_lora: bool = False,
+ is_graph_capturing: bool = False,
+ num_active_loras: int = 0,
+) -> tuple[torch.Tensor, torch.Tensor]:
+```
+
+### _model_forward Parameters
+```python
+# GPU - no num_tokens_padded
+def _model_forward(
+ self,
+ input_ids: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **model_kwargs: dict[str, Any],
+):
+
+# NPU - has num_tokens_padded as first parameter
+def _model_forward(
+ self,
+ num_tokens_padded: int,
+ input_ids: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **model_kwargs: dict[str, Any],
+):
+```
+
+## Quick Reference Table
+
+| Feature | GPU | NPU |
+|---------|-----|-----|
+| Graph wrapper | `CUDAGraphWrapper` | `ACLGraphWrapper` |
+| Forward context | `set_forward_context` | `set_ascend_forward_context` |
+| Runtime mode param | `cudagraph_runtime_mode` | `aclgraph_runtime_mode` |
+| Device sync | `torch.cuda.synchronize()` | `torch.npu.synchronize()` |
+| Stream | `torch.cuda.Stream` | `torch.npu.Stream` |
+| Current stream | `torch.cuda.current_stream()` | `torch.npu.current_stream()` |
+| Input batch | `InputBatch` | `NPUInputBatch` |
+| Sampler | `Sampler` | `AscendSampler` |
+| Attention state | N/A | `AscendAttentionState` |
+| RoPE update | N/A | `update_cos_sin()` |
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md b/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md
new file mode 100644
index 00000000000..8c5d32ab4c1
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/references/omni-specific-blocks.md
@@ -0,0 +1,374 @@
+# Omni-Specific Code Blocks Reference
+
+This document catalogs omni-specific code blocks in the NPU model runners, making it easier to identify what needs to be preserved during upgrades.
+
+> **IMPORTANT**: This document may not be complete or up-to-date!
+>
+> - Always grep for `Omni-new` in the GPU implementations (`vllm_omni/worker/`) to find the authoritative list
+> - New omni features may be added that are not yet documented here
+> - When you discover new omni-specific blocks during an upgrade, please update this document
+> - Last verified: Check git history for this file
+
+## OmniNPUModelRunner (npu_model_runner.py)
+
+### load_model - Talker MTP Initialization
+
+```python
+def load_model(self, *args, **kwargs) -> None:
+ NPUModelRunner.load_model(self, *args, **kwargs)
+ # Initialize enable_sp cache to avoid get_current_vllm_config() error
+ # in _pad_for_sequence_parallelism during execute_model.
+ # This is a workaround for vllm-ascend not passing vllm_config to enable_sp().
+ enable_sp(self.vllm_config)
+ # TODO move this model specific logic to a separate class
+ # TTS model IS the talker (no .talker sub-attr); use getattr to support both Omni and TTS.
+ talker_mtp = getattr(self.model, "talker_mtp", None)
+ if talker_mtp is not None:
+ self.talker_mtp = talker_mtp # type: ignore[assignment]
+ cudagraph_mode = self.compilation_config.cudagraph_mode
+ assert cudagraph_mode is not None
+ # Only wrap talker_mtp in CUDAGraphWrapper for Omni models that
+ # have a separate .talker sub-module. TTS models' code predictor
+ # has internal AR loops / torch.multinomial — not graph-safe.
+ has_separate_talker = getattr(self.model, "talker", None) is not None
+ if cudagraph_mode.has_full_cudagraphs() and has_separate_talker:
+ # NOTE: Use ACLGraphWrapper on NPU, not CUDAGraphWrapper
+ self.talker_mtp = ACLGraphWrapper(talker_mtp, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
+ # TTS exposes mtp_hidden_size; Omni uses hf_text_config.hidden_size.
+ hidden_size = int(
+ getattr(self.model, "mtp_hidden_size", 0) or getattr(self.model_config.hf_text_config, "hidden_size")
+ )
+ max_batch_size = max(self.max_num_reqs, self.compilation_config.max_cudagraph_capture_size)
+ self.talker_mtp_input_ids = self._make_buffer(max_batch_size, dtype=torch.int32)
+ self.talker_mtp_inputs_embeds = self._make_buffer(
+ max_batch_size, hidden_size, dtype=self.dtype, numpy=False
+ )
+ self.last_talker_hidden = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
+ self.text_step = self._make_buffer(max_batch_size, hidden_size, dtype=self.dtype, numpy=False)
+```
+
+### _dummy_run - Talker MTP Dummy Forward
+
+Location: Inside `set_ascend_forward_context` block, before main model forward
+
+```python
+# ---------------------------------------Omni-new----------------------------------------------
+if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"):
+ num_tokens_padded_talker_mtp = num_tokens_padded
+ if num_tokens_padded_talker_mtp == self.max_num_tokens:
+ num_tokens_padded_talker_mtp = self.talker_mtp_input_ids.gpu.shape[0]
+ outputs = self.talker_mtp(
+ self.talker_mtp_input_ids.gpu[:num_tokens_padded_talker_mtp],
+ self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded_talker_mtp],
+ self.last_talker_hidden.gpu[:num_tokens_padded_talker_mtp],
+ self.text_step.gpu[:num_tokens_padded_talker_mtp],
+ )
+ self.compilation_config.cache_dir = None
+# ---------------------------------------Omni-new----------------------------------------------
+```
+
+### _dummy_run - Extract Multimodal Outputs
+
+Location: After model forward, before dummy_compute_logits
+
+```python
+# ---------------------------------------Omni-new----------------------------------------------
+hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
+# ---------------------------------------Omni-new----------------------------------------------
+```
+
+### _model_forward - Omni Output Wrapping
+
+```python
+def _model_forward(
+ self,
+ num_tokens_padded: int,
+ input_ids: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **model_kwargs: dict[str, Any],
+):
+ """Override to combine NPUModelRunner's signature with OmniGPUModelRunner's logic."""
+ # Omni-specific: build and inject extra model kwargs
+ model_kwargs_extra = self._build_model_kwargs_extra()
+
+ # Call the model forward (same as NPUModelRunner)
+ assert self.model is not None
+ model_output = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **model_kwargs,
+ **model_kwargs_extra,
+ )
+
+ # Omni-specific: wrap output if needed
+ if not isinstance(model_output, OmniOutput) and hasattr(self.model, "make_omni_output"):
+ model_output = self.model.make_omni_output(model_output, **model_kwargs_extra)
+
+ # Omni-specific: cache model output for later sample_tokens
+ self._omni_last_model_output = model_output
+
+ # NPU-specific: update full graph params (keep from vllm-ascend)
+ forward_context = get_forward_context()
+ # ... NPU graph update logic ...
+
+ # NPU-specific: all-gather for sequence parallelism (keep from vllm-ascend)
+ if get_forward_context().sp_enabled and not isinstance(model_output, IntermediateTensors):
+ model_output = self._all_gather_hidden_states_and_aux(model_output)
+
+ return model_output
+```
+
+---
+
+## NPUARModelRunner (npu_ar_model_runner.py)
+
+### __init__ - KV Transfer Manager
+
+```python
+def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
+ # each model stage has their own hidden size
+ self.hidden_size = self.model_config.hf_text_config.hidden_size
+ self.inputs_embeds = self._make_buffer(self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False)
+ # Initialize KV cache manager (preserve vllm_config fallback behavior)
+ self.kv_transfer_manager = OmniKVTransferManager.from_vllm_config(self.vllm_config, self.model_config)
+```
+
+### execute_model - KV Transfer Before Update States
+
+Location: At the very beginning of execute_model
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# [Omni] Handle KV transfer BEFORE updating states (which removes finished requests)
+self.kv_extracted_req_ids = self.kv_transfer_manager.handle_finished_requests_kv_transfer(
+ finished_reqs=getattr(scheduler_output, "finished_requests_needing_kv_transfer", {}),
+ kv_caches=self.kv_caches,
+ block_size=self.cache_config.block_size,
+ cache_dtype=str(self.cache_config.cache_dtype),
+ request_id_resolver=self._resolve_global_request_id,
+)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### execute_model - Custom _update_states Call
+
+Location: Inside synchronize_input_prep context
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+self._update_states(scheduler_output)
+# ------------------------------------------------------------------------------------------------
+```
+
+### execute_model - Extract Multimodal Outputs
+
+Location: In post process section, after hidden_states assignment
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+hidden_states, multimodal_outputs = self.extract_multimodal_outputs(hidden_states)
+
+if multimodal_outputs is not None:
+ keys_or_type = (
+ list(multimodal_outputs.keys())
+ if isinstance(multimodal_outputs, dict)
+ else type(multimodal_outputs)
+ )
+ logger.debug(f"[AR] execute_model: multimodal_outputs keys = {keys_or_type}")
+else:
+ logger.debug("[AR] execute_model: multimodal_outputs is None")
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### execute_model - Compute Logits with sampling_metadata
+
+Location: In both broadcast_pp_output True and False branches
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+# Try with sampling_metadata first; fall back to without for models that don't support it
+try:
+ logits = self.model.compute_logits(
+ sample_hidden_states, sampling_metadata=self.input_batch.sampling_metadata
+ )
+except TypeError:
+ logits = self.model.compute_logits(sample_hidden_states)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### sample_tokens - KV Extracted Req IDs
+
+Location: At the beginning of sample_tokens
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+kv_extracted_req_ids = getattr(self, "kv_extracted_req_ids", None)
+self.kv_extracted_req_ids = None
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### sample_tokens - Process Additional Information and Build Output
+
+Location: After bookkeeping sync, replacing the original output construction
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+hidden_states_cpu = hidden_states.detach().to("cpu").contiguous()
+num_scheduled_tokens_np = getattr(self, "_omni_num_scheduled_tokens_np", None)
+if num_scheduled_tokens_np is None:
+ req_ids = self.input_batch.req_ids
+ num_scheduled_tokens_np = np.array(
+ [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids],
+ dtype=np.int32,
+ )
+
+self._process_additional_information_updates(
+ hidden_states, multimodal_outputs, num_scheduled_tokens_np, scheduler_output
+)
+
+pooler_output: list[dict[str, object]] = []
+for rid in req_ids_output_copy:
+ idx = req_id_to_index_output_copy[rid]
+ start = int(self.query_start_loc.cpu[idx])
+ sched = int(num_scheduled_tokens_np[idx])
+ end = start + sched
+ hidden_slice = hidden_states_cpu[start:end]
+ payload: dict[str, object] = {"hidden": hidden_slice}
+ if isinstance(multimodal_outputs, dict) and multimodal_outputs:
+ # ... multimodal output slicing logic ...
+ pooler_output.append(payload)
+
+model_runner_output = OmniModelRunnerOutput(
+ req_ids=req_ids_output_copy,
+ req_id_to_index=req_id_to_index_output_copy,
+ sampled_token_ids=valid_sampled_token_ids,
+ logprobs=logprobs_lists,
+ prompt_logprobs_dict=prompt_logprobs_dict,
+ pooler_output=(pooler_output if self.vllm_config.model_config.engine_output_type != "text" else None),
+ kv_connector_output=kv_connector_output,
+)
+model_runner_output.kv_extracted_req_ids = kv_extracted_req_ids
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+---
+
+## NPUGenerationModelRunner (npu_generation_model_runner.py)
+
+### execute_model - Async Chunk Update
+
+Location: Inside prepare input section, before synchronize_input_prep
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+if self.model_config.async_chunk and num_scheduled_tokens:
+ self._update_request_states(scheduler_output)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### execute_model - Seq Token Counts
+
+Location: After _preprocess call
+
+```python
+# [Omni] Pass token counts per request for code2wav output slicing
+model_kwargs["seq_token_counts"] = tokens
+```
+
+### execute_model - Run Generation Model
+
+Location: Inside forward context
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+outputs = self._run_generation_model(
+ num_tokens_padded=num_tokens_padded,
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ model_kwargs=model_kwargs,
+ logits_indices=logits_indices,
+)
+_, multimodal_outputs = self.extract_multimodal_outputs(outputs)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### sample_tokens - Multimodal Output Processing
+
+The entire sample_tokens method body is omni-specific for generation models:
+
+```python
+# -------------------------------------- Omni-new -------------------------------------------------
+pooler_output: list[object] = []
+if isinstance(multimodal_outputs, torch.Tensor):
+ # ... tensor handling ...
+elif isinstance(multimodal_outputs, list):
+ # ... list handling ...
+elif isinstance(multimodal_outputs, dict):
+ # ... dict handling per request ...
+else:
+ raise RuntimeError("Unsupported diffusion output type")
+# [Omni] Copy req_id mappings to avoid async scheduling mutation.
+req_ids_output_copy = self.input_batch.req_ids.copy()
+req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
+output = OmniModelRunnerOutput(
+ req_ids=req_ids_output_copy,
+ req_id_to_index=req_id_to_index_output_copy,
+ sampled_token_ids=[],
+ logprobs=None,
+ prompt_logprobs_dict={},
+ pooler_output=pooler_output,
+ kv_connector_output=kv_connector_output,
+ num_nans_in_logits={},
+ ec_connector_output=ec_connector_output if self.supports_mm_inputs else None,
+)
+# -------------------------------------- Omni-new -------------------------------------------------
+```
+
+### _dummy_run - Model Kwargs Init and Multimodal Extract
+
+Location: Before model forward and after
+
+```python
+model_kwargs = self._init_model_kwargs() # Before forward
+
+# ... forward ...
+
+# -------------------------------------- Omni-new -------------------------------------------------
+hidden_states, _ = self.extract_multimodal_outputs(hidden_states)
+# -------------------------------------------------------------------------------------------------
+```
+
+---
+
+## ExecuteModelState Extension
+
+The `ExecuteModelState` NamedTuple is extended for omni:
+
+```python
+class ExecuteModelState(NamedTuple):
+ """Ephemeral cached state transferred between execute_model() and
+ sample_tokens(), after execute_model() returns None."""
+
+ scheduler_output: SchedulerOutput
+ logits: torch.Tensor
+ spec_decode_metadata: SpecDecodeMetadata | None
+ spec_decode_common_attn_metadata: AscendCommonAttentionMetadata | None
+ hidden_states: torch.Tensor
+ sample_hidden_states: torch.Tensor
+ aux_hidden_states: list[torch.Tensor] | None
+ attn_metadata: PerLayerAttnMetadata
+ positions: torch.Tensor
+ ec_connector_output: ECConnectorOutput | None
+ cudagraph_stats: CUDAGraphStat | None
+ multimodal_outputs: Any # <-- Omni extension
+```
+
+This extended state must be imported from `npu_ar_model_runner` in `npu_generation_model_runner`.
diff --git a/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md b/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md
new file mode 100644
index 00000000000..4f184df0ecb
--- /dev/null
+++ b/.claude/skills/vllm-omni-npu-upgrade/references/workflow-checklist.md
@@ -0,0 +1,222 @@
+# NPU Model Runner Upgrade Workflow Checklist
+
+> **Note**: Reference documents (`omni-specific-blocks.md`) may not be complete. Always grep for `Omni-new` in GPU implementations to find all omni-specific code blocks. Update the reference docs when discovering new blocks.
+
+## Pre-Upgrade Preparation
+
+### 1. Version Information
+- [ ] Identify current vllm-omni version: `_________`
+- [ ] Identify target vllm-ascend version: `_________`
+- [ ] Identify target vllm version: `_________`
+- [ ] Last release date for GPU worker changes: `_________`
+
+### 2. Gather Git History
+```bash
+# GPU-side omni changes since last release
+cd /root/vllm-workspace/vllm-omni
+git log --oneline --since="YYYY-MM-DD" -- vllm_omni/worker/
+
+# vllm-ascend NPUModelRunner changes
+cd /root/vllm-workspace/vllm-ascend
+git log --oneline .. -- vllm_ascend/worker/model_runner_v1.py
+```
+
+### 3. Backup Current Files
+- [ ] Create backup of current NPU runners:
+ ```bash
+ cp -r vllm_omni/platforms/npu/worker vllm_omni/platforms/npu/worker.backup
+ ```
+
+---
+
+## OmniNPUModelRunner (npu_model_runner.py)
+
+### Read and Understand
+- [ ] Read current `npu_model_runner.py`
+- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py`
+- [ ] Read latest `vllm_omni/worker/gpu_model_runner.py`
+
+### Method: load_model
+- [ ] Document existing omni-specific logic
+- [ ] Copy latest NPUModelRunner.load_model structure
+- [ ] Re-insert: `enable_sp(self.vllm_config)` call
+- [ ] Re-insert: talker_mtp detection and setup
+- [ ] Replace: `CUDAGraphWrapper` → `ACLGraphWrapper`
+- [ ] Re-insert: Buffer allocations (talker_mtp_input_ids, etc.)
+
+### Method: _dummy_run
+- [ ] Document existing omni-specific logic locations
+- [ ] Copy latest NPUModelRunner._dummy_run
+- [ ] Re-insert: talker_mtp dummy forward block (inside context)
+- [ ] Re-insert: `extract_multimodal_outputs` call
+- [ ] Verify: Comment markers are present
+
+### Method: _model_forward
+- [ ] Copy latest NPUModelRunner._model_forward structure
+- [ ] Re-insert: `_build_model_kwargs_extra()` call
+- [ ] Re-insert: OmniOutput wrapping logic
+- [ ] Re-insert: `_omni_last_model_output` caching
+- [ ] Keep: NPU graph params update
+- [ ] Keep: SP all-gather logic
+
+### Method: _talker_mtp_forward
+- [ ] Verify: Uses `set_ascend_forward_context`
+- [ ] Verify: Uses `ACLGraphWrapper` check
+- [ ] Sync any changes from GPU `_talker_mtp_forward`
+
+### Imports
+- [ ] Update vllm-ascend imports to latest paths
+- [ ] Verify all omni imports are present
+- [ ] Remove any deprecated imports
+
+---
+
+## NPUARModelRunner (npu_ar_model_runner.py)
+
+### Read and Understand
+- [ ] Read current `npu_ar_model_runner.py`
+- [ ] Read latest `vllm_ascend/worker/model_runner_v1.py` execute_model
+- [ ] Read latest `vllm_omni/worker/gpu_ar_model_runner.py`
+
+### Method: __init__
+- [ ] Sync any new initialization from GPU side
+- [ ] Keep: `OmniKVTransferManager` setup
+- [ ] Keep: Custom buffer allocations
+
+### Method: execute_model
+- [ ] Document all omni blocks with line numbers
+- [ ] Copy latest NPUModelRunner.execute_model structure
+- [ ] Re-insert: KV transfer handling (beginning)
+- [ ] Re-insert: Custom `_update_states` call
+- [ ] Re-insert: `extract_multimodal_outputs`
+- [ ] Re-insert: `compute_logits` with sampling_metadata try/except
+- [ ] Update: ExecuteModelState to include multimodal_outputs
+
+### Method: sample_tokens
+- [ ] Document all omni blocks
+- [ ] Copy latest NPUModelRunner.sample_tokens structure
+- [ ] Re-insert: `kv_extracted_req_ids` handling
+- [ ] Re-insert: Hidden states CPU copy
+- [ ] Re-insert: `_process_additional_information_updates`
+- [ ] Re-insert: `OmniModelRunnerOutput` construction
+
+### ExecuteModelState
+- [ ] Verify: `multimodal_outputs` field is present
+- [ ] Verify: Imported/used correctly in execute_model
+
+### Imports
+- [ ] Update all vllm-ascend imports
+- [ ] Keep omni-specific imports
+
+---
+
+## NPUGenerationModelRunner (npu_generation_model_runner.py)
+
+### Read and Understand
+- [ ] Read current `npu_generation_model_runner.py`
+- [ ] Read latest GPU `gpu_generation_model_runner.py`
+
+### Method: _update_request_states
+- [ ] Verify: async_chunk handling is correct
+- [ ] Sync any changes from GPU side
+
+### Method: execute_model
+- [ ] Document all omni blocks
+- [ ] Copy latest NPUModelRunner.execute_model base structure
+- [ ] Re-insert: async_chunk update logic
+- [ ] Re-insert: `seq_token_counts` injection
+- [ ] Re-insert: `_run_generation_model` call
+- [ ] Re-insert: `extract_multimodal_outputs`
+- [ ] Use: ExecuteModelState from npu_ar_model_runner
+
+### Method: sample_tokens
+- [ ] Keep: Entire omni multimodal output processing
+- [ ] Update: Any new output fields needed
+- [ ] Keep: `OmniModelRunnerOutput` construction
+
+### Method: _run_generation_model
+- [ ] Sync any changes from GPU side
+- [ ] Keep: `_model_forward` call with sampler
+
+### Method: _dummy_run
+- [ ] Copy latest NPUModelRunner._dummy_run
+- [ ] Re-insert: `model_kwargs = self._init_model_kwargs()`
+- [ ] Re-insert: `extract_multimodal_outputs` at end
+
+### Imports
+- [ ] Import ExecuteModelState from npu_ar_model_runner
+- [ ] Update vllm-ascend imports
+
+---
+
+## Post-Upgrade Validation
+
+### Syntax Validation
+- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_model_runner.py`
+- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_ar_model_runner.py`
+- [ ] `python -m py_compile vllm_omni/platforms/npu/worker/npu_generation_model_runner.py`
+
+### Import Validation
+- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_model_runner import OmniNPUModelRunner"`
+- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_ar_model_runner import NPUARModelRunner"`
+- [ ] `python -c "from vllm_omni.platforms.npu.worker.npu_generation_model_runner import NPUGenerationModelRunner"`
+
+### Comment Markers
+- [ ] Grep for "Omni-new" in all three files
+- [ ] Verify all omni blocks have closing markers
+
+### Code Review
+- [ ] No `CUDAGraphWrapper` references
+- [ ] All `set_forward_context` replaced with `set_ascend_forward_context`
+- [ ] Parameter names correct (`aclgraph_runtime_mode` not `cudagraph_runtime_mode`)
+- [ ] No duplicate code blocks
+- [ ] No missing imports
+
+---
+
+## Git Commit
+
+### Commit Message Template
+```
+[NPU] Upgrade model runners to align with vllm-ascend vX.Y.Z
+
+- Update OmniNPUModelRunner with latest NPUModelRunner base
+- Update NPUARModelRunner execute_model and sample_tokens
+- Update NPUGenerationModelRunner for async_chunk changes
+- Sync GPU-side omni changes from vX.Y.Z release
+- Preserve all omni-specific logic (marked with Omni-new comments)
+
+Changes from vllm-ascend:
+-
+
+Changes synced from GPU:
+-
+```
+
+### Files to Stage
+- [ ] `vllm_omni/platforms/npu/worker/npu_model_runner.py`
+- [ ] `vllm_omni/platforms/npu/worker/npu_ar_model_runner.py`
+- [ ] `vllm_omni/platforms/npu/worker/npu_generation_model_runner.py`
+- [ ] Any other modified files
+
+---
+
+## Troubleshooting
+
+### Import Errors
+- Check if vllm-ascend module paths have changed
+- Verify PYTHONPATH includes both vllm-ascend and vllm-omni
+
+### Type Errors
+- Check method signatures match between GPU and NPU
+- Verify NamedTuple fields match expected structure
+
+### Runtime Errors
+- Enable debug logging: `export VLLM_LOGGING_LEVEL=DEBUG`
+- Check graph capture issues: try `--enforce-eager`
+- Check attention issues: verify AscendAttentionState usage
+
+### Performance Regression
+- Compare with previous version on same model
+- Check if graph capture is working: look for ACLGraph logs
+- Verify SP/EP configurations are correct
diff --git a/.gitignore b/.gitignore
index c0ee968064c..35dc7571ee2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -203,6 +203,7 @@ checkpoints/
# Cache directories
cache/
!vllm_omni/diffusion/cache/
+!tests/diffusion/cache/
.cache/
diffusion_cache/
kv_cache/
@@ -262,3 +263,5 @@ tmp_test
vllm_omni/_version.py
# output files
*.wav
+# CI overlay yamls materialized from tests/utils.py:_CI_OVERLAYS at test time
+tests/.ci_generated/
diff --git a/benchmarks/qwen3-tts/README.md b/benchmarks/qwen3-tts/README.md
index 9c01f29aa9f..a1c2ebe12ff 100644
--- a/benchmarks/qwen3-tts/README.md
+++ b/benchmarks/qwen3-tts/README.md
@@ -35,8 +35,8 @@ MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice bash run_benchmark.sh --async-only
# Use a Voice Clone model
MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only
-# Use bs16 config for higher throughput
-STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs16.yaml bash run_benchmark.sh --async-only
+# Use batch size 16 for higher throughput
+BATCH_SIZE=16 bash run_benchmark.sh --async-only
# Custom GPU, prompt count, concurrency levels
GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh
@@ -50,7 +50,8 @@ GPU_DEVICE=1 NUM_PROMPTS=20 CONCURRENCY="1 4" bash run_benchmark.sh
CUDA_VISIBLE_DEVICES=0 python -m vllm_omni.entrypoints.cli.main serve \
"Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice" \
--omni --host 127.0.0.1 --port 8000 \
- --stage-configs-path benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
+ --stage-overrides '{"0":{"max_num_seqs":1,"gpu_memory_utilization":0.3,"max_num_batched_tokens":512},"1":{"max_num_seqs":1,"gpu_memory_utilization":0.3,"max_num_batched_tokens":8192}}' \
--trust-remote-code
```
@@ -84,16 +85,19 @@ python benchmarks/qwen3-tts/plot_results.py \
--output results/comparison.png
```
-## Stage Configs
+## Batch-size presets
-| Config | max_num_seqs | Description |
-|--------|:------------:|-------------|
-| `vllm_omni/configs/qwen3_tts_bs1.yaml` | 1 | Single-request processing (lowest latency) |
-| `vllm_omni/configs/qwen3_tts_bs16.yaml` | 16 | High-throughput concurrent processing |
+The bench script loads the bundled production deploy (`vllm_omni/deploy/qwen3_tts.yaml`) and layers per-stage budgets on top via `--stage-overrides`, driven by the `BATCH_SIZE` env var. Each batch size picks compatible per-stage `max_num_seqs`, `max_num_batched_tokens`, and `gpu_memory_utilization` defaults:
-All configs use a 2-stage pipeline (Talker -> Code2Wav) with `async_chunk` streaming enabled. The `SharedMemoryConnector` streams codec frames (25-frame chunks with 25-frame context overlap) between stages.
+| `BATCH_SIZE` | Description |
+|:--:|-------------|
+| `1` (default) | Single-request processing (lowest latency) |
+| `4` | Moderate-throughput concurrent processing |
+| `16` | High-throughput concurrent processing |
-The model is specified via the CLI `--model` flag (or `MODEL` env var), so the same configs work for both the 0.6B and 1.7B model variants.
+The 2-stage pipeline (Talker -> Code2Wav) runs with `async_chunk` streaming enabled via the prod deploy; the `SharedMemoryConnector` streams codec frames (25-frame chunks with 25-frame context overlap) between stages.
+
+The model is specified via the CLI `--model` flag (or `MODEL` env var), so the same bench script works for both the 0.6B and 1.7B model variants.
## Metrics
diff --git a/benchmarks/qwen3-tts/run_benchmark.sh b/benchmarks/qwen3-tts/run_benchmark.sh
index 283b6b844c1..8c3e46903ca 100755
--- a/benchmarks/qwen3-tts/run_benchmark.sh
+++ b/benchmarks/qwen3-tts/run_benchmark.sh
@@ -26,8 +26,8 @@
# # Use Voice Clone model
# MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-Base TASK_TYPE=Base bash run_benchmark.sh --async-only
#
-# # Use batch_size=4 config:
-# STAGE_CONFIG=vllm_omni/configs/qwen3_tts_bs4.yaml bash run_benchmark.sh --async-only
+# # Use batch_size=4:
+# BATCH_SIZE=4 bash run_benchmark.sh --async-only
#
# Environment variables:
# GPU_DEVICE - GPU index to use (default: 0)
@@ -35,9 +35,9 @@
# CONCURRENCY - Space-separated concurrency levels (default: "1 4 10")
# MODEL - Model name (default: Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice)
# PORT - Server port (default: 8000)
-# GPU_MEM_TALKER - gpu_memory_utilization for talker stage (default: 0.3)
-# GPU_MEM_CODE2WAV - gpu_memory_utilization for code2wav stage (default: 0.2)
-# STAGE_CONFIG - Path to stage config YAML (default: configs/qwen3_tts_bs1.yaml)
+# BATCH_SIZE - Per-stage ``max_num_seqs`` for both talker and code2wav (default: 1)
+# GPU_MEM_TALKER - gpu_memory_utilization for talker stage (default: 0.3 at bs=1, else 0.2)
+# GPU_MEM_CODE2WAV - gpu_memory_utilization for code2wav stage (default: 0.3 at bs=1, else 0.2)
# TASK_TYPE - Task type: CustomVoice, VoiceDesign, Base (default: CustomVoice)
set -euo pipefail
@@ -51,14 +51,36 @@ NUM_PROMPTS="${NUM_PROMPTS:-50}"
CONCURRENCY="${CONCURRENCY:-1 4 10}"
MODEL="${MODEL:-Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice}"
PORT="${PORT:-8000}"
-GPU_MEM_TALKER="${GPU_MEM_TALKER:-0.3}"
-GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV:-0.2}"
+BATCH_SIZE="${BATCH_SIZE:-1}"
+DEFAULT_MEM=$([ "${BATCH_SIZE}" = "1" ] && echo "0.3" || echo "0.2")
+GPU_MEM_TALKER="${GPU_MEM_TALKER:-${DEFAULT_MEM}}"
+GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV:-${DEFAULT_MEM}}"
NUM_WARMUPS="${NUM_WARMUPS:-3}"
-STAGE_CONFIG="${STAGE_CONFIG:-vllm_omni/configs/qwen3_tts_bs1.yaml}"
+DEPLOY_CONFIG="vllm_omni/deploy/qwen3_tts.yaml"
RESULT_DIR="${SCRIPT_DIR}/results"
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
TASK_TYPE="${TASK_TYPE:-CustomVoice}"
+# Build --stage-overrides JSON from BATCH_SIZE + GPU_MEM_*.
+STAGE_OVERRIDES=$(
+ BATCH_SIZE="${BATCH_SIZE}" \
+ GPU_MEM_TALKER="${GPU_MEM_TALKER}" \
+ GPU_MEM_CODE2WAV="${GPU_MEM_CODE2WAV}" \
+ python - <<'PYEOF'
+import json, os
+bs = int(os.environ["BATCH_SIZE"])
+mem_t = float(os.environ["GPU_MEM_TALKER"])
+mem_c = float(os.environ["GPU_MEM_CODE2WAV"])
+# Prefill budget grows with batch size on both stages.
+talker_batched = 512 if bs <= 4 else 4096
+code2wav_batched = 8192 if bs <= 4 else 32768
+print(json.dumps({
+ "0": {"max_num_seqs": bs, "gpu_memory_utilization": mem_t, "max_num_batched_tokens": talker_batched},
+ "1": {"max_num_seqs": bs, "gpu_memory_utilization": mem_c, "max_num_batched_tokens": code2wav_batched},
+}))
+PYEOF
+)
+
# Parse args
RUN_ASYNC=true
RUN_HF=true
@@ -75,41 +97,27 @@ mkdir -p "${RESULT_DIR}"
echo "============================================================"
echo " Qwen3-TTS Benchmark"
echo "============================================================"
-echo " GPU: ${GPU_DEVICE}"
-echo " Model: ${MODEL}"
-echo " Prompts: ${NUM_PROMPTS}"
-echo " Concurrency: ${CONCURRENCY}"
-echo " Port: ${PORT}"
-echo " Stage config: ${STAGE_CONFIG}"
-echo " Results: ${RESULT_DIR}"
-echo " Task type: ${TASK_TYPE}"
+echo " GPU: ${GPU_DEVICE}"
+echo " Model: ${MODEL}"
+echo " Prompts: ${NUM_PROMPTS}"
+echo " Concurrency: ${CONCURRENCY}"
+echo " Port: ${PORT}"
+echo " Deploy config: ${DEPLOY_CONFIG}"
+echo " Batch size: ${BATCH_SIZE}"
+echo " GPU mem T/C: ${GPU_MEM_TALKER} / ${GPU_MEM_CODE2WAV}"
+echo " Results: ${RESULT_DIR}"
+echo " Task type: ${TASK_TYPE}"
echo "============================================================"
-# Prepare stage config with correct GPU device and memory settings
-prepare_config() {
- local config_template="$1"
- local config_name="$2"
- local output_path="${RESULT_DIR}/${config_name}_stage_config.yaml"
-
- # Use sed to patch GPU device and memory utilization
- sed \
- -e "s/devices: \"0\"/devices: \"${GPU_DEVICE}\"/g" \
- -e "s/gpu_memory_utilization: 0.3/gpu_memory_utilization: ${GPU_MEM_TALKER}/g" \
- -e "s/gpu_memory_utilization: 0.2/gpu_memory_utilization: ${GPU_MEM_CODE2WAV}/g" \
- "${config_template}" > "${output_path}"
-
- echo "${output_path}"
-}
-
# Start server and wait for it to be ready
start_server() {
- local stage_config="$1"
- local config_name="$2"
+ local config_name="$1"
local log_file="${RESULT_DIR}/server_${config_name}_${TIMESTAMP}.log"
echo ""
echo "Starting server with config: ${config_name}"
- echo " Stage config: ${stage_config}"
+ echo " Deploy config: ${DEPLOY_CONFIG}"
+ echo " Stage overrides: ${STAGE_OVERRIDES}"
echo " Log file: ${log_file}"
VLLM_WORKER_MULTIPROC_METHOD=spawn \
@@ -118,7 +126,8 @@ start_server() {
--omni \
--host 127.0.0.1 \
--port "${PORT}" \
- --stage-configs-path "${stage_config}" \
+ --deploy-config "${DEPLOY_CONFIG}" \
+ --stage-overrides "${STAGE_OVERRIDES}" \
--stage-init-timeout 120 \
--trust-remote-code \
--disable-log-stats \
@@ -175,17 +184,13 @@ trap 'stop_server' EXIT
# Run benchmark for a given config
run_bench() {
local config_name="$1"
- local config_template="$2"
echo ""
echo "============================================================"
echo " Benchmarking: ${config_name}"
echo "============================================================"
- local stage_config
- stage_config=$(prepare_config "${config_template}" "${config_name}")
-
- start_server "${stage_config}" "${config_name}"
+ start_server "${config_name}"
# Convert concurrency string to args
local conc_args=""
@@ -212,7 +217,7 @@ run_bench() {
# Run vllm-omni benchmark
if [ "${RUN_ASYNC}" = true ]; then
- run_bench "async_chunk" "${SCRIPT_DIR}/${STAGE_CONFIG}"
+ run_bench "async_chunk"
fi
# Run HuggingFace baseline benchmark
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml
deleted file mode 100644
index ca441d286dd..00000000000
--- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs1.yaml
+++ /dev/null
@@ -1,93 +0,0 @@
-# Qwen3-TTS batch_size=1 config (streaming with async_chunk)
-# 2-stage pipeline: Talker -> Code2Wav
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 1
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 1
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 32768
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 25
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml
deleted file mode 100644
index 2cc5cf53532..00000000000
--- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs16.yaml
+++ /dev/null
@@ -1,94 +0,0 @@
-# Qwen3-TTS max_num_seqs=16 config (streaming with async_chunk)
-# High-throughput concurrent request processing
-# 2-stage pipeline: Talker -> Code2Wav
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 16
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 16
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 16384
- max_model_len: 32768
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 16
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 25
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml b/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml
deleted file mode 100644
index 5de107d4976..00000000000
--- a/benchmarks/qwen3-tts/vllm_omni/configs/qwen3_tts_bs4.yaml
+++ /dev/null
@@ -1,94 +0,0 @@
-# Qwen3-TTS batch_size=4 config (streaming with async_chunk)
-# Enables concurrent request processing
-# 2-stage pipeline: Talker -> Code2Wav
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 32768
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 4
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 25
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
index 61cf7757a9b..0ede359ea37 100755
--- a/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
+++ b/benchmarks/qwen3-tts/vllm_omni/run_async_chunk_benchmark.sh
@@ -31,8 +31,11 @@ PORT_OFF="${PORT_OFF:-8001}"
RESULT_DIR="${SCRIPT_DIR}/results"
TIMESTAMP="$(date +%Y%m%d_%H%M%S)"
-STAGE_CONFIG_ON="vllm_omni/model_executor/stage_configs/qwen3_tts.yaml"
-STAGE_CONFIG_OFF="vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml"
+# The bundled ``vllm_omni/deploy/qwen3_tts.yaml`` is auto-loaded by the model
+# registry; no ``--deploy-config`` flag needed on the default (ON) path.
+# async_chunk OFF is selected by the ``--no-async-chunk`` CLI flag —
+# the single ``qwen3_tts`` pipeline dispatches to the end-to-end codec
+# processor when ``deploy.async_chunk`` is false.
mkdir -p "${RESULT_DIR}"
@@ -77,7 +80,6 @@ wait_for_server() {
echo ""
echo "[Phase 1] Starting async_chunk ON server on port ${PORT_ON}..."
CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \
- --stage-configs-path "${STAGE_CONFIG_ON}" \
--host 0.0.0.0 --port "${PORT_ON}" \
--trust-remote-code --enforce-eager --omni \
> "${RESULT_DIR}/server_on_${TIMESTAMP}.log" 2>&1 &
@@ -104,7 +106,7 @@ sleep 5
echo ""
echo "[Phase 2] Starting async_chunk OFF server on port ${PORT_OFF}..."
CUDA_VISIBLE_DEVICES=${GPU_DEVICE} vllm-omni serve "${MODEL}" \
- --stage-configs-path "${STAGE_CONFIG_OFF}" \
+ --no-async-chunk \
--host 0.0.0.0 --port "${PORT_OFF}" \
--trust-remote-code --enforce-eager --omni \
> "${RESULT_DIR}/server_off_${TIMESTAMP}.log" 2>&1 &
diff --git a/docs/.nav.yml b/docs/.nav.yml
index 79d7c38e274..455a0525056 100644
--- a/docs/.nav.yml
+++ b/docs/.nav.yml
@@ -107,7 +107,7 @@ nav:
- design/feature/hsdp.md
- design/feature/cache_dit.md
- design/feature/teacache.md
- - design/feature/async_chunk_design.md
+ - design/feature/async_chunk.md
- design/feature/vae_parallel.md
- design/feature/diffusion_step_execution.md
- Module Design:
diff --git a/docs/assets/WeChat.jpg b/docs/assets/WeChat.jpg
index 416439f7eb0..83252b7569d 100644
Binary files a/docs/assets/WeChat.jpg and b/docs/assets/WeChat.jpg differ
diff --git a/docs/configuration/README.md b/docs/configuration/README.md
index b5761a7f1bc..390176e9cea 100644
--- a/docs/configuration/README.md
+++ b/docs/configuration/README.md
@@ -6,7 +6,7 @@ For options within a vLLM Engine. Please refer to [vLLM Configuration](https://d
Currently, the main options are maintained by stage configs for each model.
-For specific example, please refer to [Qwen2.5-omni stage config](stage_configs/qwen2_5_omni.yaml)
+For a specific example, see the [Qwen2.5-Omni deploy config](gh-file:vllm_omni/deploy/qwen2_5_omni.yaml). The matching frozen pipeline topology lives at [vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py](gh-file:vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py).
For introduction, please check [Introduction for stage config](./stage_configs.md)
diff --git a/docs/configuration/pd_disaggregation.md b/docs/configuration/pd_disaggregation.md
index 1cf6189e603..9196bdb0240 100644
--- a/docs/configuration/pd_disaggregation.md
+++ b/docs/configuration/pd_disaggregation.md
@@ -11,7 +11,7 @@ deployment-specific values usually change per environment:
- connector backend and connector ports
- connector IPs or bootstrap addresses
-Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml)
+Start from the [default Qwen3-Omni stage config](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml)
and copy it to your own file, for example `qwen3_omni_pd.yaml`. Then apply the
changes below.
@@ -145,19 +145,13 @@ Compared with the default Qwen3-Omni config:
```yaml
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
- from: 2
to: 3
- window_size: -1
```
## 4. Launch with your custom config
diff --git a/docs/configuration/stage_configs.md b/docs/configuration/stage_configs.md
index 95c42afcc70..55b4053cc71 100644
--- a/docs/configuration/stage_configs.md
+++ b/docs/configuration/stage_configs.md
@@ -3,7 +3,147 @@
In vLLM-Omni, the target model is separated into multiple stages, which are processed by different LLMEngines, DiffusionEngines or other types of engines. Depending on different types of stages, such as Autoregressive (AR) stage or Diffusion transformer (DiT) stage, each can choose corresponding schedulers, model workers to load with the Engines in a plug-in fashion.
!!! note
- Default stage config YAMLs (for example, `vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml` and `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`) are bundled and loaded automatically when `stage_configs_path` is not provided. They have been verified to work on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni.
+ Default deploy config YAMLs (for example, `vllm_omni/deploy/qwen2_5_omni.yaml`, `vllm_omni/deploy/qwen3_omni_moe.yaml`, and `vllm_omni/deploy/qwen3_tts.yaml`) are bundled and loaded automatically when neither `--stage-configs-path` nor `--deploy-config` is provided — the model registry resolves the right pipeline + deploy YAML by `model_type`. The bundled defaults have been verified on 1xH100 for Qwen2.5-Omni and 2xH100 for Qwen3-Omni. Models that have not yet migrated to the new schema continue to use the legacy `vllm_omni/model_executor/stage_configs/.yaml` files via `--stage-configs-path`.
+
+## New deploy schema reference
+
+The new deploy schema lives under `vllm_omni/deploy/` and is paired with a frozen `PipelineConfig` registered by the model's `pipeline.py`. Each deploy YAML has these top-level fields:
+
+| Field | Type | Required | Default | Description |
+|-------|------|----------|---------|-------------|
+| `base_config` | str (path) | optional | — | Overlay parent (relative or absolute). `stages:` / `platforms:` deep-merged by stage_id; other scalars overlay-wins. Intended for user-authored overlays; prod yamls stay flat. |
+| `async_chunk` | bool | optional | `true` | Enable chunked streaming between stages. Pin to `false` if the pipeline runs end-to-end. |
+| `connectors` | dict | optional | `null` | Named connector specs (`{name, extra}`). Referenced by each stage's `input_connectors` / `output_connectors`. See [Connector schema](#connector-schema). |
+| `edges` | list | optional | `null` | Explicit edge list for the KV transfer graph. Auto-derived from stage inputs if omitted. |
+| `stages` | list | required | — | Per-stage engine args + wiring (see [Stage fields](#stage-fields)). |
+| `platforms` | dict | optional | `null` | Keyed by `npu` / `rocm` / `xpu`, each contains a `stages:` list with per-platform overrides applied on top of the CUDA defaults. |
+| `pipeline` | str | optional | `null` | Override the auto-detected pipeline registry key (used for structural variants like `qwen2_5_omni_thinker_only`). |
+| `trust_remote_code` | bool | optional | `true` | **Pipeline-wide.** Trust HF remote code on model load; applies to every stage. |
+| `distributed_executor_backend` | str | optional | `"mp"` | **Pipeline-wide.** Executor backend (`"mp"` or `"ray"`). |
+| `dtype` | str \| null | optional | `null` | **Pipeline-wide.** Model dtype for every stage. |
+| `quantization` | str \| null | optional | `null` | **Pipeline-wide.** Quantization method for every stage. |
+| `enable_prefix_caching` | bool | optional | `false` | **Pipeline-wide.** Prefix cache toggle applied to every stage. |
+| `enable_chunked_prefill` | bool \| null | optional | `null` | **Pipeline-wide.** Chunked prefill toggle applied to every stage. |
+| `data_parallel_size` | int | optional | `1` | **Pipeline-wide.** DP degree for every stage. |
+| `pipeline_parallel_size` | int | optional | `1` | **Pipeline-wide.** PP degree for every stage. |
+
+### Stage fields
+
+Each entry under `stages:` accepts any `StageDeployConfig` field directly (no nested `engine_args:`). Only fields whose value legitimately varies across stages live here; pipeline-wide settings (trust_remote_code, distributed_executor_backend, dtype, quantization, prefix/chunked prefill, DP/PP sizes) are declared at the top level and applied to every stage. Unknown keys fall through to `engine_extras:` and are forwarded to the engine.
+
+| Field | Type | Required | Default | Description |
+|-------|------|----------|---------|-------------|
+| `stage_id` | int | required | — | Stage identity; matched against `PipelineConfig.stages[*].stage_id`. |
+| `max_num_seqs` | int | optional | `64` | Max concurrent sequences per stage. |
+| `gpu_memory_utilization` | float | optional | `0.9` | Per-stage memory budget. |
+| `tensor_parallel_size` | int | optional | `1` | TP degree for this stage. |
+| `enforce_eager` | bool | optional | `false` | Disable CUDA graphs. |
+| `max_num_batched_tokens` | int | optional | `32768` | Prefill budget. |
+| `max_model_len` | int \| null | optional | `null` | Per-stage context length (auto-sets `VLLM_ALLOW_LONG_MAX_MODEL_LEN=1` when larger than HF default). |
+| `async_scheduling` | bool \| null | optional | `null` | Per-stage async scheduling toggle. |
+| `devices` | str | optional | `"0"` | `CUDA_VISIBLE_DEVICES`-style device list. |
+| `output_connectors` | dict \| null | optional | `null` | Keyed by `to_stage_`; values are names registered under top-level `connectors:`. |
+| `input_connectors` | dict \| null | optional | `null` | Keyed by `from_stage_`; values are names registered under top-level `connectors:`. |
+| `default_sampling_params` | dict \| null | optional | `null` | Baseline sampling params. Deep-merged with pipeline `sampling_constraints` (pipeline wins). |
+| `engine_extras` | dict | optional | `{}` | Catch-all for keys not listed above; deep-merged across overlays. Also carries per-stage overrides of pipeline-wide settings (e.g. stage-specific `dtype`). |
+
+### Connector schema
+
+Each entry under top-level `connectors:` follows this shape:
+
+```yaml
+connectors:
+ :
+ name: # required — class registered in vllm_omni.distributed
+ extra: # optional — forwarded to the connector's __init__
+ :
+ ...
+```
+
+| Connector class | Use case | `extra` keys |
+|-----------------|----------|--------------|
+| `SharedMemoryConnector` | Same-host KV transfer between stages (default for bundled YAMLs). | `shm_threshold_bytes` (int, default `65536`). |
+| `MooncakeStoreConnector` | Cross-host KV transfer over TCP. Required for multi-node deployments. | `host`, `metadata_server`, `master`, `segment` (int bytes), `localbuf` (int bytes), `proto` (`"tcp"` / `"rdma"`). |
+
+A stage references a connector by name in its `input_connectors` / `output_connectors`:
+
+```yaml
+connectors:
+ shm:
+ name: SharedMemoryConnector
+
+stages:
+ - stage_id: 0
+ output_connectors: {to_stage_1: shm}
+ - stage_id: 1
+ input_connectors: {from_stage_0: shm}
+```
+
+### CLI flags introduced in this refactor
+
+| Flag | Description |
+|------|-------------|
+| `--deploy-config PATH` | Load a new-schema deploy YAML. Takes precedence over `--stage-configs-path`. **Optional** — when omitted, the bundled `vllm_omni/deploy/.yaml` is auto-loaded by the model registry. |
+| `--stage-overrides JSON` | Per-stage JSON overrides, e.g. `'{"0":{"gpu_memory_utilization":0.5}}'`. Per-stage values always win over global flags. |
+| `--async-chunk` / `--no-async-chunk` | Flip the deploy YAML's `async_chunk:` bool. Unset (default) leaves the YAML value in force. |
+| `--stage-configs-path` | **Deprecated.** Accepts legacy `stage_args` yamls and (auto-detected) new deploy yamls; emits a deprecation warning. Migrate to `--deploy-config`. To be removed in a follow-up PR. |
+
+### Precedence
+
+From highest to lowest:
+
+1. Per-stage flags (`--stage-overrides` JSON, `--stage--` if registered)
+2. Explicit global CLI flags (`--gpu-memory-utilization 0.85`, etc.)
+3. Platform section (`platforms.npu.stages`, etc.) on top of the base `stages:`
+4. Overlay YAML (via `base_config:`) on top of the base YAML
+5. Parser defaults
+
+### Worked override example
+
+Starting from the bundled `vllm_omni/deploy/qwen3_omni_moe.yaml`:
+
+```yaml
+# vllm_omni/deploy/qwen3_omni_moe.yaml (excerpt)
+async_chunk: true
+stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.9
+ max_num_seqs: 32
+ - stage_id: 1
+ gpu_memory_utilization: 0.7
+ max_num_seqs: 16
+```
+
+A user-authored overlay that inherits the base and overrides only stage 1:
+
+```yaml
+# my_overrides.yaml
+base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
+stages:
+ - stage_id: 1
+ gpu_memory_utilization: 0.5 # smaller GPU
+```
+
+Launched with both an explicit global flag and a per-stage override:
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
+ --deploy-config my_overrides.yaml \
+ --max-model-len 16384 \
+ --stage-overrides '{"0": {"max_num_seqs": 8}}'
+```
+
+Effective config per stage after the merge:
+
+| Stage | Field | Final value | Source |
+|-------|-------|-------------|--------|
+| 0 | `gpu_memory_utilization` | `0.9` | base YAML (overlay didn't touch stage 0) |
+| 0 | `max_num_seqs` | `8` | per-stage CLI (`--stage-overrides`) — wins over base `32` |
+| 0 | `max_model_len` | `16384` | global CLI |
+| 1 | `gpu_memory_utilization` | `0.5` | overlay YAML — wins over base `0.7` |
+| 1 | `max_num_seqs` | `16` | base YAML (overlay didn't touch this field) |
+| 1 | `max_model_len` | `16384` | global CLI |
+| 2 | (all defaults) | — | base YAML (no overrides apply) |
Therefore, as a core part of vLLM-Omni, the stage configs for a model have several main functions:
@@ -35,7 +175,7 @@ stage_args:
- stage_id: 0 # mark the unique id for each stage
runtime: # The disaggregated configuration
process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
+ devices: "0" # Logical device index for this stage (mapped through CUDA_VISIBLE_DEVICES / ASCEND_RT_VISIBLE_DEVICES if set)
engine_args: # Engine arguments for a certain engine
model_stage: thinker
max_num_seqs: 1
@@ -114,16 +254,12 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
+
edges:
- from: 0 # thinker → talker: trigger only after receiving full input (-1)
to: 1
- window_size: -1
- from: 1 # talker → code2wav: trigger only after receiving full input (-1)
to: 2
- window_size: -1
```
@@ -155,7 +291,9 @@ Default: `true`
#### `runtime.devices`
-Visible devices for this stage, specified as a string. This controls which GPU devices are available to the stage process, similar to setting `CUDA_VISIBLE_DEVICES` or using `torch.cuda.set_device()`. For example, `"0"` uses GPU 0, `"1"` uses GPU 1, and `"0,1"` makes both GPUs 0 and 1 visible.
+Logical device indices for this stage, specified as a string. Values are **logical indices** (`0`, `1`, `2`, ...) — not physical GPU IDs — and are mapped through the platform's visibility env var (`CUDA_VISIBLE_DEVICES` on CUDA, `ASCEND_RT_VISIBLE_DEVICES` on NPU) before being applied via `torch.cuda.set_device()` (or the equivalent).
+
+Example: if `CUDA_VISIBLE_DEVICES=0,2,4` is set in the environment, then `devices: "0"` selects physical GPU 0 (the first visible), `devices: "1"` selects physical GPU 2, and `devices: "0,1"` makes physical GPUs 0 and 2 available to the stage. If no visibility env var is set, logical and physical IDs coincide.
Default: `"0"`
diff --git a/docs/configuration/stage_configs/qwen2_5_omni.yaml b/docs/configuration/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 690577b84a8..00000000000
--- a/docs/configuration/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,94 +0,0 @@
-# stage config for running qwen2.5-omni with AsyncOmniEngine + Orchestrator runtime.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/docs/contributing/ci/CI_5levels.md b/docs/contributing/ci/CI_5levels.md
index b0428ddd7de..2452ef5d4a3 100644
--- a/docs/contributing/ci/CI_5levels.md
+++ b/docs/contributing/ci/CI_5levels.md
@@ -231,8 +231,7 @@ vllm_omni/ tests/
│ ├── test_qwen3_omni_expansion.py
│ ├── test_mimo_audio.py
│ ├── test_image_gen_edit.py
- │ ├── test_images_generations_lora.py
- │ └── stage_configs/
+ │ └── test_images_generations_lora.py
└── offline_inference/ ✅
├── test_qwen2_5_omni.py
├── test_qwen3_omni.py
@@ -248,11 +247,12 @@ vllm_omni/ tests/
├── test_diffusion_layerwise_offload.py
├── test_diffusion_lora.py
├── test_sequence_parallel.py
- └── stage_configs/
- ├── qwen2_5_omni_ci.yaml
- ├── qwen3_omni_ci.yaml
- ├── bagel_*.yaml
- └── npu/, rocm/, etc.
+ └── stage_configs/ (legacy schema, still
+ ├── bagel_*.yaml present for unmigrated
+ └── npu/, rocm/, etc. models)
+
+# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under
+# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md.
```
diff --git a/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md b/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md
index 69d6ad82871..ab4deecd60d 100644
--- a/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md
+++ b/docs/contributing/ci/test_examples/l4_functionality_tests.inc.md
@@ -40,7 +40,7 @@ Currently all the features are available in online serving mode. Hence, only nee
- Test marks: always add `advanced_model` and `diffusion`. Add GPU-related marks if needed. Ref: [Markers for Tests](https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/).
- To maximize code reuse, you may refer to
- `tests/conftest.py` for `omni_server` (running server in subprocess) and `openai_client` fixtures (sending requests and validating output), `generate_synthetic_image` and `assert_XXX_valid` helper.
- - `tests/utils.py` for `@hardware_test(...)` and `hardware_marks`.
+ - `tests/helpers/mark.py` for `@hardware_test(...)` and `hardware_marks`.
- [Parametrizing tests (pytest doc)](https://docs.pytest.org/en/stable/example/parametrize.html) to reuse test function implementation for different cases.
- Doc: add a concise docstring for each test function.
- Reference L4 test implementation: [tests/e2e/online_serving/test_qwen_image_edit_expansion.py](https://github.com/vllm-project/vllm-omni/blob/main/tests/e2e/online_serving/test_qwen_image_edit_expansion.py).
diff --git a/docs/contributing/ci/tests_markers.md b/docs/contributing/ci/tests_markers.md
index 7c1ba1c73bd..7628db284a7 100644
--- a/docs/contributing/ci/tests_markers.md
+++ b/docs/contributing/ci/tests_markers.md
@@ -38,7 +38,7 @@ Defined in `pyproject.toml`:
### Example usage for markers
```python
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
@pytest.mark.core_model
@pytest.mark.omni
@@ -53,7 +53,7 @@ def test_video_to_audio()
### Decorator: `@hardware_test`
-This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/utils.py` performs the following actions:
+This decorator is intended to make hardware-aware, cross-platform test authoring easier and more robust for CI/CD environments. The `hardware_test` decorator in `vllm-omni/tests/helpers/mark.py` performs the following actions:
1. **Applies platform and resource markers**
Adds the appropriate pytest markers for each specified hardware platform (e.g., `cuda`, `rocm`, `xpu`, `npu`) and resource type (e.g., `L4`, `H100`, `MI325`, `B60`, `A2`, `A3`).
@@ -105,7 +105,7 @@ This decorator is intended to make hardware-aware, cross-platform test authoring
`hardware_marks` returns a list of pytest mark objects with the same signature as `@hardware_test`. Use it when you need more flexibility, such as attaching hardware marks to individual `pytest.param` entries rather than an entire test function.
```python
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
MULTI_CARD_MARKS = hardware_marks(
res={"cuda": "H100", "rocm": "MI325", "npu": "A2"}, num_cards=2
@@ -133,9 +133,9 @@ If you want to add support for a new platform (e.g., "tpu" for a new accelerator
"distributed_tpu: Tests that require multiple TPU devices",
]
```
-2. **Implement a marker construction function for your platform** in `vllm-omni/tests/utils.py`:
+2. **Implement a marker construction function for your platform** in `vllm-omni/tests/helpers/mark.py`:
```python
- # In vllm-omni/tests/utils.py
+ # In vllm-omni/tests/helpers/mark.py
def tpu_marks(*, res: str, num_cards: int):
test_platform = pytest.mark.tpu
@@ -175,4 +175,4 @@ If you want to add support for a new platform (e.g., "tpu" for a new accelerator
- Plug into `hardware_marks`
- You're done: tests using `@hardware_test` or `hardware_marks` with your platform now automatically get the correct markers, distribution, and isolation!
-See code in `vllm-omni/tests/utils.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`).
+See code in `vllm-omni/tests/helpers/mark.py` for existing examples (`cuda_marks`, `rocm_marks`, `npu_marks`).
diff --git a/docs/contributing/ci/tests_style.md b/docs/contributing/ci/tests_style.md
index 69d5b16d7a5..3a8cb0f127c 100644
--- a/docs/contributing/ci/tests_style.md
+++ b/docs/contributing/ci/tests_style.md
@@ -135,8 +135,7 @@ vllm_omni/ tests/
│ ├── test_qwen3_omni_expansion.py
│ ├── test_mimo_audio.py
│ ├── test_image_gen_edit.py
- │ ├── test_images_generations_lora.py
- │ └── stage_configs/
+ │ └── test_images_generations_lora.py
└── offline_inference/ ✅
├── test_qwen2_5_omni.py
├── test_qwen3_omni.py
@@ -153,11 +152,12 @@ vllm_omni/ tests/
├── test_diffusion_lora.py
├── test_sequence_parallel.py
├── test_qwen_image_edit_expansion.py
- └── stage_configs/
- ├── qwen2_5_omni_ci.yaml
- ├── qwen3_omni_ci.yaml
- ├── bagel_*.yaml
+ └── stage_configs/ (legacy schema, still present
+ ├── bagel_*.yaml for unmigrated models)
└── npu/, rocm/, etc.
+
+# Migrated models (qwen3_omni_moe, qwen2_5_omni, qwen3_tts) live under
+# vllm_omni/deploy/ instead — see docs/configuration/stage_configs.md.
examples/ tests
│ └── examples
├── online_serving/ → ├── online_serving/
@@ -221,14 +221,13 @@ from pathlib import Path
import openai
import pytest
-from tests.conftest import (
- OmniServer,
- convert_audio_to_text,
+from tests.helpers.media import (
+ convert_audio_bytes_to_text,
cosine_similarity_text,
- dummy_messages_from_mix_data,
generate_synthetic_video,
- merge_base64_and_convert_to_text,
)
+from tests.helpers.runtime import OmniServer, dummy_messages_from_mix_data
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
from vllm_omni.platforms import current_omni_platform
# Edit: model name and stage config path
@@ -236,7 +235,7 @@ models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
#If you use the default configuration file, you can directly use the following address.
def get_default_config():
- return str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
+ return get_deploy_config_path("ci/qwen3_omni_moe.yaml")
#If you need to modify the configuration file, you can use modify_stage_config.
def get_chunk_config():
@@ -405,7 +404,7 @@ def test_mix_to_text_audio_001(client: openai.OpenAI, omni_server, request) -> N
# PURPOSE: Verify text and audio outputs convey the same information
# CUSTOMIZATION: Adjust similarity threshold (0.9) based on accuracy requirements
assert audio_data is not None, "No audio output is generated"
- audio_content = merge_base64_and_convert_to_text(audio_data)
+ audio_content = convert_audio_bytes_to_text(audio_data)
print(f"text content is: {text_content}")
print(f"audio content is: {audio_content}")
similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
@@ -428,7 +427,7 @@ from pathlib import Path
import pytest
from vllm.assets.video import VideoAsset
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from ..multi_stages.conftest import OmniRunner
# Optional: set process start method for workers
diff --git a/docs/contributing/model/adding_omni_model.md b/docs/contributing/model/adding_omni_model.md
index a0619e33811..1eaff10596c 100644
--- a/docs/contributing/model/adding_omni_model.md
+++ b/docs/contributing/model/adding_omni_model.md
@@ -313,7 +313,7 @@ The registry uses lazy loading, so the model class is imported only when needed.
## Stage Configuration
-Create a YAML configuration file in `vllm_omni/model_executor/stage_configs/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml).
+Create a YAML configuration file in `vllm_omni/deploy/`. For a complete example, see the [Qwen3-Omni configuration file](gh-file:vllm_omni/deploy/qwen3_omni_moe.yaml).
### Key Configuration Fields
@@ -408,18 +408,17 @@ Understanding the data structures is crucial for implementing stage transitions:
**Input to your function:**
- `stage_list[source_stage_id].engine_outputs`: List of `EngineCoreOutput` objects
- - Each contains `outputs`: List of `RequestOutput` objects
- - Each `RequestOutput` has:
- - `token_ids`: Generated token IDs
- - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.
- - These are the hidden states or intermediate outputs from the model's forward pass
- - `prompt_token_ids`: Original prompt token IDs
+- - Each contains `outputs`: List of `RequestOutput` objects
+ - Each `RequestOutput` has:
+- - - `token_ids`: Generated token IDs
+ - `multimodal_output`: Dict with keys like `"code_predictor_codes"`, etc.These are the hidden states or intermediate outputs from the model's forward pass
+ - `prompt_token_ids`: Original prompt token IDs
**Output from your function:**
- Must return `list[OmniTokensPrompt]` where each `OmniTokensPrompt` contains:
- - `prompt_token_ids`: List[int] - Token IDs for the next stage
- - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states)
- - `multi_modal_data`: Optional multimodal data if needed
+- - `prompt_token_ids`: List[int] - Token IDs for the next stage
+ - `additional_information`: Dict[str, Any] - Optional metadata (e.g., embeddings, hidden states)
+ - `multi_modal_data`: Optional multimodal data if needed
### How Model Outputs Are Stored
@@ -614,7 +613,7 @@ For a complete reference implementation, see:
- **Thinker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_thinker.py`
- **Talker**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py`
- **Code2Wav**: `vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_code2wav.py`
-- **Stage config**: `vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml`
+- **Stage config**: `vllm_omni/deploy/qwen3_omni_moe.yaml`
- **Input processors**: `vllm_omni/model_executor/stage_input_processors/qwen3_omni.py`
- **Registry**: `vllm_omni/model_executor/models/registry.py`
- **Testing**: `vllm_omni/tests/e2e/offline_inference/test_qwen3_omni.py`
diff --git a/docs/contributing/model/adding_tts_model.md b/docs/contributing/model/adding_tts_model.md
index e48ae5049ff..622064173cd 100644
--- a/docs/contributing/model/adding_tts_model.md
+++ b/docs/contributing/model/adding_tts_model.md
@@ -28,7 +28,7 @@ and can be placed on different devices. Qwen3-TTS has two stages:
Each stage is a separate model class configured independently via YAML. The two stages
are connected by the `async_chunk` framework, which enables inter-stage streaming for
-low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk_design.md)).
+low first-packet latency (see [Async Chunk Design](../../design/feature/async_chunk.md)).
### Without async_chunk (batch mode)
@@ -120,8 +120,18 @@ vllm_omni/model_executor/stage_configs/
| `models/qwen3_tts/qwen3_tts.py` | Unified model class |
| `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | Stage 0 - optimized AR |
| `models/qwen3_tts/qwen3_tts_code2wav.py` | Stage 1 - decoder |
-| `stage_configs/qwen3_tts.yaml` | Stage config (async_chunk enabled) |
-| `stage_configs/qwen3_tts_batch.yaml` | Batch mode config |
+| `deploy/qwen3_tts.yaml` (new schema) | Deploy config (async_chunk enabled) — paired with `models/qwen3_tts/pipeline.py` for the frozen topology |
+
+> **Chunked vs end-to-end modes**: `qwen3_tts` registers a single
+> pipeline whose stage 1 declares alternate processor functions — an
+> `async_chunk_process_next_stage_input_func` (per-chunk streaming, used
+> when `deploy.async_chunk=True`) and a `sync_process_input_func`
+> (batch-end, used when `deploy.async_chunk=False`). The loader selects
+> one at merge time based on the bool, so `--no-async-chunk` alone
+> switches modes — no variant yaml or variant pipeline registration is
+> needed. Pipelines that only make sense in one mode (e.g.
+> `qwen3_omni_moe` is always chunked) can keep using the unconditional
+> `custom_process_*` fields.
| `stage_input_processors/qwen3_tts.py` | Stage transition processors |
## Step-by-Step Implementation
@@ -574,11 +584,12 @@ Adding a TTS model to vLLM-Omni involves:
| `models/qwen3_tts/qwen3_tts.py` | Unified model class |
| `models/qwen3_tts/qwen3_tts_code_predictor_vllm.py` | AR stage with vLLM fused ops |
| `models/qwen3_tts/qwen3_tts_code2wav.py` | Decoder stage with `chunked_decode_streaming()` |
-| `stage_configs/qwen3_tts.yaml` | Stage configuration |
+| `models/qwen3_tts/pipeline.py` | Frozen pipeline topology (registered at import time) |
+| `deploy/qwen3_tts.yaml` | Deploy config (user-editable, async_chunk + SharedMemoryConnector) |
| `stage_input_processors/qwen3_tts.py` | Stage transition processors |
For more information, see:
- [Architecture Overview](../../design/architecture_overview.md)
-- [Async Chunk Design](../../design/feature/async_chunk_design.md)
+- [Async Chunk Design](../../design/feature/async_chunk.md)
- [Stage Configuration Guide](../../configuration/stage_configs.md)
diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md
index 418fb707ae9..6c209e5659a 100644
--- a/docs/contributing/profiling.md
+++ b/docs/contributing/profiling.md
@@ -127,10 +127,11 @@ Multi-stage omni serving:
```bash
vllm serve Qwen/Qwen2.5-Omni-7B \
--omni \
- --stage-configs-path qwen2_5_omni.yaml \
--port 8091
```
+(The default deploy config at `vllm_omni/deploy/qwen2_5_omni.yaml` is loaded automatically. Pass `--deploy-config /path/to/custom.yaml` to override.)
+
Single-stage diffusion serving with torch profiler:
```bash
diff --git a/docs/design/feature/async_chunk_design.md b/docs/design/feature/async_chunk.md
similarity index 99%
rename from docs/design/feature/async_chunk_design.md
rename to docs/design/feature/async_chunk.md
index 45314a0aec6..57b4209b8df 100644
--- a/docs/design/feature/async_chunk_design.md
+++ b/docs/design/feature/async_chunk.md
@@ -1,4 +1,4 @@
-# Async Chunk Design
+# Async Chunk
## Table of Contents
@@ -88,8 +88,9 @@ The following diagram illustrates the **Async Chunk Architecture** for multi-sta
**Diagram Legend:**
+
| Step | Stage Type | Description |
-|:------:|:-----------:|:------------|
+|------|-----------|------------|
| `prefill` | Initialization | Context processing, KV cache initialization |
| `decode` | Autoregressive | Token-by-token generation in AR stages |
| `codes` | Audio Encoding | RVQ codec codes from Talker stage |
diff --git a/docs/design/feature/teacache.md b/docs/design/feature/teacache.md
index 9fa315cee77..8577cff1f05 100644
--- a/docs/design/feature/teacache.md
+++ b/docs/design/feature/teacache.md
@@ -326,9 +326,41 @@ for prompt in tqdm(prompts, desc="Collecting data"):
# Estimate coefficients
coeffs = estimator.estimate(poly_order=4)
-print(f"Estimated coefficients: {coeffs.tolist()}")
+print(f"Estimated coefficients: {coeffs}")
```
+Note: some models may require the vLLM context and config to be initialized to initialize vLLM modules. To this end, you may need a workaround like the following to be able to run coefficient estimation.
+```python
+from vllm_omni.diffusion.forward_context import set_forward_context
+from vllm_omni.diffusion.distributed.parallel_state import (
+ init_distributed_environment,
+ initialize_model_parallel,
+)
+from vllm.config import VllmConfig
+...
+
+if __name__ == "__main__":
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "8192"
+ os.environ["LOCAL_RANK"] = "0"
+ os.environ["RANK"] = "0"
+ os.environ["WORLD_SIZE"] = "1"
+
+ vllm_config = VllmConfig()
+ init_distributed_environment()
+ initialize_model_parallel()
+
+ # NOTE: you may have to pass an initialized OmniDiffusionConfig as a kwarg
+ # here to make current sp checks happy; if this is the case, just create one
+ # .from_kwargs() with the model name to get around this check for now,
+ # since your estimator subclass should handle the actual model configuration.
+ #
+ # This will be cleaned up in the future
+ with set_forward_context(vllm_config):
+
+```
+
+
**Data Statistics Guide:**
| Metric | Good Range | Warning Signs |
diff --git a/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png
new file mode 100644
index 00000000000..15112d5862a
Binary files /dev/null and b/docs/design/figures/omni/E2EL_s_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png
new file mode 100644
index 00000000000..2f0615f77bb
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Baseline_vs_Batch.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png
new file mode 100644
index 00000000000..62d8bc79b6b
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_CUDA_Graph_vs_Async_Chunk.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png
new file mode 100644
index 00000000000..5838b45319e
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_RTF_Batch_vs_Batch_CUDA_Graph.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png
new file mode 100644
index 00000000000..24be814b7e9
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Baseline_vs_Batch.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png
new file mode 100644
index 00000000000..c8df58ebcdf
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_CUDA_Graph_vs_Async_Chunk.png differ
diff --git a/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png
new file mode 100644
index 00000000000..2d1a04e9c2c
Binary files /dev/null and b/docs/design/figures/omni/Mean_AUDIO_TTFP_ms_Batch_vs_Batch_CUDA_Graph.png differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png b/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png
new file mode 100644
index 00000000000..e598b543431
Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Baseline_vs_Batch.png differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png
new file mode 100644
index 00000000000..54452013eb4
Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Batch_CUDA_Graph_vs_Async_Chunk.png differ
diff --git a/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png b/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png
new file mode 100644
index 00000000000..04c5ad7396a
Binary files /dev/null and b/docs/design/figures/omni/Mean_E2EL_ms_Batch_vs_Batch_CUDA_Graph.png differ
diff --git a/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png b/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png
new file mode 100644
index 00000000000..d93ba0b2af5
Binary files /dev/null and b/docs/design/figures/omni/RTF_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png b/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png
new file mode 100644
index 00000000000..04087b5910f
Binary files /dev/null and b/docs/design/figures/omni/Summary_E2EL_ms_vs_features.png differ
diff --git a/docs/design/figures/omni/Summary_RTF_vs_features.png b/docs/design/figures/omni/Summary_RTF_vs_features.png
new file mode 100644
index 00000000000..c2c8ad40834
Binary files /dev/null and b/docs/design/figures/omni/Summary_RTF_vs_features.png differ
diff --git a/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png b/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png
new file mode 100644
index 00000000000..3dcc1c55379
Binary files /dev/null and b/docs/design/figures/omni/Summary_TTFP_ms_vs_features.png differ
diff --git a/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png b/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png
new file mode 100644
index 00000000000..9a5b6c9bdaf
Binary files /dev/null and b/docs/design/figures/omni/TTFP_s_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png
new file mode 100644
index 00000000000..68f0ef17e88
Binary files /dev/null and b/docs/design/figures/tts/Mean_AUDIO_RTF_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png
new file mode 100644
index 00000000000..44be96e96da
Binary files /dev/null and b/docs/design/figures/tts/Mean_AUDIO_TTFP_(ms)_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png b/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png
new file mode 100644
index 00000000000..2e5d1482bd7
Binary files /dev/null and b/docs/design/figures/tts/Mean_E2EL_(ms)_vllm_omni_vs_transformers.png differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png
new file mode 100644
index 00000000000..04d8f0bac53
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_baseline_vs_batch.png differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png
new file mode 100644
index 00000000000..eb85ec0dd4f
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_batch_vs_cuda_graph.png differ
diff --git a/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png
new file mode 100644
index 00000000000..6f0e0e2529d
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_e2e_ms_cuda_graph_vs_async_chunk.png differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png
new file mode 100644
index 00000000000..89ea30a8643
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_baseline_vs_batch.png differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png
new file mode 100644
index 00000000000..2b207b88987
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_batch_vs_cuda_graph.png differ
diff --git a/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png
new file mode 100644
index 00000000000..f5f7ad72c8f
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_rtf_cuda_graph_vs_async_chunk.png differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png
new file mode 100644
index 00000000000..6f8c1da4a5b
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_baseline_vs_batch.png differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png
new file mode 100644
index 00000000000..b0fe1d02a9d
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_batch_vs_cuda_graph.png differ
diff --git a/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png b/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png
new file mode 100644
index 00000000000..008ba9bf78f
Binary files /dev/null and b/docs/design/figures/tts/Mean_mean_ttfp_ms_cuda_graph_vs_async_chunk.png differ
diff --git a/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png
new file mode 100644
index 00000000000..7c65aa11770
Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_e2e_ms_vs_features.png differ
diff --git a/docs/design/figures/tts/Summary_mean_rtf_vs_features.png b/docs/design/figures/tts/Summary_mean_rtf_vs_features.png
new file mode 100644
index 00000000000..71bb2c54680
Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_rtf_vs_features.png differ
diff --git a/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png b/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png
new file mode 100644
index 00000000000..cef2546d6fe
Binary files /dev/null and b/docs/design/figures/tts/Summary_mean_ttfp_ms_vs_features.png differ
diff --git a/docs/design/qwen3_omni_tts_performance_optimization.md b/docs/design/qwen3_omni_tts_performance_optimization.md
new file mode 100644
index 00000000000..2f18a1b1bc0
--- /dev/null
+++ b/docs/design/qwen3_omni_tts_performance_optimization.md
@@ -0,0 +1,539 @@
+# Speech Generation on vLLM-Omni: Performance Optimizations for Qwen3-Omni and Qwen3-TTS
+
+## Summary
+
+vLLM-Omni supports end-to-end serving for speech-generating models, including both **Qwen3-Omni** (multimodal understanding + speech) and **Qwen3-TTS** (text-to-speech). Despite their different architectures, both models share the same multi-stage pipeline design and benefit from the same set of stacked optimizations:
+
+1. **Batching** improves GPU utilization stage by stage and increases overall throughput.
+2. **CUDA Graph** reduces CPU launch overhead and decode-time jitter on stable shapes.
+3. **Async Chunk and Streaming Output** overlap compute and communication across stages and emit audio incrementally, improving both TTFP and E2E.
+
+### Model architectures
+
+**Qwen3-Omni** is a native multimodal model that understands text, audio, image, and video inputs, and generates both text and speech outputs. Its pipeline has three stages:
+
+- **Thinker**: multimodal understanding and text generation
+- **Talker (+ Talker-MTP / code predictor path)**: converts semantic/text representations into codec tokens
+- **Code2Wav**: decodes codec tokens into waveform audio
+
+**Qwen3-TTS** is a lightweight, high-quality text-to-speech model. Its pipeline has two stages:
+
+- **Talker (AR decoder)**: auto-regressively generates codec tokens from text input
+- **Code2Wav (vocoder)**: decodes codec tokens into waveform audio
+
+The optimizations described in this post apply to both models. We present results for each side by side.
+
+### vLLM-Omni vs HF Transformers
+
+Compared with **HF Transformers** (offline, single request), vLLM-Omni with the full optimization stack delivers dramatically lower latency and higher efficiency for both models.
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | vLLM-Omni | HF Transformers | Improvement |
+| --- | --- | --- | --- |
+| E2E latency (s) | 23.78 | 336.10 | ~93% reduction |
+| TTFP (s) | 0.934 | 336.10 | ~99.7% reduction |
+| RTF | 0.32 | 3.776 | ~91% reduction (~12× faster) |
+
+- **E2E latency**: 23.78 s vs 336.10 s - **~93%** reduction
+- **TTFP**: 0.934 s vs 336.10 s - **~99.7%** reduction
+- **RTF**: 0.32 vs 3.776 - **~91%** reduction (~12x faster)
+
+**Qwen3-TTS** (H200, concurrency 1):
+
+
+
+| Metric | vLLM-Omni | HF Transformers | Improvement |
+| --- | --- | --- | --- |
+| E2E latency (ms) | 941 | 15,513 | ~94% reduction |
+| TTFP (ms) | 64 | 15,513 | ~99.6% reduction (242× faster) |
+| RTF | 0.16 | 2.64 | ~94% reduction (~16.5× faster) |
+
+- **E2E latency**: 941 ms vs 15,513 ms - **~94%** reduction
+- **TTFP**: 64 ms vs 15,513 ms - **~99.6%** reduction (242x faster)
+- **RTF**: 0.16 vs 2.64 - **~94%** reduction (~16.5x faster)
+
+### Stacked optimization summary
+
+Each optimization stacks on the previous one. The summary plots below show the cumulative effect at each step, with one line per concurrency level (1, 4, 10).
+
+**Qwen3-Omni** (A100):
+
+
+
+- **E2EL reduction**: ~74% at concurrency 10 (410,054 ms -> 104,901 ms); ~90% at concurrency 1 (426,529 ms -> 41,216 ms)
+- **TTFP reduction**: ~96% at concurrency 10 (409,705 ms -> 16,482 ms); ~99.7% at concurrency 1 (426,078 ms -> 1,164 ms)
+- **RTF reduction**: ~74% at concurrency 10 (2.83 -> 0.74); ~90% at concurrency 1 (2.08 -> 0.21)
+
+**Qwen3-TTS** (H200):
+
+
+
+- **E2EL reduction**: ~85% at concurrency 10 (12,141 ms -> 1,767 ms); ~29% at concurrency 1 (1,323 ms -> 941 ms)
+- **TTFP reduction**: ~96.5% at concurrency 10 (12,141 ms -> 425 ms); ~95% at concurrency 1 (1,323 ms -> 64 ms)
+- **RTF reduction**: ~86% at concurrency 10 (2.19 -> 0.31); ~30% at concurrency 1 (0.23 -> 0.16)
+
+**Benchmark environment:**
+
+| | Qwen3-Omni | Qwen3-TTS |
+| --- |-----------------------------| --- |
+| **GPU** | A100 | H200 |
+| **Model** | Qwen3-Omni-30B-A3B-Instruct | Qwen3-TTS-12Hz-1.7B-CustomVoice |
+| **vLLM** | v0.17.0 | v0.18.0 |
+| **vllm-omni** | commit 199f7832 | v0.18.0rc2 |
+| **CUDA** | 12.9 | 12.8 |
+
+This post walks through each optimization in the same order they are typically enabled in practice, then ends with deployment playbooks for both models.
+
+---
+
+## Pipeline Batching
+
+### How stage-wise batching works
+
+For both Qwen3-Omni and Qwen3-TTS, batching is a pipeline-level optimization:
+
+- Requests are grouped per stage using `runtime.max_batch_size`
+- Each stage executes batch inference with its own scheduler/worker
+- Stage outputs are routed to downstream stages with per-request mapping preserved
+
+**Batching strategy by stage:** The understanding and decode stages (Thinker for Omni, Talker for both) use **continuous batching**: requests can join and leave the batch over time. Code2Wav uses **static batching**: once a batch is formed, the stage runs the whole batch before starting the next. This matches the decode pattern of Code2Wav and keeps implementation simple while still improving throughput.
+
+### Batching results (Baseline vs. Batch)
+
+Batching alone greatly reduces E2EL and RTF across all concurrencies. The biggest gains appear at high concurrency where requests share GPU resources.
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | Concurrency | Baseline | + Batch | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 426,529 | 307,719 | 1.4× |
+| E2EL (ms) | 4 | 407,213 | 376,934 | 1.1× |
+| E2EL (ms) | 10 | 410,054 | 234,844 | 1.7× |
+| TTFP (ms) | 1 | 426,078 | 307,262 | 1.4× |
+| TTFP (ms) | 4 | 406,843 | 376,466 | 1.1× |
+| TTFP (ms) | 10 | 409,705 | 234,557 | 1.7× |
+| RTF | 1 | 2.08 | 1.51 | 1.4× |
+| RTF | 4 | 2.55 | 1.83 | 1.4× |
+| RTF | 10 | 2.83 | 2.28 | 1.2× |
+
+At concurrency 10, E2EL drops from ~410 s to ~235 s; at concurrency 1, from ~427 s to ~308 s.
+
+**Qwen3-TTS** (H200):
+
+
+
+| Metric | Concurrency | Baseline | + Batch | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 1,323 | 1,339 | 1.0× |
+| E2EL (ms) | 4 | 5,171 | 1,471 | 3.5× |
+| E2EL (ms) | 10 | 12,141 | 1,705 | 7.1× |
+| RTF | 1 | 0.230 | 0.234 | 1.0× |
+| RTF | 4 | 0.908 | 0.255 | 3.6× |
+| RTF | 10 | 2.186 | 0.292 | 7.5× |
+| Throughput (audio-s/wall-s) | 10 | 3.99 | 33.53 | 8.4× |
+
+At concurrency 10, batching alone brings Qwen3-TTS RTF from 2.19 (slower than realtime) down to 0.29 (faster than realtime), and throughput from 4.0 to 33.5 audio-sec/wall-sec.
+
+---
+
+## CUDA Graph on the Critical Decode Path
+
+### Why CUDA Graph helps here
+
+In decode-heavy serving, repeatedly launching many small kernels from CPU can become a visible overhead. CUDA Graph reduces this overhead by capturing and replaying stable execution graphs.
+
+In stage configs, this is represented by `enforce_eager: false` for stages where graph capture is desired (Thinker/Talker), while Code2Wav keeps eager mode depending on stage behavior.
+
+### CUDA Graph results on top of batching
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | Concurrency | Batch | + CUDA Graph | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 307,719 | 61,613 | 5.0× |
+| E2EL (ms) | 4 | 376,934 | 79,019 | 4.8× |
+| E2EL (ms) | 10 | 234,844 | 126,867 | 1.9× |
+| TTFP (ms) | 1 | 307,262 | 61,257 | 5.0× |
+| TTFP (ms) | 4 | 376,466 | 78,634 | 4.8× |
+| TTFP (ms) | 10 | 234,557 | 126,534 | 1.9× |
+| RTF | 1 | 1.51 | 0.32 | 4.7× |
+| RTF | 4 | 1.83 | 0.43 | 4.3× |
+| RTF | 10 | 2.28 | 0.90 | 2.5× |
+
+For the larger Qwen3-Omni model (30B-A3B), CUDA Graph provides a significant improvement. At concurrency 1, E2EL drops from ~308 s to ~62 s; at concurrency 10, from ~235 s to ~127 s.
+
+**Qwen3-TTS** (H200):
+
+
+
+| Metric | Concurrency | Batch | + CUDA Graph | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 1,339 | 733 | 1.8× |
+| E2EL (ms) | 4 | 1,471 | 987 | 1.5× |
+| E2EL (ms) | 10 | 1,705 | 1,197 | 1.4× |
+| RTF | 1 | 0.234 | 0.124 | 1.9× |
+| RTF | 10 | 0.292 | 0.203 | 1.4× |
+| Throughput (audio-s/wall-s) | 10 | 33.53 | 47.15 | 1.4× |
+
+At concurrency 1, CUDA Graph reduces E2EL from 1,339 ms to 733 ms and RTF from 0.234 to 0.124 - nearly a 2x improvement. The benefit is consistent across all concurrency levels.
+
+---
+
+## Async Chunk and Streaming Output: Earlier Audio and Cross-Stage Overlap
+
+### Why this step matters for first-packet latency
+
+Two mechanisms work together to improve user-visible latency:
+
+- **Streaming output**: audio streaming emits audio chunks as soon as they are decoded (lower **TTFP**). Without streaming, the client waits for larger buffers or end-of-sequence.
+- **Async chunk** is the main enabler for *earlier* audio: instead of handing off whole-request results between stages, each stage forwards **chunks** so the next stage can start as soon as the first chunk is ready. For Omni: Thinker -> Talker forwards hidden-state chunks; for both: Talker -> Code2Wav forwards codec chunks; Code2Wav decodes and emits packets incrementally. This **overlaps compute and communication** across stages and directly reduces time-to-first-audio-packet (TTFP) and end-to-end latency (E2EL).
+
+So in practice: streaming output defines *how* bytes are sent to the client; async chunk defines *when* the pipeline can produce the first bytes.
+
+**Dependency between the two:** Async chunk and audio streaming output are mutually dependent. Without async chunk, **audio streaming output cannot truly take effect**. Without audio streaming output, async chunk's **TTFP advantage is not fully realized**: the client would still wait for larger buffers or end-of-sequence instead of hearing the first packet as soon as it is ready. We therefore recommend enabling **both** on top of batching + CUDA Graph; the benchmarks in this post use both.
+
+### Results: Batch + CUDA Graph vs. Batch + CUDA Graph + Async Chunk + Streaming Output
+
+**Qwen3-Omni** (A100):
+
+
+
+| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement |
+| --- | --- | --- | --- | --- |
+| E2EL (ms) | 1 | 61,613 | 41,216 | 1.5× |
+| E2EL (ms) | 4 | 79,019 | 67,584 | 1.2× |
+| E2EL (ms) | 10 | 126,867 | 104,901 | 1.2× |
+| TTFP (ms) | 1 | 61,257 | 1,164 | 53× |
+| TTFP (ms) | 4 | 78,634 | 3,152 | 24.9× |
+| TTFP (ms) | 10 | 126,534 | 16,482 | 7.7× |
+| RTF | 1 | 0.32 | 0.21 | 1.5× |
+| RTF | 4 | 0.43 | 0.34 | 1.3× |
+| RTF | 10 | 0.90 | 0.74 | 1.2× |
+
+Enabling both brings TTFP down sharply (concurrency 1: 61,257 ms -> 1,164 ms, **~98% reduction**; concurrency 4: 78,634 ms -> 3,152 ms, **~96% reduction**). E2EL and RTF also improve at every concurrency.
+
+**Qwen3-TTS** (H200):
+
+
+
+| Metric | Concurrency | Batch + CG | + Async Chunk | Improvement |
+| --- | --- | --- | --- | --- |
+| TTFP (ms) | 1 | 733 | **64** | **11.5×** |
+| TTFP (ms) | 4 | 987 | **119** | **8.3×** |
+| TTFP (ms) | 10 | 1,197 | **425** | **2.8×** |
+| E2EL (ms) | 1 | 733 | 941 | 0.8× |
+| E2EL (ms) | 10 | 1,197 | 1,767 | 0.7× |
+| RTF | 1 | 0.124 | 0.160 | 0.8× |
+| RTF | 10 | 0.203 | 0.314 | 0.6× |
+
+The TTFP improvement is the headline result for both models. For Qwen3-TTS at concurrency 1, users hear the first audio in **64 ms** instead of 733 ms - an **11.5x reduction**. For Qwen3-Omni at concurrency 1, TTFP drops from 61 s to 1.2 s - a **53x reduction**.
+
+### Why E2EL and RTF are higher with async chunk (TTS)
+
+The table above shows that enabling async chunk + streaming *increases* E2EL and RTF for TTS compared to CUDA Graph alone. This is expected - the two configurations optimize for fundamentally different metrics:
+
+- **CUDA Graph (no async chunk)** generates the entire audio end-to-end before returning. No chunking overhead, so total compute is minimized.
+- **Async Chunk + Streaming** splits the pipeline into incremental chunks, adding overhead from chunked transport, context overlap in Code2Wav (`codec_left_context_frames=25`), and smaller effective batch sizes per chunk.
+
+**The tradeoff is intentional.** Async chunk trades ~30% higher total compute for **11x faster time-to-first-audio**. For interactive applications (voice assistants, chatbots), TTFP determines perceived responsiveness. For offline batch processing, CUDA Graph without async chunk is the better choice.
+
+---
+
+## TTS-Specific: Code Predictor Re-prefill + `torch.compile`
+
+Qwen3-TTS has a **code predictor** - a small 5-layer transformer that generates residual codebook tokens (groups 1 through Q-1) autoregressively. Each AR step operates on very short sequences (2 to ~16 tokens).
+
+The naive approach uses a KV cache for this small transformer, similar to the main Talker. But the KV cache machinery (block tables, slot mappings, paged attention) introduces significant overhead relative to the tiny model. Two optimizations replace that:
+
+### Re-prefill (stateless forward, no KV cache)
+
+Instead of maintaining a KV cache across steps, the code predictor **re-feeds the full growing sequence** at each AR step using `F.scaled_dot_product_attention`. With sequences of at most ~16 tokens through 5 layers, the O(T^2) attention cost is negligible - and removing the KV cache machinery (block table management, `set_forward_context`, slot mapping) saves far more time than it costs.
+
+### `torch.compile` on the code predictor forward
+
+The 5-layer transformer forward pass launches ~60 small CUDA kernels per step. `torch.compile(mode="default", dynamic=True)` fuses these into fewer kernels via Inductor:
+
+```python
+self._compiled_model_fwd = torch.compile(
+ self.model.forward,
+ mode="default", # no Inductor CUDA graphs, avoids conflict with vLLM's CUDAGraphWrapper
+ dynamic=True, # sequence length grows each step (2, 3, ..., num_groups+1)
+)
+```
+
+`mode="default"` is used instead of `mode="reduce-overhead"` to avoid conflicts with vLLM's own CUDA graph capture on the main Talker model. `dynamic=True` handles the growing sequence length without recompilation.
+
+These optimizations are always-on in the current codebase - all Qwen3-TTS benchmark results in this post include them.
+
+---
+
+## TTS-Specific: Dynamic Initial Chunk for Faster First Audio
+
+In the async chunk pipeline, the standard `codec_chunk_frames` is 25 (each chunk = ~2 seconds of audio at 12 Hz). Waiting for 25 frames before forwarding the first chunk to Code2Wav adds unnecessary TTFP. The **initial codec chunk** optimization sends a smaller first chunk so Code2Wav can start decoding earlier.
+
+**Dynamic initial chunk sizing (default behavior):**
+
+Rather than using a fixed initial chunk size, vLLM-Omni dynamically selects it based on current server load. The initial chunk size is chosen from power-of-2 steps [2, 4, 8, 16] based on load factor (`active_requests / max_batch_size`):
+
+| Server load | Initial chunk frames | Rationale |
+| --- | --- | --- |
+| Low (e.g. 1/10 active) | **2** (~167 ms of audio) | Minimize TTFP when there's headroom |
+| Medium (e.g. 5/10 active) | **4-8** | Balance TTFP vs decode efficiency |
+| High (e.g. 10/10 active) | **16** | Larger first chunk to amortize decode cost |
+
+After the initial chunk, all subsequent chunks use the standard `codec_chunk_frames` (25) size.
+
+**How it works in the pipeline:**
+
+1. Talker generates codec tokens auto-regressively
+2. The stage input processor checks current load and picks an initial chunk size (e.g. **2 frames** at low load)
+3. After that many frames, the first chunk is forwarded to Code2Wav
+4. Code2Wav decodes this small chunk and emits the first audio packet
+5. Subsequent chunks use the standard 25-frame size for efficient batch decoding
+
+**Per-request override:** Clients can also set a fixed initial chunk size via the API:
+
+```json
+{"initial_codec_chunk_frames": 2}
+```
+
+This overrides the dynamic calculation for that request.
+
+**Config (server-side):**
+
+```yaml
+runtime:
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ codec_streaming: true
+ codec_chunk_frames: 25 # standard chunk size (~2s of audio)
+ codec_left_context_frames: 25
+ # initial chunk is computed dynamically by default
+ # set initial_codec_chunk_frames: 2 to force a fixed value
+```
+
+The 64 ms TTFP result reported above for Qwen3-TTS at concurrency 1 uses the dynamic initial chunk, which picks `initial_codec_chunk_frames=2` at low load. At higher concurrency the dynamic sizing increases the initial chunk to maintain decode efficiency.
+
+---
+
+## Live Demo: Streaming TTS over WebSocket
+
+vLLM-Omni supports real-time streaming audio output for Qwen3-TTS over WebSocket ([PR #1719](https://github.com/vllm-project/vllm-omni/pull/1719)). With `stream_audio: true`, the server sends chunked PCM audio frames as they are generated, so clients can start playback before full sentence synthesis completes.
+
+The WebSocket protocol uses `audio.start` / binary PCM chunks / `audio.done` framing per sentence:
+
+```json
+// Client sends:
+{"type":"session.config","voice":"Vivian","response_format":"pcm","stream_audio":true}
+{"type":"input.text","text":"Hello world. This is a streaming demo."}
+{"type":"input.done"}
+
+// Server streams back per sentence:
+{"type":"audio.start","sentence_index":0,"sentence_text":"Hello world.","format":"pcm","sample_rate":24000}
+
+
+...
+{"type":"audio.done","sentence_index":0,"total_bytes":96000,"error":false}
+{"type":"audio.start","sentence_index":1,"sentence_text":"This is a streaming demo.","format":"pcm","sample_rate":24000}
+
+...
+{"type":"audio.done","sentence_index":1,"total_bytes":72000,"error":false}
+{"type":"session.done","total_sentences":2}
+```
+
+VIDEO
+
+---
+
+## Deployment Playbook
+
+### Qwen3-Omni
+
+#### 1) Serve with the default 3-stage config
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --omni \
+ --port 8091
+```
+
+Notes:
+
+- `runtime.max_batch_size` controls stage-level batching.
+- Thinker/Talker commonly use `enforce_eager: false` for CUDA Graph paths.
+- Code2Wav often remains eager (`enforce_eager: true`) depending on runtime behavior.
+
+#### 2) Enable async chunk
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --omni \
+ --port 8091 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+```
+
+#### 3) Key config knobs
+
+```yaml
+async_chunk: true
+stage_args:
+ - stage_id: 0 # thinker
+ runtime:
+ max_batch_size: 64
+ engine_args:
+ enforce_eager: false
+ max_num_batched_tokens: 32768
+ custom_process_next_stage_input_func: >-
+ vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
+
+ - stage_id: 1 # talker
+ runtime:
+ max_batch_size: 64
+ engine_args:
+ enforce_eager: false
+ max_num_batched_tokens: 32768
+ custom_process_next_stage_input_func: >-
+ vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
+
+ - stage_id: 2 # code2wav
+ runtime:
+ max_batch_size: 64
+ engine_args:
+ enforce_eager: true
+ max_num_batched_tokens: 51200
+```
+
+#### Reproduce Qwen3-Omni benchmarks
+
+```bash
+vllm bench serve \
+ --dataset-name random \
+ --port ${PORT} \
+ --model ${MODEL_PATH} \
+ --endpoint /v1/chat/completions \
+ --backend openai-chat-omni \
+ --max-concurrency ${MAX_CONCURRENCY} \
+ --num-prompts ${NUM_PROMPTS} \
+ --random-input-len 2500 \
+ --ignore-eos \
+ --percentile-metrics ttft,tpot,itl,e2el,audio_ttfp,audio_rtf \
+ --random-output-len 900 \
+ --extra_body '{"modalities": ["text","audio"]}'
+```
+
+### Qwen3-TTS
+
+#### 1) Serve with async chunk (recommended)
+
+```bash
+vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
+ --omni \
+ --port 8000
+```
+
+The default config (`qwen3_tts.yaml`) enables the full optimization stack:
+
+- Batching with `max_batch_size: 10` on the Talker stage
+- CUDA Graph on the Talker (`enforce_eager: false`)
+- Async chunk with streaming transport
+
+#### 2) Serve without async chunk (for comparison)
+
+```bash
+vllm-omni serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
+ --omni \
+ --port 8000 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
+```
+
+#### 3) Key config knobs
+
+```yaml
+async_chunk: true
+stage_args:
+ - stage_id: 0 # Talker (AR decoder)
+ runtime:
+ max_batch_size: 10
+ engine_args:
+ enforce_eager: false
+ max_num_batched_tokens: 512
+ custom_process_next_stage_input_func: >-
+ vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
+
+ - stage_id: 1 # Code2Wav (vocoder)
+ runtime:
+ max_batch_size: 1
+ engine_args:
+ enforce_eager: true
+ max_num_batched_tokens: 8192
+
+runtime:
+ connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ codec_streaming: true
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+```
+
+#### Reproduce Qwen3-TTS benchmarks
+
+```bash
+GPU_DEVICE=0 \
+MODEL=Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
+NUM_PROMPTS=50 \
+CONCURRENCY="1 4 10" \
+bash benchmarks/qwen3-tts/vllm_omni/run_stacked_benchmark.sh
+```
+
+This cycles through four configs (Baseline -> + Batch -> + CUDA Graph -> + Async Chunk + Streaming), benchmarks each at the specified concurrency levels, and generates all comparison figures automatically.
diff --git a/docs/serving/speech_api.md b/docs/serving/speech_api.md
index ecbe8d9ac98..733811081a7 100644
--- a/docs/serving/speech_api.md
+++ b/docs/serving/speech_api.md
@@ -15,7 +15,7 @@ Each server instance runs a single model (specified at startup via `vllm serve <
```bash
# Qwen3-TTS: CustomVoice model (predefined speakers)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -300,7 +300,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \
```bash
# Start server with VoiceDesign model first
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -322,7 +322,7 @@ curl -X POST http://localhost:8091/v1/audio/speech \
```bash
# Start server with Base model first
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -517,15 +517,16 @@ for result in response.json()["results"]:
All items are fanned out to `generate()` concurrently. The engine's stage worker automatically batches them up to the configured `max_batch_size` and queues the rest — no client-side throttling needed.
-For best throughput, use a batch-optimized stage config with `max_batch_size > 1`:
+For best throughput, set both stages' `max_num_seqs` to ≥4 via `--stage-overrides`:
```bash
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \
- --omni --port 8091 --trust-remote-code --enforce-eager
+ --omni --port 8091 --trust-remote-code --enforce-eager \
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},
+ "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
```
-The default `qwen3_tts.yaml` uses `max_batch_size: 1` (single request). The `qwen3_tts_batch.yaml` config sets `max_batch_size: 4` for ~4x throughput.
+The bundled `qwen3_tts.yaml` uses `max_num_seqs: 1` (single request) on both stages. Bumping to 4 yields roughly 4× throughput on the talker and lets stage 1 batch chunks across in-flight requests.
## Supported Models
@@ -617,7 +618,7 @@ Enable debug logging:
```bash
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
diff --git a/docs/source/architecture/async-chunk-architecture.png b/docs/source/architecture/async-chunk-architecture.png
index 249de53bfe3..7b3e95e4df9 100644
Binary files a/docs/source/architecture/async-chunk-architecture.png and b/docs/source/architecture/async-chunk-architecture.png differ
diff --git a/docs/source/architecture/vllm-omni-dataflow-between-stages.png b/docs/source/architecture/vllm-omni-dataflow-between-stages.png
index cdbc9a8b7b3..74abc81ff07 100644
Binary files a/docs/source/architecture/vllm-omni-dataflow-between-stages.png and b/docs/source/architecture/vllm-omni-dataflow-between-stages.png differ
diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md
index 4e7003cce37..7bdeede446a 100644
--- a/docs/user_guide/diffusion_features.md
+++ b/docs/user_guide/diffusion_features.md
@@ -115,8 +115,8 @@ The following tables show which models support each feature:
| **FLUX.2-dev** | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **GLM-Image** | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| **HunyuanImage3** | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
-| **LongCat-Image** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
-| **LongCat-Image-Edit** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
+| **LongCat-Image** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
+| **LongCat-Image-Edit** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **MagiHuman** | ❌ | ❌ | ❌ | ❓ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **MammothModa2(T2I)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Nextstep_1(T2I)** | ❓ | ❓ | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@@ -140,10 +140,10 @@ The following tables show which models support each feature:
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ |
| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
-| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
| **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
-| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
+| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
**Frame Interpolation Support**
diff --git a/docs/user_guide/examples/offline_inference/bagel.md b/docs/user_guide/examples/offline_inference/bagel.md
index e6266868722..1fb4d404578 100644
--- a/docs/user_guide/examples/offline_inference/bagel.md
+++ b/docs/user_guide/examples/offline_inference/bagel.md
@@ -176,8 +176,6 @@ Example configuration for TP=2 on GPUs 0 and 1:
| Parameter | Value | Description |
| :-------------------- | :------ | :------------------------------- |
-| `window_size` | `-1` | Window size (-1 means unlimited) |
-| `max_inflight` | `1` | Maximum inflight requests |
| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) |
## Using Mooncake Connector
diff --git a/docs/user_guide/examples/offline_inference/qwen3_tts.md b/docs/user_guide/examples/offline_inference/qwen3_tts.md
index 4ece5219d7f..7226ac1fe4b 100644
--- a/docs/user_guide/examples/offline_inference/qwen3_tts.md
+++ b/docs/user_guide/examples/offline_inference/qwen3_tts.md
@@ -144,13 +144,13 @@ completes. This demonstrates that audio data is available progressively rather t
## Batched Decoding
-The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
+The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
```
python end2end.py --query-type CustomVoice \
--txt-prompts benchmark_prompts.txt \
--batch-size 4 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
```
**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.
diff --git a/docs/user_guide/examples/online_serving/qwen3_omni.md b/docs/user_guide/examples/online_serving/qwen3_omni.md
index 6f6d9ae4a9d..611eb6fd3fc 100644
--- a/docs/user_guide/examples/online_serving/qwen3_omni.md
+++ b/docs/user_guide/examples/online_serving/qwen3_omni.md
@@ -18,12 +18,12 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you want to open async chunking for qwen3-omni, launch the server with command below
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /vllm_omni/deploy/qwen3_omni_moe.yaml
```
If you have custom stage configs file, launch the server with command below
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
```
### Send Multi-modal Request
@@ -187,7 +187,7 @@ The script supports the following arguments:
- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct)
- `--server-port`: Port for vLLM server (default: 8091)
- `--gradio-port`: Port for Gradio demo (default: 7861)
-- `--stage-configs-path`: Path to custom stage configs YAML file (optional)
+- `--deploy-config`: Path to custom deploy config YAML file (optional)
- `--server-host`: Host for vLLM server (default: 0.0.0.0)
- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1)
- `--share`: Share Gradio demo publicly (creates a public link)
@@ -202,7 +202,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you have custom stage configs file:
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
```
**Step 2: Run the Gradio demo**
diff --git a/docs/user_guide/examples/online_serving/qwen3_tts.md b/docs/user_guide/examples/online_serving/qwen3_tts.md
index 4e632d4c288..95f234f02de 100644
--- a/docs/user_guide/examples/online_serving/qwen3_tts.md
+++ b/docs/user_guide/examples/online_serving/qwen3_tts.md
@@ -58,7 +58,7 @@ Then open http://localhost:7860 in your browser.
```bash
# CustomVoice model (predefined speakers)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -66,7 +66,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
# VoiceDesign model
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
@@ -74,7 +74,7 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign \
# Base model (voice cloning)
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-Base \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--omni \
--port 8091 \
--trust-remote-code \
diff --git a/examples/offline_inference/bagel/README.md b/examples/offline_inference/bagel/README.md
index 48517b1cda0..3e653d0e3ab 100644
--- a/examples/offline_inference/bagel/README.md
+++ b/examples/offline_inference/bagel/README.md
@@ -173,8 +173,6 @@ Example configuration for TP=2 on GPUs 0 and 1:
| Parameter | Value | Description |
| :-------------------- | :------ | :------------------------------- |
-| `window_size` | `-1` | Window size (-1 means unlimited) |
-| `max_inflight` | `1` | Maximum inflight requests |
| `shm_threshold_bytes` | `65536` | Shared memory threshold (64KB) |
## Using Mooncake Connector
diff --git a/examples/offline_inference/ming_flash_omni/README.md b/examples/offline_inference/ming_flash_omni/README.md
new file mode 100644
index 00000000000..7414163fc01
--- /dev/null
+++ b/examples/offline_inference/ming_flash_omni/README.md
@@ -0,0 +1,76 @@
+# Ming-flash-omni 2.0
+
+[Ming-flash-omni-2.0](https://github.com/inclusionAI/Ming) is an omni-modal model supporting text, image, video, and audio understanding, with outputs in text, image, and audio. For now, Ming-flash-omni-2.0 in vLLM-Omni is supported with thinker stage (multi-modal understanding).
+
+## Setup
+
+Please refer to the [stage configuration documentation](https://docs.vllm.ai/projects/vllm-omni/en/latest/configuration/stage_configs/) to configure memory allocation appropriately for your hardware setup.
+
+## Run examples
+
+### Text-only
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type text
+```
+
+#### Reasoning (Thinking Mode)
+
+Reasoning (Thinking) mode is enabled via applying "detailed thinking on" when building the system prompt template (in `apply_chat_template`).
+
+In the end2end example, a default problem for thinking mode is provided, as referred to the example usage of Ming's cookbook;
+To utilize it, you have to download the example figure from https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png
+
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png
+```
+
+### Image understanding
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image
+
+# With a local image
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image --image-path /path/to/image.jpg
+```
+
+### Audio understanding
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio
+
+# With a local audio file
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --audio-path /path/to/audio.wav
+```
+
+### Video understanding
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video
+
+# With a local video and custom frame count
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_video --video-path /path/to/video.mp4 --num-frames 16
+```
+
+### Mixed modalities (image + audio)
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_mixed_modalities \
+ --image-path /path/to/image.jpg \
+ --audio-path /path/to/audio.wav
+```
+
+If media file paths are not provided, the script uses built-in default assets.
+
+### Modality control
+To control output modalities (e.g. text-only output):
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_audio --modalities text
+```
+
+*For now, only text output is supported*
+
+### Custom stage config
+```bash
+python examples/offline_inference/ming_flash_omni/end2end.py --query-type use_image \
+ --stage-configs-path /path/to/your_config.yaml
+```
+
+## Online serving
+
+For online serving via the OpenAI-compatible API, see [examples/online_serving/ming_flash_omni/README.md](../../online_serving/ming_flash_omni/README.md).
diff --git a/examples/offline_inference/ming_flash_omni/end2end.py b/examples/offline_inference/ming_flash_omni/end2end.py
new file mode 100644
index 00000000000..49cdbcc0186
--- /dev/null
+++ b/examples/offline_inference/ming_flash_omni/end2end.py
@@ -0,0 +1,485 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# Partial example cases are referred from
+# https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/cookbook.ipynb
+import os
+import time
+from typing import NamedTuple
+
+import librosa
+import numpy as np
+import vllm
+from PIL import Image
+from transformers import AutoProcessor
+from vllm import SamplingParams
+from vllm.assets.audio import AudioAsset
+from vllm.assets.image import ImageAsset
+from vllm.assets.video import VideoAsset, video_to_ndarrays
+from vllm.multimodal.image import convert_image_mode
+from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+import vllm_omni
+from vllm_omni.entrypoints.omni import Omni
+
+# Imports the processor also registers itself
+from vllm_omni.transformers_utils.processors.ming import MingFlashOmniProcessor # noqa: F401
+
+SEED = 42
+MODEL_NAME = "Jonathan1909/Ming-flash-omni-2.0"
+
+
+class QueryResult(NamedTuple):
+ inputs: dict
+ limit_mm_per_prompt: dict[str, int]
+
+
+def get_text_query(processor: MingFlashOmniProcessor, question: str | None = None) -> QueryResult:
+ if question is None:
+ question = "请详细介绍鹦鹉的生活习性。"
+ conversation = [{"role": "HUMAN", "content": question}]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+ return QueryResult(
+ inputs={"prompt": prompt},
+ limit_mm_per_prompt={},
+ )
+
+
+def get_image_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ image_path: str | None = None,
+) -> QueryResult:
+ if question is None:
+ question = "Describe this image in detail."
+
+ if image_path:
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ image_data = convert_image_mode(Image.open(image_path), "RGB")
+ else:
+ image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
+
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "image", "image": image_data},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"image": image_data},
+ },
+ limit_mm_per_prompt={"image": 1},
+ )
+
+
+def get_audio_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ audio_path: str | None = None,
+ sampling_rate: int = 16000,
+) -> QueryResult:
+ if question is None:
+ question = "Please recognize the language of this speech and transcribe it. Format: oral."
+
+ if audio_path:
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+ audio_signal, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_data = (audio_signal.astype(np.float32), sr)
+ else:
+ audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
+
+ # Use a string for "audio" so the processor counts it as 1 audio input
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "audio", "audio": "input"},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"audio": audio_data},
+ },
+ limit_mm_per_prompt={"audio": 1},
+ )
+
+
+def get_video_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ video_path: str | None = None,
+ num_frames: int = 16,
+) -> QueryResult:
+ if question is None:
+ question = "Describe what is happening in this video."
+
+ if video_path:
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"Video file not found: {video_path}")
+ video_frames = video_to_ndarrays(video_path, num_frames=num_frames)
+ else:
+ video_frames = VideoAsset(name="baby_reading", num_frames=num_frames).np_ndarrays
+
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "video"},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"video": video_frames},
+ },
+ limit_mm_per_prompt={"video": 1},
+ )
+
+
+def get_mixed_modalities_query(
+ processor: MingFlashOmniProcessor,
+ image_path: str | None = None,
+ audio_path: str | None = None,
+ sampling_rate: int = 16000,
+) -> QueryResult:
+ """Mixed image + audio understanding."""
+ question = "Describe the image, and recognize the language of this speech and transcribe it. Format: oral"
+
+ if image_path:
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ image_data = convert_image_mode(Image.open(image_path), "RGB")
+ else:
+ image_data = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
+
+ if audio_path:
+ if not os.path.exists(audio_path):
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
+ sig, sr = librosa.load(audio_path, sr=sampling_rate)
+ audio_data = (sig.astype(np.float32), sr)
+ else:
+ audio_data = AudioAsset("mary_had_lamb").audio_and_sample_rate
+
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "image", "image": image_data},
+ {"type": "audio", "audio": "input"},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
+
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"image": image_data, "audio": audio_data},
+ },
+ limit_mm_per_prompt={"image": 1, "audio": 1},
+ )
+
+
+def get_reasoning_query(
+ processor: MingFlashOmniProcessor,
+ question: str | None = None,
+ image_path: str | None = None,
+) -> QueryResult:
+ if question is None:
+ # NOTE: To use the following default question, input with example figure provided by Ming
+ # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/figures/cases/3_0.png
+ # E.g.,
+ # python examples/offline_inference/ming_flash_omni/end2end.py -q reasoning --image-path ./3_0.png
+ # Otherwise, the problem solving might be false.
+ question = (
+ "Based on the following rules:\n•\tYou control the smiley face character\n"
+ "•\tYou can move up, down, left, and right, and only a single square at a time\n"
+ "•\tWalls are dark grey and cannot be moved into\n•\tThe brown square is a box\n•"
+ "\tThe box can be pushed by moving into it (i.e., if you are in the square "
+ "adjacent to the box to the left, and move onto the square with the box, "
+ "the box will move one square to the right).\n"
+ "•\tThe box cannot be pushed into walls\n"
+ "•\tThe blue door at the bottom is locked and cannot be passed through, "
+ "unless the box is placed on the blue square\n"
+ "•\tThe square beneath the blue door is the exit\n"
+ "•\tMoving from one square to another\n\n"
+ "Let's assume a coordinate system where the smiley face is "
+ "on the top left at (1,1) and the square below it is (1,2). "
+ "The smiley face performs the following moves: {down, right, right, right}, "
+ "such that the smiley face is at square (4,2) and the box is in square (5,2). "
+ "What are the next sequence of moves that must be done to move the box down to (5,3)? "
+ "Give your answer as a comma separated list."
+ )
+
+ if image_path:
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image file not found: {image_path}")
+ image_data = convert_image_mode(Image.open(image_path), "RGB")
+ conversation = [
+ {
+ "role": "HUMAN",
+ "content": [
+ {"type": "image", "image": image_data},
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+ prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True)
+ return QueryResult(
+ inputs={
+ "prompt": prompt,
+ "multi_modal_data": {"image": image_data},
+ },
+ limit_mm_per_prompt={"image": 1},
+ )
+
+ conversation = [{"role": "HUMAN", "content": question}]
+ prompt = processor.apply_chat_template(conversation, tokenize=False, use_cot_system_prompt=True)
+ return QueryResult(
+ inputs={"prompt": prompt},
+ limit_mm_per_prompt={},
+ )
+
+
+query_map = {
+ "text": get_text_query,
+ "use_audio": get_audio_query,
+ "use_image": get_image_query,
+ "use_video": get_video_query,
+ "use_mixed_modalities": get_mixed_modalities_query,
+ "reasoning": get_reasoning_query,
+}
+
+
+def main(args):
+ print(
+ "=" * 20,
+ "\n",
+ f"vllm version: {vllm.__version__}\n",
+ f"vllm-omni version: {vllm_omni.__version__}\n",
+ "=" * 20,
+ sep="",
+ )
+
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
+ assert isinstance(processor, MingFlashOmniProcessor), f"Wrong processor type being used: {type(processor)}"
+
+ query_func = query_map[args.query_type]
+ if args.query_type == "use_image":
+ query_result = query_func(processor, image_path=args.image_path)
+ elif args.query_type == "use_audio":
+ query_result = query_func(processor, audio_path=args.audio_path, sampling_rate=args.sampling_rate)
+ elif args.query_type == "use_video":
+ query_result = query_func(processor, video_path=args.video_path, num_frames=args.num_frames)
+ elif args.query_type == "use_mixed_modalities":
+ query_result = query_func(
+ processor,
+ image_path=args.image_path,
+ audio_path=args.audio_path,
+ sampling_rate=args.sampling_rate,
+ )
+ elif args.query_type == "reasoning":
+ query_result = query_func(processor, image_path=args.image_path)
+ else:
+ query_result = query_func(processor)
+
+ # Initialize Omni (with thinker-only stage config)
+ omni = Omni(
+ model=MODEL_NAME,
+ stage_configs_path=args.stage_configs_path,
+ log_stats=args.log_stats,
+ init_timeout=args.init_timeout,
+ stage_init_timeout=args.stage_init_timeout,
+ )
+
+ # Thinker sampling params
+ thinker_sampling_params = SamplingParams(
+ temperature=0.4,
+ top_p=0.9,
+ max_tokens=args.max_tokens,
+ repetition_penalty=1.05,
+ seed=SEED,
+ detokenize=True,
+ )
+ sampling_params_list = [thinker_sampling_params]
+
+ prompts = [query_result.inputs for _ in range(args.num_prompts)]
+
+ if args.modalities is not None:
+ output_modalities = args.modalities.split(",")
+ for prompt in prompts:
+ prompt["modalities"] = output_modalities
+
+ total_requests = len(prompts)
+ processed_count = 0
+ print(f"Query type: {args.query_type}")
+ print(f"Number of prompts: {total_requests}")
+
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ profiler_enabled = args.enable_profiler
+ if profiler_enabled:
+ omni.start_profile(stages=args.profiler_stages)
+
+ for stage_outputs in omni.generate(prompts, sampling_params_list):
+ output = stage_outputs.request_output
+ if stage_outputs.final_output_type == "text":
+ request_id = output.request_id
+ text_output = output.outputs[0].text
+ lines = []
+ lines.append("Prompt:\n")
+ lines.append(str(output.prompt) + "\n")
+ lines.append("Text Output:\n")
+ lines.append(str(text_output).strip() + "\n")
+ print(*lines, sep="")
+
+ # Save to file
+ out_txt = os.path.join(output_dir, f"{request_id}.txt")
+ try:
+ with open(out_txt, "w", encoding="utf-8") as f:
+ f.writelines(lines)
+ print(f"Request ID: {request_id}, text saved to {out_txt}")
+ except Exception as e:
+ print(f"Failed to write output file {out_txt}: {e}")
+
+ elif stage_outputs.final_output_type == "audio":
+ raise NotImplementedError("Add audio example after talker supported.")
+
+ processed_count += 1
+ if profiler_enabled and processed_count >= total_requests:
+ print(f"[Info] Processed {processed_count}/{total_requests}. Stopping profiler inside active loop...")
+ # Stop the profiler while workers are still alive
+ omni.stop_profile(stages=args.profiler_stages)
+
+ print("[Info] Waiting 30s for workers to write trace files to disk...")
+ time.sleep(30)
+ print("[Info] Trace export wait time finished.")
+
+ omni.close()
+
+
+def parse_args():
+ parser = FlexibleArgumentParser(description="Ming-flash-omni 2.0 offline inference example")
+ parser.add_argument(
+ "--query-type",
+ "-q",
+ type=str,
+ default="text",
+ choices=query_map.keys(),
+ help="Query type.",
+ )
+ parser.add_argument(
+ "--stage-configs-path",
+ type=str,
+ default=None,
+ help="Path to a stage configs YAML file.",
+ )
+ parser.add_argument(
+ "--log-stats",
+ action="store_true",
+ default=False,
+ help="Enable detailed statistics logging.",
+ )
+ parser.add_argument("--init-timeout", type=int, default=2000, help="Timeout for initializing in seconds.")
+ parser.add_argument(
+ "--stage-init-timeout",
+ type=int,
+ default=2000,
+ help="Timeout for initializing a single stage in seconds.",
+ )
+ parser.add_argument(
+ "--enable-profiler",
+ action="store_true",
+ default=False,
+ help="Enables profiling when set.",
+ )
+ parser.add_argument(
+ "--profiler-stages",
+ type=int,
+ nargs="*",
+ default=[0],
+ help="List of stage IDs to profile. If not set, profiles all stages.",
+ )
+ parser.add_argument(
+ "--image-path",
+ "-i",
+ type=str,
+ default=None,
+ help="Path to local image file. Uses default asset if not provided.",
+ )
+ parser.add_argument(
+ "--audio-path",
+ "-a",
+ type=str,
+ default=None,
+ help="Path to local audio file. Uses default asset if not provided.",
+ )
+ parser.add_argument(
+ "--video-path",
+ "-v",
+ type=str,
+ default=None,
+ help="Path to local video file. Uses default asset if not provided.",
+ )
+ parser.add_argument(
+ "--num-frames",
+ type=int,
+ default=16,
+ help="Number of frames to extract from video.",
+ )
+ parser.add_argument(
+ "--sampling-rate",
+ type=int,
+ default=16000,
+ help="Sampling rate for audio loading.",
+ )
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ default=16384,
+ help="Maximum tokens to generate.",
+ )
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=1,
+ help="Number of prompts to generate.",
+ )
+ parser.add_argument(
+ "--modalities",
+ type=str,
+ default=None,
+ help="Output modalities (comma-separated).",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="output_ming",
+ help="Output directory for results.",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/offline_inference/qwen2_5_omni/end2end.py b/examples/offline_inference/qwen2_5_omni/end2end.py
index d8f1898ec91..dfe124700de 100644
--- a/examples/offline_inference/qwen2_5_omni/end2end.py
+++ b/examples/offline_inference/qwen2_5_omni/end2end.py
@@ -320,14 +320,7 @@ def main(args):
query_result = query_func(audio_path=audio_path, sampling_rate=sampling_rate)
else:
query_result = query_func()
- omni = Omni(
- model=model_name,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- batch_timeout=args.batch_timeout,
- init_timeout=args.init_timeout,
- shm_threshold_bytes=args.shm_threshold_bytes,
- )
+ omni = Omni.from_cli_args(args, model=model_name)
thinker_sampling_params = SamplingParams(
temperature=0.0, # Deterministic - no randomness
top_p=1.0, # Disable nucleus sampling
diff --git a/examples/offline_inference/qwen3_omni/README.md b/examples/offline_inference/qwen3_omni/README.md
index d69ad6abfc9..0710faa133c 100644
--- a/examples/offline_inference/qwen3_omni/README.md
+++ b/examples/offline_inference/qwen3_omni/README.md
@@ -70,8 +70,8 @@ For true stage-level concurrency -- where downstream stages (Talker, Code2Wav)
start **before** the upstream stage (Thinker) finishes -- use the async_chunk
example. This requires:
-1. A stage config YAML with ``async_chunk: true`` (e.g.
- ``qwen3_omni_moe_async_chunk.yaml``).
+1. A deploy config YAML with ``async_chunk: true`` (e.g.
+ ``qwen3_omni_moe.yaml``).
2. Hardware that matches the config (e.g. 2x H100 for the default 3-stage
config).
@@ -101,7 +101,7 @@ python end2end_async_chunk.py --query-type text --modalities text
```bash
python end2end_async_chunk.py \
--query-type use_audio \
- --stage-configs-path /path/to/your_async_chunk.yaml
+ --deploy-config /path/to/your_deploy_config.yaml
```
> **Note**: The synchronous ``end2end.py`` (using ``Omni``) is still the
diff --git a/examples/offline_inference/qwen3_omni/end2end.py b/examples/offline_inference/qwen3_omni/end2end.py
index 056f820ff07..f028c32aa1b 100644
--- a/examples/offline_inference/qwen3_omni/end2end.py
+++ b/examples/offline_inference/qwen3_omni/end2end.py
@@ -294,14 +294,7 @@ def main(args):
else:
query_result = query_func()
- omni = Omni(
- model=model_name,
- dtype=args.dtype,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- init_timeout=args.init_timeout,
- )
+ omni = Omni.from_cli_args(args, model=model_name)
thinker_sampling_params = SamplingParams(
temperature=0.9,
diff --git a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
index 07442631302..f38922e9437 100644
--- a/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
+++ b/examples/offline_inference/qwen3_omni/end2end_async_chunk.py
@@ -14,7 +14,7 @@
Usage
-----
python end2end_async_chunk.py --query-type use_audio \
- --stage-configs-path
+ --deploy-config
See ``--help`` for all options.
"""
@@ -179,20 +179,26 @@ def clone_prompt_for_request(template: dict) -> dict:
return cloned
-def _default_async_chunk_stage_configs_path() -> str | None:
- """Best-effort default stage config for running Qwen3-Omni with async_chunk.
+def _default_deploy_config_path() -> str | None:
+ """Best-effort default deploy config for running Qwen3-Omni with async_chunk.
- When this example is executed from within the repository, we resolve the
- default YAML path relative to this file. When installed elsewhere, the
- file may not exist and callers should pass --stage-configs-path explicitly.
+ The default ``vllm_omni/deploy/qwen3_omni_moe.yaml`` ships with
+ ``async_chunk: true`` at the top level, so loading it is enough to
+ enable async-chunk semantics. To disable it, copy the YAML and set
+ ``async_chunk: false`` (or pass ``--deploy-config`` to a YAML that
+ overrides the flag).
+
+ When this example is executed from within the repository, we resolve
+ the default YAML path relative to this file. When installed elsewhere,
+ the file may not exist and callers should pass ``--deploy-config``
+ explicitly.
"""
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
candidate = os.path.join(
repo_root,
"vllm_omni",
- "model_executor",
- "stage_configs",
- "qwen3_omni_moe_async_chunk.yaml",
+ "deploy",
+ "qwen3_omni_moe.yaml",
)
return candidate if os.path.exists(candidate) else None
@@ -374,15 +380,16 @@ async def run_all(args):
prompt["modalities"] = output_modalities
# Create AsyncOmni
- print(f"[Info] Creating AsyncOmni with stage_configs_path={args.stage_configs_path}")
+ print(f"[Info] Creating AsyncOmni with deploy_config={args.deploy_config}")
async_omni = None
try:
- async_omni = AsyncOmni(
- model=args.model,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
+ # ``from_cli_args`` expands vars(args) into kwargs and auto-captures
+ # ``_cli_explicit_keys`` from ``sys.argv[1:]`` so argparse defaults
+ # do not silently override deploy YAML values. Mirrors the
+ # ``EngineArgs.from_cli_args`` pattern used throughout vllm /
+ # vllm-omni. ``deploy_config=None`` (the default) falls through to
+ # the bundled ``vllm_omni/deploy/qwen3_omni_moe.yaml``.
+ async_omni = AsyncOmni.from_cli_args(args)
# Use default sampling params from stage config (they are pre-configured
# in the YAML for each stage).
@@ -470,11 +477,11 @@ def parse_args():
help="Query type.",
)
parser.add_argument(
- "--stage-configs-path",
+ "--deploy-config",
type=str,
- default=_default_async_chunk_stage_configs_path(),
+ default=_default_deploy_config_path(),
help=(
- "Path to an async_chunk stage config YAML. "
+ "Path to a deploy config YAML. "
"If not set, uses the model's default config "
"(make sure it has async_chunk: true)."
),
diff --git a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
index 809054867c3..2f2be20915a 100755
--- a/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
+++ b/examples/offline_inference/qwen3_omni/run_multiple_prompts_async_chunk.sh
@@ -17,7 +17,7 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
python "${SCRIPT_DIR}/end2end_async_chunk.py" \
--query-type text \
--txt-prompts "${SCRIPT_DIR}/text_prompts_10.txt" \
- --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \
+ --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \
--output-dir output_audio_async_chunk \
--max-in-flight 2 \
"$@"
diff --git a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
index 918c7ee4fd9..9ef69293cb5 100755
--- a/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
+++ b/examples/offline_inference/qwen3_omni/run_single_prompt_async_chunk.sh
@@ -6,13 +6,13 @@
# achieving true stage-level concurrency via chunk-level streaming.
#
# Prerequisites:
-# - An async_chunk stage config YAML (e.g. qwen3_omni_moe_async_chunk.yaml)
+# - A deploy config YAML (e.g. qwen3_omni_moe.yaml)
# - Hardware matching the config (e.g. 2x H100 for the default 3-stage config)
#
# Usage:
# bash run_single_prompt_async_chunk.sh
# bash run_single_prompt_async_chunk.sh --query-type text --modalities text
-# bash run_single_prompt_async_chunk.sh --stage-configs-path /path/to/custom.yaml
+# bash run_single_prompt_async_chunk.sh --deploy-config /path/to/custom.yaml
set -euo pipefail
@@ -21,6 +21,6 @@ REPO_ROOT="$(cd "${SCRIPT_DIR}/../../.." && pwd)"
python "${SCRIPT_DIR}/end2end_async_chunk.py" \
--query-type use_audio \
- --stage-configs-path "${REPO_ROOT}/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml" \
+ --deploy-config "${REPO_ROOT}/vllm_omni/deploy/qwen3_omni_moe.yaml" \
--output-dir output_audio_async_chunk \
"$@"
diff --git a/examples/offline_inference/qwen3_tts/README.md b/examples/offline_inference/qwen3_tts/README.md
index c38a2b462d1..2971ad716a2 100644
--- a/examples/offline_inference/qwen3_tts/README.md
+++ b/examples/offline_inference/qwen3_tts/README.md
@@ -104,13 +104,13 @@ completes. This demonstrates that audio data is available progressively rather t
## Batched Decoding
-The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, provide a stage config with `max_num_seqs > 1` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
+The Code2Wav stage (stage 1) supports batched decoding, where multiple requests are decoded in a single forward pass through the SpeechTokenizer. To use it, set `max_num_seqs > 1` on both stages via `--stage-overrides` and pass multiple prompts via `--txt-prompts` with a matching `--batch-size`.
```
python end2end.py --query-type CustomVoice \
--txt-prompts benchmark_prompts.txt \
--batch-size 4 \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
```
**Important:** `--batch-size` must match a CUDA graph capture size (1, 2, 4, 8, 16...) because the Talker's code predictor KV cache is sized to `max_num_seqs`, and CUDA graphs pad the batch to the next capture size. Both stages need `max_num_seqs >= batch_size` in the stage config for batching to take effect. If only stage 1 has a higher `max_num_seqs`, it won't help — stage 1 can only batch chunks from requests that are in-flight simultaneously, which requires stage 0 to also process multiple requests concurrently.
diff --git a/examples/offline_inference/qwen3_tts/end2end.py b/examples/offline_inference/qwen3_tts/end2end.py
index 901418c39b8..77da356b4f8 100644
--- a/examples/offline_inference/qwen3_tts/end2end.py
+++ b/examples/offline_inference/qwen3_tts/end2end.py
@@ -366,12 +366,7 @@ def main(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
- omni = Omni(
- model=model_name,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
+ omni = Omni.from_cli_args(args, model=model_name)
batch_size = args.batch_size
for batch_start in range(0, len(inputs), batch_size):
@@ -387,12 +382,7 @@ async def main_streaming(args):
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
- omni = AsyncOmni(
- model=model_name,
- stage_configs_path=args.stage_configs_path,
- log_stats=args.log_stats,
- stage_init_timeout=args.stage_init_timeout,
- )
+ omni = AsyncOmni.from_cli_args(args, model=model_name)
for i, prompt in enumerate(inputs):
request_id = str(i)
diff --git a/examples/offline_inference/voxcpm2/end2end.py b/examples/offline_inference/voxcpm2/end2end.py
index 687e596018c..6b6bf78ddf1 100644
--- a/examples/offline_inference/voxcpm2/end2end.py
+++ b/examples/offline_inference/voxcpm2/end2end.py
@@ -65,6 +65,12 @@ def parse_args():
default=None,
help="Text matching --prompt-audio for continuation mode.",
)
+ parser.add_argument(
+ "--ref-text",
+ type=str,
+ default=None,
+ help="Optional transcript of --reference-audio (enables ref_continuation mode).",
+ )
return parser.parse_args()
@@ -103,24 +109,40 @@ def main():
stage_configs_path=args.stage_configs_path,
)
- additional: dict = {}
- if args.reference_audio:
- additional["reference_audio"] = args.reference_audio
- if args.prompt_audio and args.prompt_text:
- additional["prompt_audio"] = args.prompt_audio
- additional["prompt_text"] = args.prompt_text
+ from transformers import AutoTokenizer
- prompt: dict = {"prompt": args.text}
- if additional:
- prompt["additional_information"] = additional
+ from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import (
+ build_cjk_split_map,
+ build_voxcpm2_prompt,
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
+ split_map = build_cjk_split_map(tokenizer)
+ hf_config = engine.engine.stage_vllm_configs[0].model_config.hf_config
+
+ ref_audio_arg = args.reference_audio or args.prompt_audio
+ ref_text_arg = args.ref_text or args.prompt_text
+ ref_wav, ref_sr = (None, None)
+ if ref_audio_arg:
+ ref_wav_arr, ref_sr = sf.read(ref_audio_arg)
+ ref_wav = ref_wav_arr.mean(axis=-1).tolist() if ref_wav_arr.ndim > 1 else ref_wav_arr.tolist()
+
+ prompt = build_voxcpm2_prompt(
+ hf_config=hf_config,
+ tokenizer=tokenizer,
+ split_map=split_map,
+ text=args.text,
+ ref_audio=ref_wav,
+ ref_sr=ref_sr,
+ ref_text=ref_text_arg,
+ )
print(f"Model : {args.model}")
print(f"Text : {args.text}")
- if args.reference_audio:
- print(f"Ref audio : {args.reference_audio}")
- if args.prompt_audio:
- print(f"Prompt audio: {args.prompt_audio}")
- print(f"Prompt text : {args.prompt_text}")
+ if ref_audio_arg:
+ print(f"Ref audio : {ref_audio_arg}")
+ if ref_text_arg:
+ print(f"Ref text : {ref_text_arg}")
print(f"Output dir : {output_dir}")
t_start = time.perf_counter()
diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
index 322b184e520..497284ceb96 100644
--- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
+++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py
@@ -58,6 +58,11 @@ def parse_args() -> argparse.Namespace:
default=False,
help="Enable CPU offloading for diffusion models.",
)
+ parser.add_argument(
+ "--enable-layerwise-offload",
+ action="store_true",
+ help="Enable layerwise (blockwise) offloading on DiT modules.",
+ )
return parser.parse_args()
@@ -126,6 +131,7 @@ def main() -> None:
parallel_config=parallel_config,
model_type=args.model_type,
enable_cpu_offload=args.enable_cpu_offload,
+ enable_layerwise_offload=args.enable_layerwise_offload,
)
start = time.perf_counter()
outputs = omni.generate(prompt, sampling_params)
diff --git a/examples/online_serving/ming_flash_omni/README.md b/examples/online_serving/ming_flash_omni/README.md
new file mode 100644
index 00000000000..502232725c2
--- /dev/null
+++ b/examples/online_serving/ming_flash_omni/README.md
@@ -0,0 +1,204 @@
+# Ming-flash-omni 2.0
+
+## Installation
+
+Please refer to [README.md](../../../README.md)
+
+## Run examples (Ming-flash-omni 2.0)
+
+### Launch the Server
+
+```bash
+vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091
+```
+
+If you have custom stage configs file, launch the server with command below
+```bash
+vllm serve Jonathan1909/Ming-flash-omni-2.0 --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+```
+
+### Send Multi-modal Request
+
+#### Send request via python
+
+```bash
+python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py --model Jonathan1909/Ming-flash-omni-2.0 --query-type use_mixed_modalities --port 8091 --host "localhost" --modalities text
+```
+
+The Python client supports the following command-line arguments:
+
+- `--query-type` (or `-q`): Query type. Options: `text`, `use_audio`, `use_image`, `use_video`, `use_mixed_modalities`
+- `--video-path` (or `-v`): Path to local video file or URL. If not provided and query-type uses video, uses default video URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs. Example: `--video-path /path/to/video.mp4` or `--video-path https://example.com/video.mp4`
+- `--image-path` (or `-i`): Path to local image file or URL. If not provided and query-type uses image, uses default image URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common image formats: JPEG, PNG, GIF, WebP. Example: `--image-path /path/to/image.jpg` or `--image-path https://example.com/image.png`
+- `--audio-path` (or `-a`): Path to local audio file or URL. If not provided and query-type uses audio, uses default audio URL. Supports local file paths (automatically encoded to base64) or HTTP/HTTPS URLs and common audio formats: MP3, WAV, OGG, FLAC, M4A. Example: `--audio-path /path/to/audio.wav` or `--audio-path https://example.com/audio.mp3`
+- `--prompt` (or `-p`): Custom text prompt/question. If not provided, uses default prompt for the selected query type. Example: `--prompt "What are the main activities shown in this video?"`
+- `--modalities`: Output modalities. For now, only `text` is supported. Example: `--modalities text`
+
+
+#### Send request via curl
+
+```bash
+bash run_curl_multimodal_generation.sh text
+bash run_curl_multimodal_generation.sh use_image
+bash run_curl_multimodal_generation.sh use_audio
+bash run_curl_multimodal_generation.sh use_video
+bash run_curl_multimodal_generation.sh use_mixed_modalities
+```
+
+## Modality control
+
+Ming-flash-omni 2.0 currently supports text output only (thinker stage).
+
+| Modalities | Output |
+|------------|--------|
+| `["text"]` | Text only |
+| Not specified | Text only (default) |
+
+### Using curl
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Jonathan1909/Ming-flash-omni-2.0",
+ "messages": [
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}
+ ],
+ "modalities": ["text"]
+ }'
+```
+
+### Using OpenAI Python SDK
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"},
+ ],
+ modalities=["text"],
+)
+print(response.choices[0].message.content)
+```
+
+### Multi-modal input with OpenAI Python SDK
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": {"url": "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"}},
+ {"type": "text", "text": "Describe this image in detail."},
+ ],
+ },
+ ],
+ modalities=["text"],
+)
+print(response.choices[0].message.content)
+```
+
+## Streaming Output
+
+To enable streaming output:
+
+```bash
+python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \
+ --query-type use_image \
+ --model Jonathan1909/Ming-flash-omni-2.0 \
+ --modalities text \
+ --stream
+```
+
+Or with the OpenAI Python SDK:
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"},
+ ],
+ modalities=["text"],
+ stream=True,
+)
+for chunk in response:
+ for choice in chunk.choices:
+ if hasattr(choice, "delta") and choice.delta.content:
+ print(choice.delta.content, end="", flush=True)
+print()
+```
+
+Or using curl:
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Jonathan1909/Ming-flash-omni-2.0",
+ "messages": [
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking off"}]},
+ {"role": "user", "content": "请详细介绍鹦鹉的生活习性。"}
+ ],
+ "modalities": ["text"],
+ "stream": true,
+ }'
+```
+
+
+## Reasoning (Thinking Mode)
+
+To enable reasoning/thinking mode, change `detailed thinking off` to `detailed thinking on` in the system prompt:
+
+### Using curl
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Jonathan1909/Ming-flash-omni-2.0",
+ "messages": [
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]},
+ {"role": "user", "content": [
+ {"type": "image_url", "image_url": {"url": "https://example.com/math_problem.png"}},
+ {"type": "text", "text": "Solve this math problem step by step."}
+ ]}
+ ],
+ "modalities": ["text"]
+ }'
+```
+
+### Using OpenAI Python SDK
+
+```python
+from openai import OpenAI
+
+client = OpenAI(base_url="http://localhost:8091/v1", api_key="EMPTY")
+
+response = client.chat.completions.create(
+ model="Jonathan1909/Ming-flash-omni-2.0",
+ messages=[
+ {"role": "system", "content": [{"type": "text", "text": "你是一个友好的AI助手。\n\ndetailed thinking on"}]},
+ {"role": "user", "content": "If a train travels 120 km in 2 hours, what is its average speed?"},
+ ],
+ modalities=["text"],
+)
+print(response.choices[0].message.content)
+```
diff --git a/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh
new file mode 100755
index 00000000000..768a424e451
--- /dev/null
+++ b/examples/online_serving/ming_flash_omni/run_curl_multimodal_generation.sh
@@ -0,0 +1,145 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# Server port
+PORT="${PORT:-8091}"
+# Default query type
+QUERY_TYPE="${1:-text}"
+
+# Validate query type
+if [[ ! "$QUERY_TYPE" =~ ^(text|use_audio|use_image|use_video|use_mixed_modalities)$ ]]; then
+ echo "Error: Invalid query type '$QUERY_TYPE'"
+ echo "Usage: $0 [text|use_audio|use_image|use_video|use_mixed_modalities]"
+ echo " text: Text-only query"
+ echo " use_audio: Audio + Text query"
+ echo " use_image: Image + Text query"
+ echo " use_video: Video + Text query"
+ echo " use_mixed_modalities: Audio + Image + Video + Text query"
+ exit 1
+fi
+
+thinker_sampling_params='{
+ "temperature": 0.4,
+ "top_p": 0.9,
+ "top_k": -1,
+ "max_tokens": 16384,
+ "seed": 42,
+ "detokenize": true,
+ "repetition_penalty": 1.05
+}'
+# Above is optional, it has a default setting in stage_configs of the corresponding model.
+
+# Define URLs for assets
+MARY_HAD_LAMB_AUDIO_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/mary_had_lamb.ogg"
+CHERRY_BLOSSOM_IMAGE_URL="https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/cherry_blossom.jpg"
+SAMPLE_VIDEO_URL="https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4"
+
+# Build user content based on query type
+case "$QUERY_TYPE" in
+ text)
+ user_content='[
+ {
+ "type": "text",
+ "text": "请详细介绍鹦鹉的生活习性。"
+ }
+ ]'
+ ;;
+ use_image)
+ user_content='[
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Describe this image in detail."
+ }
+ ]'
+ ;;
+ use_audio)
+ user_content='[
+ {
+ "type": "audio_url",
+ "audio_url": {
+ "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Please recognize the language of this speech and transcribe it. Format: oral."
+ }
+ ]'
+ ;;
+ use_video)
+ user_content='[
+ {
+ "type": "video_url",
+ "video_url": {
+ "url": "'"$SAMPLE_VIDEO_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Describe what is happening in this video."
+ }
+ ]'
+ ;;
+ use_mixed_modalities)
+ user_content='[
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "'"$CHERRY_BLOSSOM_IMAGE_URL"'"
+ }
+ },
+ {
+ "type": "audio_url",
+ "audio_url": {
+ "url": "'"$MARY_HAD_LAMB_AUDIO_URL"'"
+ }
+ },
+ {
+ "type": "text",
+ "text": "Describe the image, and recognize the language of this speech and transcribe it. Format: oral"
+ }
+ ]'
+ ;;
+esac
+
+echo "Running query type: $QUERY_TYPE"
+echo ""
+
+request_body=$(cat < **Note on `--no-async-chunk`**: Flips the deploy yaml's `async_chunk:`
+> bool. Pipelines that implement alternate processor functions for
+> chunked vs end-to-end modes (e.g. qwen3_tts code2wav) dispatch
+> automatically based on that bool — no extra flag or variant yaml is
+> needed.
+
+> ⚠️ **For multi-stage models that share GPUs (qwen3_omni_moe by default
+> shares cuda:1 between stages 1 and 2), avoid using global memory flags.**
+> A global `--gpu-memory-utilization 0.85` would apply to every stage and
+> oversubscribe the shared device. Use per-stage overrides instead — see
+> below.
+
+#### 2. Per-stage overrides via `--stage-overrides` (recommended for memory)
+
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+# Lower stage 1's memory budget; leave others at the YAML default
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 \
+ --stage-overrides '{
+ "1": {"gpu_memory_utilization": 0.5},
+ "2": {"max_num_batched_tokens": 65536}
+ }'
+```
+
+Per-stage values are always treated as explicit and beat YAML defaults for
+the named stage. Other stages keep their YAML values.
+
+#### 3. Custom deploy YAML
+
+When per-stage overrides get long, write a small overlay YAML that inherits
+from the bundled default:
+
+```yaml
+# my_qwen3_omni_overrides.yaml
+base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
+
+stages:
+ - stage_id: 0
+ max_num_batched_tokens: 65536
+ enforce_eager: true
+ - stage_id: 1
+ gpu_memory_utilization: 0.5
+ - stage_id: 2
+ max_model_len: 8192
```
+Then start the server with `--deploy-config my_qwen3_omni_overrides.yaml`.
+The `base_config:` line tells the loader to inherit everything else (stages,
+connectors, edges, platforms section) from the bundled production YAML, so
+you only need to spell out the deltas.
+
+#### 4. Multi-node deployment (cross-host transfer connector)
+
+The bundled `qwen3_omni_moe.yaml` uses `SharedMemoryConnector` between stages,
+which only works when all stages run on the same physical host. For
+**cross-node** deployments, write a small overlay YAML that swaps in a
+network-capable connector (e.g. `MooncakeStoreConnector`) and re-points each
+stage's connector wiring at it. The connector spec carries your own server
+addresses — there is no checked-in default because every cluster is
+different.
+
+```yaml
+# my_qwen3_omni_multinode.yaml
+base_config: /path/to/vllm_omni/deploy/qwen3_omni_moe.yaml
+
+connectors:
+ mooncake_connector:
+ name: MooncakeStoreConnector
+ extra:
+ host: "127.0.0.1"
+ metadata_server: "http://YOUR_METADATA_HOST:8080/metadata"
+ master: "YOUR_MASTER_HOST:50051"
+ segment: 512000000 # 512 MB transfer segment
+ localbuf: 64000000 # 64 MB local buffer
+ proto: "tcp"
+
+stages:
+ - stage_id: 0
+ output_connectors:
+ to_stage_1: mooncake_connector
+ - stage_id: 1
+ input_connectors:
+ from_stage_0: mooncake_connector
+ output_connectors:
+ to_stage_2: mooncake_connector
+ - stage_id: 2
+ input_connectors:
+ from_stage_1: mooncake_connector
+```
+
+Then launch with `--deploy-config my_qwen3_omni_multinode.yaml`. Same
+pattern works for Qwen2.5-Omni — replace `base_config:` with the path to
+`vllm_omni/deploy/qwen2_5_omni.yaml`.
+
+> ⚠️ Replace `YOUR_METADATA_HOST` / `YOUR_MASTER_HOST` with the actual
+> mooncake server addresses for your cluster. The `base_config:` overlay
+> inherits all stage budgets, devices, and edges from the bundled prod
+> YAML — you only need to spell out the connector swap.
+
### Send Multi-modal Request
Get into the example folder
@@ -38,36 +188,43 @@ python examples/online_serving/openai_chat_completion_client_for_multimodal_gene
#### Realtime WebSocket client (`openai_realtime_client.py`)
-[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, uploads a local audio file as **PCM16 mono @ 16 kHz** chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and prints **streaming transcription** (`transcription.delta` / `transcription.done`).
+[`openai_realtime_client.py`](./openai_realtime_client.py) connects to **`ws://:/v1/realtime`**, streams a local WAV as **PCM16 mono @ 16 kHz** in fixed-size chunks (OpenAI-style `input_audio_buffer.append` / `commit`), and receives **`response.audio.delta`** (incremental PCM for the reply) plus **`transcription.*`** events. By default it concatenates audio deltas and writes **`--output-wav`** (model output is typically **24 kHz**). Optional **`--delta-dump-dir`** saves each delta as `delta_000001.wav`, … for debugging.
+
+Streaming input works well for translation-style use cases; if the Thinker runs while input is still incomplete, consider limiting **`max_tokens`** in your session / server defaults to avoid over-generation.
**Dependencies:**
```bash
-pip install websockets numpy
+pip install websockets
```
**From this directory** (`examples/online_serving/qwen3_omni`):
```bash
python openai_realtime_client.py \
- --host localhost \
- --port 8091 \
+ --url ws://localhost:8091/v1/realtime \
--model Qwen/Qwen3-Omni-30B-A3B-Instruct \
- --audio_path /path/to/your.wav
+ --input-wav /path/to/input_16k_mono.wav \
+ --output-wav realtime_output.wav \
+ --delta-dump-dir ./rt_delta_wavs
```
-If `--audio_path` is omitted, the script uses a bundled default clip (`mary_had_lamb` via vLLM assets).
-
**Arguments:**
| Flag | Default | Description |
|------|---------|-------------|
-| `--host` | `localhost` | API server host |
-| `--port` | `8000` | API server port (match your `vllm serve` port, e.g. `8091`) |
-| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (also sent in `session.update`) |
-| `--audio_path` | *(optional)* | Path to input audio; resampled to 16 kHz mono inside the client |
-
-Ensure the vLLM-Omni server is running with realtime support for this endpoint, for example:
+| `--url` | `ws://localhost:8091/v1/realtime` | Full WebSocket URL including path |
+| `--model` | `Qwen/Qwen3-Omni-30B-A3B-Instruct` | Must match the served model (sent in `session.update`) |
+| `--input-wav` | *(required)* | Input WAV: mono, 16-bit PCM, **16 kHz** |
+| `--output-wav` | `realtime_output.wav` | Output path for concatenated reply audio |
+| `--output-text` | *(optional)* | If set, write final transcription text to this path |
+| `--chunk-ms` | `200` | Size of each uploaded audio chunk (milliseconds of audio) |
+| `--send-delay-ms` | `0` | Delay between chunk sends (simulate realtime upload) |
+| `--delta-dump-dir` | *(optional)* | Directory to write per-`response.audio.delta` WAV files |
+| `--num-requests` | `1` | Number of sequential sessions (see `--concurrency`) |
+| `--concurrency` | `1` | Max concurrent WebSocket sessions when `--num-requests` > 1 |
+
+Ensure the server is running **without** `async_chunk` if you use `/v1/realtime`, for example:
```bash
vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
@@ -276,7 +433,7 @@ The script supports the following arguments:
- `--model`: Model name/path (default: Qwen/Qwen3-Omni-30B-A3B-Instruct)
- `--server-port`: Port for vLLM server (default: 8091)
- `--gradio-port`: Port for Gradio demo (default: 7861)
-- `--stage-configs-path`: Path to custom stage configs YAML file (optional)
+- `--deploy-config`: Path to custom deploy config YAML file (optional)
- `--server-host`: Host for vLLM server (default: 0.0.0.0)
- `--gradio-ip`: IP for Gradio demo (default: 127.0.0.1)
- `--share`: Share Gradio demo publicly (creates a public link)
@@ -291,7 +448,7 @@ vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
If you have custom stage configs file:
```bash
-vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --stage-configs-path /path/to/stage_configs_file
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091 --deploy-config /path/to/deploy_config_file
```
**Step 2: Run the Gradio demo**
diff --git a/examples/online_serving/qwen3_omni/openai_realtime_client.py b/examples/online_serving/qwen3_omni/openai_realtime_client.py
index 660e4ac336a..79e30a3f50b 100644
--- a/examples/online_serving/qwen3_omni/openai_realtime_client.py
+++ b/examples/online_serving/qwen3_omni/openai_realtime_client.py
@@ -1,81 +1,118 @@
-"""
-This script demonstrates how to use the vLLM-Omni Realtime WebSocket API to perform
-audio transcription by uploading an audio file.
+"""Realtime client for vLLM-Omni /v1/realtime (audio + text events).
+
+This client:
+1) Reads a local WAV file (must be mono, 16-bit PCM, 16kHz),
+2) Streams PCM16 chunks to /v1/realtime with OpenAI-style events,
+3) Receives response.audio.* and transcription.* events,
+4) Saves synthesized audio to an output WAV file and optional text file.
-Before running this script, you must start the vLLM-Omni server with a realtime-capable
-model, for example:
+By default each ``response.audio.delta`` is treated as an **incremental PCM**
+chunk and all chunks are concatenated into the final ``--output-wav``.
- vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni
+Optional debugging: pass ``--delta-dump-dir DIR`` to write every
+``response.audio.delta`` payload as ``delta_000001.wav``, ``delta_000002.wav``, …
-Requirements:
-- vllm with audio support
-- websockets
-- soundfile
-- numpy
+Usage:
+ python openai_realtime_client.py \
+ --url ws://localhost:8091/v1/realtime \
+ --model Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --input-wav input_16k_mono.wav \
+ --output-wav realtime_output.wav \
+ --delta-dump-dir ./rt_delta_wavs
-The script:
-1. Connects to the Realtime WebSocket endpoint
-2. Converts an audio file to PCM16 @ 16kHz
-3. Sends audio chunks to the server
-4. Receives and prints transcription as it streams
+Dependencies:
+ pip install websockets
"""
+from __future__ import annotations
+
import argparse
import asyncio
import base64
import json
+import wave
+from pathlib import Path
+
+try:
+ import websockets
+except ImportError:
+ print("Please install websockets: pip install websockets")
+ raise SystemExit(1)
+
+
+def _read_wav_pcm16(path: Path) -> bytes:
+ with wave.open(str(path), "rb") as wf:
+ nchannels = wf.getnchannels()
+ sampwidth = wf.getsampwidth()
+ framerate = wf.getframerate()
+ comptype = wf.getcomptype()
+ nframes = wf.getnframes()
+
+ if nchannels != 1:
+ raise ValueError(f"Input WAV must be mono (got {nchannels} channels).")
+ if sampwidth != 2:
+ raise ValueError(f"Input WAV must be 16-bit PCM (got sample width={sampwidth}).")
+ if framerate != 16000:
+ raise ValueError(f"Input WAV must be 16kHz (got {framerate} Hz).")
+ if comptype != "NONE":
+ raise ValueError(f"Input WAV must be uncompressed PCM (got comptype={comptype}).")
+ if nframes <= 0:
+ raise ValueError("Input WAV has no audio frames.")
+
+ return wf.readframes(nframes)
+
+
+def _write_wav_pcm16(path: Path, pcm16_bytes: bytes, sample_rate_hz: int) -> None:
+ with wave.open(str(path), "wb") as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(sample_rate_hz)
+ wf.writeframes(pcm16_bytes)
+
+
+async def run_client(
+ url: str,
+ model: str,
+ input_wav: Path,
+ output_wav: Path,
+ output_text: Path | None,
+ chunk_ms: int,
+ send_delay_ms: int,
+ delta_dump_dir: Path | None,
+ request_idx: int = 1,
+ total_requests: int = 1,
+) -> None:
+ log_prefix = f"[req {request_idx:02d}/{total_requests:02d}] " if total_requests > 1 else ""
+ pcm16 = _read_wav_pcm16(input_wav)
+ bytes_per_ms = 16000 * 2 // 1000 # mono PCM16 at 16kHz
+ chunk_bytes = max(bytes_per_ms * chunk_ms, 2)
-import numpy as np
-import websockets
-from vllm.assets.audio import AudioAsset
-from vllm.multimodal.media.audio import load_audio
-
-
-def audio_to_pcm16_base64(audio_path: str) -> str:
- """
- Load an audio file and convert it to base64-encoded PCM16 @ 16kHz.
- """
- # Load audio and resample to 16kHz mono
- audio, _ = load_audio(audio_path, sr=16000, mono=True)
- # Convert to PCM16
- pcm16 = (audio * 32767).astype(np.int16)
- # Encode as base64
- return base64.b64encode(pcm16.tobytes()).decode("utf-8")
-
-
-async def realtime_transcribe(audio_path: str, host: str, port: int, model: str):
- """
- Connect to the Realtime API and transcribe an audio file.
- """
- uri = f"ws://{host}:{port}/v1/realtime"
-
- async with websockets.connect(uri) as ws:
- # Wait for session.created
- response = json.loads(await ws.recv())
- if response["type"] == "session.created":
- print(f"Session created: {response['id']}")
- else:
- print(f"Unexpected response: {response}")
- return
-
- # Validate model
- await ws.send(json.dumps({"type": "session.update", "model": model}))
-
- # Signal ready to start
- await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
-
- # Convert audio file to base64 PCM16
- print(f"Loading audio from: {audio_path}")
- audio_base64 = audio_to_pcm16_base64(audio_path)
-
- # Send audio in chunks (4KB of raw audio = ~8KB base64)
- chunk_size = 4096
- audio_bytes = base64.b64decode(audio_base64)
- total_chunks = (len(audio_bytes) + chunk_size - 1) // chunk_size
-
- print(f"Sending {total_chunks} audio chunks...")
- for i in range(0, len(audio_bytes), chunk_size):
- chunk = audio_bytes[i : i + chunk_size]
+ incremental_pcm_parts: list[bytes] = []
+ output_sample_rate = 24000
+ delta_index = 0
+ text_chunks: list[str] = []
+ final_text: str = ""
+
+ if delta_dump_dir is not None:
+ delta_dump_dir.mkdir(parents=True, exist_ok=True)
+
+ async with websockets.connect(url, max_size=64 * 1024 * 1024) as ws:
+ # 1) Validate model.
+ await ws.send(
+ json.dumps(
+ {
+ "type": "session.update",
+ "model": model,
+ }
+ )
+ )
+
+ # 2) Start generation once (non-final commit).
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False}))
+
+ # 3) Stream audio chunks.
+ for i in range(0, len(pcm16), chunk_bytes):
+ chunk = pcm16[i : i + chunk_bytes]
await ws.send(
json.dumps(
{
@@ -84,63 +121,212 @@ async def realtime_transcribe(audio_path: str, host: str, port: int, model: str)
}
)
)
+ if send_delay_ms > 0:
+ await asyncio.sleep(send_delay_ms / 1000.0)
- # Signal all audio is sent
+ # 4) Final commit closes input stream.
await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
- print("Audio sent. Waiting for transcription...\n")
- # Receive transcription
- print("Transcription: ", end="", flush=True)
+ # 5) Receive server events until audio done.
while True:
- response = json.loads(await ws.recv())
- if response["type"] == "transcription.delta":
- print(response["delta"], end="", flush=True)
- elif response["type"] == "transcription.done":
- print(f"\n\nFinal transcription: {response['text']}")
- if response.get("usage"):
- print(f"Usage: {response['usage']}")
- break
- elif response["type"] == "error":
- print(f"\nError: {response['error']}")
+ message = await ws.recv()
+ if isinstance(message, bytes):
+ # We only expect JSON text frames.
+ continue
+
+ event = json.loads(message)
+ event_type = event.get("type")
+
+ if event_type == "session.created":
+ continue
+
+ if event_type == "response.audio.delta":
+ sr = event.get("sample_rate_hz")
+ if isinstance(sr, int) and sr > 0:
+ output_sample_rate = sr
+ audio_b64 = event.get("audio", "")
+ if audio_b64:
+ pcm_delta = base64.b64decode(audio_b64)
+ incremental_pcm_parts.append(pcm_delta)
+ if delta_dump_dir is not None and pcm_delta:
+ delta_index += 1
+ dump_path = delta_dump_dir / f"delta_{delta_index:06d}.wav"
+ _write_wav_pcm16(dump_path, pcm_delta, output_sample_rate)
+ print(
+ f"{log_prefix}delta dump #{delta_index}: {dump_path} "
+ f"(pcm bytes={len(pcm_delta)}, sr={output_sample_rate})"
+ )
+ continue
+
+ if event_type == "transcription.delta":
+ delta = event.get("delta", "")
+ if delta:
+ text_chunks.append(delta)
+ print(delta, end="", flush=True)
+ continue
+
+ if event_type == "transcription.done":
+ final_text = event.get("text", "") or "".join(text_chunks)
+ usage = event.get("usage")
+ final_text_with_tag = f"Final transcription: {final_text}"
+ if text_chunks:
+ print()
+ print(f"{log_prefix}{final_text_with_tag}")
+ if usage:
+ print(f"{log_prefix}text usage: {usage}")
+ continue
+
+ if event_type == "response.audio.done":
break
+ if event_type == "error":
+ raise RuntimeError(f"Server error: {event}")
-def main(args):
- if args.audio_path:
- audio_path = args.audio_path
- else:
- # Use default audio asset
- audio_path = str(AudioAsset("mary_had_lamb").get_local_path())
- print(f"No audio path provided, using default: {audio_path}")
+ all_pcm16 = b"".join(incremental_pcm_parts)
+ if not all_pcm16:
+ raise RuntimeError("No audio received from server.")
- asyncio.run(realtime_transcribe(audio_path, args.host, args.port, args.model))
+ output_wav.parent.mkdir(parents=True, exist_ok=True)
+ _write_wav_pcm16(output_wav, all_pcm16, output_sample_rate)
+ print(f"{log_prefix}Saved realtime audio to: {output_wav} (incremental chunks joined)")
+ if output_text is not None:
+ text_to_save = final_text if final_text else "".join(text_chunks)
+ output_text.parent.mkdir(parents=True, exist_ok=True)
+ output_text.write_text(text_to_save, encoding="utf-8")
+ print(f"{log_prefix}Saved realtime text to: {output_text}")
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Realtime WebSocket Transcription Client")
+
+def _indexed_output_path(path: Path | None, index: int, total: int) -> Path | None:
+ if path is None or total <= 1:
+ return path
+ return path.with_name(f"{path.stem}_{index:02d}{path.suffix}")
+
+
+async def run_clients_concurrent(
+ *,
+ url: str,
+ model: str,
+ input_wav: Path,
+ output_wav: Path,
+ output_text: Path | None,
+ chunk_ms: int,
+ send_delay_ms: int,
+ delta_dump_dir: Path | None,
+ num_requests: int,
+ concurrency: int,
+) -> None:
+ sem = asyncio.Semaphore(concurrency)
+
+ async def _run_one(index: int) -> tuple[int, bool, str | None]:
+ per_output_wav = _indexed_output_path(output_wav, index, num_requests)
+ per_output_text = _indexed_output_path(output_text, index, num_requests)
+ per_delta_dir = None
+ if delta_dump_dir is not None:
+ per_delta_dir = delta_dump_dir / f"req_{index:02d}"
+ async with sem:
+ try:
+ await run_client(
+ url=url,
+ model=model,
+ input_wav=input_wav,
+ output_wav=per_output_wav,
+ output_text=per_output_text,
+ chunk_ms=chunk_ms,
+ send_delay_ms=send_delay_ms,
+ delta_dump_dir=per_delta_dir,
+ request_idx=index,
+ total_requests=num_requests,
+ )
+ return index, True, None
+ except Exception as exc:
+ return index, False, str(exc)
+
+ tasks = [asyncio.create_task(_run_one(i), name=f"rt-client-{i}") for i in range(1, num_requests + 1)]
+ results = await asyncio.gather(*tasks)
+
+ failed = [(idx, err) for idx, ok, err in results if not ok]
+ succeeded = num_requests - len(failed)
+ print(f"[summary] succeeded={succeeded}, failed={len(failed)}, total={num_requests}")
+ if failed:
+ for idx, err in failed:
+ print(f"[summary] req {idx:02d} failed: {err}")
+ raise RuntimeError(f"{len(failed)} concurrent request(s) failed")
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Realtime audio/text client for vLLM-Omni")
+ parser.add_argument("--url", default="ws://localhost:8091/v1/realtime", help="WebSocket URL")
parser.add_argument(
"--model",
- type=str,
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
- help="Model that is served and should be pinged.",
+ help="Model name for session.update",
)
+ parser.add_argument("--input-wav", required=True, type=Path, help="Input WAV (mono, PCM16, 16kHz)")
+ parser.add_argument("--output-wav", default=Path("realtime_output.wav"), type=Path, help="Output WAV path")
parser.add_argument(
- "--audio_path",
- type=str,
+ "--output-text",
default=None,
- help="Path to the audio file to transcribe.",
+ type=Path,
+ help="Optional output text path for final transcription",
)
+ parser.add_argument("--chunk-ms", type=int, default=200, help="Input chunk size in milliseconds")
parser.add_argument(
- "--host",
- type=str,
- default="localhost",
- help="vLLM-Omni server host (default: localhost)",
+ "--send-delay-ms",
+ type=int,
+ default=0,
+ help="Delay between chunk sends; set >0 to simulate realtime upload",
)
parser.add_argument(
- "--port",
+ "--delta-dump-dir",
+ type=Path,
+ default=None,
+ help="If set, each response.audio.delta is saved as delta_NNNNNN.wav under this directory",
+ )
+ parser.add_argument("--num-requests", type=int, default=1, help="Total number of requests to send")
+ parser.add_argument(
+ "--concurrency",
type=int,
- default=8000,
- help="vLLM-Omni server port (default: 8000)",
+ default=1,
+ help="Maximum number of concurrent websocket requests",
)
args = parser.parse_args()
- main(args)
+
+ if args.num_requests <= 0:
+ raise ValueError("--num-requests must be >= 1")
+ if args.concurrency <= 0:
+ raise ValueError("--concurrency must be >= 1")
+ concurrency = min(args.concurrency, args.num_requests)
+
+ if args.num_requests == 1:
+ asyncio.run(
+ run_client(
+ url=args.url,
+ model=args.model,
+ input_wav=args.input_wav,
+ output_wav=args.output_wav,
+ output_text=args.output_text,
+ chunk_ms=args.chunk_ms,
+ send_delay_ms=args.send_delay_ms,
+ delta_dump_dir=args.delta_dump_dir,
+ )
+ )
+ else:
+ asyncio.run(
+ run_clients_concurrent(
+ url=args.url,
+ model=args.model,
+ input_wav=args.input_wav,
+ output_wav=args.output_wav,
+ output_text=args.output_text,
+ chunk_ms=args.chunk_ms,
+ send_delay_ms=args.send_delay_ms,
+ delta_dump_dir=args.delta_dump_dir,
+ num_requests=args.num_requests,
+ concurrency=concurrency,
+ )
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/online_serving/qwen3_tts/README.md b/examples/online_serving/qwen3_tts/README.md
index b48db9cf453..350fcb71cac 100644
--- a/examples/online_serving/qwen3_tts/README.md
+++ b/examples/online_serving/qwen3_tts/README.md
@@ -43,7 +43,7 @@ Then open http://localhost:7860 in your browser.
### Launch the Server
-The default stage config is located at `vllm_omni/model_executor/stage_configs/qwen3_tts.yaml`. For other platforms (e.g., NPU), refer to `vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml`.
+The default deploy config is located at `vllm_omni/deploy/qwen3_tts.yaml` and is loaded automatically by the model registry — no `--deploy-config` flag needed for default use. Platform-specific deltas (NPU, ROCm, XPU) are merged in automatically from the `platforms:` block of the same YAML based on the detected runtime.
```bash
# CustomVoice model (predefined speakers)
@@ -70,6 +70,22 @@ vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
--port 8091
```
+#### Sync vs async-chunk mode
+
+Qwen3-TTS supports both **chunked streaming** (default, lower latency) and
+**synchronous end-to-end** modes from the same deploy YAML. The bundled
+`qwen3_tts.yaml` ships with `async_chunk: true`; flip with `--no-async-chunk`
+and the pipeline automatically dispatches to the end-to-end codec processor:
+
+```bash
+vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice --omni --port 8091 \
+ --no-async-chunk
+```
+
+No variant YAML or extra flag is needed — the `StagePipelineConfig` on each
+stage declares both processor functions and the runtime picks based on the
+`async_chunk:` bool.
+
Alternatively, use the convenience script:
```bash
./run_server.sh # Default: CustomVoice model
diff --git a/examples/online_serving/qwen3_tts/batch_speech_client.py b/examples/online_serving/qwen3_tts/batch_speech_client.py
index 7d48e650f88..47fdc3691c7 100644
--- a/examples/online_serving/qwen3_tts/batch_speech_client.py
+++ b/examples/online_serving/qwen3_tts/batch_speech_client.py
@@ -5,11 +5,13 @@
batch level and generate many utterances in the cloned voice without repeating
the reference for each item.
-Start the server (with batch-optimized config for best throughput):
+Start the server (with batch-optimized stage settings for best throughput):
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml \
- --trust-remote-code
+ --omni \
+ --trust-remote-code \
+ --stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},
+ "1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
Examples:
# Batch with a predefined voice
diff --git a/examples/online_serving/qwen3_tts/run_gradio_demo.sh b/examples/online_serving/qwen3_tts/run_gradio_demo.sh
index bcc0ddb7cf5..d79be3c2abd 100644
--- a/examples/online_serving/qwen3_tts/run_gradio_demo.sh
+++ b/examples/online_serving/qwen3_tts/run_gradio_demo.sh
@@ -127,7 +127,7 @@ echo "Starting vLLM server..."
LOG_FILE="/tmp/vllm_tts_server_${SERVER_PORT}.log"
vllm-omni serve "$MODEL" \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--host "$SERVER_HOST" \
--port "$SERVER_PORT" \
--gpu-memory-utilization 0.9 \
diff --git a/examples/online_serving/qwen3_tts/run_server.sh b/examples/online_serving/qwen3_tts/run_server.sh
index 6f4aa83a0b9..78dd2c305d3 100755
--- a/examples/online_serving/qwen3_tts/run_server.sh
+++ b/examples/online_serving/qwen3_tts/run_server.sh
@@ -31,7 +31,7 @@ esac
echo "Starting Qwen3-TTS server with model: $MODEL"
vllm-omni serve "$MODEL" \
- --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_tts.yaml \
+ --deploy-config vllm_omni/deploy/qwen3_tts.yaml \
--host 0.0.0.0 \
--port 8091 \
--gpu-memory-utilization 0.9 \
diff --git a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
index 38a2bdea929..7790fa51276 100644
--- a/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
+++ b/examples/online_serving/qwen3_tts/speaker_embedding_interpolation.py
@@ -5,7 +5,7 @@
using SLERP and sends the result to the /v1/audio/speech API.
Requirements:
- pip install torch resampy soundfile numpy httpx
+ pip install torch soundfile numpy httpx
Examples:
# Extract and save an embedding
@@ -143,11 +143,12 @@ def _load_speaker_encoder_weights(encoder: torch.nn.Module, model_path: str) ->
def compute_mel_spectrogram(audio: np.ndarray, sr: int = 24000) -> torch.Tensor:
"""Compute 128-bin mel spectrogram matching Qwen3-TTS's extraction pipeline."""
- from vllm.multimodal.audio import resample_audio_resampy
+ from vllm.multimodal.audio import AudioResampler
# Resample to 24kHz if needed
if sr != 24000:
- audio = resample_audio_resampy(audio.astype(np.float32), orig_sr=sr, target_sr=24000)
+ resampler = AudioResampler(target_sr=24000)
+ audio = resampler.resample(audio.astype(np.float32), orig_sr=sr)
y = torch.from_numpy(audio).unsqueeze(0).float()
diff --git a/recipes/Qwen/Qwen3-Omni.md b/recipes/Qwen/Qwen3-Omni.md
new file mode 100644
index 00000000000..081e1453d37
--- /dev/null
+++ b/recipes/Qwen/Qwen3-Omni.md
@@ -0,0 +1,90 @@
+# Qwen3-Omni for multimodal chat on 1x A100 80GB
+
+## Summary
+
+- Vendor: Qwen
+- Model: `Qwen/Qwen3-Omni-30B-A3B-Instruct`
+- Task: Multimodal chat with text, image, audio, or video input
+- Mode: Online serving with the OpenAI-compatible API
+- Maintainer: Community
+
+## When to use this recipe
+
+Use this recipe when you want a known-good starting point for serving
+`Qwen/Qwen3-Omni-30B-A3B-Instruct` with vLLM-Omni on a single 80 GB A100 and
+validate the deployment with the existing multimodal client examples in this
+repository.
+
+## References
+
+- Upstream or canonical docs:
+ [`docs/user_guide/examples/online_serving/qwen3_omni.md`](../../docs/user_guide/examples/online_serving/qwen3_omni.md)
+- Related example under `examples/`:
+ [`examples/online_serving/qwen3_omni/README.md`](../../examples/online_serving/qwen3_omni/README.md)
+- Related issue or discussion:
+ [RFC: add recipes folder](https://github.com/vllm-project/vllm-omni/issues/2645)
+
+## Hardware Support
+
+This recipe currently documents one tested-style reference configuration for
+CUDA GPU serving. Add more sections for other hardware as community validation
+lands.
+
+## GPU
+
+### 1x A100 80GB
+
+#### Environment
+
+- OS: Linux
+- Python: 3.10+
+- Driver / runtime: NVIDIA CUDA environment with an A100 80 GB GPU
+- vLLM version: Match the repository requirements for your checkout
+- vLLM-Omni version or commit: Use the commit you are deploying from
+
+#### Command
+
+Start the server from the repository root:
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct --omni --port 8091
+```
+
+To enable async chunking, use the bundled stage config:
+
+```bash
+vllm serve Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --omni \
+ --port 8091 \
+ --stage-configs-path vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+```
+
+#### Verification
+
+Run one of the existing example clients after the server is ready:
+
+```bash
+python examples/online_serving/openai_chat_completion_client_for_multimodal_generation.py \
+ --model Qwen/Qwen3-Omni-30B-A3B-Instruct \
+ --query-type use_image \
+ --port 8091 \
+ --host localhost
+```
+
+For a quick API smoke test, request text-only output:
+
+```bash
+curl http://localhost:8091/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
+ "messages": [{"role": "user", "content": "Describe vLLM in brief."}],
+ "modalities": ["text"]
+ }'
+```
+
+#### Notes
+
+- Memory usage: Size depends on runtime options and output modalities; leave headroom for multimodal workloads.
+- Key flags: `--omni` is required; `--stage-configs-path` is optional for custom or async-chunk stage configs.
+- Known limitations: This starter recipe is intentionally narrow and focuses on the single-GPU online-serving path already documented in the repo examples.
diff --git a/recipes/README.md b/recipes/README.md
new file mode 100644
index 00000000000..5b3dfb5430b
--- /dev/null
+++ b/recipes/README.md
@@ -0,0 +1,35 @@
+# Community Recipes
+
+This directory contains community-maintained recipes for answering a
+practical user question:
+
+> How do I run model X on hardware Y for task Z?
+
+Add recipes for this repository under this in-repo `recipes/` directory. To
+keep naming and layout consistent, organize recipes by model vendor in a way
+that is aligned with
+[`vllm-project/recipes`](https://github.com/vllm-project/recipes), but treat
+that external repository as a reference for structure rather than the place to
+add files for this repo. Use one Markdown file per model family by default.
+
+Example layout:
+
+```text
+recipes/
+ Qwen/
+ Qwen3-Omni.md
+ Qwen3-TTS.md
+ Tencent-Hunyuan/
+ HunyuanVideo.md
+```
+
+## Available Recipes
+
+- [`Qwen/Qwen3-Omni.md`](./Qwen/Qwen3-Omni.md): online serving recipe for
+ multimodal chat on `1x A100 80GB`
+
+Within a single recipe file, include different hardware support sections such
+as `GPU`, `ROCm`, and `NPU`, and add concrete tested configurations like
+`1x A100 80GB` or `2x L40S` inside those sections when applicable.
+
+See [TEMPLATE.md](./TEMPLATE.md) for the recommended format.
diff --git a/recipes/TEMPLATE.md b/recipes/TEMPLATE.md
new file mode 100644
index 00000000000..9bf8cb9c759
--- /dev/null
+++ b/recipes/TEMPLATE.md
@@ -0,0 +1,82 @@
+# Recipe Title
+
+> Example: Qwen3-Omni for speech chat on 1x A100 80GB
+
+## Summary
+
+- Vendor:
+- Model:
+- Task:
+- Mode:
+- Maintainer:
+
+## When to use this recipe
+
+Briefly describe the concrete scenario this recipe covers.
+
+## References
+
+- Upstream or canonical docs:
+- Related example under `examples/`:
+- Related issue or discussion:
+
+## Hardware Support
+
+Add one section per platform, such as `GPU`, `ROCm`, or `NPU`. Under each
+platform section, document one or more tested hardware configurations.
+
+## GPU
+
+### 1x A100 80GB
+
+#### Environment
+
+- OS:
+- Python:
+- Driver / runtime:
+- vLLM version:
+- vLLM-Omni version or commit:
+
+#### Command
+
+```bash
+# Add the exact command(s) here
+```
+
+#### Verification
+
+```bash
+# Add a quick validation command or expected output here
+```
+
+#### Notes
+
+- Memory usage:
+- Key flags:
+- Known limitations:
+
+### 2x L40S
+
+Repeat the same structure for other hardware setups as needed.
+
+## ROCm
+
+### Example hardware configuration
+
+Repeat the same nested structure for ROCm setups as needed:
+
+- `#### Environment`
+- `#### Command`
+- `#### Verification`
+- `#### Notes`
+
+## NPU
+
+### Example hardware configuration
+
+Repeat the same nested structure for NPU setups as needed:
+
+- `#### Environment`
+- `#### Command`
+- `#### Verification`
+- `#### Notes`
diff --git a/requirements/common.txt b/requirements/common.txt
index 1f44d343c62..63e16d580ff 100644
--- a/requirements/common.txt
+++ b/requirements/common.txt
@@ -1,7 +1,6 @@
# Common dependencies for all platforms
av>=14.0.0
omegaconf>=2.3.0
-resampy>=0.4.3
diffusers>=0.36.0
accelerate==1.12.0
soundfile>=0.13.1
diff --git a/tests/comfyui/test_comfyui_integration.py b/tests/comfyui/test_comfyui_integration.py
index 80e86d82412..5164f3b9acb 100644
--- a/tests/comfyui/test_comfyui_integration.py
+++ b/tests/comfyui/test_comfyui_integration.py
@@ -523,6 +523,7 @@ def run_server():
"Qwen/Qwen-Image-Edit",
True,
id="image-to-image-dalle-endpoint",
+ marks=pytest.mark.skip(reason="Temporarily disabled due to failure."),
),
pytest.param(
ServerCase(
diff --git a/tests/config/__init__.py b/tests/config/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/config/test_pipeline_registry.py b/tests/config/test_pipeline_registry.py
new file mode 100644
index 00000000000..3483d530c63
--- /dev/null
+++ b/tests/config/test_pipeline_registry.py
@@ -0,0 +1,111 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for the central pipeline registry (2.5/N)."""
+
+from __future__ import annotations
+
+import pytest
+
+from vllm_omni.config.pipeline_registry import (
+ _DIFFUSION_PIPELINES,
+ _OMNI_PIPELINES,
+ _VLLM_OMNI_PIPELINES,
+)
+from vllm_omni.config.stage_config import (
+ _PIPELINE_REGISTRY,
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+ register_pipeline,
+)
+
+
+class TestCentralRegistryDeclarations:
+ """Every in-tree pipeline must be declared exactly once in the central registry."""
+
+ def test_union_contains_all_omni(self):
+ for key in _OMNI_PIPELINES:
+ assert key in _VLLM_OMNI_PIPELINES
+
+ def test_union_contains_all_diffusion(self):
+ for key in _DIFFUSION_PIPELINES:
+ assert key in _VLLM_OMNI_PIPELINES
+
+ def test_no_duplicate_model_type_between_omni_and_diffusion(self):
+ overlap = set(_OMNI_PIPELINES) & set(_DIFFUSION_PIPELINES)
+ assert not overlap, f"Duplicate model_types across omni/diffusion: {overlap}"
+
+ def test_expected_omni_pipelines_present(self):
+ # Guard against accidental removal during future refactors.
+ assert "qwen2_5_omni" in _OMNI_PIPELINES
+ assert "qwen2_5_omni_thinker_only" in _OMNI_PIPELINES
+ assert "qwen3_omni_moe" in _OMNI_PIPELINES
+ assert "qwen3_tts" in _OMNI_PIPELINES
+
+
+class TestLazyLoading:
+ """Pipelines are imported only on first access."""
+
+ def test_contains_without_import(self):
+ # ``in`` hits the lazy map, not the loaded cache.
+ assert "qwen3_omni_moe" in _PIPELINE_REGISTRY
+
+ def test_getitem_loads_correct_pipeline(self):
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ assert pipeline.model_type == "qwen3_omni_moe"
+ assert pipeline.model_arch == "Qwen3OmniMoeForConditionalGeneration"
+
+ def test_unknown_model_type_returns_none_via_get(self):
+ assert _PIPELINE_REGISTRY.get("not_a_real_pipeline") is None
+
+ def test_unknown_model_type_raises_keyerror_via_getitem(self):
+ with pytest.raises(KeyError):
+ _PIPELINE_REGISTRY["not_a_real_pipeline"]
+
+ def test_iteration_yields_registered_pipelines(self):
+ keys = set(_PIPELINE_REGISTRY)
+ assert "qwen2_5_omni" in keys
+ assert "qwen3_omni_moe" in keys
+
+
+class TestDynamicRegistration:
+ """``register_pipeline()`` still works for plugins and tests."""
+
+ def test_register_adds_to_registry(self):
+ custom = PipelineConfig(
+ model_type="_test_dynamic_registration",
+ model_arch="DynamicTestModel",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="test",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ ),
+ ),
+ )
+ register_pipeline(custom)
+ try:
+ assert "_test_dynamic_registration" in _PIPELINE_REGISTRY
+ assert _PIPELINE_REGISTRY["_test_dynamic_registration"] is custom
+ finally:
+ # Don't leak the test registration into other tests.
+ if "_test_dynamic_registration" in _PIPELINE_REGISTRY:
+ del _PIPELINE_REGISTRY["_test_dynamic_registration"]
+
+ def test_dynamic_registration_overrides_lazy_entry(self):
+ # Build a substitute for qwen3_omni_moe that we can distinguish.
+ original = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ override = PipelineConfig(
+ model_type="qwen3_omni_moe",
+ model_arch="OverriddenArch",
+ stages=original.stages,
+ )
+ register_pipeline(override)
+ try:
+ assert _PIPELINE_REGISTRY["qwen3_omni_moe"].model_arch == "OverriddenArch"
+ finally:
+ # Remove the dynamic override so later tests see the original.
+ if "qwen3_omni_moe" in _PIPELINE_REGISTRY._loaded:
+ del _PIPELINE_REGISTRY["qwen3_omni_moe"]
diff --git a/tests/conftest.py b/tests/conftest.py
index ad1008b7263..77075f9525a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3429 +1,62 @@
-import atexit
-import base64
-import datetime
-import io
-import json
-import math
-import os
-import random
-import re
-import tempfile
-
-import requests
-
-os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-# Set CPU device for CI environments without GPU
-if "VLLM_TARGET_DEVICE" not in os.environ:
- os.environ["VLLM_TARGET_DEVICE"] = "cpu"
-
-import concurrent.futures
-import contextlib
-import gc
-import multiprocessing
-import socket
-import subprocess
-import sys
-import threading
-import time
-import uuid
-from collections.abc import Generator
-from dataclasses import dataclass
-from io import BytesIO
-from pathlib import Path
-from typing import Any, NamedTuple
-
-import cv2
-import numpy as np
-import psutil
-import pytest
-import soundfile as sf
-import torch
-import yaml
-from openai import OpenAI, omit
-from PIL import Image
-from transformers import pipeline
-from vllm import TextPrompt
-from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from vllm.logger import init_logger
-from vllm.utils.network_utils import get_open_port
-
-from vllm_omni.entrypoints.omni import Omni
-from vllm_omni.inputs.data import OmniSamplingParams
-from vllm_omni.outputs import OmniRequestOutput
-from vllm_omni.platforms import current_omni_platform
-
-logger = init_logger(__name__)
-
-
-PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None
-PromptImageInput = list[Any] | Any | None
-PromptVideoInput = list[Any] | Any | None
-
-_GENDER_PIPELINE = None
-# transformers.Pipeline is not thread-safe; concurrent e2e requests must serialize inference.
-_GENDER_PIPELINE_LOCK = threading.Lock()
-
-# int16 mono PCM from /v1/audio/speech when response_format=pcm (Qwen3-TTS code2wav output rate).
-_PCM_SPEECH_SAMPLE_RATE_HZ = 24_000
-
-
-class OmniServerParams(NamedTuple):
- model: str
- port: int | None = None
- stage_config_path: str | None = None
- server_args: list[str] | None = None
- env_dict: dict[str, str] | None = None
- use_omni: bool = True
- use_stage_cli: bool = False
- init_timeout: int | None = None
- stage_init_timeout: int | None = None # None defers to the server's own default (300 s)
-
-
-def assert_image_diffusion_response(
- response,
- request_config: dict[str, Any],
- run_level: str = None,
-) -> None:
- """
- Validate image diffusion response.
-
- Expected request_config schema:
- {
- "request_type": "image",
- "extra_body": {
- "num_outputs_per_prompt": 1,
- "width": ...,
- "height": ...,
- ...
- }
- }
- """
- assert response.images is not None, "Image response is None"
- assert len(response.images) > 0, "No images in response"
-
- extra_body = request_config.get("extra_body") or {}
-
- num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt")
- if num_outputs_per_prompt is not None:
- assert len(response.images) == num_outputs_per_prompt, (
- f"Expected {num_outputs_per_prompt} images, got {len(response.images)}"
- )
-
- if run_level == "advanced_model":
- width = extra_body.get("width")
- height = extra_body.get("height")
-
- if width is not None or height is not None:
- for img in response.images:
- assert_image_valid(img, width=width, height=height)
-
-
-def assert_video_diffusion_response(
- response,
- request_config: dict[str, Any],
- run_level: str = None,
-) -> None:
- """
- Validate video diffusion response.
-
- Expected request_config schema:
- {
- "request_type": "video",
- "form_data": {
- "prompt": "...",
- "num_frames": ...,
- "width": ...,
- "height": ...,
- "fps": ...,
- ...
- }
- }
- """
- form_data = request_config.get("form_data", {})
-
- assert response.videos is not None, "Video response is None"
- assert len(response.videos) > 0, "No videos in response"
-
- expected_frames = _maybe_int(form_data.get("num_frames"))
- expected_width = _maybe_int(form_data.get("width"))
- expected_height = _maybe_int(form_data.get("height"))
- expected_fps = _maybe_int(form_data.get("fps"))
-
- for vid_bytes in response.videos:
- assert_video_valid(
- vid_bytes,
- num_frames=expected_frames,
- width=expected_width,
- height=expected_height,
- fps=expected_fps,
- )
-
-
-def assert_audio_diffusion_response(
- response,
- request_config: dict[str, Any],
- run_level: str = None,
-) -> None:
- """
- Validate audio diffusion response.
- """
- raise NotImplementedError("Audio validation is not implemented yet")
-
-
-def _maybe_int(value: Any) -> int | None:
- if value is None:
- return None
- return int(value)
-
-
-def assert_image_valid(image: Path | Image.Image, *, width: int | None = None, height: int | None = None):
- """Assert the file is a loadable image with optional exact dimensions."""
- if isinstance(image, Path):
- assert image.exists(), f"Image not found: {image}"
- image = Image.open(image)
- image.load()
- assert image.width > 0 and image.height > 0
- if width is not None:
- assert image.width == width, f"Expected width={width}, got {image.width}"
- if height is not None:
- assert image.height == height, f"Expected height={height}, got {image.height}"
- return image
-
-
-def assert_video_valid(
- video: Path | bytes | BytesIO,
- *,
- num_frames: int | None = None,
- width: int | None = None,
- height: int | None = None,
- fps: float | None = None,
-) -> dict[str, int | float]:
- """Assert the MP4 has the expected resolution and exact frame count."""
- temp_path = None
- cap = None
- try:
- # Normalize input to file path
- if isinstance(video, Path):
- if not video.exists():
- raise AssertionError(f"Video file not found: {video}")
- video_path = str(video)
- else:
- # Create temp file for bytes/BytesIO
- suffix = ".mp4"
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, mode="wb") as tmp:
- if isinstance(video, bytes):
- tmp.write(video)
- elif isinstance(video, BytesIO):
- tmp.write(video.getvalue())
- else:
- raise TypeError(f"Unsupported video type: {type(video)}")
- temp_path = Path(tmp.name)
- video_path = str(temp_path)
-
- # Open video capture
- cap = cv2.VideoCapture(video_path)
- if not cap.isOpened():
- raise AssertionError(f"Failed to open video: {video_path}")
-
- # Extract properties
- actual_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
- actual_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
- actual_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
- actual_fps = cap.get(cv2.CAP_PROP_FPS)
-
- actual_num_frames = 0
- while True:
- ok, _frame = cap.read()
- if not ok:
- break
- actual_num_frames += 1
-
- # Basic validity checks
- if actual_num_frames <= 0:
- raise AssertionError(f"Invalid frame count: {actual_num_frames} (must be > 0)")
- if actual_width <= 0 or actual_height <= 0:
- raise AssertionError(f"Invalid dimensions: {actual_width}x{actual_height} (must be > 0)")
- if actual_fps <= 0:
- raise AssertionError(f"Invalid FPS: {actual_fps} (must be > 0)")
-
- # Validate against expectations
- if num_frames is not None:
- expected_num_frames = (num_frames // 4) * 4 + 1
- assert actual_num_frames == expected_num_frames, (
- f"Frame count mismatch: expected {num_frames}, got {actual_num_frames}"
- )
- if width is not None:
- assert actual_width == width, f"Width mismatch: expected {width}px, got {actual_width}px"
- if height is not None:
- assert actual_height == height, f"Height mismatch: expected {height}px, got {actual_height}px"
- if fps is not None:
- # Use tolerance for float comparison (codec rounding)
- assert abs(actual_fps - fps) < 0.5, f"FPS mismatch: expected {fps}, got {actual_fps:.2f}"
-
- return {"num_frames": actual_num_frames, "width": actual_width, "height": actual_height, "fps": actual_fps}
-
- except Exception as e:
- print(f"ERROR: {type(e).__name__}: {e}", flush=True)
- raise
-
- finally:
- # Cleanup resources
- if cap is not None:
- cap.release()
- if temp_path and temp_path.exists():
- try:
- temp_path.unlink()
- except OSError:
- pass
-
-
-def assert_audio_valid(
- audio_or_path: Path | np.ndarray,
- *,
- sample_rate: int,
- channels: int,
- duration_s: float,
-) -> None:
- """Assert WAV file or (batch, channels, samples) ndarray matches expected audio format."""
- expected_samples = int(duration_s * sample_rate)
- if isinstance(audio_or_path, np.ndarray):
- audio = audio_or_path
- assert audio.ndim == 3, f"Expected audio ndim=3 (batch, channels, samples), got shape {audio.shape}"
- assert audio.shape[0] == 1, f"Expected batch size 1, got {audio.shape[0]}"
- assert audio.shape[1] == channels, f"Expected {channels} channels, got {audio.shape[1]}"
- assert audio.shape[2] == expected_samples, (
- f"Expected {expected_samples} samples ({duration_s}s @ {sample_rate} Hz), got {audio.shape[2]}"
- )
- return
-
- path = audio_or_path
- assert path.exists(), f"Audio not found: {path}"
- info = sf.info(str(path))
- assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}"
- assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}"
- assert info.frames == expected_samples, (
- f"Expected {expected_samples} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}"
- )
-
-
-def decode_b64_image(b64: str):
- img = Image.open(BytesIO(base64.b64decode(b64)))
- img.load()
- return img
-
-
-@pytest.fixture(scope="session")
-def model_prefix() -> str:
- """Optional model-path prefix from MODEL_PREFIX env var.
- Useful if models are downloaded to non-default local directories.
- """
- prefix = os.environ.get("MODEL_PREFIX", "")
- return f"{prefix.rstrip('/')}/" if prefix else ""
-
-
-@pytest.fixture(autouse=True)
-def default_vllm_config():
- """Set a default VllmConfig for all tests.
-
- This fixture is auto-used for all tests to ensure that any test
- that directly instantiates vLLM CustomOps (e.g., RMSNorm, LayerNorm)
- or model components has the required VllmConfig context.
-
- This fixture is required for vLLM 0.14.0+ where CustomOp initialization
- requires a VllmConfig context set via set_current_vllm_config().
- """
- from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config
-
- # Use CPU device if no GPU is available (e.g., in CI environments)
- has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0
- device = "cuda" if has_gpu else "cpu"
- device_config = DeviceConfig(device=device)
-
- with set_current_vllm_config(VllmConfig(device_config=device_config)):
- yield
-
-
-@pytest.fixture(autouse=True)
-def clean_gpu_memory_between_tests():
- print("\n=== PRE-TEST GPU CLEANUP ===")
- _run_pre_test_cleanup()
- yield
- _run_post_test_cleanup()
-
-
-@pytest.fixture(autouse=True)
-def log_test_name_before_test(request):
- print(f"--- Running test: {request.node.name}")
- yield
-
-
-def _run_pre_test_cleanup(enable_force=False):
- if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
- print("\nPre-test GPU cleanup skipped(Default off is typical when one worker/instance runs many tests.)\n")
- return
-
- print("\nPre-test GPU status:")
-
- num_gpus = torch.cuda.device_count()
- if num_gpus > 0:
- try:
- from tests.utils import wait_for_gpu_memory_to_clear
-
- wait_for_gpu_memory_to_clear(
- devices=list(range(num_gpus)),
- threshold_ratio=0.05,
- )
- except Exception as e:
- print(f"Pre-test cleanup note: {e}")
-
-
-def _run_post_test_cleanup(enable_force=False):
- if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
- print("GPU cleanup disabled")
- return
-
- if torch.cuda.is_available():
- gc.collect()
- torch.cuda.empty_cache()
-
- print("Post-test GPU status:")
- _print_gpu_processes()
-
-
-def _print_gpu_processes():
- """Print GPU information including nvidia-smi and system processes"""
-
- print("\n" + "=" * 80)
- print("NVIDIA GPU Information (nvidia-smi)")
- print("=" * 80)
-
- try:
- nvidia_result = subprocess.run(
- ["nvidia-smi"],
- capture_output=True,
- text=True,
- timeout=5,
- )
-
- if nvidia_result.returncode == 0:
- lines = nvidia_result.stdout.strip().split("\n")
- for line in lines[:20]:
- print(line)
-
- if len(lines) > 20:
- print(f"... (showing first 20 of {len(lines)} lines)")
- else:
- print("nvidia-smi command failed")
-
- except (subprocess.TimeoutExpired, FileNotFoundError):
- print("nvidia-smi not available or timed out")
- except Exception as e:
- print(f"Error running nvidia-smi: {e}")
-
- print("\n" + "=" * 80)
- print("Detailed GPU Processes (nvidia-smi pmon)")
- print("=" * 80)
-
- try:
- pmon_result = subprocess.run(
- ["nvidia-smi", "pmon", "-c", "1"],
- capture_output=True,
- text=True,
- timeout=3,
- )
-
- if pmon_result.returncode == 0 and pmon_result.stdout.strip():
- print(pmon_result.stdout)
- else:
- print("No active GPU processes found via nvidia-smi pmon")
-
- except Exception:
- print("nvidia-smi pmon not available")
-
- print("\n" + "=" * 80)
- print("System Processes with GPU keywords")
- print("=" * 80)
-
-
-def dummy_messages_from_mix_data(
- system_prompt: dict[str, Any] = None,
- video_data_url: Any = None,
- audio_data_url: Any = None,
- image_data_url: Any = None,
- content_text: str = None,
-):
- """Create messages with video、image、audio data URL for OpenAI API."""
-
- if content_text is not None:
- content = [{"type": "text", "text": content_text}]
- else:
- content = []
-
- media_items = []
- if isinstance(video_data_url, list):
- for video_url in video_data_url:
- media_items.append((video_url, "video"))
- else:
- media_items.append((video_data_url, "video"))
-
- if isinstance(image_data_url, list):
- for url in image_data_url:
- media_items.append((url, "image"))
- else:
- media_items.append((image_data_url, "image"))
-
- if isinstance(audio_data_url, list):
- for url in audio_data_url:
- media_items.append((url, "audio"))
- else:
- media_items.append((audio_data_url, "audio"))
-
- content.extend(
- {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}}
- for url, media_type in media_items
- if url is not None
- )
- messages = [{"role": "user", "content": content}]
- if system_prompt is not None:
- messages = [system_prompt] + messages
- return messages
-
-
-def generate_synthetic_audio(
- duration: int, # seconds
- num_channels: int, # 1:Mono,2:Stereo 5:5.1 surround sound
- sample_rate: int = 48000, # Default use 48000Hz.
- save_to_file: bool = False,
-) -> dict[str, Any]:
- """
- Generate TTS speech with pyttsx3 and return base64 string.
- """
-
- import pyttsx3
- import soundfile as sf
-
- def _pick_voice(engine: pyttsx3.Engine) -> str | None:
- voices = engine.getProperty("voices")
- if not voices:
- return None
-
- preferred_tokens = (
- "natural",
- "jenny",
- "sonia",
- "susan",
- "zira",
- "aria",
- "hazel",
- "samantha",
- "ava",
- "allison",
- "female",
- "woman",
- "english-us",
- "en-us",
- "english",
- )
- discouraged_tokens = (
- "espeak",
- "robot",
- "mbrola",
- "microsoft david",
- "male",
- "man",
- )
-
- best_voice = voices[0]
- best_score = float("-inf")
- for voice in voices:
- voice_text = f"{getattr(voice, 'id', '')} {getattr(voice, 'name', '')}".lower()
- voice_languages = " ".join(
- lang.decode(errors="ignore") if isinstance(lang, bytes) else str(lang)
- for lang in getattr(voice, "languages", [])
- ).lower()
- combined_text = f"{voice_text} {voice_languages}"
- score = 0
- for idx, token in enumerate(preferred_tokens):
- if token in combined_text:
- score += 20 - idx
- for token in discouraged_tokens:
- if token in combined_text:
- score -= 10
- if "english" in combined_text or "en_" in combined_text or "en-" in combined_text:
- score += 4
- if "en-us" in combined_text or "english-us" in combined_text:
- score += 4
- if score > best_score:
- best_score = score
- best_voice = voice
-
- return best_voice.id
-
- def _resample_audio(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
- if src_sr == dst_sr or len(audio) == 0:
- return audio.astype(np.float32)
-
- src_len = audio.shape[0]
- dst_len = max(1, int(round(src_len * float(dst_sr) / float(src_sr))))
- src_idx = np.arange(src_len, dtype=np.float32)
- dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
-
- resampled_channels: list[np.ndarray] = []
- for ch in range(audio.shape[1]):
- resampled_channels.append(np.interp(dst_idx, src_idx, audio[:, ch]).astype(np.float32))
- return np.stack(resampled_channels, axis=1)
-
- def _match_channels(audio: np.ndarray, target_channels: int) -> np.ndarray:
- current_channels = audio.shape[1]
- if current_channels == target_channels:
- return audio.astype(np.float32)
- if target_channels == 1:
- return np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
- if current_channels == 1:
- return np.repeat(audio, target_channels, axis=1).astype(np.float32)
-
- collapsed = np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
- return np.repeat(collapsed, target_channels, axis=1).astype(np.float32)
-
- def _trim_silence(audio: np.ndarray, threshold: float = 0.01) -> np.ndarray:
- if len(audio) == 0:
- return audio
- energy = np.max(np.abs(audio), axis=1)
- voiced = np.where(energy > threshold)[0]
- if len(voiced) == 0:
- return audio
- start = max(0, int(voiced[0]) - int(sample_rate * 0.02))
- end = min(len(audio), int(voiced[-1]) + int(sample_rate * 0.04) + 1)
- return audio[start:end]
-
- def _enhance_speech(audio: np.ndarray) -> np.ndarray:
- if len(audio) == 0:
- return audio.astype(np.float32)
- enhanced = audio.astype(np.float32).copy()
- enhanced -= np.mean(enhanced, axis=0, keepdims=True, dtype=np.float32)
- if len(enhanced) > 1:
- preemphasis = enhanced.copy()
- preemphasis[1:] = enhanced[1:] - 0.94 * enhanced[:-1]
- enhanced = 0.7 * enhanced + 0.3 * preemphasis
- # Mild dynamic-range compression for ASR/TTS robustness.
- enhanced = np.sign(enhanced) * np.sqrt(np.abs(enhanced))
- # Light fade to avoid clicks after trimming/repeating.
- fade = min(len(enhanced) // 4, max(1, int(sample_rate * 0.01)))
- if fade > 1:
- ramp_in = np.linspace(0.0, 1.0, fade, dtype=np.float32)
- ramp_out = np.linspace(1.0, 0.0, fade, dtype=np.float32)
- enhanced[:fade] *= ramp_in[:, None]
- enhanced[-fade:] *= ramp_out[:, None]
- peak = float(np.max(np.abs(enhanced)))
- if peak > 1e-8:
- enhanced = enhanced / peak * 0.95
- return enhanced.astype(np.float32)
-
- phrase_text = "test"
- num_samples = int(sample_rate * max(1, duration))
- audio_data = np.zeros((num_samples, num_channels), dtype=np.float32)
-
- engine = pyttsx3.init()
- engine.setProperty("rate", 112)
- engine.setProperty("volume", 1.0)
- selected_voice = _pick_voice(engine)
- if selected_voice is not None:
- engine.setProperty("voice", selected_voice)
-
- temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
- temp_wav.close()
-
- try:
- engine.save_to_file(phrase_text, temp_wav.name)
- engine.runAndWait()
- engine.stop()
-
- ready = False
- for _ in range(50):
- if os.path.exists(temp_wav.name) and os.path.getsize(temp_wav.name) > 44:
- ready = True
- break
- time.sleep(0.1)
-
- if not ready:
- raise RuntimeError("pyttsx3 did not produce a WAV file in time.")
-
- tts_audio, tts_sr = sf.read(temp_wav.name, dtype="float32", always_2d=True)
- finally:
- if os.path.exists(temp_wav.name):
- os.unlink(temp_wav.name)
-
- if len(tts_audio) == 0:
- raise RuntimeError("pyttsx3 produced an empty WAV file.")
-
- tts_audio = _resample_audio(tts_audio, tts_sr, sample_rate)
- tts_audio = _match_channels(tts_audio, num_channels)
- tts_audio = _trim_silence(tts_audio, threshold=0.012)
- tts_audio = _enhance_speech(tts_audio)
-
- lead_silence = min(int(sample_rate * 0.02), num_samples // 8)
- pause_samples = int(sample_rate * 0.18)
- start = lead_silence
- phrase_len = tts_audio.shape[0]
-
- while start < num_samples:
- take = min(phrase_len, num_samples - start)
- audio_data[start : start + take] = tts_audio[:take]
- start += phrase_len + pause_samples
-
- max_amp = float(np.max(np.abs(audio_data)))
- if max_amp > 0:
- audio_data = audio_data / max_amp * 0.95
-
- audio_bytes: bytes | None = None
- output_path: str | None = None
- result: dict[str, Any] = {
- "np_array": audio_data.copy(),
- }
-
- if save_to_file:
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- output_path = f"audio_{num_channels}ch_{timestamp}.wav"
-
- try:
- sf.write(output_path, audio_data, sample_rate, format="WAV", subtype="PCM_16")
- print(f"Audio saved: {output_path}")
-
- with open(output_path, "rb") as f:
- audio_bytes = f.read()
- except Exception as e:
- print(f"Save failed: {e}")
- save_to_file = False
-
- # If not saving or save failed, create in memory
- if not save_to_file or audio_bytes is None:
- buffer = io.BytesIO()
- sf.write(buffer, audio_data, sample_rate, format="WAV", subtype="PCM_16")
- buffer.seek(0)
- audio_bytes = buffer.read()
-
- # Return result
- base64_audio = base64.b64encode(audio_bytes).decode("utf-8")
- result["base64"] = base64_audio
- # Always include file_path to avoid KeyError in callers.
- result["file_path"] = output_path if save_to_file and output_path else None
-
- return result
-
-
-def _mux_mp4_bytes_with_synthetic_audio(
- video_mp4_bytes: bytes,
- *,
- num_frames: int,
- fps: float = 30.0,
- sample_rate: int = 48000,
-) -> bytes:
- """
- Mux a video-only MP4 with mono TTS audio from :func:`generate_synthetic_audio` (AAC).
-
- Audio length is at least the video duration in whole seconds (rounded up); ffmpeg
- ``-shortest`` trims to the video when the WAV is longer.
-
- Uses ffmpeg from ``imageio_ffmpeg`` when available, else ``ffmpeg`` on PATH.
- If TTS or mux fails, returns ``video_mp4_bytes`` unchanged.
-
- Mux subprocess does **not** use ``capture_output=True``: ffmpeg can block writing
- to a full stderr pipe while :func:`subprocess.run` waits for exit (classic deadlock).
- """
- duration_sec = num_frames / fps if fps > 0 else 0.0
- # generate_synthetic_audio(duration=int) uses at least 1s of buffer internally
- duration_int = max(1, int(math.ceil(duration_sec)))
-
- try:
- audio_result = generate_synthetic_audio(
- duration=duration_int,
- num_channels=1,
- sample_rate=sample_rate,
- save_to_file=False,
- )
- audio_pcm = audio_result["np_array"]
- except Exception as e:
- logger.warning("Synthetic video: generate_synthetic_audio failed (%s); using video-only MP4.", e)
- return video_mp4_bytes
-
- try:
- import imageio_ffmpeg
-
- ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
- except Exception:
- ffmpeg_exe = "ffmpeg"
-
- import tempfile
-
- try:
- with tempfile.TemporaryDirectory(prefix="syn_vid_mux_") as tmp:
- vid_path = os.path.join(tmp, "video.mp4")
- wav_path = os.path.join(tmp, "audio.wav")
- out_path = os.path.join(tmp, "out.mp4")
- with open(vid_path, "wb") as f:
- f.write(video_mp4_bytes)
- sf.write(wav_path, audio_pcm, sample_rate, format="WAV", subtype="PCM_16")
- cmd = [
- ffmpeg_exe,
- "-y",
- "-nostdin",
- "-hide_banner",
- "-loglevel",
- "error",
- "-i",
- vid_path,
- "-i",
- wav_path,
- "-c:v",
- "copy",
- "-c:a",
- "aac",
- "-b:a",
- "128k",
- "-shortest",
- "-movflags",
- "+faststart",
- out_path,
- ]
- subprocess.run(
- cmd,
- check=True,
- stdin=subprocess.DEVNULL,
- timeout=300,
- )
- with open(out_path, "rb") as f:
- return f.read()
- except (
- FileNotFoundError,
- subprocess.CalledProcessError,
- subprocess.TimeoutExpired,
- OSError,
- ) as e:
- logger.warning("Synthetic video: audio mux failed (%s); using video-only MP4.", e)
- return video_mp4_bytes
-
-
-def generate_synthetic_video(
- width: int,
- height: int,
- num_frames: int,
- save_to_file: bool = False,
- *,
- embed_audio: bool = False,
-) -> dict[str, Any]:
- """Generate synthetic video with bouncing balls and base64 MP4.
-
- When ``embed_audio`` is True, muxes mono AAC from :func:`generate_synthetic_audio`
- (TTS + ffmpeg) into the MP4; otherwise returns video-only MP4 (faster when tests do
- not need an audio track).
- """
-
- import cv2
- import imageio
-
- # Create random balls
- num_balls = random.randint(3, 8)
- balls = []
-
- for _ in range(num_balls):
- radius = min(width, height) // 8
- if radius < 1:
- raise ValueError(f"Video dimensions ({width}x{height}) are too small for synthetic video generation")
- x = random.randint(radius, width - radius)
- y = random.randint(radius, height - radius)
-
- speed = random.uniform(3.0, 8.0)
- angle = random.uniform(0, 2 * math.pi)
- vx = speed * math.cos(angle)
- vy = speed * math.sin(angle)
-
- # OpenCV uses BGR format, but imageio expects RGB
- # We'll create in BGR first, then convert to RGB later
- color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
-
- balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr})
-
- # Generate video frames
- video_frames = []
-
- for frame_idx in range(num_frames):
- # Create black background (BGR format)
- frame_bgr = np.zeros((height, width, 3), dtype=np.uint8)
-
- for ball in balls:
- # Update position
- ball["x"] += ball["vx"]
- ball["y"] += ball["vy"]
-
- # Boundary collision detection
- if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width:
- ball["vx"] = -ball["vx"]
- ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"]))
-
- if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height:
- ball["vy"] = -ball["vy"]
- ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"]))
-
- # Use cv2 to draw circle
- x, y = int(ball["x"]), int(ball["y"])
- radius = ball["radius"]
-
- # Draw solid circle (main circle)
- cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1)
-
- # Add simple 3D effect: draw a brighter center
- if radius > 3: # Only add highlight when radius is large enough
- highlight_radius = max(1, radius // 2)
- highlight_x = max(highlight_radius, min(x - radius // 4, width - highlight_radius))
- highlight_y = max(highlight_radius, min(y - radius // 4, height - highlight_radius))
-
- # Create highlight color (brighter)
- highlight_color = tuple(min(c + 40, 255) for c in ball["color_bgr"])
- cv2.circle(frame_bgr, (highlight_x, highlight_y), highlight_radius, highlight_color, -1)
-
- # Convert BGR to RGB for imageio
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
- video_frames.append(frame_rgb)
-
- video_array = np.array(video_frames)
- result = {
- "np_array": video_array,
- }
- saved_file_path = None
-
- fps = 30
- buffer = io.BytesIO()
- writer_kwargs = {
- "format": "mp4",
- "fps": fps,
- "codec": "libx264",
- "quality": 7,
- "pixelformat": "yuv420p",
- "macro_block_size": 16,
- "ffmpeg_params": [
- "-preset",
- "medium",
- "-crf",
- "23",
- "-movflags",
- "+faststart",
- "-pix_fmt",
- "yuv420p",
- "-vf",
- f"scale={width}:{height}",
- ],
- }
-
- try:
- with imageio.get_writer(buffer, **writer_kwargs) as writer:
- for frame in video_frames:
- writer.append_data(frame)
- buffer.seek(0)
- video_only_bytes = buffer.read()
- except Exception as e:
- print(f"Warning: Failed to encode synthetic video: {e}")
- raise
-
- if embed_audio:
- video_bytes = _mux_mp4_bytes_with_synthetic_audio(video_only_bytes, num_frames=num_frames, fps=float(fps))
- else:
- video_bytes = video_only_bytes
-
- if save_to_file:
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- output_path = f"video_{width}x{height}_{timestamp}.mp4"
- try:
- with open(output_path, "wb") as f:
- f.write(video_bytes)
- saved_file_path = output_path
- print(f"Video saved to: {saved_file_path}")
- except Exception as e:
- print(f"Warning: Failed to save video to file {output_path}: {e}")
-
- base64_video = base64.b64encode(video_bytes).decode("utf-8")
-
- result["base64"] = base64_video
- if save_to_file and saved_file_path:
- result["file_path"] = saved_file_path
-
- return result
-
-
-def generate_synthetic_image(width: int, height: int, save_to_file: bool = False) -> dict[str, Any]:
- """Generate synthetic image with randomly colored squares and return base64 string."""
- from PIL import Image, ImageDraw
-
- # Create white background
- image = Image.new("RGB", (width, height), (255, 255, 255))
- draw = ImageDraw.Draw(image)
-
- # Generate random number of squares
- num_squares = random.randint(3, 8)
-
- for _ in range(num_squares):
- # Random square size
- square_size = random.randint(min(width, height) // 8, min(width, height) // 4)
-
- # Random position
- x = random.randint(0, width - square_size - 1)
- y = random.randint(0, height - square_size - 1)
-
- # Random color
- color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
-
- # Random border width
- border_width = random.randint(1, 5)
-
- # Draw square
- draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width)
-
- image_array = np.array(image)
- result = {"np_array": image_array.copy()}
-
- # Handle file saving
- image_bytes = None
- saved_file_path = None
-
- if save_to_file:
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- output_path = f"image_{width}x{height}_{timestamp}.jpg"
-
- try:
- # Save image to file
- image.save(output_path, format="JPEG", quality=85, optimize=True)
- saved_file_path = output_path
- print(f"Image saved to: {saved_file_path}")
-
- # Read file for base64 encoding
- with open(output_path, "rb") as f:
- image_bytes = f.read()
-
- except Exception as e:
- print(f"Warning: Failed to save image to file {output_path}: {e}")
- save_to_file = False
-
- # If not saving or save failed, create in memory
- if not save_to_file or image_bytes is None:
- buffer = io.BytesIO()
- image.save(buffer, format="JPEG", quality=85, optimize=True)
- buffer.seek(0)
- image_bytes = buffer.read()
-
- # Generate base64
- base64_image = base64.b64encode(image_bytes).decode("utf-8")
-
- # Return result
- result["base64"] = base64_image
- if save_to_file and saved_file_path:
- result["file_path"] = saved_file_path
-
- return result
-
-
-def preprocess_text(text):
- import opencc
-
- word_to_num = {
- "zero": "0",
- "one": "1",
- "two": "2",
- "three": "3",
- "four": "4",
- "five": "5",
- "six": "6",
- "seven": "7",
- "eight": "8",
- "nine": "9",
- "ten": "10",
- }
-
- for word, num in word_to_num.items():
- pattern = r"\b" + re.escape(word) + r"\b"
- text = re.sub(pattern, num, text, flags=re.IGNORECASE)
-
- text = re.sub(r"[^\w\s]", "", text)
- text = re.sub(r"\s+", " ", text)
- cc = opencc.OpenCC("t2s")
- text = cc.convert(text)
-
- # Special handling for spaces between Chinese characters:
- # - Keep single spaces between English words/numbers
- # - Remove spaces only when surrounded by Chinese characters on both sides to prevent incorrect word segmentation
- text = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", text)
-
- return text.lower().strip()
-
-
-def cosine_similarity_text(text1, text2, n: int = 3):
- from collections import Counter
-
- if not text1 or not text2:
- return 0.0
-
- text1 = preprocess_text(text1)
- text2 = preprocess_text(text2)
- print(f"cosine similarity text1 is: {text1}, text2 is: {text2}")
-
- ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)]
- ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)]
-
- counter1 = Counter(ngrams1)
- counter2 = Counter(ngrams2)
-
- all_ngrams = set(counter1.keys()) | set(counter2.keys())
- vec1 = [counter1.get(ng, 0) for ng in all_ngrams]
- vec2 = [counter2.get(ng, 0) for ng in all_ngrams]
-
- dot_product = sum(a * b for a, b in zip(vec1, vec2))
- norm1 = sum(a * a for a in vec1) ** 0.5
- norm2 = sum(b * b for b in vec2) ** 0.5
-
- if norm1 == 0 or norm2 == 0:
- return 0.0
- return dot_product / (norm1 * norm2)
-
-
-def convert_audio_to_text(audio_data):
- """
- Convert base64 encoded audio data to text using speech recognition.
- """
- audio_data = base64.b64decode(audio_data)
- output_path = f"./test_{uuid.uuid4().hex}.wav"
- with open(output_path, "wb") as audio_file:
- audio_file.write(audio_data)
-
- print(f"audio data is saved: {output_path}")
- text = convert_audio_file_to_text(output_path=output_path)
- return text
-
-
-def _merge_base64_audio_to_segment(base64_list: list[str]):
- """Merge a list of base64-encoded audio chunks into one pydub AudioSegment."""
- from pydub import AudioSegment
-
- merged = None
- for b64 in base64_list:
- raw = base64.b64decode(b64.split(",", 1)[-1])
- seg = AudioSegment.from_file(io.BytesIO(raw))
- merged = seg if merged is None else merged + seg
- return merged
-
-
-@contextlib.contextmanager
-def _serialize_whisper_small_model_download():
- """Serialize Whisper ``small`` cache writes across processes (Linux; ``fcntl``)."""
- import fcntl
-
- lock_path = Path.home() / ".cache" / "whisper" / ".small_model_download.lock"
- lock_path.parent.mkdir(parents=True, exist_ok=True)
- f = open(lock_path, "a+b")
- try:
- fcntl.flock(f.fileno(), fcntl.LOCK_EX)
- yield
- finally:
- fcntl.flock(f.fileno(), fcntl.LOCK_UN)
- f.close()
-
-
-def _whisper_transcribe_in_current_process(output_path: str) -> str:
- import whisper
-
- # Multi-GPU: use last visible device to avoid colliding with default device 0; single device uses 0.
- device_index = None
- if current_omni_platform.is_available():
- n = current_omni_platform.get_device_count()
- if n == 1:
- device_index = 0
- elif n > 1:
- device_index = n - 1
-
- if device_index is not None:
- torch_device = current_omni_platform.get_torch_device(device_index)
- current_omni_platform.set_device(torch_device)
- device = str(torch_device)
- use_accelerator = True
- else:
- use_accelerator = False
- device = "cpu"
- with _serialize_whisper_small_model_download():
- model = whisper.load_model("small", device=device)
- try:
- text = model.transcribe(
- output_path,
- temperature=0.0,
- word_timestamps=True,
- condition_on_previous_text=False,
- )["text"]
- finally:
- del model
- gc.collect()
- if use_accelerator:
- current_omni_platform.synchronize()
- current_omni_platform.empty_cache()
-
- return text or ""
-
-
-def convert_audio_file_to_text(output_path: str) -> str:
- """Convert an audio file to text in an isolated subprocess (spawn)."""
- ctx = multiprocessing.get_context("spawn")
- with concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor:
- future = executor.submit(_whisper_transcribe_in_current_process, output_path)
- return future.result()
-
-
-def convert_audio_bytes_to_text(raw_bytes: bytes) -> str:
- """
- Write container audio bytes (WAV, etc.) to a temp WAV file suitable for Whisper/ffmpeg.
- Normalizes with soundfile to PCM_16 WAV when possible to avoid codec issues.
- """
- output_path = f"./test_{uuid.uuid4().hex}.wav"
- data, samplerate = sf.read(io.BytesIO(raw_bytes))
- sf.write(output_path, data, samplerate, format="WAV", subtype="PCM_16")
- text = convert_audio_file_to_text(output_path)
- return text
-
-
-def modify_stage_config(
- yaml_path: str,
- updates: dict[str, Any] = None,
- deletes: dict[str, Any] = None,
-) -> str:
- """
- Modify configurations in a YAML file, supporting both top-level and stage-specific modifications,
- including addition, modification, and deletion of configurations.
-
- Args:
- yaml_path: Path to the YAML configuration file.
- updates: Dictionary containing both top-level and stage-specific modifications to add or update.
- Format: {
- 'async_chunk': True,
- 'stage_args': {
- 0: {'engine_args.max_model_len': 5800},
- 1: {'engine_args.max_num_seqs': 2}
- }
- }
- deletes: Dictionary containing configurations to delete.
- Format: {
- 'old_config': None, # Delete entire key
- 'stage_args': {
- 0: ['engine_args.old_param'],
- 1: ['runtime.unused_setting']
- }
- }
-
- Returns:
- str: Path to the newly created modified YAML file with timestamp suffix.
- """
- path = Path(yaml_path)
- if not path.exists():
- raise FileNotFoundError(f"yaml does not exist: {path}")
-
- try:
- with open(yaml_path, encoding="utf-8") as f:
- config = yaml.safe_load(f) or {}
- except Exception as e:
- raise ValueError(f"Cannot parse YAML file: {e}")
-
- # Helper function to apply update
- def apply_update(config_dict: dict, key_path: str, value: Any) -> None:
- """Apply update to dictionary using dot-separated path."""
- # Handle direct list assignment (e.g., engine_input_source: [1, 2])
- if "." not in key_path:
- # Simple key, set directly
- config_dict[key_path] = value
- return
-
- current = config_dict
- keys = key_path.split(".")
-
- for i in range(len(keys) - 1):
- key = keys[i]
-
- # Handle list indices
- if key.isdigit() and isinstance(current, list):
- index = int(key)
- if index < 0:
- raise ValueError(f"Negative list index not allowed: {index}")
- if index >= len(current):
- # Expand list if needed
- while len(current) <= index:
- # If we need to go deeper (more keys after this), create a dict
- # Otherwise, create None placeholder
- current.append({} if i < len(keys) - 2 else None)
- current = current[index]
- elif isinstance(current, dict):
- # Handle dictionary keys
- if key not in current:
- # If there are more keys after this, create appropriate structure
- if i < len(keys) - 1:
- # Check if next key is a digit (list index) or string (dict key)
- if keys[i + 1].isdigit():
- current[key] = []
- else:
- current[key] = {}
- else:
- # This is the last key, create based on value type
- current[key] = [] if isinstance(value, list) else {}
- elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1:
- # If current value is not dict/list but we need to go deeper, replace it
- if keys[i + 1].isdigit():
- current[key] = []
- else:
- current[key] = {}
- current = current[key]
- else:
- # Current is not a dict or list, cannot traverse further
- raise TypeError(
- f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
- )
-
- # Set the final value
- last_key = keys[-1]
- if isinstance(current, list) and last_key.isdigit():
- # Setting a value in a list by index
- index = int(last_key)
- if index < 0:
- raise ValueError(f"Negative list index not allowed: {index}")
- if index >= len(current):
- # Expand list if needed
- while len(current) <= index:
- current.append(None)
- current[index] = value
- elif isinstance(current, dict):
- # Special case: if the value is a list and we're setting a top-level key
- # Example: updating engine_input_source with [1, 2]
- current[last_key] = value
- else:
- # Current is not a dict, cannot set key
- raise TypeError(f"Cannot set value at {key_path}. Current type is {type(current).__name__}, expected dict.")
-
- # Helper function to delete by path
- def delete_by_path(config_dict: dict, path: str) -> None:
- """Delete configuration by dot-separated path."""
- if not path:
- return
-
- current = config_dict
- keys = path.split(".")
-
- # Traverse to the parent
- for i in range(len(keys) - 1):
- key = keys[i]
-
- # Handle list indices
- if key.isdigit() and isinstance(current, list):
- index = int(key)
- if index < 0 or index >= len(current):
- raise KeyError(f"List index {index} out of bounds")
- current = current[index]
- elif isinstance(current, dict):
- if key not in current:
- raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist")
- current = current[key]
- else:
- raise TypeError(
- f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
- )
-
- # Delete the item
- last_key = keys[-1]
-
- if isinstance(current, list) and last_key.isdigit():
- index = int(last_key)
- if index < 0 or index >= len(current):
- raise KeyError(f"List index {index} out of bounds")
- del current[index]
- elif isinstance(current, dict) and last_key in current:
- del current[last_key]
- else:
- print(f"Path {path} does not exist")
-
- # Apply deletions first
- if deletes:
- for key, value in deletes.items():
- if key == "stage_args":
- if value and isinstance(value, dict):
- stage_args = config.get("stage_args", [])
- if not stage_args:
- raise ValueError("stage_args does not exist in config")
-
- for stage_id, delete_paths in value.items():
- if not delete_paths:
- continue
-
- # Find stage by ID
- target_stage = None
- for stage in stage_args:
- if stage.get("stage_id") == int(stage_id):
- target_stage = stage
- break
-
- if target_stage is None:
- continue
-
- # Delete specified paths in this stage
- # Avoid shadowing the original YAML Path used for the output filename below.
- for delete_path in delete_paths:
- if delete_path: # Skip empty paths
- delete_by_path(target_stage, delete_path)
- elif "." in key:
- # Delete using dot-separated path
- delete_by_path(config, key)
- elif value is None and key in config:
- # Delete entire key
- del config[key]
-
- # Apply updates
- if updates:
- for key, value in updates.items():
- if key == "stage_args":
- if value and isinstance(value, dict):
- stage_args = config.get("stage_args", [])
- if not stage_args:
- raise ValueError("stage_args does not exist in config")
-
- for stage_id, stage_updates in value.items():
- # Find stage by ID
- target_stage = None
- for stage in stage_args:
- if stage.get("stage_id") == int(stage_id):
- target_stage = stage
- break
-
- if target_stage is None:
- available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
- raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
-
- # Apply updates to this stage
- for update_path, val in stage_updates.items():
- # Check if this is a simple key (not dot-separated)
- # Example: 'engine_input_source' vs 'engine_args.max_model_len'
- if "." not in update_path:
- # Direct key assignment (e.g., updating a list value)
- target_stage[update_path] = val
- else:
- # Dot-separated path (e.g., nested dict access)
- apply_update(target_stage, update_path, val)
- elif "." in key:
- # Apply using dot-separated path
- apply_update(config, key, value)
- else:
- # Direct top-level key
- config[key] = value
-
- # Unique suffix: multiple modify_stage_config calls in one process often run
- # within the same second (e.g. test_qwen3_omni_expansion imports both
- # get_chunk_config and get_batch_token_config). int(time.time()) would collide
- # and the later write would overwrite the earlier YAML on disk.
- # Keep generated configs outside the repo and delete them when pytest exits.
- output_fd, output_path = tempfile.mkstemp(prefix=f"{path.stem}_", suffix=".yaml")
- atexit.register(Path(output_path).unlink, missing_ok=True)
-
- with os.fdopen(output_fd, "w", encoding="utf-8") as f:
- yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2)
-
- return str(output_path)
-
-
-class OmniServer:
- """Omniserver for vLLM-Omni tests."""
-
- def __init__(
- self,
- model: str,
- serve_args: list[str],
- *,
- port: int | None = None,
- env_dict: dict[str, str] | None = None,
- use_omni: bool = True,
- ) -> None:
- _run_pre_test_cleanup(enable_force=True)
- _run_post_test_cleanup(enable_force=True)
- cleanup_dist_env_and_memory()
- self.model = model
- self.serve_args = serve_args
- self.env_dict = env_dict
- self.use_omni = use_omni
- self.proc: subprocess.Popen | None = None
- self.host = "127.0.0.1"
- if port is None:
- self.port = get_open_port()
- else:
- self.port = port
-
- def _start_server(self) -> None:
- """Start the vLLM-Omni server subprocess."""
- env = os.environ.copy()
- env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
- if self.env_dict is not None:
- env.update(self.env_dict)
-
- cmd = [
- sys.executable,
- "-m",
- "vllm_omni.entrypoints.cli.main",
- "serve",
- self.model,
- "--host",
- self.host,
- "--port",
- str(self.port),
- ]
- if self.use_omni:
- cmd.append("--omni")
- cmd += self.serve_args
-
- print(f"Launching OmniServer with: {' '.join(cmd)}")
- self.proc = subprocess.Popen(
- cmd,
- env=env,
- cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # Set working directory to vllm-omni root
- )
-
- # Wait for server to be ready
- max_wait = 1200 # 20 minutes
- start_time = time.time()
- while time.time() - start_time < max_wait:
- # Check for process status
- ret = self.proc.poll()
- if ret is not None:
- raise RuntimeError(f"Server processes exited with code {ret} before becoming ready.")
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
- sock.settimeout(1)
- result = sock.connect_ex((self.host, self.port))
- if result == 0:
- print(f"Server ready on {self.host}:{self.port}")
- return
- time.sleep(2)
-
- raise RuntimeError(f"Server failed to start within {max_wait} seconds")
-
- def _kill_process_tree(self, pid):
- """kill process and its children with verification"""
- try:
- parent = psutil.Process(pid)
- children = parent.children(recursive=True)
-
- # Get all PIDs first
- all_pids = [pid] + [child.pid for child in children]
-
- # Terminate children
- for child in children:
- try:
- child.terminate()
- except psutil.NoSuchProcess:
- pass
-
- # Wait for children
- gone, still_alive = psutil.wait_procs(children, timeout=10)
-
- # Kill remaining children
- for child in still_alive:
- try:
- child.kill()
- except psutil.NoSuchProcess:
- pass
-
- # Terminate parent
- try:
- parent.terminate()
- parent.wait(timeout=10)
- except (psutil.NoSuchProcess, psutil.TimeoutExpired):
- try:
- parent.kill()
- except psutil.NoSuchProcess:
- pass
-
- # VERIFICATION: Check if all processes are gone
- time.sleep(1) # Give system time
- alive_processes = []
- for check_pid in all_pids:
- if psutil.pid_exists(check_pid):
- alive_processes.append(check_pid)
-
- if alive_processes:
- print(f"Warning: Processes still alive: {alive_processes}")
- # Optional: Try system kill
- import subprocess
-
- for alive_pid in alive_processes:
- try:
- subprocess.run(["kill", "-9", str(alive_pid)], timeout=2)
- except Exception as e:
- print(f"Cleanup failed: {e}")
-
- except psutil.NoSuchProcess:
- pass
-
- def __enter__(self):
- self._start_server()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.proc:
- self._kill_process_tree(self.proc.pid)
- _run_pre_test_cleanup(enable_force=True)
- _run_post_test_cleanup(enable_force=True)
- cleanup_dist_env_and_memory()
-
-
-class OmniServerStageCli(OmniServer):
- """Omni server harness that exercises the stage CLI flow."""
-
- def __init__(
- self,
- model: str,
- stage_config_path: str,
- serve_args: list[str] | None = None,
- *,
- stage_ids: list[int] | None = None,
- port: int | None = None,
- env_dict: dict[str, str] | None = None,
- ) -> None:
- super().__init__(model, serve_args or [], port=port, env_dict=env_dict, use_omni=True)
- self.stage_config_path = stage_config_path
- self.master_port = get_open_port()
- self.visible_device_list = self._load_visible_device_list(env_dict)
- self.stage_runtime_devices = self._load_stage_runtime_devices(stage_config_path)
- self.stage_ids = stage_ids or self._load_stage_ids(stage_config_path)
- if 0 not in self.stage_ids:
- raise ValueError(f"Stage CLI test requires stage_id=0 in config: {stage_config_path}")
- self.stage_procs: dict[int, subprocess.Popen] = {}
- self.proc = None
-
- @staticmethod
- def _load_stage_ids(stage_config_path: str) -> list[int]:
- with open(stage_config_path, encoding="utf-8") as f:
- cfg = yaml.safe_load(f) or {}
-
- stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage]
- if not stage_ids:
- raise ValueError(f"No stage IDs found in config: {stage_config_path}")
- return stage_ids
-
- @staticmethod
- def _load_stage_runtime_devices(stage_config_path: str) -> dict[int, str]:
- with open(stage_config_path, encoding="utf-8") as f:
- cfg = yaml.safe_load(f) or {}
-
- runtime_devices: dict[int, str] = {}
- for stage in cfg.get("stage_args", []):
- stage_id = stage.get("stage_id")
- devices = stage.get("runtime", {}).get("devices")
- if stage_id is not None and devices:
- runtime_devices[int(stage_id)] = str(devices)
- return runtime_devices
-
- @classmethod
- def _parse_device_list(cls, devices: str | int) -> list[str]:
- if isinstance(devices, int):
- if devices < 0:
- raise ValueError("Device IDs must be non-negative integers")
- return [str(devices)]
- return [token.strip() for token in str(devices).split(",") if token.strip()]
-
- @classmethod
- def _load_visible_device_list(cls, env_dict: dict[str, str] | None) -> list[str] | None:
- env = os.environ.copy()
- if env_dict is not None:
- env.update(env_dict)
-
- env_var = getattr(current_omni_platform, "device_control_env_var", None)
- if env_var and env_var in env:
- return [token.strip() for token in env[env_var].split(",") if token.strip()]
- return None
-
- @classmethod
- def _map_stage_devices(cls, stage_id: int, visible_device_list: list[str] | None, devices: str) -> str:
- device_list = cls._parse_device_list(devices)
-
- if visible_device_list is None:
- return ",".join(device_list)
-
- if not all(device.isdigit() for device in device_list):
- raise ValueError("Logical devices must be non-negative integers")
-
- logical_ids = [int(device) for device in device_list]
- if logical_ids and max(logical_ids) >= len(visible_device_list):
- raise ValueError(
- f"Stage {stage_id} has logical IDs {device_list}, one or more of which exceed the number of visible devices"
- )
-
- return ",".join(visible_device_list[idx] for idx in logical_ids)
-
- def _set_stage_device_env(self, stage_id: int, env: dict[str, str], devices: str) -> None:
- mapped_devices = self._map_stage_devices(stage_id, self.visible_device_list, devices)
- env_var = getattr(current_omni_platform, "device_control_env_var", None)
- if env_var:
- env[env_var] = mapped_devices
-
- def _build_stage_cmd(self, stage_id: int, *, headless: bool) -> list[str]:
- cmd = [
- sys.executable,
- "-m",
- "vllm_omni.entrypoints.cli.main",
- "serve",
- self.model,
- "--omni",
- "--stage-configs-path",
- self.stage_config_path,
- "--stage-id",
- str(stage_id),
- "--omni-master-address",
- self.host,
- "--omni-master-port",
- str(self.master_port),
- ]
-
- if headless:
- cmd.append("--headless")
- else:
- cmd += ["--host", self.host, "--port", str(self.port)]
-
- cmd += self.serve_args
- return cmd
-
- def _launch_stage(self, stage_id: int, *, headless: bool) -> None:
- env = os.environ.copy()
- env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
- if self.env_dict is not None:
- env.update(self.env_dict)
-
- devices = self.stage_runtime_devices.get(stage_id)
- if devices:
- self._set_stage_device_env(stage_id, env, devices)
-
- cmd = self._build_stage_cmd(stage_id, headless=headless)
- print(f"Launching OmniServerStageCli stage {stage_id}: {' '.join(cmd)}")
- proc = subprocess.Popen(
- cmd,
- env=env,
- cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
- )
- self.stage_procs[stage_id] = proc
- if stage_id == 0:
- self.proc = proc
-
- def _ensure_stage_processes_alive(self) -> None:
- for stage_id, proc in self.stage_procs.items():
- ret = proc.poll()
- if ret is not None:
- raise RuntimeError(f"Stage {stage_id} exited with code {ret} before API server became ready.")
-
- def _start_server(self) -> None:
- ordered_stage_ids = [0, *[stage_id for stage_id in self.stage_ids if stage_id != 0]]
-
- self._launch_stage(0, headless=False)
- time.sleep(2)
- self._ensure_stage_processes_alive()
-
- for stage_id in ordered_stage_ids[1:]:
- self._launch_stage(stage_id, headless=True)
-
- max_wait = 1200
- start_time = time.time()
- while time.time() - start_time < max_wait:
- self._ensure_stage_processes_alive()
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
- sock.settimeout(1)
- result = sock.connect_ex((self.host, self.port))
- if result == 0:
- print(f"OmniServerStageCli ready on {self.host}:{self.port}")
- return
- time.sleep(2)
-
- raise RuntimeError(f"OmniServerStageCli failed to start within {max_wait} seconds")
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- for stage_id in sorted(self.stage_procs, reverse=True):
- proc = self.stage_procs[stage_id]
- if proc.poll() is None:
- self._kill_process_tree(proc.pid)
- _run_pre_test_cleanup(enable_force=True)
- _run_post_test_cleanup(enable_force=True)
- cleanup_dist_env_and_memory()
-
-
-def pytest_addoption(parser):
- parser.addoption(
- "--run-level",
- action="store",
- default="core_model",
- choices=["core_model", "advanced_model"],
- help="Test level to run: L2, L3",
- )
-
-
-@pytest.fixture(scope="session")
-def run_level(request) -> str:
- """A command-line argument that specifies the level of tests to run in this session.
- See https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/CI_5levels/"""
- return request.config.getoption("--run-level")
-
-
-_omni_server_lock = threading.Lock()
-
-
-@pytest.fixture(scope="module")
-def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: str) -> Generator[OmniServer, Any, None]:
- """Start vLLM-Omni through the standard or stage-CLI launcher.
-
- The fixture stays module-scoped because multi-stage initialization is costly.
- The ``use_stage_cli`` flag on ``OmniServerParams`` routes the setup through the
- stage-CLI harness while still reusing the same fixture grouping semantics.
- """
- with _omni_server_lock:
- params: OmniServerParams = request.param
- model = model_prefix + params.model
- port = params.port
- stage_config_path = params.stage_config_path
- if run_level == "advanced_model" and stage_config_path is not None:
- with open(stage_config_path, encoding="utf-8") as f:
- cfg = yaml.safe_load(f) or {}
- stage_ids = [stage["stage_id"] for stage in cfg.get("stage_args", []) if "stage_id" in stage]
- stage_config_path = modify_stage_config(
- stage_config_path,
- deletes={"stage_args": {stage_id: ["engine_args.load_format"] for stage_id in stage_ids}},
- )
-
- server_args = params.server_args or []
- if params.use_omni and params.stage_init_timeout is not None:
- server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)]
- else:
- server_args = [*server_args, "--stage-init-timeout", "600"]
- if params.init_timeout is not None:
- server_args = [*server_args, "--init-timeout", str(params.init_timeout)]
- else:
- server_args = [*server_args, "--init-timeout", "900"]
- if params.use_stage_cli:
- if not params.use_omni:
- raise ValueError("omni_server with use_stage_cli=True requires use_omni=True")
- if stage_config_path is None:
- raise ValueError("omni_server with use_stage_cli=True requires a stage_config_path")
-
- with OmniServerStageCli(
- model,
- stage_config_path,
- server_args,
- port=port,
- env_dict=params.env_dict,
- ) as server:
- print("OmniServer started successfully")
- yield server
- print("OmniServer stopping...")
- else:
- if stage_config_path is not None:
- server_args += ["--stage-configs-path", stage_config_path]
-
- with (
- OmniServer(
- model,
- server_args,
- port=port,
- env_dict=params.env_dict,
- use_omni=params.use_omni,
- )
- if port
- else OmniServer(
- model,
- server_args,
- env_dict=params.env_dict,
- use_omni=params.use_omni,
- )
- ) as server:
- print("OmniServer started successfully")
- yield server
- print("OmniServer stopping...")
-
- print("OmniServer stopped")
-
-
-@dataclass
-class OmniResponse:
- text_content: str | None = None
- audio_data: list[str] | None = None
- audio_content: str | None = None
- audio_format: str | None = None
- audio_bytes: bytes | None = None
- similarity: float | None = None
- e2e_latency: float | None = None
- success: bool = False
- error_message: str | None = None
- cached_tokens: int | None = None
-
-
-@dataclass
-class DiffusionResponse:
- text_content: str | None = None
- images: list[Image.Image] | None = None
- audios: list[Any] | None = None
- videos: list[Any] | None = None
- e2e_latency: float | None = None
- success: bool = False
- error_message: str | None = None
-
-
-def _load_gender_pipeline():
- """
- Lazy-load a cached audio-classification pipeline for gender.
-
- We prefer the pipeline wrapper because it encapsulates processor/model loading
- and avoids direct AutoProcessor.from_pretrained call sites in this file.
- """
- global _GENDER_PIPELINE
- if _GENDER_PIPELINE is not None:
- return _GENDER_PIPELINE
-
- model_name = "7wolf/wav2vec2-base-gender-classification"
- try:
- # device=-1 forces CPU for pipeline.
- _GENDER_PIPELINE = pipeline(
- task="audio-classification",
- model=model_name,
- device=-1,
- )
- return _GENDER_PIPELINE
- except Exception as exc: # pragma: no cover - best-effort fallback
- print(f"Warning: failed to create gender pipeline '{model_name}': {exc}")
- _GENDER_PIPELINE = None
- return None
-
-
-def _median_pitch_hz_from_autocorr(mono: np.ndarray, sr: int) -> float | None:
- """
- Rough median F0 (Hz) over short-time frames. Used to debias wav2vec2 gender head on TTS,
- which often labels lower-pitched synthetic speech as female under load or on clean signals.
- Returns None if the clip is too short or mostly unvoiced.
- """
- x = np.asarray(mono, dtype=np.float64)
- x = x - np.mean(x)
- if x.size < int(0.15 * sr):
- return None
- frame_len = int(0.04 * sr)
- hop = max(frame_len // 2, 1)
- f0_min_hz, f0_max_hz = 70.0, 400.0
- lag_min = max(1, int(sr / f0_max_hz))
- lag_max = min(frame_len - 2, int(sr / f0_min_hz))
- if lag_max <= lag_min:
- return None
- win = np.hamming(frame_len)
- pitches: list[float] = []
- for start in range(0, int(x.shape[0]) - frame_len, hop):
- frame = x[start : start + frame_len] * win
- frame = frame - np.mean(frame)
- if float(np.sqrt(np.mean(frame**2))) < 1e-4:
- continue
- ac = np.correlate(frame, frame, mode="full")[frame_len - 1 :]
- ac = ac / (float(ac[0]) + 1e-12)
- region = ac[lag_min : lag_max + 1]
- peak_rel = int(np.argmax(region))
- peak_lag = peak_rel + lag_min
- if peak_lag <= 0:
- continue
- f0 = float(sr) / float(peak_lag)
- if f0_min_hz <= f0 <= f0_max_hz:
- pitches.append(f0)
- if len(pitches) < 4:
- return None
- return float(np.median(np.asarray(pitches, dtype=np.float64)))
-
-
-def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str:
- """
- Estimate voice gender from audio using a small pre-trained classification model.
-
- Uses a cached `audio-classification` pipeline to classify the clip.
- Returns 'male' / 'female' when the model confidence is >= 0.9 and the label
- maps to one of these; otherwise returns 'unknown'. If the model is unavailable
- or inference fails, returns 'unknown' to keep tests stable.
-
- Under concurrent tests, a global lock serializes pipeline calls (the HF pipeline is not
- thread-safe). A coarse F0 median can correct systematic "male -> female" errors on TTS audio.
- """
- data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
- if data.size == 0:
- raise ValueError("Empty audio")
- mono = np.mean(data, axis=1)
-
- try:
- target_sr = 16000
- if int(sr) != target_sr and mono.size > 1:
- src_len = int(mono.shape[0])
- dst_len = max(1, int(round(src_len * float(target_sr) / float(sr))))
- src_idx = np.arange(src_len, dtype=np.float32)
- dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
- mono = np.interp(dst_idx, src_idx, mono.astype(np.float32, copy=False)).astype(np.float32)
- sr = target_sr
-
- median_f0 = _median_pitch_hz_from_autocorr(mono, sr)
-
- clf = _load_gender_pipeline()
- if clf is None:
- print("gender model not available, returning 'unknown'")
- return "unknown"
-
- # transformers pipeline returns a list of {label, score} (highest score first).
- with _GENDER_PIPELINE_LOCK:
- outputs = clf(mono, sampling_rate=sr)
- if not outputs:
- return "unknown"
-
- top = outputs[0]
- label = str(top.get("label", "")).lower()
- conf = float(top.get("score", 0.0))
-
- if conf < 0.5:
- gender = "unknown"
- # Some models use non-English labels (e.g., Russian). Normalize to 'male'/'female'.
- elif ("female" in label) or ("жен" in label):
- gender = "female"
- elif ("male" in label) or ("муж" in label):
- gender = "male"
- else:
- gender = "unknown"
-
- # Debias: wav2vec2 gender heads often call TTS / band-limited male speech "female".
- # Low median F0 (~speech male range) + female label -> trust pitch when score is not overwhelming.
- if gender == "female" and median_f0 is not None and median_f0 < 165.0 and conf < 0.88:
- print(f"gender pitch assist: reclassifying female->male (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
- gender = "male"
- elif gender == "male" and median_f0 is not None and median_f0 > 230.0 and conf < 0.88:
- print(f"gender pitch assist: reclassifying male->female (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
- gender = "female"
-
- print(
- f"gender classifier: label={label}, conf={conf:.3f}, gender={gender}"
- + (f", median_f0={median_f0:.1f}Hz" if median_f0 is not None else "")
- )
- return gender
- except Exception as exc: # pragma: no cover - best-effort fallback
- print(f"Warning: gender classification failed, returning 'unknown': {exc}")
- return "unknown"
-
-
-_PRESET_VOICE_GENDER_MAP: dict[str, str] = {
- "serena": "female",
- "uncle_fu": "male",
- "chelsie": "female",
- "clone": "female",
- "ethan": "male",
+"""
+Root pytest entrypoint for the vLLM-Omni test suite.
+
+- `tests/conftest.py` stays thin: plugin registration + compatibility re-exports.
+- Importable utilities live under `tests/helpers/`.
+- Fixtures live under `tests/helpers/fixtures/` and are loaded via `pytest_plugins`.
+"""
+
+from __future__ import annotations
+
+pytest_plugins = (
+ "tests.helpers.fixtures.env",
+ "tests.helpers.fixtures.log",
+ "tests.helpers.fixtures.run_args",
+ "tests.helpers.fixtures.runtime",
+)
+
+
+def pytest_terminal_summary(terminalreporter, exitstatus, config):
+ # Marker for Buildkite log folding before pytest summary lines.
+ terminalreporter.write_line("--- Running Summary")
+
+
+# Backward-compatible lazy re-exports.
+# (Many tests still import from `tests.conftest`; migrate these imports to `tests.helpers.*` over time.)
+# Keep these lazy so conftest import does not trigger heavy helper dependencies.
+_ASSERTION_EXPORT_NAMES = (
+ "assert_audio_speech_response",
+ "assert_diffusion_response",
+ "assert_image_diffusion_response",
+ "assert_image_valid",
+ "assert_omni_response",
+ "assert_video_diffusion_response",
+ "assert_video_valid",
+)
+_MEDIA_EXPORT_NAMES = (
+ "convert_audio_bytes_to_text",
+ "convert_audio_file_to_text",
+ "cosine_similarity_text",
+ "decode_b64_image",
+ "generate_synthetic_audio",
+ "generate_synthetic_image",
+ "generate_synthetic_video",
+)
+_STAGE_CONFIG_EXPORT_NAMES = ("modify_stage_config",)
+_RUNTIME_EXPORT_NAMES = (
+ "DiffusionResponse",
+ "OmniResponse",
+ "OmniRunner",
+ "OmniRunnerHandler",
+ "OmniServer",
+ "OmniServerParams",
+ "OmniServerStageCli",
+ "OpenAIClientHandler",
+ "dummy_messages_from_mix_data",
+)
+_LAZY_EXPORT_MODULES = {
+ **{name: "tests.helpers.assertions" for name in _ASSERTION_EXPORT_NAMES},
+ **{name: "tests.helpers.media" for name in _MEDIA_EXPORT_NAMES},
+ **{name: "tests.helpers.stage_config" for name in _STAGE_CONFIG_EXPORT_NAMES},
+ **{name: "tests.helpers.runtime" for name in _RUNTIME_EXPORT_NAMES},
}
-
-
-def _assert_preset_voice_gender_from_audio(
- audio_bytes: bytes | None,
- voice_name: str | None,
-) -> None:
- """If ``voice_name`` matches a known preset, assert classifier gender matches (skip when unknown)."""
- if not voice_name or not audio_bytes:
- return
- key = str(voice_name).lower()
- expected_gender = _PRESET_VOICE_GENDER_MAP.get(key)
- if expected_gender is None:
- return
- estimated_gender = _estimate_voice_gender_from_audio(audio_bytes)
- print(f"Preset voice gender check: preset={key!r}, estimated={estimated_gender!r}, expected={expected_gender!r}")
- if estimated_gender != "unknown":
- assert estimated_gender == expected_gender, (
- f"{voice_name!r} is expected {expected_gender}, but estimated gender is {estimated_gender!r}"
- )
-
-
-# Threshold aligned with _compute_pcm_hnr_db docstring (clean clone vs distorted).
-_MIN_PCM_SPEECH_HNR_DB = 1.0
-
-
-def _compute_pcm_hnr_db(pcm_samples: np.ndarray, sr: int = _PCM_SPEECH_SAMPLE_RATE_HZ) -> float:
- """Compute mean Harmonic-to-Noise Ratio (dB) for speech quality.
-
- Clean cloned speech has HNR > 1.2 dB; distorted speech (e.g. lost
- ref_code decoder context) drops below 1.0 dB.
- """
- frame_len = int(0.03 * sr) # 30ms frames
- hop = frame_len // 2
- hnr_values: list[float] = []
-
- for start in range(0, len(pcm_samples) - frame_len, hop):
- frame = pcm_samples[start : start + frame_len].astype(np.float32, copy=False)
- frame = frame - np.mean(frame)
- if np.max(np.abs(frame)) < 0.01:
- continue
- ac = np.correlate(frame, frame, mode="full")[len(frame) - 1 :]
- ac = ac / (ac[0] + 1e-10)
- min_lag = int(sr / 400)
- max_lag = min(int(sr / 80), len(ac))
- if min_lag >= max_lag:
- continue
- peak = float(np.max(ac[min_lag:max_lag]))
- if 0 < peak < 1:
- hnr_values.append(10 * np.log10(peak / (1 - peak + 1e-10)))
-
- return float(np.mean(hnr_values)) if hnr_values else 0.0
-
-
-def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None:
- """Validate harmonic-to-noise ratio on raw int16 PCM from /v1/audio/speech."""
- assert audio_bytes is not None and len(audio_bytes) >= 2, "missing PCM bytes"
- assert len(audio_bytes) % 2 == 0, "PCM byte length must be aligned to int16"
- pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
- hnr = _compute_pcm_hnr_db(pcm_samples)
- print(f"PCM speech HNR: {hnr:.2f} dB (threshold: {_MIN_PCM_SPEECH_HNR_DB} dB)")
- assert hnr >= _MIN_PCM_SPEECH_HNR_DB, (
- f"Audio distortion detected: HNR={hnr:.2f} dB < {_MIN_PCM_SPEECH_HNR_DB} dB. "
- "Voice clone decoder may be losing ref_code speaker context on later chunks."
- )
-
-
-def assert_omni_response(response: OmniResponse, request_config: dict[str, Any], run_level):
- """
- Validate response results.
-
- Args:
- response: OmniResponse object
-
- Raises:
- AssertionError: When the response does not meet validation criteria
- """
- assert response.success, "The request failed."
- e2e_latency = response.e2e_latency
- if e2e_latency is not None:
- print(f"the e2e latency is: {e2e_latency}")
-
- modalities = request_config.get("modalities", ["text", "audio"])
-
- if run_level == "advanced_model":
- if "audio" in modalities:
- assert response.audio_content is not None, "No audio output is generated"
- print(f"audio content is: {response.audio_content}")
- speaker = request_config.get("speaker")
- if speaker:
- _assert_preset_voice_gender_from_audio(
- response.audio_bytes,
- speaker,
- )
-
- if "text" in modalities:
- assert response.text_content is not None, "No text output is generated"
- print(f"text content is: {response.text_content}")
-
- # Verify image description
- word_types = ["text", "image", "audio", "video"]
- keywords_dict = request_config.get("key_words", {})
- for word_type in word_types:
- keywords = keywords_dict.get(word_type)
- if "text" in modalities:
- if keywords:
- text_lower = response.text_content.lower()
- assert any(str(kw).lower() in text_lower for kw in keywords), (
- "The output does not contain any of the keywords."
- )
- else:
- if keywords:
- audio_lower = response.audio_content.lower()
- assert any(str(kw).lower() in audio_lower for kw in keywords), (
- "The output does not contain any of the keywords."
- )
-
- # Verify similarity (Whisper transcript vs streamed/detokenized text)
- if "text" in modalities and "audio" in modalities:
- assert response.similarity is not None and response.similarity > 0.9, (
- "The audio content is not same as the text"
- )
- print(f"similarity is: {response.similarity}")
-
-
-def assert_audio_speech_response(
- response: OmniResponse,
- request_config: dict[str, Any],
- run_level: str,
-) -> None:
- """
- Validate /v1/audio/speech response: success, optional format check, transcription similarity
- and gender (non-PCM only for advanced_model), and int16 PCM HNR when response_format is pcm.
- """
- assert response.success, "The request failed."
-
- req_fmt = request_config.get("response_format")
-
- if req_fmt == "pcm" and response.audio_bytes:
- _assert_pcm_int16_speech_hnr(response.audio_bytes)
- if response.audio_format:
- assert "pcm" in response.audio_format.lower(), (
- f"Expected audio/pcm content-type, got {response.audio_format!r}"
- )
-
- elif req_fmt == "wav" and response.audio_format:
- assert req_fmt in response.audio_format, (
- f"The response audio format {response.audio_format} don't match the request audio format {req_fmt}"
- )
-
- e2e_latency = response.e2e_latency
- if e2e_latency is not None:
- print(f"the avg e2e latency is: {e2e_latency}")
-
- if run_level == "advanced_model" and req_fmt != "pcm":
- # Text–audio semantic similarity check (skipped for raw PCM: no Whisper transcript).
- expected_text = request_config.get("input")
- if expected_text:
- transcript = (response.audio_content or "").strip()
- print(f"audio content is: {transcript}")
- print(f"input text is: {expected_text}")
- similarity = cosine_similarity_text(transcript.lower(), expected_text.lower())
- print(f"Cosine similarity: {similarity:.3f}")
- assert similarity > 0.9, (
- f"Transcript doesn't match input: similarity={similarity:.2f}, transcript='{transcript}'"
- )
-
- # Voice gender consistency check (preset names in ``_PRESET_VOICE_GENDER_MAP``).
- # When the estimator returns 'unknown', we treat it as inconclusive and do NOT fail the test.
- _assert_preset_voice_gender_from_audio(
- response.audio_bytes,
- request_config.get("voice"),
- )
-
-
-def assert_diffusion_response(response: DiffusionResponse, request_config: dict[str, Any], run_level: str = None):
- """
- Validate diffusion response results.
-
- Dispatcher that routes validation to modality-specific assert functions.
-
- Args:
- response: DiffusionResponse object.
- request_config: Request configuration dictionary.
- run_level: Test run level (e.g. "core_model", "advanced_model")
-
- Raises:
- AssertionError: When the response does not meet validation criteria
- KeyError: When the request_config does not contain necessary parameters for validation
- """
- assert response.success, "The request failed."
-
- e2e_latency = response.e2e_latency
- if e2e_latency is not None:
- print(f"the avg e2e is: {e2e_latency}")
-
- has_any_content = any(content is not None for content in (response.images, response.videos, response.audios))
- assert has_any_content, "Response contains no images, videos, or audios"
-
- if response.images is not None:
- assert_image_diffusion_response(
- response=response,
- request_config=request_config,
- run_level=run_level,
- )
-
- if response.videos is not None:
- assert_video_diffusion_response(
- response=response,
- request_config=request_config,
- run_level=run_level,
- )
-
- if response.audios is not None:
- assert_audio_diffusion_response(
- response=response,
- request_config=request_config,
- run_level=run_level,
- )
-
-
-class OpenAIClientHandler:
- """
- OpenAI client handler class, encapsulating both streaming and non-streaming response processing logic.
-
- This class integrates OpenAI API request sending, response handling, and validation functionality,
- supporting both single request and concurrent request modes.
- """
-
- def __init__(
- self, host: str = "127.0.0.1", port: int = get_open_port(), api_key: str = "EMPTY", run_level: str = None
- ):
- """
- Initialize the OpenAI client.
-
- Args:
- host: vLLM-Omni server host address
- port: vLLM-Omni server port
- api_key: API key (defaults to "EMPTY")
- """
- self.base_url = f"http://{host}:{port}"
- self.client = OpenAI(base_url=f"http://{host}:{port}/v1", api_key=api_key)
- self.run_level = run_level
-
- def _process_stream_omni_response(self, chat_completion) -> OmniResponse:
- """
- Process streaming responses.
-
- Args:
- chat_completion: OpenAI streaming response object
- request_config: Request configuration dictionary
-
- Returns:
- OmniResponse: Processed response object
- """
- result = OmniResponse()
- start_time = time.perf_counter()
-
- try:
- text_content = ""
- audio_data = []
-
- for chunk in chat_completion:
- for choice in chunk.choices:
- # Get content data
- if hasattr(choice, "delta"):
- content = getattr(choice.delta, "content", None)
- else:
- content = None
-
- # Get modality type
- modality = getattr(chunk, "modality", None)
-
- # Process content based on modality type
- if modality == "audio" and content:
- audio_data.append(content)
- elif modality == "text" and content:
- text_content += content if content else ""
-
- # Calculate end-to-end latency
- result.e2e_latency = time.perf_counter() - start_time
-
- # Process audio and text content
- audio_content = None
- similarity = None
-
- if audio_data or text_content:
- if audio_data:
- merged_seg = _merge_base64_audio_to_segment(audio_data)
- wav_buf = BytesIO()
- merged_seg.export(wav_buf, format="wav")
- result.audio_bytes = wav_buf.getvalue()
- audio_content = convert_audio_bytes_to_text(result.audio_bytes)
- if audio_content and text_content:
- similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
-
- # Populate result object
- result.text_content = text_content
- result.audio_data = audio_data
- result.audio_content = audio_content
- result.similarity = similarity
- result.success = True
-
- except Exception as e:
- result.error_message = f"Stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
-
- return result
-
- def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse:
- """
- Process non-streaming responses.
-
- Args:
- chat_completion: OpenAI non-streaming response object
- request_config: Request configuration dictionary
-
- Returns:
- OmniResponse: Processed response object
- """
- result = OmniResponse()
- start_time = time.perf_counter()
-
- try:
- audio_data = None
- text_content = None
-
- # Iterate through all choices
- for choice in chat_completion.choices:
- # Process audio data
- if hasattr(choice.message, "audio") and choice.message.audio is not None:
- audio_message = choice.message
- audio_data = audio_message.audio.data
-
- # Process text content
- if hasattr(choice.message, "content") and choice.message.content is not None:
- text_content = choice.message.content
-
- # Extract cached_tokens for prefix caching tests
- usage = getattr(chat_completion, "usage", None)
- if usage and (details := getattr(usage, "prompt_tokens_details", None)):
- result.cached_tokens = details.cached_tokens
-
- # Calculate end-to-end latency
- result.e2e_latency = time.perf_counter() - start_time
-
- # Process audio and text content
- audio_content = None
- similarity = None
-
- if audio_data or text_content:
- if audio_data:
- result.audio_bytes = base64.b64decode(audio_data)
- audio_content = convert_audio_bytes_to_text(result.audio_bytes)
- if audio_content and text_content:
- similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
-
- # Populate result object
- result.text_content = text_content
- result.audio_content = audio_content
- result.similarity = similarity
- result.success = True
-
- except Exception as e:
- result.error_message = f"Non-stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
-
- return result
-
- def _process_diffusion_response(self, chat_completion) -> DiffusionResponse:
- """
- Process diffusion responses (image generation/editing).
-
- Args:
- chat_completion: OpenAI response object
-
- Returns:
- DiffusionResponse: Processed response object
- """
- result = DiffusionResponse()
- start_time = time.perf_counter()
-
- try:
- images = []
- # [TODO] reading video and audio output from API response for later validation
-
- for choice in chat_completion.choices:
- if hasattr(choice.message, "content") and choice.message.content is not None:
- content = choice.message.content
- if isinstance(content, list):
- for item in content:
- if isinstance(item, dict):
- image_url = item.get("image_url", {}).get("url")
- else:
- image_url_obj = getattr(item, "image_url", None)
- image_url = getattr(image_url_obj, "url", None) if image_url_obj else None
- if image_url and image_url.startswith("data:image"):
- b64_data = image_url.split(",", 1)[1]
- img = decode_b64_image(b64_data)
- images.append(img)
-
- result.e2e_latency = time.perf_counter() - start_time
- result.images = images if images else None
- result.success = True
-
- except Exception as e:
- result.error_message = f"Diffusion response processing error: {str(e)}"
- print(f"Error: {result.error_message}")
-
- return result
-
- def _process_stream_audio_speech_response(self, response, *, response_format: str | None = None) -> OmniResponse:
- """
- Process streaming /v1/audio/speech responses into an OmniResponse.
-
- This mirrors _process_stream_omni_response but operates on low-level
- audio bytes and produces an OmniResponse with audio_content filled
- from Whisper transcription.
- """
- result = OmniResponse()
- start_time = time.perf_counter()
-
- try:
- # Aggregate all audio bytes from the streaming response.
- data = bytearray()
-
- # Preferred OpenAI helper.
- if hasattr(response, "iter_bytes") and callable(getattr(response, "iter_bytes")):
- for chunk in response.iter_bytes():
- if chunk:
- data.extend(chunk)
- else:
- # Generic iterable-of-bytes fallback (e.g., generator or list of chunks).
- try:
- iterator = iter(response)
- except TypeError:
- iterator = None
-
- if iterator is not None:
- for chunk in iterator:
- if not chunk:
- continue
- if isinstance(chunk, (bytes, bytearray)):
- data.extend(chunk)
- elif hasattr(chunk, "data"):
- data.extend(chunk.data) # type: ignore[arg-type]
- elif hasattr(chunk, "content"):
- data.extend(chunk.content) # type: ignore[arg-type]
- else:
- raise TypeError(f"Unsupported stream chunk type: {type(chunk)}")
- else:
- raise TypeError(f"Unsupported audio speech streaming response type: {type(response)}")
-
- raw_bytes = bytes(data)
- if response_format == "pcm":
- transcript = None
- else:
- transcript = convert_audio_bytes_to_text(raw_bytes)
-
- # Populate OmniResponse.
- result.audio_bytes = raw_bytes
- result.audio_content = transcript
- result.e2e_latency = time.perf_counter() - start_time
- result.success = True
- result.audio_format = getattr(response, "response", None)
- if result.audio_format is not None:
- result.audio_format = result.audio_format.headers.get("content-type", "")
-
- except Exception as e:
- result.error_message = f"Audio speech stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
-
- return result
-
- def _process_non_stream_audio_speech_response(
- self, response, *, response_format: str | None = None
- ) -> OmniResponse:
- """
- Process non-streaming /v1/audio/speech responses into an OmniResponse.
-
- This mirrors _process_non_stream_omni_response but for the binary
- audio payload returned by audio.speech.create.
- """
- result = OmniResponse()
- start_time = time.perf_counter()
-
- try:
- # OpenAI non-streaming audio.speech.create returns HttpxBinaryResponseContent (.read() or .content)
- if hasattr(response, "read") and callable(getattr(response, "read")):
- raw_bytes = response.read()
- elif hasattr(response, "content"):
- raw_bytes = response.content # type: ignore[assignment]
- else:
- raise TypeError(f"Unsupported audio speech response type: {type(response)}")
-
- if response_format == "pcm":
- transcript = None
- else:
- transcript = convert_audio_bytes_to_text(raw_bytes)
-
- result.audio_bytes = raw_bytes
- result.audio_content = transcript
- result.e2e_latency = time.perf_counter() - start_time
- result.success = True
- result.audio_format = getattr(response, "response", None)
- if result.audio_format is not None:
- result.audio_format = result.audio_format.headers.get("content-type", "")
-
- except Exception as e:
- result.error_message = f"Audio speech non-stream processing error: {str(e)}"
- print(f"Error: {result.error_message}")
-
- return result
-
- def send_omni_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
- """
- Send OpenAI requests.
-
- Args:
- request_config: Request configuration dictionary containing parameters like model, messages, stream.
- Optional ``use_audio_in_video`` (bool): when true, sets
- ``extra_body["mm_processor_kwargs"] = {"use_audio_in_video": True}`` for Qwen-Omni video+audio
- extraction.
- Optional top-level ``speaker`` (str): Qwen3-Omni preset TTS speaker name; sent as
- ``extra_body["speaker"]`` to ``chat.completions.create``.
- request_num: Number of requests, defaults to 1 (single request)
-
- Returns:
- List[OmniResponse]: List of response objects
- """
-
- responses = []
- stream = request_config.get("stream", False)
- modalities = request_config.get("modalities", ["text", "audio"])
-
- extra_body: dict[str, Any] = {}
- if "speaker" in request_config:
- extra_body["speaker"] = request_config["speaker"]
- if request_config.get("use_audio_in_video"):
- mm = dict(extra_body.get("mm_processor_kwargs") or {})
- mm["use_audio_in_video"] = True
- extra_body["mm_processor_kwargs"] = mm
- extra_body_arg: dict[str, Any] | None = extra_body if extra_body else None
-
- create_kwargs: dict[str, Any] = {
- "model": request_config.get("model"),
- "messages": request_config.get("messages"),
- "stream": stream,
- "modalities": modalities,
- }
- if extra_body_arg is not None:
- create_kwargs["extra_body"] = extra_body_arg
-
- if request_num == 1:
- # Send single request
- chat_completion = self.client.chat.completions.create(**create_kwargs)
-
- if stream:
- response = self._process_stream_omni_response(chat_completion)
- else:
- response = self._process_non_stream_omni_response(chat_completion)
-
- assert_omni_response(response, request_config, run_level=self.run_level)
- responses.append(response)
-
- else:
- # Send concurrent requests: run create + process in worker so e2e_latency includes full round-trip.
- def _one_omni_request():
- start = time.perf_counter()
- worker_kwargs: dict[str, Any] = {
- "model": request_config.get("model"),
- "messages": request_config.get("messages"),
- "modalities": modalities,
- "stream": stream,
- }
- if extra_body_arg is not None:
- worker_kwargs["extra_body"] = extra_body_arg
- chat_completion = self.client.chat.completions.create(**worker_kwargs)
- if stream:
- response = self._process_stream_omni_response(chat_completion)
- else:
- response = self._process_non_stream_omni_response(chat_completion)
- response.e2e_latency = time.perf_counter() - start
- return response
-
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = [executor.submit(_one_omni_request) for _ in range(request_num)]
- for future in concurrent.futures.as_completed(futures):
- response = future.result()
- assert_omni_response(response, request_config, run_level=self.run_level)
- responses.append(response)
-
- return responses
-
- def send_audio_speech_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
- """
- Call the /v1/audio/speech endpoint using the same configuration-dict
- style as send_omni_request, but via the OpenAI Python client's
- audio.speech APIs.
-
- Expected keys in request_config:
- - model: model name/path (required)
- - input: text to synthesize (required)
- - response_format: audio format such as "wav" or "pcm" (optional)
- - task_type, ref_text, ref_audio: TTS-specific extras (optional, passed via extra_body)
- - timeout: request timeout in seconds (float, optional, default 120.0)
- - stream: whether to use streaming API (bool, optional, default False)
- """
- timeout = float(request_config.get("timeout", 120.0))
-
- model = request_config["model"]
- text_input = request_config["input"]
- stream = bool(request_config.get("stream", False))
- voice = request_config.get("voice", None)
-
- # Standard OpenAI param: use omit when not provided to keep default behavior.
- response_format = request_config.get("response_format", omit)
-
- # Qwen3-TTS custom fields, forwarded via extra_body.
- extra_body: dict[str, Any] = {}
- # Keep this list aligned with vllm_omni.entrypoints.openai.protocol.audio params.
- for key in ("task_type", "ref_text", "ref_audio", "language", "max_new_tokens"):
- if key in request_config:
- extra_body[key] = request_config[key]
-
- responses: list[OmniResponse] = []
-
- speech_fmt: str | None = None if response_format is omit else str(response_format).lower()
-
- if request_num == 1:
- if stream:
- # Use streaming response helper.
- with self.client.audio.speech.with_streaming_response.create(
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- ) as resp:
- omni_resp = self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
- else:
- # Non-streaming response.
- resp = self.client.audio.speech.create(
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- )
- omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
-
- assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
- responses.append(omni_resp)
- return responses
- else:
- # request_num > 1: concurrent requests (use same params as single-request path)
-
- if stream:
-
- def _stream_task():
- with self.client.audio.speech.with_streaming_response.create(
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- ) as resp:
- return self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
-
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = [executor.submit(_stream_task) for _ in range(request_num)]
- for future in concurrent.futures.as_completed(futures):
- omni_resp = future.result()
- assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
- responses.append(omni_resp)
- else:
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = []
- for _ in range(request_num):
- future = executor.submit(
- self.client.audio.speech.create,
- model=model,
- input=text_input,
- response_format=response_format,
- extra_body=extra_body or None,
- timeout=timeout,
- voice=voice,
- )
- futures.append(future)
-
- for future in concurrent.futures.as_completed(futures):
- resp = future.result()
- omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
- assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
- responses.append(omni_resp)
-
- return responses
-
- def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[DiffusionResponse]:
- """
- Send OpenAI requests for diffusion models.
-
- Args:
- request_config: Request configuration dictionary containing parameters like model, messages
- request_num: Number of requests to send concurrently, defaults to 1 (single request)
- Returns:
- List[DiffusionResponse]: List of response objects
- """
- responses: list[DiffusionResponse] = []
- stream = request_config.get("stream", False)
- modalities = request_config.get("modalities", omit) # Most diffusion models don't require modalities param
- extra_body = request_config.get("extra_body", None)
-
- if stream:
- raise NotImplementedError("Streaming is not currently implemented for diffusion model e2e test")
-
- if request_num == 1:
- # Send single request
- chat_completion = self.client.chat.completions.create(
- model=request_config.get("model"),
- messages=request_config.get("messages"),
- extra_body=extra_body,
- modalities=modalities,
- )
-
- response = self._process_diffusion_response(chat_completion)
- assert_diffusion_response(response, request_config, run_level=self.run_level)
- responses.append(response)
-
- else:
- # Send concurrent requests
- with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
- futures = []
-
- # Submit all request tasks
- for _ in range(request_num):
- future = executor.submit(
- self.client.chat.completions.create,
- model=request_config.get("model"),
- messages=request_config.get("messages"),
- modalities=modalities,
- extra_body=extra_body,
- )
- futures.append(future)
-
- # Process completed tasks
- for future in concurrent.futures.as_completed(futures):
- chat_completion = future.result()
- response = self._process_diffusion_response(chat_completion)
- assert_diffusion_response(response, request_config, run_level=self.run_level)
- responses.append(response)
-
- return responses
-
- def send_video_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
- """
- Send native /v1/videos requests.
- """
- if request_num != 1:
- raise NotImplementedError("Concurrent video diffusion requests are not currently implemented")
-
- if request_config.get("stream", False):
- raise NotImplementedError("Streaming is not currently implemented for video diffusion e2e test")
-
- form_data = request_config.get("form_data")
- if not isinstance(form_data, dict):
- raise ValueError("Video request_config must contain 'form_data'")
-
- if not form_data.get("prompt"):
- raise ValueError("Video request_config['form_data'] must contain 'prompt'")
-
- normalized_form_data = {key: str(value) for key, value in form_data.items() if value is not None}
-
- files: dict[str, tuple[str, BytesIO, str]] = {}
- image_reference = request_config.get("image_reference")
- if image_reference:
- if image_reference.startswith("data:image"):
- header, encoded = image_reference.split(",", 1)
- content_type = header.split(";")[0].removeprefix("data:")
- extension = content_type.split("/")[-1]
- file_data = base64.b64decode(encoded)
-
- files["input_reference"] = (
- f"reference.{extension}",
- BytesIO(file_data),
- content_type,
- )
- else:
- normalized_form_data["image_reference"] = json.dumps({"image_url": image_reference})
-
- result = DiffusionResponse()
- start_time = time.perf_counter()
-
- try:
- create_url = self._build_url("/v1/videos")
- response = requests.post(
- create_url,
- data=normalized_form_data,
- files=files,
- headers={"Accept": "application/json"},
- timeout=60,
- )
- response.raise_for_status()
-
- job_data = response.json()
- video_id = job_data["id"]
-
- self._wait_until_video_completed(video_id)
-
- video_content = self._download_video_content(video_id)
-
- result.success = True
- result.videos = [video_content]
- result.e2e_latency = time.perf_counter() - start_time
-
- assert_diffusion_response(result, request_config, run_level=self.run_level)
-
- except Exception as e:
- result.success = False
- result.error_message = f"Diffusion response processing error: {e}"
- assert False, result.error_message
-
- return [result]
-
- def _wait_until_video_completed(
- self,
- video_id: str,
- poll_interval_seconds: int = 2,
- timeout_seconds: int = 300,
- ) -> None:
- status_url = self._build_url(f"/v1/videos/{video_id}")
- deadline = time.monotonic() + timeout_seconds
-
- while time.monotonic() < deadline:
- status_resp = requests.get(
- status_url,
- headers={"Accept": "application/json"},
- timeout=30,
- )
- status_resp.raise_for_status()
-
- status_data = status_resp.json()
- current_status = status_data["status"]
-
- if current_status == "completed":
- return
-
- if current_status == "failed":
- error_msg = status_data.get("last_error", "Unknown error")
- raise RuntimeError(f"Job failed: {error_msg}")
-
- time.sleep(poll_interval_seconds)
-
- raise TimeoutError(f"Video job {video_id} did not complete within {timeout_seconds}s")
-
- def _download_video_content(self, video_id: str) -> bytes:
- download_url = self._build_url(f"/v1/videos/{video_id}/content")
- video_resp = requests.get(download_url, stream=True, timeout=60)
- video_resp.raise_for_status()
-
- video_bytes = BytesIO()
- for chunk in video_resp.iter_content(chunk_size=8192):
- if chunk:
- video_bytes.write(chunk)
-
- return video_bytes.getvalue()
-
- def _build_url(self, path: str) -> str:
- return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
-
-
-@pytest.fixture
-def openai_client(request: pytest.FixtureRequest, run_level: str):
- """Create OpenAIClientHandler fixture to facilitate communication with OmniServer
- with encapsulated request sending, concurrent requests, response handling, and validation."""
- server = request.getfixturevalue("omni_server")
- return OpenAIClientHandler(host=server.host, port=server.port, api_key="EMPTY", run_level=run_level)
-
-
-class OmniRunner:
- """
- Offline test runner for Omni models.
- """
-
- def __init__(
- self,
- model_name: str,
- seed: int = 42,
- stage_init_timeout: int = 600,
- batch_timeout: int = 10,
- init_timeout: int = 900,
- shm_threshold_bytes: int = 65536,
- log_stats: bool = False,
- stage_configs_path: str | None = None,
- **kwargs,
- ) -> None:
- """
- Initialize an OmniRunner for testing.
-
- Args:
- model_name: The model name or path
- seed: Random seed for reproducibility
- stage_init_timeout: Timeout for initializing a single stage in seconds
- batch_timeout: Timeout for batching in seconds
- init_timeout: Timeout for initializing stages in seconds
- shm_threshold_bytes: Threshold for using shared memory
- log_stats: Enable detailed statistics logging
- stage_configs_path: Optional path to YAML stage config file
- **kwargs: Additional arguments passed to Omni
- """
- cleanup_dist_env_and_memory()
- _run_pre_test_cleanup(enable_force=True)
- _run_post_test_cleanup(enable_force=True)
- self.model_name = model_name
- self.seed = seed
-
- self.omni = Omni(
- model=model_name,
- log_stats=log_stats,
- stage_init_timeout=stage_init_timeout,
- batch_timeout=batch_timeout,
- init_timeout=init_timeout,
- shm_threshold_bytes=shm_threshold_bytes,
- stage_configs_path=stage_configs_path,
- **kwargs,
- )
-
- def _estimate_prompt_len(
- self,
- additional_information: dict[str, Any],
- model_name: str,
- _cache: dict[str, Any] = {},
- ) -> int:
- """Estimate prompt_token_ids placeholder length for the Talker stage.
-
- The AR Talker replaces all input embeddings via ``preprocess``, so the
- placeholder values are irrelevant but the **length** must match the
- embeddings that ``preprocess`` will produce.
- """
- try:
- from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig
- from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
- Qwen3TTSTalkerForConditionalGeneration,
- )
-
- if model_name not in _cache:
- from transformers import AutoTokenizer
-
- tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
- cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True)
- _cache[model_name] = (tok, getattr(cfg, "talker_config", None))
-
- tok, tcfg = _cache[model_name]
- task_type = (additional_information.get("task_type") or ["CustomVoice"])[0]
- return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information(
- additional_information=additional_information,
- task_type=task_type,
- tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"],
- codec_language_id=getattr(tcfg, "codec_language_id", None),
- spk_is_dialect=getattr(tcfg, "spk_is_dialect", None),
- )
- except Exception as exc:
- logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc)
- return 2048
-
- def get_default_sampling_params_list(self) -> list[OmniSamplingParams]:
- """
- Get a list of default sampling parameters for all stages.
-
- Returns:
- List of SamplingParams with default decoding for each stage
- """
- if not hasattr(self.omni, "default_sampling_params_list"):
- raise AttributeError("Omni.default_sampling_params_list is not available")
- return list(self.omni.default_sampling_params_list)
-
- def get_omni_inputs(
- self,
- prompts: list[str] | str,
- system_prompt: str | None = None,
- audios: PromptAudioInput = None,
- images: PromptImageInput = None,
- videos: PromptVideoInput = None,
- mm_processor_kwargs: dict[str, Any] | None = None,
- modalities: list[str] | None = None,
- ) -> list[TextPrompt]:
- """
- Construct Omni input format from prompts and multimodal data.
-
- Args:
- prompts: Text prompt(s) - either a single string or list of strings
- system_prompt: Optional system prompt (defaults to Qwen system prompt)
- audios: Audio input(s) - tuple of (audio_array, sample_rate) or list of tuples
- images: Image input(s) - PIL Image or list of PIL Images
- videos: Video input(s) - numpy array or list of numpy arrays
- mm_processor_kwargs: Optional processor kwargs (e.g., use_audio_in_video)
-
- Returns:
- List of prompt dictionaries suitable for Omni.generate()
- """
- if system_prompt is None:
- system_prompt = (
- "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
- "Group, capable of perceiving auditory and visual inputs, as well as "
- "generating text and speech."
- )
-
- video_padding_token = "<|VIDEO|>"
- image_padding_token = "<|IMAGE|>"
- audio_padding_token = "<|AUDIO|>"
-
- if "Qwen3-Omni-30B-A3B-Instruct" in self.model_name:
- video_padding_token = "<|video_pad|>"
- image_padding_token = "<|image_pad|>"
- audio_padding_token = "<|audio_pad|>"
-
- if isinstance(prompts, str):
- prompts = [prompts]
-
- # Qwen-TTS: follow examples/offline_inference/qwen3_tts/end2end.py style.
- # Stage 0 expects token placeholders + additional_information (text/speaker/task_type/...),
- # and Talker replaces embeddings in preprocess based on additional_information only.
- is_tts_model = "Qwen3-TTS" in self.model_name or "qwen3_tts" in self.model_name.lower()
- if is_tts_model and modalities == ["audio"]:
- tts_kw = mm_processor_kwargs or {}
- task_type = tts_kw.get("task_type", "CustomVoice")
- speaker = tts_kw.get("speaker", "Vivian")
- language = tts_kw.get("language", "Auto")
- max_new_tokens = int(tts_kw.get("max_new_tokens", 2048))
- ref_audio = tts_kw.get("ref_audio", None)
- ref_text = tts_kw.get("ref_text", None)
-
- omni_inputs: list[TextPrompt] = []
- for prompt_text in prompts:
- text_str = str(prompt_text).strip() or " "
- additional_information: dict[str, Any] = {
- "task_type": [task_type],
- "text": [text_str],
- "language": [language],
- "speaker": [speaker],
- "max_new_tokens": [max_new_tokens],
- }
- if ref_audio is not None:
- additional_information["ref_audio"] = [ref_audio]
- if ref_text is not None:
- additional_information["ref_text"] = [ref_text]
- # Use official helper to get correct placeholder length
- plen = self._estimate_prompt_len(additional_information, self.model_name)
- input_dict: TextPrompt = {
- "prompt_token_ids": [0] * plen,
- "additional_information": additional_information,
- }
- omni_inputs.append(input_dict)
- return omni_inputs
-
- def _normalize_mm_input(mm_input, num_prompts):
- if mm_input is None:
- return [None] * num_prompts
- if isinstance(mm_input, list):
- if len(mm_input) != num_prompts:
- raise ValueError(
- f"Multimodal input list length ({len(mm_input)}) must match prompts length ({num_prompts})"
- )
- return mm_input
- return [mm_input] * num_prompts
-
- num_prompts = len(prompts)
- audios_list = _normalize_mm_input(audios, num_prompts)
- images_list = _normalize_mm_input(images, num_prompts)
- videos_list = _normalize_mm_input(videos, num_prompts)
-
- omni_inputs = []
- for i, prompt_text in enumerate(prompts):
- user_content = ""
- multi_modal_data = {}
-
- audio = audios_list[i]
- if audio is not None:
- if isinstance(audio, list):
- for _ in audio:
- user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
- multi_modal_data["audio"] = audio
- else:
- user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
- multi_modal_data["audio"] = audio
-
- image = images_list[i]
- if image is not None:
- if isinstance(image, list):
- for _ in image:
- user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
- multi_modal_data["image"] = image
- else:
- user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
- multi_modal_data["image"] = image
-
- video = videos_list[i]
- if video is not None:
- if isinstance(video, list):
- for _ in video:
- user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
- multi_modal_data["video"] = video
- else:
- user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
- multi_modal_data["video"] = video
-
- user_content += prompt_text
-
- full_prompt = (
- f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
- f"<|im_start|>user\n{user_content}<|im_end|>\n"
- f"<|im_start|>assistant\n"
- )
-
- input_dict: TextPrompt = {"prompt": full_prompt}
- if multi_modal_data:
- input_dict["multi_modal_data"] = multi_modal_data
- if modalities:
- input_dict["modalities"] = modalities
- if mm_processor_kwargs:
- input_dict["mm_processor_kwargs"] = mm_processor_kwargs
-
- omni_inputs.append(input_dict)
-
- return omni_inputs
-
- def generate(
- self,
- prompts: list[TextPrompt],
- sampling_params_list: list[OmniSamplingParams] | None = None,
- ) -> list[OmniRequestOutput]:
- """
- Generate outputs for the given prompts.
-
- Args:
- prompts: List of prompt dictionaries with 'prompt' and optionally
- 'multi_modal_data' keys
- sampling_params_list: List of sampling parameters for each stage.
- If None, uses default parameters.
-
- Returns:
- List of OmniRequestOutput objects from stages with final_output=True
- """
- if sampling_params_list is None:
- sampling_params_list = self.get_default_sampling_params_list()
-
- return self.omni.generate(prompts, sampling_params_list)
-
- def generate_multimodal(
- self,
- prompts: list[str] | str,
- sampling_params_list: list[OmniSamplingParams] | None = None,
- system_prompt: str | None = None,
- audios: PromptAudioInput = None,
- images: PromptImageInput = None,
- videos: PromptVideoInput = None,
- mm_processor_kwargs: dict[str, Any] | None = None,
- modalities: list[str] | None = None,
- ) -> list[OmniRequestOutput]:
- """
- Convenience method to generate with multimodal inputs.
-
- Args:
- prompts: Text prompt(s)
- sampling_params_list: List of sampling parameters for each stage
- system_prompt: Optional system prompt
- audios: Audio input(s)
- images: Image input(s)
- videos: Video input(s)
- mm_processor_kwargs: Optional processor kwargs
-
- Returns:
- List of OmniRequestOutput objects from stages with final_output=True
- """
- omni_inputs = self.get_omni_inputs(
- prompts=prompts,
- system_prompt=system_prompt,
- audios=audios,
- images=images,
- videos=videos,
- mm_processor_kwargs=mm_processor_kwargs,
- modalities=modalities,
- )
- return self.generate(omni_inputs, sampling_params_list)
-
- def start_profile(
- self,
- profile_prefix: str | None = None,
- stages: list[int] | None = None,
- ) -> list[Any]:
- """Start profiling specified stages.
-
- Args:
- profile_prefix: Optional prefix for the trace file names.
- stages: List of stage IDs to profile. If None, profiles all stages.
-
- Returns:
- List of results from each stage.
- """
- return self.omni.start_profile(profile_prefix=profile_prefix, stages=stages)
-
- def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
- """Stop profiling specified stages.
-
- Args:
- stages: List of stage IDs to profile. If None, stops all stages.
-
- Returns:
- List of results from each stage.
- """
- return self.omni.stop_profile(stages=stages)
-
- def _cleanup_process(self):
- try:
- keywords = ["enginecore"]
- matched = []
-
- for proc in psutil.process_iter(["pid", "name", "cmdline", "username"]):
- try:
- cmdline = " ".join(proc.cmdline()).lower() if proc.cmdline() else ""
- name = proc.name().lower()
-
- is_process = any(keyword in cmdline for keyword in keywords) or any(
- keyword in name for keyword in keywords
- )
-
- if is_process:
- print(f"Found vllm process: PID={proc.pid}, cmd={cmdline[:100]}")
- matched.append(proc)
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
-
- for proc in matched:
- try:
- proc.terminate()
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
-
- _, still_alive = psutil.wait_procs(matched, timeout=5)
- for proc in still_alive:
- try:
- proc.kill()
- except (psutil.NoSuchProcess, psutil.AccessDenied):
- pass
-
- if still_alive:
- _, stubborn = psutil.wait_procs(still_alive, timeout=3)
- if stubborn:
- print(f"Warning: failed to kill residual vllm pids: {[p.pid for p in stubborn]}")
- else:
- print(f"Force-killed residual vllm pids: {[p.pid for p in still_alive]}")
- elif matched:
- print(f"Terminated vllm pids: {[p.pid for p in matched]}")
-
- except Exception as e:
- print(f"Error in psutil vllm cleanup: {e}")
-
- def __enter__(self):
- """Context manager entry."""
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- """Context manager exit - cleanup resources."""
- if hasattr(self.omni, "close"):
- self.omni.close()
- self._cleanup_process()
- _run_pre_test_cleanup(enable_force=True)
- _run_post_test_cleanup(enable_force=True)
- cleanup_dist_env_and_memory()
-
-
-@pytest.fixture(scope="module")
-def omni_runner(request, model_prefix):
- with _omni_server_lock:
- model, stage_config_path = request.param
- model = model_prefix + model
- with OmniRunner(model, seed=42, stage_configs_path=stage_config_path) as runner:
- print("OmniRunner started successfully")
- yield runner
- print("OmniRunner stopping...")
-
- print("OmniRunner stopped")
-
-
-class OmniRunnerHandler:
- def __init__(self, omni_runner):
- self.runner = omni_runner
-
- def _process_output(self, outputs: list[Any]) -> OmniResponse:
- result = OmniResponse()
- try:
- text_content = None
- audio_content = None
- for stage_output in outputs:
- if getattr(stage_output, "final_output_type", None) == "text":
- text_content = stage_output.request_output.outputs[0].text
- if getattr(stage_output, "final_output_type", None) == "audio":
- audio_content = stage_output.request_output.outputs[0].multimodal_output["audio"]
-
- result.audio_content = audio_content
- result.text_content = text_content
- result.success = True
-
- except Exception as e:
- result.error_message = f"Output processing error: {str(e)}"
- result.success = False
- print(f"Error: {result.error_message}")
-
- return result
-
- def send_request(self, request_config: dict[str, Any] | None = None) -> OmniResponse:
- if request_config is None:
- request_config = {}
- prompts = request_config.get("prompts")
- videos = request_config.get("videos")
- images = request_config.get("images")
- audios = request_config.get("audios")
- modalities = request_config.get("modalities", ["text", "audio"])
- outputs = self.runner.generate_multimodal(
- prompts=prompts, videos=videos, images=images, audios=audios, modalities=modalities
- )
- response = self._process_output(outputs)
- assert_omni_response(response, request_config, run_level="core_model")
- return response
-
- def send_audio_speech_request(
- self,
- request_config: dict[str, Any],
- ) -> OmniResponse:
- """
- Offline TTS: text -> audio via generate_multimodal, then validate with assert_audio_speech_response.
-
- request_config must contain:
- - 'input' or 'prompts': text to synthesize.
- Optional keys:
- - 'voice' -> speaker (CustomVoice)
- - 'task_type' -> task_type in additional_information (default: "CustomVoice")
- - 'language' -> language in additional_information (default: "Auto")
- - 'max_new_tokens' -> max_new_tokens in additional_information (default: 2048)
- - 'response_format' -> desired audio format (used only for assertion)
- """
- input_text = request_config.get("input") or request_config.get("prompts")
- if input_text is None:
- raise ValueError("request_config must contain 'input' or 'prompts' for TTS")
- if isinstance(input_text, list):
- input_text = input_text[0] if input_text else ""
-
- # Build TTS-specific kwargs passed through to get_omni_inputs for Qwen3-TTS,
- # matching examples/offline_inference/qwen3_tts/end2end.py.
- mm_processor_kwargs: dict[str, Any] = {}
- if "voice" in request_config:
- mm_processor_kwargs["speaker"] = request_config["voice"]
- if "task_type" in request_config:
- mm_processor_kwargs["task_type"] = request_config["task_type"]
- if "ref_audio" in request_config:
- mm_processor_kwargs["ref_audio"] = request_config["ref_audio"]
- if "ref_text" in request_config:
- mm_processor_kwargs["ref_text"] = request_config["ref_text"]
- if "language" in request_config:
- mm_processor_kwargs["language"] = request_config["language"]
- if "max_new_tokens" in request_config:
- mm_processor_kwargs["max_new_tokens"] = request_config["max_new_tokens"]
-
- outputs = self.runner.generate_multimodal(
- prompts=input_text,
- modalities=["audio"],
- mm_processor_kwargs=mm_processor_kwargs or None,
- )
- mm_out: dict[str, Any] | None = None
- for stage_out in outputs:
- if getattr(stage_out, "final_output_type", None) == "audio":
- mm_out = stage_out.request_output.outputs[0].multimodal_output
- break
- if mm_out is None:
- result = OmniResponse(success=False, error_message="No audio output from pipeline")
- assert result.success, result.error_message
- return result
-
- audio_data = mm_out.get("audio")
- if audio_data is None:
- result = OmniResponse(success=False, error_message="No audio tensor in multimodal output")
- assert result.success, result.error_message
- return result
-
- sr_raw = mm_out.get("sr")
- sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw
- sr = int(sr_val.item() if hasattr(sr_val, "item") else sr_val)
- wav_tensor = torch.cat(audio_data, dim=-1) if isinstance(audio_data, list) else audio_data
- wav_buf = io.BytesIO()
- sf.write(
- wav_buf,
- wav_tensor.float().cpu().numpy().reshape(-1),
- samplerate=sr,
- format="WAV",
- subtype="PCM_16",
- )
- result = OmniResponse(success=True, audio_bytes=wav_buf.getvalue(), audio_format="audio/wav")
- assert_audio_speech_response(result, request_config, run_level="core_model")
- return result
-
- def start_profile(
- self,
- profile_prefix: str | None = None,
- stages: list[int] | None = None,
- ) -> list[Any]:
- """Start profiling specified stages."""
- return self.runner.start_profile(profile_prefix=profile_prefix, stages=stages)
-
- def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
- """Stop profiling specified stages."""
- return self.runner.stop_profile(stages=stages)
-
-
-@pytest.fixture
-def omni_runner_handler(omni_runner):
- return OmniRunnerHandler(omni_runner)
diff --git a/tests/core/sched/test_omni_scheduler_mixin.py b/tests/core/sched/test_omni_scheduler_mixin.py
new file mode 100644
index 00000000000..e04a9c39fbc
--- /dev/null
+++ b/tests/core/sched/test_omni_scheduler_mixin.py
@@ -0,0 +1,129 @@
+"""Unit tests for OmniSchedulerMixin streaming session replacement.
+
+These tests pin the behavior of `_replace_session_with_streaming_update` against
+current vLLM `Request` / `StreamingUpdate` (and Omni patches). When upgrading
+vLLM, failures here should highlight incompatible changes to request state or
+update payloads early.
+"""
+
+from __future__ import annotations
+
+from dataclasses import replace
+
+import pytest
+
+# Imports must run in this order: vllm_omni applies patches to vllm.v1.request before
+# Request / StreamingUpdate are bound in this module. Ruff isort would reorder them.
+# isort: off
+import vllm_omni # noqa: F401 - import for side effects (patch vLLM)
+from vllm.sampling_params import SamplingParams
+from vllm.v1.engine import EngineCoreEventType
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
+from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
+
+# isort: on
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+class _SchedulerStub(OmniSchedulerMixin):
+ """Minimal scheduler surface required by OmniSchedulerMixin."""
+
+ def __init__(self, *, log_stats: bool = False) -> None:
+ self.num_waiting_for_streaming_input = 0
+ self.log_stats = log_stats
+
+
+def _make_request(**kwargs) -> Request:
+ sp = SamplingParams(max_tokens=8)
+ defaults = dict(
+ request_id="req-mixin-test",
+ prompt_token_ids=[1, 2, 3],
+ sampling_params=sp,
+ pooling_params=None,
+ arrival_time=100.0,
+ block_hasher=None,
+ )
+ defaults.update(kwargs)
+ return Request(**defaults)
+
+
+def _make_update(**kwargs) -> StreamingUpdate:
+ sp_new = SamplingParams(max_tokens=16)
+ defaults = dict(
+ mm_features=None,
+ prompt_token_ids=[10, 20],
+ max_tokens=32,
+ arrival_time=200.0,
+ sampling_params=sp_new,
+ )
+ defaults.update(kwargs)
+ return StreamingUpdate(**defaults)
+
+
+class TestReplaceSessionWithStreamingUpdate:
+ def test_resets_tokens_and_prompt_from_update(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ session.append_output_token_ids([7, 8])
+ session.num_computed_tokens = 99
+ session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
+
+ update = _make_update(prompt_token_ids=[40, 41, 42])
+ sched.num_waiting_for_streaming_input = 3
+ sched._replace_session_with_streaming_update(session, update)
+
+ assert session._output_token_ids == []
+ assert list(session._all_token_ids) == [40, 41, 42]
+ assert session.prompt_token_ids == [40, 41, 42]
+ assert session.num_computed_tokens == 0
+ assert session.num_prompt_tokens == 3
+ assert session.arrival_time == 200.0
+ assert session.sampling_params is update.sampling_params
+ assert session.status == RequestStatus.WAITING
+ assert sched.num_waiting_for_streaming_input == 2
+
+ def test_none_prompt_token_ids_becomes_empty(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ session.status = RequestStatus.RUNNING
+ update = _make_update(prompt_token_ids=None)
+ sched._replace_session_with_streaming_update(session, update)
+
+ assert session.prompt_token_ids == ()
+ assert list(session._all_token_ids) == []
+ assert session.num_prompt_tokens == 0
+ assert sched.num_waiting_for_streaming_input == 0
+
+ def test_additional_information_cleared_when_update_omits_it(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ if not hasattr(session, "additional_information"):
+ pytest.skip("Request has no additional_information (Omni patch inactive?)")
+ session.additional_information = {"keep": True}
+ session.status = RequestStatus.RUNNING
+
+ base = _make_update()
+ if not hasattr(base, "additional_information"):
+ pytest.skip("StreamingUpdate has no additional_information (Omni patch inactive?)")
+ update = replace(base, additional_information=None)
+
+ sched._replace_session_with_streaming_update(session, update)
+ assert session.additional_information is None
+
+ def test_does_not_decrement_waiting_when_not_streaming_status(self) -> None:
+ sched = _SchedulerStub()
+ session = _make_request()
+ session.status = RequestStatus.RUNNING
+ sched.num_waiting_for_streaming_input = 5
+ sched._replace_session_with_streaming_update(session, _make_update())
+ assert sched.num_waiting_for_streaming_input == 5
+
+ def test_records_queued_event_when_log_stats_enabled(self) -> None:
+ sched = _SchedulerStub(log_stats=True)
+ session = _make_request()
+ session.status = RequestStatus.WAITING_FOR_STREAMING_REQ
+ sched._replace_session_with_streaming_update(session, _make_update())
+
+ assert session.events
+ assert session.events[-1].type == EngineCoreEventType.QUEUED
diff --git a/tests/core/test_prefix_cache.py b/tests/core/test_prefix_cache.py
index c3d8c1ff928..b5d0e96d305 100644
--- a/tests/core/test_prefix_cache.py
+++ b/tests/core/test_prefix_cache.py
@@ -1,5 +1,3 @@
-from unittest.mock import Mock, patch
-
import pytest
import torch
@@ -19,10 +17,14 @@ def __init__(self, num_computed_tokens_cpu):
self.req_ids = ["req1", "req2"]
self.req_id_to_index = {req_id: i for i, req_id in enumerate(self.req_ids)}
self.num_computed_tokens_cpu = num_computed_tokens_cpu
+
# Block table is only mocked for validation of length;
# we don't actually need to add valid values here since
# we patch the table when testing.
- self.block_table = Mock()
+ class _DummyBlockTable:
+ pass
+
+ self.block_table = _DummyBlockTable()
self.block_table.block_tables = [None]
@@ -186,7 +188,7 @@ def fake_get_cached_block_ids(self, req_idx, *args, **kwargs):
@pytest.mark.parametrize("num_tokens_padded", [None, 16])
-def test_get_merged_hidden_states(num_tokens_padded):
+def test_get_merged_hidden_states(num_tokens_padded, mocker):
"""Ensure that hidden states are merged correctly."""
cache = get_omni_pcache()
@@ -221,16 +223,16 @@ def test_get_merged_hidden_states(num_tokens_padded):
input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0]))
- with patch(
+ mocker.patch(
"vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids",
new=fake_get_cached_block_ids,
- ):
- merged_states = cache.get_merged_hidden_states(
- query_start_loc=[0, num_new_toks_req1],
- input_batch=input_batch,
- hidden_states=new_hidden_states,
- num_scheduled_tokens=num_scheduled_tokens,
- )
+ )
+ merged_states = cache.get_merged_hidden_states(
+ query_start_loc=[0, num_new_toks_req1],
+ input_batch=input_batch,
+ hidden_states=new_hidden_states,
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
assert "req1" in merged_states and "req2" in merged_states
req1_merged_states = merged_states["req1"]
@@ -255,7 +257,7 @@ def test_get_merged_hidden_states(num_tokens_padded):
{"foo": 100, "bar": 50, "baz": 10},
],
)
-def test_get_merged_multimodal_outputs(feat_dims, num_tokens_padded):
+def test_get_merged_multimodal_outputs(feat_dims, num_tokens_padded, mocker):
cache = get_omni_pcache_with_mm_tensors(feat_dims, seq_len=DEFAULT_SEQ_LEN)
orig_num_tokens_unpadded = 8
@@ -298,16 +300,16 @@ def test_get_merged_multimodal_outputs(feat_dims, num_tokens_padded):
input_batch = MockInputBatch(num_computed_tokens_cpu=torch.Tensor([orig_num_tokens_unpadded, 0]))
- with patch(
+ mocker.patch(
"vllm_omni.core.prefix_cache.OmniTensorPrefixCache._get_cached_block_ids",
new=fake_get_cached_block_ids,
- ):
- merged_mm_outputs = cache.get_merged_multimodal_states(
- query_start_loc=[0, num_new_toks_req1],
- input_batch=input_batch,
- multimodal_outputs=new_mm_outputs,
- num_scheduled_tokens=num_scheduled_tokens,
- )
+ )
+ merged_mm_outputs = cache.get_merged_multimodal_states(
+ query_start_loc=[0, num_new_toks_req1],
+ input_batch=input_batch,
+ multimodal_outputs=new_mm_outputs,
+ num_scheduled_tokens=num_scheduled_tokens,
+ )
# Ensure the passthrough data wasn't dropped
assert "passthrough_data" in merged_mm_outputs
diff --git a/tests/dfx/conftest.py b/tests/dfx/conftest.py
index 997f25e6e54..c3f6d0a15d8 100644
--- a/tests/dfx/conftest.py
+++ b/tests/dfx/conftest.py
@@ -4,7 +4,7 @@
import pytest
-from tests.conftest import modify_stage_config
+from tests.helpers.stage_config import modify_stage_config
def load_configs(config_path: str) -> list[dict[str, Any]]:
@@ -40,22 +40,32 @@ def modify_stage(default_path, updates, deletes):
def create_unique_server_params(
configs: list[dict[str, Any]],
stage_configs_dir: Path,
-) -> list[tuple[str, str, str]]:
+) -> list[tuple[str, str, str | None, str | None, tuple[str, ...]]]:
unique_params = []
seen = set()
for config in configs:
test_name = config["test_name"]
- model = config["server_params"]["model"]
- stage_config_name = config["server_params"].get("stage_config_name")
+ server_params = config["server_params"]
+ model = server_params["model"]
+ stage_config_name = server_params.get("stage_config_name")
if stage_config_name:
stage_config_path = str(stage_configs_dir / stage_config_name)
- delete = config["server_params"].get("delete", None)
- update = config["server_params"].get("update", None)
+ delete = server_params.get("delete", None)
+ update = server_params.get("update", None)
stage_config_path = modify_stage(stage_config_path, update, delete)
else:
stage_config_path = None
- server_param = (test_name, model, stage_config_path)
+ stage_overrides = server_params.get("stage_overrides")
+ stage_overrides_json = json.dumps(stage_overrides) if stage_overrides else None
+
+ # ``extra_cli_args`` passes raw CLI flags straight through to
+ # ``vllm_omni.entrypoints.cli.main serve`` — used for flags that
+ # don't map to stage-level overrides, e.g. ``--async-chunk`` /
+ # ``--no-async-chunk`` toggling the deploy-level async_chunk bool.
+ extra_cli_args = tuple(server_params.get("extra_cli_args") or ())
+
+ server_param = (test_name, model, stage_config_path, stage_overrides_json, extra_cli_args)
if server_param not in seen:
seen.add(server_param)
unique_params.append(server_param)
diff --git a/tests/dfx/perf/scripts/run_benchmark.py b/tests/dfx/perf/scripts/run_benchmark.py
index bea46f684be..13011b4bdab 100644
--- a/tests/dfx/perf/scripts/run_benchmark.py
+++ b/tests/dfx/perf/scripts/run_benchmark.py
@@ -8,7 +8,6 @@
import pytest
-from tests.conftest import OmniServer
from tests.dfx.conftest import (
create_benchmark_indices,
create_test_parameter_mapping,
@@ -16,6 +15,7 @@
get_benchmark_params_for_server,
load_configs,
)
+from tests.helpers.runtime import OmniServer
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
@@ -48,8 +48,8 @@ def _get_config_file_from_argv() -> str | None:
OMNI_RESULT_TEMPLATE_PATH = Path(__file__).parent / "result_omni_template.json"
-STAGE_CONFIGS_DIR = Path(__file__).parent.parent / "stage_configs"
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, STAGE_CONFIGS_DIR)
+DEPLOY_CONFIGS_DIR = Path(__file__).parent.parent / "deploy"
+test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR)
server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS)
_omni_server_lock = threading.Lock()
@@ -62,13 +62,19 @@ def omni_server(request):
Multi-stage initialization can take 10-20+ minutes.
"""
with _omni_server_lock:
- test_name, model, stage_config_path = request.param
+ test_name, model, stage_config_path, stage_overrides, extra_cli_args = request.param
print(f"Starting OmniServer with test: {test_name}, model: {model}")
server_args = ["--stage-init-timeout", "600", "--init-timeout", "900"]
+ # --deploy-config and --stage-overrides compose at the CLI (see vllm_omni/entrypoints/utils.py):
+ # deploy-config sets the base; stage-overrides are applied on top. Both can be set.
if stage_config_path:
- server_args = ["--stage-configs-path", stage_config_path] + server_args
+ server_args = ["--deploy-config", stage_config_path] + server_args
+ if stage_overrides:
+ server_args = ["--stage-overrides", stage_overrides] + server_args
+ if extra_cli_args:
+ server_args = list(extra_cli_args) + server_args
with OmniServer(model, server_args) as server:
server.test_name = test_name
print("OmniServer started successfully")
diff --git a/tests/dfx/perf/stage_configs/qwen3_omni.yaml b/tests/dfx/perf/stage_configs/qwen3_omni.yaml
deleted file mode 100644
index 2add22b8732..00000000000
--- a/tests/dfx/perf/stage_configs/qwen3_omni.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 100000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/dfx/perf/stage_configs/qwen3_tts.yaml b/tests/dfx/perf/stage_configs/qwen3_tts.yaml
deleted file mode 100644
index 97b30905603..00000000000
--- a/tests/dfx/perf/stage_configs/qwen3_tts.yaml
+++ /dev/null
@@ -1,96 +0,0 @@
-# Stage config for running Qwen3-TTS with 2-stage architecture
-# Stage 0: Talker (text -> 8-layer RVQ codec codes)
-# Stage 1: Code2Wav (codec codes -> audio waveform)
-#
-# The following config has been verified on 1x H100-80G GPU.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: qwen3_tts
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 4
- model_stage: code2wav
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 8192
- max_model_len: 32768
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 4
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- codec_left_context_frames: 72
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/tests/dfx/perf/tests/test_qwen_omni.json b/tests/dfx/perf/tests/test_qwen_omni.json
index 4662f8c0c71..ca3eb555708 100644
--- a/tests/dfx/perf/tests/test_qwen_omni.json
+++ b/tests/dfx/perf/tests/test_qwen_omni.json
@@ -3,7 +3,7 @@
"test_name": "test_qwen3_omni",
"server_params": {
"model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml"
+ "extra_cli_args": ["--no-async-chunk"]
},
"benchmark_params": [
{
@@ -109,25 +109,7 @@
"test_name": "test_qwen3_omni_chunk",
"server_params": {
"model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml",
- "update": {
- "async_chunk": true,
- "stage_args": {
- "0": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
- },
- "1": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- }
- }
- },
- "delete": {
- "stage_args": {
- "2": [
- "custom_process_input_func"
- ]
- }
- }
+ "extra_cli_args": ["--async-chunk"]
},
"benchmark_params": [
{
diff --git a/tests/dfx/stability/conftest.py b/tests/dfx/stability/conftest.py
index 3a0aee7608f..e36c88b9aa6 100644
--- a/tests/dfx/stability/conftest.py
+++ b/tests/dfx/stability/conftest.py
@@ -1,125 +1,17 @@
-"""
-Stability-specific conftest: when pytest is executed under this directory,
-resource monitoring is started before each test and finalized after each test,
-so each stability test case gets its own HTML report (one report per case).
-No need to wrap pytest with `bash resource_monitor.sh run -- pytest ...`.
-"""
+"""Stability pytest hooks and fixtures."""
-import os
import subprocess
import sys
import threading
-import time
-from pathlib import Path
import pytest
-STABILITY_DIR = Path(__file__).resolve().parent
-RESOURCE_MONITOR_SCRIPT = STABILITY_DIR / "scripts" / "resource_monitor.sh"
-REPO_ROOT = STABILITY_DIR.parent.parent.parent
-
-
-def _start_resource_monitor():
- """Start `resource_monitor.sh start` in the background and return `Popen` or `None`."""
- if not RESOURCE_MONITOR_SCRIPT.is_file():
- return None
- try:
- proc = subprocess.Popen(
- ["bash", str(RESOURCE_MONITOR_SCRIPT), "start", "--backend", "gpu"],
- cwd=str(REPO_ROOT),
- stdout=subprocess.DEVNULL,
- stderr=subprocess.PIPE,
- start_new_session=True,
- )
- try:
- proc.wait(timeout=2)
- if proc.returncode != 0:
- stderr = proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else ""
- if stderr.strip():
- sys.stderr.write(f"[Stability] Resource monitor failed to start: {stderr.strip()}\n")
- return None
- except subprocess.TimeoutExpired:
- pass
- return proc
- except (FileNotFoundError, OSError):
- return None
-
-
-def _get_monitor_data_root() -> Path:
- data_root = os.environ.get("RESOURCE_MONITOR_DATA_ROOT") or os.environ.get("GPU_MONITOR_DATA_ROOT")
- if data_root:
- return Path(data_root)
- return STABILITY_DIR / "gpu_monitor_data"
-
-
-def _wait_for_run_dir(timeout_sec: int = 10) -> Path | None:
- data_root = _get_monitor_data_root()
- run_id_file = data_root / "current_run_id"
- deadline = time.time() + timeout_sec
- while time.time() < deadline:
- if run_id_file.is_file():
- run_id = run_id_file.read_text(encoding="utf-8").strip()
- if run_id:
- run_dir = data_root / run_id
- if run_dir.is_dir():
- return run_dir
- time.sleep(0.5)
- return None
-
-
-def _report_latest_gpu_samples(stop_event: threading.Event) -> None:
- """Periodically print the latest sampled GPU line."""
- log_interval = int(
- os.environ.get("RESOURCE_MONITOR_LOG_INTERVAL") or os.environ.get("GPU_MONITOR_LOG_INTERVAL") or "15"
- )
- log_interval = max(log_interval, 1)
- last_line = ""
-
- time.sleep(min(log_interval, 5))
- while not stop_event.wait(log_interval):
- run_dir = _wait_for_run_dir(timeout_sec=1)
- if run_dir is None:
- continue
- csv_file = run_dir / "gpu_metrics.csv"
- if not csv_file.is_file():
- continue
- try:
- lines = csv_file.read_text(encoding="utf-8").splitlines()
- except OSError:
- continue
- if len(lines) <= 1:
- continue
- latest = lines[-1].strip()
- if latest and latest != last_line:
- last_line = latest
- sys.stderr.write(f"[GPU] {latest}\n")
-
-
-def _finalize_resource_monitor() -> str | None:
- """
- Run `resource_monitor.sh finalize` for the current run and generate the report.
- Returns the bundle dir path (for this test case's report) if successful, else None.
- """
- if not RESOURCE_MONITOR_SCRIPT.is_file():
- return None
- try:
- result = subprocess.run(
- ["bash", str(RESOURCE_MONITOR_SCRIPT), "finalize", "--backend", "gpu"],
- cwd=str(REPO_ROOT),
- capture_output=True,
- text=True,
- timeout=60,
- check=False,
- )
- if result.returncode != 0:
- return None
- for line in (result.stdout or "").splitlines():
- if line.startswith("GPU_MONITOR_BUNDLE_DIR=") or line.startswith("RESOURCE_MONITOR_BUNDLE_DIR="):
- _, _, value = line.partition("=")
- return value.strip() if value else None
- return None
- except (FileNotFoundError, OSError, subprocess.TimeoutExpired):
- return None
+from tests.dfx.stability.helpers import (
+ finalize_resource_monitor,
+ report_latest_gpu_samples,
+ start_resource_monitor,
+ wait_for_run_dir,
+)
@pytest.fixture(autouse=True)
@@ -128,19 +20,19 @@ def stability_resource_monitor_per_test(request: pytest.FixtureRequest):
For each test under this directory: start GPU monitor before the test,
then finalize after the test so this case gets its own report.html.
"""
- proc = _start_resource_monitor()
+ proc = start_resource_monitor()
stop_event = threading.Event()
reporter: threading.Thread | None = None
if proc is not None:
reporter = threading.Thread(
- target=_report_latest_gpu_samples,
+ target=report_latest_gpu_samples,
args=(stop_event,),
name="stability-resource-monitor-reporter",
daemon=True,
)
reporter.start()
- run_dir = _wait_for_run_dir(timeout_sec=5)
+ run_dir = wait_for_run_dir(timeout_sec=5)
node_name = request.node.name
if run_dir is not None:
sys.stderr.write(f"[Stability] Resource monitor started for test: {node_name} | run dir: {run_dir}\n")
@@ -161,7 +53,7 @@ def stability_resource_monitor_per_test(request: pytest.FixtureRequest):
except subprocess.TimeoutExpired:
proc.kill()
proc.wait()
- bundle_dir = _finalize_resource_monitor()
+ bundle_dir = finalize_resource_monitor()
node_name = request.node.name
if bundle_dir:
sys.stderr.write(f"[Stability] Report for test «{node_name}»: {bundle_dir}/report.html\n")
diff --git a/tests/dfx/stability/helpers.py b/tests/dfx/stability/helpers.py
new file mode 100644
index 00000000000..3956bd21304
--- /dev/null
+++ b/tests/dfx/stability/helpers.py
@@ -0,0 +1,117 @@
+"""Stability resource monitor helpers."""
+
+from __future__ import annotations
+
+import os
+import subprocess
+import sys
+import threading
+import time
+from pathlib import Path
+
+STABILITY_DIR = Path(__file__).resolve().parent
+RESOURCE_MONITOR_SCRIPT = STABILITY_DIR / "scripts" / "resource_monitor.sh"
+REPO_ROOT = STABILITY_DIR.parent.parent.parent
+
+
+def start_resource_monitor():
+ """Start `resource_monitor.sh start` in the background and return `Popen` or `None`."""
+ if not RESOURCE_MONITOR_SCRIPT.is_file():
+ return None
+ try:
+ proc = subprocess.Popen(
+ ["bash", str(RESOURCE_MONITOR_SCRIPT), "start", "--backend", "gpu"],
+ cwd=str(REPO_ROOT),
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.PIPE,
+ start_new_session=True,
+ )
+ try:
+ proc.wait(timeout=2)
+ if proc.returncode != 0:
+ stderr = proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else ""
+ if stderr.strip():
+ sys.stderr.write(f"[Stability] Resource monitor failed to start: {stderr.strip()}\n")
+ return None
+ except subprocess.TimeoutExpired:
+ pass
+ return proc
+ except (FileNotFoundError, OSError):
+ return None
+
+
+def get_monitor_data_root() -> Path:
+ data_root = os.environ.get("RESOURCE_MONITOR_DATA_ROOT") or os.environ.get("GPU_MONITOR_DATA_ROOT")
+ if data_root:
+ return Path(data_root)
+ return STABILITY_DIR / "gpu_monitor_data"
+
+
+def wait_for_run_dir(timeout_sec: int = 10) -> Path | None:
+ data_root = get_monitor_data_root()
+ run_id_file = data_root / "current_run_id"
+ deadline = time.time() + timeout_sec
+ while time.time() < deadline:
+ if run_id_file.is_file():
+ run_id = run_id_file.read_text(encoding="utf-8").strip()
+ if run_id:
+ run_dir = data_root / run_id
+ if run_dir.is_dir():
+ return run_dir
+ time.sleep(0.5)
+ return None
+
+
+def report_latest_gpu_samples(stop_event: threading.Event) -> None:
+ """Periodically print the latest sampled GPU line."""
+ log_interval = int(
+ os.environ.get("RESOURCE_MONITOR_LOG_INTERVAL") or os.environ.get("GPU_MONITOR_LOG_INTERVAL") or "15"
+ )
+ log_interval = max(log_interval, 1)
+ last_line = ""
+
+ time.sleep(min(log_interval, 5))
+ while not stop_event.wait(log_interval):
+ run_dir = wait_for_run_dir(timeout_sec=1)
+ if run_dir is None:
+ continue
+ csv_file = run_dir / "gpu_metrics.csv"
+ if not csv_file.is_file():
+ continue
+ try:
+ lines = csv_file.read_text(encoding="utf-8").splitlines()
+ except OSError:
+ continue
+ if len(lines) <= 1:
+ continue
+ latest = lines[-1].strip()
+ if latest and latest != last_line:
+ last_line = latest
+ sys.stderr.write(f"[GPU] {latest}\n")
+
+
+def finalize_resource_monitor() -> str | None:
+ """
+ Run `resource_monitor.sh finalize` for the current run and generate the report.
+ Returns the bundle dir path (for this test case's report) if successful, else None.
+ """
+ if not RESOURCE_MONITOR_SCRIPT.is_file():
+ return None
+ try:
+ result = subprocess.run(
+ ["bash", str(RESOURCE_MONITOR_SCRIPT), "finalize", "--backend", "gpu"],
+ cwd=str(REPO_ROOT),
+ capture_output=True,
+ text=True,
+ timeout=60,
+ check=False,
+ )
+ if result.returncode != 0:
+ return None
+ for line in (result.stdout or "").splitlines():
+ if line.startswith("GPU_MONITOR_BUNDLE_DIR=") or line.startswith("RESOURCE_MONITOR_BUNDLE_DIR="):
+ _, _, value = line.partition("=")
+ return value.strip() if value else None
+ return None
+ except (FileNotFoundError, OSError, subprocess.TimeoutExpired):
+ return None
diff --git a/tests/dfx/stability/scripts/test_benchmark_stability.py b/tests/dfx/stability/scripts/test_benchmark_stability.py
index a9faae8ab84..620241762d3 100644
--- a/tests/dfx/stability/scripts/test_benchmark_stability.py
+++ b/tests/dfx/stability/scripts/test_benchmark_stability.py
@@ -24,7 +24,6 @@
import pytest
-from tests.conftest import OmniServer
from tests.dfx.conftest import (
create_benchmark_indices,
create_test_parameter_mapping,
@@ -33,9 +32,10 @@
load_configs,
)
from tests.dfx.perf.scripts.run_benchmark import run_benchmark
+from tests.helpers.runtime import OmniServer
STABILITY_DIR = Path(__file__).resolve().parent.parent
-STAGE_CONFIGS_DIR = STABILITY_DIR / "stage_configs"
+DEPLOY_CONFIGS_DIR = STABILITY_DIR / "deploy"
CONFIG_FILE_PATH = str(STABILITY_DIR / "tests" / "test.json")
DEFAULT_NUM_PROMPTS_PER_BATCH = 20
@@ -45,7 +45,7 @@
except FileNotFoundError:
BENCHMARK_CONFIGS = []
-test_params = create_unique_server_params(BENCHMARK_CONFIGS, STAGE_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
+test_params = create_unique_server_params(BENCHMARK_CONFIGS, DEPLOY_CONFIGS_DIR) if BENCHMARK_CONFIGS else []
server_to_benchmark_mapping = create_test_parameter_mapping(BENCHMARK_CONFIGS) if BENCHMARK_CONFIGS else {}
_omni_server_lock = threading.Lock()
@@ -219,11 +219,20 @@ def omni_server(request):
Multi-stage initialization can take 10-20+ minutes.
"""
with _omni_server_lock:
- test_name, model, stage_config_path = request.param
+ test_name, model, stage_config_path, stage_overrides, extra_cli_args = request.param
print(f"Starting OmniServer with test: {test_name}, model: {model}")
- with OmniServer(model, ["--stage-configs-path", stage_config_path, "--stage-init-timeout", "120"]) as server:
+ server_args = ["--stage-init-timeout", "120"]
+ # --deploy-config and --stage-overrides compose at the CLI (see vllm_omni/entrypoints/utils.py):
+ # deploy-config sets the base; stage-overrides are applied on top. Both can be set.
+ if stage_config_path:
+ server_args = ["--deploy-config", stage_config_path] + server_args
+ if stage_overrides:
+ server_args = ["--stage-overrides", stage_overrides] + server_args
+ if extra_cli_args:
+ server_args = list(extra_cli_args) + server_args
+ with OmniServer(model, server_args) as server:
server.test_name = test_name
print("OmniServer started successfully")
yield server
diff --git a/tests/dfx/stability/stage_configs/qwen3_omni.yaml b/tests/dfx/stability/stage_configs/qwen3_omni.yaml
deleted file mode 100644
index 802f8dd2494..00000000000
--- a/tests/dfx/stability/stage_configs/qwen3_omni.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type to launch OmniLLM
- runtime:
- devices: "0"
- max_batch_size: 64
- engine_args:
- model_stage: thinker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type to launch OmniLLM
- runtime:
- devices: "1"
- max_batch_size: 64
- engine_args:
- model_stage: talker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type to launch OmniLLM
- runtime:
- devices: "1"
- max_batch_size: 64
- engine_args:
- model_stage: code2wav
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/dfx/stability/tests/test.json b/tests/dfx/stability/tests/test.json
index 95993c9c556..255cd5b1091 100644
--- a/tests/dfx/stability/tests/test.json
+++ b/tests/dfx/stability/tests/test.json
@@ -3,7 +3,11 @@
"test_name": "test_qwen3_omni_stability",
"server_params": {
"model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml"
+ "stage_overrides": {
+ "2": {
+ "max_num_batched_tokens": 1000000
+ }
+ }
},
"benchmark_params": [
{
@@ -36,25 +40,12 @@
"test_name": "test_qwen3_omni_stability_async_chunk",
"server_params": {
"model": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
- "stage_config_name": "qwen3_omni.yaml",
- "update": {
- "async_chunk": true,
- "stage_args": {
- "0": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
- },
- "1": {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- }
+ "stage_overrides": {
+ "2": {
+ "max_num_batched_tokens": 1000000
}
},
- "delete": {
- "stage_args": {
- "2": [
- "custom_process_input_func"
- ]
- }
- }
+ "extra_cli_args": ["--async-chunk"]
},
"benchmark_params": [
{
diff --git a/tests/diffusion/cache/test_cache_dit.py b/tests/diffusion/cache/test_cache_dit.py
new file mode 100644
index 00000000000..0b7ef723585
--- /dev/null
+++ b/tests/diffusion/cache/test_cache_dit.py
@@ -0,0 +1,40 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""
+Model specific tests for CacheDiT enablement.
+"""
+
+from unittest.mock import Mock, patch
+
+import pytest
+
+import vllm_omni.diffusion.cache.cache_dit_backend as cd_backend
+from vllm_omni.diffusion.data import DiffusionCacheConfig
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+SEPARATE_CFG_ENABLERS = [
+ cd_backend.enable_cache_for_ltx2,
+ cd_backend.enable_cache_for_wan22,
+ cd_backend.enable_cache_for_longcat_image,
+]
+
+SAMPLE_CACHE_CONFIG = DiffusionCacheConfig()
+
+
+@pytest.mark.parametrize("enabler", SEPARATE_CFG_ENABLERS)
+@patch("vllm_omni.diffusion.cache.cache_dit_backend.BlockAdapter")
+@patch("vllm_omni.diffusion.cache.cache_dit_backend.cache_dit")
+def test_separate_cfg(mock_cache_dit, mock_block_adapter, enabler):
+ """Ensure that custom enablers for models with separate CFG pass
+ the param through to cache_dit correctly.
+
+ Regression test for: https://github.com/vllm-project/vllm-omni/pull/2860
+ """
+ mock_pipeline = Mock()
+ enabler(mock_pipeline, SAMPLE_CACHE_CONFIG)
+
+ mock_cache_dit.enable_cache.assert_called_once()
+ adapter_kwargs = mock_block_adapter.call_args.kwargs
+ assert adapter_kwargs["has_separate_cfg"] is True
diff --git a/tests/diffusion/cache/test_teacache_extractors.py b/tests/diffusion/cache/test_teacache_extractors.py
index c22a60e227e..4bb958a36c1 100644
--- a/tests/diffusion/cache/test_teacache_extractors.py
+++ b/tests/diffusion/cache/test_teacache_extractors.py
@@ -21,7 +21,7 @@
import pytest
import torch
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.cache.teacache.extractors import extract_flux2_context, extract_flux2_klein_context
from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import (
Flux2Transformer2DModel,
diff --git a/tests/diffusion/distributed/test_ulysses_uaa_perf.py b/tests/diffusion/distributed/test_ulysses_uaa_perf.py
index 04bbf5ee863..2a16a9ae578 100644
--- a/tests/diffusion/distributed/test_ulysses_uaa_perf.py
+++ b/tests/diffusion/distributed/test_ulysses_uaa_perf.py
@@ -17,7 +17,7 @@
import torch
import torch.distributed as dist
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.attention.parallel.ulysses import (
_all_gather_int,
_ulysses_all_to_all_any_o,
diff --git a/tests/diffusion/layers/test_rotary_emb_equivalence.py b/tests/diffusion/layers/test_rotary_emb_equivalence.py
new file mode 100644
index 00000000000..2fbb7a31f5a
--- /dev/null
+++ b/tests/diffusion/layers/test_rotary_emb_equivalence.py
@@ -0,0 +1,112 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+Numerical equivalence tests for rotary embedding implementations (#2436).
+
+Verifies that the optimized stack+flatten RoPE produces bit-identical results
+to the original strided-slice implementation across various tensor shapes and
+dtypes, ensuring the refactor is safe.
+"""
+
+from __future__ import annotations
+
+import pytest
+import torch
+
+
+def _apply_rotary_emb_helios_original(
+ hidden_states: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> torch.Tensor:
+ """Original Helios RoPE using strided slice assignment (pre-#2436)."""
+ x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
+ out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
+ return out.type_as(hidden_states)
+
+
+def _apply_rotary_emb_helios_optimized(
+ hidden_states: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> torch.Tensor:
+ """Optimized Helios RoPE using stack+flatten (post-#2436)."""
+ x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
+ rotated = torch.stack(
+ (
+ x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2],
+ x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2],
+ ),
+ dim=-1,
+ )
+ return rotated.flatten(-2, -1).type_as(hidden_states)
+
+
+def _make_inputs(
+ batch: int,
+ seq_len: int,
+ num_heads: int,
+ head_dim: int,
+ dtype: torch.dtype = torch.float32,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Generate random hidden_states and freqs_cis for testing."""
+ torch.manual_seed(42)
+ hidden_states = torch.randn(batch, seq_len, num_heads, head_dim, dtype=dtype)
+ # freqs_cis: [B, seq, head_dim*2] — cos and sin concatenated along last dim
+ freqs_cis = torch.randn(batch, seq_len, head_dim * 2, dtype=dtype)
+ return hidden_states, freqs_cis
+
+
+class TestHeliosRoPEEquivalence:
+ """Verify optimized Helios RoPE is numerically identical to original."""
+
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
+ def test_equivalence_across_dtypes(self, dtype: torch.dtype) -> None:
+ """Optimized output must be bit-identical to original across dtypes."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64, dtype=dtype)
+ original = _apply_rotary_emb_helios_original(hidden, freqs)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ torch.testing.assert_close(optimized, original, atol=0, rtol=0)
+
+ @pytest.mark.parametrize(
+ "batch,seq_len,num_heads,head_dim",
+ [
+ (1, 8, 1, 32), # minimal: single batch, single head
+ (2, 16, 8, 64), # typical transformer config
+ (1, 8192, 4, 64), # video-scale patch tokens (720p DiT)
+ (4, 32, 16, 128), # large head_dim
+ ],
+ )
+ def test_equivalence_across_shapes(self, batch: int, seq_len: int, num_heads: int, head_dim: int) -> None:
+ """Equivalence must hold across different tensor shapes."""
+ hidden, freqs = _make_inputs(batch, seq_len, num_heads, head_dim)
+ original = _apply_rotary_emb_helios_original(hidden, freqs)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ torch.testing.assert_close(optimized, original, atol=0, rtol=0)
+
+ def test_output_contiguous(self) -> None:
+ """Optimized output should be contiguous in memory."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ assert optimized.is_contiguous()
+
+ def test_output_shape_preserved(self) -> None:
+ """Output shape must match input shape."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ assert optimized.shape == hidden.shape
+
+ def test_output_dtype_preserved(self) -> None:
+ """Output dtype must match input dtype."""
+ hidden, freqs = _make_inputs(2, 16, 8, 64, dtype=torch.float16)
+ optimized = _apply_rotary_emb_helios_optimized(hidden, freqs)
+ assert optimized.dtype == hidden.dtype
+
+ def test_odd_head_dim_raises(self) -> None:
+ """Odd head_dim should fail at unflatten (not a valid RoPE config)."""
+ hidden = torch.randn(1, 4, 2, 63)
+ freqs = torch.randn(1, 4, 126)
+ with pytest.raises(RuntimeError):
+ _apply_rotary_emb_helios_optimized(hidden, freqs)
diff --git a/tests/diffusion/lora/conftest.py b/tests/diffusion/lora/helpers.py
similarity index 100%
rename from tests/diffusion/lora/conftest.py
rename to tests/diffusion/lora/helpers.py
diff --git a/tests/diffusion/lora/test_lora_manager.py b/tests/diffusion/lora/test_lora_manager.py
index 83ac7a1144b..785f5d84217 100644
--- a/tests/diffusion/lora/test_lora_manager.py
+++ b/tests/diffusion/lora/test_lora_manager.py
@@ -8,7 +8,7 @@
from vllm.lora.lora_weights import LoRALayerWeights
from vllm.lora.utils import get_supported_lora_modules
-from tests.diffusion.lora.conftest import (
+from tests.diffusion.lora.helpers import (
DummyBaseLayerWithLoRA,
FakeLinearBase,
fake_replace_submodule,
diff --git a/tests/diffusion/models/bagel/test_bagel_lora.py b/tests/diffusion/models/bagel/test_bagel_lora.py
index 8cb3446ed53..c285758fe86 100644
--- a/tests/diffusion/models/bagel/test_bagel_lora.py
+++ b/tests/diffusion/models/bagel/test_bagel_lora.py
@@ -11,7 +11,7 @@
import torch
from safetensors.torch import save_file
-from tests.diffusion.lora.conftest import (
+from tests.diffusion.lora.helpers import (
DummyBaseLayerWithLoRA,
FakeLinearBase,
fake_replace_submodule,
diff --git a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
index 54dda1dd07e..c613bb0b4c8 100644
--- a/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
+++ b/tests/diffusion/models/flux2/test_flux2_transformer_tp.py
@@ -2,7 +2,7 @@
import torch
from pytest_mock import MockerFixture
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.models.flux2.flux2_transformer import (
Flux2PosEmbed,
Flux2Transformer2DModel,
diff --git a/tests/diffusion/models/glm_image/test_glm_image_sp.py b/tests/diffusion/models/glm_image/test_glm_image_sp.py
index 1b1c8d7a75b..40d1c873070 100644
--- a/tests/diffusion/models/glm_image/test_glm_image_sp.py
+++ b/tests/diffusion/models/glm_image/test_glm_image_sp.py
@@ -2,27 +2,26 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for GLM-Image Sequence Parallelism support."""
-from unittest.mock import MagicMock, patch
-
import pytest
from vllm_omni.diffusion.data import DiffusionParallelConfig
@pytest.fixture(scope="function", autouse=True)
-def setup_sp_groups():
+def setup_sp_groups(mocker):
"""Set up SP and TP groups for each test function."""
- with patch("vllm_omni.diffusion.distributed.parallel_state.get_sp_group") as mock_get_sp_group:
- with patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=1):
- with patch("vllm.distributed.parallel_state.get_tp_group") as mock_get_tp_group:
- mock_sp_group = MagicMock()
- mock_sp_group.world_size = 4
- mock_get_sp_group.return_value = mock_sp_group
-
- mock_tp_group = MagicMock()
- mock_tp_group.world_size = 1
- mock_get_tp_group.return_value = mock_tp_group
- yield
+ mock_get_sp_group = mocker.patch("vllm_omni.diffusion.distributed.parallel_state.get_sp_group")
+ mocker.patch("vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size", return_value=1)
+ mock_get_tp_group = mocker.patch("vllm.distributed.parallel_state.get_tp_group")
+
+ mock_sp_group = mocker.MagicMock()
+ mock_sp_group.world_size = 4
+ mock_get_sp_group.return_value = mock_sp_group
+
+ mock_tp_group = mocker.MagicMock()
+ mock_tp_group.world_size = 1
+ mock_get_tp_group.return_value = mock_tp_group
+ yield
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
diff --git a/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
new file mode 100644
index 00000000000..873b52bf7a6
--- /dev/null
+++ b/tests/diffusion/models/qwen_image/test_qwen_image_edit_plus.py
@@ -0,0 +1,38 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+from pathlib import Path
+from types import SimpleNamespace
+
+import numpy as np
+import pytest
+from PIL import Image
+
+from vllm_omni.diffusion.models.qwen_image.pipeline_qwen_image_edit_plus import (
+ get_qwen_image_edit_plus_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]
+
+
+def test_qwen_image_edit_plus_rejects_too_many_input_images(tmp_path: Path):
+ vae_dir = tmp_path / "vae"
+ vae_dir.mkdir()
+ # Keep the mock config intentionally minimal: this test only needs the
+ # fields touched during pre-process initialization.
+ (vae_dir / "config.json").write_text(json.dumps({"z_dim": 16}))
+
+ pre_process = get_qwen_image_edit_plus_pre_process_func(SimpleNamespace(model=str(tmp_path)))
+ image = Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8))
+ request = SimpleNamespace(
+ prompts=[
+ {
+ "prompt": "combine",
+ "multi_modal_data": {"image": [image, image, image, image, image]},
+ }
+ ],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ with pytest.raises(ValueError, match=r"At most 4 images are supported by this model"):
+ pre_process(request)
diff --git a/tests/diffusion/models/wan2_2/conftest.py b/tests/diffusion/models/wan2_2/conftest.py
new file mode 100644
index 00000000000..f836fa545fd
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/conftest.py
@@ -0,0 +1,80 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+from types import SimpleNamespace
+
+import torch
+from torch import nn
+
+
+class StubTransformer(nn.Module):
+ def __init__(self, *, name: str = "transformer", in_channels: int = 4, out_channels: int = 4) -> None:
+ super().__init__()
+ self.name = name
+ self.config = SimpleNamespace(
+ patch_size=(1, 2, 2),
+ in_channels=in_channels,
+ out_channels=out_channels,
+ image_dim=None,
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return torch.float32
+
+ def forward(self, **kwargs):
+ hidden_states = kwargs["hidden_states"]
+ return (torch.zeros_like(hidden_states[:, : self.config.out_channels]),)
+
+
+class StubScheduler:
+ def __init__(self, timesteps: list[int]) -> None:
+ self.timesteps = torch.tensor(timesteps, dtype=torch.int64)
+ self.config = SimpleNamespace(num_train_timesteps=1000)
+ self.set_timesteps_calls: list[tuple[int, torch.device]] = []
+
+ def set_timesteps(self, num_steps: int, device: torch.device) -> None:
+ self.set_timesteps_calls.append((num_steps, device))
+
+
+class StubVAE:
+ dtype = torch.float32
+
+ def __init__(self, z_dim: int = 4) -> None:
+ self.config = SimpleNamespace(
+ z_dim=z_dim,
+ scale_factor_temporal=4,
+ scale_factor_spatial=8,
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ def encode(self, video: torch.Tensor):
+ latent_frames = (video.shape[2] + self.config.scale_factor_temporal - 1) // self.config.scale_factor_temporal
+ latent_height = video.shape[-2] // self.config.scale_factor_spatial
+ latent_width = video.shape[-1] // self.config.scale_factor_spatial
+ latents = torch.ones(
+ video.shape[0],
+ self.config.z_dim,
+ latent_frames,
+ latent_height,
+ latent_width,
+ dtype=video.dtype,
+ device=video.device,
+ )
+ return SimpleNamespace(latents=latents)
+
+ def decode(self, latents: torch.Tensor, return_dict: bool = False):
+ del return_dict
+ return (latents,)
+
+
+@contextmanager
+def noop_progress_bar(*args, **kwargs):
+ del args, kwargs
+
+ class Bar:
+ def update(self) -> None:
+ return None
+
+ yield Bar()
diff --git a/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py
new file mode 100644
index 00000000000..04e834ac47c
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_i2v_pipeline.py
@@ -0,0 +1,126 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+from PIL import Image
+from torch import nn
+
+from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_i2v import (
+ Wan22I2VPipeline,
+ get_wan22_i2v_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+def _make_i2v_pipeline(*, expand_timesteps: bool) -> Wan22I2VPipeline:
+ pipeline = object.__new__(Wan22I2VPipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = StubTransformer(name="high", in_channels=8, out_channels=4)
+ pipeline.transformer_2 = StubTransformer(name="low", in_channels=8, out_channels=4)
+ pipeline.vae = StubVAE(z_dim=4)
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.expand_timesteps = expand_timesteps
+ pipeline.progress_bar = noop_progress_bar
+ return pipeline
+
+
+def test_i2v_preprocess_requires_image_and_resizes_to_480p_aspect() -> None:
+ preprocess = get_wan22_i2v_pre_process_func(SimpleNamespace())
+ request = SimpleNamespace(
+ prompts=[{"prompt": "p", "multi_modal_data": {"image": Image.new("RGB", (320, 160), "red")}}],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ result = preprocess(request)
+ prompt = result.prompts[0]
+
+ assert result.sampling_params.height == 432
+ assert result.sampling_params.width == 880
+ assert prompt["multi_modal_data"]["image"].size == (880, 432)
+ assert prompt["additional_information"]["preprocessed_image"].shape[-2:] == (432, 880)
+
+ missing_image = SimpleNamespace(
+ prompts=[{"prompt": "p", "multi_modal_data": {}}],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+ with pytest.raises(ValueError, match="No image is provided"):
+ preprocess(missing_image)
+
+
+def test_i2v_diffuse_selects_stage_guidance_and_expands_timesteps() -> None:
+ pipeline = _make_i2v_pipeline(expand_timesteps=True)
+ latents = torch.zeros(1, 4, 2, 4, 4)
+ condition = torch.ones_like(latents)
+ first_frame_mask = torch.ones(1, 1, 2, 4, 4)
+ first_frame_mask[:, :, 0] = 0
+ timesteps = torch.tensor([900, 100])
+
+ calls = []
+
+ def fake_predict_noise_maybe_with_cfg(**kwargs):
+ positive = kwargs["positive_kwargs"]
+ calls.append(
+ {
+ "model": positive["current_model"].name,
+ "scale": kwargs["true_cfg_scale"],
+ "timestep_shape": tuple(positive["timestep"].shape),
+ "timestep_values": positive["timestep"].clone(),
+ "hidden_states": positive["hidden_states"].clone(),
+ }
+ )
+ return torch.ones_like(latents)
+
+ pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=torch.zeros(1, 2, 3),
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ guidance_low=1.0,
+ guidance_high=2.0,
+ boundary_timestep=500.0,
+ dtype=torch.float32,
+ attention_kwargs={},
+ condition=condition,
+ first_frame_mask=first_frame_mask,
+ )
+
+ assert [call["model"] for call in calls] == ["high", "low"]
+ assert [call["scale"] for call in calls] == [1.0, 2.0]
+ assert calls[0]["timestep_shape"] == (1, 8)
+ timestep_dtype = calls[0]["timestep_values"].dtype
+ torch.testing.assert_close(calls[0]["timestep_values"][0, :4], torch.zeros(4, dtype=timestep_dtype))
+ torch.testing.assert_close(calls[0]["timestep_values"][0, 4:], torch.full((4,), 900, dtype=timestep_dtype))
+ torch.testing.assert_close(calls[0]["hidden_states"][:, :, 0], torch.ones(1, 4, 4, 4))
+ torch.testing.assert_close(result, torch.full_like(latents, 2.0))
+
+
+def test_i2v_prepare_latents_builds_expand_condition_and_first_frame_mask() -> None:
+ pipeline = _make_i2v_pipeline(expand_timesteps=True)
+ latents, condition, first_frame_mask = pipeline.prepare_latents(
+ image=torch.zeros(1, 3, 16, 16),
+ batch_size=1,
+ num_channels_latents=4,
+ height=16,
+ width=16,
+ num_frames=5,
+ dtype=torch.float32,
+ device=torch.device("cpu"),
+ generator=torch.Generator(device="cpu").manual_seed(0),
+ )
+
+ assert latents.shape == (1, 4, 2, 2, 2)
+ assert condition.shape == (1, 4, 1, 2, 2)
+ assert first_frame_mask.shape == (1, 1, 2, 2, 2)
+ assert first_frame_mask[:, :, 0].sum() == 0
+ assert first_frame_mask[:, :, 1].sum() == 4
diff --git a/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py b/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py
new file mode 100644
index 00000000000..54bb672ef81
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_pipeline_diffuse.py
@@ -0,0 +1,155 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from contextlib import contextmanager
+from types import SimpleNamespace
+
+import pytest
+import torch
+from torch import nn
+
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import Wan22Pipeline
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+class _StubTransformer(nn.Module):
+ @property
+ def dtype(self) -> torch.dtype:
+ return torch.float32
+
+
+class _StubScheduler:
+ def __init__(self, timesteps: list[int]) -> None:
+ self.timesteps = torch.tensor(timesteps, dtype=torch.int64)
+ self.config = SimpleNamespace(num_train_timesteps=1000)
+ self.set_timesteps_calls: list[tuple[int, torch.device]] = []
+
+ def set_timesteps(self, num_steps: int, device: torch.device) -> None:
+ self.set_timesteps_calls.append((num_steps, device))
+
+
+@contextmanager
+def _noop_progress_bar(*args, **kwargs):
+ del args, kwargs
+
+ class _Bar:
+ def update(self) -> None:
+ return None
+
+ yield _Bar()
+
+
+def _make_pipeline() -> Wan22Pipeline:
+ pipeline = object.__new__(Wan22Pipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = _StubTransformer()
+ pipeline.transformer_2 = None
+ pipeline.transformer_config = SimpleNamespace(patch_size=(1, 2, 2), in_channels=4, out_channels=4)
+ pipeline.scheduler = _StubScheduler([9, 5])
+ pipeline.od_config = SimpleNamespace(flow_shift=5.0)
+ pipeline._sample_solver = "unipc"
+ pipeline._flow_shift = 5.0
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.boundary_ratio = 0.875
+ pipeline.expand_timesteps = False
+ pipeline._guidance_scale = None
+ pipeline._guidance_scale_2 = None
+ pipeline._num_timesteps = None
+ pipeline._current_timestep = None
+ pipeline.check_inputs = lambda **kwargs: None
+ pipeline.prepare_latents = lambda **kwargs: torch.zeros((1, 4, 1, 8, 8), dtype=torch.float32)
+ pipeline.progress_bar = _noop_progress_bar
+ return pipeline
+
+
+def test_forward_delegates_denoising_to_diffuse(monkeypatch) -> None:
+ pipeline = _make_pipeline()
+
+ prompt_embeds = torch.randn(1, 8)
+ captured: dict[str, object] = {}
+
+ def _fake_diffuse(**kwargs):
+ captured.update(kwargs)
+ return kwargs["latents"] + 1
+
+ pipeline.diffuse = _fake_diffuse # type: ignore[method-assign]
+
+ req = SimpleNamespace(
+ prompts=["prompt"],
+ sampling_params=SimpleNamespace(
+ height=None,
+ width=None,
+ num_frames=1,
+ num_inference_steps=2,
+ guidance_scale_provided=False,
+ guidance_scale=None,
+ guidance_scale_2=None,
+ boundary_ratio=None,
+ generator=None,
+ seed=None,
+ num_outputs_per_prompt=1,
+ max_sequence_length=32,
+ latents=None,
+ extra_args={},
+ ),
+ )
+
+ output = pipeline.forward(req, prompt_embeds=prompt_embeds, output_type="latent", guidance_scale=1.0)
+
+ assert torch.equal(output.output, torch.ones((1, 4, 1, 8, 8)))
+ assert torch.equal(captured["timesteps"], pipeline.scheduler.timesteps)
+ assert captured["guidance_low"] == 1.0
+ assert captured["guidance_high"] == 1.0
+ assert captured["boundary_timestep"] == pytest.approx(875.0)
+ assert captured["latent_condition"] is None
+ assert captured["first_frame_mask"] is None
+ assert pipeline.scheduler.set_timesteps_calls == [(2, torch.device("cpu"))]
+
+
+def test_diffuse_runs_prediction_and_scheduler_for_each_timestep() -> None:
+ pipeline = _make_pipeline()
+ latents = torch.zeros((1, 1, 1, 2, 2), dtype=torch.float32)
+ timesteps = torch.tensor([7, 3], dtype=torch.int64)
+ prompt_embeds = torch.randn(1, 8)
+
+ predict_calls: list[dict[str, object]] = []
+ scheduler_calls: list[tuple[float, int, float, bool]] = []
+
+ def _fake_predict_noise_maybe_with_cfg(**kwargs):
+ predict_calls.append(kwargs)
+ timestep = kwargs["positive_kwargs"]["timestep"]
+ assert isinstance(timestep, torch.Tensor)
+ return torch.full_like(latents, float(timestep[0].item()))
+
+ def _fake_scheduler_step_maybe_with_cfg(noise_pred, t, current_latents, do_true_cfg):
+ scheduler_calls.append(
+ (float(noise_pred[0, 0, 0, 0, 0]), int(t.item()), float(current_latents.sum()), do_true_cfg)
+ )
+ return current_latents + noise_pred
+
+ pipeline.predict_noise_maybe_with_cfg = _fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = _fake_scheduler_step_maybe_with_cfg # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=None,
+ guidance_low=1.0,
+ guidance_high=2.0,
+ boundary_timestep=5.0,
+ dtype=torch.float32,
+ attention_kwargs={},
+ )
+
+ assert len(predict_calls) == 2
+ assert predict_calls[0]["true_cfg_scale"] == 1.0
+ assert predict_calls[1]["true_cfg_scale"] == 2.0
+ assert scheduler_calls == [
+ (7.0, 7, 0.0, False),
+ (3.0, 3, 28.0, False),
+ ]
+ assert torch.equal(result, torch.full_like(latents, 10.0))
diff --git a/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py b/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py
new file mode 100644
index 00000000000..31471786976
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_pipeline_helpers.py
@@ -0,0 +1,81 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import json
+from types import SimpleNamespace
+
+import pytest
+import torch
+
+import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 as wan22_module
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import (
+ create_transformer_from_config,
+ load_transformer_config,
+ retrieve_latents,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+class _LatentDist:
+ def sample(self, generator):
+ assert isinstance(generator, torch.Generator)
+ return torch.tensor([1.0])
+
+ def mode(self):
+ return torch.tensor([2.0])
+
+
+def test_retrieve_latents_supports_sample_mode_argmax_and_direct_latents() -> None:
+ generator = torch.Generator(device="cpu")
+
+ assert retrieve_latents(SimpleNamespace(latent_dist=_LatentDist()), generator).item() == 1.0
+ assert retrieve_latents(SimpleNamespace(latent_dist=_LatentDist()), sample_mode="argmax").item() == 2.0
+ torch.testing.assert_close(retrieve_latents(SimpleNamespace(latents=torch.tensor([3.0]))), torch.tensor([3.0]))
+
+
+def test_retrieve_latents_rejects_unknown_encoder_output() -> None:
+ with pytest.raises(AttributeError, match="Could not access latents"):
+ retrieve_latents(SimpleNamespace())
+
+
+def test_load_transformer_config_reads_local_subfolder_config(tmp_path) -> None:
+ config_dir = tmp_path / "transformer_2"
+ config_dir.mkdir(parents=True)
+ (config_dir / "config.json").write_text(json.dumps({"patch_size": [1, 2, 2], "num_layers": 2}))
+
+ assert load_transformer_config(str(tmp_path), "transformer_2") == {"patch_size": [1, 2, 2], "num_layers": 2}
+ assert load_transformer_config(str(tmp_path), "missing") == {}
+
+
+def test_create_transformer_from_config_maps_supported_keys(monkeypatch) -> None:
+ captured = {}
+
+ class FakeTransformer:
+ def __init__(self, **kwargs) -> None:
+ captured.update(kwargs)
+
+ monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer)
+
+ transformer = create_transformer_from_config(
+ {
+ "patch_size": [1, 2, 2],
+ "num_attention_heads": 8,
+ "attention_head_dim": 128,
+ "in_channels": 16,
+ "out_channels": 16,
+ "text_dim": 4096,
+ "vace_layers": [0],
+ "ignored": "value",
+ }
+ )
+
+ assert isinstance(transformer, FakeTransformer)
+ assert captured == {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 8,
+ "attention_head_dim": 128,
+ "in_channels": 16,
+ "out_channels": 16,
+ "text_dim": 4096,
+ }
diff --git a/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py
new file mode 100644
index 00000000000..983350c4cf9
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_ti2v_pipeline.py
@@ -0,0 +1,98 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+from PIL import Image
+from torch import nn
+
+from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_ti2v import (
+ Wan22TI2VPipeline,
+ get_wan22_ti2v_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+def _make_ti2v_pipeline() -> Wan22TI2VPipeline:
+ pipeline = object.__new__(Wan22TI2VPipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = StubTransformer(in_channels=4, out_channels=4)
+ pipeline.vae = StubVAE(z_dim=4)
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.progress_bar = noop_progress_bar
+ return pipeline
+
+
+def test_ti2v_preprocess_uses_720p_area_for_image_condition() -> None:
+ preprocess = get_wan22_ti2v_pre_process_func(SimpleNamespace())
+ request = SimpleNamespace(
+ prompts=[{"prompt": "p", "multi_modal_data": {"image": Image.new("RGB", (320, 160), "blue")}}],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ result = preprocess(request)
+
+ assert result.sampling_params.height == 672
+ assert result.sampling_params.width == 1344
+ assert result.prompts[0]["multi_modal_data"]["image"].size == (1344, 672)
+ assert result.prompts[0]["additional_information"]["preprocessed_image"].shape[-2:] == (672, 1344)
+
+
+def test_ti2v_diffuse_without_image_condition_expands_patch_timesteps() -> None:
+ pipeline = _make_ti2v_pipeline()
+ latents = torch.zeros(1, 4, 2, 4, 4)
+ calls = []
+
+ def fake_predict_noise_maybe_with_cfg(**kwargs):
+ calls.append(kwargs)
+ return torch.ones_like(latents)
+
+ pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=torch.tensor([7]),
+ prompt_embeds=torch.zeros(1, 2, 3),
+ negative_prompt_embeds=torch.zeros(1, 2, 3),
+ guidance_scale=3.0,
+ dtype=torch.float32,
+ attention_kwargs={"a": "b"},
+ num_latent_frames=2,
+ latent_height=4,
+ latent_width=4,
+ )
+
+ positive = calls[0]["positive_kwargs"]
+ assert calls[0]["do_true_cfg"] is True
+ assert positive["timestep"].shape == (1, 8)
+ torch.testing.assert_close(positive["timestep"], torch.full((1, 8), 7, dtype=positive["timestep"].dtype))
+ torch.testing.assert_close(positive["hidden_states"], latents)
+ torch.testing.assert_close(result, torch.ones_like(latents))
+
+
+def test_ti2v_prepare_i2v_latents_encodes_condition_and_masks_first_frame() -> None:
+ pipeline = _make_ti2v_pipeline()
+ latents, latent_condition, first_frame_mask = pipeline.prepare_i2v_latents(
+ image=torch.zeros(1, 3, 16, 16),
+ batch_size=1,
+ num_channels_latents=4,
+ height=16,
+ width=16,
+ num_frames=5,
+ dtype=torch.float32,
+ device=torch.device("cpu"),
+ generator=None,
+ latents=torch.zeros(1, 4, 2, 2, 2),
+ )
+
+ torch.testing.assert_close(latents, torch.zeros(1, 4, 2, 2, 2))
+ assert latent_condition.shape == (1, 4, 1, 2, 2)
+ assert first_frame_mask[:, :, 0].sum() == 0
+ assert first_frame_mask[:, :, 1].sum() == 4
diff --git a/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py b/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py
new file mode 100644
index 00000000000..9fa9b67c499
--- /dev/null
+++ b/tests/diffusion/models/wan2_2/test_wan22_vace_pipeline.py
@@ -0,0 +1,137 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+from PIL import Image
+from torch import nn
+
+from tests.diffusion.models.wan2_2.conftest import StubTransformer, StubVAE, noop_progress_bar
+from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import (
+ Wan22VACEPipeline,
+ create_vace_transformer_from_config,
+ get_wan22_vace_pre_process_func,
+)
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion]
+
+
+def _make_vace_pipeline() -> Wan22VACEPipeline:
+ pipeline = object.__new__(Wan22VACEPipeline)
+ nn.Module.__init__(pipeline)
+ pipeline.device = torch.device("cpu")
+ pipeline.transformer = StubTransformer(in_channels=4, out_channels=4)
+ pipeline.transformer_config = pipeline.transformer.config
+ pipeline.vae = StubVAE(z_dim=4)
+ pipeline.vae_scale_factor_temporal = 4
+ pipeline.vae_scale_factor_spatial = 8
+ pipeline.progress_bar = noop_progress_bar
+ return pipeline
+
+
+def test_vace_preprocess_collects_reference_video_and_mask_inputs() -> None:
+ preprocess = get_wan22_vace_pre_process_func(SimpleNamespace())
+ ref = Image.new("RGB", (320, 160), "green")
+ frame = Image.new("RGB", (64, 64), "black")
+ mask = Image.new("L", (64, 64), 255)
+ request = SimpleNamespace(
+ prompts=[
+ {
+ "prompt": "p",
+ "multi_modal_data": {
+ "image": ref,
+ "video": [frame],
+ "mask": mask,
+ },
+ }
+ ],
+ sampling_params=SimpleNamespace(height=None, width=None),
+ )
+
+ result = preprocess(request)
+ additional_info = result.prompts[0]["additional_information"]
+
+ assert result.sampling_params.height == 432
+ assert result.sampling_params.width == 880
+ assert additional_info["reference_images"] == [ref]
+ assert additional_info["source_video"] == [frame]
+ assert additional_info["mask"] == [mask]
+
+
+def test_create_vace_transformer_from_config_maps_vace_specific_keys(monkeypatch) -> None:
+ captured = {}
+
+ class FakeVACETransformer:
+ def __init__(self, **kwargs) -> None:
+ captured.update(kwargs)
+
+ monkeypatch.setattr(
+ "vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace.WanVACETransformer3DModel",
+ FakeVACETransformer,
+ )
+
+ transformer = create_vace_transformer_from_config(
+ {
+ "patch_size": [1, 2, 2],
+ "in_channels": 96,
+ "out_channels": 16,
+ "vace_layers": [0, 1, 2],
+ "vace_in_channels": 132,
+ "unknown": "ignored",
+ }
+ )
+
+ assert isinstance(transformer, FakeVACETransformer)
+ assert captured == {
+ "patch_size": (1, 2, 2),
+ "in_channels": 96,
+ "out_channels": 16,
+ "vace_layers": [0, 1, 2],
+ "vace_in_channels": 132,
+ }
+
+
+def test_vace_prepare_masks_encodes_spatial_stride_and_reference_padding() -> None:
+ pipeline = _make_vace_pipeline()
+ mask = torch.ones(1, 3, 5, 16, 16)
+ reference_images = [[torch.zeros(3, 16, 16), torch.zeros(3, 16, 16)]]
+
+ encoded = pipeline.prepare_masks(mask, reference_images)
+
+ assert encoded.shape == (1, 64, 4, 2, 2)
+ torch.testing.assert_close(encoded[:, :, :2], torch.zeros(1, 64, 2, 2, 2))
+ torch.testing.assert_close(encoded[:, :, 2:], torch.ones(1, 64, 2, 2, 2))
+
+
+def test_vace_diffuse_passes_context_and_scale_to_cfg_branches() -> None:
+ pipeline = _make_vace_pipeline()
+ latents = torch.zeros(1, 4, 1, 2, 2)
+ vace_context = torch.ones(1, 12, 1, 2, 2)
+ calls = []
+
+ def fake_predict_noise_maybe_with_cfg(**kwargs):
+ calls.append(kwargs)
+ return torch.ones_like(latents)
+
+ pipeline.predict_noise_maybe_with_cfg = fake_predict_noise_maybe_with_cfg # type: ignore[method-assign]
+ pipeline.scheduler_step_maybe_with_cfg = lambda noise, t, current, cfg: current + noise # type: ignore[method-assign]
+
+ result = pipeline.diffuse(
+ latents=latents,
+ timesteps=torch.tensor([5]),
+ prompt_embeds=torch.zeros(1, 2, 3),
+ negative_prompt_embeds=torch.zeros(1, 2, 3),
+ guidance_scale=4.0,
+ dtype=torch.float32,
+ attention_kwargs={},
+ vace_context=vace_context,
+ vace_context_scale=0.75,
+ )
+
+ assert calls[0]["do_true_cfg"] is True
+ assert calls[0]["true_cfg_scale"] == 4.0
+ assert calls[0]["positive_kwargs"]["vace_context"] is vace_context
+ assert calls[0]["negative_kwargs"]["vace_context_scale"] == 0.75
+ torch.testing.assert_close(result, torch.ones_like(latents))
diff --git a/tests/diffusion/quantization/test_quantization_quality.py b/tests/diffusion/quantization/test_quantization_quality.py
index 3d8f1873698..ba6a150c4bb 100644
--- a/tests/diffusion/quantization/test_quantization_quality.py
+++ b/tests/diffusion/quantization/test_quantization_quality.py
@@ -32,7 +32,7 @@
import pytest
import torch
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
# ---------------------------------------------------------------------------
# Configuration — add new quantization methods / models here
diff --git a/tests/diffusion/test_diffusion_model_runner.py b/tests/diffusion/test_diffusion_model_runner.py
index 8768986f01d..b63f6d8887f 100644
--- a/tests/diffusion/test_diffusion_model_runner.py
+++ b/tests/diffusion/test_diffusion_model_runner.py
@@ -8,7 +8,7 @@
import torch
import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.worker.diffusion_model_runner import DiffusionModelRunner
pytestmark = [pytest.mark.diffusion]
diff --git a/tests/diffusion/test_diffusion_step_pipeline.py b/tests/diffusion/test_diffusion_step_pipeline.py
index 42687d4a1ed..06f8cd14dc8 100644
--- a/tests/diffusion/test_diffusion_step_pipeline.py
+++ b/tests/diffusion/test_diffusion_step_pipeline.py
@@ -13,7 +13,7 @@
from pytest_mock import MockerFixture
import vllm_omni.diffusion.worker.diffusion_model_runner as model_runner_module
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.data import DiffusionOutput
from vllm_omni.diffusion.diffusion_engine import DiffusionEngine
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
diff --git a/tests/e2e/accuracy/conftest.py b/tests/e2e/accuracy/conftest.py
index 3d614b8cdc1..709fdf345ec 100644
--- a/tests/e2e/accuracy/conftest.py
+++ b/tests/e2e/accuracy/conftest.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import os
-import shutil
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
@@ -13,7 +12,7 @@
import torch
from PIL import Image
-from tests.conftest import OmniServer, OmniServerParams
+from tests.helpers.runtime import OmniServer, OmniServerParams
def pytest_addoption(parser):
@@ -208,18 +207,6 @@ def rabbit_image(accuracy_artifact_root: Path) -> Image.Image:
return image
-def reset_artifact_dir(path: Path) -> Path:
- if path.exists():
- shutil.rmtree(path)
- path.mkdir(parents=True, exist_ok=True)
- return path
-
-
-def infer_model_label(model: str) -> str:
- label = Path(model.rstrip("/\\")).name or "model"
- return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in label)
-
-
def _build_accuracy_server_config(
*,
generate_model: str,
diff --git a/tests/e2e/accuracy/utils.py b/tests/e2e/accuracy/helpers.py
similarity index 91%
rename from tests/e2e/accuracy/utils.py
rename to tests/e2e/accuracy/helpers.py
index d722b69b011..382d3ea9b5f 100644
--- a/tests/e2e/accuracy/utils.py
+++ b/tests/e2e/accuracy/helpers.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
from pathlib import Path
import numpy as np
@@ -9,6 +7,20 @@
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
+def reset_artifact_dir(path: Path) -> Path:
+ import shutil
+
+ if path.exists():
+ shutil.rmtree(path)
+ path.mkdir(parents=True, exist_ok=True)
+ return path
+
+
+def infer_model_label(model: str) -> str:
+ label = Path(model.rstrip("/\\")).name or "model"
+ return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in label)
+
+
def model_output_dir(parent_dir: Path, model: str) -> Path:
safe_model_name = model.split("/")[-1].replace(".", "_")
path = parent_dir / safe_model_name
diff --git a/tests/e2e/accuracy/test_gebench_h100_smoke.py b/tests/e2e/accuracy/test_gebench_h100_smoke.py
index b4b83187135..74891926910 100644
--- a/tests/e2e/accuracy/test_gebench_h100_smoke.py
+++ b/tests/e2e/accuracy/test_gebench_h100_smoke.py
@@ -6,8 +6,8 @@
import pytest
from benchmarks.accuracy.text_to_image.gbench import main as gbench_main
-from tests.e2e.accuracy.conftest import infer_model_label, reset_artifact_dir
-from tests.utils import hardware_test
+from tests.e2e.accuracy.helpers import infer_model_label, reset_artifact_dir
+from tests.helpers.mark import hardware_test
@pytest.mark.advanced_model
diff --git a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
index 960ea57960c..6227d636863 100644
--- a/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
+++ b/tests/e2e/accuracy/test_gedit_bench_h100_smoke.py
@@ -7,8 +7,8 @@
from benchmarks.accuracy.image_to_image.gedit_bench import GROUPS
from benchmarks.accuracy.image_to_image.gedit_bench import main as gedit_main
-from tests.e2e.accuracy.conftest import infer_model_label, reset_artifact_dir
-from tests.utils import hardware_test
+from tests.e2e.accuracy.helpers import infer_model_label, reset_artifact_dir
+from tests.helpers.mark import hardware_test
@pytest.mark.advanced_model
diff --git a/tests/e2e/accuracy/test_qwen_image.py b/tests/e2e/accuracy/test_qwen_image.py
index e73195017aa..8922d9d1044 100644
--- a/tests/e2e/accuracy/test_qwen_image.py
+++ b/tests/e2e/accuracy/test_qwen_image.py
@@ -12,13 +12,10 @@
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from PIL import Image
-from tests.conftest import (
- OmniServer,
- _run_post_test_cleanup,
- _run_pre_test_cleanup,
-)
-from tests.e2e.accuracy.utils import assert_similarity, model_output_dir
-from tests.utils import hardware_test
+from tests.e2e.accuracy.helpers import assert_similarity, model_output_dir
+from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServer
MODEL_ID = "Qwen/Qwen-Image"
MODEL_ENV_VAR = "QWEN_IMAGE_MODEL"
@@ -70,7 +67,7 @@ def _run_vllm_omni_qwen_image(*, model: str, output_path: Path) -> Image.Image:
def _run_diffusers_qwen_image(*, model: str, output_path: Path) -> Image.Image:
- _run_pre_test_cleanup(enable_force=True)
+ run_pre_test_cleanup(enable_force=True)
pipe: DiffusionPipeline | None = None
try:
pipe = DiffusionPipeline.from_pretrained(
@@ -99,7 +96,7 @@ def _run_diffusers_qwen_image(*, model: str, output_path: Path) -> Image.Image:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
- _run_post_test_cleanup(enable_force=True)
+ run_post_test_cleanup(enable_force=True)
@pytest.mark.advanced_model
diff --git a/tests/e2e/accuracy/test_qwen_image_edit.py b/tests/e2e/accuracy/test_qwen_image_edit.py
index 9a970103438..e17aca6e99b 100644
--- a/tests/e2e/accuracy/test_qwen_image_edit.py
+++ b/tests/e2e/accuracy/test_qwen_image_edit.py
@@ -10,13 +10,10 @@
from PIL import Image
from benchmarks.accuracy.common import decode_base64_image, pil_to_png_bytes
-from tests.conftest import (
- OmniServer,
- _run_post_test_cleanup,
- _run_pre_test_cleanup,
-)
-from tests.e2e.accuracy.utils import assert_similarity, model_output_dir
-from tests.utils import hardware_test
+from tests.e2e.accuracy.helpers import assert_similarity, model_output_dir
+from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServer
SINGLE_MODEL = "Qwen/Qwen-Image-Edit"
MULTIPLE_MODEL = "Qwen/Qwen-Image-Edit-2509"
@@ -77,7 +74,7 @@ def _run_diffusers_image_edit(
input_images: list[Image.Image],
output_path: Path,
) -> Image.Image:
- _run_pre_test_cleanup(enable_force=True)
+ run_pre_test_cleanup(enable_force=True)
pipe: QwenImageEditPipeline | QwenImageEditPlusPipeline | None = None
device = torch.device("cuda:0")
torch.cuda.set_device(device)
@@ -110,7 +107,7 @@ def _run_diffusers_image_edit(
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
- _run_post_test_cleanup(enable_force=True)
+ run_post_test_cleanup(enable_force=True)
def _vllm_omni_output_single_image(
diff --git a/tests/e2e/accuracy/test_qwen_image_layered.py b/tests/e2e/accuracy/test_qwen_image_layered.py
index 04b13df3bb2..0ab9cb32363 100644
--- a/tests/e2e/accuracy/test_qwen_image_layered.py
+++ b/tests/e2e/accuracy/test_qwen_image_layered.py
@@ -12,13 +12,10 @@
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from PIL import Image
-from tests.conftest import (
- OmniServer,
- _run_post_test_cleanup,
- _run_pre_test_cleanup,
-)
-from tests.e2e.accuracy.utils import assert_image_sequence_similarity, model_output_dir
-from tests.utils import hardware_test
+from tests.e2e.accuracy.helpers import assert_image_sequence_similarity, model_output_dir
+from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServer
MODEL_ID = "Qwen/Qwen-Image-Layered"
MODEL_ENV_VAR = "QWEN_IMAGE_LAYERED_MODEL"
@@ -93,7 +90,7 @@ def _run_vllm_omni_qwen_image_layered(*, model: str, input_image: Image.Image, o
def _run_diffusers_qwen_image_layered(*, model: str, input_image: Image.Image, output_dir: Path) -> list[Image.Image]:
- _run_pre_test_cleanup(enable_force=True)
+ run_pre_test_cleanup(enable_force=True)
pipe: DiffusionPipeline | None = None
try:
pipe = DiffusionPipeline.from_pretrained(
@@ -126,7 +123,7 @@ def _run_diffusers_qwen_image_layered(*, model: str, input_image: Image.Image, o
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
- _run_post_test_cleanup(enable_force=True)
+ run_post_test_cleanup(enable_force=True)
@pytest.mark.advanced_model
diff --git a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
index 3cdda1f9ffa..24daa8ccf54 100644
--- a/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
+++ b/tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
@@ -22,7 +22,6 @@
from diffusers import UniPCMultistepScheduler
from PIL import Image
-from tests.conftest import OmniServerParams
from tests.e2e.accuracy.wan22_i2v.run_wan22_i2v_diffusers_cp import (
_configure_scheduler,
_ensure_wan_ftfy_fallback,
@@ -48,7 +47,8 @@
SSIM_THRESHOLD,
WIDTH,
)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
def test_parse_video_metadata_extracts_dimensions_and_fps() -> None:
@@ -567,6 +567,7 @@ def test_wan22_i2v_diffusers_offline_generates_video(
@pytest.mark.benchmark
@pytest.mark.diffusion
@hardware_test(res={"cuda": "H100"}, num_cards=2)
+@pytest.mark.skip(reason="issue: #2874")
@pytest.mark.parametrize("omni_server", SERVER_CASES, indirect=True)
def test_wan22_i2v_online_serving_generates_video(
omni_server,
diff --git a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py
index 57743d62bf6..bd3f2e09975 100644
--- a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py
+++ b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_collective_rpc.py
@@ -26,7 +26,7 @@
import pytest
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
diff --git a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py
index 03bd12efae2..0681687fe73 100644
--- a/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py
+++ b/tests/e2e/offline_inference/custom_pipeline/test_async_omni_qwen_image_generate.py
@@ -19,7 +19,7 @@
import pytest
from transformers import AutoTokenizer
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
diff --git a/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py b/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py
index ffbe703ca78..653b35d7e2f 100644
--- a/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py
+++ b/tests/e2e/offline_inference/custom_pipeline/test_worker_extension.py
@@ -10,7 +10,7 @@
from tests.e2e.offline_inference.custom_pipeline.worker_extension import (
vLLMOmniColocateWorkerExtensionForTest,
)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.worker.diffusion_worker import CustomPipelineWorkerExtension
from vllm_omni.entrypoints.async_omni import AsyncOmni
diff --git a/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
index 1f0d06cb8c0..b7768c071f6 100644
--- a/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
+++ b/tests/e2e/offline_inference/stage_configs/bagel_mooncake_ci.yaml
@@ -64,9 +64,6 @@ stage_args:
# Top-level runtime config with Mooncake connector
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
mooncake_connector:
name: MooncakeConnector
@@ -80,4 +77,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
index 36b1d2bbe48..504f3c98e92 100644
--- a/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
+++ b/tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_ci.yaml
@@ -62,10 +62,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -78,4 +74,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml b/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml
deleted file mode 100644
index f93a6c71473..00000000000
--- a/tests/e2e/offline_inference/stage_configs/npu/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,103 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# This config is optimized for CI e2e tests.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0"
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 896
- max_num_batched_tokens: 896
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 896
- max_num_batched_tokens: 896
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/offline_inference/test_bagel_img2img.py b/tests/e2e/offline_inference/test_bagel_img2img.py
index 63d2a37da79..66aec80c7c4 100644
--- a/tests/e2e/offline_inference/test_bagel_img2img.py
+++ b/tests/e2e/offline_inference/test_bagel_img2img.py
@@ -22,9 +22,10 @@
from PIL import Image
from vllm.assets.image import ImageAsset
-from tests.conftest import OmniRunner, modify_stage_config
-from tests.utils import hardware_test
-from vllm_omni import Omni
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
+from tests.helpers.stage_config import modify_stage_config
+from vllm_omni.entrypoints.omni import Omni
from vllm_omni.platforms import current_omni_platform
# Reference pixel data extracted from the known-good output image
@@ -32,30 +33,30 @@
# prompt='Change the grass color to red',
# input image: 2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (157, 172, 217)},
- {"position": (400, 50), "rgb": (105, 144, 218)},
- {"position": (700, 100), "rgb": (118, 159, 233)},
- {"position": (150, 400), "rgb": (195, 34, 60)},
- {"position": (512, 336), "rgb": (222, 214, 193)},
- {"position": (700, 400), "rgb": (197, 15, 43)},
- {"position": (100, 600), "rgb": (105, 13, 18)},
- {"position": (400, 600), "rgb": (169, 33, 44)},
- {"position": (700, 600), "rgb": (101, 86, 93)},
- {"position": (256, 256), "rgb": (181, 202, 222)},
+ {"position": (100, 100), "rgb": (156, 172, 217)},
+ {"position": (400, 50), "rgb": (105, 144, 217)},
+ {"position": (700, 100), "rgb": (118, 159, 232)},
+ {"position": (150, 400), "rgb": (180, 22, 52)},
+ {"position": (512, 336), "rgb": (221, 211, 194)},
+ {"position": (700, 400), "rgb": (192, 10, 46)},
+ {"position": (100, 600), "rgb": (102, 12, 22)},
+ {"position": (400, 600), "rgb": (161, 28, 47)},
+ {"position": (700, 600), "rgb": (100, 87, 94)},
+ {"position": (256, 256), "rgb": (181, 201, 221)},
]
if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (156, 172, 215)},
- {"position": (400, 50), "rgb": (106, 144, 216)},
- {"position": (700, 100), "rgb": (118, 158, 231)},
- {"position": (150, 400), "rgb": (183, 23, 48)},
- {"position": (512, 336), "rgb": (218, 215, 191)},
- {"position": (700, 400), "rgb": (194, 14, 42)},
- {"position": (100, 600), "rgb": (105, 10, 16)},
- {"position": (400, 600), "rgb": (167, 33, 46)},
- {"position": (700, 600), "rgb": (102, 86, 92)},
- {"position": (256, 256), "rgb": (181, 201, 220)},
+ {"position": (100, 100), "rgb": (156, 172, 217)},
+ {"position": (400, 50), "rgb": (105, 144, 217)},
+ {"position": (700, 100), "rgb": (118, 159, 232)},
+ {"position": (150, 400), "rgb": (180, 22, 52)},
+ {"position": (512, 336), "rgb": (221, 211, 194)},
+ {"position": (700, 400), "rgb": (192, 10, 46)},
+ {"position": (100, 600), "rgb": (102, 12, 22)},
+ {"position": (400, 600), "rgb": (161, 28, 47)},
+ {"position": (700, 600), "rgb": (100, 87, 94)},
+ {"position": (256, 256), "rgb": (181, 201, 221)},
]
PIXEL_TOLERANCE = 10
diff --git a/tests/e2e/offline_inference/test_bagel_lora.py b/tests/e2e/offline_inference/test_bagel_lora.py
index 501d23eaa88..75f41f9beea 100644
--- a/tests/e2e/offline_inference/test_bagel_lora.py
+++ b/tests/e2e/offline_inference/test_bagel_lora.py
@@ -31,9 +31,10 @@
from PIL import Image
from safetensors.torch import save_file
-from tests.conftest import OmniRunner, modify_stage_config
-from tests.utils import hardware_test
-from vllm_omni import Omni
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
+from tests.helpers.stage_config import modify_stage_config
+from vllm_omni.entrypoints.omni import Omni
from vllm_omni.lora.request import LoRARequest
from vllm_omni.lora.utils import stable_lora_int_id
diff --git a/tests/e2e/offline_inference/test_bagel_text2img.py b/tests/e2e/offline_inference/test_bagel_text2img.py
index e45d64f2ac5..0819f103a0a 100644
--- a/tests/e2e/offline_inference/test_bagel_text2img.py
+++ b/tests/e2e/offline_inference/test_bagel_text2img.py
@@ -27,9 +27,10 @@
import pytest
from PIL import Image
-from tests.conftest import OmniRunner, modify_stage_config
-from tests.utils import hardware_test
-from vllm_omni import Omni
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
+from tests.helpers.stage_config import modify_stage_config
+from vllm_omni.entrypoints.omni import Omni
from vllm_omni.platforms import current_omni_platform
# Reference pixel data extracted from the known-good output image
@@ -37,30 +38,30 @@
# "Generated with seed=52, num_inference_steps=15,
# prompt='A futuristic city skyline at twilight, cyberpunk style'"
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (121, 118, 100)},
- {"position": (400, 50), "rgb": (163, 162, 143)},
- {"position": (700, 100), "rgb": (170, 156, 127)},
- {"position": (150, 400), "rgb": (129, 127, 112)},
- {"position": (512, 512), "rgb": (135, 61, 59)},
- {"position": (700, 400), "rgb": (205, 107, 43)},
- {"position": (100, 700), "rgb": (197, 177, 157)},
- {"position": (400, 700), "rgb": (139, 107, 86)},
- {"position": (700, 700), "rgb": (247, 205, 146)},
- {"position": (256, 256), "rgb": (171, 160, 153)},
+ {"position": (100, 100), "rgb": (115, 113, 94)},
+ {"position": (400, 50), "rgb": (159, 160, 144)},
+ {"position": (700, 100), "rgb": (164, 151, 123)},
+ {"position": (150, 400), "rgb": (120, 121, 107)},
+ {"position": (512, 512), "rgb": (165, 133, 127)},
+ {"position": (700, 400), "rgb": (217, 130, 66)},
+ {"position": (100, 700), "rgb": (191, 168, 152)},
+ {"position": (400, 700), "rgb": (130, 96, 77)},
+ {"position": (700, 700), "rgb": (247, 203, 140)},
+ {"position": (256, 256), "rgb": (167, 156, 150)},
]
if current_omni_platform.is_rocm():
REFERENCE_PIXELS = [
- {"position": (100, 100), "rgb": (123, 119, 100)},
- {"position": (400, 50), "rgb": (162, 161, 142)},
- {"position": (700, 100), "rgb": (171, 156, 127)},
- {"position": (150, 400), "rgb": (131, 128, 112)},
- {"position": (512, 512), "rgb": (134, 61, 59)},
- {"position": (700, 400), "rgb": (204, 107, 43)},
- {"position": (100, 700), "rgb": (201, 180, 165)},
- {"position": (400, 700), "rgb": (140, 108, 87)},
- {"position": (700, 700), "rgb": (247, 205, 145)},
- {"position": (256, 256), "rgb": (171, 160, 153)},
+ {"position": (100, 100), "rgb": (115, 113, 94)},
+ {"position": (400, 50), "rgb": (159, 160, 144)},
+ {"position": (700, 100), "rgb": (164, 151, 123)},
+ {"position": (150, 400), "rgb": (120, 121, 107)},
+ {"position": (512, 512), "rgb": (165, 133, 127)},
+ {"position": (700, 400), "rgb": (217, 130, 66)},
+ {"position": (100, 700), "rgb": (191, 168, 152)},
+ {"position": (400, 700), "rgb": (130, 96, 77)},
+ {"position": (700, 700), "rgb": (247, 203, 140)},
+ {"position": (256, 256), "rgb": (167, 156, 150)},
]
# Maximum allowed difference per color channel
diff --git a/tests/e2e/offline_inference/test_bagel_understanding.py b/tests/e2e/offline_inference/test_bagel_understanding.py
index bbee3298079..c3ed97b42bd 100644
--- a/tests/e2e/offline_inference/test_bagel_understanding.py
+++ b/tests/e2e/offline_inference/test_bagel_understanding.py
@@ -26,8 +26,9 @@
import pytest
from vllm.assets.image import ImageAsset
-from tests.conftest import OmniRunner, modify_stage_config
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
+from tests.helpers.stage_config import modify_stage_config
MODEL_NAME = "ByteDance-Seed/BAGEL-7B-MoT"
STAGE_CONFIG = str(Path(__file__).parent / "stage_configs" / "bagel_sharedmemory_ci.yaml")
diff --git a/tests/e2e/offline_inference/test_cache_dit.py b/tests/e2e/offline_inference/test_cache_dit.py
index fc08da7bedf..1577dd9f6db 100644
--- a/tests/e2e/offline_inference/test_cache_dit.py
+++ b/tests/e2e/offline_inference/test_cache_dit.py
@@ -11,8 +11,8 @@
import pytest
import torch
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_cosyvoice3.py b/tests/e2e/offline_inference/test_cosyvoice3.py
index 8c88d972d5e..db5debac828 100644
--- a/tests/e2e/offline_inference/test_cosyvoice3.py
+++ b/tests/e2e/offline_inference/test_cosyvoice3.py
@@ -26,8 +26,8 @@
from huggingface_hub import snapshot_download
from vllm.sampling_params import SamplingParams
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.model_executor.models.cosyvoice3.config import CosyVoice3Config
from vllm_omni.model_executor.models.cosyvoice3.tokenizer import get_qwen_tokenizer
diff --git a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
index 257755ef8b9..d7fd6f72f5b 100644
--- a/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_cpu_offload.py
@@ -4,8 +4,9 @@
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from tests.conftest import OmniRunner
-from tests.utils import DeviceMemoryMonitor, hardware_test
+from tests.helpers.env import DeviceMemoryMonitor
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
index bdfd594c774..4f19c100476 100644
--- a/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
+++ b/tests/e2e/offline_inference/test_diffusion_layerwise_offload.py
@@ -2,8 +2,8 @@
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from tests.conftest import OmniRunner
-from tests.utils import DeviceMemoryMonitor
+from tests.helpers.env import DeviceMemoryMonitor
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_diffusion_lora.py b/tests/e2e/offline_inference/test_diffusion_lora.py
index 7edd03f20d1..027dadb3f4e 100644
--- a/tests/e2e/offline_inference/test_diffusion_lora.py
+++ b/tests/e2e/offline_inference/test_diffusion_lora.py
@@ -7,7 +7,7 @@
import torch
from safetensors.torch import save_file
-from tests.conftest import OmniRunner
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_dynin_omni.py b/tests/e2e/offline_inference/test_dynin_omni.py
index 5388ac67468..f891fc4f12e 100644
--- a/tests/e2e/offline_inference/test_dynin_omni.py
+++ b/tests/e2e/offline_inference/test_dynin_omni.py
@@ -18,7 +18,7 @@
import torch
from transformers import AutoTokenizer
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
diff --git a/tests/e2e/offline_inference/test_expert_parallel.py b/tests/e2e/offline_inference/test_expert_parallel.py
index 29d84d7a3e2..f11646b300d 100644
--- a/tests/e2e/offline_inference/test_expert_parallel.py
+++ b/tests/e2e/offline_inference/test_expert_parallel.py
@@ -18,8 +18,8 @@
import torch.distributed as dist
from PIL import Image
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
index cbcd1009dd5..ef5d6f9e051 100644
--- a/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
+++ b/tests/e2e/offline_inference/test_flux_autoround_w4a16.py
@@ -14,8 +14,9 @@
import torch
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
-from tests.conftest import OmniRunner
-from tests.utils import DeviceMemoryMonitor, hardware_test
+from tests.helpers.env import DeviceMemoryMonitor
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_flux_kontext.py b/tests/e2e/offline_inference/test_flux_kontext.py
index cd711d6b818..057319c855f 100644
--- a/tests/e2e/offline_inference/test_flux_kontext.py
+++ b/tests/e2e/offline_inference/test_flux_kontext.py
@@ -13,7 +13,7 @@
from PIL import Image
from vllm.assets.image import ImageAsset
-from tests.conftest import OmniRunner
+from tests.helpers.runtime import OmniRunner
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
diff --git a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
index ec4f4693d75..bd0d132d093 100644
--- a/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
+++ b/tests/e2e/offline_inference/test_hunyuanimage3_text2img.py
@@ -8,7 +8,7 @@
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
-from tests.conftest import OmniRunner
+from tests.helpers.runtime import OmniRunner
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py b/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py
index 659040929e2..07aa5a647be 100644
--- a/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py
+++ b/tests/e2e/offline_inference/test_ltx2_cfg_parallel_parity.py
@@ -11,7 +11,7 @@
import pytest
from PIL import Image
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
REPO_ROOT = Path(__file__).resolve().parents[3]
T2V_EXAMPLE = REPO_ROOT / "examples" / "offline_inference" / "text_to_video" / "text_to_video.py"
diff --git a/tests/e2e/offline_inference/test_magi_human.py b/tests/e2e/offline_inference/test_magi_human.py
index abb7f9c163c..6d46141729e 100644
--- a/tests/e2e/offline_inference/test_magi_human.py
+++ b/tests/e2e/offline_inference/test_magi_human.py
@@ -8,8 +8,8 @@
import numpy as np
import pytest
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.diffusion.utils.media_utils import mux_video_audio_bytes
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
diff --git a/tests/e2e/offline_inference/test_mammoth_moda2.py b/tests/e2e/offline_inference/test_mammoth_moda2.py
index ff744c86e1e..c3d95844c11 100644
--- a/tests/e2e/offline_inference/test_mammoth_moda2.py
+++ b/tests/e2e/offline_inference/test_mammoth_moda2.py
@@ -23,8 +23,10 @@
import torch
from vllm.sampling_params import SamplingParams
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
+
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
# ---------------------------------------------------------------------------
# Constants
diff --git a/tests/e2e/offline_inference/test_ming_flash_omni.py b/tests/e2e/offline_inference/test_ming_flash_omni.py
new file mode 100644
index 00000000000..c591e910ac3
--- /dev/null
+++ b/tests/e2e/offline_inference/test_ming_flash_omni.py
@@ -0,0 +1,142 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import os
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+
+from pathlib import Path
+
+import pytest
+
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import (
+ generate_synthetic_audio,
+ generate_synthetic_image,
+ generate_synthetic_video,
+)
+from tests.helpers.stage_config import modify_stage_config
+
+models = ["Jonathan1909/Ming-flash-omni-2.0"]
+
+# Ming-specific
+SYSTEM_PROMPT = "你是一个友好的AI助手。\n\ndetailed thinking off"
+EOS_TOKEN = "<|role_end|>"
+IMAGE_TOKEN = ""
+VIDEO_TOKEN = ""
+AUDIO_TOKEN = ""
+
+
+def build_prompt(user_text: str) -> str:
+ """Build a Ming chat prompt."""
+ return (
+ f"SYSTEM {SYSTEM_PROMPT}{EOS_TOKEN}HUMAN {user_text}{EOS_TOKEN}ASSISTANT "
+ )
+
+
+def get_eager_config():
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"),
+ updates={
+ "stage_args": {
+ 0: {
+ "engine_args.enforce_eager": "true",
+ },
+ },
+ },
+ )
+ return path
+
+
+stage_configs = [get_eager_config()]
+test_params = [(model, stage_config) for model in models for stage_config in stage_configs]
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_text_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test text-only input processing and text output generation.
+ Input Modal: text
+ Output Modal: text
+ """
+ prompt = build_prompt("请详细介绍鹦鹉的生活习性。")
+ request_config = {"prompts": prompt, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_image_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test image understanding with text output.
+ Input Modal: image + text
+ Output Modal: text
+ """
+ image = generate_synthetic_image(224, 224)["np_array"]
+ prompt = build_prompt(f"{IMAGE_TOKEN}Describe this image briefly.")
+ request_config = {"prompts": prompt, "images": image, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_audio_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test audio understanding with text output.
+ Input Modal: audio + text
+ Output Modal: text
+ """
+ audio = generate_synthetic_audio(2, 1, 16000)["np_array"]
+ if len(audio.shape) == 2:
+ audio = audio.squeeze()
+ prompt = build_prompt(f"{AUDIO_TOKEN}Please recognize the language of this speech and transcribe it. Format: oral.")
+ request_config = {"prompts": prompt, "audios": audio, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_video_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test video understanding with text output.
+ Input Modal: video + text
+ Output Modal: text
+ """
+ video = generate_synthetic_video(224, 224, 30)["np_array"]
+ prompt = build_prompt(f"{VIDEO_TOKEN}Describe what is happening in this video.")
+ request_config = {"prompts": prompt, "videos": video, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_runner", test_params, indirect=True)
+def test_mixed_to_text(omni_runner, omni_runner_handler) -> None:
+ """
+ Test mixed modality input (image + audio) with text output.
+ Input Modal: image + audio + text
+ Output Modal: text
+ """
+ image = generate_synthetic_image(224, 224)["np_array"]
+ audio = generate_synthetic_audio(2, 1, 16000)["np_array"]
+ if len(audio.shape) == 2:
+ audio = audio.squeeze()
+ prompt = build_prompt(f"{IMAGE_TOKEN}{AUDIO_TOKEN}Describe the image and transcribe the audio.")
+ request_config = {"prompts": prompt, "images": image, "audios": audio, "modalities": ["text"]}
+
+ omni_runner_handler.send_request(request_config)
diff --git a/tests/e2e/offline_inference/test_omnivoice.py b/tests/e2e/offline_inference/test_omnivoice.py
index bb4c8a5dd7e..30a3427bee6 100644
--- a/tests/e2e/offline_inference/test_omnivoice.py
+++ b/tests/e2e/offline_inference/test_omnivoice.py
@@ -16,8 +16,8 @@
import numpy as np
import pytest
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
MODEL = "k2-fsa/OmniVoice"
diff --git a/tests/e2e/offline_inference/test_ovis_image.py b/tests/e2e/offline_inference/test_ovis_image.py
index 41e21bca3a9..70fab4fe101 100644
--- a/tests/e2e/offline_inference/test_ovis_image.py
+++ b/tests/e2e/offline_inference/test_ovis_image.py
@@ -16,7 +16,7 @@
import torch
from pytest_mock import MockerFixture
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
# Mock the OvisImageTransformer2DModel to avoid complex init if needed,
diff --git a/tests/e2e/offline_inference/test_quantization_fp8.py b/tests/e2e/offline_inference/test_quantization_fp8.py
index 291779fd931..92cf351e3d3 100644
--- a/tests/e2e/offline_inference/test_quantization_fp8.py
+++ b/tests/e2e/offline_inference/test_quantization_fp8.py
@@ -36,8 +36,8 @@
import pytest
import torch
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_qwen2_5_omni.py b/tests/e2e/offline_inference/test_qwen2_5_omni.py
index 4c4315aab9c..8ea41b00778 100644
--- a/tests/e2e/offline_inference/test_qwen2_5_omni.py
+++ b/tests/e2e/offline_inference/test_qwen2_5_omni.py
@@ -2,46 +2,39 @@
E2E tests for Qwen2.5-Omni model with mixed modality inputs, audio and text output.
"""
-from pathlib import Path
-
import pytest
-from tests.conftest import (
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import (
generate_synthetic_audio,
generate_synthetic_image,
generate_synthetic_video,
- modify_stage_config,
)
-from tests.utils import hardware_test
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
from vllm_omni.platforms import current_omni_platform
models = ["Qwen/Qwen2.5-Omni-7B"]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section. NPU still uses the legacy per-platform YAML until it
+# also migrates to the new schema.
+_CI_DEPLOY = get_deploy_config_path("ci/qwen2_5_omni.yaml")
+
def get_cuda_graph_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "qwen2_5_omni_ci.yaml"),
+ return modify_stage_config(
+ _CI_DEPLOY,
updates={
- "stage_args": {
- 0: {
- "engine_args.enforce_eager": "true",
- },
- 1: {"engine_args.enforce_eager": "true"},
+ "stages": {
+ 0: {"enforce_eager": True},
+ 1: {"enforce_eager": True},
},
},
)
- return path
-
-
-# CI stage config optimized for 24GB GPU (L4/RTX3090) or NPU
-if current_omni_platform.is_npu():
- stage_config = str(Path(__file__).parent / "stage_configs" / "npu" / "qwen2_5_omni_ci.yaml")
-elif current_omni_platform.is_rocm():
- # ROCm stage config optimized for MI325 GPU
- stage_config = str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen2_5_omni_ci.yaml")
-elif current_omni_platform.is_xpu():
- # Intel XPU stage config optimized for B60 GPU
- stage_config = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen2_5_omni_ci.yaml")
+
+
+if current_omni_platform.is_rocm() or current_omni_platform.is_xpu() or current_omni_platform.is_npu():
+ stage_config = _CI_DEPLOY
else:
stage_config = get_cuda_graph_config()
diff --git a/tests/e2e/offline_inference/test_qwen3_omni.py b/tests/e2e/offline_inference/test_qwen3_omni.py
index cc0af437eca..c4d257b5114 100644
--- a/tests/e2e/offline_inference/test_qwen3_omni.py
+++ b/tests/e2e/offline_inference/test_qwen3_omni.py
@@ -7,41 +7,35 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
-from tests.conftest import (
- generate_synthetic_video,
- modify_stage_config,
-)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import generate_synthetic_video
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
from vllm_omni.platforms import current_omni_platform
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section. Only CUDA needs an extra enforce_eager tweak.
+_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
+
+
def get_cuda_graph_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml"),
+ return modify_stage_config(
+ _CI_DEPLOY,
updates={
- "stage_args": {
- 0: {
- "engine_args.enforce_eager": "true",
- },
- 1: {"engine_args.enforce_eager": "true"},
+ "stages": {
+ 0: {"enforce_eager": True},
+ 1: {"enforce_eager": True},
},
},
)
- return path
-# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
-if current_omni_platform.is_rocm():
- # ROCm stage config optimized for MI325 GPU
- stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen3_omni_ci.yaml")]
-elif current_omni_platform.is_xpu():
- stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
+if current_omni_platform.is_rocm() or current_omni_platform.is_xpu():
+ stage_configs = [_CI_DEPLOY]
else:
stage_configs = [get_cuda_graph_config()]
diff --git a/tests/e2e/offline_inference/test_qwen3_tts_base.py b/tests/e2e/offline_inference/test_qwen3_tts_base.py
index be7bd50a36a..af2b5195b98 100644
--- a/tests/e2e/offline_inference/test_qwen3_tts_base.py
+++ b/tests/e2e/offline_inference/test_qwen3_tts_base.py
@@ -13,12 +13,10 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
-from tests.conftest import modify_stage_config
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
REF_AUDIO_URL = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone_2.wav"
@@ -26,23 +24,31 @@
def get_cuda_graph_config():
- path = modify_stage_config(
- get_stage_config(),
+ """Build a temp deploy yaml mirroring the deleted qwen3_tts_no_async_chunk.yaml.
+
+ Composes the synchronous (no-async-chunk) variant on top of the bundled
+ qwen3_tts.yaml prod default, with cudagraphs disabled. Replaces the deleted
+ standalone variant yaml; same effective config, no checked-in file needed.
+ """
+ return modify_stage_config(
+ get_deploy_config_path("qwen3_tts.yaml"),
updates={
- "stage_args": {
+ "async_chunk": False,
+ "stages": {
0: {
- "engine_args.enforce_eager": "true",
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
+ },
+ 1: {
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
},
- 1: {"engine_args.enforce_eager": "true"},
},
},
)
- return path
-
-
-def get_stage_config(name: str = "qwen3_tts_no_async_chunk.yaml"):
- """Get the no_async_chunk stage config path (async_chunk disable, cuda_graph disabled)."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
# Same structure as test_qwen3_omni: models, stage_configs, test_params
diff --git a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
index 67d72df908c..3214541af8f 100644
--- a/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
+++ b/tests/e2e/offline_inference/test_qwen3_tts_customvoice.py
@@ -13,34 +13,40 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
-from tests.conftest import modify_stage_config
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
def get_cuda_graph_config():
- path = modify_stage_config(
- get_stage_config(),
+ """Build a temp deploy yaml mirroring the deleted qwen3_tts_no_async_chunk.yaml.
+
+ Composes the synchronous (no-async-chunk) variant on top of the bundled
+ qwen3_tts.yaml prod default, with cudagraphs disabled. Replaces the deleted
+ standalone variant yaml; same effective config, no checked-in file needed.
+ """
+ return modify_stage_config(
+ get_deploy_config_path("qwen3_tts.yaml"),
updates={
- "stage_args": {
+ "async_chunk": False,
+ "stages": {
0: {
- "engine_args.enforce_eager": "true",
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
+ },
+ 1: {
+ "gpu_memory_utilization": 0.2,
+ "enforce_eager": True,
+ "async_scheduling": False,
},
- 1: {"engine_args.enforce_eager": "true"},
},
},
)
- return path
-
-
-def get_stage_config(name: str = "qwen3_tts_no_async_chunk.yaml"):
- """Get the no_async_chunk stage config path (async_chunk disable, cuda_graph disabled)."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
# Same structure as test_qwen3_omni: models, stage_configs, test_params
diff --git a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
index f0b0b55c9f6..2ce113d5bfd 100644
--- a/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
+++ b/tests/e2e/offline_inference/test_qwen_image_diffusion_batching.py
@@ -36,8 +36,8 @@
import pytest
import torch
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
diff --git a/tests/e2e/offline_inference/test_sequence_parallel.py b/tests/e2e/offline_inference/test_sequence_parallel.py
index d3abccd78cf..9f76b3b75c5 100644
--- a/tests/e2e/offline_inference/test_sequence_parallel.py
+++ b/tests/e2e/offline_inference/test_sequence_parallel.py
@@ -20,8 +20,8 @@
import torch.distributed as dist
from PIL import Image
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_stable_audio_expansion.py b/tests/e2e/offline_inference/test_stable_audio_expansion.py
index 54c1799e145..2fc07eb24ed 100644
--- a/tests/e2e/offline_inference/test_stable_audio_expansion.py
+++ b/tests/e2e/offline_inference/test_stable_audio_expansion.py
@@ -15,8 +15,8 @@
import pytest
import torch
-from tests.conftest import assert_audio_valid
-from tests.utils import hardware_test
+from tests.helpers.assertions import assert_audio_valid
+from tests.helpers.mark import hardware_test
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
diff --git a/tests/e2e/offline_inference/test_t2i_model.py b/tests/e2e/offline_inference/test_t2i_model.py
index fc54f9a7ff1..702f902cdb9 100644
--- a/tests/e2e/offline_inference/test_t2i_model.py
+++ b/tests/e2e/offline_inference/test_t2i_model.py
@@ -1,7 +1,7 @@
import pytest
import torch
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_t2v_model.py b/tests/e2e/offline_inference/test_t2v_model.py
index 6fe623cfc82..cedc9e59b37 100644
--- a/tests/e2e/offline_inference/test_t2v_model.py
+++ b/tests/e2e/offline_inference/test_t2v_model.py
@@ -3,7 +3,7 @@
import pytest
import torch
-from tests.conftest import OmniRunner
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_teacache.py b/tests/e2e/offline_inference/test_teacache.py
index 7cd1c5a4797..8152792fc01 100644
--- a/tests/e2e/offline_inference/test_teacache.py
+++ b/tests/e2e/offline_inference/test_teacache.py
@@ -11,8 +11,8 @@
import pytest
import torch
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
from vllm_omni.platforms import current_omni_platform
diff --git a/tests/e2e/offline_inference/test_vae_decode_parallelism.py b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
index 0fce28d6692..32902c318fa 100644
--- a/tests/e2e/offline_inference/test_vae_decode_parallelism.py
+++ b/tests/e2e/offline_inference/test_vae_decode_parallelism.py
@@ -18,7 +18,7 @@
import time
-from tests.conftest import OmniRunner
+from tests.helpers.runtime import OmniRunner
from vllm_omni.platforms import current_omni_platform
# os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "1"
diff --git a/tests/e2e/offline_inference/test_voxcpm.py b/tests/e2e/offline_inference/test_voxcpm.py
index d7f65525e93..bda087612de 100644
--- a/tests/e2e/offline_inference/test_voxcpm.py
+++ b/tests/e2e/offline_inference/test_voxcpm.py
@@ -12,9 +12,9 @@
import pytest
import torch
-import tests.conftest as omni_test_conftest
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+import tests.helpers.runtime as omni_runtime
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.model_executor.models.voxcpm.voxcpm_runtime_utils import (
prepare_voxcpm_hf_config_dir,
resolve_voxcpm_model_dir,
@@ -30,7 +30,7 @@
@pytest.fixture(autouse=True)
def _patch_npu_cleanup_for_voxcpm(monkeypatch: pytest.MonkeyPatch):
"""Limit the NPU cleanup workaround to this VoxCPM test module only."""
- original_cleanup = omni_test_conftest.cleanup_dist_env_and_memory
+ original_cleanup = omni_runtime.cleanup_dist_env_and_memory
def _safe_cleanup() -> None:
try:
@@ -40,7 +40,7 @@ def _safe_cleanup() -> None:
return
raise
- monkeypatch.setattr(omni_test_conftest, "cleanup_dist_env_and_memory", _safe_cleanup)
+ monkeypatch.setattr(omni_runtime, "cleanup_dist_env_and_memory", _safe_cleanup)
def _build_prompt(text: str) -> dict[str, Any]:
diff --git a/tests/e2e/offline_inference/test_voxcpm2.py b/tests/e2e/offline_inference/test_voxcpm2.py
index 6ec4630a45e..cffdee8001e 100644
--- a/tests/e2e/offline_inference/test_voxcpm2.py
+++ b/tests/e2e/offline_inference/test_voxcpm2.py
@@ -5,8 +5,8 @@
import pytest
import torch
-from tests.conftest import OmniRunner
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
VOXCPM2_MODEL = "openbmb/VoxCPM2"
STAGE_CONFIG = os.path.join(
@@ -100,3 +100,31 @@ def test_voxcpm2_voice_clone_002(voxcpm2_engine):
audio = _extract_audio(outputs[0].outputs[0].multimodal_output)
duration_s = audio.shape[0] / SAMPLE_RATE
assert 0.5 < duration_s < 30.0, f"Audio duration out of range: {duration_s:.2f}s"
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "L4"}, num_cards=1)
+def test_voxcpm2_prefill_decode_mixed_batch_003(voxcpm2_engine):
+ """Regression: prefill+decode mixed batch must not crash (PR #2903)."""
+ long_prompt = (
+ "This is a deliberately long prompt that will stay in the decode "
+ "phase for many steps so that subsequent shorter prompts keep "
+ "entering prefill alongside it, reproducing the prefill plus "
+ "decode mixed batch scheduling pattern."
+ )
+ short_prompts = [
+ "Hello one.",
+ "Hello two.",
+ "Hello three.",
+ "Hello four.",
+ ]
+ requests = [{"prompt": long_prompt}] + [{"prompt": p} for p in short_prompts]
+
+ outputs = voxcpm2_engine.generate(requests)
+ assert len(outputs) == len(requests)
+
+ for i, out in enumerate(outputs):
+ audio = _extract_audio(out.outputs[0].multimodal_output)
+ duration_s = audio.shape[0] / SAMPLE_RATE
+ assert 0.1 < duration_s < 30.0, f"Request {i} audio duration out of range: {duration_s:.2f}s"
diff --git a/tests/e2e/offline_inference/test_voxtral_tts.py b/tests/e2e/offline_inference/test_voxtral_tts.py
index 4f440f243bf..6bf95fce1e8 100644
--- a/tests/e2e/offline_inference/test_voxtral_tts.py
+++ b/tests/e2e/offline_inference/test_voxtral_tts.py
@@ -29,8 +29,9 @@
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from vllm import SamplingParams
-from tests.conftest import OmniRunner, modify_stage_config
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
+from tests.helpers.stage_config import modify_stage_config
from vllm_omni.entrypoints.async_omni import AsyncOmni
MODEL = "mistralai/Voxtral-4B-TTS-2603"
diff --git a/tests/e2e/offline_inference/test_zimage_parallelism.py b/tests/e2e/offline_inference/test_zimage_parallelism.py
index 27edc48f205..ab330ee9a26 100644
--- a/tests/e2e/offline_inference/test_zimage_parallelism.py
+++ b/tests/e2e/offline_inference/test_zimage_parallelism.py
@@ -20,8 +20,9 @@
import torch
from PIL import Image
-from tests.conftest import OmniRunner
-from tests.utils import DeviceMemoryMonitor, hardware_test
+from tests.helpers.env import DeviceMemoryMonitor
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniRunner
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
from vllm_omni.outputs import OmniRequestOutput
@@ -213,7 +214,10 @@ def test_zimage_tensor_parallel_tp2(tmp_path: Path):
)
-@pytest.mark.integration
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.parallel
+@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards={"cuda": 4, "rocm": 2})
def test_zimage_vae_patch_parallel_tp2(tmp_path: Path):
if current_omni_platform.is_npu():
pytest.skip("Z-Image VAE patch parallel e2e test is only supported on CUDA and ROCm for now.")
diff --git a/tests/e2e/online_serving/test_bagel_expansion.py b/tests/e2e/online_serving/test_bagel_expansion.py
index e2d75e0d199..aa289c6d9a2 100644
--- a/tests/e2e/online_serving/test_bagel_expansion.py
+++ b/tests/e2e/online_serving/test_bagel_expansion.py
@@ -16,13 +16,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
PROMPT = "A futuristic city skyline at twilight, cyberpunk style, ultra-detailed, high resolution."
NEGATIVE_PROMPT = "low quality, blurry, distorted, deformed, watermark"
@@ -88,7 +83,7 @@ def _get_diffusion_feature_cases(model: str):
],
),
id="parallel_tp_2",
- marks=PARALLEL_FEATURE_MARKS,
+ marks=[*PARALLEL_FEATURE_MARKS, pytest.mark.skip(reason="issue: #2862")],
),
# Ulysses-SP degree=2 (2 GPUs)
pytest.param(
@@ -138,7 +133,7 @@ def test_bagel(
- Ulysses-SP (degree=2)
- Ring-Attention (degree=2)
- Validation is delegated to assert_diffusion_response in tests.conftest,
+ Validation is delegated to assert_diffusion_response in tests/helpers/assertions.py,
which checks output dimensions and basic correctness.
"""
diff --git a/tests/e2e/online_serving/test_bagel_online.py b/tests/e2e/online_serving/test_bagel_online.py
index a3f999f13da..745e9a1161f 100644
--- a/tests/e2e/online_serving/test_bagel_online.py
+++ b/tests/e2e/online_serving/test_bagel_online.py
@@ -28,8 +28,8 @@
import pytest
from vllm.assets.image import ImageAsset
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
diff --git a/tests/e2e/online_serving/test_cosyvoice3_tts.py b/tests/e2e/online_serving/test_cosyvoice3_tts.py
index 276b1782f52..e05b5e34f41 100644
--- a/tests/e2e/online_serving/test_cosyvoice3_tts.py
+++ b/tests/e2e/online_serving/test_cosyvoice3_tts.py
@@ -16,8 +16,8 @@
import pytest
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
MODEL = "FunAudioLLM/Fun-CosyVoice3-0.5B-2512"
diff --git a/tests/e2e/online_serving/test_dynin_omni_expansion.py b/tests/e2e/online_serving/test_dynin_omni_expansion.py
index 710c480f08d..b63fcbd7521 100644
--- a/tests/e2e/online_serving/test_dynin_omni_expansion.py
+++ b/tests/e2e/online_serving/test_dynin_omni_expansion.py
@@ -15,9 +15,9 @@
import soundfile as sf
from vllm.assets.image import ImageAsset
-from tests import conftest as tests_conftest
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import convert_audio_bytes_to_text
+from tests.helpers.runtime import OmniServerParams
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
@@ -87,7 +87,7 @@ def _convert_audio_bytes_to_text_without_ffmpeg(raw_bytes: bytes) -> str:
@pytest.fixture
def dynin_t2s_openai_client(openai_client, monkeypatch):
monkeypatch.setattr(
- tests_conftest,
+ convert_audio_bytes_to_text,
"convert_audio_bytes_to_text",
_convert_audio_bytes_to_text_without_ffmpeg,
)
diff --git a/tests/e2e/online_serving/test_flux2_expansion.py b/tests/e2e/online_serving/test_flux2_expansion.py
index 336bd83a1d2..ce06ad56461 100644
--- a/tests/e2e/online_serving/test_flux2_expansion.py
+++ b/tests/e2e/online_serving/test_flux2_expansion.py
@@ -11,12 +11,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
FOUR_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=4)
POSITIVE_PROMPT = "A cat sitting on a windowsill"
diff --git a/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py b/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py
index f59a0e783d7..0c45bb33f1d 100644
--- a/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py
+++ b/tests/e2e/online_serving/test_flux2_klein_inpaint_expansion.py
@@ -14,7 +14,7 @@
import pytest
from PIL import Image, ImageDraw
-from tests.conftest import OmniServer, OmniServerParams
+from tests.helpers.runtime import OmniServer, OmniServerParams
MODEL = "black-forest-labs/FLUX.2-klein-4B"
diff --git a/tests/e2e/online_serving/test_flux_2_dev_expansion.py b/tests/e2e/online_serving/test_flux_2_dev_expansion.py
index f7477ed803e..addf6f00248 100644
--- a/tests/e2e/online_serving/test_flux_2_dev_expansion.py
+++ b/tests/e2e/online_serving/test_flux_2_dev_expansion.py
@@ -14,13 +14,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
MODEL = "black-forest-labs/FLUX.2-dev"
PROMPT = "A cinematic mountain landscape at sunrise, dramatic clouds, ultra-detailed, realistic photography."
diff --git a/tests/e2e/online_serving/test_flux_kontext_expansion.py b/tests/e2e/online_serving/test_flux_kontext_expansion.py
index c13e1e8189d..574bb3db4a9 100644
--- a/tests/e2e/online_serving/test_flux_kontext_expansion.py
+++ b/tests/e2e/online_serving/test_flux_kontext_expansion.py
@@ -5,13 +5,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
- generate_synthetic_image,
-)
+from tests.helpers.media import generate_synthetic_image
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting."
NEGATIVE_PROMPT = "blurry, low quality, modern, geometrist"
diff --git a/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py b/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py
index de950edb900..7af4c43443d 100644
--- a/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py
+++ b/tests/e2e/online_serving/test_hunyuan_video_15_expansion.py
@@ -11,12 +11,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
PROMPT = "A cat walking across a sunlit garden, cinematic lighting, slow motion."
NEGATIVE_PROMPT = "low quality, blurry, distorted"
diff --git a/tests/e2e/online_serving/test_image_gen_edit.py b/tests/e2e/online_serving/test_image_gen_edit.py
index 7db740f2037..56747abd16e 100644
--- a/tests/e2e/online_serving/test_image_gen_edit.py
+++ b/tests/e2e/online_serving/test_image_gen_edit.py
@@ -22,7 +22,7 @@
from vllm.assets.image import ImageAsset
from vllm.utils.network_utils import get_open_port
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Increase timeout for downloading assets from S3 (default 5s is too short for CI)
diff --git a/tests/e2e/online_serving/test_images_generations_lora.py b/tests/e2e/online_serving/test_images_generations_lora.py
index fb1e3ea1e0f..931a572878e 100644
--- a/tests/e2e/online_serving/test_images_generations_lora.py
+++ b/tests/e2e/online_serving/test_images_generations_lora.py
@@ -22,8 +22,8 @@
from PIL import Image
from safetensors.torch import save_file
-from tests.conftest import OmniServer
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServer
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
diff --git a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py
index 8a2cfbcc145..9a96280f2e9 100644
--- a/tests/e2e/online_serving/test_longcat_image_edit_expansion.py
+++ b/tests/e2e/online_serving/test_longcat_image_edit_expansion.py
@@ -13,14 +13,9 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
- generate_synthetic_image,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.media import generate_synthetic_image
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
EDIT_PROMPT = "Transform this modern image into a cinematic animation style with vibrant colors and soft lighting."
NEGATIVE_PROMPT = "blurry, low quality, distorted, oversaturated"
diff --git a/tests/e2e/online_serving/test_longcat_image_expansion.py b/tests/e2e/online_serving/test_longcat_image_expansion.py
index 161e7cd2e65..7d12db6a2ca 100644
--- a/tests/e2e/online_serving/test_longcat_image_expansion.py
+++ b/tests/e2e/online_serving/test_longcat_image_expansion.py
@@ -13,13 +13,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
TEXT_TO_IMAGE_PROMPT = (
"A cinematic illustration of a cat typing on a silver laptop, soft window light, highly detailed."
diff --git a/tests/e2e/online_serving/test_mimo_audio.py b/tests/e2e/online_serving/test_mimo_audio.py
index 43eeb773355..331baf18a54 100644
--- a/tests/e2e/online_serving/test_mimo_audio.py
+++ b/tests/e2e/online_serving/test_mimo_audio.py
@@ -9,13 +9,10 @@
import pytest
-from tests.conftest import (
- OmniServerParams,
- dummy_messages_from_mix_data,
- generate_synthetic_audio,
- modify_stage_config,
-)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import generate_synthetic_audio
+from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
+from tests.helpers.stage_config import modify_stage_config
from vllm_omni.model_executor.model_loader.weight_utils import (
download_weights_from_hf_specific,
)
diff --git a/tests/e2e/online_serving/test_ming_flash_omni.py b/tests/e2e/online_serving/test_ming_flash_omni.py
new file mode 100644
index 00000000000..8161c438929
--- /dev/null
+++ b/tests/e2e/online_serving/test_ming_flash_omni.py
@@ -0,0 +1,246 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+E2E online serving tests for Ming-flash-omni-2.0 model (Thinker stage).
+Tests multimodal understanding via OpenAI-compatible API.
+"""
+
+import os
+from pathlib import Path
+
+import pytest
+
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import (
+ generate_synthetic_audio,
+ generate_synthetic_image,
+ generate_synthetic_video,
+)
+from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
+from tests.helpers.stage_config import modify_stage_config
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
+
+models = ["Jonathan1909/Ming-flash-omni-2.0"]
+
+
+def get_eager_config():
+ path = modify_stage_config(
+ str(Path(__file__).parent.parent / "stage_configs" / "bailingmm_moe_v2_lite_ci.yaml"),
+ updates={
+ "stage_args": {
+ 0: {
+ "engine_args.enforce_eager": "true",
+ },
+ },
+ },
+ )
+ return path
+
+
+stage_configs = [get_eager_config()]
+
+# Create parameter combinations for model and stage config
+test_params = [
+ OmniServerParams(model=model, stage_config_path=stage_config) for model in models for stage_config in stage_configs
+]
+
+
+def get_system_prompt():
+ return {
+ "role": "system",
+ "content": [
+ {
+ "type": "text",
+ "text": "你是一个友好的AI助手。\n\ndetailed thinking off",
+ }
+ ],
+ }
+
+
+def get_prompt(prompt_type="text_only"):
+ prompts = {
+ "text_only": "What is the capital of China? Answer in 20 words.",
+ "text_image": "What is in this image?",
+ "text_audio": "What is in this audio?",
+ "text_video": "What is in this video?",
+ "mix": "What is recited in the audio? What is in this image? What is in this video?",
+ }
+ return prompts.get(prompt_type, prompts["text_only"])
+
+
+def get_max_batch_size(size_type="few"):
+ batch_sizes = {"few": 5, "medium": 100, "large": 256}
+ return batch_sizes.get(size_type, 5)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_text_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: text
+ Output Modal: text
+ Input Setting: stream=False
+ Datasets: single request
+ """
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ content_text=get_prompt("text_only"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": False,
+ "modalities": ["text"],
+ "key_words": {"text": ["beijing"]},
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_text_to_text_stream_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: text
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: few requests
+ """
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ content_text=get_prompt("text_only"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ "key_words": {"text": ["beijing"]},
+ }
+
+ openai_client.send_omni_request(request_config, request_num=get_max_batch_size())
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_image_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: image + text
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: single request
+ """
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ image_data_url=image_data_url,
+ content_text=get_prompt("text_image"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_audio_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: audio + text
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: single request
+ """
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ audio_data_url=audio_data_url,
+ content_text=get_prompt("text_audio"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_video_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: video + text
+ Output Modal: text
+ Input Setting: stream=False
+ Datasets: single request
+ """
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ video_data_url=video_data_url,
+ content_text=get_prompt("text_video"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": False,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
+
+
+@pytest.mark.advanced_model
+@pytest.mark.core_model
+@pytest.mark.omni
+@hardware_test(res={"cuda": "H100"}, num_cards=4)
+@pytest.mark.parametrize("omni_server", test_params, indirect=True)
+def test_mix_to_text_001(omni_server, openai_client) -> None:
+ """
+ Input Modal: text + audio + image + video
+ Output Modal: text
+ Input Setting: stream=True
+ Datasets: single request
+ """
+ video_data_url = f"data:video/mp4;base64,{generate_synthetic_video(224, 224, 300)['base64']}"
+ image_data_url = f"data:image/jpeg;base64,{generate_synthetic_image(224, 224)['base64']}"
+ audio_data_url = f"data:audio/wav;base64,{generate_synthetic_audio(2, 1)['base64']}"
+ messages = dummy_messages_from_mix_data(
+ system_prompt=get_system_prompt(),
+ video_data_url=video_data_url,
+ image_data_url=image_data_url,
+ audio_data_url=audio_data_url,
+ content_text=get_prompt("mix"),
+ )
+
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "stream": True,
+ "modalities": ["text"],
+ }
+
+ openai_client.send_omni_request(request_config)
diff --git a/tests/e2e/online_serving/test_nextstep_expansion.py b/tests/e2e/online_serving/test_nextstep_expansion.py
new file mode 100644
index 00000000000..004a0967b56
--- /dev/null
+++ b/tests/e2e/online_serving/test_nextstep_expansion.py
@@ -0,0 +1,71 @@
+"""
+Online serving E2E for NextStep-1.1 text-to-image (tensor parallel).
+"""
+
+import os
+
+import pytest
+
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import (
+ OmniServer,
+ OmniServerParams,
+ OpenAIClientHandler,
+ dummy_messages_from_mix_data,
+)
+
+# L4: 4 GPUs + TP=4; XPU B60: 2 cards (use num_cards={"cuda": 4, "xpu": 4} if needed)
+FOUR_CARD_MARKS = hardware_marks(
+ res={"cuda": "L4", "xpu": "B60"},
+ num_cards={"cuda": 2, "xpu": 2},
+)
+
+POSITIVE_PROMPT = "A small red barn in a snowy field, simple illustration."
+NEGATIVE_PROMPT = "blurry, low quality"
+
+_DEFAULT_MODEL = "stepfun-ai/NextStep-1.1"
+
+
+def _get_diffusion_feature_cases(model: str):
+ """Single online config: TP=4, explicit pipeline class."""
+ return [
+ pytest.param(
+ OmniServerParams(
+ model=model,
+ server_args=[
+ "--tensor-parallel-size",
+ "2",
+ "--model-class-name",
+ "NextStep11Pipeline",
+ ],
+ ),
+ id="nextstep_tp4_pipeline",
+ marks=FOUR_CARD_MARKS,
+ ),
+ ]
+
+
+@pytest.mark.advanced_model
+@pytest.mark.diffusion
+@pytest.mark.parametrize(
+ "omni_server",
+ _get_diffusion_feature_cases(model=os.environ.get("VLLM_TEST_NEXTSTEP_MODEL", _DEFAULT_MODEL)),
+ indirect=True,
+)
+def test_nextstep_11(omni_server: OmniServer, openai_client: OpenAIClientHandler):
+ messages = dummy_messages_from_mix_data(content_text=POSITIVE_PROMPT)
+ request_config = {
+ "model": omni_server.model,
+ "messages": messages,
+ "extra_body": {
+ "height": 512,
+ "width": 512,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "guidance_scale_2": 1.0,
+ "negative_prompt": NEGATIVE_PROMPT,
+ "seed": 42,
+ },
+ }
+
+ openai_client.send_diffusion_request(request_config)
diff --git a/tests/e2e/online_serving/test_omnivoice.py b/tests/e2e/online_serving/test_omnivoice.py
index 4a0069f4022..892896e05c7 100644
--- a/tests/e2e/online_serving/test_omnivoice.py
+++ b/tests/e2e/online_serving/test_omnivoice.py
@@ -17,8 +17,9 @@
import httpx
import pytest
-from tests.conftest import OmniServerParams, generate_synthetic_audio
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import generate_synthetic_audio
+from tests.helpers.runtime import OmniServerParams
try:
from transformers import HiggsAudioV2TokenizerModel # noqa: F401
diff --git a/tests/e2e/online_serving/test_qwen2_5_omni.py b/tests/e2e/online_serving/test_qwen2_5_omni.py
index e2913ce0215..8a1a8eb9950 100644
--- a/tests/e2e/online_serving/test_qwen2_5_omni.py
+++ b/tests/e2e/online_serving/test_qwen2_5_omni.py
@@ -3,20 +3,13 @@
"""
import os
-from pathlib import Path
import pytest
-from tests.conftest import (
- OmniServerParams,
- dummy_messages_from_mix_data,
- generate_synthetic_audio,
- generate_synthetic_image,
- generate_synthetic_video,
- modify_stage_config,
-)
-from tests.utils import hardware_test
-from vllm_omni.platforms import current_omni_platform
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video
+from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
@@ -24,20 +17,9 @@
models = ["Qwen/Qwen2.5-Omni-7B"]
-
-def get_config():
- path = modify_stage_config(
- str(Path(__file__).parent.parent / "stage_configs" / "qwen2_5_omni_ci.yaml"),
- )
- return path
-
-
-# CI stage config for 2xH100-80G GPUs or AMD GPU MI325
-if current_omni_platform.is_rocm():
- # ROCm stage config optimized for MI325 GPU
- stage_configs = [str(Path(__file__).parent.parent / "stage_configs" / "rocm" / "qwen2_5_omni_ci.yaml")]
-else:
- stage_configs = [get_config()]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section in vllm_omni/deploy/ci/qwen2_5_omni.yaml.
+stage_configs = [modify_stage_config(get_deploy_config_path("ci/qwen2_5_omni.yaml"))]
# Create parameter combinations for model and stage config
test_params = [
diff --git a/tests/e2e/online_serving/test_qwen3_omni.py b/tests/e2e/online_serving/test_qwen3_omni.py
index 13af2ad1109..7d1a181d271 100644
--- a/tests/e2e/online_serving/test_qwen3_omni.py
+++ b/tests/e2e/online_serving/test_qwen3_omni.py
@@ -3,19 +3,13 @@
"""
import os
-from pathlib import Path
import pytest
-from tests.conftest import (
- OmniServerParams,
- dummy_messages_from_mix_data,
- generate_synthetic_audio,
- generate_synthetic_image,
- generate_synthetic_video,
- modify_stage_config,
-)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video
+from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
from vllm_omni.platforms import current_omni_platform
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -23,27 +17,24 @@
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
-QWEN3_OMNI_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
-QWEN3_OMNI_XPU_CONFIG_PATH = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")
+# Set VLLM_TEST_PD_MODE=1 to test PD disaggregation (follow-up — deploy overlay not yet migrated).
+_USE_PD = os.environ.get("VLLM_TEST_PD_MODE", "0") == "1"
+
+_CI_DEPLOY = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-def get_chunk_config(config_path: str):
- path = modify_stage_config(
- config_path,
- updates={
- "async_chunk": True,
- "stage_args": {
- 0: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk"
- },
- 1: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- },
- },
- },
- deletes={"stage_args": {2: ["custom_process_input_func"]}},
- )
- return path
+
+def get_chunk_config(config_path: str | None = None):
+ """Load the qwen3_omni CI deploy yaml with async_chunk modifications for streaming mode."""
+ if config_path is None:
+ config_path = _CI_DEPLOY
+ # TODO: remove this workaround once legacy `stage_args` path is deleted.
+ # The pipeline (qwen3_omni/pipeline.py) already wires
+ # thinker2talker_async_chunk / talker2code2wav_async_chunk on stage 0/1,
+ # so only async_chunk needs flipping. Writing nested `engine_args:` into
+ # the new-schema overlay trips _parse_stage_deploy's legacy branch and
+ # drops flat fields (load_format, max_num_seqs, ...).
+ return modify_stage_config(config_path, updates={"async_chunk": True})
def get_prefix_caching_config(config_path: str):
@@ -59,12 +50,16 @@ def get_prefix_caching_config(config_path: str):
return path
+# Platform-specific overrides live inside the new deploy yaml's ``platforms:``
+# section, so a single ``_CI_DEPLOY`` path serves CUDA, ROCm, and XPU.
+# TODO: re-add VLLM_TEST_PD_MODE branch once the PD-disaggregation deploy
+# overlay has been migrated to the new schema (previously used the deleted
+# ``qwen3_omni_moe_pd_ci.yaml`` stage-configs file).
if current_omni_platform.is_xpu():
- stage_configs = [QWEN3_OMNI_XPU_CONFIG_PATH]
- prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_XPU_CONFIG_PATH)]
-else: # MI325 GPU should share the same config as H100
- stage_configs = [get_chunk_config(QWEN3_OMNI_CONFIG_PATH)]
- prefix_caching_stage_configs = [get_prefix_caching_config(QWEN3_OMNI_CONFIG_PATH)]
+ stage_configs = [_CI_DEPLOY]
+else: # CUDA + ROCm MI325 share the same deploy config
+ stage_configs = [get_chunk_config()]
+prefix_caching_stage_configs = [get_prefix_caching_config(_CI_DEPLOY)]
# Create parameter combinations for model and stage config
test_params = [
@@ -116,7 +111,8 @@ def get_max_batch_size(size_type="few"):
@pytest.mark.advanced_model
@pytest.mark.core_model
@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.")
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
"""
@@ -155,7 +151,8 @@ def test_mix_to_text_audio_001(omni_server, openai_client) -> None:
@pytest.mark.advanced_model
@pytest.mark.core_model
@pytest.mark.omni
-@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+@pytest.mark.skipif(_USE_PD, reason="Temporarily skip PD mode in this test module.")
+@hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=3 if _USE_PD else 2)
@pytest.mark.parametrize("omni_server", test_params, indirect=True)
def test_text_to_text_001(omni_server, openai_client) -> None:
"""
diff --git a/tests/e2e/online_serving/test_qwen3_omni_expansion.py b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
index 06847f3d51b..6be2ebb2ac3 100644
--- a/tests/e2e/online_serving/test_qwen3_omni_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_omni_expansion.py
@@ -6,22 +6,14 @@
import os
-from vllm_omni.platforms import current_omni_platform
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
-from pathlib import Path
import pytest
-from tests.conftest import (
- OmniServerParams,
- dummy_messages_from_mix_data,
- generate_synthetic_audio,
- generate_synthetic_image,
- generate_synthetic_video,
- modify_stage_config,
-)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import generate_synthetic_audio, generate_synthetic_image, generate_synthetic_video
+from tests.helpers.runtime import OmniServerParams, dummy_messages_from_mix_data
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
@@ -40,47 +32,57 @@
LONG_AUDIO_DURATION_SEC = 120
-def get_chunk_config(default_path):
- path = modify_stage_config(
+def get_batch_token_config(default_path):
+ """Override stage 1's max_num_batched_tokens to exercise small-batch paths.
+
+ Uses the new flat-stage schema (``stages..``); the legacy
+ ``stage_args..engine_args.`` path no longer applies because
+ the deploy YAML doesn't nest engine fields under ``engine_args:``.
+ """
+ return modify_stage_config(
default_path,
updates={
- "async_chunk": True,
- "stage_args": {
- 0: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk",
- "default_sampling_params.max_tokens": 2048,
- },
- 1: {
- "engine_args.custom_process_next_stage_input_func": "vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk"
- },
- },
+ "stages": {1: {"max_num_batched_tokens": 64}},
},
- deletes={"stage_args": {2: ["custom_process_input_func"]}},
)
- return path
-def get_batch_token_config(default_path):
- path = modify_stage_config(
+def get_async_chunk_config(default_path):
+ """Flip async_chunk on and bump stage 0 thinker output to 2048 tokens.
+
+ Pipeline registry (qwen3_omni/pipeline.py) already wires
+ thinker2talker_async_chunk / talker2code2wav_async_chunk on stages 0/1,
+ so no per-stage processor override is needed. Using only flat-schema
+ writes so _parse_stage_deploy stays in its flat branch (nested
+ ``engine_args:`` would drop other overlay fields).
+ """
+ return modify_stage_config(
default_path,
updates={
- "stage_args": {1: {"engine_args.max_num_batched_tokens": 64}},
+ "stages": {0: {"default_sampling_params.max_tokens": 2048}},
},
)
- return path
-
-# CI stage config for 2*H100-80G GPUs
-default_path = str(Path(__file__).parent.parent / "stage_configs" / "qwen3_omni_ci.yaml")
-if current_omni_platform.is_xpu():
- default_path = str(Path(__file__).parent.parent / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")
+# CI deploy YAML (single file; xpu deltas applied via ``platforms:`` section).
+# The overlay explicitly sets ``async_chunk: False``, so ``default`` tests the
+# sync path and ``async_chunk`` tests the streaming path with a longer thinker
+# output — two distinct scenarios, kept as separate parametrizations.
+default_path = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
-# Create parameter combinations for model and stage config
test_params = [
- pytest.param(OmniServerParams(model=model, stage_config_path=default_path, use_stage_cli=True), id="default"),
pytest.param(
- OmniServerParams(model=model, stage_config_path=get_chunk_config(default_path), use_stage_cli=True),
+ OmniServerParams(
+ model=model, stage_config_path=default_path, use_stage_cli=True, server_args=["--no-async-chunk"]
+ ),
+ id="default",
+ ),
+ pytest.param(
+ OmniServerParams(
+ model=model,
+ stage_config_path=get_async_chunk_config(default_path),
+ use_stage_cli=True,
+ ),
id="async_chunk",
),
]
@@ -448,7 +450,7 @@ def test_one_word_prompt_001(omni_server, openai_client) -> None:
"key_words": {"text": ["london"]},
}
- # Retry only when assert_omni_response fails on text/audio cosine similarity (see tests/conftest.py).
+ # Retry only when assert_omni_response fails on text/audio cosine similarity (see tests/helpers/assertions.py).
_similarity_assert_msg = "The audio content is not same as the text"
_max_retries = 3
for attempt in range(_max_retries):
@@ -513,7 +515,7 @@ def test_speaker_002(omni_server, openai_client) -> None:
"key_words": {"text": ["beijing"]},
}
- # Retry only when assert_omni_response fails on preset voice gender (see tests/conftest.py).
+ # Retry only when assert_omni_response fails on preset voice gender (see tests/helpers/assertions.py).
_gender_assert_substr = "estimated gender"
_max_retries = 3
for attempt in range(_max_retries):
diff --git a/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py b/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py
new file mode 100644
index 00000000000..f3b26108199
--- /dev/null
+++ b/tests/e2e/online_serving/test_qwen3_omni_realtime_websocket.py
@@ -0,0 +1,206 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""
+E2E online tests for Qwen3-Omni /v1/realtime WebSocket (streaming PCM in, audio out).
+"""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import io
+import json
+import os
+import wave
+
+import pytest
+import websockets
+
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import (
+ convert_audio_bytes_to_text,
+ cosine_similarity_text,
+ generate_synthetic_audio,
+)
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path, modify_stage_config
+
+os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+
+MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+
+# The new-schema CI overlay bakes in async_chunk: False and covers CUDA/ROCm/XPU
+# via its ``platforms:`` section, so one path serves all three.
+default_stage_config = get_deploy_config_path("ci/qwen3_omni_moe.yaml")
+
+
+def _realtime_stage_config_path() -> str:
+ """CI omni layout without async_chunk; stage 0 thinker max_tokens=10."""
+ return modify_stage_config(
+ default_stage_config,
+ updates={"stages": {0: {"default_sampling_params.max_tokens": 10}}},
+ )
+
+
+realtime_server_params = [
+ pytest.param(
+ OmniServerParams(
+ model=MODEL,
+ stage_config_path=_realtime_stage_config_path(),
+ use_stage_cli=True,
+ ),
+ id="thinker_max_tokens_10",
+ ),
+]
+
+
+def _pcm16_mono_16k_from_wav_bytes(wav_bytes: bytes) -> bytes:
+ with wave.open(io.BytesIO(wav_bytes), "rb") as wf:
+ if wf.getnchannels() != 1:
+ raise ValueError(f"Expected mono WAV, got {wf.getnchannels()} channels")
+ if wf.getsampwidth() != 2:
+ raise ValueError(f"Expected 16-bit PCM, sampwidth={wf.getsampwidth()}")
+ if wf.getframerate() != 16000:
+ raise ValueError(f"Expected 16 kHz input for /v1/realtime, got {wf.getframerate()} Hz")
+ if wf.getcomptype() != "NONE":
+ raise ValueError(f"Expected uncompressed PCM, comptype={wf.getcomptype()!r}")
+ return wf.readframes(wf.getnframes())
+
+
+def _wav_bytes_from_pcm16(pcm: bytes, sample_rate_hz: int) -> bytes:
+ buf = io.BytesIO()
+ with wave.open(buf, "wb") as wf:
+ wf.setnchannels(1)
+ wf.setsampwidth(2)
+ wf.setframerate(sample_rate_hz)
+ wf.writeframes(pcm)
+ return buf.getvalue()
+
+
+async def _run_realtime_audio_roundtrip(
+ host: str,
+ port: int,
+ model: str,
+ pcm16: bytes,
+ *,
+ chunk_ms: int = 100,
+) -> dict:
+ uri = f"ws://{host}:{port}/v1/realtime"
+ incremental: list[bytes] = []
+ output_sr = 24000
+ text_chunks: list[str] = []
+ final_text = ""
+ delta_events = 0
+
+ bytes_per_ms = 16000 * 2 // 1000
+ chunk_bytes = max(bytes_per_ms * chunk_ms, 2)
+
+ async with websockets.connect(uri, max_size=64 * 1024 * 1024) as ws:
+ await ws.send(json.dumps({"type": "session.update", "model": model}))
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": False}))
+
+ for i in range(0, len(pcm16), chunk_bytes):
+ chunk = pcm16[i : i + chunk_bytes]
+ await ws.send(
+ json.dumps(
+ {
+ "type": "input_audio_buffer.append",
+ "audio": base64.b64encode(chunk).decode("utf-8"),
+ }
+ )
+ )
+
+ await ws.send(json.dumps({"type": "input_audio_buffer.commit", "final": True}))
+
+ while True:
+ message = await asyncio.wait_for(ws.recv(), timeout=600)
+ if isinstance(message, bytes):
+ continue
+
+ event = json.loads(message)
+ event_type = event.get("type")
+
+ if event_type == "session.created":
+ continue
+
+ if event_type == "response.audio.delta":
+ delta_events += 1
+ sr = event.get("sample_rate_hz")
+ if isinstance(sr, int) and sr > 0:
+ output_sr = sr
+ audio_b64 = event.get("audio", "")
+ if audio_b64:
+ incremental.append(base64.b64decode(audio_b64))
+ continue
+
+ if event_type == "transcription.delta":
+ d = event.get("delta", "")
+ if d:
+ text_chunks.append(d)
+ continue
+
+ if event_type == "transcription.done":
+ final_text = event.get("text", "") or "".join(text_chunks)
+ continue
+
+ if event_type == "response.audio.done":
+ break
+
+ if event_type == "error":
+ raise AssertionError(f"WebSocket error: {event}")
+
+ raise AssertionError(f"Unexpected WebSocket event: {event}")
+
+ out_pcm = b"".join(incremental)
+ return {
+ "output_pcm": out_pcm,
+ "output_sample_rate": output_sr,
+ "transcription_text": final_text if final_text else "".join(text_chunks),
+ "delta_events": delta_events,
+ }
+
+
+class TestQwen3OmniRealtimeWebSocket:
+ @pytest.mark.advanced_model
+ @pytest.mark.omni
+ @hardware_test(res={"cuda": "H100", "rocm": "MI325"}, num_cards=2)
+ @pytest.mark.parametrize("omni_server", realtime_server_params, indirect=True)
+ def test_streaming_audio_input_pcm_output(self, omni_server) -> None:
+ """
+ Short streamed 16 kHz mono PCM16 input; expect streamed PCM16 audio deltas and
+ transcription. Verify Whisper(output audio) aligns with model text (same idea
+ as multimodal omni e2e).
+ """
+ syn = generate_synthetic_audio(10, 1, sample_rate=16000)
+ wav_bytes = base64.b64decode(syn["base64"])
+ pcm16 = _pcm16_mono_16k_from_wav_bytes(wav_bytes)
+
+ result = asyncio.run(
+ _run_realtime_audio_roundtrip(
+ omni_server.host,
+ omni_server.port,
+ omni_server.model,
+ pcm16,
+ chunk_ms=100,
+ )
+ )
+
+ out_pcm = result["output_pcm"]
+ assert result["delta_events"] >= 1
+ assert out_pcm, "No output PCM from response.audio.delta"
+ assert len(out_pcm) % 2 == 0
+ assert len(out_pcm) >= 4096, "Output audio unexpectedly small"
+ assert result["output_sample_rate"] > 0
+
+ final_text = (result["transcription_text"] or "").strip()
+ assert final_text, "Expected non-empty transcription (model text stream)"
+
+ wav_out = _wav_bytes_from_pcm16(out_pcm, result["output_sample_rate"])
+ whisper_text = convert_audio_bytes_to_text(wav_out).strip()
+ assert whisper_text, "Whisper returned empty string for synthesized output audio"
+
+ sim = cosine_similarity_text(whisper_text.lower(), final_text.lower())
+ assert sim > 0.9, (
+ f"Output audio transcript should match model text (sim={sim:.3f}): "
+ f"whisper={whisper_text!r}, model_text={final_text!r}"
+ )
diff --git a/tests/e2e/online_serving/test_qwen3_tts_base.py b/tests/e2e/online_serving/test_qwen3_tts_base.py
index 002f9d99724..fd7bc43b55d 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_base.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_base.py
@@ -12,12 +12,11 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
@@ -25,11 +24,6 @@
REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -48,7 +42,7 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
index 3c33485e4f4..527eeaef4dd 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_base_expansion.py
@@ -12,12 +12,11 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
@@ -25,11 +24,6 @@
REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -48,16 +42,19 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
),
+ # Synchronous (no async-chunk) variant — ``--no-async-chunk`` alone
+ # flips the deploy yaml's bool and the pipeline dispatches to the
+ # end-to-end codec processor. No variant yaml / pipeline needed.
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts_no_async_chunk.yaml"),
- server_args=["--trust-remote-code", "--disable-log-stats"],
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ server_args=["--trust-remote-code", "--disable-log-stats", "--no-async-chunk"],
),
id="no_async_chunk",
),
diff --git a/tests/e2e/online_serving/test_qwen3_tts_batch.py b/tests/e2e/online_serving/test_qwen3_tts_batch.py
index 1a453afb72a..3ca0688195b 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_batch.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_batch.py
@@ -22,19 +22,18 @@
import pytest
import yaml
-from tests.conftest import (
- OmniServer,
- convert_audio_file_to_text,
- cosine_similarity_text,
-)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text
+from tests.helpers.runtime import OmniServer
+from tests.helpers.stage_config import get_deploy_config_path
MODEL = "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"
STAGE_INIT_TIMEOUT_S = 120
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
+def get_stage_config(name: str = "qwen3_tts.yaml") -> str:
+ """Resolve a deploy config path under vllm_omni/deploy/."""
+ return get_deploy_config_path(name)
@pytest.fixture(scope="module")
diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
index fb60df725ba..2577361a0c8 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice.py
@@ -12,21 +12,15 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="text"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -45,7 +39,7 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
diff --git a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
index 03a985896e4..b32303842aa 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_customvoice_expansion.py
@@ -12,21 +12,15 @@
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
-from pathlib import Path
-
import pytest
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path
MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
-def get_stage_config(name: str = "qwen3_tts.yaml"):
- """Get the stage config path from vllm_omni model_executor stage_configs."""
- return str(Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / name)
-
-
def get_prompt(prompt_type="english"):
"""Text prompt for text-to-audio tests (same as test_qwen3_omni - beijing test case)."""
prompts = {
@@ -46,16 +40,19 @@ def get_max_batch_size(size_type="few"):
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts.yaml"),
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
server_args=["--trust-remote-code", "--disable-log-stats"],
),
id="async_chunk",
),
+ # Synchronous (no async-chunk) variant — ``--no-async-chunk`` alone
+ # flips the deploy yaml's bool and the pipeline dispatches to the
+ # end-to-end codec processor. No variant yaml / pipeline needed.
pytest.param(
OmniServerParams(
model=MODEL,
- stage_config_path=get_stage_config("qwen3_tts_no_async_chunk.yaml"),
- server_args=["--trust-remote-code", "--disable-log-stats"],
+ stage_config_path=get_deploy_config_path("qwen3_tts.yaml"),
+ server_args=["--trust-remote-code", "--disable-log-stats", "--no-async-chunk"],
),
id="no_async_chunk",
),
diff --git a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
index 8c1c860819c..b91548c5a66 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_speaker_embedding.py
@@ -13,13 +13,13 @@
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
import struct
-from pathlib import Path
import httpx
import pytest
-from tests.conftest import OmniServer
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServer
+from tests.helpers.stage_config import get_deploy_config_path
MODEL_BASE = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
MODEL_BASE_1_7B = "Qwen/Qwen3-TTS-12Hz-1.7B-Base"
@@ -37,10 +37,8 @@
MAX_NEW_TOKENS = 256
-def get_stage_config():
- return str(
- Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml"
- )
+def get_stage_config() -> str:
+ return get_deploy_config_path("qwen3_tts.yaml")
def _server_args():
diff --git a/tests/e2e/online_serving/test_qwen3_tts_websocket.py b/tests/e2e/online_serving/test_qwen3_tts_websocket.py
index 849d1c11585..5ac021cf88b 100644
--- a/tests/e2e/online_serving/test_qwen3_tts_websocket.py
+++ b/tests/e2e/online_serving/test_qwen3_tts_websocket.py
@@ -7,13 +7,13 @@
import asyncio
import json
import os
-from pathlib import Path
import pytest
import websockets
-from tests.conftest import OmniServer
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServer
+from tests.helpers.stage_config import get_deploy_config_path
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
@@ -23,9 +23,7 @@
def get_stage_config() -> str:
- return str(
- Path(__file__).parent.parent.parent.parent / "vllm_omni" / "model_executor" / "stage_configs" / "qwen3_tts.yaml"
- )
+ return get_deploy_config_path("qwen3_tts.yaml")
@pytest.fixture(scope="module")
diff --git a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py
index 14e4c915b6b..c2461977d64 100644
--- a/tests/e2e/online_serving/test_qwen_image_edit_expansion.py
+++ b/tests/e2e/online_serving/test_qwen_image_edit_expansion.py
@@ -7,14 +7,9 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
- generate_synthetic_image,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.media import generate_synthetic_image
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
EDIT_PROMPT = "Transform this modern, geometrist image into a Vincent van Gogh style impressionist painting."
MULTI_EDIT_PROMPT = (
diff --git a/tests/e2e/online_serving/test_qwen_image_expansion.py b/tests/e2e/online_serving/test_qwen_image_expansion.py
index 88e56cc3e10..b6f91c13daa 100644
--- a/tests/e2e/online_serving/test_qwen_image_expansion.py
+++ b/tests/e2e/online_serving/test_qwen_image_expansion.py
@@ -12,13 +12,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- dummy_messages_from_mix_data,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
T2I_PROMPT = "A photo of a cat sitting on a laptop keyboard, digital art style."
NEGATIVE_PROMPT = "blurry, low quality"
diff --git a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py
index fc73801c0e0..b958cfc054c 100644
--- a/tests/e2e/online_serving/test_qwen_image_layered_expansion.py
+++ b/tests/e2e/online_serving/test_qwen_image_layered_expansion.py
@@ -14,15 +14,9 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- decode_b64_image,
- dummy_messages_from_mix_data,
- generate_synthetic_image,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.media import decode_b64_image, generate_synthetic_image
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler, dummy_messages_from_mix_data
MODEL = "Qwen/Qwen-Image-Layered"
EDIT_PROMPT = "Decompose this image into layers."
diff --git a/tests/e2e/online_serving/test_sd3_expansion.py b/tests/e2e/online_serving/test_sd3_expansion.py
index 3ed5cc5f308..37e590ba3e1 100644
--- a/tests/e2e/online_serving/test_sd3_expansion.py
+++ b/tests/e2e/online_serving/test_sd3_expansion.py
@@ -4,12 +4,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
FOUR_CARD_FEATURE_MARKS = hardware_marks(res={"cuda": "L4"}, num_cards=4)
POSITIVE_PROMPT = "A serene mountain landscape at sunset"
diff --git a/tests/e2e/online_serving/test_video_generation_api.py b/tests/e2e/online_serving/test_video_generation_api.py
index 0711a1048e3..6a8fe45875a 100644
--- a/tests/e2e/online_serving/test_video_generation_api.py
+++ b/tests/e2e/online_serving/test_video_generation_api.py
@@ -16,8 +16,8 @@
import pytest
import requests
-from tests.conftest import OmniServer
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServer
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
os.environ["VLLM_TEST_CLEAN_GPU_MEMORY"] = "0"
diff --git a/tests/e2e/online_serving/test_voxtral_tts.py b/tests/e2e/online_serving/test_voxtral_tts.py
index f795288f375..2dd46fcfaa8 100644
--- a/tests/e2e/online_serving/test_voxtral_tts.py
+++ b/tests/e2e/online_serving/test_voxtral_tts.py
@@ -17,8 +17,8 @@
import httpx
import pytest
-from tests.conftest import OmniServerParams
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.runtime import OmniServerParams
MODEL = "mistralai/Voxtral-4B-TTS-2603"
diff --git a/tests/e2e/online_serving/test_wan22_expansion.py b/tests/e2e/online_serving/test_wan22_expansion.py
index e5e2d748d58..7e5bc912113 100644
--- a/tests/e2e/online_serving/test_wan22_expansion.py
+++ b/tests/e2e/online_serving/test_wan22_expansion.py
@@ -19,13 +19,9 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
- generate_synthetic_image,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.media import generate_synthetic_image
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
PROMPT = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
NEGATIVE_PROMPT = "low quality, blurry, distorted face, extra limbs, bad anatomy, watermark, logo, text, ugly, deformed, mutated, jpeg artifacts"
diff --git a/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py b/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py
index 0de70afe862..1f7a8c0722f 100644
--- a/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py
+++ b/tests/e2e/online_serving/test_wan_2_1_vace_expansion.py
@@ -23,12 +23,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
MODEL = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
PROMPT = "A cat walking slowly across a sunlit garden path"
diff --git a/tests/e2e/online_serving/test_zimage_expansion.py b/tests/e2e/online_serving/test_zimage_expansion.py
index 9f90ec855b6..679233c82a9 100644
--- a/tests/e2e/online_serving/test_zimage_expansion.py
+++ b/tests/e2e/online_serving/test_zimage_expansion.py
@@ -13,12 +13,8 @@
import pytest
-from tests.conftest import (
- OmniServer,
- OmniServerParams,
- OpenAIClientHandler,
-)
-from tests.utils import hardware_marks
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams, OpenAIClientHandler
MODEL = "Tongyi-MAI/Z-Image-Turbo"
PROMPT = "A high-detail studio photo of an orange tabby cat sitting on a laptop keyboard."
diff --git a/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml
new file mode 100644
index 00000000000..fb0c72cc513
--- /dev/null
+++ b/tests/e2e/stage_configs/bailingmm_moe_v2_lite_ci.yaml
@@ -0,0 +1,35 @@
+# Thinker stage only
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ engine_args:
+ model_stage: thinker
+ model_arch: MingFlashOmniForConditionalGeneration
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ max_model_len: 32768
+ tensor_parallel_size: 4
+ hf_config_name: llm_config
+ load_format: dummy
+ mm_processor_cache_gb: 0
+ final_output: true
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ max_tokens: 100
+ repetition_penalty: 1.05
+ seed: 42
+ detokenize: true
+ ignore_eos: false
diff --git a/tests/e2e/stage_configs/dynin_omni_ci.yaml b/tests/e2e/stage_configs/dynin_omni_ci.yaml
index 02400075104..525b7d888c2 100644
--- a/tests/e2e/stage_configs/dynin_omni_ci.yaml
+++ b/tests/e2e/stage_configs/dynin_omni_ci.yaml
@@ -72,13 +72,8 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml
deleted file mode 100644
index a7c637d486a..00000000000
--- a/tests/e2e/stage_configs/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,109 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090).
-# This config is optimized for CI e2e tests.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.9
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- load_format: dummy
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.4
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- load_format: dummy
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.5 #increase the gpu memory utilization to enable the test on H800
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- max_num_batched_tokens: 8192
- max_model_len: 8192
- load_format: dummy
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 8192
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml b/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml
deleted file mode 100644
index 94013828478..00000000000
--- a/tests/e2e/stage_configs/qwen2_5_omni_thinker_ci.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.9
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/qwen3_omni_ci.yaml
deleted file mode 100644
index 08dd49de953..00000000000
--- a/tests/e2e/stage_configs/qwen3_omni_ci.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
-- stage_id: 0
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 5
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 32768
- max_model_len: 32768
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- load_format: dummy
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 150
- seed: 42
- ignore_eos: False
- detokenize: True
- repetition_penalty: 1.05
-
-- stage_id: 1
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 5
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.5
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- max_model_len: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- load_format: dummy
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 1000
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
-- stage_id: 2
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 5
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 100000
- hf_config_name: thinker_config
- async_scheduling: false
- load_format: dummy
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2000
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml
deleted file mode 100644
index 0c756ce56b1..00000000000
--- a/tests/e2e/stage_configs/rocm/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,106 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x 24GB GPU (L4/RTX3090/RTX4090).
-# This config is optimized for CI e2e tests.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.8
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- max_num_batched_tokens: 4096
- max_model_len: 4096
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml
deleted file mode 100644
index ac2b1fbd713..00000000000
--- a/tests/e2e/stage_configs/rocm/qwen3_omni_ci.yaml
+++ /dev/null
@@ -1,100 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
- - stage_id: 0
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- load_format: dummy
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 100
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- load_format: dummy
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 100
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- load_format: dummy
- async_scheduling: false
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 200
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml
deleted file mode 100644
index 14ef3c34385..00000000000
--- a/tests/e2e/stage_configs/xpu/qwen2_5_omni_ci.yaml
+++ /dev/null
@@ -1,108 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config is verified with 2 * Intel Arc Pro B60 XPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage
- engine_args:
- model_stage: thinker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.9 # thinker weight is around 16.74GB for Qwen2.5-Omni-7B
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 128
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- max_model_len: 16384
- max_num_batched_tokens: 16384
- max_num_seqs: 1
- gpu_memory_utilization: 0.5 # talker weight is 6.03GB for Qwen2.5-Omni-7B
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 4096
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "2"
- engine_args:
- max_num_seqs: 1
- model_stage: code2wav
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.3 # code2wav weight is around 1.46GB for Qwen2.5-Omni-7B
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml b/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml
deleted file mode 100644
index c4586e06649..00000000000
--- a/tests/e2e/stage_configs/xpu/qwen3_omni_ci.yaml
+++ /dev/null
@@ -1,109 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config is verified with 8 * Intel Arc Pro B60 XPU.
-stage_args:
-- stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0,1,2,3"
- engine_args:
- max_num_seqs: 1
- model_stage: thinker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.85 # thinker weight is around 61.08GB for Qwen3-Omni-30B-A3B-Instruct
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 4096
- max_model_len: 4096
- enable_prefix_caching: false
- hf_config_name: thinker_config
- tensor_parallel_size: 4
- max_cudagraph_capture_size: 0
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 100
- seed: 42
- ignore_eos: False
- detokenize: True
- repetition_penalty: 1.05
-
-- stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "4"
- engine_args:
- max_num_seqs: 1
- model_stage: talker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6 # talker weight is around 8.5GB for Qwen3-Omni-30B-A3B-Instruct
- skip_mm_profiling: true
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 4096
- max_model_len: 4096
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- max_cudagraph_capture_size: 0
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
-- stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "5"
- engine_args:
- max_num_seqs: 1
- model_stage: code2wav
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3 # code2wav weight is around 0.4GB for Qwen3-Omni-30B-A3B-Instruct
- skip_mm_profiling: true
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 100000
- hf_config_name: thinker_config
- async_scheduling: false
- max_cudagraph_capture_size: 0
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2000
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py
index 565c83c1ad4..4d69f24c56a 100644
--- a/tests/engine/test_arg_utils.py
+++ b/tests/engine/test_arg_utils.py
@@ -39,21 +39,28 @@ def test_default_stage_id_is_concrete_int():
assert cfg.stage_id == 0
-def test_multimodal_kwarg_overrides():
+def test_multimodal_kwarg_overrides(mocker):
"""Ensure that overrides in the multimodal config are preserved."""
- # Get a different value than the default for a multimodal field
sig = inspect.signature(OmniEngineArgs)
default_mm_cache = sig.parameters["mm_processor_cache_gb"].default
override_val = default_mm_cache + 1
- # NOTE: This needs to be a model that resolves to supports_multimodal=True
- # in vLLM, otherwise we won't have an MM config
+ fake_model_config = SimpleNamespace(
+ multimodal_config=SimpleNamespace(mm_processor_cache_gb=override_val),
+ )
+
+ def _fake_parent_create_model_config(self):
+ assert self.mm_processor_cache_gb == override_val
+ return fake_model_config
+
+ mocker.patch.object(EngineArgs, "create_model_config", _fake_parent_create_model_config)
+ mocker.patch.object(OmniModelConfig, "from_vllm_model_config", side_effect=lambda model_config, **_: model_config)
+
cfg = OmniEngineArgs(
model="Qwen/Qwen2-VL-2B-Instruct",
mm_processor_cache_gb=override_val,
).create_model_config()
- # Ensure that the override was applied correctly
assert cfg.multimodal_config is not None
assert cfg.multimodal_config.mm_processor_cache_gb == override_val
diff --git a/tests/engine/test_async_omni_engine_abort.py b/tests/engine/test_async_omni_engine_abort.py
index 34fdf45ea25..eda7a7a788e 100644
--- a/tests/engine/test_async_omni_engine_abort.py
+++ b/tests/engine/test_async_omni_engine_abort.py
@@ -2,20 +2,21 @@
import os
import sys
from contextlib import ExitStack
-from pathlib import Path
import pytest
from vllm import SamplingParams
from vllm.inputs import PromptType
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.stage_config import get_deploy_config_path
from vllm_omni.entrypoints.async_omni import AsyncOmni
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
SEED = 42
-stage_config = str(Path(__file__).parent.parent / "e2e" / "stage_configs" / "qwen2_5_omni_thinker_ci.yaml")
+# Single-stage thinker-only deploy, materialized from tests.helpers.stage_config._CI_OVERLAYS.
+stage_config = get_deploy_config_path("ci/qwen2_5_omni_thinker_only.yaml")
model = "Qwen/Qwen2.5-Omni-7B"
@@ -60,7 +61,6 @@ async def generate(
@pytest.mark.core_model
@pytest.mark.omni
-@pytest.mark.real_hf_config
@hardware_test(res={"cuda": "L4", "rocm": "MI325"}, num_cards=1)
@pytest.mark.asyncio
async def test_abort():
diff --git a/tests/engine/test_orchestrator.py b/tests/engine/test_orchestrator.py
index f9316467fa7..b012728cce5 100644
--- a/tests/engine/test_orchestrator.py
+++ b/tests/engine/test_orchestrator.py
@@ -68,7 +68,10 @@ def get_diffusion_output_nowait(self):
except queue.Empty:
return None
- def process_engine_inputs(self, source_outputs, prompt=None):
+ def set_engine_outputs(self, outputs) -> None:
+ return None
+
+ def process_engine_inputs(self, source_outputs, prompt=None, streaming_context=None):
return list(self.next_inputs)
async def abort_requests_async(self, request_ids: list[str]) -> None:
diff --git a/tests/entrypoints/openai_api/test_image_server.py b/tests/entrypoints/openai_api/test_image_server.py
index cd1e83393b0..607b3eaa813 100644
--- a/tests/entrypoints/openai_api/test_image_server.py
+++ b/tests/entrypoints/openai_api/test_image_server.py
@@ -106,10 +106,13 @@ def test_encode_image_base64():
class MockGenerationResult:
- """Mock result object from AsyncOmni.generate()"""
+ """Mock result object compatible with current diffusion output shape."""
def __init__(self, images):
self.images = images
+ self.request_output = SimpleNamespace(images=images)
+ self.stage_durations = {}
+ self.peak_memory_mb = 0.0
class FakeAsyncOmni:
@@ -117,20 +120,26 @@ class FakeAsyncOmni:
def __init__(self, images=None):
self.stage_configs = [
- SimpleNamespace(stage_type="llm"),
- SimpleNamespace(stage_type="diffusion"),
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
]
self.default_sampling_params_list = [SamplingParams(temperature=0.1), OmniDiffusionSamplingParams()]
self.captured_sampling_params_list = None
self.captured_prompt = None
self._images = images or [Image.new("RGB", (64, 64), color="green")]
- async def generate(self, prompt, request_id, sampling_params_list):
- self.captured_sampling_params_list = sampling_params_list
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
self.captured_prompt = prompt
images = [img.copy() for img in self._images]
yield MockGenerationResult(images)
+ def __class_getitem__(cls, item):
+ return cls
+
@pytest.fixture
def mock_async_diffusion(mocker: MockerFixture):
@@ -189,12 +198,54 @@ def async_omni_test_client():
"""Create test client with mocked AsyncOmni engine."""
from fastapi import FastAPI
+ from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ class FakeAsyncOmniClass(AsyncOmni):
+ def __init__(self):
+ stage_configs = [
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ]
+ default_sampling_params_list = [
+ SamplingParams(temperature=0.1),
+ OmniDiffusionSamplingParams(),
+ ]
+ self.engine = SimpleNamespace(
+ stage_configs=stage_configs,
+ default_sampling_params_list=default_sampling_params_list,
+ )
+ self.default_sampling_params_list = default_sampling_params_list
+ self.captured_sampling_params_list = None
+ self.captured_prompt = None
+ self._images = [Image.new("RGB", (64, 64), color="green")]
+ self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
+
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
+ self.captured_prompt = prompt
+ images = [img.copy() for img in self._images]
+ yield MockGenerationResult(images)
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ def get_diffusion_od_config(self):
+ return self.od_config
app = FastAPI()
app.include_router(router)
- app.state.engine_client = FakeAsyncOmni()
+ engine = FakeAsyncOmniClass()
+ chat_handler = object.__new__(OmniOpenAIServingChat)
+ chat_handler.engine_client = engine
+ chat_handler._diffusion_engine = None
+ app.state.openai_serving_chat = chat_handler
+ app.state.engine_client = engine
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
@@ -211,12 +262,54 @@ def async_omni_rgba_test_client():
"""Create test client with mocked AsyncOmni engine returning RGBA output."""
from fastapi import FastAPI
+ from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ class FakeAsyncOmniClass(AsyncOmni):
+ def __init__(self):
+ stage_configs = [
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ]
+ default_sampling_params_list = [
+ SamplingParams(temperature=0.1),
+ OmniDiffusionSamplingParams(),
+ ]
+ self.engine = SimpleNamespace(
+ stage_configs=stage_configs,
+ default_sampling_params_list=default_sampling_params_list,
+ )
+ self.default_sampling_params_list = default_sampling_params_list
+ self.captured_sampling_params_list = None
+ self.captured_prompt = None
+ self._images = [Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))]
+ self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
+
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
+ self.captured_prompt = prompt
+ images = [img.copy() for img in self._images]
+ yield MockGenerationResult(images)
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ def get_diffusion_od_config(self):
+ return self.od_config
app = FastAPI()
app.include_router(router)
- app.state.engine_client = FakeAsyncOmni(images=[Image.new("RGBA", (64, 64), color=(0, 255, 0, 128))])
+ engine = FakeAsyncOmniClass()
+ chat_handler = object.__new__(OmniOpenAIServingChat)
+ chat_handler.engine_client = engine
+ chat_handler._diffusion_engine = None
+ app.state.openai_serving_chat = chat_handler
+ app.state.engine_client = engine
app.state.stage_configs = [
SimpleNamespace(stage_type="llm"),
SimpleNamespace(stage_type="diffusion"),
@@ -233,16 +326,55 @@ def async_omni_stage_configs_only_client():
"""Create test client with refactored AsyncOmni compatibility surface only."""
from fastapi import FastAPI
+ from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.openai.api_server import router
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ class FakeAsyncOmniClass(AsyncOmni):
+ def __init__(self):
+ stage_configs = [
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ]
+ default_sampling_params_list = [
+ SamplingParams(temperature=0.1),
+ OmniDiffusionSamplingParams(),
+ ]
+ self.engine = SimpleNamespace(
+ stage_configs=stage_configs,
+ default_sampling_params_list=default_sampling_params_list,
+ )
+ self.default_sampling_params_list = default_sampling_params_list
+ self.captured_sampling_params_list = None
+ self.captured_prompt = None
+ self._images = [Image.new("RGB", (64, 64), color="green")]
+ self.od_config = SimpleNamespace(supports_multimodal_inputs=True)
+
+ async def generate(self, prompt, request_id, sampling_params=None, sampling_params_list=None):
+ if sampling_params_list is not None:
+ self.captured_sampling_params_list = sampling_params_list
+ else:
+ self.captured_sampling_params_list = [sampling_params]
+ self.captured_prompt = prompt
+ images = [img.copy() for img in self._images]
+ yield MockGenerationResult(images)
+
+ def __class_getitem__(cls, item):
+ return cls
+
+ def get_diffusion_od_config(self):
+ return self.od_config
app = FastAPI()
app.include_router(router)
- engine = FakeAsyncOmni()
+ engine = FakeAsyncOmniClass()
assert not hasattr(engine, "stage_list")
app.state.engine_client = engine
- # Intentionally do not populate app.state.stage_configs. Refactored
- # AsyncOmni exposes stage_configs on the engine instance.
+ chat_handler = object.__new__(OmniOpenAIServingChat)
+ chat_handler.engine_client = engine
+ chat_handler._diffusion_engine = None
+ app.state.openai_serving_chat = chat_handler
app.state.args = Namespace(
default_sampling_params='{"1": {"num_inference_steps":4, "guidance_scale":7.5, "generator_device":"cpu"}}',
max_generated_image_size=1024 * 1792,
@@ -306,6 +438,9 @@ def test_models_endpoint_no_engine():
def test_generate_single_image(test_client):
"""Test generating a single image"""
+ # Single-stage path should not require openai_serving_chat.
+ assert not hasattr(test_client.app.state, "openai_serving_chat")
+
response = test_client.post(
"/v1/images/generations",
json={
@@ -374,6 +509,43 @@ def test_generate_images_async_omni_stage_configs_only(async_omni_stage_configs_
assert captured[1].seed == 11
+def test_multistage_images_async_omni_construction(async_omni_test_client):
+ """Regression: multistage image generation builds the expected chat-style payload."""
+ response = async_omni_test_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "a cat",
+ "n": 2,
+ "size": "128x256",
+ "seed": 7,
+ "num_inference_steps": 12,
+ "guidance_scale": 6.5,
+ },
+ )
+ assert response.status_code == 200
+
+ engine = async_omni_test_client.app.state.engine_client
+ captured_prompt = engine.captured_prompt
+ assert captured_prompt["prompt"] == "a cat"
+ assert captured_prompt["modalities"] == ["image"]
+ assert captured_prompt["mm_processor_kwargs"] == {
+ "target_h": 256,
+ "target_w": 128,
+ }
+
+ captured = engine.captured_sampling_params_list
+ assert captured is not None
+ assert len(captured) == 2
+ assert captured[0].temperature == 0.1
+ assert captured[0].seed == 7
+ assert captured[1].num_outputs_per_prompt == 2
+ assert captured[1].width == 128
+ assert captured[1].height == 256
+ assert captured[1].seed == 7
+ assert captured[1].num_inference_steps == 12
+ assert captured[1].guidance_scale == 6.5
+
+
def test_image_edits_async_omni_stage_configs_only(async_omni_stage_configs_only_client):
"""Regression: image edits accepts refactored AsyncOmni without stage_list."""
img_bytes = make_test_image_bytes((16, 16))
@@ -679,6 +851,19 @@ def test_model_field_omitted_works(test_client):
assert response.status_code == 200
+def test_generate_images_rejects_model_mismatch(test_client):
+ response = test_client.post(
+ "/v1/images/generations",
+ json={
+ "prompt": "test",
+ "model": "Qwen/Qwen-Image-2512",
+ "size": "1024x1024",
+ },
+ )
+ assert response.status_code == 400
+ assert "model mismatch" in response.json()["detail"].lower()
+
+
def make_test_image_bytes(size=(64, 64)) -> bytes:
img = Image.new(
"RGB",
@@ -782,6 +967,77 @@ def test_image_edit_rejects_multiple_images_when_model_does_not_support_them(asy
assert engine.captured_prompt is None
+def test_image_edit_rejects_model_mismatch(test_client):
+ img_bytes = make_test_image_bytes((16, 16))
+ response = test_client.post(
+ "/v1/images/edits",
+ files=[("image", img_bytes)],
+ data={
+ "prompt": "edit me",
+ "model": "Qwen/Qwen-Image-Edit",
+ },
+ )
+ assert response.status_code == 400
+ assert "model mismatch" in response.json()["detail"].lower()
+
+
+def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511(async_omni_test_client):
+ engine = async_omni_test_client.app.state.engine_client
+ engine.get_diffusion_od_config = lambda: SimpleNamespace(
+ supports_multimodal_inputs=True,
+ max_multimodal_image_inputs=4,
+ )
+
+ response = async_omni_test_client.post(
+ "/v1/images/edits",
+ files=[
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ],
+ data={"prompt": "hello world."},
+ )
+
+ assert response.status_code == 400
+ assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
+ assert engine.captured_prompt is None
+
+
+def test_image_edit_rejects_too_many_images_for_qwen_image_edit_2511_before_loading(
+ async_omni_test_client, monkeypatch: pytest.MonkeyPatch
+):
+ import vllm_omni.entrypoints.openai.api_server as api_server_module
+
+ engine = async_omni_test_client.app.state.engine_client
+ engine.get_diffusion_od_config = lambda: SimpleNamespace(
+ supports_multimodal_inputs=True,
+ max_multimodal_image_inputs=4,
+ )
+
+ def _fail_load(*args, **kwargs):
+ raise AssertionError("_load_input_images should not run for over-limit requests")
+
+ monkeypatch.setattr(api_server_module, "_load_input_images", _fail_load)
+
+ response = async_omni_test_client.post(
+ "/v1/images/edits",
+ files=[
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ("image", make_test_image_bytes((16, 16))),
+ ],
+ data={"prompt": "hello world."},
+ )
+
+ assert response.status_code == 400
+ assert response.json()["detail"] == "Received 5 input images. At most 4 images are supported by this model."
+ assert engine.captured_prompt is None
+
+
def test_image_edit_parameter_pass(async_omni_test_client):
img_bytes_1 = make_test_image_bytes((16, 16))
diff --git a/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py
new file mode 100644
index 00000000000..a9b9f53ba8a
--- /dev/null
+++ b/tests/entrypoints/openai_api/test_serving_chat_multistage_generation.py
@@ -0,0 +1,82 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Regression tests for multistage diffusion generation input construction."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+from PIL import Image
+from vllm.sampling_params import SamplingParams
+
+from vllm_omni.inputs.data import OmniDiffusionSamplingParams
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture
+def serving_chat():
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ return object.__new__(OmniOpenAIServingChat)
+
+
+def test_build_multistage_generation_inputs_applies_stage_specific_overrides(serving_chat):
+ from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
+
+ engine = SimpleNamespace(
+ stage_configs=[
+ SimpleNamespace(stage_type="llm", is_comprehension=True),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ SimpleNamespace(stage_type="diffusion", is_comprehension=False),
+ ],
+ default_sampling_params_list=[
+ SamplingParams(temperature=0.2, seed=11),
+ OmniDiffusionSamplingParams(),
+ OmniDiffusionSamplingParams(),
+ ],
+ )
+ reference_image = Image.new("RGB", (24, 24), color="green")
+ extra_body = {
+ "negative_prompt": "blurry",
+ "num_inference_steps": 28,
+ "guidance_scale": 7.5,
+ "true_cfg_scale": 5.0,
+ "guidance_scale_2": 1.25,
+ "layers": 6,
+ "resolution": 1024,
+ "lora": {"name": "adapter-a", "path": "/tmp/adapter-a", "scale": 0.6},
+ }
+ gen_params = OmniDiffusionSamplingParams(height=768, width=1024, seed=0, num_outputs_per_prompt=2)
+
+ engine_prompt, sampling_params_list = OmniOpenAIServingChat._build_multistage_generation_inputs(
+ serving_chat,
+ engine=engine,
+ prompt="draw a robot",
+ extra_body=extra_body,
+ reference_images=[reference_image],
+ gen_params=gen_params,
+ )
+
+ assert engine_prompt["prompt"] == "draw a robot"
+ assert engine_prompt["modalities"] == ["img2img"]
+ assert engine_prompt["negative_prompt"] == "blurry"
+ assert engine_prompt["mm_processor_kwargs"] == {"target_h": 768, "target_w": 1024}
+ assert engine_prompt["multi_modal_data"]["img2img"].size == (24, 24)
+
+ assert len(sampling_params_list) == 3
+ assert sampling_params_list[0].temperature == 0.2
+ assert sampling_params_list[0].seed == 0
+ assert sampling_params_list[1].height == 768
+ assert sampling_params_list[1].width == 1024
+ assert sampling_params_list[1].seed == 0
+ assert sampling_params_list[1].num_inference_steps == 28
+ assert sampling_params_list[1].guidance_scale == 7.5
+ assert sampling_params_list[1].num_outputs_per_prompt == 2
+ assert sampling_params_list[1].true_cfg_scale == 5.0
+ assert sampling_params_list[1].lora_request.name == "adapter-a"
+ assert sampling_params_list[2].height == 768
+ assert sampling_params_list[2].width == 1024
+ assert sampling_params_list[2].num_inference_steps == 28
+ assert engine.default_sampling_params_list[1].height is None
+ assert engine.default_sampling_params_list[2].resolution == 640
diff --git a/tests/entrypoints/openai_api/test_text_splitter.py b/tests/entrypoints/openai_api/test_text_splitter.py
index a1886662ae5..b9022e015dd 100644
--- a/tests/entrypoints/openai_api/test_text_splitter.py
+++ b/tests/entrypoints/openai_api/test_text_splitter.py
@@ -4,7 +4,7 @@
from vllm_omni.entrypoints.openai.text_splitter import SentenceSplitter
-pytestmark = [pytest.mark.openai, pytest.mark.speech, pytest.mark.core_model, pytest.mark.cpu]
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
class TestSentenceSplitterEnglish:
diff --git a/tests/entrypoints/openai_api/test_video_api_utils.py b/tests/entrypoints/openai_api/test_video_api_utils.py
index 5012c9b9826..9e732403fbb 100644
--- a/tests/entrypoints/openai_api/test_video_api_utils.py
+++ b/tests/entrypoints/openai_api/test_video_api_utils.py
@@ -72,7 +72,7 @@ def test_frame_interpolator_runs_actual_torch_tensor_path(monkeypatch):
assert torch.isfinite(output_video).all()
-def test_frame_interpolator_prefers_input_tensor_device(monkeypatch):
+def test_frame_interpolator_uses_platform_device_when_tensor_is_cpu(monkeypatch):
chosen_devices = []
model = rife_interpolator.Model().eval()
@@ -83,10 +83,11 @@ def _fake_ensure_model_loaded(*, preferred_device=None):
interpolator = rife_interpolator.FrameInterpolator()
monkeypatch.setattr(interpolator, "_ensure_model_loaded", _fake_ensure_model_loaded)
monkeypatch.setattr(model.flownet, "to", lambda device: model.flownet)
+ monkeypatch.setattr(rife_interpolator, "_select_torch_device", lambda: torch.device("cuda"))
video = torch.zeros(1, 3, 2, 32, 32)
output_video, multiplier = interpolator.interpolate_tensor(video, exp=1, scale=1.0)
- assert chosen_devices == [video.device]
+ assert chosen_devices == [torch.device("cuda")]
assert multiplier == 2
assert output_video.shape == (1, 3, 3, 32, 32)
diff --git a/tests/entrypoints/openai_api/test_video_server.py b/tests/entrypoints/openai_api/test_video_server.py
index 7a395bab5b0..6157d82313e 100644
--- a/tests/entrypoints/openai_api/test_video_server.py
+++ b/tests/entrypoints/openai_api/test_video_server.py
@@ -564,6 +564,18 @@ def test_missing_prompt_returns_422(test_client):
assert response.status_code == 422
+def test_video_generation_rejects_model_mismatch(test_client):
+ response = test_client.post(
+ "/v1/videos",
+ data={
+ "prompt": "bad model",
+ "model": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ },
+ )
+ assert response.status_code == 400
+ assert "model mismatch" in response.json()["detail"].lower()
+
+
def test_invalid_size_parse_returns_422(test_client):
response = test_client.post(
"/v1/videos",
diff --git a/tests/entrypoints/test_pd_disaggregation.py b/tests/entrypoints/test_pd_disaggregation.py
new file mode 100644
index 00000000000..5ffabfbf2af
--- /dev/null
+++ b/tests/entrypoints/test_pd_disaggregation.py
@@ -0,0 +1,1222 @@
+"""Unit tests for PD (Prefill-Decode) disaggregation in the Omni orchestrator.
+
+Tests the PD detection, validation, config parsing, sampling param
+preparation, and routing logic added by the PD disaggregation feature
+(issue #1188). All tests run without GPU.
+
+NOTE (v1908 adaptation): Tests that relied on the old OmniStage / stage_list
+architecture (removed in PR #1908) are marked xfail with
+``reason="Requires migration to v1908 Orchestrator architecture"``.
+The remaining tests exercise PDDisaggregationMixin directly and work
+without spinning up a real engine.
+"""
+
+import uuid
+import warnings
+from queue import Empty, Queue
+from types import SimpleNamespace
+from typing import Any
+
+import pytest
+from vllm import SamplingParams
+
+from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
+
+pytestmark = pytest.mark.skip(reason="Temporarily skip PD entrypoint tests while PD config is being removed.")
+
+# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies.
+warnings.filterwarnings(
+ "ignore",
+ message=r"builtin type SwigPy.*has no __module__ attribute",
+ category=DeprecationWarning,
+)
+
+
+def _ns(**kwargs):
+ """Create a lightweight attribute object for tests."""
+ return SimpleNamespace(**kwargs)
+
+
+# ---------------------------------------------------------------------------
+# Fake helpers (same pattern as test_omni_llm.py)
+# ---------------------------------------------------------------------------
+
+
+class _FakeEngineArgs(dict):
+ """Fake engine args that supports both attribute and dict access."""
+
+ def __init__(self, args_dict: dict[str, Any]):
+ super().__init__(args_dict)
+ if "model_stage" not in self:
+ self["model_stage"] = None
+ if "engine_output_type" not in self:
+ self["engine_output_type"] = None
+ for key, value in self.items():
+ setattr(self, key, value)
+
+
+class _FakeStageConfig:
+ def __init__(self, config_dict: dict[str, Any]):
+ engine_args_dict = config_dict.get("engine_args", {})
+ self.engine_args = _FakeEngineArgs(engine_args_dict)
+ self.final_output = config_dict.get("final_output", False)
+ self.final_output_type = config_dict.get("final_output_type", None)
+ self.stage_id = config_dict.get("stage_id", 0)
+ self.is_prefill_only = config_dict.get("is_prefill_only", False)
+ self.is_decode_only = config_dict.get("is_decode_only", False)
+ self.engine_input_source = config_dict.get("engine_input_source", [])
+ self.is_comprehension = config_dict.get("is_comprehension", False)
+ self._config_dict = config_dict
+
+
+class _FakeQueue:
+ def __init__(self, maxsize=0):
+ self._queue = Queue(maxsize=maxsize)
+
+ def put(self, item):
+ self._queue.put(item)
+
+ def put_nowait(self, item):
+ self._queue.put_nowait(item)
+
+ def get(self):
+ return self._queue.get()
+
+ def get_nowait(self):
+ return self._queue.get_nowait()
+
+ def empty(self):
+ return self._queue.empty()
+
+
+class _FakeStage:
+ """Lightweight stage stub with PD disaggregation flag support."""
+
+ def __init__(self, config, stage_init_timeout: int = 300):
+ if isinstance(config, dict):
+ config = _FakeStageConfig(config)
+ self.config = config
+ self.stage_config = config
+ self.engine = None
+ self.engine_outputs = None
+ self.stage_id = getattr(config, "stage_id", 0)
+ self.engine_args = config.engine_args
+ self.model_stage = getattr(config.engine_args, "model_stage", None)
+ self.stage_type = "llm"
+ self.default_sampling_params = SamplingParams(temperature=1.0)
+ self.final_output = config.final_output if hasattr(config, "final_output") else False
+ self.final_output_type = getattr(config, "final_output_type", None)
+ self.is_prefill_only = getattr(config, "is_prefill_only", False)
+ self.is_decode_only = getattr(config, "is_decode_only", False)
+ self.engine_input_source = getattr(config, "engine_input_source", [])
+ self.is_comprehension = getattr(config, "is_comprehension", False)
+ processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"])
+ self._processed_input = processed_input
+ self._in_q = None
+ self._out_q = None
+ self._proc = None
+ self._stage_init_timeout = max(0, int(stage_init_timeout))
+
+ def attach_queues(self, in_q, out_q):
+ self._in_q = in_q
+ self._out_q = out_q
+
+ def init_stage_worker(
+ self, model: str, *, is_async=False, shm_threshold_bytes=65536, ctx=None, batch_timeout=10, **kwargs
+ ):
+ self._proc = _ns(
+ start=lambda: None,
+ join=lambda timeout=None: None,
+ is_alive=lambda: False,
+ terminate=lambda: None,
+ )
+ if self._out_q is not None:
+ try:
+ self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id})
+ except Exception:
+ pass
+
+ def stop_stage_worker(self):
+ if self._in_q is not None:
+ try:
+ self._in_q.put_nowait({"type": "shutdown"})
+ except Exception:
+ pass
+
+ def submit(self, payload: dict[str, Any]):
+ if self._in_q is not None:
+ self._in_q.put(payload)
+
+ def try_collect(self) -> Any:
+ if self._out_q is None:
+ return None
+ try:
+ return self._out_q.get_nowait()
+ except Empty:
+ return None
+
+ def set_engine_outputs(self, outputs):
+ self.engine_outputs = outputs
+
+ def process_engine_inputs(self, stage_list, prompts):
+ return self._processed_input
+
+
+# ---------------------------------------------------------------------------
+# Shared mock setup helpers
+# ---------------------------------------------------------------------------
+
+
+def _setup_engine_mocks(monkeypatch):
+ fake_engine = _ns()
+ fake_engine.tokenizer = _ns()
+ fake_engine.log_stats = False
+ fake_engine.vllm_config = _ns()
+ fake_engine.vllm_config.model_config = _ns()
+ fake_engine.vllm_config.model_config.io_processor_plugin = None
+ fake_engine.get_supported_tasks = lambda: []
+ fake_engine.model_config = _ns()
+ fake_engine.model_config.io_processor_plugin = None
+ fake_registry = _ns()
+ fake_registry.resolve_model_cls = lambda *args, **kwargs: (_ns(), "test_arch")
+ fake_engine.model_config.registry = fake_registry
+ fake_engine.vllm_config.model_config.registry = fake_registry
+
+ monkeypatch.setattr(
+ "vllm.v1.engine.llm_engine.LLMEngine.from_engine_args",
+ lambda **kw: fake_engine,
+ raising=False,
+ )
+
+ class FakeModelClass:
+ pass
+
+ monkeypatch.setattr(
+ "vllm.model_executor.model_loader.utils.get_model_architecture",
+ lambda model_config: (FakeModelClass, "test_arch"),
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.model_executor.model_loader.utils._get_model_architecture",
+ lambda model_config: (FakeModelClass, "test_arch"),
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls",
+ lambda model_cls: model_cls,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.multimodal.cache._enable_processor_cache",
+ lambda model_config, mm_registry: False,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ "vllm.plugins.io_processors.get_io_processor",
+ lambda vllm_config, io_processor_plugin: None,
+ raising=False,
+ )
+
+
+def _setup_multiprocessing_mocks(monkeypatch):
+ import multiprocessing as mp
+
+ fake_process_instance = _ns(
+ start=lambda: None,
+ join=lambda timeout=None: None,
+ is_alive=lambda: False,
+ terminate=lambda: None,
+ )
+
+ def fake_process_class(*args, **kwargs):
+ return fake_process_instance
+
+ fake_ctx = _ns()
+ fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize)
+ fake_ctx.Process = fake_process_class
+
+ monkeypatch.setattr(mp, "get_context", lambda method: fake_ctx, raising=False)
+ monkeypatch.setattr(mp, "Process", fake_process_class, raising=False)
+
+
+def _setup_ipc_mocks(monkeypatch):
+ # These IPC helpers existed in the old architecture; no-op in new arch.
+ pass
+
+
+def _setup_log_mocks(monkeypatch):
+ class _FakeOrchestratorAggregator:
+ def __init__(self, num_stages, enable_stats, wall_start_ts, final_stage_id_for_e2e=None):
+ self.num_stages = num_stages
+ self.enable_stats = enable_stats
+ self.stage_first_ts = [None] * num_stages
+ self.stage_last_ts = [None] * num_stages
+ self.stage_total_tokens = [0] * num_stages
+ self.accumulated_gen_time_ms = {}
+ self.e2e_done = set()
+ self.e2e_count = 0
+ self.e2e_total_ms = 0.0
+
+ def on_stage_metrics(self, stage_id, req_id, metrics, final_output_type=None):
+ pass
+
+ def on_finalize_request(self, stage_id, req_id, start_ts):
+ self.e2e_done.add(req_id)
+
+ def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm):
+ pass
+
+ def accumulate_diffusion_metrics(self, stage_type, req_id, engine_outputs):
+ pass
+
+ def record_audio_generated_frames(self, output, stage_id, req_id):
+ pass
+
+ def stage_postprocess_timer(self, stage_id, req_id):
+ from contextlib import contextmanager
+
+ @contextmanager
+ def _noop():
+ yield
+
+ return _noop()
+
+ def build_and_log_summary(self):
+ return "Fake summary"
+
+ monkeypatch.setattr(
+ "vllm_omni.entrypoints.omni.OrchestratorAggregator",
+ _FakeOrchestratorAggregator,
+ raising=False,
+ )
+
+
+def _clear_modules():
+ import sys
+
+ for module_name in [
+ "vllm_omni.entrypoints.utils",
+ "vllm_omni.entrypoints.omni",
+ ]:
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+
+@pytest.fixture(autouse=True)
+def mock_get_config(monkeypatch):
+ """Auto-mock get_config and related model loading functions."""
+ import sys
+
+ fake_tokenizer = _ns()
+ fake_tokenizer.encode = lambda *args, **kwargs: [1, 2, 3]
+ fake_tokenizer.decode = lambda *args, **kwargs: "test"
+
+ def _mock_init_tokenizer_from_configs(model_config=None, **kwargs):
+ return fake_tokenizer
+
+ monkeypatch.setattr(
+ "vllm.transformers_utils.tokenizer.init_tokenizer_from_configs",
+ _mock_init_tokenizer_from_configs,
+ raising=False,
+ )
+ tokenizer_module_path = "vllm.transformers_utils.tokenizer"
+ if tokenizer_module_path in sys.modules:
+ setattr(sys.modules[tokenizer_module_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
+
+ def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None):
+ if prompt_token_ids is not None:
+ if isinstance(prompt_token_ids, list):
+ return len(prompt_token_ids)
+ return 10
+
+ monkeypatch.setattr(
+ "vllm.utils.length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds, raising=False
+ )
+ monkeypatch.setattr(
+ "vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds",
+ _mock_length_from_prompt_token_ids_or_embeds,
+ raising=False,
+ )
+
+ processor_module_path = "vllm_omni.engine.input_processor"
+ if processor_module_path in sys.modules:
+ setattr(
+ sys.modules[processor_module_path],
+ "length_from_prompt_token_ids_or_embeds",
+ _mock_length_from_prompt_token_ids_or_embeds,
+ )
+
+ monkeypatch.setattr(
+ "vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs", _mock_init_tokenizer_from_configs, raising=False
+ )
+ async_omni_path = "vllm_omni.entrypoints.async_omni"
+ if async_omni_path in sys.modules:
+ setattr(sys.modules[async_omni_path], "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
+
+ fake_hf_config = _ns()
+ fake_hf_config.model_type = "qwen2_5_omni"
+
+ monkeypatch.setattr(
+ "vllm.transformers_utils.config.get_config", lambda model, **kwargs: fake_hf_config, raising=False
+ )
+ monkeypatch.setattr("vllm_omni.entrypoints.utils.get_config", lambda model, **kwargs: fake_hf_config, raising=False)
+
+ def _mock_cached_file(path_or_repo_id, *args, **kwargs):
+ import os
+ import tempfile
+
+ fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json")
+ if not os.path.exists(fake_config_file):
+ with open(fake_config_file, "w") as f:
+ f.write('{"model_type": "qwen2_5_omni"}')
+ return fake_config_file
+
+ monkeypatch.setattr("transformers.utils.hub.cached_file", _mock_cached_file, raising=False)
+ monkeypatch.setattr(
+ "transformers.utils.hub.cached_files",
+ lambda path_or_repo_id, filenames, **kwargs: (
+ [_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None
+ ),
+ raising=False,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Helper to build an Omni instance with PD stage configs
+# ---------------------------------------------------------------------------
+
+
+def _make_pd_omni(monkeypatch, stage_configs, *, extra_setup=None):
+ """Create a lightweight PDDisaggregationMixin instance for unit tests.
+
+ Bypasses the full OmniBase / AsyncOmniEngine init chain so tests run
+ without GPU. Returns an object that has all PDDisaggregationMixin
+ methods and state (``_pd_separation_pair``, ``_pd_kv_params_by_req``,
+ etc.) initialised from *stage_configs*.
+
+ Tests that need the full ``Omni.generate()`` loop (old stage_list / queue
+ infrastructure) are marked ``xfail`` and not covered here.
+ """
+ configs = [_FakeStageConfig(c) for c in stage_configs]
+
+ class _LightweightOmni(PDDisaggregationMixin):
+ """Minimal shim: exposes stage_configs so PDDisaggregationMixin works."""
+
+ def __init__(self):
+ self._name = "Omni"
+ self._stage_configs = configs
+ self._init_pd_state()
+
+ @property
+ def stage_configs(self):
+ return self._stage_configs
+
+ if extra_setup:
+ import vllm_omni.entrypoints.omni as omni_module
+
+ extra_setup(monkeypatch, omni_module)
+
+ return _LightweightOmni()
+
+
+# ---------------------------------------------------------------------------
+# Stage config templates
+# ---------------------------------------------------------------------------
+
+
+def _prefill_stage_cfg(stage_id=0, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {
+ "model_stage": "thinker",
+ "kv_transfer_config": {
+ "kv_connector": "MooncakeConnector",
+ "kv_role": "kv_producer",
+ "kv_rank": 0,
+ "kv_parallel_size": 2,
+ "kv_connector_extra_config": {"mooncake_bootstrap_port": 25201},
+ },
+ },
+ "is_prefill_only": True,
+ "final_output": False,
+ "is_comprehension": True,
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+def _decode_stage_cfg(stage_id=1, engine_input_source=None, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {
+ "model_stage": "thinker",
+ "kv_transfer_config": {
+ "kv_connector": "MooncakeConnector",
+ "kv_role": "kv_consumer",
+ "kv_rank": 1,
+ "kv_parallel_size": 2,
+ "kv_connector_extra_config": {"mooncake_bootstrap_port": 25202},
+ },
+ },
+ "is_decode_only": True,
+ "engine_input_source": engine_input_source if engine_input_source is not None else [0],
+ "final_output": True,
+ "final_output_type": "text",
+ "is_comprehension": True,
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+def _talker_stage_cfg(stage_id=2, engine_input_source=None, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {"model_stage": "talker"},
+ "engine_input_source": engine_input_source if engine_input_source is not None else [1],
+ "final_output": False,
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+def _code2wav_stage_cfg(stage_id=3, engine_input_source=None, **overrides):
+ cfg = {
+ "stage_id": stage_id,
+ "engine_args": {"model_stage": "code2wav"},
+ "engine_input_source": engine_input_source if engine_input_source is not None else [2],
+ "final_output": True,
+ "final_output_type": "audio",
+ }
+ cfg.update(overrides)
+ return cfg
+
+
+# ===================================================================
+# Tests: PD pair detection
+# ===================================================================
+
+
+class TestDetectPDSeparation:
+ """Tests for Omni._detect_pd_separation()."""
+
+ def test_detects_pd_pair(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_no_pd_pair_without_flags(self, monkeypatch):
+ """Normal (non-PD) pipeline has no PD pair."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ {
+ "stage_id": 0,
+ "engine_args": {"model_stage": "thinker"},
+ "final_output": True,
+ "final_output_type": "text",
+ },
+ {
+ "stage_id": 1,
+ "engine_args": {"model_stage": "talker"},
+ "engine_input_source": [0],
+ "final_output": True,
+ "final_output_type": "audio",
+ },
+ ],
+ )
+ assert omni._pd_separation_pair is None
+
+ def test_detects_pd_pair_in_4_stage_pipeline(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ _talker_stage_cfg(stage_id=2, engine_input_source=[1]),
+ _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_pd_pair_uses_stage_id_for_input_source(self, monkeypatch):
+ """engine_input_source references stage_id, not list index."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=10),
+ _decode_stage_cfg(stage_id=20, engine_input_source=[10]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
+
+
+# ===================================================================
+# Tests: PD config validation
+# ===================================================================
+
+
+class TestValidatePDConfig:
+ """Tests for Omni._validate_pd_separation_config()."""
+
+ def test_valid_config_passes(self, monkeypatch):
+ """Valid PD config should not raise."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ # If we got here without error, validation passed
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_mismatched_connector_raises(self, monkeypatch):
+ """Different kv_connector types should raise ValueError."""
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["kv_transfer_config"]["kv_connector"] = "NixlConnector"
+
+ with pytest.raises(ValueError, match="connector mismatch"):
+ _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg])
+
+ def test_wrong_prefill_role_raises(self, monkeypatch):
+ """Prefill with kv_consumer role should raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_consumer"
+
+ with pytest.raises(ValueError, match="kv_role must be"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])])
+
+ def test_wrong_decode_role_raises(self, monkeypatch):
+ """Decode with kv_producer role should raise."""
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["kv_transfer_config"]["kv_role"] = "kv_producer"
+
+ with pytest.raises(ValueError, match="kv_role must be"):
+ _make_pd_omni(monkeypatch, [_prefill_stage_cfg(), decode_cfg])
+
+ def test_missing_kv_transfer_config_raises(self, monkeypatch):
+ """Missing kv_transfer_config should raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ del prefill_cfg["engine_args"]["kv_transfer_config"]
+
+ with pytest.raises(ValueError, match="kv_transfer_config"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, _decode_stage_cfg(engine_input_source=[0])])
+
+ def test_mismatched_buffer_device_raises(self, monkeypatch):
+ """Mismatched kv_buffer_device should raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cuda"
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["kv_transfer_config"]["kv_buffer_device"] = "cpu"
+
+ with pytest.raises(ValueError, match="kv_buffer_device mismatch"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
+
+
+# ===================================================================
+# Tests: Connector info extraction
+# ===================================================================
+
+
+class TestGetPDConnectorInfo:
+ """Tests for Omni._get_pd_connector_info()."""
+
+ def test_extracts_bootstrap_addr_for_mooncake(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ info = omni._pd_connector_info
+ assert "prefill_bootstrap_addr" in info
+ assert info["prefill_bootstrap_addr"] == "127.0.0.1:25201"
+
+ def test_none_for_non_pd_pipeline(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ {"stage_id": 0, "engine_args": {}, "final_output": True, "final_output_type": "text"},
+ ],
+ )
+ assert omni._pd_connector_info is None
+
+
+# ===================================================================
+# Tests: Prefill sampling params preparation
+# ===================================================================
+
+
+class TestPreparePrefillSamplingParams:
+ """Tests for Omni._prepare_prefill_sampling_params()."""
+
+ def test_sets_max_tokens_to_1(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048)
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ assert result.max_tokens == 1
+ assert result is not sp # should be cloned
+
+ def test_injects_kv_transfer_params(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048)
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ kv_params = result.extra_args["kv_transfer_params"]
+ assert kv_params["do_remote_decode"] is True
+ assert kv_params["do_remote_prefill"] is False
+ assert kv_params["transfer_id"] == "xfer-req-1"
+
+ def test_preserves_existing_extra_args(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, extra_args={"custom_key": "value"})
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ assert result.extra_args["custom_key"] == "value"
+ assert "kv_transfer_params" in result.extra_args
+
+ def test_does_not_mutate_original(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048)
+ _ = omni._prepare_prefill_sampling_params("req-1", sp)
+
+ assert sp.max_tokens == 2048
+ assert sp.extra_args is None
+
+
+# ===================================================================
+# Tests: Sampling params auto-duplication for PD split
+# ===================================================================
+
+
+@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)")
+class TestSamplingParamsAutoDuplication:
+ """When user provides N-1 sampling params (for logical stages), the
+ orchestrator should auto-duplicate the thinker params for the decode stage.
+ """
+
+ def test_auto_duplicates_for_4_stage_pipeline(self, monkeypatch):
+ """User provides 3 params for 4 physical stages -> auto-insert decode params."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000001")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ _talker_stage_cfg(stage_id=2, engine_input_source=[1]),
+ _code2wav_stage_cfg(stage_id=3, engine_input_source=[2]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ assert omni._pd_separation_pair == (0, 1)
+ assert len(omni.stage_list) == 4
+
+ # Simulate outputs for all stages
+ expected_rid = f"0_{test_uuid}"
+ for i in range(4):
+ omni.stage_list[i]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+
+ # Provide 3 params (one less than 4 stages) - should auto-duplicate
+ sp_thinker = SamplingParams(temperature=0.4, max_tokens=2048)
+ sp_talker = SamplingParams(temperature=0.9, max_tokens=4096)
+ sp_code2wav = SamplingParams(temperature=0.0, max_tokens=65536)
+
+ # This should NOT raise ValueError about param count mismatch
+ outputs = omni.generate(
+ prompts=["hello"],
+ sampling_params_list=[sp_thinker, sp_talker, sp_code2wav],
+ )
+ assert isinstance(outputs, list)
+
+
+# ===================================================================
+# Tests: KV transfer params normalization
+# ===================================================================
+
+
+class TestNormalizeKVTransferParams:
+ def test_dict_passthrough(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ d = {"transfer_id": "test", "do_remote_decode": True}
+ assert omni._normalize_kv_transfer_params(d) is d
+
+ def test_none_returns_none(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ assert omni._normalize_kv_transfer_params(None) is None
+
+ def test_dataclass_to_dict(self, monkeypatch):
+ from dataclasses import dataclass
+
+ @dataclass
+ class FakeKVParams:
+ transfer_id: str = "test"
+ do_remote_decode: bool = True
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ result = omni._normalize_kv_transfer_params(FakeKVParams())
+ assert isinstance(result, dict)
+ assert result["transfer_id"] == "test"
+
+
+# ===================================================================
+# Tests: _kv_cfg_to_dict
+# ===================================================================
+
+
+class TestKvCfgToDict:
+ def test_dict_passthrough(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ d = {"kv_connector": "MooncakeConnector"}
+ assert omni._kv_cfg_to_dict(d) is d
+
+ def test_none_returns_empty(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ assert omni._kv_cfg_to_dict(None) == {}
+
+ def test_dataclass_converted(self, monkeypatch):
+ from dataclasses import dataclass
+
+ @dataclass
+ class FakeCfg:
+ kv_connector: str = "TestConnector"
+ kv_role: str = "kv_producer"
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ result = omni._kv_cfg_to_dict(FakeCfg())
+ assert result["kv_connector"] == "TestConnector"
+ assert result["kv_role"] == "kv_producer"
+
+
+# ===================================================================
+# Tests: PD routing in scheduling loop
+# ===================================================================
+
+
+@pytest.mark.xfail(reason="Requires migration to v1908 Orchestrator architecture (no stage_list / OmniStage)")
+class TestPDRouting:
+ """Test that the scheduling loop correctly routes requests from
+ prefill to decode stage with proper kv_transfer_params.
+ """
+
+ def test_prefill_stage_receives_max_tokens_1(self, monkeypatch):
+ """Stage 0 (prefill) should receive max_tokens=1."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000002")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ expected_rid = f"0_{test_uuid}"
+
+ # Put stage outputs in both queues
+ omni.stage_list[0]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+ omni.stage_list[1]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
+ "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
+ }
+ )
+
+ sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
+ omni.generate(prompts=["hello"], sampling_params_list=sp_list)
+
+ # Check what was submitted to stage 0's input queue
+ # (skip the stage_ready message first)
+ task = omni.stage_list[0]._in_q.get_nowait()
+ assert task["sampling_params"].max_tokens == 1
+ kv_params = task["sampling_params"].extra_args["kv_transfer_params"]
+ assert kv_params["do_remote_decode"] is True
+ assert kv_params["do_remote_prefill"] is False
+ assert kv_params["transfer_id"] == f"xfer-{expected_rid}"
+
+ def test_decode_stage_receives_original_prompt(self, monkeypatch):
+ """Decode stage should get the original prompt (not processed outputs)."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000003")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ expected_rid = f"0_{test_uuid}"
+ original_prompt = "test prompt for PD"
+
+ omni.stage_list[0]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+ omni.stage_list[1]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
+ "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
+ }
+ )
+
+ sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
+ omni.generate(prompts=[original_prompt], sampling_params_list=sp_list)
+
+ # Check what was forwarded to stage 1 (decode)
+ # The connector sends tasks to stage 1's input queue
+ task = omni.stage_list[1]._in_q.get_nowait()
+ # The engine_inputs should contain the original prompt
+ engine_inputs = task.get("engine_inputs")
+ # For PD routing, the original prompt is wrapped in a list
+ if isinstance(engine_inputs, list):
+ assert original_prompt in engine_inputs
+ else:
+ assert engine_inputs == original_prompt
+
+ def test_decode_kv_params_have_correct_flags(self, monkeypatch):
+ """Decode stage kv_transfer_params should have correct role flags."""
+ test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000004")
+
+ def _extra_setup(mp, omni_module):
+ mp.setattr(uuid, "uuid4", lambda: test_uuid)
+ mp.setattr(omni_module, "uuid", uuid)
+
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(stage_id=0),
+ _decode_stage_cfg(stage_id=1, engine_input_source=[0]),
+ ],
+ extra_setup=_extra_setup,
+ )
+
+ expected_rid = f"0_{test_uuid}"
+
+ omni.stage_list[0]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1])])],
+ "metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
+ }
+ )
+ omni.stage_list[1]._out_q.put_nowait(
+ {
+ "request_id": expected_rid,
+ "engine_outputs": [_ns(request_id=expected_rid, outputs=[_ns(token_ids=[1, 2, 3])])],
+ "metrics": {"num_tokens_out": 3, "stage_gen_time_ms": 50.0},
+ }
+ )
+
+ sp_list = [SamplingParams(max_tokens=2048), SamplingParams(max_tokens=2048)]
+ omni.generate(prompts=["hello"], sampling_params_list=sp_list)
+
+ # Check decode task's kv_transfer_params
+ task = omni.stage_list[1]._in_q.get_nowait()
+ kv_params = task["sampling_params"].extra_args["kv_transfer_params"]
+ assert kv_params["do_remote_prefill"] is True
+ assert kv_params["do_remote_decode"] is False
+ assert kv_params["transfer_id"] == f"xfer-{expected_rid}"
+ assert kv_params["remote_bootstrap_addr"] == "127.0.0.1:25201"
+
+
+# ===================================================================
+# Tests: KV params cleanup
+# ===================================================================
+
+
+class TestKVParamsCleanup:
+ def test_drop_cleans_up(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ omni._pd_kv_params_by_req["req-1"] = {"transfer_id": "xfer-1"}
+ omni._drop_pd_kv_params("req-1")
+ assert "req-1" not in omni._pd_kv_params_by_req
+
+ def test_drop_nonexistent_is_noop(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ omni._drop_pd_kv_params("nonexistent") # should not raise
+
+ def test_pop_returns_stored_params(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ stored = {"transfer_id": "xfer-1", "extra_field": "value"}
+ omni._pd_kv_params_by_req["req-1"] = stored
+
+ result = omni._pop_pd_kv_params("req-1")
+ assert result == stored
+ assert "req-1" not in omni._pd_kv_params_by_req
+
+ def test_pop_uses_fallback_when_no_stored(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ fallback = {"transfer_id": "xfer-fallback"}
+ result = omni._pop_pd_kv_params("req-1", fallback=fallback)
+ assert result == fallback
+
+
+# ===================================================================
+# Tests: Config YAML loads without error
+# ===================================================================
+
+
+class TestPDYAMLConfig:
+ def test_pd_yaml_loads(self):
+ """The PD separation YAML config should load without errors."""
+ import os
+
+ yaml_path = os.path.join(
+ os.path.dirname(__file__),
+ "../../vllm_omni/model_executor/stage_configs/qwen3_omni_moe_pd_separation.yaml",
+ )
+ yaml_path = os.path.abspath(yaml_path)
+ if not os.path.exists(yaml_path):
+ pytest.skip("PD separation YAML not found")
+
+ from omegaconf import OmegaConf
+
+ cfg = OmegaConf.load(yaml_path)
+ stages = cfg.stage_args
+ assert len(stages) == 4
+
+ # Prefill stage
+ assert stages[0].is_prefill_only is True
+ assert stages[0].final_output is False
+ assert stages[0].is_comprehension is True
+
+ # Decode stage
+ assert stages[1].is_decode_only is True
+ assert stages[1].final_output is True
+ assert stages[1].final_output_type == "text"
+ assert stages[1].is_comprehension is True
+ assert 0 in stages[1].engine_input_source
+
+ # KV transfer configs
+ assert stages[0].engine_args.kv_transfer_config.kv_role == "kv_producer"
+ assert stages[1].engine_args.kv_transfer_config.kv_role == "kv_consumer"
+ assert stages[0].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector"
+ assert stages[1].engine_args.kv_transfer_config.kv_connector == "MooncakeConnector"
+
+
+class TestPrefillStopNeutralization:
+ """Tests that _prepare_prefill_sampling_params neutralizes stop
+ conditions to ensure finish_reason='length'.
+ """
+
+ def test_clears_stop_strings(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, stop=["", "STOP"])
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert result.stop == []
+
+ def test_clears_stop_token_ids(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, stop_token_ids=[151643, 151644])
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert result.stop_token_ids == []
+
+ def test_clears_include_stop_str_in_output(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, include_stop_str_in_output=True)
+ result = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert result.include_stop_str_in_output is False
+
+ def test_original_sp_unchanged(self, monkeypatch):
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ sp = SamplingParams(max_tokens=2048, stop=[""], stop_token_ids=[151643])
+ _ = omni._prepare_prefill_sampling_params("req-1", sp)
+ assert sp.stop == [""]
+ assert sp.stop_token_ids == [151643]
+
+
+# ===================================================================
+# Tests: Failure mode & memory leak prevention
+# ===================================================================
+# NOTE: Full generate()-level failure mode tests are removed for now.
+# The _run_generation error handler (line 1344-1350 in omni.py) calls
+# _drop_pd_kv_params but does not increment completed_requests, causing
+# the while-loop to hang. These tests need to be revisited once the
+# production error-handling path is fixed to properly terminate on
+# stage errors.
+
+
+# ===================================================================
+# Tests: TP size validation
+# ===================================================================
+
+
+class TestTPSizeValidation:
+ """Tests that _validate_pd_separation_config checks tensor_parallel_size."""
+
+ def test_matching_tp_passes(self, monkeypatch):
+ """Same TP size should not raise."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["tensor_parallel_size"] = 2
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["tensor_parallel_size"] = 2
+ omni = _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
+ assert omni._pd_separation_pair == (0, 1)
+
+ def test_mismatched_tp_raises(self, monkeypatch):
+ """Different TP sizes should raise ValueError."""
+ prefill_cfg = _prefill_stage_cfg()
+ prefill_cfg["engine_args"]["tensor_parallel_size"] = 2
+ decode_cfg = _decode_stage_cfg(engine_input_source=[0])
+ decode_cfg["engine_args"]["tensor_parallel_size"] = 4
+ with pytest.raises(ValueError, match="tensor_parallel_size"):
+ _make_pd_omni(monkeypatch, [prefill_cfg, decode_cfg])
+
+ def test_default_tp_no_error(self, monkeypatch):
+ """Stages without explicit TP (defaults to 1) should pass."""
+ omni = _make_pd_omni(
+ monkeypatch,
+ [
+ _prefill_stage_cfg(),
+ _decode_stage_cfg(engine_input_source=[0]),
+ ],
+ )
+ assert omni._pd_separation_pair == (0, 1)
diff --git a/tests/entrypoints/test_realtime_connection_helpers.py b/tests/entrypoints/test_realtime_connection_helpers.py
new file mode 100644
index 00000000000..e795aa92d0f
--- /dev/null
+++ b/tests/entrypoints/test_realtime_connection_helpers.py
@@ -0,0 +1,86 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for realtime streaming helpers (PR #2581 /v1/realtime path)."""
+
+from __future__ import annotations
+
+import base64
+
+import numpy as np
+import pytest
+import torch
+from vllm.sampling_params import RequestOutputKind, SamplingParams
+
+from vllm_omni.entrypoints.async_omni import AsyncOmni
+from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture
+def realtime_conn() -> RealtimeConnection:
+ return RealtimeConnection.__new__(RealtimeConnection)
+
+
+class TestRealtimeConnectionTensorAndPcm:
+ def test_tensor_to_numpy_none(self) -> None:
+ assert RealtimeConnection._tensor_to_numpy(None) is None
+
+ def test_tensor_to_numpy_1d_numpy(self) -> None:
+ arr = np.array([1.0, 2.0], dtype=np.float64)
+ out = RealtimeConnection._tensor_to_numpy(arr)
+ assert out is not None
+ assert out.dtype == np.float32
+ assert out.shape == (2,)
+
+ def test_tensor_to_numpy_2d_numpy_flattened(self) -> None:
+ arr = np.array([[0.5], [-0.5]], dtype=np.float32)
+ out = RealtimeConnection._tensor_to_numpy(arr)
+ assert out is not None
+ assert out.shape == (2,)
+
+ def test_tensor_to_numpy_torch(self) -> None:
+ t = torch.tensor([[0.25, -0.25]], dtype=torch.float32)
+ out = RealtimeConnection._tensor_to_numpy(t)
+ assert out is not None
+ assert out.shape == (2,)
+ np.testing.assert_allclose(out, [0.25, -0.25], rtol=1e-5)
+
+ def test_pcm16_b64_roundtrip(self) -> None:
+ audio = np.array([0.0, 1.0, -1.0], dtype=np.float32)
+ b64 = RealtimeConnection._pcm16_b64(audio)
+ raw = base64.b64decode(b64)
+ assert len(raw) == 6
+ pcm = np.frombuffer(raw, dtype=np.int16)
+ assert pcm[0] == 0
+ assert pcm[1] == 32767
+ assert pcm[2] == -32767
+
+
+class TestAsyncOmniStreamingParamsValidation:
+ def test_accepts_streaming_friendly_params(self) -> None:
+ p = SamplingParams(
+ n=1,
+ stop=[],
+ output_kind=RequestOutputKind.DELTA,
+ )
+ AsyncOmni._validate_streaming_input_sampling_params(p)
+
+ def test_rejects_non_sampling_params(self) -> None:
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(object()) # type: ignore[arg-type]
+
+ def test_rejects_n_greater_than_one(self) -> None:
+ p = SamplingParams(n=2, stop=[], output_kind=RequestOutputKind.DELTA)
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(p)
+
+ def test_rejects_final_only(self) -> None:
+ p = SamplingParams(n=1, stop=[], output_kind=RequestOutputKind.FINAL_ONLY)
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(p)
+
+ def test_rejects_stop_strings(self) -> None:
+ p = SamplingParams(n=1, stop=["\n"], output_kind=RequestOutputKind.DELTA)
+ with pytest.raises(ValueError, match="Input streaming"):
+ AsyncOmni._validate_streaming_input_sampling_params(p)
diff --git a/tests/entrypoints/test_serve.py b/tests/entrypoints/test_serve.py
index afa7fa82e4b..e60afc9cd7b 100644
--- a/tests/entrypoints/test_serve.py
+++ b/tests/entrypoints/test_serve.py
@@ -7,11 +7,33 @@
import pytest
from pytest_mock import MockerFixture
-from vllm_omni.entrypoints.cli.serve import run_headless
+from vllm_omni.entrypoints.cli.serve import OmniServeCommand, run_headless
+from vllm_omni.entrypoints.utils import detect_explicit_cli_keys
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+def test_serve_parser_accepts_no_async_chunk_and_marks_it_explicit() -> None:
+ """``--no-async-chunk`` should parse to ``async_chunk=False`` and mark the
+ shared deploy-level dest as explicitly provided by the user."""
+ try:
+ from vllm.utils.argparse_utils import FlexibleArgumentParser
+ except Exception as exc:
+ pytest.skip(f"Cannot build parser in this environment: {exc}")
+
+ root = FlexibleArgumentParser()
+ subparsers = root.add_subparsers(dest="subcommand")
+ cmd = OmniServeCommand()
+ serve_parser = cmd.subparser_init(subparsers)
+
+ argv = ["serve", "fake-model", "--omni", "--no-async-chunk"]
+ args = root.parse_args(argv)
+
+ assert args.async_chunk is False
+ explicit = detect_explicit_cli_keys(argv, serve_parser)
+ assert "async_chunk" in explicit
+
+
def _make_headless_args() -> argparse.Namespace:
return argparse.Namespace(
model="fake-model",
diff --git a/tests/examples/conftest.py b/tests/examples/conftest.py
index 137d15f163f..867731b21f9 100644
--- a/tests/examples/conftest.py
+++ b/tests/examples/conftest.py
@@ -1,353 +1,3 @@
-"""
-Shared fixtures, helpers, and path constants for tests/examples/.
-"""
+"""Pytest fixtures for tests/examples."""
-import json
-import os
-import re
-import shlex
-import subprocess
-import sys
-import tempfile
-from collections import defaultdict
-from collections.abc import Callable
-from pathlib import Path
-from typing import Any, NamedTuple, cast
-
-import pytest
-import torch
-from safetensors.torch import save_file
-
-# ---------------------------------------------------------------------------
-# Path constants and fixtures
-# ---------------------------------------------------------------------------
-
-REPO_ROOT = Path(__file__).resolve().parents[2]
-EXAMPLES = REPO_ROOT / "examples"
-
-# Use Python tempfile instead of pytest's tmp_path_factory because
-# OUTPUT_DIR is needed in test collection time, but tmp_path_factory is only available in test running time.
-# It is needed during test collection because extract_readme_snippets replaces LoRA path with a generated one under OUTPUT_DIR,
-# and extract_readme_snippets is called at collection time to generate separate test cases for each README code block.
-OUTPUT_DIR = (
- REPO_ROOT / prefix
- if (prefix := os.environ.get("OUTPUT_DIR"))
- else Path(tempfile.mkdtemp(prefix="vllm_omni_test_examples_"))
-)
-
-
-# ---------------------------------------------------------------------------
-# Code snippet extraction and asset file helpers
-# ---------------------------------------------------------------------------
-
-# parameters: language, code, h2_title
-ReadmeSnippetExtractionSkipPredicate = Callable[[str, str, str], tuple[bool, str]]
-
-
-class ReadmeSnippet(NamedTuple):
- language: str
- code: str
- h2_title: str
- index_in_section: int
- output_file_path: Path | None = None
- skip: tuple[bool, str] = (False, "")
-
- @property
- def test_id(self) -> str:
- return f"{ReadmeSnippet._slug(self.h2_title)}_{self.index_in_section:03d}"
-
- @staticmethod
- def extract_readme_snippets(
- readme_path: Path,
- skipif: ReadmeSnippetExtractionSkipPredicate | None = None,
- ) -> list["ReadmeSnippet"]:
- import mistune
-
- markdown = mistune.create_markdown(renderer="ast")
- tokens = markdown(readme_path.read_text(encoding="utf-8"))
- tokens = cast(list[dict[str, Any]], tokens) # mistune's AST renderer always produces a list, not a str
-
- h2_title = ""
- section_counts: defaultdict[str, int] = defaultdict(int)
- snippets: list[ReadmeSnippet] = []
-
- for token in tokens:
- token_type = token.get("type")
-
- if token_type == "heading":
- level = (token.get("attrs") or {}).get("level")
- title = ReadmeSnippet._heading_text(token)
- if level == 2:
- h2_title = title
- continue
-
- if token_type != "block_code":
- continue
-
- try:
- info = token.get("attrs").get("info") # type: ignore[reportOptionalMemberAccess]
- language = info.strip().split()[0].lower() # type: ignore[reportOptionalMemberAccess]
-
- # Common shell aliases to "bash" in several markdown renderers.
- if language in {"shell", "sh", "ksh", "zsh"}:
- language = "bash"
-
- if language not in {"bash", "python"}:
- continue
- except AttributeError:
- # The fence is missing explicit language info; skip it.
- continue
-
- key = h2_title
- section_counts[key] += 1
- code = token.get("raw", "")
- output_file_path = None
- if language == "bash":
- argv = ReadmeSnippet._normalize_bash_command(code, Path(readme_path.parent))
- code = shlex.join(argv)
- output_file_path = ReadmeSnippet._output_file_path_from_argv(argv)
- if skipif is not None:
- skip_config = skipif(language, code, h2_title)
- else:
- skip_config = (False, "")
- snippet = ReadmeSnippet(
- language=language,
- code=code,
- h2_title=h2_title,
- index_in_section=section_counts[key],
- output_file_path=output_file_path,
- skip=skip_config,
- )
- snippets.append(snippet)
-
- return snippets
-
- @staticmethod
- def _normalize_bash_command(command: str, readme_dir: Path) -> list[str]:
- line_joined_command = re.sub(r"\\\s*\n", " ", command).strip()
- argv = shlex.split(line_joined_command, comments=True)
- assert argv, "README bash fence produced an empty command"
-
- # Normalize python directory and example script location
- if argv[0] in {"python", "python3"}:
- argv[0] = sys.executable
- if len(argv) > 1 and argv[1].endswith(".py"):
- script_arg = argv[1]
- script_path = Path(script_arg)
- if script_path.is_absolute():
- resolved_script = script_path
- else:
- # Take the file name only, and append script_dir to its front
- resolved_script = readme_dir / script_path.name
- assert resolved_script.exists(), (
- f"README bash snippet references a script that does not exist: {script_arg} (resolved to {resolved_script})"
- )
- argv[1] = str(resolved_script)
-
- # Normalize LoRA adapter path and ensure README LoRA assets exist.
- try:
- lora_arg_idx = argv.index("--lora-path") # Raise ValueError if not found
- assert len(argv) > lora_arg_idx + 1, "README bash snippet uses --lora-path without a following value"
-
- lora_dir = OUTPUT_DIR / "lora"
- adapter_model = lora_dir / "adapter_model.safetensors"
- adapter_config = lora_dir / "adapter_config.json"
- if not adapter_model.exists() or not adapter_config.exists():
- write_zimage_lora(lora_dir, v_scale=8.0)
-
- argv[lora_arg_idx + 1] = str(lora_dir)
- except ValueError:
- pass
-
- return argv
-
- @staticmethod
- def _output_file_path_from_argv(argv: list[str]) -> Path | None:
- if "--output" not in argv:
- return None
- output_param_idx = argv.index("--output")
- assert len(argv) > output_param_idx + 1, "README bash snippet uses --output without a following value"
- output_arg = argv[output_param_idx + 1]
- return Path(output_arg)
-
- @staticmethod
- def _slug(text: str) -> str:
- return "".join(ch.lower() if ch.isalnum() else "_" for ch in text).strip("_")
-
- @staticmethod
- def _heading_text(token: dict) -> str:
- return "".join(child.get("raw", "") for child in token.get("children", [])).strip()
-
-
-# [TODO] Duplicate `_write_zimage_lora` in tests/e2e/online_serving/test_images_generations_lora.py. Combine these helpers and tests/e2e/offline_inference/test_diffusion_lora.py to test/utils later
-def write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0):
- adapter_dir.mkdir(parents=True, exist_ok=True)
-
- # Z-Image transformer uses dim=3840 by default.
- dim = 3840
- module_name = "transformer.layers.0.attention.to_qkv"
- rank = 1
-
- lora_a = torch.zeros((rank, dim), dtype=torch.float32)
- lora_a[0, 0] = 1.0
-
- # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1).
- lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32)
- if q_scale:
- lora_b[:dim, 0] = q_scale
- if k_scale:
- lora_b[dim : 2 * dim, 0] = k_scale
- if v_scale:
- lora_b[2 * dim :, 0] = v_scale
-
- save_file(
- {
- f"base_model.model.{module_name}.lora_A.weight": lora_a,
- f"base_model.model.{module_name}.lora_B.weight": lora_b,
- },
- str(adapter_dir / "adapter_model.safetensors"),
- )
- (adapter_dir / "adapter_config.json").write_text(
- json.dumps(
- {
- "r": rank,
- "lora_alpha": rank,
- "target_modules": [module_name],
- }
- ),
- encoding="utf-8",
- )
-
-
-# ---------------------------------------------------------------------------
-# Code runner and subprocess helpers
-# ---------------------------------------------------------------------------
-
-
-class ExampleRunResult(NamedTuple):
- run_dir: Path
- assets: list[Path]
-
-
-class ExampleRunner:
- """Run extracted README snippets and return generated assets.
-
- The output materials are organized in a three-level directory structure:
- - Set at init: `self.output_root` for all tests (from env OUTPUT_DIR)
- - Set at `self.run(...)`: `output_subfolder` for a specific example page (e.g., `example_offline_t2i`)
- - Generated by `extract_readme_snippets`: `snippet.test_id` for a specific code block (matching H2 titles, e.g., `basic_usage_001`)
- """
-
- IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"}
-
- def __init__(self, output_root: Path) -> None:
- self.output_root = output_root
-
- def run(
- self, snippet: ReadmeSnippet, *, output_subfolder: Path = Path("."), env: dict[str, str] | None = None
- ) -> ExampleRunResult:
- run_dir = self.output_root / output_subfolder / snippet.test_id
- run_dir.mkdir(parents=True, exist_ok=True)
-
- if snippet.language == "python":
- assets = self._run_python_snippet(snippet, run_dir, env)
- return ExampleRunResult(run_dir=run_dir, assets=assets)
-
- if snippet.language == "bash":
- asset = self._run_bash_snippet(snippet, run_dir, env)
- return ExampleRunResult(run_dir=run_dir, assets=[asset])
-
- raise AssertionError(f"Unsupported snippet language: {snippet.language}")
-
- def _run_python_snippet(
- self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None
- ) -> list[Path]:
- # Saving the script to a temporary file and `run_cmd` it.
- # Not using `exec(snippet.code)` because the output is lost.
- script_path = run_dir / "snippet.py"
- script_path.write_text(snippet.code, encoding="utf-8")
-
- before = self._collect_images(run_dir)
- run_cmd([sys.executable, str(script_path)], cwd=run_dir, env=env)
- after = self._collect_images(run_dir)
-
- assets = sorted(after - before)
- return assets
-
- def _run_bash_snippet(self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None) -> Path:
- run_cmd(snippet.code, shell=True, cwd=run_dir, env=env)
-
- assert snippet.output_file_path is not None, (
- f"README bash snippet is missing --output argument: {snippet.test_id}. "
- "The test script cannot guess the output file path."
- )
-
- # If the code snippet declares a relative path for the output file, append this path to the parent output collection directory.
- # If the code snippet declares an absolute path (not likely but just in case), the return value resolution removes `run_dir`, also correctly pointing to this file.
- return run_dir / snippet.output_file_path
-
- def _collect_images(self, root: Path) -> set[Path]:
- return {path for path in root.rglob("*") if path.suffix.lower() in self.IMAGE_SUFFIXES}
-
-
-@pytest.fixture
-def example_runner() -> ExampleRunner:
- return ExampleRunner(output_root=OUTPUT_DIR)
-
-
-def run_cmd(
- command: list[str] | str,
- *,
- shell: bool = False,
- env: dict[str, str] | None = None,
- cwd: Path | str | None = None,
-) -> str:
- """Run a command as a subprocess; assert zero exit code and return stdout.
-
- Output is fully captured and returned as a string so callers can parse it
- (e.g. with :func:`extract_content_after_keyword`).
- Use this for scripts whose printed output is part of the test assertion.
- """
- if env is not None:
- env = {**os.environ.copy(), **env}
- result = subprocess.run(command, capture_output=True, text=True, shell=shell, env=env, cwd=cwd)
-
- if result.returncode != 0:
- print(f"STDERR: {result.stderr}")
- raise subprocess.CalledProcessError(result.returncode, command)
-
- all_output = result.stdout
- print(f"All output:\n{all_output}")
- return all_output
-
-
-# ---------------------------------------------------------------------------
-# Output validation helpers
-# ---------------------------------------------------------------------------
-
-
-def extract_content_after_keyword(keywords: str, text: str) -> str:
- """Return the text that follows *keywords* in *text* (regex match).
-
- Raises ``AssertionError`` if the keyword is not found, so test failures
- produce a clear message pointing at the missing keyword.
- """
- matches = re.findall(rf"{keywords}\s*(.+)", text, re.DOTALL)
-
- if not matches:
- raise AssertionError(f"Keywords {keywords} not found in provided text output")
- return matches[0]
-
-
-def strip_trailing_audio_saved_line(text: str) -> str:
- """Drop trailing ``Audio saved to ...`` lines from captured client stdout.
-
- ``openai_chat_completion_client_for_multimodal_generation.py`` may print
- ``Chat completion output from text:`` for one choice and ``Audio saved to``
- for another; :func:`extract_content_after_keyword` uses ``re.DOTALL`` and
- would otherwise keep the audio progress line inside the *text* segment.
- """
- lines = text.splitlines()
- while lines and lines[-1].strip().startswith("Audio saved to"):
- lines.pop()
- return "\n".join(lines).strip()
+from tests.examples.helpers import example_runner # noqa: F401
diff --git a/tests/examples/helpers.py b/tests/examples/helpers.py
new file mode 100644
index 00000000000..137d15f163f
--- /dev/null
+++ b/tests/examples/helpers.py
@@ -0,0 +1,353 @@
+"""
+Shared fixtures, helpers, and path constants for tests/examples/.
+"""
+
+import json
+import os
+import re
+import shlex
+import subprocess
+import sys
+import tempfile
+from collections import defaultdict
+from collections.abc import Callable
+from pathlib import Path
+from typing import Any, NamedTuple, cast
+
+import pytest
+import torch
+from safetensors.torch import save_file
+
+# ---------------------------------------------------------------------------
+# Path constants and fixtures
+# ---------------------------------------------------------------------------
+
+REPO_ROOT = Path(__file__).resolve().parents[2]
+EXAMPLES = REPO_ROOT / "examples"
+
+# Use Python tempfile instead of pytest's tmp_path_factory because
+# OUTPUT_DIR is needed in test collection time, but tmp_path_factory is only available in test running time.
+# It is needed during test collection because extract_readme_snippets replaces LoRA path with a generated one under OUTPUT_DIR,
+# and extract_readme_snippets is called at collection time to generate separate test cases for each README code block.
+OUTPUT_DIR = (
+ REPO_ROOT / prefix
+ if (prefix := os.environ.get("OUTPUT_DIR"))
+ else Path(tempfile.mkdtemp(prefix="vllm_omni_test_examples_"))
+)
+
+
+# ---------------------------------------------------------------------------
+# Code snippet extraction and asset file helpers
+# ---------------------------------------------------------------------------
+
+# parameters: language, code, h2_title
+ReadmeSnippetExtractionSkipPredicate = Callable[[str, str, str], tuple[bool, str]]
+
+
+class ReadmeSnippet(NamedTuple):
+ language: str
+ code: str
+ h2_title: str
+ index_in_section: int
+ output_file_path: Path | None = None
+ skip: tuple[bool, str] = (False, "")
+
+ @property
+ def test_id(self) -> str:
+ return f"{ReadmeSnippet._slug(self.h2_title)}_{self.index_in_section:03d}"
+
+ @staticmethod
+ def extract_readme_snippets(
+ readme_path: Path,
+ skipif: ReadmeSnippetExtractionSkipPredicate | None = None,
+ ) -> list["ReadmeSnippet"]:
+ import mistune
+
+ markdown = mistune.create_markdown(renderer="ast")
+ tokens = markdown(readme_path.read_text(encoding="utf-8"))
+ tokens = cast(list[dict[str, Any]], tokens) # mistune's AST renderer always produces a list, not a str
+
+ h2_title = ""
+ section_counts: defaultdict[str, int] = defaultdict(int)
+ snippets: list[ReadmeSnippet] = []
+
+ for token in tokens:
+ token_type = token.get("type")
+
+ if token_type == "heading":
+ level = (token.get("attrs") or {}).get("level")
+ title = ReadmeSnippet._heading_text(token)
+ if level == 2:
+ h2_title = title
+ continue
+
+ if token_type != "block_code":
+ continue
+
+ try:
+ info = token.get("attrs").get("info") # type: ignore[reportOptionalMemberAccess]
+ language = info.strip().split()[0].lower() # type: ignore[reportOptionalMemberAccess]
+
+ # Common shell aliases to "bash" in several markdown renderers.
+ if language in {"shell", "sh", "ksh", "zsh"}:
+ language = "bash"
+
+ if language not in {"bash", "python"}:
+ continue
+ except AttributeError:
+ # The fence is missing explicit language info; skip it.
+ continue
+
+ key = h2_title
+ section_counts[key] += 1
+ code = token.get("raw", "")
+ output_file_path = None
+ if language == "bash":
+ argv = ReadmeSnippet._normalize_bash_command(code, Path(readme_path.parent))
+ code = shlex.join(argv)
+ output_file_path = ReadmeSnippet._output_file_path_from_argv(argv)
+ if skipif is not None:
+ skip_config = skipif(language, code, h2_title)
+ else:
+ skip_config = (False, "")
+ snippet = ReadmeSnippet(
+ language=language,
+ code=code,
+ h2_title=h2_title,
+ index_in_section=section_counts[key],
+ output_file_path=output_file_path,
+ skip=skip_config,
+ )
+ snippets.append(snippet)
+
+ return snippets
+
+ @staticmethod
+ def _normalize_bash_command(command: str, readme_dir: Path) -> list[str]:
+ line_joined_command = re.sub(r"\\\s*\n", " ", command).strip()
+ argv = shlex.split(line_joined_command, comments=True)
+ assert argv, "README bash fence produced an empty command"
+
+ # Normalize python directory and example script location
+ if argv[0] in {"python", "python3"}:
+ argv[0] = sys.executable
+ if len(argv) > 1 and argv[1].endswith(".py"):
+ script_arg = argv[1]
+ script_path = Path(script_arg)
+ if script_path.is_absolute():
+ resolved_script = script_path
+ else:
+ # Take the file name only, and append script_dir to its front
+ resolved_script = readme_dir / script_path.name
+ assert resolved_script.exists(), (
+ f"README bash snippet references a script that does not exist: {script_arg} (resolved to {resolved_script})"
+ )
+ argv[1] = str(resolved_script)
+
+ # Normalize LoRA adapter path and ensure README LoRA assets exist.
+ try:
+ lora_arg_idx = argv.index("--lora-path") # Raise ValueError if not found
+ assert len(argv) > lora_arg_idx + 1, "README bash snippet uses --lora-path without a following value"
+
+ lora_dir = OUTPUT_DIR / "lora"
+ adapter_model = lora_dir / "adapter_model.safetensors"
+ adapter_config = lora_dir / "adapter_config.json"
+ if not adapter_model.exists() or not adapter_config.exists():
+ write_zimage_lora(lora_dir, v_scale=8.0)
+
+ argv[lora_arg_idx + 1] = str(lora_dir)
+ except ValueError:
+ pass
+
+ return argv
+
+ @staticmethod
+ def _output_file_path_from_argv(argv: list[str]) -> Path | None:
+ if "--output" not in argv:
+ return None
+ output_param_idx = argv.index("--output")
+ assert len(argv) > output_param_idx + 1, "README bash snippet uses --output without a following value"
+ output_arg = argv[output_param_idx + 1]
+ return Path(output_arg)
+
+ @staticmethod
+ def _slug(text: str) -> str:
+ return "".join(ch.lower() if ch.isalnum() else "_" for ch in text).strip("_")
+
+ @staticmethod
+ def _heading_text(token: dict) -> str:
+ return "".join(child.get("raw", "") for child in token.get("children", [])).strip()
+
+
+# [TODO] Duplicate `_write_zimage_lora` in tests/e2e/online_serving/test_images_generations_lora.py. Combine these helpers and tests/e2e/offline_inference/test_diffusion_lora.py to test/utils later
+def write_zimage_lora(adapter_dir: Path, *, q_scale: float = 0.0, k_scale: float = 0.0, v_scale: float = 0.0):
+ adapter_dir.mkdir(parents=True, exist_ok=True)
+
+ # Z-Image transformer uses dim=3840 by default.
+ dim = 3840
+ module_name = "transformer.layers.0.attention.to_qkv"
+ rank = 1
+
+ lora_a = torch.zeros((rank, dim), dtype=torch.float32)
+ lora_a[0, 0] = 1.0
+
+ # QKVParallelLinear packs (Q, K, V) => out dim is 3 * dim (tp=1).
+ lora_b = torch.zeros((3 * dim, rank), dtype=torch.float32)
+ if q_scale:
+ lora_b[:dim, 0] = q_scale
+ if k_scale:
+ lora_b[dim : 2 * dim, 0] = k_scale
+ if v_scale:
+ lora_b[2 * dim :, 0] = v_scale
+
+ save_file(
+ {
+ f"base_model.model.{module_name}.lora_A.weight": lora_a,
+ f"base_model.model.{module_name}.lora_B.weight": lora_b,
+ },
+ str(adapter_dir / "adapter_model.safetensors"),
+ )
+ (adapter_dir / "adapter_config.json").write_text(
+ json.dumps(
+ {
+ "r": rank,
+ "lora_alpha": rank,
+ "target_modules": [module_name],
+ }
+ ),
+ encoding="utf-8",
+ )
+
+
+# ---------------------------------------------------------------------------
+# Code runner and subprocess helpers
+# ---------------------------------------------------------------------------
+
+
+class ExampleRunResult(NamedTuple):
+ run_dir: Path
+ assets: list[Path]
+
+
+class ExampleRunner:
+ """Run extracted README snippets and return generated assets.
+
+ The output materials are organized in a three-level directory structure:
+ - Set at init: `self.output_root` for all tests (from env OUTPUT_DIR)
+ - Set at `self.run(...)`: `output_subfolder` for a specific example page (e.g., `example_offline_t2i`)
+ - Generated by `extract_readme_snippets`: `snippet.test_id` for a specific code block (matching H2 titles, e.g., `basic_usage_001`)
+ """
+
+ IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp"}
+
+ def __init__(self, output_root: Path) -> None:
+ self.output_root = output_root
+
+ def run(
+ self, snippet: ReadmeSnippet, *, output_subfolder: Path = Path("."), env: dict[str, str] | None = None
+ ) -> ExampleRunResult:
+ run_dir = self.output_root / output_subfolder / snippet.test_id
+ run_dir.mkdir(parents=True, exist_ok=True)
+
+ if snippet.language == "python":
+ assets = self._run_python_snippet(snippet, run_dir, env)
+ return ExampleRunResult(run_dir=run_dir, assets=assets)
+
+ if snippet.language == "bash":
+ asset = self._run_bash_snippet(snippet, run_dir, env)
+ return ExampleRunResult(run_dir=run_dir, assets=[asset])
+
+ raise AssertionError(f"Unsupported snippet language: {snippet.language}")
+
+ def _run_python_snippet(
+ self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None
+ ) -> list[Path]:
+ # Saving the script to a temporary file and `run_cmd` it.
+ # Not using `exec(snippet.code)` because the output is lost.
+ script_path = run_dir / "snippet.py"
+ script_path.write_text(snippet.code, encoding="utf-8")
+
+ before = self._collect_images(run_dir)
+ run_cmd([sys.executable, str(script_path)], cwd=run_dir, env=env)
+ after = self._collect_images(run_dir)
+
+ assets = sorted(after - before)
+ return assets
+
+ def _run_bash_snippet(self, snippet: ReadmeSnippet, run_dir: Path, env: dict[str, str] | None = None) -> Path:
+ run_cmd(snippet.code, shell=True, cwd=run_dir, env=env)
+
+ assert snippet.output_file_path is not None, (
+ f"README bash snippet is missing --output argument: {snippet.test_id}. "
+ "The test script cannot guess the output file path."
+ )
+
+ # If the code snippet declares a relative path for the output file, append this path to the parent output collection directory.
+ # If the code snippet declares an absolute path (not likely but just in case), the return value resolution removes `run_dir`, also correctly pointing to this file.
+ return run_dir / snippet.output_file_path
+
+ def _collect_images(self, root: Path) -> set[Path]:
+ return {path for path in root.rglob("*") if path.suffix.lower() in self.IMAGE_SUFFIXES}
+
+
+@pytest.fixture
+def example_runner() -> ExampleRunner:
+ return ExampleRunner(output_root=OUTPUT_DIR)
+
+
+def run_cmd(
+ command: list[str] | str,
+ *,
+ shell: bool = False,
+ env: dict[str, str] | None = None,
+ cwd: Path | str | None = None,
+) -> str:
+ """Run a command as a subprocess; assert zero exit code and return stdout.
+
+ Output is fully captured and returned as a string so callers can parse it
+ (e.g. with :func:`extract_content_after_keyword`).
+ Use this for scripts whose printed output is part of the test assertion.
+ """
+ if env is not None:
+ env = {**os.environ.copy(), **env}
+ result = subprocess.run(command, capture_output=True, text=True, shell=shell, env=env, cwd=cwd)
+
+ if result.returncode != 0:
+ print(f"STDERR: {result.stderr}")
+ raise subprocess.CalledProcessError(result.returncode, command)
+
+ all_output = result.stdout
+ print(f"All output:\n{all_output}")
+ return all_output
+
+
+# ---------------------------------------------------------------------------
+# Output validation helpers
+# ---------------------------------------------------------------------------
+
+
+def extract_content_after_keyword(keywords: str, text: str) -> str:
+ """Return the text that follows *keywords* in *text* (regex match).
+
+ Raises ``AssertionError`` if the keyword is not found, so test failures
+ produce a clear message pointing at the missing keyword.
+ """
+ matches = re.findall(rf"{keywords}\s*(.+)", text, re.DOTALL)
+
+ if not matches:
+ raise AssertionError(f"Keywords {keywords} not found in provided text output")
+ return matches[0]
+
+
+def strip_trailing_audio_saved_line(text: str) -> str:
+ """Drop trailing ``Audio saved to ...`` lines from captured client stdout.
+
+ ``openai_chat_completion_client_for_multimodal_generation.py`` may print
+ ``Chat completion output from text:`` for one choice and ``Audio saved to``
+ for another; :func:`extract_content_after_keyword` uses ``re.DOTALL`` and
+ would otherwise keep the audio progress line inside the *text* segment.
+ """
+ lines = text.splitlines()
+ while lines and lines[-1].strip().startswith("Audio saved to"):
+ lines.pop()
+ return "\n".join(lines).strip()
diff --git a/tests/examples/offline_inference/test_text_to_image.py b/tests/examples/offline_inference/test_text_to_image.py
index a08d16f1614..f24506587c1 100644
--- a/tests/examples/offline_inference/test_text_to_image.py
+++ b/tests/examples/offline_inference/test_text_to_image.py
@@ -7,9 +7,9 @@
import pytest
-from tests.conftest import assert_image_valid
-from tests.examples.conftest import EXAMPLES, ExampleRunner, ReadmeSnippet
-from tests.utils import hardware_marks
+from tests.examples.helpers import EXAMPLES, ExampleRunner, ReadmeSnippet
+from tests.helpers.assertions import assert_image_valid
+from tests.helpers.mark import hardware_marks
pytestmark = [pytest.mark.advanced_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})]
diff --git a/tests/examples/online_serving/test_qwen2_5_omni.py b/tests/examples/online_serving/test_qwen2_5_omni.py
index a78ccf5924a..b8acfef84ba 100644
--- a/tests/examples/online_serving/test_qwen2_5_omni.py
+++ b/tests/examples/online_serving/test_qwen2_5_omni.py
@@ -5,33 +5,29 @@
import os
-from vllm_omni.platforms import current_omni_platform
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from pathlib import Path
import pytest
-from tests.conftest import OmniServerParams, convert_audio_file_to_text, cosine_similarity_text
-from tests.examples.conftest import (
+from tests.examples.helpers import (
extract_content_after_keyword,
run_cmd,
strip_trailing_audio_saved_line,
)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path
pytestmark = [pytest.mark.advanced_model, pytest.mark.example]
models = ["Qwen/Qwen2.5-Omni-7B"]
-
-stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen2_5_omni_ci.yaml")]
-
-if current_omni_platform.is_xpu():
- stage_configs = [
- str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen2_5_omni_ci.yaml")
- ]
+# Single CI deploy YAML; rocm/xpu deltas are picked automatically via the
+# platforms: section in vllm_omni/deploy/ci/qwen2_5_omni.yaml.
+stage_configs = [get_deploy_config_path("ci/qwen2_5_omni.yaml")]
example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving")
# Create parameter combinations for model and stage config
diff --git a/tests/examples/online_serving/test_qwen3_omni.py b/tests/examples/online_serving/test_qwen3_omni.py
index 65f99d7bf28..e2c3b33849b 100644
--- a/tests/examples/online_serving/test_qwen3_omni.py
+++ b/tests/examples/online_serving/test_qwen3_omni.py
@@ -5,31 +5,28 @@
import os
-from vllm_omni.platforms import current_omni_platform
-
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from pathlib import Path
import pytest
-from tests.conftest import OmniServerParams, convert_audio_file_to_text, cosine_similarity_text
-from tests.examples.conftest import (
+from tests.examples.helpers import (
extract_content_after_keyword,
run_cmd,
strip_trailing_audio_saved_line,
)
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
+from tests.helpers.media import convert_audio_file_to_text, cosine_similarity_text
+from tests.helpers.runtime import OmniServerParams
+from tests.helpers.stage_config import get_deploy_config_path
pytestmark = [pytest.mark.advanced_model, pytest.mark.example]
models = ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]
-stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "qwen3_omni_ci.yaml")]
-
-if current_omni_platform.is_xpu():
- stage_configs = [str(Path(__file__).parent.parent.parent / "e2e" / "stage_configs" / "xpu" / "qwen3_omni_ci.yaml")]
+stage_configs = [get_deploy_config_path("ci/qwen3_omni_moe.yaml")]
example_dir = str(Path(__file__).parent.parent.parent.parent / "examples" / "online_serving")
diff --git a/tests/examples/online_serving/test_text_to_image.py b/tests/examples/online_serving/test_text_to_image.py
index 51b7ff61bc9..6f89cb5496b 100644
--- a/tests/examples/online_serving/test_text_to_image.py
+++ b/tests/examples/online_serving/test_text_to_image.py
@@ -13,9 +13,10 @@
import pytest
-from tests.conftest import OmniServer, OmniServerParams, assert_image_valid
-from tests.examples.conftest import EXAMPLES, OUTPUT_DIR, run_cmd, write_zimage_lora
-from tests.utils import hardware_marks
+from tests.examples.helpers import EXAMPLES, OUTPUT_DIR, run_cmd, write_zimage_lora
+from tests.helpers.assertions import assert_image_valid
+from tests.helpers.mark import hardware_marks
+from tests.helpers.runtime import OmniServer, OmniServerParams
pytestmark = [pytest.mark.advanced_model, pytest.mark.example, *hardware_marks(res={"cuda": "H100"})]
diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py
new file mode 100644
index 00000000000..a3348b07fe0
--- /dev/null
+++ b/tests/helpers/__init__.py
@@ -0,0 +1,8 @@
+"""Shared, importable test helper utilities.
+
+Submodules (``assertions``, ``env``, ``media``, ``runtime``, …) are imported
+explicitly by callers. Avoid star-importing everything here: that ran before
+refactor only inside the old monolithic ``conftest``; a greedy ``__init__``
+changes import order and can affect in-process Omni (``OmniRunner`` / offline
+e2e) vs subprocess-based ``OmniServer`` tests.
+"""
diff --git a/tests/helpers/assertions.py b/tests/helpers/assertions.py
new file mode 100644
index 00000000000..533806477e5
--- /dev/null
+++ b/tests/helpers/assertions.py
@@ -0,0 +1,506 @@
+"""Assertion and response validation helpers for tests."""
+
+import io
+import tempfile
+import threading
+from io import BytesIO
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import soundfile as sf
+from PIL import Image
+
+from tests.helpers.media import cosine_similarity_text
+
+_GENDER_PIPELINE = None
+_GENDER_PIPELINE_LOCK = threading.Lock()
+_PCM_SPEECH_SAMPLE_RATE_HZ = 24_000
+_MIN_PCM_SPEECH_HNR_DB = 1.0
+_PRESET_VOICE_GENDER_MAP: dict[str, str] = {
+ "serena": "female",
+ "uncle_fu": "male",
+ "chelsie": "female",
+ "clone": "female",
+ "ethan": "male",
+}
+
+
+def assert_image_diffusion_response(
+ response,
+ request_config: dict[str, Any],
+ run_level: str = None,
+) -> None:
+ """
+ Validate image diffusion response.
+
+ Expected request_config schema:
+ {
+ "request_type": "image",
+ "extra_body": {
+ "num_outputs_per_prompt": 1,
+ "width": ...,
+ "height": ...,
+ ...
+ }
+ }
+ """
+ assert response.images is not None, "Image response is None"
+ assert len(response.images) > 0, "No images in response"
+
+ extra_body = request_config.get("extra_body") or {}
+
+ num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt")
+ if num_outputs_per_prompt is not None:
+ assert len(response.images) == num_outputs_per_prompt, (
+ f"Expected {num_outputs_per_prompt} images, got {len(response.images)}"
+ )
+
+ if run_level == "advanced_model":
+ width = extra_body.get("width")
+ height = extra_body.get("height")
+
+ if width is not None or height is not None:
+ for img in response.images:
+ assert_image_valid(img, width=width, height=height)
+
+
+def assert_video_diffusion_response(
+ response,
+ request_config: dict[str, Any],
+ run_level: str = None,
+) -> None:
+ """
+ Validate video diffusion response.
+
+ Expected request_config schema:
+ {
+ "request_type": "video",
+ "form_data": {
+ "prompt": "...",
+ "num_frames": ...,
+ "width": ...,
+ "height": ...,
+ "fps": ...,
+ ...
+ }
+ }
+ """
+ form_data = request_config.get("form_data", {})
+
+ assert response.videos is not None, "Video response is None"
+ assert len(response.videos) > 0, "No videos in response"
+
+ expected_frames = _maybe_int(form_data.get("num_frames"))
+ expected_width = _maybe_int(form_data.get("width"))
+ expected_height = _maybe_int(form_data.get("height"))
+ expected_fps = _maybe_int(form_data.get("fps"))
+
+ for vid_bytes in response.videos:
+ assert_video_valid(
+ vid_bytes,
+ num_frames=expected_frames,
+ width=expected_width,
+ height=expected_height,
+ fps=expected_fps,
+ )
+
+
+def assert_audio_diffusion_response(
+ response,
+ request_config: dict[str, Any],
+ run_level: str = None,
+) -> None:
+ """
+ Validate audio diffusion response.
+ """
+ raise NotImplementedError("Audio validation is not implemented yet")
+
+
+def _maybe_int(value: Any) -> int | None:
+ if value is None:
+ return None
+ return int(value)
+
+
+def assert_image_valid(image: Path | Image.Image, *, width: int | None = None, height: int | None = None):
+ """Assert the file is a loadable image with optional exact dimensions."""
+ if isinstance(image, Path):
+ assert image.exists(), f"Image not found: {image}"
+ image = Image.open(image)
+ image.load()
+ assert image.width > 0 and image.height > 0
+ if width is not None:
+ assert image.width == width, f"Expected width={width}, got {image.width}"
+ if height is not None:
+ assert image.height == height, f"Expected height={height}, got {image.height}"
+ return image
+
+
+def assert_video_valid(
+ video: Path | bytes | BytesIO,
+ *,
+ num_frames: int | None = None,
+ width: int | None = None,
+ height: int | None = None,
+ fps: float | None = None,
+) -> dict[str, int | float]:
+ """Assert the MP4 has the expected resolution and frame count.
+
+ For several diffusion backends, encoded MP4 frame count follows a codec-aligned
+ convention (e.g. request `num_frames=8` can produce 9 encoded frames). Keep
+ this compatibility behavior to avoid false negatives in online-serving tests.
+ """
+ temp_path = None
+ cap = None
+ try:
+ import cv2
+
+ if isinstance(video, Path):
+ if not video.exists():
+ raise AssertionError(f"Video file not found: {video}")
+ video_path = str(video)
+ else:
+ suffix = ".mp4"
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, mode="wb") as tmp:
+ if isinstance(video, bytes):
+ tmp.write(video)
+ elif isinstance(video, BytesIO):
+ tmp.write(video.getvalue())
+ else:
+ raise TypeError(f"Unsupported video type: {type(video)}")
+ temp_path = Path(tmp.name)
+ video_path = str(temp_path)
+
+ cap = cv2.VideoCapture(video_path)
+ if not cap.isOpened():
+ raise AssertionError("Failed to open video capture")
+
+ actual_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ actual_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ actual_fps = float(cap.get(cv2.CAP_PROP_FPS))
+ actual_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ if width is not None:
+ assert actual_width == width, f"Expected width={width}, got {actual_width}"
+ if height is not None:
+ assert actual_height == height, f"Expected height={height}, got {actual_height}"
+ if fps is not None and actual_fps:
+ assert abs(actual_fps - float(fps)) < 1.0, f"Expected fps~={fps}, got {actual_fps}"
+ if num_frames is not None:
+ expected_frames = (int(num_frames) // 4) * 4 + 1
+ assert actual_frames == expected_frames, f"Expected frames={expected_frames}, got {actual_frames}"
+
+ return {
+ "width": actual_width,
+ "height": actual_height,
+ "fps": actual_fps,
+ "num_frames": actual_frames,
+ }
+ except Exception as e:
+ print(f"ERROR: {type(e).__name__}: {e}", flush=True)
+ raise
+ finally:
+ if cap is not None:
+ cap.release()
+ if temp_path and temp_path.exists():
+ try:
+ temp_path.unlink()
+ except OSError:
+ pass
+
+
+def assert_audio_valid(
+ audio_or_path: Path | np.ndarray,
+ *,
+ sample_rate: int,
+ channels: int,
+ duration_s: float,
+) -> None:
+ """Assert WAV file or (batch, channels, samples) ndarray matches expected audio format."""
+ expected_samples = int(duration_s * sample_rate)
+ if isinstance(audio_or_path, np.ndarray):
+ audio = audio_or_path
+ assert audio.ndim == 3, f"Expected audio ndim=3 (batch, channels, samples), got shape {audio.shape}"
+ assert audio.shape[0] == 1, f"Expected batch size 1, got {audio.shape[0]}"
+ assert audio.shape[1] == channels, f"Expected {channels} channels, got {audio.shape[1]}"
+ assert audio.shape[2] == expected_samples, (
+ f"Expected {expected_samples} samples ({duration_s}s @ {sample_rate} Hz), got {audio.shape[2]}"
+ )
+ return
+
+ path = audio_or_path
+ assert path.exists(), f"Audio not found: {path}"
+ info = sf.info(str(path))
+ assert info.samplerate == sample_rate, f"Expected sample_rate={sample_rate}, got {info.samplerate}"
+ assert info.channels == channels, f"Expected {channels} channel(s), got {info.channels}"
+ assert info.frames == expected_samples, (
+ f"Expected {expected_samples} frames ({duration_s}s @ {sample_rate} Hz), got {info.frames}"
+ )
+
+
+def _load_gender_pipeline():
+ global _GENDER_PIPELINE
+ if _GENDER_PIPELINE is not None:
+ return _GENDER_PIPELINE
+ model_name = "7wolf/wav2vec2-base-gender-classification"
+ try:
+ from transformers import pipeline
+
+ _GENDER_PIPELINE = pipeline(task="audio-classification", model=model_name, device=-1)
+ return _GENDER_PIPELINE
+ except Exception as exc: # pragma: no cover
+ print(f"Warning: failed to create gender pipeline '{model_name}': {exc}")
+ _GENDER_PIPELINE = None
+ return None
+
+
+def _median_pitch_hz_from_autocorr(mono: np.ndarray, sr: int) -> float | None:
+ x = np.asarray(mono, dtype=np.float64)
+ x = x - np.mean(x)
+ if x.size < int(0.15 * sr):
+ return None
+ frame_len = int(0.04 * sr)
+ hop = max(frame_len // 2, 1)
+ f0_min_hz, f0_max_hz = 70.0, 400.0
+ lag_min = max(1, int(sr / f0_max_hz))
+ lag_max = min(frame_len - 2, int(sr / f0_min_hz))
+ if lag_max <= lag_min:
+ return None
+ win = np.hamming(frame_len)
+ pitches: list[float] = []
+ for start in range(0, int(x.shape[0]) - frame_len, hop):
+ frame = x[start : start + frame_len] * win
+ frame = frame - np.mean(frame)
+ if float(np.sqrt(np.mean(frame**2))) < 1e-4:
+ continue
+ ac = np.correlate(frame, frame, mode="full")[frame_len - 1 :]
+ ac = ac / (float(ac[0]) + 1e-12)
+ region = ac[lag_min : lag_max + 1]
+ peak_rel = int(np.argmax(region))
+ peak_lag = peak_rel + lag_min
+ if peak_lag <= 0:
+ continue
+ f0 = float(sr) / float(peak_lag)
+ if f0_min_hz <= f0 <= f0_max_hz:
+ pitches.append(f0)
+ if len(pitches) < 4:
+ return None
+ return float(np.median(np.asarray(pitches, dtype=np.float64)))
+
+
+def _estimate_voice_gender_from_audio(audio_bytes: bytes) -> str:
+ data, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32", always_2d=True)
+ if data.size == 0:
+ raise ValueError("Empty audio")
+ mono = np.mean(data, axis=1)
+ try:
+ target_sr = 16000
+ if int(sr) != target_sr and mono.size > 1:
+ src_len = int(mono.shape[0])
+ dst_len = max(1, int(round(src_len * float(target_sr) / float(sr))))
+ src_idx = np.arange(src_len, dtype=np.float32)
+ dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
+ mono = np.interp(dst_idx, src_idx, mono.astype(np.float32, copy=False)).astype(np.float32)
+ sr = target_sr
+
+ median_f0 = _median_pitch_hz_from_autocorr(mono, sr)
+ clf = _load_gender_pipeline()
+ if clf is None:
+ print("gender model not available, returning 'unknown'")
+ return "unknown"
+ with _GENDER_PIPELINE_LOCK:
+ outputs = clf(mono, sampling_rate=sr)
+ if not outputs:
+ return "unknown"
+ top = outputs[0]
+ label = str(top.get("label", "")).lower()
+ conf = float(top.get("score", 0.0))
+ if conf < 0.5:
+ gender = "unknown"
+ elif ("female" in label) or ("жен" in label):
+ gender = "female"
+ elif ("male" in label) or ("муж" in label):
+ gender = "male"
+ else:
+ gender = "unknown"
+
+ if gender == "female" and median_f0 is not None and median_f0 < 165.0 and conf < 0.88:
+ print(f"gender pitch assist: reclassifying female->male (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
+ gender = "male"
+ elif gender == "male" and median_f0 is not None and median_f0 > 230.0 and conf < 0.88:
+ print(f"gender pitch assist: reclassifying male->female (median_f0={median_f0:.1f} Hz, conf={conf:.3f})")
+ gender = "female"
+ print(
+ f"gender classifier: label={label}, conf={conf:.3f}, gender={gender}"
+ + (f", median_f0={median_f0:.1f}Hz" if median_f0 is not None else "")
+ )
+ return gender
+ except Exception as exc: # pragma: no cover
+ print(f"Warning: gender classification failed, returning 'unknown': {exc}")
+ return "unknown"
+
+
+def _assert_preset_voice_gender_from_audio(audio_bytes: bytes | None, voice_name: str | None) -> None:
+ """If ``voice_name`` matches a known preset, assert classifier gender matches (skip when unknown)."""
+ if not voice_name or not audio_bytes:
+ return
+ key = str(voice_name).lower()
+ expected_gender = _PRESET_VOICE_GENDER_MAP.get(key)
+ if expected_gender is None:
+ return
+ estimated_gender = _estimate_voice_gender_from_audio(audio_bytes)
+ print(f"Preset voice gender check: preset={key!r}, estimated={estimated_gender!r}, expected={expected_gender!r}")
+ if estimated_gender != "unknown":
+ assert estimated_gender == expected_gender, (
+ f"{voice_name!r} is expected {expected_gender}, but estimated gender is {estimated_gender!r}"
+ )
+
+
+def _compute_pcm_hnr_db(pcm_samples: np.ndarray, sr: int = _PCM_SPEECH_SAMPLE_RATE_HZ) -> float:
+ frame_len = int(0.03 * sr)
+ hop = frame_len // 2
+ hnr_values: list[float] = []
+ for start in range(0, len(pcm_samples) - frame_len, hop):
+ frame = pcm_samples[start : start + frame_len].astype(np.float32, copy=False)
+ frame = frame - np.mean(frame)
+ if np.max(np.abs(frame)) < 0.01:
+ continue
+ ac = np.correlate(frame, frame, mode="full")[len(frame) - 1 :]
+ ac = ac / (ac[0] + 1e-10)
+ min_lag = int(sr / 400)
+ max_lag = min(int(sr / 80), len(ac))
+ if min_lag >= max_lag:
+ continue
+ peak = float(np.max(ac[min_lag:max_lag]))
+ if 0 < peak < 1:
+ hnr_values.append(10 * np.log10(peak / (1 - peak + 1e-10)))
+ return float(np.mean(hnr_values)) if hnr_values else 0.0
+
+
+def _assert_pcm_int16_speech_hnr(audio_bytes: bytes) -> None:
+ assert audio_bytes is not None and len(audio_bytes) >= 2, "missing PCM bytes"
+ assert len(audio_bytes) % 2 == 0, "PCM byte length must be aligned to int16"
+ pcm_samples = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
+ hnr = _compute_pcm_hnr_db(pcm_samples)
+ print(f"PCM speech HNR: {hnr:.2f} dB (threshold: {_MIN_PCM_SPEECH_HNR_DB} dB)")
+ assert hnr >= _MIN_PCM_SPEECH_HNR_DB, (
+ f"Audio distortion detected: HNR={hnr:.2f} dB < {_MIN_PCM_SPEECH_HNR_DB} dB. "
+ "Voice clone decoder may be losing ref_code speaker context on later chunks."
+ )
+
+
+def assert_omni_response(response: Any, request_config: dict[str, Any], run_level):
+ """
+ Validate response results.
+
+ Args:
+ response: OmniResponse object
+
+ Raises:
+ AssertionError: When the response does not meet validation criteria
+ """
+ assert response.success, "The request failed."
+ e2e_latency = response.e2e_latency
+ if e2e_latency is not None:
+ print(f"the e2e latency is: {e2e_latency}")
+
+ modalities = request_config.get("modalities", ["text", "audio"])
+
+ if run_level == "advanced_model":
+ if "audio" in modalities:
+ assert response.audio_content is not None, "No audio output is generated"
+ print(f"audio content is: {response.audio_content}")
+ speaker = request_config.get("speaker")
+ if speaker:
+ _assert_preset_voice_gender_from_audio(
+ response.audio_bytes,
+ speaker,
+ )
+
+ if "text" in modalities:
+ assert response.text_content is not None, "No text output is generated"
+ print(f"text content is: {response.text_content}")
+
+ # Verify image description
+ word_types = ["text", "image", "audio", "video"]
+ keywords_dict = request_config.get("key_words", {})
+ for word_type in word_types:
+ keywords = keywords_dict.get(word_type)
+ if "text" in modalities:
+ if keywords:
+ text_lower = response.text_content.lower()
+ assert any(str(kw).lower() in text_lower for kw in keywords), (
+ "The output does not contain any of the keywords."
+ )
+ else:
+ if keywords:
+ audio_lower = response.audio_content.lower()
+ assert any(str(kw).lower() in audio_lower for kw in keywords), (
+ "The output does not contain any of the keywords."
+ )
+
+ # Verify similarity (Whisper transcript vs streamed/detokenized text)
+ if "text" in modalities and "audio" in modalities:
+ assert response.similarity is not None and response.similarity > 0.9, (
+ "The audio content is not same as the text"
+ )
+ print(f"similarity is: {response.similarity}")
+
+
+def assert_audio_speech_response(response: Any, request_config: dict[str, Any], run_level: str) -> None:
+ assert response.success, "The request failed."
+ e2e_latency = getattr(response, "e2e_latency", None)
+ if e2e_latency is not None:
+ print(f"the avg e2e latency is: {e2e_latency}")
+
+ req_fmt = request_config.get("response_format")
+ if req_fmt == "pcm" and response.audio_bytes:
+ _assert_pcm_int16_speech_hnr(response.audio_bytes)
+ if response.audio_format:
+ assert "pcm" in response.audio_format.lower(), (
+ f"Expected audio/pcm content-type, got {response.audio_format!r}"
+ )
+ elif req_fmt == "wav" and response.audio_format:
+ assert req_fmt in response.audio_format
+
+ if run_level == "advanced_model" and req_fmt != "pcm":
+ expected_text = request_config.get("input")
+ if expected_text:
+ transcript = (response.audio_content or "").strip()
+ print(f"audio content is: {transcript}")
+ print(f"input text is: {expected_text}")
+ similarity = cosine_similarity_text(transcript.lower(), expected_text.lower())
+ print(f"Cosine similarity: {similarity:.3f}")
+ assert similarity > 0.9, (
+ f"Transcript doesn't match input: similarity={similarity:.2f}, transcript='{transcript}'"
+ )
+ _assert_preset_voice_gender_from_audio(response.audio_bytes, request_config.get("voice"))
+
+
+def assert_diffusion_response(response: Any, request_config: dict[str, Any], run_level: str = None):
+ assert response.success, "The request failed."
+ e2e_latency = getattr(response, "e2e_latency", None)
+ if e2e_latency is not None:
+ print(f"the avg e2e is: {e2e_latency}")
+ has_any_content = any(content is not None for content in (response.images, response.videos, response.audios))
+ assert has_any_content, "Response contains no images, videos, or audios"
+ if response.images is not None:
+ assert_image_diffusion_response(response=response, request_config=request_config, run_level=run_level)
+ if response.videos is not None:
+ assert_video_diffusion_response(response=response, request_config=request_config, run_level=run_level)
+ if response.audios is not None:
+ assert_audio_diffusion_response(response=response, request_config=request_config, run_level=run_level)
+
+
+__all__ = [
+ "assert_audio_diffusion_response",
+ "assert_audio_speech_response",
+ "assert_diffusion_response",
+ "assert_image_diffusion_response",
+ "assert_image_valid",
+ "assert_omni_response",
+ "assert_video_diffusion_response",
+ "assert_video_valid",
+ "assert_audio_valid",
+]
diff --git a/tests/helpers/env.py b/tests/helpers/env.py
new file mode 100644
index 00000000000..22ec9a78626
--- /dev/null
+++ b/tests/helpers/env.py
@@ -0,0 +1,280 @@
+"""Test environment / lifecycle helpers (GPU cleanup hooks and memory monitoring for tests).
+
+``vllm.platforms`` / ``vllm_omni.platforms`` are imported only inside functions that need them
+so importing this module at pytest plugin load does not run before session autouse fixtures
+"""
+
+from __future__ import annotations
+
+import gc
+import os
+import subprocess
+import threading
+import time
+from contextlib import contextmanager
+
+import torch
+
+
+def run_forced_gpu_cleanup_round() -> None:
+ run_pre_test_cleanup(enable_force=True)
+ run_post_test_cleanup(enable_force=True)
+
+
+def get_physical_device_indices(devices):
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
+ if visible_devices is None:
+ return devices
+ visible_indices = [int(x) for x in visible_devices.split(",")]
+ index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
+ return [index_mapping[i] for i in devices if i in index_mapping]
+
+
+def wait_for_gpu_memory_to_clear(
+ *,
+ devices: list[int],
+ threshold_bytes: int | None = None,
+ threshold_ratio: float | None = None,
+ timeout_s: float = 120,
+) -> None:
+ from vllm.platforms import current_platform
+
+ assert threshold_bytes is not None or threshold_ratio is not None
+ devices = get_physical_device_indices(devices)
+ start_time = time.time()
+
+ device_list = ", ".join(str(d) for d in devices)
+ if threshold_bytes is not None:
+ threshold_str = f"{threshold_bytes / 2**30:.2f} GiB"
+ condition_str = f"Memory usage ≤ {threshold_str}"
+ else:
+ threshold_percent = threshold_ratio * 100
+ threshold_str = f"{threshold_percent:.1f}%"
+ condition_str = f"Memory usage ratio ≤ {threshold_str}"
+
+ print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}")
+
+ if threshold_bytes is not None:
+
+ def is_free(used, total):
+ return used <= threshold_bytes / 2**30
+ else:
+
+ def is_free(used, total):
+ return used / total <= threshold_ratio
+
+ @contextmanager
+ def nvml_scope():
+ if current_platform.is_rocm():
+ from amdsmi import amdsmi_init, amdsmi_shut_down
+
+ amdsmi_init()
+ try:
+ yield
+ finally:
+ amdsmi_shut_down()
+ elif current_platform.is_cuda():
+ from vllm.third_party.pynvml import nvmlInit, nvmlShutdown
+
+ nvmlInit()
+ try:
+ yield
+ finally:
+ nvmlShutdown()
+ else:
+ yield
+
+ is_rocm = current_platform.is_rocm()
+
+ with nvml_scope():
+ if is_rocm:
+ from amdsmi import amdsmi_get_gpu_vram_usage, amdsmi_get_processor_handles
+ elif current_platform.is_cuda():
+ from vllm.third_party.pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
+
+ while True:
+ output: dict[int, str] = {}
+ output_raw: dict[int, tuple[float, float]] = {}
+ for device in devices:
+ if is_rocm:
+ dev_handle = amdsmi_get_processor_handles()[device]
+ mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
+ gb_used = mem_info["vram_used"] / 2**10
+ gb_total = mem_info["vram_total"] / 2**10
+ else:
+ dev_handle = nvmlDeviceGetHandleByIndex(device)
+ mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
+ gb_used = mem_info.used / 2**30
+ gb_total = mem_info.total / 2**30
+ output_raw[device] = (gb_used, gb_total)
+ usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0
+ output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)"
+
+ print("[GPU Memory Status] Current usage:")
+ for device_id, mem_info in output.items():
+ print(f" GPU {device_id}: {mem_info}")
+
+ dur_s = time.time() - start_time
+ elapsed_minutes = dur_s / 60
+ if all(is_free(used, total) for used, total in output_raw.values()):
+ print(f"[GPU Memory Freed] Devices {device_list} meet memory condition")
+ print(f" Condition: {condition_str}")
+ print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
+ break
+
+ if dur_s >= timeout_s:
+ raise ValueError(
+ f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n"
+ f"Condition: {condition_str}\n"
+ f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices)
+ )
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ time.sleep(5)
+
+
+def _print_gpu_processes() -> None:
+ """Print GPU information including nvidia-smi and system processes."""
+
+ print("\n" + "=" * 80)
+ print("NVIDIA GPU Information (nvidia-smi)")
+ print("=" * 80)
+
+ try:
+ nvidia_result = subprocess.run(
+ ["nvidia-smi"],
+ capture_output=True,
+ text=True,
+ timeout=5,
+ )
+
+ if nvidia_result.returncode == 0:
+ lines = nvidia_result.stdout.strip().split("\n")
+ for line in lines[:20]:
+ print(line)
+
+ if len(lines) > 20:
+ print(f"... (showing first 20 of {len(lines)} lines)")
+ else:
+ print("nvidia-smi command failed")
+
+ except (subprocess.TimeoutExpired, FileNotFoundError):
+ print("nvidia-smi not available or timed out")
+ except Exception as e:
+ print(f"Error running nvidia-smi: {e}")
+
+ print("\n" + "=" * 80)
+ print("Detailed GPU Processes (nvidia-smi pmon)")
+ print("=" * 80)
+
+ try:
+ pmon_result = subprocess.run(
+ ["nvidia-smi", "pmon", "-c", "1"],
+ capture_output=True,
+ text=True,
+ timeout=3,
+ )
+
+ if pmon_result.returncode == 0 and pmon_result.stdout.strip():
+ print(pmon_result.stdout)
+ else:
+ print("No active GPU processes found via nvidia-smi pmon")
+
+ except Exception:
+ print("nvidia-smi pmon not available")
+
+ print("\n" + "=" * 80)
+ print("System Processes with GPU keywords")
+ print("=" * 80)
+
+
+_SKIPPED_GPU_CLEANUP_MSG = (
+ "\nSkipping GPU memory cleanup check (typically: instance already up; no check needed between tests)\n"
+)
+
+
+def run_pre_test_cleanup(enable_force: bool = False) -> None:
+ if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
+ print(_SKIPPED_GPU_CLEANUP_MSG)
+ return
+
+ print("Pre-test GPU status:")
+
+ num_gpus = torch.cuda.device_count()
+ if num_gpus > 0:
+ try:
+ wait_for_gpu_memory_to_clear(
+ devices=list(range(num_gpus)),
+ threshold_ratio=0.05,
+ )
+ except Exception as e:
+ print(f"Pre-test cleanup note: {e}")
+
+
+def run_post_test_cleanup(enable_force: bool = False) -> None:
+ if os.getenv("VLLM_TEST_CLEAN_GPU_MEMORY", "0") != "1" and not enable_force:
+ print(_SKIPPED_GPU_CLEANUP_MSG)
+ return
+
+ if torch.cuda.is_available():
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ print("Post-test GPU status:")
+ _print_gpu_processes()
+
+
+class DeviceMemoryMonitor:
+ """Poll global device memory usage."""
+
+ def __init__(self, device_index: int, interval: float = 0.05):
+ self.device_index = device_index
+ self.interval = interval
+ self._peak_used_mb = 0.0
+ self._stop_event = threading.Event()
+ self._thread: threading.Thread | None = None
+
+ def start(self) -> None:
+ from vllm_omni.platforms import current_omni_platform
+
+ def monitor_loop() -> None:
+ while not self._stop_event.is_set():
+ try:
+ with current_omni_platform.device(self.device_index):
+ free_bytes, total_bytes = current_omni_platform.mem_get_info()
+ used_mb = (total_bytes - free_bytes) / (1024**2)
+ self._peak_used_mb = max(self._peak_used_mb, used_mb)
+ except Exception:
+ pass
+ time.sleep(self.interval)
+
+ self._thread = threading.Thread(target=monitor_loop, daemon=False)
+ self._thread.start()
+
+ def stop(self) -> None:
+ if self._thread is None:
+ return
+ self._stop_event.set()
+ self._thread.join(timeout=2.0)
+
+ @property
+ def peak_used_mb(self) -> float:
+ from vllm_omni.platforms import current_omni_platform
+
+ fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2)
+ fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2)
+ return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
+
+ def __del__(self):
+ self.stop()
+
+
+__all__ = [
+ "DeviceMemoryMonitor",
+ "get_physical_device_indices",
+ "run_post_test_cleanup",
+ "run_pre_test_cleanup",
+ "run_forced_gpu_cleanup_round",
+ "wait_for_gpu_memory_to_clear",
+]
diff --git a/tests/helpers/fixtures/__init__.py b/tests/helpers/fixtures/__init__.py
new file mode 100644
index 00000000000..8bd090b7824
--- /dev/null
+++ b/tests/helpers/fixtures/__init__.py
@@ -0,0 +1 @@
+"""Pytest fixture modules under tests.helpers."""
diff --git a/tests/helpers/fixtures/env.py b/tests/helpers/fixtures/env.py
new file mode 100644
index 00000000000..939bad02ca4
--- /dev/null
+++ b/tests/helpers/fixtures/env.py
@@ -0,0 +1,59 @@
+import os
+
+import pytest
+import torch
+
+
+@pytest.fixture(scope="session", autouse=True)
+def default_env():
+ # Keep behavior but avoid import-time side effects (RFC #2299).
+ keys = ("VLLM_WORKER_MULTIPROC_METHOD", "VLLM_TARGET_DEVICE")
+ previous = {key: os.environ.get(key) for key in keys}
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = previous["VLLM_WORKER_MULTIPROC_METHOD"] or "spawn"
+ os.environ["VLLM_TARGET_DEVICE"] = previous["VLLM_TARGET_DEVICE"] or (
+ "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
+ )
+ yield
+ for key, value in previous.items():
+ if value is None:
+ os.environ.pop(key, None)
+ else:
+ os.environ[key] = value
+
+
+@pytest.fixture(scope="session")
+def model_prefix() -> str:
+ prefix = os.environ.get("MODEL_PREFIX", "")
+ return f"{prefix.rstrip('/')}/" if prefix else ""
+
+
+@pytest.fixture(autouse=True)
+def clean_gpu_memory_between_tests():
+ # Import here so ``tests.helpers.env`` (and vLLM platform modules) load only
+ # after session autouse fixtures like ``default_env`` have run (RFC #2299).
+ from tests.helpers.env import run_post_test_cleanup, run_pre_test_cleanup
+
+ print("\n=== PRE-TEST GPU CLEANUP ===")
+ run_pre_test_cleanup()
+ yield
+ run_post_test_cleanup()
+
+
+@pytest.fixture(scope="session", autouse=True)
+def default_vllm_config():
+ """Set a default VllmConfig for the whole test session.
+
+ Session scope ensures module-scoped fixtures (e.g. ``omni_runner``) and
+ deferred imports of ``tests.helpers.runtime`` both see the same context.
+ Function-scoped autouse ran too late for ``OmniRunner`` setup and could
+ desynchronize vLLM init vs request preprocessing (e.g. renderer state).
+ """
+ from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config
+
+ # Use CPU device if no GPU is available (e.g., in CI environments)
+ has_gpu = torch.cuda.is_available() and torch.cuda.device_count() > 0
+ device = "cuda" if has_gpu else "cpu"
+ device_config = DeviceConfig(device=device)
+
+ with set_current_vllm_config(VllmConfig(device_config=device_config)):
+ yield
diff --git a/tests/helpers/fixtures/log.py b/tests/helpers/fixtures/log.py
new file mode 100644
index 00000000000..798fa4ae6c7
--- /dev/null
+++ b/tests/helpers/fixtures/log.py
@@ -0,0 +1,7 @@
+import pytest
+
+
+@pytest.fixture(autouse=True)
+def log_test_name_before_test(request: pytest.FixtureRequest):
+ print(f"--- Running test: {request.node.name}")
+ yield
diff --git a/tests/helpers/fixtures/run_args.py b/tests/helpers/fixtures/run_args.py
new file mode 100644
index 00000000000..b18a64b9810
--- /dev/null
+++ b/tests/helpers/fixtures/run_args.py
@@ -0,0 +1,17 @@
+import pytest
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--run-level",
+ action="store",
+ default="core_model",
+ choices=["core_model", "advanced_model"],
+ help="Test level to run: L2, L3",
+ )
+
+
+@pytest.fixture(scope="session")
+def run_level(request) -> str:
+ """Session test level from ``--run-level`` (see CI five-level docs)."""
+ return request.config.getoption("--run-level")
diff --git a/tests/helpers/fixtures/runtime.py b/tests/helpers/fixtures/runtime.py
new file mode 100644
index 00000000000..77a468f2f34
--- /dev/null
+++ b/tests/helpers/fixtures/runtime.py
@@ -0,0 +1,137 @@
+"""Runtime fixtures (OmniRunner / OmniServer). Imports are deferred to fixture time.
+
+Loading ``tests.helpers.runtime`` at plugin import time (before session fixtures)
+pulls in vLLM/vllm_omni too early and breaks initialization order vs the legacy
+monolithic conftest. Defer imports until fixtures run so ``default_env`` /
+``default_vllm_config`` run first.
+"""
+
+from __future__ import annotations
+
+import threading
+from collections.abc import Generator
+from typing import Any
+
+import pytest
+import yaml
+
+from tests.helpers.runtime import OmniServer
+from tests.helpers.stage_config import modify_stage_config
+
+omni_fixture_lock = threading.Lock()
+
+
+@pytest.fixture(scope="module")
+def omni_server(request: pytest.FixtureRequest, run_level: str, model_prefix: str) -> Generator[OmniServer, Any, None]:
+ """Start vLLM-Omni through the standard or stage-CLI launcher.
+
+ The fixture stays module-scoped because multi-stage initialization is costly.
+ The ``use_stage_cli`` flag on ``OmniServerParams`` routes the setup through the
+ stage-CLI harness while still reusing the same fixture grouping semantics.
+ """
+ with omni_fixture_lock:
+ from tests.helpers.runtime import OmniServer, OmniServerParams, OmniServerStageCli
+
+ params: OmniServerParams = request.param
+ model = model_prefix + params.model
+ port = params.port
+ stage_config_path = params.stage_config_path
+ if run_level == "advanced_model" and stage_config_path is not None:
+ with open(stage_config_path, encoding="utf-8") as f:
+ cfg = yaml.safe_load(f) or {}
+ # Strip ``load_format: dummy`` (CI overlay default) so advanced_model
+ # tests use real weights. New schema (``stages:``) writes the field
+ # flat at stage level; legacy schema (``stage_args:``) nests it as
+ # ``engine_args.load_format``. Handle both.
+ new_schema_stages = cfg.get("stages")
+ stage_key = "stages" if new_schema_stages is not None else "stage_args"
+ delete_path = "load_format" if new_schema_stages is not None else "engine_args.load_format"
+ stage_entries = cfg.get(stage_key, [])
+ stage_ids = [stage["stage_id"] for stage in stage_entries if "stage_id" in stage]
+ stage_config_path = modify_stage_config(
+ stage_config_path,
+ deletes={stage_key: {stage_id: [delete_path] for stage_id in stage_ids}},
+ )
+
+ server_args = params.server_args or []
+ if params.use_omni and params.stage_init_timeout is not None:
+ server_args = [*server_args, "--stage-init-timeout", str(params.stage_init_timeout)]
+ else:
+ server_args = [*server_args, "--stage-init-timeout", "600"]
+ if params.init_timeout is not None:
+ server_args = [*server_args, "--init-timeout", str(params.init_timeout)]
+ else:
+ server_args = [*server_args, "--init-timeout", "900"]
+ if params.use_stage_cli:
+ if not params.use_omni:
+ raise ValueError("omni_server with use_stage_cli=True requires use_omni=True")
+ if stage_config_path is None:
+ raise ValueError("omni_server with use_stage_cli=True requires a stage_config_path")
+ server_args += ["--stage-configs-path", stage_config_path]
+
+ with OmniServerStageCli(
+ model,
+ stage_config_path,
+ server_args,
+ port=port,
+ env_dict=params.env_dict,
+ ) as server:
+ print("OmniServer started successfully")
+ yield server
+ print("OmniServer stopping...")
+ else:
+ if stage_config_path is not None:
+ server_args += ["--stage-configs-path", stage_config_path]
+
+ with (
+ OmniServer(
+ model,
+ server_args,
+ port=port,
+ env_dict=params.env_dict,
+ use_omni=params.use_omni,
+ )
+ if port
+ else OmniServer(
+ model,
+ server_args,
+ env_dict=params.env_dict,
+ use_omni=params.use_omni,
+ )
+ ) as server:
+ print("OmniServer started successfully")
+ yield server
+ print("OmniServer stopping...")
+
+ print("OmniServer stopped")
+
+
+@pytest.fixture
+def openai_client(request: pytest.FixtureRequest, run_level: str):
+ """Resolve ``omni_server`` lazily so parametrized server fixtures work like upstream."""
+ from tests.helpers.runtime import OpenAIClientHandler
+
+ server = request.getfixturevalue("omni_server")
+ return OpenAIClientHandler(host=server.host, port=server.port, api_key="EMPTY", run_level=run_level)
+
+
+@pytest.fixture(scope="module")
+def omni_runner(request: pytest.FixtureRequest, model_prefix: str):
+ from tests.helpers.runtime import OmniRunner
+
+ with omni_fixture_lock:
+ model, stage_config_path = request.param
+ model = model_prefix + model
+ with OmniRunner(model, seed=42, stage_configs_path=stage_config_path) as runner:
+ print("OmniRunner started successfully")
+ yield runner
+ print("OmniRunner stopping...")
+
+ print("OmniRunner stopped")
+
+
+@pytest.fixture
+def omni_runner_handler(omni_runner: Any):
+ from tests.helpers.runtime import OmniRunnerHandler
+
+ return OmniRunnerHandler(omni_runner)
diff --git a/tests/helpers/mark.py b/tests/helpers/mark.py
new file mode 100644
index 00000000000..ed45dd7e9a1
--- /dev/null
+++ b/tests/helpers/mark.py
@@ -0,0 +1,135 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Pytest marks and decorators for hardware / resource selection (CUDA, ROCm, …)."""
+
+import pytest
+from vllm.utils.torch_utils import cuda_device_count_stateless
+
+# Re-exported from tests.helpers.env (GPU wait + DeviceMemoryMonitor).
+
+
+def cuda_marks(*, res: str, num_cards: int):
+ test_platform_detail = pytest.mark.cuda
+ if res == "L4":
+ test_resource = pytest.mark.L4
+ elif res == "H100":
+ test_resource = pytest.mark.H100
+ else:
+ raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100")
+ marks = [test_resource, test_platform_detail]
+ if num_cards == 1:
+ return marks
+ test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards)
+ test_skipif = pytest.mark.skipif_cuda(
+ cuda_device_count_stateless() < num_cards,
+ reason=f"Need at least {num_cards} CUDA GPUs to run the test.",
+ )
+ return marks + [test_distributed, test_skipif]
+
+
+def rocm_marks(*, res: str, num_cards: int):
+ test_platform_detail = pytest.mark.rocm
+ if res == "MI325":
+ test_resource = pytest.mark.MI325
+ else:
+ raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325")
+ marks = [test_resource, test_platform_detail]
+ if num_cards == 1:
+ return marks
+ test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
+ return marks + [test_distributed]
+
+
+def xpu_marks(*, res: str, num_cards: int):
+ test_platform_detail = pytest.mark.xpu
+ if res == "B60":
+ test_resource = pytest.mark.B60
+ else:
+ raise ValueError(f"Invalid XPU resource type: {res}. Supported: B60")
+ marks = [test_resource, test_platform_detail]
+ if num_cards == 1:
+ return marks
+ test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
+ return marks + [test_distributed]
+
+
+def musa_marks(*, res: str, num_cards: int):
+ test_platform_detail = pytest.mark.musa
+ if res == "S5000":
+ test_resource = pytest.mark.S5000
+ else:
+ raise ValueError(f"Invalid MUSA resource type: {res}. Supported: S5000")
+ marks = [test_resource, test_platform_detail]
+ if num_cards == 1:
+ return marks
+ test_distributed = pytest.mark.distributed_musa(num_cards=num_cards)
+ return marks + [test_distributed]
+
+
+def gpu_marks(*, res: str, num_cards: int):
+ test_platform = pytest.mark.gpu
+ if res in ("L4", "H100"):
+ return [test_platform] + cuda_marks(res=res, num_cards=num_cards)
+ if res == "MI325":
+ return [test_platform] + rocm_marks(res=res, num_cards=num_cards)
+ if res == "B60":
+ return [test_platform] + xpu_marks(res=res, num_cards=num_cards)
+ if res == "S5000":
+ return [test_platform] + musa_marks(res=res, num_cards=num_cards)
+ raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325, B60, S5000")
+
+
+def npu_marks(*, res: str, num_cards: int):
+ test_platform = pytest.mark.npu
+ if res == "A2":
+ test_resource = pytest.mark.A2
+ elif res == "A3":
+ test_resource = pytest.mark.A3
+ else:
+ test_resource = None
+ if num_cards == 1:
+ return [mark for mark in [test_platform, test_resource] if mark is not None]
+ test_distributed = pytest.mark.distributed_npu(num_cards=num_cards)
+ return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None]
+
+
+def hardware_marks(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
+ for platform, _ in res.items():
+ if platform not in ("cuda", "rocm", "xpu", "npu", "musa"):
+ raise ValueError(f"Unsupported platform: {platform}")
+ if isinstance(num_cards, int):
+ num_cards_dict = {platform: num_cards for platform in res.keys()}
+ else:
+ num_cards_dict = num_cards
+ for platform in num_cards_dict.keys():
+ if platform not in res:
+ raise ValueError(f"Platform '{platform}' in num_cards but not in res.")
+ for platform in res.keys():
+ if platform not in num_cards_dict:
+ num_cards_dict[platform] = 1
+
+ all_marks: list[pytest.MarkDecorator] = []
+ for platform, resource in res.items():
+ cards = num_cards_dict[platform]
+ if platform in ("cuda", "rocm", "xpu"):
+ marks = gpu_marks(res=resource, num_cards=cards)
+ elif platform == "musa":
+ marks = musa_marks(res=resource, num_cards=cards)
+ elif platform == "npu":
+ marks = npu_marks(res=resource, num_cards=cards)
+ else:
+ raise ValueError(f"Unsupported platform: {platform}")
+ all_marks.extend(marks)
+ return all_marks
+
+
+def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
+ all_marks = hardware_marks(res=res, num_cards=num_cards)
+
+ def wrapper(f):
+ func = f
+ for mark in reversed(all_marks):
+ func = mark(func)
+ return func
+
+ return wrapper
diff --git a/tests/helpers/media.py b/tests/helpers/media.py
new file mode 100644
index 00000000000..3c45c2a9d95
--- /dev/null
+++ b/tests/helpers/media.py
@@ -0,0 +1,643 @@
+"""Synthetic media generation and media/text utilities for tests."""
+
+import base64
+import concurrent.futures
+import gc
+import io
+import logging
+import math
+import multiprocessing
+import os
+import random
+import re
+import subprocess
+import tempfile
+import time
+import uuid
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import soundfile as sf
+from PIL import Image
+
+logger = logging.getLogger(__name__)
+
+
+def _resolve_synthetic_media_cache_dir(cache_dir: Path | str | None) -> Path:
+ if cache_dir is not None:
+ return Path(cache_dir).expanduser().resolve()
+ return Path(tempfile.gettempdir()) / "vllm_omni_test_synthetic_media"
+
+
+def _np_array_from_mp4_bytes(video_bytes: bytes) -> np.ndarray:
+ """Decode MP4 bytes to a (T, H, W, 3) uint8 RGB stack (matches in-memory synthetic frames)."""
+ import cv2
+
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
+ tmp.write(video_bytes)
+ path = tmp.name
+ cap = None
+ try:
+ cap = cv2.VideoCapture(path)
+ if not cap.isOpened():
+ raise RuntimeError("Failed to open cached synthetic video for decode")
+ frames: list[np.ndarray] = []
+ while True:
+ ok, frame_bgr = cap.read()
+ if not ok:
+ break
+ frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
+ if not frames:
+ raise RuntimeError("Cached synthetic video has no decodable frames")
+ return np.stack(frames, axis=0)
+ finally:
+ if cap is not None:
+ cap.release()
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+
+
+def generate_synthetic_audio(
+ duration: int,
+ num_channels: int,
+ sample_rate: int = 48000,
+ *,
+ force_regenerate: bool = False,
+ cache_dir: Path | str | None = None,
+) -> dict[str, Any]:
+ """
+ Generate TTS speech with pyttsx3 and return base64 string.
+
+ Caches the WAV under ``cache_dir`` when given, else under the default temp
+ subdirectory. Reuses the file when the same
+ ``duration`` / ``num_channels`` / ``sample_rate`` are requested unless
+ ``force_regenerate`` is true.
+ """
+ root = _resolve_synthetic_media_cache_dir(cache_dir)
+ root.mkdir(parents=True, exist_ok=True)
+ cache_path = root / f"synth_audio_d{duration}_ch{num_channels}_sr{sample_rate}.wav"
+
+ if not force_regenerate and cache_path.is_file():
+ data, _sr = sf.read(str(cache_path), dtype="float32", always_2d=True)
+ audio_bytes = cache_path.read_bytes()
+ return {
+ "np_array": np.asarray(data, dtype=np.float32),
+ "base64": base64.b64encode(audio_bytes).decode("utf-8"),
+ "file_path": str(cache_path.resolve()),
+ }
+
+ import pyttsx3
+
+ def _pick_voice(engine: pyttsx3.Engine) -> str | None:
+ voices = engine.getProperty("voices")
+ if not voices:
+ return None
+
+ preferred_tokens = (
+ "natural",
+ "jenny",
+ "sonia",
+ "susan",
+ "zira",
+ "aria",
+ "hazel",
+ "samantha",
+ "ava",
+ "allison",
+ "female",
+ "woman",
+ "english-us",
+ "en-us",
+ "english",
+ )
+ discouraged_tokens = (
+ "espeak",
+ "robot",
+ "mbrola",
+ "microsoft david",
+ "male",
+ "man",
+ )
+
+ best_voice = voices[0]
+ best_score = float("-inf")
+ for voice in voices:
+ voice_text = f"{getattr(voice, 'id', '')} {getattr(voice, 'name', '')}".lower()
+ voice_languages = " ".join(
+ lang.decode(errors="ignore") if isinstance(lang, bytes) else str(lang)
+ for lang in getattr(voice, "languages", [])
+ ).lower()
+ combined_text = f"{voice_text} {voice_languages}"
+ score = 0
+ for idx, token in enumerate(preferred_tokens):
+ if token in combined_text:
+ score += 20 - idx
+ for token in discouraged_tokens:
+ if token in combined_text:
+ score -= 10
+ if "english" in combined_text or "en_" in combined_text or "en-" in combined_text:
+ score += 4
+ if "en-us" in combined_text or "english-us" in combined_text:
+ score += 4
+ if score > best_score:
+ best_score = score
+ best_voice = voice
+
+ return best_voice.id
+
+ def _resample_audio(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
+ if src_sr == dst_sr or len(audio) == 0:
+ return audio.astype(np.float32)
+ src_len = audio.shape[0]
+ dst_len = max(1, int(round(src_len * float(dst_sr) / float(src_sr))))
+ src_idx = np.arange(src_len, dtype=np.float32)
+ dst_idx = np.linspace(0, src_len - 1, dst_len, dtype=np.float32)
+ resampled_channels: list[np.ndarray] = []
+ for ch in range(audio.shape[1]):
+ resampled_channels.append(np.interp(dst_idx, src_idx, audio[:, ch]).astype(np.float32))
+ return np.stack(resampled_channels, axis=1)
+
+ def _match_channels(audio: np.ndarray, target_channels: int) -> np.ndarray:
+ current_channels = audio.shape[1]
+ if current_channels == target_channels:
+ return audio.astype(np.float32)
+ if target_channels == 1:
+ return np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
+ if current_channels == 1:
+ return np.repeat(audio, target_channels, axis=1).astype(np.float32)
+ collapsed = np.mean(audio, axis=1, keepdims=True, dtype=np.float32)
+ return np.repeat(collapsed, target_channels, axis=1).astype(np.float32)
+
+ def _trim_silence(audio: np.ndarray, threshold: float = 0.01) -> np.ndarray:
+ if len(audio) == 0:
+ return audio
+ energy = np.max(np.abs(audio), axis=1)
+ voiced = np.where(energy > threshold)[0]
+ if len(voiced) == 0:
+ return audio
+ start = max(0, int(voiced[0]) - int(sample_rate * 0.02))
+ end = min(len(audio), int(voiced[-1]) + int(sample_rate * 0.04) + 1)
+ return audio[start:end]
+
+ def _enhance_speech(audio: np.ndarray) -> np.ndarray:
+ if len(audio) == 0:
+ return audio.astype(np.float32)
+ enhanced = audio.astype(np.float32).copy()
+ enhanced -= np.mean(enhanced, axis=0, keepdims=True, dtype=np.float32)
+ if len(enhanced) > 1:
+ preemphasis = enhanced.copy()
+ preemphasis[1:] = enhanced[1:] - 0.94 * enhanced[:-1]
+ enhanced = 0.7 * enhanced + 0.3 * preemphasis
+ enhanced = np.sign(enhanced) * np.sqrt(np.abs(enhanced))
+ fade = min(len(enhanced) // 4, max(1, int(sample_rate * 0.01)))
+ if fade > 1:
+ ramp_in = np.linspace(0.0, 1.0, fade, dtype=np.float32)
+ ramp_out = np.linspace(1.0, 0.0, fade, dtype=np.float32)
+ enhanced[:fade] *= ramp_in[:, None]
+ enhanced[-fade:] *= ramp_out[:, None]
+ peak = float(np.max(np.abs(enhanced)))
+ if peak > 1e-8:
+ enhanced = enhanced / peak * 0.95
+ return enhanced.astype(np.float32)
+
+ phrase_text = "test"
+ num_samples = int(sample_rate * max(1, duration))
+ audio_data = np.zeros((num_samples, num_channels), dtype=np.float32)
+
+ engine = pyttsx3.init()
+ engine.setProperty("rate", 112)
+ engine.setProperty("volume", 1.0)
+ selected_voice = _pick_voice(engine)
+ if selected_voice is not None:
+ engine.setProperty("voice", selected_voice)
+
+ temp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
+ temp_wav.close()
+ try:
+ engine.save_to_file(phrase_text, temp_wav.name)
+ engine.runAndWait()
+ engine.stop()
+
+ ready = False
+ for _ in range(50):
+ if os.path.exists(temp_wav.name) and os.path.getsize(temp_wav.name) > 44:
+ ready = True
+ break
+ time.sleep(0.1)
+ if not ready:
+ raise RuntimeError("pyttsx3 did not produce a WAV file in time.")
+
+ tts_audio, tts_sr = sf.read(temp_wav.name, dtype="float32", always_2d=True)
+ finally:
+ if os.path.exists(temp_wav.name):
+ os.unlink(temp_wav.name)
+
+ if len(tts_audio) == 0:
+ raise RuntimeError("pyttsx3 produced an empty WAV file.")
+
+ tts_audio = _resample_audio(tts_audio, tts_sr, sample_rate)
+ tts_audio = _match_channels(tts_audio, num_channels)
+ tts_audio = _trim_silence(tts_audio, threshold=0.012)
+ tts_audio = _enhance_speech(tts_audio)
+
+ lead_silence = min(int(sample_rate * 0.02), num_samples // 8)
+ pause_samples = int(sample_rate * 0.18)
+ start = lead_silence
+ phrase_len = tts_audio.shape[0]
+ while start < num_samples:
+ take = min(phrase_len, num_samples - start)
+ audio_data[start : start + take] = tts_audio[:take]
+ start += phrase_len + pause_samples
+
+ max_amp = float(np.max(np.abs(audio_data)))
+ if max_amp > 0:
+ audio_data = audio_data / max_amp * 0.95
+
+ sf.write(str(cache_path), audio_data, sample_rate, format="WAV", subtype="PCM_16")
+ audio_bytes = cache_path.read_bytes()
+
+ return {
+ "np_array": audio_data.copy(),
+ "base64": base64.b64encode(audio_bytes).decode("utf-8"),
+ "file_path": str(cache_path.resolve()),
+ }
+
+
+def _mux_mp4_bytes_with_synthetic_audio(
+ video_mp4_bytes: bytes,
+ *,
+ num_frames: int,
+ fps: float = 30.0,
+ sample_rate: int = 48000,
+) -> bytes:
+ duration_sec = num_frames / fps if fps > 0 else 0.0
+ duration_int = max(1, int(math.ceil(duration_sec)))
+
+ try:
+ audio_result = generate_synthetic_audio(
+ duration=duration_int,
+ num_channels=1,
+ sample_rate=sample_rate,
+ )
+ audio_pcm = audio_result["np_array"]
+ except Exception as e:
+ logger.warning("Synthetic video: generate_synthetic_audio failed (%s); using video-only MP4.", e)
+ return video_mp4_bytes
+
+ try:
+ import imageio_ffmpeg
+
+ ffmpeg_exe = imageio_ffmpeg.get_ffmpeg_exe()
+ except Exception:
+ ffmpeg_exe = "ffmpeg"
+
+ try:
+ with tempfile.TemporaryDirectory(prefix="syn_vid_mux_") as tmp:
+ vid_path = os.path.join(tmp, "video.mp4")
+ wav_path = os.path.join(tmp, "audio.wav")
+ out_path = os.path.join(tmp, "out.mp4")
+ with open(vid_path, "wb") as f:
+ f.write(video_mp4_bytes)
+ sf.write(wav_path, audio_pcm, sample_rate, format="WAV", subtype="PCM_16")
+ cmd = [
+ ffmpeg_exe,
+ "-y",
+ "-nostdin",
+ "-hide_banner",
+ "-loglevel",
+ "error",
+ "-i",
+ vid_path,
+ "-i",
+ wav_path,
+ "-c:v",
+ "copy",
+ "-c:a",
+ "aac",
+ "-b:a",
+ "128k",
+ "-shortest",
+ "-movflags",
+ "+faststart",
+ out_path,
+ ]
+ subprocess.run(cmd, check=True, stdin=subprocess.DEVNULL, timeout=300)
+ with open(out_path, "rb") as f:
+ return f.read()
+ except (
+ FileNotFoundError,
+ subprocess.CalledProcessError,
+ subprocess.TimeoutExpired,
+ OSError,
+ ) as e:
+ logger.warning("Synthetic video: audio mux failed (%s); using video-only MP4.", e)
+ return video_mp4_bytes
+
+
+def generate_synthetic_video(
+ width: int,
+ height: int,
+ num_frames: int,
+ *,
+ embed_audio: bool = False,
+ force_regenerate: bool = False,
+ cache_dir: Path | str | None = None,
+) -> dict[str, Any]:
+ """
+ Generate synthetic MP4 (optional AAC audio). Caches final bytes by
+ ``width`` / ``height`` / ``num_frames`` / ``embed_audio`` unless
+ ``force_regenerate`` is true. Cache root: ``cache_dir`` if given, else the
+ default temp subdirectory.
+ """
+ root = _resolve_synthetic_media_cache_dir(cache_dir)
+ root.mkdir(parents=True, exist_ok=True)
+ cache_path = root / f"synth_video_w{width}_h{height}_nf{num_frames}_ea{int(embed_audio)}.mp4"
+
+ if not force_regenerate and cache_path.is_file():
+ video_bytes = cache_path.read_bytes()
+ return {
+ "np_array": _np_array_from_mp4_bytes(video_bytes),
+ "base64": base64.b64encode(video_bytes).decode("utf-8"),
+ "file_path": str(cache_path.resolve()),
+ }
+
+ import cv2
+ import imageio
+
+ num_balls = random.randint(3, 8)
+ balls = []
+ for _ in range(num_balls):
+ radius = min(width, height) // 8
+ if radius < 1:
+ raise ValueError(f"Video dimensions ({width}x{height}) too small")
+ x = random.randint(radius, width - radius)
+ y = random.randint(radius, height - radius)
+ speed = random.uniform(3.0, 8.0)
+ angle = random.uniform(0, 2 * math.pi)
+ vx = speed * math.cos(angle)
+ vy = speed * math.sin(angle)
+ color_bgr = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
+ balls.append({"x": x, "y": y, "vx": vx, "vy": vy, "radius": radius, "color_bgr": color_bgr})
+
+ video_frames = []
+ for _ in range(num_frames):
+ frame_bgr = np.zeros((height, width, 3), dtype=np.uint8)
+ for ball in balls:
+ ball["x"] += ball["vx"]
+ ball["y"] += ball["vy"]
+ if ball["x"] - ball["radius"] <= 0 or ball["x"] + ball["radius"] >= width:
+ ball["vx"] = -ball["vx"]
+ ball["x"] = max(ball["radius"], min(width - ball["radius"], ball["x"]))
+ if ball["y"] - ball["radius"] <= 0 or ball["y"] + ball["radius"] >= height:
+ ball["vy"] = -ball["vy"]
+ ball["y"] = max(ball["radius"], min(height - ball["radius"], ball["y"]))
+ x, y = int(ball["x"]), int(ball["y"])
+ radius = int(ball["radius"])
+ cv2.circle(frame_bgr, (x, y), radius, ball["color_bgr"], -1)
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
+ video_frames.append(frame_rgb)
+
+ fps = 30
+ buffer = io.BytesIO()
+ writer_kwargs = {
+ "format": "mp4",
+ "fps": fps,
+ "codec": "libx264",
+ "quality": 7,
+ "pixelformat": "yuv420p",
+ "macro_block_size": 16,
+ "ffmpeg_params": ["-preset", "medium", "-crf", "23", "-movflags", "+faststart", "-pix_fmt", "yuv420p"],
+ }
+ try:
+ with imageio.get_writer(buffer, **writer_kwargs) as writer:
+ for frame in video_frames:
+ writer.append_data(frame)
+ buffer.seek(0)
+ video_only_bytes = buffer.read()
+ except Exception as e:
+ print(f"Warning: Failed to encode synthetic video: {e}")
+ raise
+ video_bytes = (
+ _mux_mp4_bytes_with_synthetic_audio(video_only_bytes, num_frames=num_frames, fps=float(fps))
+ if embed_audio
+ else video_only_bytes
+ )
+
+ cache_path.write_bytes(video_bytes)
+
+ return {
+ "np_array": np.array(video_frames),
+ "base64": base64.b64encode(video_bytes).decode("utf-8"),
+ "file_path": str(cache_path.resolve()),
+ }
+
+
+def generate_synthetic_image(
+ width: int,
+ height: int,
+ *,
+ force_regenerate: bool = False,
+ cache_dir: Path | str | None = None,
+) -> dict[str, Any]:
+ """
+ Random colored squares on white background. Caches JPEG by ``width`` /
+ ``height`` unless ``force_regenerate`` is true. Cache root: ``cache_dir``
+ if given, else the default temp subdirectory.
+ """
+ root = _resolve_synthetic_media_cache_dir(cache_dir)
+ root.mkdir(parents=True, exist_ok=True)
+ cache_path = root / f"synth_image_w{width}_h{height}.jpg"
+
+ if not force_regenerate and cache_path.is_file():
+ from PIL import Image as PILImage
+
+ image = PILImage.open(cache_path)
+ image.load()
+ image_bytes = cache_path.read_bytes()
+ return {
+ "np_array": np.array(image).copy(),
+ "base64": base64.b64encode(image_bytes).decode("utf-8"),
+ "file_path": str(cache_path.resolve()),
+ }
+
+ from PIL import ImageDraw
+
+ image = Image.new("RGB", (width, height), (255, 255, 255))
+ draw = ImageDraw.Draw(image)
+ num_squares = random.randint(3, 8)
+ for _ in range(num_squares):
+ square_size = random.randint(max(1, min(width, height) // 8), max(2, min(width, height) // 4))
+ x = random.randint(0, max(0, width - square_size - 1))
+ y = random.randint(0, max(0, height - square_size - 1))
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+ border_width = random.randint(1, 5)
+ draw.rectangle([x, y, x + square_size, y + square_size], fill=color, outline=(0, 0, 0), width=border_width)
+
+ image.save(str(cache_path), format="JPEG", quality=85, optimize=True)
+ image_bytes = cache_path.read_bytes()
+
+ return {
+ "np_array": np.array(image).copy(),
+ "base64": base64.b64encode(image_bytes).decode("utf-8"),
+ "file_path": str(cache_path.resolve()),
+ }
+
+
+def decode_b64_image(b64: str):
+ img = Image.open(io.BytesIO(base64.b64decode(b64)))
+ img.load()
+ return img
+
+
+def preprocess_text(text):
+ import opencc
+
+ word_to_num = {
+ "zero": "0",
+ "one": "1",
+ "two": "2",
+ "three": "3",
+ "four": "4",
+ "five": "5",
+ "six": "6",
+ "seven": "7",
+ "eight": "8",
+ "nine": "9",
+ "ten": "10",
+ }
+ for word, num in word_to_num.items():
+ pattern = r"\b" + re.escape(word) + r"\b"
+ text = re.sub(pattern, num, text, flags=re.IGNORECASE)
+
+ text = re.sub(r"[^\w\s]", "", text)
+ text = re.sub(r"\s+", " ", text)
+ cc = opencc.OpenCC("t2s")
+ text = cc.convert(text)
+ text = re.sub(r"(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff])", "", text)
+ return text.lower().strip()
+
+
+def cosine_similarity_text(text1, text2, n: int = 3):
+ from collections import Counter
+
+ if not text1 or not text2:
+ return 0.0
+
+ text1 = preprocess_text(text1)
+ text2 = preprocess_text(text2)
+ print(f"cosine similarity text1 is: {text1}, text2 is: {text2}")
+
+ ngrams1 = [text1[i : i + n] for i in range(len(text1) - n + 1)]
+ ngrams2 = [text2[i : i + n] for i in range(len(text2) - n + 1)]
+ counter1 = Counter(ngrams1)
+ counter2 = Counter(ngrams2)
+
+ all_ngrams = set(counter1.keys()) | set(counter2.keys())
+ vec1 = [counter1.get(ng, 0) for ng in all_ngrams]
+ vec2 = [counter2.get(ng, 0) for ng in all_ngrams]
+ dot_product = sum(a * b for a, b in zip(vec1, vec2))
+ norm1 = sum(a * a for a in vec1) ** 0.5
+ norm2 = sum(b * b for b in vec2) ** 0.5
+ if norm1 == 0 or norm2 == 0:
+ return 0.0
+ return dot_product / (norm1 * norm2)
+
+
+def _merge_base64_audio_to_segment(base64_list: list[str]):
+ from pydub import AudioSegment
+
+ merged = None
+ for b64 in base64_list:
+ raw = base64.b64decode(b64.split(",", 1)[-1])
+ seg = AudioSegment.from_file(io.BytesIO(raw))
+ merged = seg if merged is None else merged + seg
+ return merged
+
+
+@contextmanager
+def _serialize_whisper_small_model_download():
+ """Serialize Whisper ``small`` cache writes across processes (Linux/Unix)."""
+ import fcntl
+
+ lock_path = Path.home() / ".cache" / "whisper" / ".small_model_download.lock"
+ lock_path.parent.mkdir(parents=True, exist_ok=True)
+ f = open(lock_path, "a+b")
+ try:
+ fcntl.flock(f.fileno(), fcntl.LOCK_EX)
+ yield
+ finally:
+ fcntl.flock(f.fileno(), fcntl.LOCK_UN)
+ f.close()
+
+
+def _whisper_transcribe_in_current_process(output_path: str) -> str:
+ import whisper
+
+ device_index = None
+ from vllm_omni.platforms import current_omni_platform
+
+ if current_omni_platform.is_available():
+ n = current_omni_platform.get_device_count()
+ if n == 1:
+ device_index = 0
+ elif n > 1:
+ device_index = n - 1
+
+ if device_index is not None:
+ torch_device = current_omni_platform.get_torch_device(device_index)
+ current_omni_platform.set_device(torch_device)
+ device = str(torch_device)
+ use_accelerator = True
+ else:
+ use_accelerator = False
+ device = "cpu"
+
+ with _serialize_whisper_small_model_download():
+ model = whisper.load_model("small", device=device)
+ try:
+ text = model.transcribe(
+ output_path,
+ temperature=0.0,
+ word_timestamps=True,
+ condition_on_previous_text=False,
+ )["text"]
+ finally:
+ del model
+ gc.collect()
+ if use_accelerator:
+ current_omni_platform.synchronize()
+ current_omni_platform.empty_cache()
+ return text or ""
+
+
+def convert_audio_file_to_text(output_path: str) -> str:
+ """Convert an audio file to text in an isolated subprocess."""
+ ctx = multiprocessing.get_context("spawn")
+ with concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=ctx) as executor:
+ future = executor.submit(_whisper_transcribe_in_current_process, output_path)
+ return future.result()
+
+
+def convert_audio_bytes_to_text(raw_bytes: bytes) -> str:
+ output_path = f"./test_{uuid.uuid4().hex}.wav"
+ data, samplerate = sf.read(io.BytesIO(raw_bytes))
+ sf.write(output_path, data, samplerate, format="WAV", subtype="PCM_16")
+ print(f"audio data is saved: {output_path}")
+ return convert_audio_file_to_text(output_path)
+
+
+__all__ = [
+ "_merge_base64_audio_to_segment",
+ "convert_audio_bytes_to_text",
+ "convert_audio_file_to_text",
+ "cosine_similarity_text",
+ "decode_b64_image",
+ "generate_synthetic_audio",
+ "generate_synthetic_image",
+ "generate_synthetic_video",
+ "preprocess_text",
+]
diff --git a/tests/e2e/offline_inference/utils.py b/tests/helpers/process.py
similarity index 58%
rename from tests/e2e/offline_inference/utils.py
rename to tests/helpers/process.py
index 3113599a305..094de965239 100644
--- a/tests/e2e/offline_inference/utils.py
+++ b/tests/helpers/process.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import contextlib
import functools
import os
import signal
@@ -10,73 +9,48 @@
import tempfile
from collections.abc import Callable
from contextlib import ExitStack, suppress
-from pathlib import Path
from typing import Any, Literal
import cloudpickle
from typing_extensions import ParamSpec
from vllm.platforms import current_platform
-VLLM_PATH = Path(__file__).parent.parent.parent
-"""Path to root of the vLLM repository."""
-
-
_P = ParamSpec("_P")
def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
- """Decorator to fork a new process for each test function.
- See https://github.com/vllm-project/vllm/issues/7053 for more details.
- """
+ """Decorator to fork a new process for each test function."""
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- # Make the process the leader of its own process group
- # to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
- # Create a unique temporary file to store exception info from child
- # process. Use test function name and process ID to avoid collisions.
with (
tempfile.NamedTemporaryFile(
- delete=False,
- mode="w+b",
- prefix=f"vllm_test_{func.__name__}_{os.getpid()}_",
- suffix=".exc",
+ delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc"
) as exc_file,
ExitStack() as delete_after,
):
exc_file_path = exc_file.name
delete_after.callback(os.remove, exc_file_path)
-
pid = os.fork()
- print(f"Fork a new process to run a test {pid}")
if pid == 0:
- # Parent process responsible for deleting, don't delete
- # in child.
delete_after.pop_all()
try:
func(*args, **kwargs)
except Skipped as e:
- # convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception as e:
import traceback
tb_string = traceback.format_exc()
-
- # Try to serialize the exception object first
exc_to_serialize: dict[str, Any]
try:
- # First, try to pickle the actual exception with
- # its traceback.
exc_to_serialize = {"pickled_exception": e}
- # Test if it can be pickled
cloudpickle.dumps(exc_to_serialize)
except (Exception, KeyboardInterrupt):
- # Fall back to string-based approach.
exc_to_serialize = {
"exception_type": type(e).__name__,
"exception_msg": str(e),
@@ -86,7 +60,6 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
with open(exc_file_path, "wb") as f:
cloudpickle.dump(exc_to_serialize, f)
except Exception:
- # Fallback: just print the traceback.
print(tb_string)
os._exit(1)
else:
@@ -94,40 +67,24 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
- # ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
- # kill all child processes
os.killpg(pgid, signal.SIGTERM)
- # restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
if _exitcode != 0:
- # Try to read the exception from the child process
exc_info = {}
if os.path.exists(exc_file_path):
- with (
- contextlib.suppress(Exception),
- open(exc_file_path, "rb") as f,
- ):
+ with suppress(Exception), open(exc_file_path, "rb") as f:
exc_info = cloudpickle.load(f)
-
- original_exception = exc_info.get("pickled_exception")
- if original_exception is not None and isinstance(original_exception, Exception):
- # Re-raise the actual exception object if it was
- # successfully pickled.
+ if (original_exception := exc_info.get("pickled_exception")) is not None:
+ assert isinstance(original_exception, Exception)
raise original_exception
-
if (original_tb := exc_info.get("traceback")) is not None:
- # Use string-based traceback for fallback case
raise AssertionError(
- f"Test {func.__name__} failed when called with"
- f" args {args} and kwargs {kwargs}"
+ f"Test {func.__name__} failed when called with args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode}):\n{original_tb}"
) from None
-
- # Fallback to the original generic error
raise AssertionError(
- f"function {func.__name__} failed when called with"
- f" args {args} and kwargs {kwargs}"
+ f"function {func.__name__} failed when called with args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode})"
) from None
@@ -139,9 +96,7 @@ def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- # Check if we're already in a subprocess
if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
- # If we are, just run the function directly
return f(*args, **kwargs)
import torch.multiprocessing as mp
@@ -149,33 +104,18 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
with suppress(RuntimeError):
mp.set_start_method("spawn")
- # Get the module
module_name = f.__module__
-
- # Create a process with environment variable set
env = os.environ.copy()
env["RUNNING_IN_SUBPROCESS"] = "1"
with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "new_process.tmp")
-
- # `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath))
-
- repo_root = str(VLLM_PATH.resolve())
-
- env = dict(env or os.environ)
- env["PYTHONPATH"] = repo_root + os.pathsep + env.get("PYTHONPATH", "")
-
cmd = [sys.executable, "-m", f"{module_name}"]
-
returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env)
-
- # check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
- # wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e
return wrapper
@@ -184,27 +124,11 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
def create_new_process_for_each_test(
method: Literal["spawn", "fork"] | None = None,
) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
- """Creates a decorator that runs each test function in a new process.
-
- Args:
- method: The process creation method. Can be either "spawn" or "fork".
- If not specified, it defaults to "spawn" on ROCm and XPU
- platforms and "fork" otherwise.
-
- Returns:
- A decorator to run test functions in separate processes.
- """
+ """Creates a decorator that runs each test function in a new process."""
if method is None:
- # TODO: Find out why spawn is not working correctly on ROCm
- # The test content will not run and tests passed immediately.
- # For now, using `fork` for ROCm as it can run with `fork`
- # and tests are running correctly.
use_spawn = current_platform.is_xpu()
method = "spawn" if use_spawn else "fork"
-
assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
-
if method == "fork":
return fork_new_process_for_each_test
-
return spawn_new_process_for_each_test
diff --git a/tests/helpers/runtime.py b/tests/helpers/runtime.py
new file mode 100644
index 00000000000..9520aff0c53
--- /dev/null
+++ b/tests/helpers/runtime.py
@@ -0,0 +1,1413 @@
+"""Server/client/runner runtime primitives for tests."""
+
+import base64
+import concurrent.futures
+import io
+import json
+import os
+import socket
+import subprocess
+import sys
+import tempfile
+import time
+from dataclasses import dataclass
+from io import BytesIO
+from pathlib import Path
+from typing import Any, NamedTuple
+
+import psutil
+import requests
+import soundfile as sf
+import torch
+import yaml
+from openai import OpenAI, omit
+from PIL import Image
+from vllm import TextPrompt
+from vllm.distributed.parallel_state import cleanup_dist_env_and_memory
+from vllm.logger import init_logger
+
+from tests.helpers.assertions import (
+ assert_audio_speech_response,
+ assert_diffusion_response,
+ assert_omni_response,
+)
+from tests.helpers.env import run_forced_gpu_cleanup_round
+from tests.helpers.media import (
+ _merge_base64_audio_to_segment,
+ convert_audio_bytes_to_text,
+ cosine_similarity_text,
+ decode_b64_image,
+)
+from vllm_omni.config.stage_config import resolve_deploy_yaml
+from vllm_omni.platforms import current_omni_platform
+
+logger = init_logger(__name__)
+
+PromptAudioInput = list[tuple[Any, int]] | tuple[Any, int] | None
+PromptImageInput = list[Any] | Any | None
+PromptVideoInput = list[Any] | Any | None
+
+try:
+ from vllm.distributed.parallel_state import cleanup_dist_env_and_memory # type: ignore
+except Exception: # pragma: no cover
+
+ def cleanup_dist_env_and_memory() -> None:
+ return None
+
+
+def get_open_port() -> int:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("127.0.0.1", 0))
+ return int(s.getsockname()[1])
+
+
+def dummy_messages_from_mix_data(
+ system_prompt: dict[str, Any] = None,
+ video_data_url: Any = None,
+ audio_data_url: Any = None,
+ image_data_url: Any = None,
+ content_text: str = None,
+):
+ """Create messages with video、image、audio data URL for OpenAI API."""
+ if content_text is not None:
+ content = [{"type": "text", "text": content_text}]
+ else:
+ content = []
+
+ media_items = []
+ if isinstance(video_data_url, list):
+ for video_url in video_data_url:
+ media_items.append((video_url, "video"))
+ else:
+ media_items.append((video_data_url, "video"))
+
+ if isinstance(image_data_url, list):
+ for url in image_data_url:
+ media_items.append((url, "image"))
+ else:
+ media_items.append((image_data_url, "image"))
+
+ if isinstance(audio_data_url, list):
+ for url in audio_data_url:
+ media_items.append((url, "audio"))
+ else:
+ media_items.append((audio_data_url, "audio"))
+
+ content.extend(
+ {"type": f"{media_type}_url", f"{media_type}_url": {"url": url}}
+ for url, media_type in media_items
+ if url is not None
+ )
+ messages = [{"role": "user", "content": content}]
+ if system_prompt is not None:
+ messages = [system_prompt] + messages
+ return messages
+
+
+def _omni_subprocess_cwd() -> str:
+ """Repo root for ``python -m vllm_omni...`` (legacy conftest lived under ``tests/``; helpers under ``tests/helpers/``)."""
+ return os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
+
+
+class OmniServerParams(NamedTuple):
+ model: str
+ port: int | None = None
+ stage_config_path: str | None = None
+ server_args: list[str] | None = None
+ env_dict: dict[str, str] | None = None
+ use_omni: bool = True
+ use_stage_cli: bool = False
+ init_timeout: int | None = None
+ stage_init_timeout: int | None = None # None: fixture supplies default (600 s)
+
+
+class OmniServer:
+ """Omniserver for vLLM-Omni tests."""
+
+ def __init__(
+ self,
+ model: str,
+ serve_args: list[str],
+ *,
+ port: int | None = None,
+ env_dict: dict[str, str] | None = None,
+ use_omni: bool = True,
+ ) -> None:
+ run_forced_gpu_cleanup_round()
+ cleanup_dist_env_and_memory()
+ self.model = model
+ self.serve_args = serve_args
+ self.env_dict = env_dict
+ self.use_omni = use_omni
+ self.proc: subprocess.Popen | None = None
+ self.host = "127.0.0.1"
+ self.port = get_open_port() if port is None else port
+
+ def _start_server(self) -> None:
+ env = os.environ.copy()
+ env.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
+ if self.env_dict is not None:
+ env.update(self.env_dict)
+
+ cmd = [
+ sys.executable,
+ "-m",
+ "vllm_omni.entrypoints.cli.main",
+ "serve",
+ self.model,
+ "--host",
+ self.host,
+ "--port",
+ str(self.port),
+ ]
+ if self.use_omni:
+ cmd.append("--omni")
+ cmd += self.serve_args
+
+ print(f"Launching OmniServer with: {' '.join(cmd)}")
+ self.proc = subprocess.Popen(
+ cmd,
+ env=env,
+ cwd=_omni_subprocess_cwd(),
+ )
+
+ max_wait = 1200
+ start_time = time.time()
+ while time.time() - start_time < max_wait:
+ ret = self.proc.poll()
+ if ret is not None:
+ raise RuntimeError(f"Server processes exited with code {ret} before becoming ready.")
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.settimeout(1)
+ if sock.connect_ex((self.host, self.port)) == 0:
+ print(f"Server ready on {self.host}:{self.port}")
+ return
+ time.sleep(2)
+ raise RuntimeError(f"Server failed to start within {max_wait} seconds")
+
+ def _kill_process_tree(self, pid):
+ try:
+ parent = psutil.Process(pid)
+ children = parent.children(recursive=True)
+ all_pids = [pid] + [child.pid for child in children]
+
+ for child in children:
+ try:
+ child.terminate()
+ except psutil.NoSuchProcess:
+ pass
+
+ _, still_alive = psutil.wait_procs(children, timeout=10)
+
+ for child in still_alive:
+ try:
+ child.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ try:
+ parent.terminate()
+ parent.wait(timeout=10)
+ except (psutil.NoSuchProcess, psutil.TimeoutExpired):
+ try:
+ parent.kill()
+ except psutil.NoSuchProcess:
+ pass
+
+ time.sleep(1)
+ alive_processes = []
+ for check_pid in all_pids:
+ if psutil.pid_exists(check_pid):
+ alive_processes.append(check_pid)
+
+ if alive_processes:
+ print(f"Warning: Processes still alive: {alive_processes}")
+ for alive_pid in alive_processes:
+ try:
+ subprocess.run(["kill", "-9", str(alive_pid)], timeout=2)
+ except Exception as e:
+ print(f"Cleanup failed: {e}")
+
+ except psutil.NoSuchProcess:
+ pass
+
+ def __enter__(self):
+ self._start_server()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.proc:
+ self._kill_process_tree(self.proc.pid)
+ run_forced_gpu_cleanup_round()
+ cleanup_dist_env_and_memory()
+
+
+class OmniServerStageCli(OmniServer):
+ """Omni server harness that exercises the stage CLI flow."""
+
+ def __init__(
+ self,
+ model: str,
+ stage_config_path: str,
+ serve_args: list[str] | None = None,
+ *,
+ stage_ids: list[int] | None = None,
+ port: int | None = None,
+ env_dict: dict[str, str] | None = None,
+ ) -> None:
+ super().__init__(model, serve_args or [], port=port, env_dict=env_dict, use_omni=True)
+ self.stage_config_path = stage_config_path
+ self.master_port = get_open_port()
+ self.visible_device_list = self._load_visible_device_list(env_dict)
+ resolved_cfg = resolve_deploy_yaml(stage_config_path)
+ # Dump the resolved deploy config so CI logs show each stage's
+ # gpu_memory_utilization / max_model_len / max_num_seqs after
+ # base_config inheritance and overlay merge — essential when
+ # diagnosing OOMs that depend on the merged values.
+ print(
+ f"[OmniServerStageCli] Resolved deploy config from {stage_config_path}:\n"
+ f"{yaml.safe_dump(resolved_cfg, sort_keys=False, default_flow_style=False)}",
+ flush=True,
+ )
+ self.stage_runtime_devices = self._load_stage_runtime_devices(resolved_cfg)
+ self.stage_ids = stage_ids or self._load_stage_ids(resolved_cfg)
+ if 0 not in self.stage_ids:
+ raise ValueError(f"Stage CLI test requires stage_id=0 in config: {stage_config_path}")
+ self.stage_procs: dict[int, subprocess.Popen] = {}
+ self.proc = None
+
+ @staticmethod
+ def _stage_entries(cfg: dict) -> list[dict]:
+ """Return the list of stage entries from either legacy (``stage_args``)
+ or new-schema (``stages``) deploy YAMLs."""
+ return cfg.get("stage_args") or cfg.get("stages") or []
+
+ @staticmethod
+ def _load_stage_ids(resolved_config: dict) -> list[int]:
+ stage_ids = [
+ stage["stage_id"] for stage in OmniServerStageCli._stage_entries(resolved_config) if "stage_id" in stage
+ ]
+ if not stage_ids:
+ raise ValueError("No stage IDs found in resolved config")
+ return stage_ids
+
+ @staticmethod
+ def _load_stage_runtime_devices(resolved_config: dict) -> dict[int, str]:
+ runtime_devices: dict[int, str] = {}
+ for stage in OmniServerStageCli._stage_entries(resolved_config):
+ stage_id = stage.get("stage_id")
+ # New schema: stage.devices is flat at stage level.
+ # Legacy schema: stage.runtime.devices is nested.
+ devices = stage.get("devices") or stage.get("runtime", {}).get("devices")
+ if stage_id is not None and devices:
+ runtime_devices[int(stage_id)] = str(devices)
+ return runtime_devices
+
+ @classmethod
+ def _parse_device_list(cls, devices: str | int) -> list[str]:
+ if isinstance(devices, int):
+ if devices < 0:
+ raise ValueError("Device IDs must be non-negative integers")
+ return [str(devices)]
+ return [token.strip() for token in str(devices).split(",") if token.strip()]
+
+ @classmethod
+ def _load_visible_device_list(cls, env_dict: dict[str, str] | None) -> list[str] | None:
+ env = os.environ.copy()
+ if env_dict is not None:
+ env.update(env_dict)
+
+ env_var = getattr(current_omni_platform, "device_control_env_var", None)
+ if env_var and env_var in env:
+ return [token.strip() for token in env[env_var].split(",") if token.strip()]
+ return None
+
+ @classmethod
+ def _map_stage_devices(cls, stage_id: int, visible_device_list: list[str] | None, devices: str) -> str:
+ device_list = cls._parse_device_list(devices)
+
+ if visible_device_list is None:
+ return ",".join(device_list)
+
+ if not all(device.isdigit() for device in device_list):
+ raise ValueError("Logical devices must be non-negative integers")
+
+ logical_ids = [int(device) for device in device_list]
+ if logical_ids and max(logical_ids) >= len(visible_device_list):
+ raise ValueError(
+ f"Stage {stage_id} has logical IDs {device_list}, one or more of which exceed the number of visible devices"
+ )
+
+ return ",".join(visible_device_list[idx] for idx in logical_ids)
+
+ def _set_stage_device_env(self, stage_id: int, env: dict[str, str], devices: str) -> None:
+ mapped_devices = self._map_stage_devices(stage_id, self.visible_device_list, devices)
+ env_var = getattr(current_omni_platform, "device_control_env_var", None)
+ if env_var:
+ env[env_var] = mapped_devices
+
+ def _build_stage_cmd(self, stage_id: int, *, headless: bool) -> list[str]:
+ cmd = [
+ sys.executable,
+ "-m",
+ "vllm_omni.entrypoints.cli.main",
+ "serve",
+ self.model,
+ "--omni",
+ "--stage-configs-path",
+ self.stage_config_path,
+ "--stage-id",
+ str(stage_id),
+ "--omni-master-address",
+ self.host,
+ "--omni-master-port",
+ str(self.master_port),
+ ]
+
+ if headless:
+ cmd.append("--headless")
+ else:
+ cmd += ["--host", self.host, "--port", str(self.port)]
+
+ cmd += self.serve_args
+ return cmd
+
+ def _launch_stage(self, stage_id: int, *, headless: bool) -> None:
+ env = os.environ.copy()
+ env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
+ if self.env_dict is not None:
+ env.update(self.env_dict)
+
+ devices = self.stage_runtime_devices.get(stage_id)
+ if devices:
+ self._set_stage_device_env(stage_id, env, devices)
+
+ cmd = self._build_stage_cmd(stage_id, headless=headless)
+ print(f"Launching OmniServerStageCli stage {stage_id}: {' '.join(cmd)}")
+ # Capture each subprocess's stdout+stderr to a per-stage log file so
+ # debugging "Stage N exited before API server ready" doesn't rely on
+ # guessing; the file is surfaced in the RuntimeError message.
+ log_path = Path(tempfile.gettempdir()) / f"omni_stage_{stage_id}_{self.master_port}.log"
+ self._stage_log_paths = getattr(self, "_stage_log_paths", {})
+ self._stage_log_paths[stage_id] = log_path
+ log_fh = open(log_path, "w", buffering=1) # noqa: SIM115 - closed in __exit__
+ self._stage_log_files = getattr(self, "_stage_log_files", {})
+ self._stage_log_files[stage_id] = log_fh
+ proc = subprocess.Popen(
+ cmd,
+ env=env,
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
+ stdout=log_fh,
+ stderr=subprocess.STDOUT,
+ )
+ self.stage_procs[stage_id] = proc
+ if stage_id == 0:
+ self.proc = proc
+
+ def _ensure_stage_processes_alive(self) -> None:
+ for stage_id, proc in self.stage_procs.items():
+ ret = proc.poll()
+ if ret is not None:
+ log_path = getattr(self, "_stage_log_paths", {}).get(stage_id)
+ tail = ""
+ if log_path and log_path.exists():
+ try:
+ with open(log_path, encoding="utf-8", errors="replace") as f:
+ lines = f.readlines()
+ tail = "\n=== Last 60 lines of stage {} log ({}) ===\n{}".format(
+ stage_id, log_path, "".join(lines[-60:]) or ""
+ )
+ except Exception as exc: # pragma: no cover - diagnostic only
+ tail = f"\n"
+ raise RuntimeError(f"Stage {stage_id} exited with code {ret} before API server became ready.{tail}")
+
+ def _start_server(self) -> None:
+ ordered_stage_ids = [0, *[stage_id for stage_id in self.stage_ids if stage_id != 0]]
+
+ self._launch_stage(0, headless=False)
+ time.sleep(2)
+ self._ensure_stage_processes_alive()
+
+ for stage_id in ordered_stage_ids[1:]:
+ self._launch_stage(stage_id, headless=True)
+
+ max_wait = 1200
+ start_time = time.time()
+ while time.time() - start_time < max_wait:
+ self._ensure_stage_processes_alive()
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.settimeout(1)
+ result = sock.connect_ex((self.host, self.port))
+ if result == 0:
+ print(f"OmniServerStageCli ready on {self.host}:{self.port}")
+ return
+ time.sleep(2)
+
+ raise RuntimeError(f"OmniServerStageCli failed to start within {max_wait} seconds")
+
+ def _dump_stage_logs_for_debug(self, head_lines: int = 300, tail_lines: int = 500) -> None:
+ """Tail each stage's subprocess log back to stdout on teardown.
+
+ Stage subprocesses redirect stdout/stderr to ``/tmp/omni_stage_*.log``
+ so we don't spam the main CI stream while tests run; but that also
+ hides engine init (KV cache size, Available KV cache memory, vLLM
+ engine config) when things go wrong. Dump them here so buildkite
+ captures them post-run. Head covers engine init; tail covers
+ whatever state the stage was in when it was torn down.
+ """
+ log_paths = getattr(self, "_stage_log_paths", {}) or {}
+ for stage_id in sorted(log_paths):
+ log_path = log_paths[stage_id]
+ if not log_path or not log_path.exists():
+ continue
+ try:
+ with open(log_path, encoding="utf-8", errors="replace") as f:
+ lines = f.readlines()
+ except Exception as exc: # pragma: no cover - diagnostic only
+ print(f"[OmniServerStageCli] stage {stage_id} log read failed: {exc}", flush=True)
+ continue
+ total = len(lines)
+ if total <= head_lines + tail_lines:
+ head_chunk = lines
+ tail_chunk = []
+ elided = 0
+ else:
+ head_chunk = lines[:head_lines]
+ tail_chunk = lines[-tail_lines:]
+ elided = total - head_lines - tail_lines
+ print(f"\n=== stage {stage_id} log HEAD ({log_path}) ===", flush=True)
+ print("".join(head_chunk).rstrip("\n"), flush=True)
+ if tail_chunk:
+ print(f"\n... [{elided} lines elided] ...", flush=True)
+ print(f"\n=== stage {stage_id} log TAIL ({log_path}) ===", flush=True)
+ print("".join(tail_chunk).rstrip("\n"), flush=True)
+ print(f"=== end stage {stage_id} log ===\n", flush=True)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._dump_stage_logs_for_debug()
+ for stage_id in sorted(self.stage_procs, reverse=True):
+ proc = self.stage_procs[stage_id]
+ if proc.poll() is None:
+ self._kill_process_tree(proc.pid)
+ run_forced_gpu_cleanup_round()
+ cleanup_dist_env_and_memory()
+
+
+@dataclass
+class OmniResponse:
+ text_content: str | None = None
+ audio_data: list[str] | None = None
+ audio_content: str | None = None
+ audio_format: str | None = None
+ audio_bytes: bytes | None = None
+ similarity: float | None = None
+ e2e_latency: float | None = None
+ success: bool = False
+ error_message: str | None = None
+ cached_tokens: int | None = None
+
+
+@dataclass
+class DiffusionResponse:
+ text_content: str | None = None
+ images: list[Image.Image] | None = None
+ audios: list[Any] | None = None
+ videos: list[Any] | None = None
+ e2e_latency: float | None = None
+ success: bool = False
+ error_message: str | None = None
+
+
+class OpenAIClientHandler:
+ def __init__(self, host: str = "127.0.0.1", port: int = None, api_key: str = "EMPTY", run_level: str = None):
+ if port is None:
+ port = get_open_port()
+ self.base_url = f"http://{host}:{port}"
+ self.client = OpenAI(base_url=f"http://{host}:{port}/v1", api_key=api_key)
+ self.run_level = run_level
+
+ def _process_stream_omni_response(self, chat_completion) -> OmniResponse:
+ result = OmniResponse()
+ start_time = time.perf_counter()
+ try:
+ text_content = ""
+ audio_data = []
+ for chunk in chat_completion:
+ for choice in chunk.choices:
+ content = getattr(getattr(choice, "delta", None), "content", None)
+ modality = getattr(chunk, "modality", None)
+ if modality == "audio" and content:
+ audio_data.append(content)
+ elif modality == "text" and content:
+ text_content += content
+ result.e2e_latency = time.perf_counter() - start_time
+ audio_content = None
+ similarity = None
+ if audio_data:
+ merged_seg = _merge_base64_audio_to_segment(audio_data)
+ wav_buf = BytesIO()
+ merged_seg.export(wav_buf, format="wav")
+ result.audio_bytes = wav_buf.getvalue()
+ audio_content = convert_audio_bytes_to_text(result.audio_bytes)
+ if audio_content and text_content:
+ similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
+ result.text_content = text_content
+ result.audio_data = audio_data
+ result.audio_content = audio_content
+ result.similarity = similarity
+ result.success = True
+ except Exception as e:
+ result.error_message = f"Stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+ return result
+
+ def _process_non_stream_omni_response(self, chat_completion) -> OmniResponse:
+ result = OmniResponse()
+ start_time = time.perf_counter()
+ try:
+ audio_data = None
+ text_content = None
+ for choice in chat_completion.choices:
+ if hasattr(choice.message, "audio") and choice.message.audio is not None:
+ audio_data = choice.message.audio.data
+ if hasattr(choice.message, "content") and choice.message.content is not None:
+ text_content = choice.message.content
+ # Extract cached_tokens for prefix caching tests
+ usage = getattr(chat_completion, "usage", None)
+ if usage and (details := getattr(usage, "prompt_tokens_details", None)):
+ result.cached_tokens = details.cached_tokens
+ result.e2e_latency = time.perf_counter() - start_time
+ audio_content = None
+ similarity = None
+ if audio_data:
+ result.audio_bytes = base64.b64decode(audio_data)
+ audio_content = convert_audio_bytes_to_text(result.audio_bytes)
+ if audio_content and text_content:
+ similarity = cosine_similarity_text(audio_content.lower(), text_content.lower())
+ result.text_content = text_content
+ result.audio_content = audio_content
+ result.similarity = similarity
+ result.success = True
+ except Exception as e:
+ result.error_message = f"Non-stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+ return result
+
+ def _process_diffusion_response(self, chat_completion) -> DiffusionResponse:
+ result = DiffusionResponse()
+ start_time = time.perf_counter()
+ try:
+ images = []
+ for choice in chat_completion.choices:
+ content = getattr(choice.message, "content", None)
+ if isinstance(content, list):
+ for item in content:
+ image_url = None
+ if isinstance(item, dict):
+ image_url = item.get("image_url", {}).get("url")
+ else:
+ image_url_obj = getattr(item, "image_url", None)
+ image_url = getattr(image_url_obj, "url", None) if image_url_obj else None
+ if image_url and image_url.startswith("data:image"):
+ b64_data = image_url.split(",", 1)[1]
+ images.append(decode_b64_image(b64_data))
+ result.e2e_latency = time.perf_counter() - start_time
+ result.images = images if images else None
+ result.success = True
+ except Exception as e:
+ result.error_message = f"Diffusion response processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+ return result
+
+ def send_omni_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ responses: list[OmniResponse] = []
+ stream = request_config.get("stream", False)
+ modalities = request_config.get("modalities", ["text", "audio"])
+ extra_body: dict[str, Any] = {}
+ if "speaker" in request_config:
+ extra_body["speaker"] = request_config["speaker"]
+ if request_config.get("use_audio_in_video"):
+ mm = dict(extra_body.get("mm_processor_kwargs") or {})
+ mm["use_audio_in_video"] = True
+ extra_body["mm_processor_kwargs"] = mm
+ create_kwargs: dict[str, Any] = {
+ "model": request_config.get("model"),
+ "messages": request_config.get("messages"),
+ "stream": stream,
+ "modalities": modalities,
+ }
+ if extra_body:
+ create_kwargs["extra_body"] = extra_body
+
+ if request_num == 1:
+ chat_completion = self.client.chat.completions.create(**create_kwargs)
+ resp = (
+ self._process_stream_omni_response(chat_completion)
+ if stream
+ else self._process_non_stream_omni_response(chat_completion)
+ )
+ assert_omni_response(resp, request_config, run_level=self.run_level)
+ responses.append(resp)
+ return responses
+
+ def _one():
+ chat_completion = self.client.chat.completions.create(**create_kwargs)
+ return (
+ self._process_stream_omni_response(chat_completion)
+ if stream
+ else self._process_non_stream_omni_response(chat_completion)
+ )
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = [executor.submit(_one) for _ in range(request_num)]
+ for future in concurrent.futures.as_completed(futures):
+ resp = future.result()
+ assert_omni_response(resp, request_config, run_level=self.run_level)
+ responses.append(resp)
+ return responses
+
+ def _process_stream_audio_speech_response(self, response, *, response_format: str | None = None) -> OmniResponse:
+ """
+ Process streaming /v1/audio/speech responses into an OmniResponse.
+
+ This mirrors _process_stream_omni_response but operates on low-level
+ audio bytes and produces an OmniResponse with audio_content filled
+ from Whisper transcription.
+ """
+ result = OmniResponse()
+ start_time = time.perf_counter()
+
+ try:
+ # Aggregate all audio bytes from the streaming response.
+ data = bytearray()
+
+ # Preferred OpenAI helper.
+ if hasattr(response, "iter_bytes") and callable(getattr(response, "iter_bytes")):
+ for chunk in response.iter_bytes():
+ if chunk:
+ data.extend(chunk)
+ else:
+ # Generic iterable-of-bytes fallback (e.g., generator or list of chunks).
+ try:
+ iterator = iter(response)
+ except TypeError:
+ iterator = None
+
+ if iterator is not None:
+ for chunk in iterator:
+ if not chunk:
+ continue
+ if isinstance(chunk, (bytes, bytearray)):
+ data.extend(chunk)
+ elif hasattr(chunk, "data"):
+ data.extend(chunk.data) # type: ignore[arg-type]
+ elif hasattr(chunk, "content"):
+ data.extend(chunk.content) # type: ignore[arg-type]
+ else:
+ raise TypeError(f"Unsupported stream chunk type: {type(chunk)}")
+ else:
+ raise TypeError(f"Unsupported audio speech streaming response type: {type(response)}")
+
+ raw_bytes = bytes(data)
+ if response_format == "pcm":
+ transcript = None
+ else:
+ transcript = convert_audio_bytes_to_text(raw_bytes)
+
+ # Populate OmniResponse.
+ result.audio_bytes = raw_bytes
+ result.audio_content = transcript
+ result.e2e_latency = time.perf_counter() - start_time
+ result.success = True
+ result.audio_format = getattr(response, "response", None)
+ if result.audio_format is not None:
+ result.audio_format = result.audio_format.headers.get("content-type", "")
+
+ except Exception as e:
+ result.error_message = f"Audio speech stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def _process_non_stream_audio_speech_response(
+ self, response, *, response_format: str | None = None
+ ) -> OmniResponse:
+ """
+ Process non-streaming /v1/audio/speech responses into an OmniResponse.
+
+ This mirrors _process_non_stream_omni_response but for the binary
+ audio payload returned by audio.speech.create.
+ """
+ result = OmniResponse()
+ start_time = time.perf_counter()
+
+ try:
+ # OpenAI non-streaming audio.speech.create returns HttpxBinaryResponseContent (.read() or .content)
+ if hasattr(response, "read") and callable(getattr(response, "read")):
+ raw_bytes = response.read()
+ elif hasattr(response, "content"):
+ raw_bytes = response.content # type: ignore[assignment]
+ else:
+ raise TypeError(f"Unsupported audio speech response type: {type(response)}")
+
+ if response_format == "pcm":
+ transcript = None
+ else:
+ transcript = convert_audio_bytes_to_text(raw_bytes)
+
+ result.audio_bytes = raw_bytes
+ result.audio_content = transcript
+ result.e2e_latency = time.perf_counter() - start_time
+ result.success = True
+ result.audio_format = getattr(response, "response", None)
+ if result.audio_format is not None:
+ result.audio_format = result.audio_format.headers.get("content-type", "")
+
+ except Exception as e:
+ result.error_message = f"Audio speech non-stream processing error: {str(e)}"
+ print(f"Error: {result.error_message}")
+
+ return result
+
+ def send_audio_speech_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ """
+ Call the /v1/audio/speech endpoint using the same configuration-dict
+ style as send_omni_request, but via the OpenAI Python client's
+ audio.speech APIs.
+
+ Expected keys in request_config:
+ - model: model name/path (required)
+ - input: text to synthesize (required)
+ - response_format: audio format such as "wav" or "pcm" (optional)
+ - task_type, ref_text, ref_audio: TTS-specific extras (optional, passed via extra_body)
+ - timeout: request timeout in seconds (float, optional, default 120.0)
+ - stream: whether to use streaming API (bool, optional, default False)
+ """
+ timeout = float(request_config.get("timeout", 120.0))
+
+ model = request_config["model"]
+ text_input = request_config["input"]
+ stream = bool(request_config.get("stream", False))
+ voice = request_config.get("voice", None)
+
+ # Standard OpenAI param: use omit when not provided to keep default behavior.
+ response_format = request_config.get("response_format", omit)
+
+ # Qwen3-TTS custom fields, forwarded via extra_body.
+ extra_body: dict[str, Any] = {}
+ # Keep this list aligned with vllm_omni.entrypoints.openai.protocol.audio params.
+ for key in ("task_type", "ref_text", "ref_audio", "language", "max_new_tokens"):
+ if key in request_config:
+ extra_body[key] = request_config[key]
+
+ responses: list[OmniResponse] = []
+
+ speech_fmt: str | None = None if response_format is omit else str(response_format).lower()
+
+ if request_num == 1:
+ if stream:
+ # Use streaming response helper.
+ with self.client.audio.speech.with_streaming_response.create(
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ ) as resp:
+ omni_resp = self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
+ else:
+ # Non-streaming response.
+ resp = self.client.audio.speech.create(
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ )
+ omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
+
+ assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
+ responses.append(omni_resp)
+ return responses
+ else:
+ # request_num > 1: concurrent requests (use same params as single-request path)
+
+ if stream:
+
+ def _stream_task():
+ with self.client.audio.speech.with_streaming_response.create(
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ ) as resp:
+ return self._process_stream_audio_speech_response(resp, response_format=speech_fmt)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = [executor.submit(_stream_task) for _ in range(request_num)]
+ for future in concurrent.futures.as_completed(futures):
+ omni_resp = future.result()
+ assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
+ responses.append(omni_resp)
+ else:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = []
+ for _ in range(request_num):
+ future = executor.submit(
+ self.client.audio.speech.create,
+ model=model,
+ input=text_input,
+ response_format=response_format,
+ extra_body=extra_body or None,
+ timeout=timeout,
+ voice=voice,
+ )
+ futures.append(future)
+
+ for future in concurrent.futures.as_completed(futures):
+ resp = future.result()
+ omni_resp = self._process_non_stream_audio_speech_response(resp, response_format=speech_fmt)
+ assert_audio_speech_response(omni_resp, request_config, run_level=self.run_level)
+ responses.append(omni_resp)
+
+ return responses
+
+ def send_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[DiffusionResponse]:
+ """
+ Send OpenAI requests for diffusion models.
+ Args:
+ request_config: Request configuration dictionary containing parameters like model, messages
+ request_num: Number of requests to send concurrently, defaults to 1 (single request)
+ Returns:
+ list[DiffusionResponse]: List of DiffusionResponse objects containing the response data
+ """
+ responses: list[DiffusionResponse] = []
+ stream = request_config.get("stream", False)
+ modalities = request_config.get("modalities", omit) # Most diffusion models don't require modalities param
+ extra_body = request_config.get("extra_body", None)
+ if stream:
+ raise NotImplementedError("Streaming is not currently implemented for diffusion model e2e test")
+ if request_num == 1:
+ # Send single request
+ chat_completion = self.client.chat.completions.create(
+ model=request_config.get("model"),
+ messages=request_config.get("messages"),
+ extra_body=extra_body,
+ modalities=modalities,
+ )
+ response = self._process_diffusion_response(chat_completion)
+ assert_diffusion_response(response, request_config, run_level=self.run_level)
+ responses.append(response)
+ else:
+ # Send concurrent requests
+ with concurrent.futures.ThreadPoolExecutor(max_workers=request_num) as executor:
+ futures = []
+ # Submit all request tasks
+ for _ in range(request_num):
+ future = executor.submit(
+ self.client.chat.completions.create,
+ model=request_config.get("model"),
+ messages=request_config.get("messages"),
+ modalities=modalities,
+ extra_body=extra_body,
+ )
+ futures.append(future)
+ # Process completed tasks
+ for future in concurrent.futures.as_completed(futures):
+ chat_completion = future.result()
+ response = self._process_diffusion_response(chat_completion)
+ assert_diffusion_response(response, request_config, run_level=self.run_level)
+ responses.append(response)
+ return responses
+
+ def send_video_diffusion_request(self, request_config: dict[str, Any], request_num: int = 1) -> list[OmniResponse]:
+ """
+ Send native /v1/videos requests.
+ """
+ if request_num != 1:
+ raise NotImplementedError("Concurrent video diffusion requests are not currently implemented")
+ form_data = request_config.get("form_data")
+ if not isinstance(form_data, dict):
+ raise ValueError("Video request_config must contain 'form_data'")
+ normalized_form_data = {key: str(value) for key, value in form_data.items() if value is not None}
+ files: dict[str, tuple[str, BytesIO, str]] = {}
+ image_reference = request_config.get("image_reference")
+ if image_reference:
+ if image_reference.startswith("data:image"):
+ header, encoded = image_reference.split(",", 1)
+ content_type = header.split(";")[0].removeprefix("data:")
+ extension = content_type.split("/")[-1]
+ file_data = base64.b64decode(encoded)
+ files["input_reference"] = (f"reference.{extension}", BytesIO(file_data), content_type)
+ else:
+ normalized_form_data["image_reference"] = json.dumps({"image_url": image_reference})
+
+ result = DiffusionResponse()
+ start_time = time.perf_counter()
+ create_url = self._build_url("/v1/videos")
+ response = requests.post(
+ create_url,
+ data=normalized_form_data,
+ files=files,
+ headers={"Accept": "application/json"},
+ timeout=60,
+ )
+ response.raise_for_status()
+ job_data = response.json()
+ video_id = job_data["id"]
+ self._wait_until_video_completed(video_id)
+ video_content = self._download_video_content(video_id)
+ result.success = True
+ result.videos = [video_content]
+ result.e2e_latency = time.perf_counter() - start_time
+ assert_diffusion_response(result, request_config, run_level=self.run_level)
+ return [result]
+
+ def _wait_until_video_completed(
+ self, video_id: str, poll_interval_seconds: int = 2, timeout_seconds: int = 300
+ ) -> None:
+ status_url = self._build_url(f"/v1/videos/{video_id}")
+ deadline = time.monotonic() + timeout_seconds
+ while time.monotonic() < deadline:
+ status_resp = requests.get(status_url, headers={"Accept": "application/json"}, timeout=30)
+ status_resp.raise_for_status()
+ status_data = status_resp.json()
+ current_status = status_data["status"]
+ if current_status == "completed":
+ return
+ if current_status == "failed":
+ error_msg = status_data.get("last_error", "Unknown error")
+ raise RuntimeError(f"Job failed: {error_msg}")
+ time.sleep(poll_interval_seconds)
+ raise TimeoutError(f"Video job {video_id} did not complete within {timeout_seconds}s")
+
+ def _download_video_content(self, video_id: str) -> bytes:
+ download_url = self._build_url(f"/v1/videos/{video_id}/content")
+ video_resp = requests.get(download_url, stream=True, timeout=60)
+ video_resp.raise_for_status()
+ video_bytes = BytesIO()
+ for chunk in video_resp.iter_content(chunk_size=8192):
+ if chunk:
+ video_bytes.write(chunk)
+ return video_bytes.getvalue()
+
+ def _build_url(self, path: str) -> str:
+ return f"{self.base_url.rstrip('/')}/{path.lstrip('/')}"
+
+
+class OmniRunner:
+ def __init__(
+ self,
+ model_name: str,
+ seed: int = 42,
+ stage_init_timeout: int = 600,
+ batch_timeout: int = 10,
+ init_timeout: int = 900,
+ shm_threshold_bytes: int = 65536,
+ log_stats: bool = False,
+ stage_configs_path: str | None = None,
+ **kwargs,
+ ) -> None:
+ cleanup_dist_env_and_memory()
+ run_forced_gpu_cleanup_round()
+ self.model_name = model_name
+ self.seed = seed
+ self._prompt_len_estimate_cache: dict[str, Any] = {}
+ from vllm_omni.entrypoints.omni import Omni
+
+ self.omni = Omni(
+ model=model_name,
+ log_stats=log_stats,
+ stage_init_timeout=stage_init_timeout,
+ batch_timeout=batch_timeout,
+ init_timeout=init_timeout,
+ shm_threshold_bytes=shm_threshold_bytes,
+ stage_configs_path=stage_configs_path,
+ **kwargs,
+ )
+
+ def get_default_sampling_params_list(self) -> list[Any]:
+ if not hasattr(self.omni, "default_sampling_params_list"):
+ raise AttributeError("Omni.default_sampling_params_list is not available")
+ return list(self.omni.default_sampling_params_list)
+
+ def _estimate_prompt_len(
+ self,
+ additional_information: dict[str, Any],
+ model_name: str,
+ ) -> int:
+ """Estimate prompt_token_ids placeholder length for the Talker stage.
+
+ The AR Talker replaces all input embeddings via ``preprocess``, so the
+ placeholder values are irrelevant but the **length** must match the
+ embeddings that ``preprocess`` will produce.
+ """
+ _cache = self._prompt_len_estimate_cache
+ try:
+ from vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts import Qwen3TTSConfig
+ from vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_talker import (
+ Qwen3TTSTalkerForConditionalGeneration,
+ )
+
+ if model_name not in _cache:
+ from transformers import AutoTokenizer
+
+ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left")
+ cfg = Qwen3TTSConfig.from_pretrained(model_name, trust_remote_code=True)
+ _cache[model_name] = (tok, getattr(cfg, "talker_config", None))
+
+ tok, tcfg = _cache[model_name]
+ task_type = (additional_information.get("task_type") or ["CustomVoice"])[0]
+ return Qwen3TTSTalkerForConditionalGeneration.estimate_prompt_len_from_additional_information(
+ additional_information=additional_information,
+ task_type=task_type,
+ tokenize_prompt=lambda t: tok(t, padding=False)["input_ids"],
+ codec_language_id=getattr(tcfg, "codec_language_id", None),
+ spk_is_dialect=getattr(tcfg, "spk_is_dialect", None),
+ )
+ except Exception as exc:
+ logger.warning("Failed to estimate prompt length, using fallback 2048: %s", exc)
+ return 2048
+
+ def get_omni_inputs(
+ self,
+ prompts: list[str] | str,
+ system_prompt: str | None = None,
+ audios: PromptAudioInput = None,
+ images: PromptImageInput = None,
+ videos: PromptVideoInput = None,
+ mm_processor_kwargs: dict[str, Any] | None = None,
+ modalities: list[str] | None = None,
+ ) -> list[TextPrompt]:
+ if system_prompt is None:
+ system_prompt = (
+ "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
+ "Group, capable of perceiving auditory and visual inputs, as well as "
+ "generating text and speech."
+ )
+ video_padding_token = "<|VIDEO|>"
+ image_padding_token = "<|IMAGE|>"
+ audio_padding_token = "<|AUDIO|>"
+ if "Qwen3-Omni-30B-A3B-Instruct" in self.model_name:
+ video_padding_token = "<|video_pad|>"
+ image_padding_token = "<|image_pad|>"
+ audio_padding_token = "<|audio_pad|>"
+ elif "Ming-flash-omni" in self.model_name:
+ video_padding_token = ""
+ image_padding_token = ""
+ audio_padding_token = ""
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ # Qwen-TTS: follow examples/offline_inference/qwen3_tts/end2end.py style.
+ # Stage 0 expects token placeholders + additional_information (text/speaker/task_type/...),
+ # and Talker replaces embeddings in preprocess based on additional_information only.
+ is_tts_model = "Qwen3-TTS" in self.model_name or "qwen3_tts" in self.model_name.lower()
+ if is_tts_model and modalities == ["audio"]:
+ tts_kw = mm_processor_kwargs or {}
+ task_type = tts_kw.get("task_type", "CustomVoice")
+ speaker = tts_kw.get("speaker", "Vivian")
+ language = tts_kw.get("language", "Auto")
+ max_new_tokens = int(tts_kw.get("max_new_tokens", 2048))
+ ref_audio = tts_kw.get("ref_audio", None)
+ ref_text = tts_kw.get("ref_text", None)
+
+ omni_inputs: list[TextPrompt] = []
+ for prompt_text in prompts:
+ text_str = str(prompt_text).strip() or " "
+ additional_information: dict[str, Any] = {
+ "task_type": [task_type],
+ "text": [text_str],
+ "language": [language],
+ "speaker": [speaker],
+ "max_new_tokens": [max_new_tokens],
+ }
+ if ref_audio is not None:
+ additional_information["ref_audio"] = [ref_audio]
+ if ref_text is not None:
+ additional_information["ref_text"] = [ref_text]
+ plen = self._estimate_prompt_len(additional_information, self.model_name)
+ input_dict: TextPrompt = {
+ "prompt_token_ids": [0] * plen,
+ "additional_information": additional_information,
+ }
+ omni_inputs.append(input_dict)
+ return omni_inputs
+
+ def _normalize(mm_input, num_prompts):
+ if mm_input is None:
+ return [None] * num_prompts
+ if isinstance(mm_input, list):
+ if len(mm_input) != num_prompts:
+ raise ValueError("Multimodal input list length must match prompts length")
+ return mm_input
+ return [mm_input] * num_prompts
+
+ num_prompts = len(prompts)
+ audios_list = _normalize(audios, num_prompts)
+ images_list = _normalize(images, num_prompts)
+ videos_list = _normalize(videos, num_prompts)
+
+ omni_inputs = []
+ for i, prompt_text in enumerate(prompts):
+ user_content = ""
+ multi_modal_data = {}
+ audio = audios_list[i]
+ if audio is not None:
+ if isinstance(audio, list):
+ for _ in audio:
+ user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
+ multi_modal_data["audio"] = audio
+ else:
+ user_content += f"<|audio_bos|>{audio_padding_token}<|audio_eos|>"
+ multi_modal_data["audio"] = audio
+ image = images_list[i]
+ if image is not None:
+ if isinstance(image, list):
+ for _ in image:
+ user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
+ multi_modal_data["image"] = image
+ else:
+ user_content += f"<|vision_bos|>{image_padding_token}<|vision_eos|>"
+ multi_modal_data["image"] = image
+ video = videos_list[i]
+ if video is not None:
+ if isinstance(video, list):
+ for _ in video:
+ user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
+ multi_modal_data["video"] = video
+ else:
+ user_content += f"<|vision_bos|>{video_padding_token}<|vision_eos|>"
+ multi_modal_data["video"] = video
+ user_content += prompt_text
+
+ full_prompt = (
+ f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
+ f"<|im_start|>user\n{user_content}<|im_end|>\n"
+ f"<|im_start|>assistant\n"
+ )
+ input_dict: dict[str, Any] = {"prompt": full_prompt}
+ if multi_modal_data:
+ input_dict["multi_modal_data"] = multi_modal_data
+ if modalities:
+ input_dict["modalities"] = modalities
+ if mm_processor_kwargs:
+ input_dict["mm_processor_kwargs"] = mm_processor_kwargs
+ omni_inputs.append(input_dict)
+ return omni_inputs
+
+ def generate(
+ self,
+ prompts: list[Any],
+ sampling_params_list: list[Any] | None = None,
+ ) -> list[Any]:
+ if sampling_params_list is None:
+ sampling_params_list = self.get_default_sampling_params_list()
+ return self.omni.generate(prompts, sampling_params_list)
+
+ def generate_multimodal(
+ self,
+ prompts: list[str] | str,
+ sampling_params_list: list[Any] | None = None,
+ system_prompt: str | None = None,
+ audios: PromptAudioInput = None,
+ images: PromptImageInput = None,
+ videos: PromptVideoInput = None,
+ mm_processor_kwargs: dict[str, Any] | None = None,
+ modalities: list[str] | None = None,
+ ) -> list[Any]:
+ omni_inputs = self.get_omni_inputs(
+ prompts=prompts,
+ system_prompt=system_prompt,
+ audios=audios,
+ images=images,
+ videos=videos,
+ mm_processor_kwargs=mm_processor_kwargs,
+ modalities=modalities,
+ )
+ return self.generate(omni_inputs, sampling_params_list)
+
+ def start_profile(self, profile_prefix: str | None = None, stages: list[int] | None = None) -> list[Any]:
+ return self.omni.start_profile(profile_prefix=profile_prefix, stages=stages)
+
+ def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
+ return self.omni.stop_profile(stages=stages)
+
+ def _cleanup_process(self):
+ try:
+ keywords = ["enginecore"]
+ matched = []
+ for proc in psutil.process_iter(["pid", "name", "cmdline", "username"]):
+ try:
+ cmdline = " ".join(proc.cmdline()).lower() if proc.cmdline() else ""
+ name = proc.name().lower()
+ if any(k in cmdline for k in keywords) or any(k in name for k in keywords):
+ print(f"Found vllm process: PID={proc.pid}, cmd={cmdline[:100]}")
+ matched.append(proc)
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ pass
+ for proc in matched:
+ try:
+ proc.terminate()
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ pass
+ _, still_alive = psutil.wait_procs(matched, timeout=5)
+ for proc in still_alive:
+ try:
+ proc.kill()
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ pass
+ if still_alive:
+ _, stubborn = psutil.wait_procs(still_alive, timeout=3)
+ if stubborn:
+ print(f"Warning: failed to kill residual vllm pids: {[p.pid for p in stubborn]}")
+ else:
+ print(f"Force-killed residual vllm pids: {[p.pid for p in still_alive]}")
+ elif matched:
+ print(f"Terminated vllm pids: {[p.pid for p in matched]}")
+ except Exception as e:
+ print(f"Error in psutil vllm cleanup: {e}")
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if hasattr(self.omni, "close"):
+ self.omni.close()
+ self._cleanup_process()
+ run_forced_gpu_cleanup_round()
+ cleanup_dist_env_and_memory()
+
+
+class OmniRunnerHandler:
+ def __init__(self, omni_runner):
+ self.runner = omni_runner
+
+ def _process_output(self, outputs: list[Any]) -> OmniResponse:
+ result = OmniResponse()
+ try:
+ text_content = None
+ audio_content = None
+ for stage_output in outputs:
+ if getattr(stage_output, "final_output_type", None) == "text":
+ text_content = stage_output.request_output.outputs[0].text
+ if getattr(stage_output, "final_output_type", None) == "audio":
+ audio_content = stage_output.request_output.outputs[0].multimodal_output["audio"]
+ result.audio_content = audio_content
+ result.text_content = text_content
+ result.success = True
+ except Exception as e:
+ result.error_message = f"Output processing error: {str(e)}"
+ result.success = False
+ print(f"Error: {result.error_message}")
+ return result
+
+ def send_request(self, request_config: dict[str, Any] | None = None) -> OmniResponse:
+ if request_config is None:
+ request_config = {}
+ prompts = request_config.get("prompts")
+ videos = request_config.get("videos")
+ images = request_config.get("images")
+ audios = request_config.get("audios")
+ modalities = request_config.get("modalities", ["text", "audio"])
+ outputs = self.runner.generate_multimodal(
+ prompts=prompts, videos=videos, images=images, audios=audios, modalities=modalities
+ )
+ response = self._process_output(outputs)
+ assert_omni_response(response, request_config, run_level="core_model")
+ return response
+
+ def send_audio_speech_request(self, request_config: dict[str, Any]) -> OmniResponse:
+ """
+ Offline TTS: text -> audio via generate_multimodal, then validate with assert_audio_speech_response.
+
+ request_config must contain:
+ - 'input' or 'prompts': text to synthesize.
+ Optional keys:
+ - 'voice' -> speaker (CustomVoice)
+ - 'task_type' -> task_type in additional_information (default: "CustomVoice")
+ - 'language' -> language in additional_information (default: "Auto")
+ - 'max_new_tokens' -> max_new_tokens in additional_information (default: 2048)
+ - 'response_format' -> desired audio format (used only for assertion)
+ """
+ input_text = request_config.get("input") or request_config.get("prompts")
+ if input_text is None:
+ raise ValueError("request_config must contain 'input' or 'prompts' for TTS")
+ if isinstance(input_text, list):
+ input_text = input_text[0] if input_text else ""
+
+ mm_processor_kwargs: dict[str, Any] = {}
+ if "voice" in request_config:
+ mm_processor_kwargs["speaker"] = request_config["voice"]
+ if "task_type" in request_config:
+ mm_processor_kwargs["task_type"] = request_config["task_type"]
+ if "ref_audio" in request_config:
+ mm_processor_kwargs["ref_audio"] = request_config["ref_audio"]
+ if "ref_text" in request_config:
+ mm_processor_kwargs["ref_text"] = request_config["ref_text"]
+ if "language" in request_config:
+ mm_processor_kwargs["language"] = request_config["language"]
+ if "max_new_tokens" in request_config:
+ mm_processor_kwargs["max_new_tokens"] = request_config["max_new_tokens"]
+
+ outputs = self.runner.generate_multimodal(
+ prompts=input_text,
+ modalities=["audio"],
+ mm_processor_kwargs=mm_processor_kwargs or None,
+ )
+ mm_out: dict[str, Any] | None = None
+ for stage_out in outputs:
+ if getattr(stage_out, "final_output_type", None) == "audio":
+ mm_out = stage_out.request_output.outputs[0].multimodal_output
+ break
+ if mm_out is None:
+ result = OmniResponse(success=False, error_message="No audio output from pipeline")
+ assert result.success, result.error_message
+ return result
+
+ audio_data = mm_out.get("audio")
+ if audio_data is None:
+ result = OmniResponse(success=False, error_message="No audio tensor in multimodal output")
+ assert result.success, result.error_message
+ return result
+
+ sr_raw = mm_out.get("sr")
+ sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw
+ sr = int(sr_val.item() if hasattr(sr_val, "item") else sr_val)
+ wav_tensor = torch.cat(audio_data, dim=-1) if isinstance(audio_data, list) else audio_data
+ wav_buf = io.BytesIO()
+ sf.write(
+ wav_buf,
+ wav_tensor.float().cpu().numpy().reshape(-1),
+ samplerate=sr,
+ format="WAV",
+ subtype="PCM_16",
+ )
+ result = OmniResponse(success=True, audio_bytes=wav_buf.getvalue(), audio_format="audio/wav")
+ assert_audio_speech_response(result, request_config, run_level="core_model")
+ return result
+
+ def start_profile(self, profile_prefix: str | None = None, stages: list[int] | None = None) -> list[Any]:
+ return self.runner.start_profile(profile_prefix=profile_prefix, stages=stages)
+
+ def stop_profile(self, stages: list[int] | None = None) -> list[Any]:
+ return self.runner.stop_profile(stages=stages)
+
+
+__all__ = [
+ "DiffusionResponse",
+ "OmniResponse",
+ "OmniRunner",
+ "OmniRunnerHandler",
+ "OmniServer",
+ "OmniServerParams",
+ "OmniServerStageCli",
+ "OpenAIClientHandler",
+ "get_open_port",
+ "run_forced_gpu_cleanup_round",
+ "dummy_messages_from_mix_data",
+]
diff --git a/tests/helpers/stage_config.py b/tests/helpers/stage_config.py
new file mode 100644
index 00000000000..29a80372ecf
--- /dev/null
+++ b/tests/helpers/stage_config.py
@@ -0,0 +1,475 @@
+"""Config/message construction helpers used by tests."""
+
+import atexit
+import os
+import tempfile
+from pathlib import Path
+from typing import Any
+
+import yaml
+
+
+def modify_stage_config(
+ yaml_path: str,
+ updates: dict[str, Any] = None,
+ deletes: dict[str, Any] = None,
+) -> str:
+ """
+ Modify configurations in a YAML file, supporting both top-level and stage-specific modifications,
+ including addition, modification, and deletion of configurations.
+
+ Args:
+ yaml_path: Path to the YAML configuration file.
+ updates: Dictionary containing both top-level and stage-specific modifications to add or update.
+ Format: {
+ 'async_chunk': True,
+ 'stage_args': {
+ 0: {'engine_args.max_model_len': 5800},
+ 1: {'engine_args.max_num_seqs': 2}
+ }
+ }
+ deletes: Dictionary containing configurations to delete.
+ Format: {
+ 'old_config': None, # Delete entire key
+ 'stage_args': {
+ 0: ['engine_args.old_param'],
+ 1: ['runtime.unused_setting']
+ }
+ }
+
+ Returns:
+ str: Path to the newly created modified YAML file with timestamp suffix.
+ """
+ path = Path(yaml_path)
+ if not path.exists():
+ raise FileNotFoundError(f"yaml does not exist: {path}")
+
+ try:
+ with open(yaml_path, encoding="utf-8") as f:
+ config = yaml.safe_load(f) or {}
+ except Exception as e:
+ raise ValueError(f"Cannot parse YAML file: {e}")
+
+ # Helper function to apply update
+ def apply_update(config_dict: dict, key_path: str, value: Any) -> None:
+ """Apply update to dictionary using dot-separated path."""
+ # Handle direct list assignment (e.g., engine_input_source: [1, 2])
+ if "." not in key_path:
+ # Simple key, set directly
+ config_dict[key_path] = value
+ return
+
+ current = config_dict
+ keys = key_path.split(".")
+
+ for i in range(len(keys) - 1):
+ key = keys[i]
+
+ # Handle list indices
+ if key.isdigit() and isinstance(current, list):
+ index = int(key)
+ if index < 0:
+ raise ValueError(f"Negative list index not allowed: {index}")
+ if index >= len(current):
+ # Expand list if needed
+ while len(current) <= index:
+ # If we need to go deeper (more keys after this), create a dict
+ # Otherwise, create None placeholder
+ current.append({} if i < len(keys) - 2 else None)
+ current = current[index]
+ elif isinstance(current, dict):
+ # Handle dictionary keys
+ if key not in current:
+ # If there are more keys after this, create appropriate structure
+ if i < len(keys) - 1:
+ # Check if next key is a digit (list index) or string (dict key)
+ if keys[i + 1].isdigit():
+ current[key] = []
+ else:
+ current[key] = {}
+ else:
+ # This is the last key, create based on value type
+ current[key] = [] if isinstance(value, list) else {}
+ elif not isinstance(current[key], (dict, list)) and i < len(keys) - 1:
+ # If current value is not dict/list but we need to go deeper, replace it
+ if keys[i + 1].isdigit():
+ current[key] = []
+ else:
+ current[key] = {}
+ current = current[key]
+ else:
+ # Current is not a dict or list, cannot traverse further
+ raise TypeError(
+ f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
+ )
+
+ # Set the final value
+ last_key = keys[-1]
+ if isinstance(current, list) and last_key.isdigit():
+ # Setting a value in a list by index
+ index = int(last_key)
+ if index < 0:
+ raise ValueError(f"Negative list index not allowed: {index}")
+ if index >= len(current):
+ # Expand list if needed
+ while len(current) <= index:
+ current.append(None)
+ current[index] = value
+ elif isinstance(current, dict):
+ # Special case: if the value is a list and we're setting a top-level key
+ # Example: updating engine_input_source with [1, 2]
+ current[last_key] = value
+ else:
+ # Current is not a dict, cannot set key
+ raise TypeError(f"Cannot set value at {key_path}. Current type is {type(current).__name__}, expected dict.")
+
+ # Helper function to delete by path
+ def delete_by_path(config_dict: dict, path: str) -> None:
+ """Delete configuration by dot-separated path."""
+ if not path:
+ return
+
+ current = config_dict
+ keys = path.split(".")
+
+ # Traverse to the parent
+ for i in range(len(keys) - 1):
+ key = keys[i]
+
+ # Handle list indices
+ if key.isdigit() and isinstance(current, list):
+ index = int(key)
+ if index < 0 or index >= len(current):
+ raise KeyError(f"List index {index} out of bounds")
+ current = current[index]
+ elif isinstance(current, dict):
+ if key not in current:
+ raise KeyError(f"Path {'.'.join(keys[: i + 1])} does not exist")
+ current = current[key]
+ else:
+ raise TypeError(
+ f"Cannot access {'.'.join(keys[: i + 1])} as a dict/list. It's a {type(current).__name__}"
+ )
+
+ # Delete the item
+ last_key = keys[-1]
+
+ if isinstance(current, list) and last_key.isdigit():
+ index = int(last_key)
+ if index < 0 or index >= len(current):
+ raise KeyError(f"List index {index} out of bounds")
+ del current[index]
+ elif isinstance(current, dict) and last_key in current:
+ del current[last_key]
+ else:
+ print(f"Path {path} does not exist")
+
+ _stage_key = "stages" if "stages" in config else "stage_args"
+
+ # Apply deletions first
+ if deletes:
+ for key, value in deletes.items():
+ if key in ("stage_args", "stages"):
+ if value and isinstance(value, dict):
+ stage_args = config.get(_stage_key, [])
+ if not stage_args:
+ raise ValueError("stage_args does not exist in config")
+
+ for stage_id, delete_paths in value.items():
+ if not delete_paths:
+ continue
+
+ # Find stage by ID
+ target_stage = None
+ for stage in stage_args:
+ if stage.get("stage_id") == int(stage_id):
+ target_stage = stage
+ break
+
+ if target_stage is None:
+ continue
+
+ # Delete specified paths in this stage
+ # Avoid shadowing the original YAML Path used for the output filename below.
+ for delete_path in delete_paths:
+ if delete_path: # Skip empty paths
+ delete_by_path(target_stage, delete_path)
+ elif "." in key:
+ # Delete using dot-separated path
+ delete_by_path(config, key)
+ elif value is None and key in config:
+ # Delete entire key
+ del config[key]
+
+ # Apply updates
+ if updates:
+ for key, value in updates.items():
+ if key in ("stage_args", "stages"):
+ if value and isinstance(value, dict):
+ stage_args = config.get(_stage_key, [])
+ if not stage_args:
+ raise ValueError("stage_args does not exist in config")
+
+ for stage_id, stage_updates in value.items():
+ # Find stage by ID
+ target_stage = None
+ for stage in stage_args:
+ if stage.get("stage_id") == int(stage_id):
+ target_stage = stage
+ break
+
+ if target_stage is None:
+ available_ids = [s.get("stage_id") for s in stage_args if "stage_id" in s]
+ raise KeyError(f"Stage ID {stage_id} not found, available: {available_ids}")
+
+ # Apply updates to this stage
+ for update_path, val in stage_updates.items():
+ # Check if this is a simple key (not dot-separated)
+ # Example: 'engine_input_source' vs 'engine_args.max_model_len'
+ if "." not in update_path:
+ # Direct key assignment (e.g., updating a list value)
+ target_stage[update_path] = val
+ else:
+ # Dot-separated path (e.g., nested dict access)
+ apply_update(target_stage, update_path, val)
+ elif "." in key:
+ # Apply using dot-separated path
+ apply_update(config, key, value)
+ else:
+ # Direct top-level key
+ config[key] = value
+
+ # Unique suffix: multiple modify_stage_config calls in one process often run
+ # within the same second (e.g. test_qwen3_omni_expansion imports both
+ # get_chunk_config and get_batch_token_config). int(time.time()) would collide
+ # and the later write would overwrite the earlier YAML on disk.
+ # Keep generated configs outside the repo and delete them when pytest exits.
+ output_fd, output_path = tempfile.mkstemp(prefix=f"{path.stem}_", suffix=".yaml")
+ atexit.register(Path(output_path).unlink, missing_ok=True)
+
+ with os.fdopen(output_fd, "w", encoding="utf-8") as f:
+ yaml.dump(config, f, default_flow_style=None, sort_keys=False, allow_unicode=True, indent=2)
+
+ return str(output_path)
+
+
+# ``stage_config.py`` lives under ``tests/helpers/``; repo root is three parents up.
+_REPO_ROOT = Path(__file__).resolve().parent.parent.parent
+_DEPLOY_DIR = _REPO_ROOT / "vllm_omni" / "deploy"
+_CI_GENERATED_DIR = _REPO_ROOT / "tests" / ".ci_generated"
+
+
+# CI overlays as Python dicts (LSP-friendly). Materialized on demand to
+# tests/.ci_generated/.yaml via get_deploy_config_path("ci/.yaml").
+_CI_OVERLAYS: dict[str, dict[str, Any]] = {
+ "qwen2_5_omni": {
+ "base_config": "qwen2_5_omni.yaml",
+ "async_chunk": False,
+ "stages": [
+ {
+ "stage_id": 0,
+ "max_model_len": 16384,
+ "max_num_batched_tokens": 16384,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.9,
+ "skip_mm_profiling": True,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 128},
+ },
+ {
+ "stage_id": 1,
+ "max_model_len": 16384,
+ "max_num_batched_tokens": 16384,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.4,
+ "skip_mm_profiling": True,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 4096},
+ },
+ {
+ "stage_id": 2,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.5,
+ "max_num_batched_tokens": 8192,
+ "max_model_len": 8192,
+ "load_format": "dummy",
+ "devices": "2",
+ "default_sampling_params": {"max_tokens": 8192},
+ },
+ ],
+ "platforms": {
+ "rocm": {
+ "stages": [
+ {"stage_id": 0, "gpu_memory_utilization": 0.9},
+ {"stage_id": 1, "gpu_memory_utilization": 0.4},
+ {"stage_id": 2, "gpu_memory_utilization": 0.5, "devices": "2"},
+ ],
+ },
+ "xpu": {
+ "stages": [
+ {
+ "stage_id": 0,
+ "gpu_memory_utilization": 0.9,
+ "max_num_batched_tokens": 16384,
+ "max_model_len": 16384,
+ },
+ {"stage_id": 1, "gpu_memory_utilization": 0.5},
+ {
+ "stage_id": 2,
+ "gpu_memory_utilization": 0.3,
+ "max_num_batched_tokens": 4096,
+ "max_model_len": 4096,
+ "devices": "2",
+ },
+ ],
+ },
+ },
+ },
+ "qwen3_omni_moe": {
+ "base_config": "qwen3_omni_moe.yaml",
+ "async_chunk": False,
+ "stages": [
+ {
+ "stage_id": 0,
+ "max_num_seqs": 5,
+ "max_model_len": 32768,
+ "mm_processor_cache_gb": 0,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 150, "ignore_eos": False},
+ },
+ {
+ "stage_id": 1,
+ "gpu_memory_utilization": 0.5,
+ "max_num_seqs": 5,
+ "max_model_len": 32768,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 1000},
+ },
+ {
+ "stage_id": 2,
+ "max_num_seqs": 5,
+ "max_num_batched_tokens": 100000,
+ "load_format": "dummy",
+ "default_sampling_params": {"max_tokens": 2000},
+ },
+ ],
+ "platforms": {
+ "rocm": {
+ "stages": [
+ {"stage_id": 0, "max_num_seqs": 1, "default_sampling_params": {"max_tokens": 100}},
+ {
+ "stage_id": 1,
+ "max_num_seqs": 1,
+ "enforce_eager": True,
+ "default_sampling_params": {"max_tokens": 100},
+ },
+ {
+ "stage_id": 2,
+ "max_num_seqs": 1,
+ "max_num_batched_tokens": 1000000,
+ "default_sampling_params": {"max_tokens": 200},
+ },
+ ],
+ },
+ "xpu": {
+ "stages": [
+ {
+ "stage_id": 0,
+ "gpu_memory_utilization": 0.85,
+ "max_num_seqs": 1,
+ "tensor_parallel_size": 4,
+ "enforce_eager": True,
+ "max_num_batched_tokens": 4096,
+ "max_model_len": 4096,
+ "max_cudagraph_capture_size": 0,
+ "skip_mm_profiling": True,
+ "devices": "0,1,2,3",
+ "default_sampling_params": {"max_tokens": 100, "ignore_eos": False},
+ },
+ {
+ "stage_id": 1,
+ "gpu_memory_utilization": 0.6,
+ "max_num_seqs": 1,
+ "enforce_eager": True,
+ "max_num_batched_tokens": 4096,
+ "max_model_len": 4096,
+ "max_cudagraph_capture_size": 0,
+ "skip_mm_profiling": True,
+ "devices": "4",
+ },
+ {
+ "stage_id": 2,
+ "gpu_memory_utilization": 0.3,
+ "max_num_seqs": 1,
+ "max_num_batched_tokens": 100000,
+ "max_cudagraph_capture_size": 0,
+ "skip_mm_profiling": True,
+ "devices": "5",
+ "default_sampling_params": {"max_tokens": 2000},
+ },
+ ],
+ },
+ },
+ },
+ # Single-stage thinker-only topology for the abort test.
+ "qwen2_5_omni_thinker_only": {
+ "async_chunk": False,
+ "pipeline": "qwen2_5_omni_thinker_only",
+ "stages": [
+ {
+ "stage_id": 0,
+ "max_num_seqs": 1,
+ "gpu_memory_utilization": 0.9,
+ "enforce_eager": True,
+ "max_num_batched_tokens": 16384,
+ "max_model_len": 16384,
+ "skip_mm_profiling": True,
+ "mm_processor_cache_gb": 0,
+ "load_format": "dummy",
+ "devices": "0",
+ "default_sampling_params": {
+ "temperature": 0.0,
+ "top_p": 1.0,
+ "top_k": -1,
+ "max_tokens": 128,
+ "seed": 42,
+ "repetition_penalty": 1.1,
+ },
+ },
+ ],
+ },
+}
+
+
+def _materialize_ci_overlay(model_type: str) -> Path:
+ import yaml
+
+ if model_type not in _CI_OVERLAYS:
+ raise KeyError(f"No CI overlay registered for {model_type!r}. Available: {sorted(_CI_OVERLAYS)}")
+
+ _CI_GENERATED_DIR.mkdir(parents=True, exist_ok=True)
+ out = _CI_GENERATED_DIR / f"{model_type}.yaml"
+
+ overlay = {**_CI_OVERLAYS[model_type]}
+ base = overlay.get("base_config")
+ if base:
+ overlay["base_config"] = str(_DEPLOY_DIR / base)
+
+ with open(out, "w", encoding="utf-8") as f:
+ yaml.safe_dump(overlay, f, sort_keys=False)
+ return out
+
+
+def get_deploy_config_path(rel_path: str) -> str:
+ """Resolve a deploy yaml; ``ci/.yaml`` materializes from ``_CI_OVERLAYS``."""
+ if rel_path.startswith("ci/") and rel_path.endswith(".yaml"):
+ model_type = rel_path[len("ci/") : -len(".yaml")]
+ if model_type in _CI_OVERLAYS:
+ return str(_materialize_ci_overlay(model_type))
+ return str(_DEPLOY_DIR / rel_path)
+
+
+__all__ = [
+ "modify_stage_config",
+ "get_deploy_config_path",
+]
diff --git a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py
index ec24f6949fe..0e071f724e5 100644
--- a/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py
+++ b/tests/model_executor/models/cosyvoice3/test_cosyvoice3_components.py
@@ -8,7 +8,7 @@
import torch
import torch.nn as nn
-from tests.utils import hardware_test
+from tests.helpers.mark import hardware_test
class TestPreLookaheadLayer:
diff --git a/tests/model_executor/models/voxcpm2/__init__.py b/tests/model_executor/models/voxcpm2/__init__.py
new file mode 100644
index 00000000000..208f01a7cb5
--- /dev/null
+++ b/tests/model_executor/models/voxcpm2/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
new file mode 100644
index 00000000000..5d8a35636b7
--- /dev/null
+++ b/tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
@@ -0,0 +1,121 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Regression tests for VoxCPM2 talker per-request state lifecycle."""
+
+from __future__ import annotations
+
+import pytest
+
+torch = pytest.importorskip("torch")
+pytest.importorskip("librosa")
+
+from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import ( # noqa: E402
+ VoxCPM2TalkerForConditionalGeneration,
+ _RequestState,
+)
+
+
+def _make_bare_talker() -> VoxCPM2TalkerForConditionalGeneration:
+ talker = VoxCPM2TalkerForConditionalGeneration.__new__(VoxCPM2TalkerForConditionalGeneration)
+ talker._active_states = {}
+ talker._current_request_id = None
+ talker._pending_requests = []
+ talker._results_queue = []
+ talker._audio_queue = []
+ talker._deferred_cleanup_ids = set()
+ talker._max_batch_size = 4
+ talker._active_state_warn_threshold = 512
+ talker._active_state_warned = False
+ return talker
+
+
+def _seed_cached_decode(talker, req_id: str) -> _RequestState:
+ state = _RequestState(request_id=req_id)
+ state.prefill_completed = True
+ state.decode_step_count = 5
+ talker._active_states[req_id] = state
+ return state
+
+
+class TestStateEvictionContract:
+ def test_pending_requests_is_not_used_for_eviction(self) -> None:
+ talker = _make_bare_talker()
+
+ cached_ids = [f"req-{i}" for i in range(4)]
+ for rid in cached_ids:
+ _seed_cached_decode(talker, rid)
+
+ walked_so_far = ["req-new", cached_ids[0], cached_ids[1]]
+ talker._pending_requests = [(rid, False, None, 0) for rid in walked_so_far]
+
+ for rid in cached_ids:
+ assert rid in talker._active_states
+ assert talker._active_states[rid].prefill_completed is True
+
+ def test_on_requests_finished_defers_cleanup(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ _seed_cached_decode(talker, "req-B")
+
+ talker.on_requests_finished({"req-A"})
+
+ assert "req-A" in talker._active_states
+ assert "req-A" in talker._deferred_cleanup_ids
+
+ def test_flush_deferred_cleanup_removes_only_finished(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ _seed_cached_decode(talker, "req-B")
+ talker.on_requests_finished(["req-A"])
+
+ talker._flush_deferred_cleanup()
+
+ assert "req-A" not in talker._active_states
+ assert "req-B" in talker._active_states
+ assert talker._deferred_cleanup_ids == set()
+
+ def test_current_request_id_cleared_when_matching(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ talker._current_request_id = "req-A"
+
+ talker.on_requests_finished({"req-A"})
+ talker._flush_deferred_cleanup()
+
+ assert talker._current_request_id is None
+
+ def test_current_request_id_preserved_when_not_finished(self) -> None:
+ talker = _make_bare_talker()
+ _seed_cached_decode(talker, "req-A")
+ _seed_cached_decode(talker, "req-B")
+ talker._current_request_id = "req-B"
+
+ talker.on_requests_finished({"req-A"})
+ talker._flush_deferred_cleanup()
+
+ assert talker._current_request_id == "req-B"
+
+
+class TestLeakWarnGuard:
+ def test_warn_fires_once_over_threshold(self, monkeypatch) -> None:
+ from vllm_omni.model_executor.models.voxcpm2 import voxcpm2_talker as tk
+
+ calls: list[str] = []
+
+ def _capture(msg, *args, **kwargs):
+ calls.append(msg % args if args else msg)
+
+ monkeypatch.setattr(tk.logger, "warning", _capture)
+
+ talker = _make_bare_talker()
+ talker._active_state_warn_threshold = 3
+
+ for i in range(4):
+ talker._active_states[f"seed-{i}"] = _RequestState(request_id=f"seed-{i}")
+
+ talker._get_or_create_state("new-1")
+ talker._get_or_create_state("new-2")
+
+ leak_warnings = [m for m in calls if "cleanup path leak" in m]
+ assert len(leak_warnings) == 1
+ assert talker._active_state_warned is True
diff --git a/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py
new file mode 100644
index 00000000000..18972c91d5d
--- /dev/null
+++ b/tests/model_executor/stage_input_processors/test_qwen3_omni_streaming_helpers.py
@@ -0,0 +1,81 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for Qwen3-Omni streaming thinker→talker / talker→codec helpers (PR #2581)."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+
+import pytest
+
+import vllm_omni.model_executor.stage_input_processors.qwen3_omni as q3
+
+pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
+
+
+@pytest.fixture(autouse=True)
+def _streaming_context() -> SimpleNamespace:
+ return SimpleNamespace(bridge_states={})
+
+
+def test_get_streaming_talker_tokens_first_segment(_streaming_context: SimpleNamespace) -> None:
+ inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
+ "r1",
+ [1, 2],
+ [10, 11],
+ streaming_context=_streaming_context,
+ )
+ assert inc_p == [1, 2]
+ assert inc_o == [10, 11]
+ assert merged == [1, 2, 10, 11]
+ assert thinker_in == [1, 2]
+
+
+def test_get_streaming_talker_tokens_second_segment_accumulates(_streaming_context: SimpleNamespace) -> None:
+ q3._get_streaming_talker_tokens("r2", [1, 2], [10, 11], streaming_context=_streaming_context)
+ inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
+ "r2",
+ [1, 2, 3, 4],
+ [10, 11, 12, 13],
+ streaming_context=_streaming_context,
+ )
+ assert inc_p == [3, 4]
+ assert inc_o == [12, 13]
+ assert merged == [1, 2, 10, 3, 4, 12, 13]
+ assert thinker_in == [1, 2, 10, 3, 4]
+
+
+def test_get_streaming_talker_tokens_new_prompt_len_snapshot_truncates(
+ _streaming_context: SimpleNamespace,
+) -> None:
+ inc_p, inc_o, merged, thinker_in = q3._get_streaming_talker_tokens(
+ "r3",
+ [1, 2, 3, 4, 5, 6],
+ [10],
+ new_prompt_len_snapshot=2,
+ streaming_context=_streaming_context,
+ )
+ assert inc_p == [1, 2, 3, 4]
+ assert inc_o == [10]
+ assert merged == [1, 2, 3, 4, 10]
+ assert thinker_in == [1, 2, 3, 4]
+
+
+def test_get_streaming_talker_tokens_clear_state(_streaming_context: SimpleNamespace) -> None:
+ q3._get_streaming_talker_tokens("r4", [1], [2], streaming_context=_streaming_context, clear_state=True)
+ state = q3._get_qwen3_streaming_state("r4", _streaming_context).thinker2talker
+ assert state.last_prompt_len == 0
+ assert state.last_output_len == 0
+ assert state.merged_sequences == []
+
+
+def test_get_streaming_codec_delta_len_increments_and_finishes(_streaming_context: SimpleNamespace) -> None:
+ d1 = q3._get_streaming_codec_delta_len(5, "c1", SimpleNamespace(finished=False), _streaming_context)
+ assert d1 == 5
+ d2 = q3._get_streaming_codec_delta_len(8, "c1", SimpleNamespace(finished=False), _streaming_context)
+ assert d2 == 2
+ # After d2, stored cursor is cur_seq_len + 1 == 9; next delta uses new cur_seq_len - 9.
+ d3 = q3._get_streaming_codec_delta_len(10, "c1", SimpleNamespace(finished=True), _streaming_context)
+ assert d3 == 1
+ state = q3._get_qwen3_streaming_state("c1", _streaming_context)
+ assert state.talker2code2wav_last_seq_len == 0
diff --git a/tests/test_arg_utils.py b/tests/test_arg_utils.py
new file mode 100644
index 00000000000..dab5ed6878a
--- /dev/null
+++ b/tests/test_arg_utils.py
@@ -0,0 +1,353 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Tests for vllm_omni.engine.arg_utils — invariants that must
+hold for the orchestrator/engine/server CLI flag partition."""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass, fields
+
+import pytest
+
+from vllm_omni.engine.arg_utils import (
+ SHARED_FIELDS,
+ derive_server_dests_from_vllm_parser,
+ internal_blacklist_keys,
+ orchestrator_args_from_argparse,
+ orchestrator_field_names,
+ split_kwargs,
+)
+
+# ---------------------------------------------------------------------------
+# Fake engine class for unit testing — avoids pulling in the full vllm
+# EngineArgs and its heavy __post_init__ at test time.
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class _FakeEngineArgs:
+ """Stand-in for OmniEngineArgs with a representative subset of fields."""
+
+ model: str = ""
+ stage_id: int = 0
+ max_num_seqs: int = 64
+ gpu_memory_utilization: float = 0.9
+ async_chunk: bool = False # also in OrchestratorArgs → shared
+ log_stats: bool = False # also in OrchestratorArgs → shared
+ stage_configs_path: str | None = None
+
+
+# ============================================================================
+# Invariant 1 — OrchestratorArgs and engine must not ambiguously overlap.
+# ============================================================================
+
+
+def test_no_ambiguous_overlap_with_fake_engine():
+ """OrchestratorArgs ∩ engine fields must be ⊆ SHARED_FIELDS."""
+ orch = orchestrator_field_names()
+ engine = {f.name for f in fields(_FakeEngineArgs)}
+ overlap = orch & engine
+ unexpected = overlap - SHARED_FIELDS
+ assert not unexpected, (
+ f"Fields declared in both OrchestratorArgs and the engine class "
+ f"but not in SHARED_FIELDS: {sorted(unexpected)}. These cause "
+ f"double-routing — either remove the duplicate declaration or add "
+ f"to SHARED_FIELDS if sharing is intentional."
+ )
+
+
+def test_no_ambiguous_overlap_with_real_engine():
+ """Same check, but against the real OmniEngineArgs."""
+ try:
+ from vllm_omni.engine.arg_utils import OmniEngineArgs
+ except Exception as exc:
+ pytest.skip(f"OmniEngineArgs not importable: {exc}")
+
+ orch = orchestrator_field_names()
+ engine = {f.name for f in fields(OmniEngineArgs)}
+ overlap = orch & engine
+ unexpected = overlap - SHARED_FIELDS
+ assert not unexpected, (
+ f"Real OmniEngineArgs has ambiguous overlap with OrchestratorArgs: "
+ f"{sorted(unexpected)}. Update SHARED_FIELDS or remove duplication."
+ )
+
+
+# ============================================================================
+# Invariant 2 — split_kwargs partitions correctly.
+# ============================================================================
+
+
+def test_split_orchestrator_only():
+ """Pure orchestrator fields go to OrchestratorArgs, not engine_kwargs."""
+ raw = {"stage_init_timeout": 500, "worker_backend": "ray"}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert orch.stage_init_timeout == 500
+ assert orch.worker_backend == "ray"
+ assert "stage_init_timeout" not in engine
+ assert "worker_backend" not in engine
+
+
+def test_split_engine_only():
+ """Pure engine fields go to engine_kwargs, not OrchestratorArgs."""
+ raw = {"max_num_seqs": 128, "gpu_memory_utilization": 0.85}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert engine["max_num_seqs"] == 128
+ assert engine["gpu_memory_utilization"] == 0.85
+ # These fields don't exist on OrchestratorArgs at all.
+
+
+def test_split_shared_fields_go_to_both():
+ """Fields in SHARED_FIELDS are copied to both buckets."""
+ raw = {"model": "Qwen/Qwen2.5-Omni-7B", "log_stats": True}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert orch.log_stats is True
+ assert engine["model"] == "Qwen/Qwen2.5-Omni-7B"
+ assert engine["log_stats"] is True
+
+
+def test_split_drops_unclassified():
+ """Unclassified fields (uvicorn/server) are dropped silently."""
+ raw = {
+ "max_num_seqs": 64, # engine
+ "host": "0.0.0.0", # unclassified (server)
+ "port": 8091, # unclassified (server)
+ "ssl_keyfile": "key.pem", # unclassified (server)
+ }
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert engine == {"max_num_seqs": 64}
+ assert "host" not in engine
+ assert "port" not in engine
+ assert "ssl_keyfile" not in engine
+
+
+def test_split_mixed_real_world():
+ """End-to-end: raw CLI kwargs with all three classes present."""
+ raw = {
+ # orchestrator
+ "stage_init_timeout": 400,
+ "deploy_config": "/tmp/deploy.yaml",
+ "worker_backend": "multi_process",
+ "async_chunk": True,
+ # engine
+ "max_num_seqs": 32,
+ "gpu_memory_utilization": 0.8,
+ # shared
+ "model": "Qwen/Qwen3-Omni",
+ "log_stats": False,
+ # server / unclassified
+ "host": "0.0.0.0",
+ "port": 8091,
+ "api_key": "secret",
+ # None values
+ "ray_address": None,
+ }
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+
+ # Orchestrator side
+ assert orch.stage_init_timeout == 400
+ assert orch.deploy_config == "/tmp/deploy.yaml"
+ assert orch.worker_backend == "multi_process"
+ assert orch.async_chunk is True
+ assert orch.log_stats is False # shared, read from raw
+ assert orch.ray_address is None # default preserved
+
+ # Engine side
+ assert engine["max_num_seqs"] == 32
+ assert engine["gpu_memory_utilization"] == 0.8
+ assert engine["model"] == "Qwen/Qwen3-Omni"
+ assert engine["log_stats"] is False
+ assert "host" not in engine
+ assert "port" not in engine
+ assert "api_key" not in engine
+ # orchestrator-only keys never reach engine
+ assert "stage_init_timeout" not in engine
+ assert "deploy_config" not in engine
+ assert "async_chunk" not in engine
+
+
+# ============================================================================
+# Invariant 3 — user-typed unclassifiable flags warn (don't fail silently).
+# ============================================================================
+
+
+def test_user_typed_unclassified_warns(caplog):
+ """If the user types a flag we can't route, warn — don't silently drop."""
+ raw = {"bogus_flag": "value", "max_num_seqs": 64}
+ with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
+ split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed={"bogus_flag"})
+ assert any("bogus_flag" in rec.message for rec in caplog.records), (
+ f"Expected warning mentioning 'bogus_flag', got: {[rec.message for rec in caplog.records]}"
+ )
+
+
+def test_unclassified_without_user_typed_silent(caplog):
+ """Without user_typed, unclassified keys drop silently (argparse defaults
+ for server flags shouldn't spam logs on every launch)."""
+ raw = {"host": "0.0.0.0", "port": 8091, "max_num_seqs": 64}
+ with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
+ split_kwargs(raw, engine_cls=_FakeEngineArgs, user_typed=None)
+ # No warnings because we don't know these were user-typed.
+ assert not any("host" in rec.message or "port" in rec.message for rec in caplog.records)
+
+
+# ============================================================================
+# Invariant 4 — CLI flag classification completeness.
+# Catches new flags added without updating OrchestratorArgs or OmniEngineArgs.
+# ============================================================================
+
+
+def test_all_omni_cli_flags_classified():
+ """Every vllm-omni-added CLI flag must be classifiable.
+
+ Runs ``OmniServeCommand.subparser_init`` and checks that every new
+ argument (compared to vllm's base parser) is either:
+ - a field on OrchestratorArgs, OR
+ - a field on OmniEngineArgs, OR
+ - in SHARED_FIELDS
+ """
+ try:
+ from vllm.utils.argparse_utils import FlexibleArgumentParser
+
+ from vllm_omni.engine.arg_utils import OmniEngineArgs
+ from vllm_omni.entrypoints.cli.serve import OmniServeCommand
+ except Exception as exc:
+ pytest.skip(f"Cannot build parser in this environment: {exc}")
+
+ # Build the serve parser
+ root = FlexibleArgumentParser()
+ subparsers = root.add_subparsers()
+ cmd = OmniServeCommand()
+ try:
+ parser = cmd.subparser_init(subparsers)
+ except Exception as exc:
+ pytest.skip(f"subparser_init failed (dev env issue): {exc}")
+
+ all_dests = {a.dest for a in parser._actions if a.dest and a.dest not in {"help", "model_tag"}}
+
+ orch = orchestrator_field_names()
+ engine = {f.name for f in fields(OmniEngineArgs)}
+ server_derived = derive_server_dests_from_vllm_parser()
+
+ unclassified = all_dests - orch - engine - SHARED_FIELDS - server_derived
+ # Some argparse-internal dests (suppressed, private) may not match —
+ # filter those out.
+ unclassified = {d for d in unclassified if not d.startswith("_")}
+
+ assert not unclassified, (
+ f"These CLI flags are not classified as "
+ f"orchestrator/engine/shared/server: {sorted(unclassified)}. "
+ f"Add them to OrchestratorArgs (if consumed by orchestrator), "
+ f"OmniEngineArgs (if consumed by per-stage engine), or the known-server "
+ f"allowlist (if they're vllm frontend flags). "
+ f"If intentional (e.g. a new CLI-only flag that doesn't map to either "
+ f"dataclass), add it to a KNOWN_UNROUTED allowlist."
+ )
+
+
+# ============================================================================
+# argparse interop (Phase 3).
+# ============================================================================
+
+
+def test_orchestrator_args_from_argparse():
+ """Can build OrchestratorArgs from an argparse.Namespace."""
+ import argparse
+
+ ns = argparse.Namespace(
+ stage_init_timeout=500,
+ deploy_config="/tmp/x.yaml",
+ max_num_seqs=64, # engine field — ignored
+ host="0.0.0.0", # server field — ignored
+ )
+ orch = orchestrator_args_from_argparse(ns)
+ assert orch.stage_init_timeout == 500
+ assert orch.deploy_config == "/tmp/x.yaml"
+ assert orch.worker_backend == "multi_process" # default
+
+
+def test_derive_server_dests_returns_frozenset():
+ """Server-dest derivation returns a frozenset (possibly empty)."""
+ result = derive_server_dests_from_vllm_parser()
+ assert isinstance(result, frozenset)
+
+
+# ============================================================================
+# internal_blacklist_keys — single source of truth for per-stage forwarding.
+# ============================================================================
+
+
+def test_internal_blacklist_keys_derived_from_orchestrator():
+ """Blacklist is exactly OrchestratorArgs fields minus SHARED_FIELDS.
+
+ This function replaces the old hardcoded INTERNAL_STAGE_OVERRIDE_KEYS
+ frozenset. Asserts the contract so future changes to OrchestratorArgs
+ automatically propagate to the blacklist.
+ """
+ blacklist = internal_blacklist_keys()
+ assert blacklist == orchestrator_field_names() - SHARED_FIELDS
+ # Spot-check expected entries
+ assert "stage_init_timeout" in blacklist
+ assert "deploy_config" in blacklist
+ assert "async_chunk" in blacklist
+ # Shared fields must NOT appear — they flow to both orchestrator and engine
+ assert "model" not in blacklist
+ assert "log_stats" not in blacklist
+
+
+# ============================================================================
+# Boundary value analysis — edge cases around split_kwargs.
+# ============================================================================
+
+
+def test_split_empty_kwargs():
+ """Empty kwargs yields default OrchestratorArgs and empty engine dict."""
+ orch, engine = split_kwargs({}, engine_cls=_FakeEngineArgs)
+ assert orch.stage_init_timeout == 300 # dataclass default
+ assert orch.worker_backend == "multi_process" # dataclass default
+ assert engine == {}
+
+
+def test_split_all_none_values_preserved_on_orchestrator():
+ """None values for orchestrator fields are kept (represents 'not set')."""
+ raw = {"ray_address": None, "deploy_config": None, "max_num_seqs": None}
+ orch, engine = split_kwargs(raw, engine_cls=_FakeEngineArgs)
+ assert orch.ray_address is None
+ assert orch.deploy_config is None
+ # Engine-side None still passes through; caller decides semantics downstream.
+ assert engine.get("max_num_seqs") is None
+
+
+def test_split_user_typed_with_empty_kwargs_no_warn(caplog):
+ """user_typed non-empty but kwargs empty — no warnings emitted."""
+ with caplog.at_level(logging.WARNING, logger="vllm_omni.engine.arg_utils"):
+ split_kwargs({}, engine_cls=_FakeEngineArgs, user_typed={"nothing"})
+ assert not caplog.records
+
+
+def test_ambiguous_field_strict_raises():
+ """strict=True raises ValueError on overlap outside SHARED_FIELDS."""
+
+ # deploy_config is on OrchestratorArgs; declaring it on the engine class
+ # too (without adding to SHARED_FIELDS) creates an ambiguous route.
+ @dataclass
+ class _AmbiguousEngine:
+ deploy_config: str | None = None
+
+ with pytest.raises(ValueError, match="both OrchestratorArgs and"):
+ split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=True)
+
+
+def test_ambiguous_field_non_strict_routes_to_orchestrator(caplog):
+ """strict=False logs ERROR but routes the ambiguous field to orchestrator."""
+
+ @dataclass
+ class _AmbiguousEngine:
+ deploy_config: str | None = None
+
+ with caplog.at_level(logging.ERROR, logger="vllm_omni.engine.arg_utils"):
+ orch, engine = split_kwargs({"deploy_config": "x"}, engine_cls=_AmbiguousEngine, strict=False)
+ assert orch.deploy_config == "x"
+ assert "deploy_config" not in engine
+ assert any("both OrchestratorArgs" in r.message for r in caplog.records)
diff --git a/tests/test_config_factory.py b/tests/test_config_factory.py
index e284de48d0b..4a271426ff3 100644
--- a/tests/test_config_factory.py
+++ b/tests/test_config_factory.py
@@ -4,12 +4,26 @@
Unit tests for StageConfigFactory and related classes.
"""
+from dataclasses import dataclass
+from pathlib import Path
+
+import pytest
+
from vllm_omni.config.stage_config import (
+ _EXECUTION_TYPE_TO_SCHEDULER,
+ _PIPELINE_REGISTRY,
ModelPipeline,
+ PipelineConfig,
StageConfig,
StageConfigFactory,
+ StageExecutionType,
+ StagePipelineConfig,
StageType,
+ build_stage_runtime_overrides,
+ register_pipeline,
+ strip_parent_engine_args,
)
+from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys
class TestStageType:
@@ -241,8 +255,9 @@ def test_default_diffusion_no_yaml(self):
def test_default_diffusion_with_parallel_config(self):
"""Test diffusion config calculates devices from parallel_config."""
+ @dataclass
class MockParallelConfig:
- world_size = 4
+ world_size: int = 4
kwargs = {
"parallel_config": MockParallelConfig(),
@@ -270,7 +285,7 @@ def test_cli_override_forwards_engine_registered_args(self):
stage = StageConfig(stage_id=0, model_stage="thinker", input_sources=[])
cli_overrides = {
"gpu_memory_utilization": 0.9, # Well-known param
- "custom_engine_flag": True, # Not in _INTERNAL_KEYS, so forwarded
+ "custom_engine_flag": True, # Not orchestrator-owned, so forwarded
}
overrides = StageConfigFactory._merge_cli_overrides(stage, cli_overrides)
@@ -311,6 +326,56 @@ def test_per_stage_override_excludes_internal_keys(self):
assert "batch_timeout" not in overrides
+class TestStageResolutionHelpers:
+ """Tests for shared stage override / filtering helpers."""
+
+ def test_build_stage_runtime_overrides_ignores_other_stage_and_internal_keys(self):
+ # Pass the same filter set the function uses by default
+ # (orchestrator-only fields plus SHARED_FIELDS so ``model`` is
+ # treated as not-per-stage-overridable).
+ overrides = build_stage_runtime_overrides(
+ 0,
+ {
+ "gpu_memory_utilization": 0.5,
+ "stage_0_gpu_memory_utilization": 0.9,
+ "stage_1_gpu_memory_utilization": 0.1,
+ "stage_0_model": "should_be_ignored",
+ "parallel_config": {"world_size": 2},
+ },
+ internal_keys=internal_blacklist_keys() | SHARED_FIELDS,
+ )
+
+ assert overrides["gpu_memory_utilization"] == 0.9
+ assert "model" not in overrides
+ assert "parallel_config" not in overrides
+
+ def test_strip_parent_engine_args_reports_only_surprising_parent_overrides(self):
+ from dataclasses import fields as dc_fields
+
+ from vllm.engine.arg_utils import EngineArgs
+
+ parent_fields = {f.name: f for f in dc_fields(EngineArgs)}
+ filtered, overridden = strip_parent_engine_args(
+ {
+ "model": "some/model",
+ "stage_configs_path": "/tmp/stages.yaml",
+ "tensor_parallel_size": 4,
+ "worker_extension_cls": "some.Extension",
+ "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"},
+ },
+ parent_fields=parent_fields,
+ keep_keys={"worker_extension_cls"},
+ strip_keys={"stage_configs_path"},
+ no_warn_keys={"model"},
+ )
+
+ assert filtered == {
+ "worker_extension_cls": "some.Extension",
+ "custom_pipeline_args": {"pipeline_class": "demo.Pipeline"},
+ }
+ assert overridden == ["tensor_parallel_size"]
+
+
class TestPipelineYamlParsing:
"""Tests for pipeline YAML file parsing (@ZJY0516)."""
@@ -609,16 +674,638 @@ def test_parse_missing_async_chunk_defaults_false(self, tmp_path):
assert pipeline.async_chunk is False
-class TestArchitectureFallback:
- """Tests for architecture-based model detection fallback."""
+class TestPipelineDiscovery:
+ """Tests for the central pipeline registry (``pipeline_registry._VLLM_OMNI_PIPELINES``)."""
+
+ def test_registry_has_known_models(self):
+ """Built-in pipelines are lazy-loaded from the central declaration
+ on first access; no eager import or discovery walk needed."""
+ # ``in`` triggers the lazy-map lookup without forcing a load.
+ assert "qwen2_5_omni" in _PIPELINE_REGISTRY
+ assert "qwen3_omni_moe" in _PIPELINE_REGISTRY
+ assert "qwen3_tts" in _PIPELINE_REGISTRY
+
+ def test_registry_loads_pipeline_on_getitem(self):
+ """Looking up a registered model_type returns the matching PipelineConfig."""
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ assert pipeline.model_type == "qwen3_omni_moe"
+ assert len(pipeline.stages) == 3 # thinker + talker + code2wav
+
+ def test_registry_returns_none_for_unknown(self):
+ """Unknown model_types aren't found; ``get()`` returns None."""
+ assert "definitely_not_a_real_model" not in _PIPELINE_REGISTRY
+ assert _PIPELINE_REGISTRY.get("definitely_not_a_real_model") is None
+
+ def test_pipeline_config_supports_hf_architectures(self):
+ """PipelineConfig accepts hf_architectures for HF-arch fallback
+ (replaces the old _ARCHITECTURE_MODELS dict)."""
+ p = PipelineConfig(
+ model_type="custom_collide",
+ hf_architectures=("SomeCollidingArch",),
+ )
+ assert p.hf_architectures == ("SomeCollidingArch",)
+
+
+class TestStagePipelineConfig:
+ def test_frozen(self):
+ s = StagePipelineConfig(stage_id=0, model_stage="a")
+ with pytest.raises(AttributeError):
+ s.model_stage = "changed"
+
+ def test_defaults(self):
+ s = StagePipelineConfig(stage_id=0, model_stage="a")
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.input_sources == ()
+ assert s.final_output is False
+ assert s.sampling_constraints == {}
+ assert s.engine_output_type is None
+
+
+class TestPipelineConfigNew:
+ def test_frozen(self):
+ p = PipelineConfig(model_type="t", model_arch="A")
+ with pytest.raises(AttributeError):
+ p.model_type = "changed"
+
+ def test_validate_valid(self):
+ p = PipelineConfig(
+ model_type="t",
+ model_arch="A",
+ stages=(
+ StagePipelineConfig(stage_id=0, model_stage="a"),
+ StagePipelineConfig(stage_id=1, model_stage="b", input_sources=(0,)),
+ ),
+ )
+ assert p.validate() == []
+
+ def test_validate_no_stages(self):
+ p = PipelineConfig(model_type="t", model_arch="A")
+ assert any("no stages" in e.lower() for e in p.validate())
+
+ def test_get_scheduler_cls(self):
+ p = PipelineConfig(
+ model_type="t",
+ model_arch="A",
+ stages=(
+ StagePipelineConfig(stage_id=0, model_stage="a", execution_type=StageExecutionType.LLM_AR),
+ StagePipelineConfig(
+ stage_id=1, model_stage="b", execution_type=StageExecutionType.LLM_GENERATION, input_sources=(0,)
+ ),
+ ),
+ )
+ assert "OmniARScheduler" in p.get_scheduler_cls(0)
+ assert "OmniGenerationScheduler" in p.get_scheduler_cls(1)
+
+
+class TestExecutionTypeToScheduler:
+ def test_all_types_mapped(self):
+ for et in StageExecutionType:
+ assert et in _EXECUTION_TYPE_TO_SCHEDULER
+
+
+class TestPipelineRegistry:
+ def test_register_and_lookup(self):
+ p = PipelineConfig(
+ model_type="__test_only__",
+ model_arch="A",
+ stages=(StagePipelineConfig(stage_id=0, model_stage="a"),),
+ )
+ register_pipeline(p)
+ assert _PIPELINE_REGISTRY["__test_only__"] is p
+ del _PIPELINE_REGISTRY["__test_only__"]
+
+
+class TestDeployConfigLoading:
+ def test_load_deploy_config(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ assert len(deploy.stages) == 3
+ assert deploy.async_chunk is True
+ assert deploy.connectors is not None
+ assert deploy.platforms is not None
+
+ def test_merge_pipeline_deploy(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ assert len(stages) == 3
+ s0 = stages[0]
+ assert s0.model_stage == "thinker"
+ assert s0.yaml_engine_args["model_arch"] == "Qwen3OmniMoeForConditionalGeneration"
+ assert s0.yaml_engine_args["engine_output_type"] == "latent"
+ assert s0.yaml_extras["default_sampling_params"]["detokenize"] is True
+
+ def test_merge_pipeline_deploy_preserves_num_replicas(self, tmp_path):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not base.exists():
+ pytest.skip("Deploy config not found")
+
+ overlay = tmp_path / "multi_replicas.yaml"
+ overlay.write_text(f'base_config: {base}\nstages:\n - stage_id: 1\n devices: "1,2"\n num_replicas: 2\n')
+
+ deploy = load_deploy_config(overlay)
+ assert deploy.stages[1].num_replicas == 2
+
+ stages = merge_pipeline_deploy(pipeline, deploy)
+ assert stages[1].yaml_runtime["devices"] == "1,2"
+ assert stages[1].yaml_runtime["num_replicas"] == 2
+
+
+class TestQwen3OmniPipeline:
+ def test_registered(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ p = _PIPELINE_REGISTRY.get("qwen3_omni_moe")
+ assert p is not None
+ assert p.model_arch == "Qwen3OmniMoeForConditionalGeneration"
+ assert len(p.stages) == 3
+ assert p.validate() == []
+
+ def test_thinker(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(0)
+ assert s.model_stage == "thinker"
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.owns_tokenizer is True
+ assert s.engine_output_type == "latent"
+ assert s.sampling_constraints["detokenize"] is True
+
+ def test_talker(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(1)
+ assert s.input_sources == (0,)
+ assert s.sampling_constraints["stop_token_ids"] == [2150]
+ assert s.custom_process_input_func is not None
+ assert s.custom_process_next_stage_input_func is not None
+
+ def test_code2wav(self):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_omni_moe"].get_stage(2)
+ assert s.execution_type == StageExecutionType.LLM_GENERATION
+ assert s.final_output_type == "audio"
+ assert s.custom_process_input_func is not None
+
+
+class TestQwen2_5OmniPipeline:
+ def test_registered(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ p = _PIPELINE_REGISTRY.get("qwen2_5_omni")
+ assert p is not None
+ assert p.model_arch == "Qwen2_5OmniForConditionalGeneration"
+ assert len(p.stages) == 3
+ assert p.validate() == []
+
+ def test_thinker(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(0)
+ assert s.model_stage == "thinker"
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.owns_tokenizer is True
+ assert s.engine_output_type == "latent"
+ assert s.requires_multimodal_data is True
+
+ def test_talker(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(1)
+ assert s.input_sources == (0,)
+ assert s.sampling_constraints["stop_token_ids"] == [8294]
+ assert s.custom_process_input_func is not None
+
+ def test_code2wav(self):
+ import vllm_omni.model_executor.models.qwen2_5_omni.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen2_5_omni"].get_stage(2)
+ assert s.execution_type == StageExecutionType.LLM_GENERATION
+ assert s.final_output_type == "audio"
+ assert s.engine_output_type == "audio"
+
+
+class TestQwen3TTSPipeline:
+ def test_registered(self):
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+
+ p = _PIPELINE_REGISTRY.get("qwen3_tts")
+ assert p is not None
+ assert p.model_arch == "Qwen3TTSTalkerForConditionalGeneration"
+ assert len(p.stages) == 2
+ assert p.validate() == []
+
+ def test_talker_stage(self):
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(0)
+ assert s.model_stage == "qwen3_tts"
+ assert s.execution_type == StageExecutionType.LLM_AR
+ assert s.owns_tokenizer is True
+ assert s.engine_output_type == "latent"
+ assert s.sampling_constraints["stop_token_ids"] == [2150]
+ # Stage 0 inherits the pipeline-level model_arch
+ assert s.model_arch is None
+
+ def test_code2wav_stage_has_per_stage_model_arch(self):
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+
+ s = _PIPELINE_REGISTRY["qwen3_tts"].get_stage(1)
+ assert s.execution_type == StageExecutionType.LLM_GENERATION
+ assert s.final_output_type == "audio"
+ assert s.engine_output_type == "audio"
+ # Per-stage model_arch override (different from pipeline-level talker)
+ assert s.model_arch == "Qwen3TTSCode2Wav"
+ # tts_args is passed through via extras
+ assert s.extras["tts_args"]["max_instructions_length"] == 500
+
+ def test_per_stage_model_arch_flows_through_merge(self, tmp_path):
+ """Verify the new ps.model_arch override survives merge_pipeline_deploy."""
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_tts.yaml"
+ if not deploy_path.exists():
+ pytest.skip("qwen3_tts deploy yaml not found")
+
+ deploy = load_deploy_config(deploy_path)
+ pipeline = _PIPELINE_REGISTRY["qwen3_tts"]
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ # Stage 0 inherits pipeline-level model_arch
+ assert stages[0].yaml_engine_args["model_arch"] == "Qwen3TTSTalkerForConditionalGeneration"
+ # Stage 1 uses its per-stage override
+ assert stages[1].yaml_engine_args["model_arch"] == "Qwen3TTSCode2Wav"
+
+
+class TestBaseConfigInheritance:
+ """Test deploy YAML base_config inheritance."""
+
+ def test_ci_inherits_from_main(self):
+ from tests.helpers.stage_config import get_deploy_config_path
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml"))
+ if not ci_path.exists():
+ pytest.skip("CI deploy config not found")
+
+ deploy = load_deploy_config(ci_path)
+ assert len(deploy.stages) == 3
+ # CI overrides
+ assert deploy.stages[0].engine_extras.get("load_format") == "dummy"
+ assert deploy.stages[0].max_num_seqs == 5
+ # Inherited from base
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+ assert deploy.connectors is not None
+ assert "connector_of_shared_memory" in deploy.connectors
+ # CI overlay explicitly sets async_chunk: False (see
+ # tests.helpers.stage_config._CI_OVERLAYS and PR #2383 discussion). Overlay
+ # bool overrides base even when the base yaml has async_chunk: true.
+ assert deploy.async_chunk is False
+
+ def test_ci_sampling_merge(self):
+ from tests.helpers.stage_config import get_deploy_config_path
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ ci_path = Path(get_deploy_config_path("ci/qwen3_omni_moe.yaml"))
+ if not ci_path.exists():
+ pytest.skip("CI deploy config not found")
+
+ deploy = load_deploy_config(ci_path)
+ s0 = deploy.stages[0].default_sampling_params
+ # CI overrides max_tokens
+ assert s0["max_tokens"] == 150
+ # Inherited from base
+ assert s0["temperature"] == 0.4
+ assert s0["seed"] == 42
+
+ def test_pure_inheritance_overlay(self, tmp_path):
+ """An overlay with only ``base_config`` inherits everything."""
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not base.exists():
+ pytest.skip("Base deploy config not found")
+
+ overlay = tmp_path / "overlay.yaml"
+ overlay.write_text(f"base_config: {base}\n")
+
+ deploy = load_deploy_config(overlay)
+ assert len(deploy.stages) == 3
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+
+ def test_single_field_overlay(self, tmp_path):
+ """An overlay overriding one stage field merges with the base."""
+ from vllm_omni.config.stage_config import load_deploy_config
+
+ base = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not base.exists():
+ pytest.skip("Base deploy config not found")
+
+ overlay = tmp_path / "overlay.yaml"
+ overlay.write_text(f"base_config: {base}\nstages:\n - stage_id: 2\n max_num_batched_tokens: 1000000\n")
+
+ deploy = load_deploy_config(overlay)
+ assert deploy.stages[2].max_num_batched_tokens == 1000000
+ # Rest inherited
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+
+
+class TestPlatformOverrides:
+ """Test platform-specific deploy config overrides."""
+
+ def test_npu_overrides(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ deploy = _apply_platform_overrides(deploy, platform="npu")
+
+ assert deploy.stages[0].gpu_memory_utilization == 0.6
+ assert deploy.stages[0].tensor_parallel_size == 2
+ assert deploy.stages[0].devices == "0,1"
+ # Stage 2 unaffected fields stay at base
+ assert deploy.stages[2].enforce_eager is True
+
+ def test_xpu_overrides(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ deploy = _apply_platform_overrides(deploy, platform="xpu")
+
+ assert deploy.stages[0].tensor_parallel_size == 4
+ assert deploy.stages[0].devices == "0,1,2,3"
+ assert deploy.stages[0].engine_extras.get("max_cudagraph_capture_size") == 0
+
+ def test_unknown_platform_noop(self):
+ from pathlib import Path
+
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ original_mem = deploy.stages[0].gpu_memory_utilization
+ deploy = _apply_platform_overrides(deploy, platform="unknown_hw")
+ assert deploy.stages[0].gpu_memory_utilization == original_mem
+
+ def test_platforms_deep_merge_inheritance(self, tmp_path):
+ """Overlay's platforms: block layers onto base's, per-stage."""
+ from vllm_omni.config.stage_config import _apply_platform_overrides, load_deploy_config
+
+ base = tmp_path / "base.yaml"
+ base.write_text(
+ "stages:\n"
+ " - stage_id: 0\n"
+ " gpu_memory_utilization: 0.9\n"
+ "platforms:\n"
+ " rocm:\n"
+ " stages:\n"
+ " - stage_id: 0\n"
+ " enforce_eager: true\n"
+ )
+ overlay = tmp_path / "overlay.yaml"
+ overlay.write_text(
+ f"base_config: {base.name}\n"
+ "platforms:\n"
+ " rocm:\n"
+ " stages:\n"
+ " - stage_id: 0\n"
+ " max_num_seqs: 1\n"
+ )
+
+ deploy = load_deploy_config(overlay)
+ deploy = _apply_platform_overrides(deploy, platform="rocm")
+ # Both base's enforce_eager and overlay's max_num_seqs should apply.
+ assert deploy.stages[0].enforce_eager is True
+ assert deploy.stages[0].max_num_seqs == 1
+ # Inherited stage default not touched by overlay platforms section.
+ assert deploy.stages[0].gpu_memory_utilization == 0.9
+
+
+class TestCLIOverrideFlow:
+ """Test --stage-overrides JSON merge into StageConfig."""
+
+ def test_stage_overrides_merge(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ # Simulate --stage-overrides '{"0": {"gpu_memory_utilization": 0.5}}'
+ overrides = {"stage_0_gpu_memory_utilization": 0.5}
+ stages[0].runtime_overrides = StageConfigFactory._merge_cli_overrides(stages[0], overrides)
+ assert stages[0].runtime_overrides["gpu_memory_utilization"] == 0.5
+
+ def test_global_override_applies_to_all(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
+
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
+
+ overrides = {"enforce_eager": True}
+ for s in stages:
+ s.runtime_overrides = StageConfigFactory._merge_cli_overrides(s, overrides)
+ assert s.runtime_overrides["enforce_eager"] is True
+
+
+class TestCLIExplicitPrecedence:
+ """Verify YAML > argparse defaults; explicit CLI args > YAML."""
+
+ def _stages(self, cli_overrides, cli_explicit_keys):
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+
+ return StageConfigFactory._create_from_registry(
+ "qwen3_omni_moe",
+ cli_overrides=cli_overrides,
+ cli_explicit_keys=cli_explicit_keys,
+ )
+
+ def test_explicit_cli_overrides_yaml(self):
+ """User-typed --max-num-seqs wins over the deploy YAML value."""
+ stages = self._stages(
+ cli_overrides={"max_num_seqs": 999},
+ cli_explicit_keys={"max_num_seqs"},
+ )
+ # Stage 2 yaml has max_num_seqs=1; explicit CLI must beat it.
+ assert stages[2].runtime_overrides.get("max_num_seqs") == 999
+
+ def test_default_cli_does_not_override_yaml(self):
+ """Argparse defaults must NOT clobber values that are present in YAML."""
+ stages = self._stages(
+ cli_overrides={"max_num_seqs": 256},
+ cli_explicit_keys=set(), # user typed nothing
+ )
+ # Stage 2's YAML value (1) should win because the user didn't type --max-num-seqs.
+ assert stages[2].runtime_overrides.get("max_num_seqs") != 256
+
+ def test_default_cli_fills_missing_yaml_field(self):
+ """Argparse defaults still fill fields the YAML doesn't set."""
+ stages = self._stages(
+ cli_overrides={"some_unrelated_knob": "fallback"},
+ cli_explicit_keys=set(),
+ )
+ # Field absent from YAML → CLI default flows through as a fallback.
+ assert stages[0].runtime_overrides.get("some_unrelated_knob") == "fallback"
+
+ def test_per_stage_overrides_always_explicit(self):
+ """``stage__*`` keys are always treated as explicit."""
+ stages = self._stages(
+ cli_overrides={"stage_0_gpu_memory_utilization": 0.42},
+ cli_explicit_keys=set(), # not in the explicit set, but per-stage
+ )
+ assert stages[0].runtime_overrides.get("gpu_memory_utilization") == 0.42
+
+ def test_none_explicit_set_treats_all_as_explicit(self):
+ """Programmatic Omni() callers (cli_explicit_keys=None) keep current behavior."""
+ stages = self._stages(
+ cli_overrides={"max_num_seqs": 999},
+ cli_explicit_keys=None,
+ )
+ assert stages[2].runtime_overrides.get("max_num_seqs") == 999
+
+ def test_explicit_async_chunk_false_overrides_yaml(self):
+ """``--no-async-chunk`` flips the deploy-level async_chunk to False even
+ when the YAML sets it to True. Verifies that the per-stage
+ ``async_chunk: True`` injection in ``merge_pipeline_deploy`` is skipped
+ and that ``async_chunk`` does not leak through ``_merge_cli_overrides``.
+ """
+ stages = self._stages(
+ cli_overrides={"async_chunk": False},
+ cli_explicit_keys={"async_chunk"},
+ )
+ # qwen3_omni_moe.yaml has `async_chunk: true`, so by default every
+ # stage's engine_args would carry it. With the explicit override, it
+ # must NOT show up.
+ for stage in stages:
+ assert stage.yaml_engine_args.get("async_chunk") is not True
+ assert stage.runtime_overrides.get("async_chunk") is None
+
+ def test_default_async_chunk_leaves_yaml_alone(self):
+ """An unset ``--async-chunk`` (default None) must leave the YAML's True
+ in force on every stage."""
+ stages = self._stages(
+ cli_overrides={"async_chunk": None},
+ cli_explicit_keys=set(),
+ )
+ # qwen3_omni_moe.yaml: `async_chunk: true` → injected on every stage.
+ for stage in stages:
+ assert stage.yaml_engine_args.get("async_chunk") is True
+
+ def test_explicit_enable_prefix_caching_overrides_yaml(self):
+ """``--enable-prefix-caching`` (global) flips every stage's
+ ``enable_prefix_caching`` to True regardless of the YAML default."""
+ stages = self._stages(
+ cli_overrides={"enable_prefix_caching": True},
+ cli_explicit_keys={"enable_prefix_caching"},
+ )
+ for stage in stages:
+ assert stage.runtime_overrides.get("enable_prefix_caching") is True
+
+ def test_async_chunk_dispatches_processors(self):
+ """A single ``qwen3_tts`` pipeline picks per-chunk vs end-to-end
+ processors based on ``deploy.async_chunk``, without needing a
+ separate variant pipeline registration."""
+ import vllm_omni.model_executor.models.qwen3_tts.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import (
+ _PIPELINE_REGISTRY,
+ DeployConfig,
+ merge_pipeline_deploy,
+ )
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_tts"]
+
+ # async_chunk=True → stage 0's per-chunk processor wires up, stage 1
+ # has no sync input processor.
+ async_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=True))
+ assert (
+ async_stages[0]
+ .yaml_engine_args.get("custom_process_next_stage_input_func", "")
+ .endswith("talker2code2wav_async_chunk")
+ )
+ assert async_stages[1].custom_process_input_func is None
+
+ # async_chunk=False → stage 0 has no streaming processor, stage 1's
+ # batch-end processor wires up.
+ sync_stages = merge_pipeline_deploy(pipeline, DeployConfig(async_chunk=False))
+ assert "custom_process_next_stage_input_func" not in sync_stages[0].yaml_engine_args
+ assert sync_stages[1].custom_process_input_func is not None
+ assert sync_stages[1].custom_process_input_func.endswith("talker2code2wav")
+
+
+class TestSamplingConstraintsPrecedence:
+ """Test that pipeline sampling_constraints override deploy defaults."""
+
+ def test_constraints_win(self):
+ from pathlib import Path
+
+ import vllm_omni.model_executor.models.qwen3_omni.pipeline # noqa: F401
+ from vllm_omni.config.stage_config import load_deploy_config, merge_pipeline_deploy
+
+ pipeline = _PIPELINE_REGISTRY["qwen3_omni_moe"]
+ deploy_path = Path(__file__).parent.parent / "vllm_omni" / "deploy" / "qwen3_omni_moe.yaml"
+ if not deploy_path.exists():
+ pytest.skip("Deploy config not found")
- def test_architecture_models_mapping_exists(self):
- """Test that _ARCHITECTURE_MODELS contains expected entries."""
- assert "MiMoAudioForConditionalGeneration" in StageConfigFactory._ARCHITECTURE_MODELS
- assert StageConfigFactory._ARCHITECTURE_MODELS["MiMoAudioForConditionalGeneration"] == "mimo_audio"
- assert "HunyuanImage3ForCausalMM" in StageConfigFactory._ARCHITECTURE_MODELS
- assert StageConfigFactory._ARCHITECTURE_MODELS["HunyuanImage3ForCausalMM"] == "hunyuan_image3"
+ deploy = load_deploy_config(deploy_path)
+ stages = merge_pipeline_deploy(pipeline, deploy)
- def test_mimo_audio_in_pipeline_models(self):
- """Test that mimo_audio is registered in PIPELINE_MODELS."""
- assert "mimo_audio" in StageConfigFactory.PIPELINE_MODELS
+ # Pipeline says detokenize=True for thinker, deploy can't override
+ assert stages[0].yaml_extras["default_sampling_params"]["detokenize"] is True
+ # Pipeline says stop_token_ids=[2150] for talker
+ assert stages[1].yaml_extras["default_sampling_params"]["stop_token_ids"] == [2150]
+ # Deploy temperature still flows through
+ assert stages[0].yaml_extras["default_sampling_params"]["temperature"] == 0.4
diff --git a/tests/test_diffusion_config_propagation.py b/tests/test_diffusion_config_propagation.py
index 7d6d9c43f05..eeb3505efe9 100644
--- a/tests/test_diffusion_config_propagation.py
+++ b/tests/test_diffusion_config_propagation.py
@@ -15,6 +15,7 @@
DiffusionParallelConfig,
OmniDiffusionConfig,
)
+from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
pytestmark = [pytest.mark.core_model, pytest.mark.cpu]
@@ -109,3 +110,12 @@ def test_extra_kwargs_forwarded(self):
ea = stages[0]["engine_args"]
assert ea["enforce_eager"] is True
assert ea["lora_path"] == "/tmp/lora"
+
+
+def test_qwen_image_edit_plus_sets_generic_multimodal_limit():
+ od_config = OmniDiffusionConfig(model="Qwen/Qwen-Image-Edit-2511", model_class_name="QwenImageEditPlusPipeline")
+
+ od_config.update_multimodal_support()
+
+ assert od_config.supports_multimodal_inputs is True
+ assert od_config.max_multimodal_image_inputs == QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
diff --git a/tests/test_version.py b/tests/test_version.py
new file mode 100644
index 00000000000..07e622a7d15
--- /dev/null
+++ b/tests/test_version.py
@@ -0,0 +1,58 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Unit tests for version compatibility warnings."""
+
+import warnings
+from unittest import mock
+
+import pytest
+
+from vllm_omni.version import warn_if_misaligned_vllm_version
+
+
+@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0))
+@mock.patch("vllm_omni.version.__version__", "0.19.0")
+@mock.patch("vllm.__version_tuple__", (0, 18, 0))
+@mock.patch("vllm.__version__", "0.18.0")
+def test_version_mismatch_warning():
+ """Ensure that we warn when vLLM and vLLM-Omni major/minor versions differ."""
+ with pytest.warns(RuntimeWarning, match="mismatched major/minor versions"):
+ warn_if_misaligned_vllm_version()
+
+
+@pytest.mark.parametrize(
+ "vllm_ver,vllm_tuple,omni_ver,omni_tuple",
+ [
+ ("0.19.0", (0, 19, 0), "0.19.5", (0, 19, 5)), # Patch differs
+ ("0.18.0", (0, 18, 0), "dev", (0, 0, "dev")), # Omni dev
+ ("dev", (0, 0, "dev"), "0.19.0", (0, 19, 0)), # vLLM dev
+ # Ensure local identifies don't matter for the warning
+ ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5", (0, 19, 0)),
+ ("0.19.0", (0, 19, 0), "0.19.5+bar", (0, 19, 0, "bar")),
+ ("0.19.0+foo", (0, 19, 0, "foo"), "0.19.5+bar", (0, 19, 0, "bar")),
+ ],
+)
+def test_no_warning_cases(vllm_ver, vllm_tuple, omni_ver, omni_tuple):
+ """Ensure we don't warn when minor versions match or either is a dev build."""
+ with (
+ mock.patch.multiple("vllm", __version__=vllm_ver, __version_tuple__=vllm_tuple),
+ mock.patch.multiple("vllm_omni.version", __version__=omni_ver, __version_tuple__=omni_tuple),
+ ):
+ with warnings.catch_warnings():
+ warnings.simplefilter("error")
+ warn_if_misaligned_vllm_version()
+
+
+@mock.patch("vllm_omni.version.__version_tuple__", (0, 19, 0))
+@mock.patch("vllm_omni.version.__version__", "0.19.0rc2.dev21")
+@mock.patch("vllm.__version_tuple__", (0, 18, 0))
+@mock.patch("vllm.__version__", "0.18.0")
+def test_warning_contains_version_strings():
+ """Ensure that the warning contains the full version strings."""
+ with pytest.warns(RuntimeWarning) as record:
+ warn_if_misaligned_vllm_version()
+
+ assert len(record) == 1
+ msg = str(record[0].message)
+ assert "0.19.0rc2.dev21" in msg
+ assert "0.18.0" in msg
diff --git a/tests/utils.py b/tests/utils.py
deleted file mode 100644
index 84edbbf3d11..00000000000
--- a/tests/utils.py
+++ /dev/null
@@ -1,621 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-# Some functions are copied from vllm/tests/utils.py
-import functools
-import os
-import signal
-import subprocess
-import sys
-import tempfile
-import threading
-import time
-from collections.abc import Callable
-from contextlib import ExitStack, contextmanager, suppress
-from typing import Any, Literal
-
-import cloudpickle
-import pytest
-import torch
-from typing_extensions import ParamSpec
-from vllm.platforms import current_platform
-from vllm.utils.torch_utils import cuda_device_count_stateless
-
-from vllm_omni.platforms import current_omni_platform
-
-_P = ParamSpec("_P")
-
-if current_platform.is_rocm():
- from amdsmi import (
- amdsmi_get_gpu_vram_usage,
- amdsmi_get_processor_handles,
- amdsmi_init,
- amdsmi_shut_down,
- )
-
- @contextmanager
- def _nvml():
- try:
- amdsmi_init()
- yield
- finally:
- amdsmi_shut_down()
-elif current_platform.is_cuda():
- from vllm.third_party.pynvml import (
- nvmlDeviceGetHandleByIndex,
- nvmlDeviceGetMemoryInfo,
- nvmlInit,
- nvmlShutdown,
- )
-
- @contextmanager
- def _nvml():
- try:
- nvmlInit()
- yield
- finally:
- nvmlShutdown()
-else:
-
- @contextmanager
- def _nvml():
- yield
-
-
-def get_physical_device_indices(devices):
- visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
- if visible_devices is None:
- return devices
-
- visible_indices = [int(x) for x in visible_devices.split(",")]
- index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
- return [index_mapping[i] for i in devices if i in index_mapping]
-
-
-@_nvml()
-def wait_for_gpu_memory_to_clear(
- *,
- devices: list[int],
- threshold_bytes: int | None = None,
- threshold_ratio: float | None = None,
- timeout_s: float = 120,
-) -> None:
- import gc
-
- assert threshold_bytes is not None or threshold_ratio is not None
- # Use nvml instead of pytorch to reduce measurement error from torch cuda
- # context.
- devices = get_physical_device_indices(devices)
- start_time = time.time()
-
- # Print waiting start information
- device_list = ", ".join(str(d) for d in devices)
- if threshold_bytes is not None:
- threshold_str = f"{threshold_bytes / 2**30:.2f} GiB"
- condition_str = f"Memory usage ≤ {threshold_str}"
- else:
- threshold_percent = threshold_ratio * 100
- threshold_str = f"{threshold_percent:.1f}%"
- condition_str = f"Memory usage ratio ≤ {threshold_str}"
-
- print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}")
-
- # Define the is_free function based on threshold type
- if threshold_bytes is not None:
-
- def is_free(used, total):
- return used <= threshold_bytes / 2**30
- else:
-
- def is_free(used, total):
- return used / total <= threshold_ratio
-
- while True:
- output: dict[int, str] = {}
- output_raw: dict[int, tuple[float, float]] = {}
- for device in devices:
- if current_platform.is_rocm():
- dev_handle = amdsmi_get_processor_handles()[device]
- mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
- gb_used = mem_info["vram_used"] / 2**10
- gb_total = mem_info["vram_total"] / 2**10
- else:
- dev_handle = nvmlDeviceGetHandleByIndex(device)
- mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
- gb_used = mem_info.used / 2**30
- gb_total = mem_info.total / 2**30
- output_raw[device] = (gb_used, gb_total)
- # Format to more readable form
- usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0
- output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)"
-
- # Optimized GPU memory status print
- print("[GPU Memory Status] Current usage:")
- for device_id, mem_info in output.items():
- print(f" GPU {device_id}: {mem_info}")
-
- # Calculate waiting duration
- dur_s = time.time() - start_time
- elapsed_minutes = dur_s / 60
-
- # Check if all devices meet the condition
- if all(is_free(used, total) for used, total in output_raw.values()):
- # Optimized completion message
- print(f"[GPU Memory Freed] Devices {device_list} meet memory condition")
- print(f" Condition: {condition_str}")
- print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
- print(" Final status:")
- for device_id, mem_info in output.items():
- print(f" GPU {device_id}: {mem_info}")
- break
-
- # Check timeout
- if dur_s >= timeout_s:
- raise ValueError(
- f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n"
- f"Condition: {condition_str}\n"
- f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices)
- )
-
- # Add waiting hint (optional)
- if dur_s > 10 and int(dur_s) % 10 == 0: # Show hint every 10 seconds
- print(f"Waiting... Already waited {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
-
- gc.collect()
- torch.cuda.empty_cache()
-
- time.sleep(5)
-
-
-def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
- """Decorator to fork a new process for each test function.
- See https://github.com/vllm-project/vllm/issues/7053 for more details.
- """
-
- @functools.wraps(func)
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- # Make the process the leader of its own process group
- # to avoid sending SIGTERM to the parent process
- os.setpgrp()
- from _pytest.outcomes import Skipped
-
- # Create a unique temporary file to store exception info from child
- # process. Use test function name and process ID to avoid collisions.
- with (
- tempfile.NamedTemporaryFile(
- delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc"
- ) as exc_file,
- ExitStack() as delete_after,
- ):
- exc_file_path = exc_file.name
- delete_after.callback(os.remove, exc_file_path)
-
- pid = os.fork()
- print(f"Fork a new process to run a test {pid}")
- if pid == 0:
- # Parent process responsible for deleting, don't delete
- # in child.
- delete_after.pop_all()
- try:
- func(*args, **kwargs)
- except Skipped as e:
- # convert Skipped to exit code 0
- print(str(e))
- os._exit(0)
- except Exception as e:
- import traceback
-
- tb_string = traceback.format_exc()
-
- # Try to serialize the exception object first
- exc_to_serialize: dict[str, Any]
- try:
- # First, try to pickle the actual exception with
- # its traceback.
- exc_to_serialize = {"pickled_exception": e}
- # Test if it can be pickled
- cloudpickle.dumps(exc_to_serialize)
- except (Exception, KeyboardInterrupt):
- # Fall back to string-based approach.
- exc_to_serialize = {
- "exception_type": type(e).__name__,
- "exception_msg": str(e),
- "traceback": tb_string,
- }
- try:
- with open(exc_file_path, "wb") as f:
- cloudpickle.dump(exc_to_serialize, f)
- except Exception:
- # Fallback: just print the traceback.
- print(tb_string)
- os._exit(1)
- else:
- os._exit(0)
- else:
- pgid = os.getpgid(pid)
- _pid, _exitcode = os.waitpid(pid, 0)
- # ignore SIGTERM signal itself
- old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
- # kill all child processes
- os.killpg(pgid, signal.SIGTERM)
- # restore the signal handler
- signal.signal(signal.SIGTERM, old_signal_handler)
- if _exitcode != 0:
- # Try to read the exception from the child process
- exc_info = {}
- if os.path.exists(exc_file_path):
- with suppress(Exception), open(exc_file_path, "rb") as f:
- exc_info = cloudpickle.load(f)
-
- if (original_exception := exc_info.get("pickled_exception")) is not None:
- # Re-raise the actual exception object if it was
- # successfully pickled.
- assert isinstance(original_exception, Exception)
- raise original_exception
-
- if (original_tb := exc_info.get("traceback")) is not None:
- # Use string-based traceback for fallback case
- raise AssertionError(
- f"Test {func.__name__} failed when called with"
- f" args {args} and kwargs {kwargs}"
- f" (exit code: {_exitcode}):\n{original_tb}"
- ) from None
-
- # Fallback to the original generic error
- raise AssertionError(
- f"function {func.__name__} failed when called with"
- f" args {args} and kwargs {kwargs}"
- f" (exit code: {_exitcode})"
- ) from None
-
- return wrapper
-
-
-def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]:
- """Decorator to spawn a new process for each test function."""
-
- @functools.wraps(f)
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- # Check if we're already in a subprocess
- if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
- # If we are, just run the function directly
- return f(*args, **kwargs)
-
- import torch.multiprocessing as mp
-
- with suppress(RuntimeError):
- mp.set_start_method("spawn")
-
- # Get the module
- module_name = f.__module__
-
- # Create a process with environment variable set
- env = os.environ.copy()
- env["RUNNING_IN_SUBPROCESS"] = "1"
-
- with tempfile.TemporaryDirectory() as tempdir:
- output_filepath = os.path.join(tempdir, "new_process.tmp")
-
- # `cloudpickle` allows pickling complex functions directly
- input_bytes = cloudpickle.dumps((f, output_filepath))
-
- cmd = [sys.executable, "-m", f"{module_name}"]
-
- returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env)
-
- # check if the subprocess is successful
- try:
- returned.check_returncode()
- except Exception as e:
- # wrap raised exception to provide more information
- raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e
-
- return wrapper
-
-
-def create_new_process_for_each_test(
- method: Literal["spawn", "fork"] | None = None,
-) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
- """Creates a decorator that runs each test function in a new process.
-
- Args:
- method: The process creation method. Can be either "spawn" or "fork".
- If not specified, it defaults to "spawn" on ROCm and XPU
- platforms and "fork" otherwise.
-
- Returns:
- A decorator to run test functions in separate processes.
- """
- if method is None:
- # TODO: Spawn is not working correctly on ROCm
- # The test content will not run and tests passed immediately.
- # For now, using `fork` for ROCm as it can run with `fork`
- # and tests are running correctly.
- use_spawn = current_platform.is_xpu()
- method = "spawn" if use_spawn else "fork"
-
- assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
-
- if method == "fork":
- return fork_new_process_for_each_test
-
- return spawn_new_process_for_each_test
-
-
-def cuda_marks(*, res: str, num_cards: int):
- """
- Get a collection of pytest marks to apply for `@cuda_test`.
-
- Args:
- res: Resource type, e.g., "L4" or "H100".
- num_cards: Number of GPU cards required.
-
- Returns:
- List of pytest marks to apply.
- """
- test_platform_detail = pytest.mark.cuda
-
- if res == "L4":
- test_resource = pytest.mark.L4
- elif res == "H100":
- test_resource = pytest.mark.H100
- else:
- raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100")
-
- marks = [test_resource, test_platform_detail]
-
- if num_cards == 1:
- return marks
- else:
- test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards)
- test_skipif = pytest.mark.skipif_cuda(
- cuda_device_count_stateless() < num_cards,
- reason=f"Need at least {num_cards} CUDA GPUs to run the test.",
- )
- return marks + [test_distributed, test_skipif]
-
-
-def rocm_marks(*, res: str, num_cards: int):
- """
- Get a collection of pytest marks to apply for `@rocm_test`.
-
- Args:
- res: Resource type, e.g., "MI325".
- num_cards: Number of GPU cards required.
-
- Returns:
- List of pytest marks to apply.
- """
- test_platform_detail = pytest.mark.rocm
-
- if res == "MI325":
- test_resource = pytest.mark.MI325
- else:
- raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325")
-
- marks = [test_resource, test_platform_detail]
-
- if num_cards == 1:
- return marks
- else:
- test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
- # TODO: add ROCm support for `skipif_rocm` marker
- return marks + [test_distributed]
-
-
-def xpu_marks(*, res: str, num_cards: int):
- """
- Get a collection of pytest marks to apply for `@xpu_test`.
-
- Args:
- res: Resource type, e.g., "B60".
- num_cards: Number of GPU cards required.
-
- Returns:
- List of pytest marks to apply.
- """
- test_platform_detail = pytest.mark.xpu
-
- if res == "B60":
- test_resource = pytest.mark.B60
- else:
- raise ValueError(f"Invalid XPU resource type: {res}. Supported: B60")
-
- marks = [test_resource, test_platform_detail]
-
- if num_cards == 1:
- return marks
- else:
- test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
- # TODO: add XPU support for `skipif_xpu` marker
- return marks + [test_distributed]
-
-
-def musa_marks(*, res: str, num_cards: int):
- """
- Get a collection of pytest marks to apply for `@musa_test`.
-
- Args:
- res: Resource type, e.g., "S5000".
- num_cards: Number of GPU cards required.
-
- Returns:
- List of pytest marks to apply.
- """
- test_platform_detail = pytest.mark.musa
-
- if res == "S5000":
- test_resource = pytest.mark.S5000
- else:
- raise ValueError(f"Invalid MUSA resource type: {res}. Supported: S5000")
-
- marks = [test_resource, test_platform_detail]
-
- if num_cards == 1:
- return marks
- else:
- test_distributed = pytest.mark.distributed_musa(num_cards=num_cards)
- # TODO: add MUSA support for `skipif_musa` marker
- return marks + [test_distributed]
-
-
-def gpu_marks(*, res: str, num_cards: int):
- """
- Get a collection of pytest marks to apply for `@gpu_test`.
- Platform is automatically determined based on resource type.
-
- Args:
- res: Resource type, e.g., "L4", "H100" for CUDA, or "MI325" for ROCm, or "B60" for XPU, or "S5000" for MUSA.
- num_cards: Number of GPU cards required.
-
- Returns:
- List of pytest marks to apply.
- """
- test_platform = pytest.mark.gpu
- if res in ("L4", "H100"):
- return [test_platform] + cuda_marks(res=res, num_cards=num_cards)
- if res == "MI325":
- return [test_platform] + rocm_marks(res=res, num_cards=num_cards)
- if res == "B60":
- return [test_platform] + xpu_marks(res=res, num_cards=num_cards)
- if res == "S5000":
- return [test_platform] + musa_marks(res=res, num_cards=num_cards)
- raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325, B60, S5000")
-
-
-def npu_marks(*, res: str, num_cards: int):
- """Get a collection of pytest marks to apply for `@npu_test`."""
- test_platform = pytest.mark.npu
- if res == "A2":
- test_resource = pytest.mark.A2
- elif res == "A3":
- test_resource = pytest.mark.A3
- else:
- # TODO: Currently we don't have various NPU card types defined
- # Use None to skip resource-specific marking for unknown types
- test_resource = None
-
- if num_cards == 1:
- return [mark for mark in [test_platform, test_resource] if mark is not None]
- else:
- # Multiple cards scenario needs distributed_npu mark
- test_distributed = pytest.mark.distributed_npu(num_cards=num_cards)
- # TODO: add NPU support for `skipif_npu` marker
- return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None]
-
-
-def hardware_marks(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
- """
- Get a collection of pytest marks to apply for `@hardware_test`,
- including CUDA, ROCm, XPU, NPU, and MUSA,
- based on the specified platforms and resources.
- """
- # Validate platforms
- # Don't validate platform details in this decorator
- for platform, _ in res.items():
- if platform not in ("cuda", "rocm", "xpu", "npu", "musa"):
- raise ValueError(f"Unsupported platform: {platform}")
-
- # Normalize num_cards
- if isinstance(num_cards, int):
- num_cards_dict = {platform: num_cards for platform in res.keys()}
- else:
- num_cards_dict = num_cards
- for platform in num_cards_dict.keys():
- if platform not in res:
- raise ValueError(
- f"Platform '{platform}' in num_cards but not in res. Available platforms: {list(res.keys())}"
- )
- for platform in res.keys():
- if platform not in num_cards_dict:
- num_cards_dict[platform] = 1
-
- # Collect marks from all platforms
- all_marks: list[pytest.MarkDecorator] = []
- for platform, resource in res.items():
- cards = num_cards_dict[platform]
- if platform == "cuda" or platform == "rocm" or platform == "xpu":
- marks = gpu_marks(res=resource, num_cards=cards)
- elif platform == "musa":
- marks = musa_marks(res=resource, num_cards=cards)
- elif platform == "npu":
- marks = npu_marks(res=resource, num_cards=cards)
- else:
- raise ValueError(f"Unsupported platform: {platform}")
- all_marks.extend(marks)
- return all_marks
-
-
-def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
- """
- Decorate a test for multiple hardware platforms with a single call.
- Automatically wraps the test with @create_new_process_for_each_test() for distributed tests.
-
- Args:
- res: Mapping from platform to resource type. Supported platforms/resources:
- - cuda: L4, H100
- - rocm: MI325
- - xpu: B60
- - npu: A2, A3
- - musa: S5000
- num_cards: Number of cards required. Can be:
- - int: same card count for all platforms (default: 1)
- - dict: per-platform card count, e.g., {"cuda": 2, "rocm": 2}
-
- Example:
- @hardware_test(
- res={"cuda": "L4", "rocm": "MI325", "npu": "A2", "musa": "S5000"},
- num_cards={"cuda": 2, "rocm": 2, "npu": 2, "musa": 2},
- )
- def test_multi_platform():
- ...
- """
- all_marks = hardware_marks(res=res, num_cards=num_cards)
-
- def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
- func = f
- for mark in reversed(all_marks):
- func = mark(func)
- return func
-
- return wrapper
-
-
-class DeviceMemoryMonitor:
- """Poll global device memory usage."""
-
- def __init__(self, device_index: int, interval: float = 0.05):
- self.device_index = device_index
- self.interval = interval
- self._peak_used_mb = 0.0
- self._stop_event = threading.Event()
- self._thread: threading.Thread | None = None
-
- def start(self) -> None:
- def monitor_loop() -> None:
- while not self._stop_event.is_set():
- try:
- with current_omni_platform.device(self.device_index):
- free_bytes, total_bytes = current_omni_platform.mem_get_info()
- used_mb = (total_bytes - free_bytes) / (1024**2)
- self._peak_used_mb = max(self._peak_used_mb, used_mb)
- except Exception:
- pass
- time.sleep(self.interval)
-
- self._thread = threading.Thread(target=monitor_loop, daemon=False)
- self._thread.start()
-
- def stop(self) -> None:
- if self._thread is None:
- return
- self._stop_event.set()
- self._thread.join(timeout=2.0)
-
- @property
- def peak_used_mb(self) -> float:
- fallback_alloc = current_omni_platform.max_memory_allocated(device=self.device_index) / (1024**2)
- fallback_reserved = current_omni_platform.max_memory_reserved(device=self.device_index) / (1024**2)
- return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
-
- def __del__(self):
- self.stop()
diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py
index 1c08a1543d2..819a7c8c3dd 100644
--- a/tools/pre_commit/check_pickle_imports.py
+++ b/tools/pre_commit/check_pickle_imports.py
@@ -16,8 +16,7 @@
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = {
- "tests/e2e/offline_inference/utils.py",
- "tests/utils.py",
+ "tests/helpers/process.py",
"vllm_omni/diffusion/distributed/group_coordinator.py",
"tests/diffusion/attention/test_attention_sp.py",
}
diff --git a/vllm_omni/__init__.py b/vllm_omni/__init__.py
index cec8b0af7e8..65ad79c725a 100644
--- a/vllm_omni/__init__.py
+++ b/vllm_omni/__init__.py
@@ -12,6 +12,12 @@
processing
"""
+# We import version early, because it will warn if vLLM / vLLM Omni
+# are not using the same major + minor version (if vLLM is installed).
+# We should do this before applying patch, because vLLM imports might
+# throw in patch if the versions differ.
+from .version import __version__, __version_tuple__ # isort:skip # noqa: F401
+
try:
from . import patch # noqa: F401
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency
@@ -25,8 +31,6 @@
from .config import OmniModelConfig
-from .version import __version__, __version_tuple__ # isort:skip
-
def __getattr__(name: str):
# Lazy import for AsyncOmni and Omni to avoid pulling in heavy
diff --git a/vllm_omni/config/__init__.py b/vllm_omni/config/__init__.py
index 2aa236e69f5..f02c0758805 100644
--- a/vllm_omni/config/__init__.py
+++ b/vllm_omni/config/__init__.py
@@ -5,10 +5,18 @@
from vllm_omni.config.lora import LoRAConfig
from vllm_omni.config.model import OmniModelConfig
from vllm_omni.config.stage_config import (
+ DeployConfig,
ModelPipeline,
+ PipelineConfig,
StageConfig,
StageConfigFactory,
+ StageDeployConfig,
+ StageExecutionType,
+ StagePipelineConfig,
StageType,
+ load_deploy_config,
+ merge_pipeline_deploy,
+ register_pipeline,
)
from vllm_omni.config.yaml_util import (
create_config,
@@ -24,6 +32,14 @@
"StageConfigFactory",
"ModelPipeline",
"StageType",
+ "StageExecutionType",
+ "StagePipelineConfig",
+ "PipelineConfig",
+ "StageDeployConfig",
+ "DeployConfig",
+ "load_deploy_config",
+ "merge_pipeline_deploy",
+ "register_pipeline",
"create_config",
"load_yaml_config",
"merge_configs",
diff --git a/vllm_omni/config/pipeline_registry.py b/vllm_omni/config/pipeline_registry.py
new file mode 100644
index 00000000000..c07bc2610c3
--- /dev/null
+++ b/vllm_omni/config/pipeline_registry.py
@@ -0,0 +1,55 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Central declarative registry of all vllm-omni pipelines.
+
+Mirrors the pattern in ``vllm/model_executor/models/registry.py``: each entry
+is ``model_type -> (module_path, variable_name)``, and the module is imported
+lazily on first lookup (see ``_LazyPipelineRegistry`` in
+``vllm_omni/config/stage_config.py``). Keeping every pipeline declared in one
+file makes it easy to spot a missing registration, which was the original
+motivation in https://github.com/vllm-project/vllm-omni/issues/2887 (item 4).
+
+Per-model ``pipeline.py`` modules still define the ``PipelineConfig`` instance;
+they just no longer need to self-register via ``register_pipeline(...)``.
+
+Adding a new pipeline:
+ 1. Define the ``PipelineConfig`` instance as a module-level variable in
+ ``vllm_omni/.../pipeline.py``.
+ 2. Add one line to ``_OMNI_PIPELINES`` or ``_DIFFUSION_PIPELINES`` below.
+
+``register_pipeline(config)`` in ``stage_config`` is still supported for
+out-of-tree plugins and tests that create pipelines at runtime; those override
+the entries declared here.
+"""
+
+from __future__ import annotations
+
+# --- Multi-stage omni pipelines (LLM-centric; audio / video I/O) ---
+_OMNI_PIPELINES: dict[str, tuple[str, str]] = {
+ # model_type -> (module_path, variable_name)
+ "qwen2_5_omni": (
+ "vllm_omni.model_executor.models.qwen2_5_omni.pipeline",
+ "QWEN2_5_OMNI_PIPELINE",
+ ),
+ "qwen2_5_omni_thinker_only": (
+ "vllm_omni.model_executor.models.qwen2_5_omni.pipeline",
+ "QWEN2_5_OMNI_THINKER_ONLY_PIPELINE",
+ ),
+ "qwen3_omni_moe": (
+ "vllm_omni.model_executor.models.qwen3_omni.pipeline",
+ "QWEN3_OMNI_PIPELINE",
+ ),
+ "qwen3_tts": (
+ "vllm_omni.model_executor.models.qwen3_tts.pipeline",
+ "QWEN3_TTS_PIPELINE",
+ ),
+}
+
+# --- Single-stage diffusion pipelines (populated in PR 3/N) ---
+_DIFFUSION_PIPELINES: dict[str, tuple[str, str]] = {}
+
+# Union view used by ``_LazyPipelineRegistry``; don't mutate at runtime.
+_VLLM_OMNI_PIPELINES: dict[str, tuple[str, str]] = {
+ **_OMNI_PIPELINES,
+ **_DIFFUSION_PIPELINES,
+}
diff --git a/vllm_omni/config/stage_config.py b/vllm_omni/config/stage_config.py
index a4e186c3bd2..ab975cafc3b 100644
--- a/vllm_omni/config/stage_config.py
+++ b/vllm_omni/config/stage_config.py
@@ -1,18 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-"""
-Stage Configuration System for vLLM-Omni.
-
-Pipeline structure (stages, types, data-flow) is defined in per-model YAML
-files and is set by model developers at integration time.
-Runtime parameters (gpu_memory_utilization, tp_size, etc.) come from CLI.
-"""
+"""Stage configuration system for vLLM-Omni."""
from __future__ import annotations
+import dataclasses
import re
import warnings
-from dataclasses import asdict, dataclass, field
+from dataclasses import asdict, dataclass, field, fields
from enum import Enum
from pathlib import Path
from typing import Any
@@ -20,76 +15,832 @@
from vllm.logger import init_logger
from vllm_omni.config.yaml_util import create_config, load_yaml_config, to_dict
+from vllm_omni.core.sched.omni_ar_scheduler import OmniARScheduler
+from vllm_omni.core.sched.omni_generation_scheduler import OmniGenerationScheduler
-# Pipeline YAMLs live alongside model code in model_executor/models//
_MODELS_DIR = Path(__file__).resolve().parent.parent / "model_executor" / "models"
def get_pipeline_path(model_dir: str, filename: str) -> Path:
- """Return the full path to a pipeline YAML file.
+ return _MODELS_DIR / model_dir / filename
+
+
+logger = init_logger(__name__)
+
+
+_STAGE_OVERRIDE_PATTERN = re.compile(r"^stage_(\d+)_(.+)$")
- Args:
- model_dir: Model subdirectory name (e.g., "qwen3_omni").
- filename: Name of the YAML file (e.g., "pipeline.yaml").
- Returns:
- Absolute path to the file.
+def build_stage_runtime_overrides(
+ stage_id: int,
+ cli_overrides: dict[str, Any],
+ *,
+ internal_keys: set[str] | frozenset[str] | None = None,
+) -> dict[str, Any]:
+ """Build per-stage runtime overrides from global and ``stage__*`` kwargs.
+
+ ``internal_keys`` defaults to the union of
+ ``arg_utils.internal_blacklist_keys()`` and ``arg_utils.SHARED_FIELDS``
+ so that neither orchestrator-only fields nor shared-pipeline fields
+ (``model`` / ``stage_configs_path`` / ``log_stats`` / ``stage_id``) leak
+ into a stage's per-stage runtime overrides — the orchestrator sets those
+ uniformly for every stage, they are not per-stage knobs. Callers can
+ pass an explicit set for tests or specialized flows.
"""
- return _MODELS_DIR / model_dir / filename
+ if internal_keys is None:
+ from vllm_omni.engine.arg_utils import SHARED_FIELDS, internal_blacklist_keys
+ internal_keys = internal_blacklist_keys() | SHARED_FIELDS
-logger = init_logger(__name__)
+ result: dict[str, Any] = {}
+
+ for key, value in cli_overrides.items():
+ if value is None or key in internal_keys:
+ continue
+
+ match = _STAGE_OVERRIDE_PATTERN.match(key)
+ if match is not None:
+ override_stage_id = int(match.group(1))
+ param_name = match.group(2)
+ if override_stage_id == stage_id and param_name not in internal_keys:
+ result[param_name] = value
+ continue
+
+ result[key] = value
+
+ return result
+
+
+def strip_parent_engine_args(
+ kwargs: dict[str, Any],
+ *,
+ parent_fields: dict[str, dataclasses.Field],
+ keep_keys: set[str] | frozenset[str] = frozenset(),
+ strip_keys: set[str] | frozenset[str] = frozenset(),
+ no_warn_keys: set[str] | frozenset[str] = frozenset(),
+) -> tuple[dict[str, Any], list[str]]:
+ """Strip parent ``EngineArgs`` fields before merging into stage YAML."""
+ overridden: list[str] = []
+ result: dict[str, Any] = {}
+
+ for key, value in kwargs.items():
+ if key in strip_keys:
+ continue
+
+ if key not in parent_fields or key in keep_keys:
+ result[key] = value
+ continue
+
+ field_def = parent_fields[key]
+ if field_def.default is not dataclasses.MISSING:
+ default = field_def.default
+ elif field_def.default_factory is not dataclasses.MISSING:
+ default = field_def.default_factory()
+ else:
+ default = dataclasses.MISSING
+
+ if default is dataclasses.MISSING or value is None:
+ continue
+
+ if dataclasses.is_dataclass(default) and not isinstance(default, type):
+ default = asdict(default)
+
+ if value != default and key not in no_warn_keys:
+ overridden.append(key)
+
+ return result, sorted(overridden)
class StageType(str, Enum):
"""Type of processing stage in the Omni pipeline."""
+ # TODO(@lishunyang12): remove once all models migrate to StageExecutionType
LLM = "llm"
DIFFUSION = "diffusion"
+class StageExecutionType(str, Enum):
+ """Merged StageType + WorkerType — 3 combinations today."""
+
+ LLM_AR = "llm_ar"
+ LLM_GENERATION = "llm_generation"
+ DIFFUSION = "diffusion"
+
+
+# Mapping class refs (not dotted-path strings) so module/class renames fail
+# at import time instead of lazily at scheduler resolution. YAML overrides
+# and downstream serialization still use the dotted-path string form; the
+# conversion happens at the map lookup site via _scheduler_path().
+_EXECUTION_TYPE_TO_SCHEDULER: dict[StageExecutionType, type | None] = {
+ StageExecutionType.LLM_AR: OmniARScheduler,
+ StageExecutionType.LLM_GENERATION: OmniGenerationScheduler,
+ StageExecutionType.DIFFUSION: None,
+}
+
+
+def _scheduler_path(cls: type | None) -> str | None:
+ """Return the dotted import path for a scheduler class (``None`` passes through)."""
+ if cls is None:
+ return None
+ return f"{cls.__module__}.{cls.__qualname__}"
+
+
+@dataclass(frozen=True)
+class StagePipelineConfig:
+ """Fixed topology for one stage (frozen, not user-configurable)."""
+
+ stage_id: int
+ model_stage: str
+ execution_type: StageExecutionType = StageExecutionType.LLM_AR
+ input_sources: tuple[int, ...] = ()
+ final_output: bool = False
+ final_output_type: str | None = None
+ owns_tokenizer: bool = False
+ requires_multimodal_data: bool = False
+ hf_config_name: str | None = None
+ engine_output_type: str | None = None
+ model_arch: str | None = None
+ sampling_constraints: dict[str, Any] = field(default_factory=dict)
+ custom_process_input_func: str | None = None
+ custom_process_next_stage_input_func: str | None = None
+ # Alternates picked by ``merge_pipeline_deploy`` based on ``deploy.async_chunk``.
+ async_chunk_process_next_stage_input_func: str | None = None
+ sync_process_input_func: str | None = None
+ prompt_expand_func: str | None = None
+ cfg_kv_collect_func: str | None = None
+ omni_kv_config: dict[str, Any] | None = None
+ extras: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass(frozen=True)
+class PipelineConfig:
+ """Complete pipeline topology for a model (frozen)."""
+
+ model_type: str
+ model_arch: str = ""
+ stages: tuple[StagePipelineConfig, ...] = ()
+ # HF architecture aliases: used by StageConfigFactory when the model's
+ # HF config reports a generic model_type that collides with a different
+ # model (e.g. MiMo Audio reports model_type="qwen2"). The factory
+ # matches ``hf_config.architectures[*]`` against this tuple to route
+ # to the correct pipeline. Leave empty for models with unique model_type.
+ hf_architectures: tuple[str, ...] = ()
+
+ def get_stage(self, stage_id: int) -> StagePipelineConfig | None:
+ """Look up a stage by its ID."""
+ for stage in self.stages:
+ if stage.stage_id == stage_id:
+ return stage
+ return None
+
+ def get_scheduler_cls(self, stage_id: int) -> str | None:
+ """Return the inferred scheduler class path for a stage.
+
+ Returns ``None`` for DIFFUSION stages (no vLLM scheduler). Raises
+ ``ValueError`` if ``stage_id`` doesn't exist in this pipeline, and
+ ``KeyError`` if ``execution_type`` isn't in the scheduler map.
+ """
+ stage = self.get_stage(stage_id)
+ if stage is None:
+ raise ValueError(f"Pipeline {self.model_type!r} has no stage with id {stage_id}")
+ return _scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER[stage.execution_type])
+
+ def validate(self) -> list[str]:
+ """Return list of topology errors (empty if valid)."""
+ errors: list[str] = []
+ if not self.stages:
+ errors.append("Pipeline has no stages defined")
+ return errors
+ stage_ids = [s.stage_id for s in self.stages]
+ if len(stage_ids) != len(set(stage_ids)):
+ errors.append("Duplicate stage IDs found")
+ stage_id_set = set(stage_ids)
+ for stage in self.stages:
+ for src in stage.input_sources:
+ if src not in stage_id_set:
+ errors.append(f"Stage {stage.stage_id} references non-existent input source {src}")
+ if src == stage.stage_id:
+ errors.append(f"Stage {stage.stage_id} references itself")
+ if not any(not s.input_sources for s in self.stages):
+ errors.append("No entry point (stage with empty input_sources)")
+ return errors
+
+
+class _LazyPipelineRegistry:
+ """Dict-like registry that lazy-loads pipelines from the central declaration.
+
+ In-tree pipelines are declared once in
+ ``vllm_omni/config/pipeline_registry.py`` as
+ ``model_type -> (module_path, variable_name)`` entries; the module is
+ imported only when the pipeline is first looked up. This mirrors the
+ pattern in ``vllm/model_executor/models/registry.py`` and addresses
+ https://github.com/vllm-project/vllm-omni/issues/2887 (item 4): having
+ every registration in one file makes a missing entry easy to spot.
+
+ Out-of-tree / dynamic registrations via ``register_pipeline()`` are stored
+ directly in ``_loaded`` and take precedence over the lazy-map entry with
+ the same ``model_type``.
+
+ The class exposes the subset of ``dict`` operations the rest of this
+ module relies on (``__contains__``, ``__getitem__``, ``__setitem__``,
+ ``get``, ``keys``, ``values``, ``items``, ``__iter__``), so existing call
+ sites don't need to change.
+ """
+
+ def __init__(self) -> None:
+ self._loaded: dict[str, PipelineConfig] = {}
+ # Populated lazily to avoid a circular import at module init time.
+ self._lazy_map: dict[str, tuple[str, str]] | None = None
+
+ def _get_lazy_map(self) -> dict[str, tuple[str, str]]:
+ if self._lazy_map is None:
+ from vllm_omni.config.pipeline_registry import _VLLM_OMNI_PIPELINES
+
+ self._lazy_map = _VLLM_OMNI_PIPELINES
+ return self._lazy_map
+
+ def _load_lazy(self, model_type: str) -> PipelineConfig | None:
+ entry = self._get_lazy_map().get(model_type)
+ if entry is None:
+ return None
+ module_path, var_name = entry
+ import importlib
+
+ try:
+ module = importlib.import_module(module_path)
+ except ImportError as exc:
+ logger.error(
+ "Failed to import pipeline module %r for %r: %s",
+ module_path,
+ model_type,
+ exc,
+ )
+ return None
+ pipeline = getattr(module, var_name, None)
+ if pipeline is None:
+ logger.error(
+ "Pipeline variable %r not found in module %r (registered for %r)",
+ var_name,
+ module_path,
+ model_type,
+ )
+ return None
+ errors = pipeline.validate()
+ if errors:
+ logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors)
+ self._loaded[model_type] = pipeline
+ return pipeline
+
+ def __contains__(self, model_type: str) -> bool:
+ if model_type in self._loaded:
+ return True
+ return model_type in self._get_lazy_map()
+
+ def __getitem__(self, model_type: str) -> PipelineConfig:
+ if model_type in self._loaded:
+ return self._loaded[model_type]
+ pipeline = self._load_lazy(model_type)
+ if pipeline is None:
+ raise KeyError(model_type)
+ return pipeline
+
+ def get(self, model_type: str, default: PipelineConfig | None = None) -> PipelineConfig | None:
+ if model_type in self._loaded:
+ return self._loaded[model_type]
+ pipeline = self._load_lazy(model_type)
+ return pipeline if pipeline is not None else default
+
+ def __setitem__(self, model_type: str, pipeline: PipelineConfig) -> None:
+ self._loaded[model_type] = pipeline
+
+ def __delitem__(self, model_type: str) -> None:
+ """Remove a dynamically-registered pipeline.
+
+ Only the dynamic-cache side of the registry can be mutated; the
+ central declarative registry is immutable at runtime. Calling ``del``
+ on a model_type that only exists in the central registry raises
+ ``KeyError``.
+ """
+ if model_type in self._loaded:
+ del self._loaded[model_type]
+ return
+ if model_type in self._get_lazy_map():
+ raise KeyError(
+ f"{model_type!r} is declared in the central pipeline_registry and "
+ "cannot be removed at runtime. Edit "
+ "vllm_omni/config/pipeline_registry.py to delete a built-in entry."
+ )
+ raise KeyError(model_type)
+
+ def keys(self) -> set[str]:
+ return set(self._get_lazy_map().keys()) | set(self._loaded.keys())
+
+ def values(self):
+ # Iterating values forces load of every lazy pipeline.
+ for key in self.keys():
+ yield self[key]
+
+ def items(self):
+ for key in self.keys():
+ yield key, self[key]
+
+ def __iter__(self):
+ return iter(self.keys())
+
+
+_PIPELINE_REGISTRY = _LazyPipelineRegistry()
+
+
+def register_pipeline(pipeline: PipelineConfig) -> None:
+ """Register a pipeline config dynamically.
+
+ In-tree pipelines are declared in ``pipeline_registry._VLLM_OMNI_PIPELINES``
+ and loaded lazily; calling ``register_pipeline`` is only needed for
+ out-of-tree plugins or tests that build a ``PipelineConfig`` at runtime.
+ A dynamic registration overrides the central-registry entry with the same
+ ``model_type``.
+ """
+ errors = pipeline.validate()
+ if errors:
+ logger.warning("Pipeline %s has issues: %s", pipeline.model_type, errors)
+ _PIPELINE_REGISTRY[pipeline.model_type] = pipeline
+
+
+_DEPLOY_DIR = Path(__file__).resolve().parent.parent / "deploy"
+
+
+@dataclass
+class StageDeployConfig:
+ """Per-stage deployment knobs.
+
+ Only fields whose value legitimately varies across stages of the same
+ pipeline live here (e.g. ``max_num_seqs`` on thinker vs talker,
+ ``devices`` for GPU placement). Pipeline-wide settings
+ (``trust_remote_code``, ``distributed_executor_backend``, ``dtype``,
+ ``quantization``, prefix/chunked prefill, DP/PP sizes) are declared at
+ the top level of ``DeployConfig`` and propagated to every stage.
+ """
+
+ stage_id: int
+ max_num_seqs: int = 64
+ gpu_memory_utilization: float = 0.9
+ tensor_parallel_size: int = 1
+ enforce_eager: bool = False
+ max_num_batched_tokens: int = 32768
+ max_model_len: int | None = None
+ async_scheduling: bool | None = None
+ devices: str = "0"
+ num_replicas: int = 1
+ output_connectors: dict[str, str] | None = None
+ input_connectors: dict[str, str] | None = None
+ default_sampling_params: dict[str, Any] | None = None
+ engine_extras: dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class DeployConfig:
+ """Loaded from deploy/.yaml — the only config file users edit.
+
+ Top-level fields (``trust_remote_code``, ``distributed_executor_backend``,
+ ``dtype``, ``quantization``, ``enable_prefix_caching``,
+ ``enable_chunked_prefill``, ``data_parallel_size``,
+ ``pipeline_parallel_size``) are pipeline-wide: they apply uniformly to
+ every stage. Fields that legitimately vary per stage live in the
+ individual ``StageDeployConfig`` entries under ``stages:``.
+ """
+
+ async_chunk: bool = True
+ connectors: dict[str, Any] | None = None
+ edges: list[dict[str, Any]] | None = None
+ stages: list[StageDeployConfig] = field(default_factory=list)
+ platforms: dict[str, Any] | None = None
+ # Overrides the auto-detected pipeline registry key for structural variants.
+ pipeline: str | None = None
+
+ # === Pipeline-wide engine settings (applied uniformly to every stage) ===
+ trust_remote_code: bool = True
+ distributed_executor_backend: str = "mp"
+ dtype: str | None = None
+ quantization: str | None = None
+ enable_prefix_caching: bool = False
+ enable_chunked_prefill: bool | None = None
+ data_parallel_size: int = 1
+ pipeline_parallel_size: int = 1
+
+
+_STAGE_NON_ENGINE_KEYS = frozenset(
+ {
+ "stage_id",
+ "devices",
+ "num_replicas",
+ "output_connectors",
+ "input_connectors",
+ "default_sampling_params",
+ "engine_extras",
+ }
+)
+
+# Fields on StageDeployConfig that are populated from engine_args dict
+_STAGE_DEPLOY_FIELDS = {f.name: f for f in fields(StageDeployConfig) if f.name not in _STAGE_NON_ENGINE_KEYS}
+
+
+def _parse_stage_deploy(stage_data: dict[str, Any]) -> StageDeployConfig:
+ """Parse a single stage entry from deploy YAML into StageDeployConfig."""
+ if "engine_args" in stage_data:
+ runtime_cfg = dict(stage_data.get("runtime", {}))
+ engine_args = dict(stage_data["engine_args"])
+ devices = runtime_cfg.get("devices", stage_data.get("devices", "0"))
+ num_replicas = runtime_cfg.get("num_replicas", stage_data.get("num_replicas", 1))
+ else:
+ engine_args = {k: v for k, v in stage_data.items() if k not in _STAGE_NON_ENGINE_KEYS and k != "stage_id"}
+ devices = stage_data.get("devices", "0")
+ num_replicas = stage_data.get("num_replicas", stage_data.get("runtime", {}).get("num_replicas", 1))
+
+ kwargs: dict[str, Any] = {
+ "stage_id": stage_data["stage_id"],
+ "devices": devices,
+ "num_replicas": int(num_replicas),
+ }
+ for name, f in _STAGE_DEPLOY_FIELDS.items():
+ if name in engine_args:
+ kwargs[name] = engine_args.pop(name)
+
+ kwargs["output_connectors"] = stage_data.get("output_connectors")
+ kwargs["input_connectors"] = stage_data.get("input_connectors")
+ kwargs["default_sampling_params"] = stage_data.get("default_sampling_params")
+ kwargs["engine_extras"] = engine_args
+ return StageDeployConfig(**kwargs)
+
+
+_DEEP_MERGE_KEYS = frozenset({"default_sampling_params", "engine_extras", "engine_args"})
+
+
+def _deep_merge_stage(base: dict, overlay: dict) -> dict:
+ """Deep-merge ``_DEEP_MERGE_KEYS`` so thin overlays don't drop base keys."""
+ merged = dict(base)
+ for k, v in overlay.items():
+ if k in _DEEP_MERGE_KEYS:
+ base_val = merged.get(k)
+ if isinstance(v, dict) and isinstance(base_val, dict):
+ merged[k] = {**base_val, **v}
+ continue
+ # Deep-merge key but at least one side isn't a dict: surface the
+ # silent clobber so mismatched YAML types don't get past review.
+ if base_val is not None:
+ logger.warning(
+ "Deep-merge key %r has non-dict value (base=%s, overlay=%s); "
+ "overlay will fully replace base instead of merging.",
+ k,
+ type(base_val).__name__,
+ type(v).__name__,
+ )
+ merged[k] = v
+ return merged
+
+
+def _merge_stage_lists(
+ base_stages: list[dict[str, Any]] | None,
+ overlay_stages: list[dict[str, Any]] | None,
+) -> list[dict[str, Any]]:
+ """Merge two ``stages:`` lists by ``stage_id`` (overlay wins per field)."""
+ by_id: dict[int, dict[str, Any]] = {s["stage_id"]: s for s in (base_stages or [])}
+ for overlay_stage in overlay_stages or []:
+ sid = overlay_stage["stage_id"]
+ if sid in by_id:
+ by_id[sid] = _deep_merge_stage(by_id[sid], overlay_stage)
+ else:
+ by_id[sid] = overlay_stage
+ return list(by_id.values())
+
+
+def _merge_platforms(
+ base: dict[str, Any] | None,
+ overlay: dict[str, Any] | None,
+) -> dict[str, Any] | None:
+ """Deep-merge two ``platforms:`` blocks per-platform, per-stage_id."""
+ if not base and not overlay:
+ return None
+ base = base or {}
+ overlay = overlay or {}
+ merged: dict[str, Any] = {}
+ for plat in set(base) | set(overlay):
+ bp = base.get(plat) or {}
+ op = overlay.get(plat) or {}
+ merged_plat = {**bp, **{k: v for k, v in op.items() if k != "stages"}}
+ merged_plat["stages"] = _merge_stage_lists(bp.get("stages"), op.get("stages"))
+ merged[plat] = merged_plat
+ return merged
+
+
+def resolve_deploy_yaml(path: str | Path) -> dict[str, Any]:
+ """Load a deploy YAML with optional ``base_config`` inheritance."""
+ raw_dict = to_dict(load_yaml_config(path))
+
+ base_path = raw_dict.pop("base_config", None)
+ if base_path is None:
+ return raw_dict
+
+ # Resolve relative to the overlay file's directory
+ base_path = Path(path).parent / base_path
+ base_dict = resolve_deploy_yaml(base_path)
+
+ # Merge top-level scalars: overlay wins. ``stages:`` and ``platforms:``
+ # are deep-merged below so an overlay can layer on top of the base.
+ merged = {
+ **base_dict,
+ **{k: v for k, v in raw_dict.items() if k not in ("stages", "platforms")},
+ }
+ merged["stages"] = _merge_stage_lists(base_dict.get("stages"), raw_dict.get("stages"))
+ merged_platforms = _merge_platforms(base_dict.get("platforms"), raw_dict.get("platforms"))
+ if merged_platforms is not None:
+ merged["platforms"] = merged_platforms
+
+ return merged
+
+
+def load_deploy_config(path: str | Path) -> DeployConfig:
+ """Load a deploy YAML (with optional base_config inheritance)."""
+ raw_dict = resolve_deploy_yaml(path)
+
+ stages = [_parse_stage_deploy(s) for s in raw_dict.get("stages", [])]
+
+ kwargs: dict[str, Any] = {
+ "async_chunk": raw_dict.get("async_chunk", True),
+ "connectors": raw_dict.get("connectors", None),
+ "edges": raw_dict.get("edges", None),
+ "stages": stages,
+ "platforms": raw_dict.get("platforms", None),
+ "pipeline": raw_dict.get("pipeline", None),
+ }
+ # Pipeline-wide engine settings: only set if explicitly present in YAML
+ # so the DeployConfig dataclass defaults take effect otherwise.
+ for name in (
+ "trust_remote_code",
+ "distributed_executor_backend",
+ "dtype",
+ "quantization",
+ "enable_prefix_caching",
+ "enable_chunked_prefill",
+ "data_parallel_size",
+ "pipeline_parallel_size",
+ ):
+ if name in raw_dict:
+ kwargs[name] = raw_dict[name]
+ return DeployConfig(**kwargs)
+
+
+def _detect_platform() -> str | None:
+ """Return "npu", "rocm", "xpu", or None (CUDA default)."""
+ try:
+ from vllm.platforms import current_platform
+
+ name = current_platform.device_name.lower()
+ if "npu" in name:
+ return "npu"
+ if "rocm" in name or "amd" in name:
+ return "rocm"
+ if "xpu" in name:
+ return "xpu"
+ except Exception as e:
+ logger.debug("Platform auto-detect failed, falling back to CUDA: %s", e)
+ return None
+
+
+def _extract_platform_overrides(ps: dict[str, Any]) -> tuple[dict[str, Any], str | None]:
+ """Return ``(overrides, devices)`` from a platform stage entry.
+
+ Handles both the nested layout (``engine_args:`` / ``runtime.devices``) and
+ the flat layout. ``devices`` is ``None`` when no override is set.
+ """
+ if "engine_args" in ps:
+ overrides = dict(ps["engine_args"])
+ runtime_cfg = ps.get("runtime", {})
+ if "num_replicas" in runtime_cfg:
+ overrides["num_replicas"] = runtime_cfg["num_replicas"]
+ return overrides, runtime_cfg.get("devices")
+ overrides = {k: v for k, v in ps.items() if k not in ("stage_id", "devices")}
+ return overrides, ps.get("devices")
+
+
+def _apply_platform_overrides(
+ deploy: DeployConfig,
+ platform: str | None = None,
+) -> DeployConfig:
+ """Merge platform-specific stage overrides into deploy config."""
+ if platform is None:
+ platform = _detect_platform()
+ if platform is None or deploy.platforms is None:
+ return deploy
+ platform_section = deploy.platforms.get(platform)
+ if platform_section is None:
+ return deploy
+
+ platform_stages = platform_section.get("stages", [])
+ base_by_id = {s.stage_id: s for s in deploy.stages}
+
+ for ps in platform_stages:
+ base = base_by_id.get(ps["stage_id"])
+ if base is None:
+ continue
+ overrides, devices = _extract_platform_overrides(ps)
+ if devices is not None:
+ base.devices = devices
+ for key, val in overrides.items():
+ if hasattr(base, key):
+ setattr(base, key, val)
+ else:
+ base.engine_extras[key] = val
+
+ return deploy
+
+
+_EXECUTION_TYPE_TO_STAGE_WORKER: dict[StageExecutionType, tuple[StageType, str | None]] = {
+ StageExecutionType.LLM_AR: (StageType.LLM, "ar"),
+ StageExecutionType.LLM_GENERATION: (StageType.LLM, "generation"),
+ StageExecutionType.DIFFUSION: (StageType.DIFFUSION, None),
+}
+
+
+def _resolve_execution_mode(
+ execution_type: StageExecutionType,
+) -> tuple[StageType, str | None]:
+ """Map ``execution_type`` → ``(stage_type, worker_type)`` legacy tuple."""
+ return _EXECUTION_TYPE_TO_STAGE_WORKER.get(execution_type, (StageType.LLM, None))
+
+
+def _select_processor_funcs(
+ ps: StagePipelineConfig,
+ async_chunk: bool,
+) -> tuple[str | None, str | None]:
+ """Pick ``(input_proc, next_stage_proc)`` based on the async_chunk mode."""
+ next_stage_proc = ps.custom_process_next_stage_input_func
+ input_proc = ps.custom_process_input_func
+ if async_chunk and ps.async_chunk_process_next_stage_input_func:
+ next_stage_proc = ps.async_chunk_process_next_stage_input_func
+ elif not async_chunk and ps.sync_process_input_func:
+ input_proc = ps.sync_process_input_func
+ return input_proc, next_stage_proc
+
+
+# Pipeline-wide DeployConfig fields that are propagated to every stage's
+# engine args during merge. These live at top level of the deploy YAML.
+_PIPELINE_WIDE_ENGINE_FIELDS: tuple[str, ...] = (
+ "trust_remote_code",
+ "distributed_executor_backend",
+ "dtype",
+ "quantization",
+ "enable_prefix_caching",
+ "enable_chunked_prefill",
+ "data_parallel_size",
+ "pipeline_parallel_size",
+)
+
+
+def _build_engine_args(
+ ps: StagePipelineConfig,
+ ds: StageDeployConfig | None,
+ pipeline: PipelineConfig,
+ deploy: DeployConfig,
+ next_stage_proc: str | None,
+) -> dict[str, Any]:
+ """Assemble the flat ``yaml_engine_args`` dict for one stage.
+
+ Pipeline-wide DeployConfig fields are applied uniformly to every stage;
+ per-stage StageDeployConfig overrides take precedence when present (e.g.
+ ``engine_extras`` can still carry a stage-specific ``dtype``).
+ """
+ engine_args: dict[str, Any] = {"model_arch": ps.model_arch or pipeline.model_arch}
+ if ps.engine_output_type:
+ engine_args["engine_output_type"] = ps.engine_output_type
+ if next_stage_proc:
+ engine_args["custom_process_next_stage_input_func"] = next_stage_proc
+
+ # Pipeline-wide top-level DeployConfig settings, applied to every stage.
+ for name in _PIPELINE_WIDE_ENGINE_FIELDS:
+ value = getattr(deploy, name)
+ if value is not None:
+ engine_args[name] = value
+
+ # Per-stage StageDeployConfig values override pipeline-wide settings.
+ if ds is not None:
+ for k, v in asdict(ds).items():
+ if k in _STAGE_NON_ENGINE_KEYS or v is None:
+ continue
+ engine_args[k] = v
+ engine_args.update(ds.engine_extras)
+ if deploy.async_chunk:
+ engine_args["async_chunk"] = True
+ return engine_args
+
+
+def _build_extras(
+ ps: StagePipelineConfig,
+ ds: StageDeployConfig | None,
+) -> dict[str, Any]:
+ """Assemble ``yaml_extras`` (sampling + connectors + pipeline extras)."""
+ extras: dict[str, Any] = {}
+ sampling: dict[str, Any] = {}
+ if ds is not None and ds.default_sampling_params:
+ sampling.update(ds.default_sampling_params)
+ sampling.update(ps.sampling_constraints)
+ if sampling:
+ extras["default_sampling_params"] = sampling
+ if ds is not None and ds.output_connectors:
+ extras["output_connectors"] = dict(ds.output_connectors)
+ if ds is not None and ds.input_connectors:
+ extras["input_connectors"] = dict(ds.input_connectors)
+ if ps.extras:
+ extras.update(ps.extras)
+ return extras
+
+
+def merge_pipeline_deploy(
+ pipeline: PipelineConfig,
+ deploy: DeployConfig,
+ cli_overrides: dict[str, Any] | None = None,
+) -> list[StageConfig]:
+ """Merge pipeline + deploy + platform overrides → list[StageConfig]."""
+ if cli_overrides is None:
+ cli_overrides = {}
+
+ deploy = _apply_platform_overrides(deploy)
+ deploy_by_id = {s.stage_id: s for s in deploy.stages}
+
+ # A pipeline supports async_chunk if any stage has either an explicit
+ # async-chunk-only processor slot OR a custom next-stage processor (some
+ # pipelines like qwen3_omni wire async-chunk processing directly through
+ # ``custom_process_next_stage_input_func``). Only raise when neither is
+ # present — that's the "user enabled async_chunk but pipeline has no
+ # inter-stage processing at all" case.
+ if deploy.async_chunk and not any(
+ ps.async_chunk_process_next_stage_input_func or ps.custom_process_next_stage_input_func
+ for ps in pipeline.stages
+ ):
+ raise ValueError(
+ f"Pipeline {pipeline.model_type!r} has async_chunk=True in deploy but no stage "
+ "declares a next-stage input processor "
+ "(``async_chunk_process_next_stage_input_func`` or ``custom_process_next_stage_input_func``). "
+ "Either set async_chunk=False or implement an async-chunk processor on the pipeline."
+ )
+
+ result: list[StageConfig] = []
+ for ps in pipeline.stages:
+ ds = deploy_by_id.get(ps.stage_id)
+ stage_type, worker_type = _resolve_execution_mode(ps.execution_type)
+ input_proc, next_stage_proc = _select_processor_funcs(ps, deploy.async_chunk)
+ engine_args = _build_engine_args(ps, ds, pipeline, deploy, next_stage_proc)
+ extras = _build_extras(ps, ds)
+ runtime: dict[str, Any] = {"process": True}
+ if ds is not None:
+ runtime["devices"] = ds.devices
+ runtime["num_replicas"] = ds.num_replicas
+
+ result.append(
+ StageConfig(
+ stage_id=ps.stage_id,
+ model_stage=ps.model_stage,
+ stage_type=stage_type,
+ input_sources=list(ps.input_sources),
+ custom_process_input_func=input_proc,
+ final_output=ps.final_output,
+ final_output_type=ps.final_output_type,
+ worker_type=worker_type,
+ scheduler_cls=_scheduler_path(_EXECUTION_TYPE_TO_SCHEDULER.get(ps.execution_type)),
+ hf_config_name=ps.hf_config_name,
+ is_comprehension=ps.owns_tokenizer,
+ yaml_engine_args=engine_args,
+ yaml_runtime=runtime,
+ yaml_extras=extras,
+ )
+ )
+ return result
+
+
@dataclass
class StageConfig:
- """Per-stage configuration from pipeline YAML.
+ """Per-stage config (legacy path). Used by both new and legacy loaders.
- Topology fields (stage_id, input_sources, etc.) define the DAG.
- Engine and runtime defaults come from the YAML; CLI overrides take
- precedence via ``runtime_overrides``.
+ TODO(@lishunyang12): replace with ResolvedStageConfig once all models are migrated.
"""
- # Identity
stage_id: int
model_stage: str
-
- # Stage type
stage_type: StageType = StageType.LLM
-
input_sources: list[int] = field(default_factory=list)
custom_process_input_func: str | None = None
final_output: bool = False
- final_output_type: str | None = None # "text", "audio", "image"
- worker_type: str | None = None # "ar" or "generation"
+ final_output_type: str | None = None
+ worker_type: str | None = None
scheduler_cls: str | None = None
hf_config_name: str | None = None
is_comprehension: bool = False
-
- # Per-stage engine args from pipeline YAML (defaults)
yaml_engine_args: dict[str, Any] = field(default_factory=dict)
- # Per-stage runtime config from pipeline YAML (devices, etc.)
yaml_runtime: dict[str, Any] = field(default_factory=dict)
- # Pass-through fields from pipeline YAML (default_sampling_params,
- # output_connectors, input_connectors, tts_args, etc.)
yaml_extras: dict[str, Any] = field(default_factory=dict)
-
- # Runtime overrides (populated from CLI, not from pipeline YAML)
runtime_overrides: dict[str, Any] = field(default_factory=dict)
def to_omegaconf(self) -> Any:
- """Convert to OmegaConf for backward compatibility with OmniStage.
-
- Returns:
- OmegaConf DictConfig with stage configuration in legacy format.
- """
+ """TODO(@lishunyang12): remove once engine consumes ResolvedStageConfig directly."""
# Start with YAML engine_args defaults
engine_args: dict[str, Any] = dict(self.yaml_engine_args)
@@ -152,9 +903,9 @@ def to_omegaconf(self) -> Any:
@dataclass
class ModelPipeline:
- """Complete pipeline definition for a multi-stage model.
+ """Complete pipeline definition for a multi-stage model (legacy).
- Defined by model developers, bundled with the model, not user-editable.
+ TODO(@lishunyang12): remove once all models migrate to PipelineConfig.
"""
model_type: str
@@ -225,49 +976,55 @@ class StageConfigFactory:
"""Factory that loads pipeline YAML and merges CLI overrides.
Handles both single-stage and multi-stage models.
- """
-
- # Mapping of model types to directories under model_executor/models/.
- PIPELINE_MODELS: dict[str, str] = {
- "qwen3_omni_moe": "qwen3_omni",
- "qwen2_5_omni": "qwen2_5_omni",
- "bagel": "bagel",
- "qwen3_tts": "qwen3_tts",
- "voxtral_tts": "voxtral_tts",
- "mimo_audio": "mimo_audio",
- "glm-image": "glm_image",
- "cosyvoice3": "cosyvoice3",
- "mammothmoda2": "mammoth_moda2",
- }
- # Fallback: map HF architecture class names to pipeline dirs.
- # Used when model_type collides with another model (e.g. MiMo Audio
- # reports model_type="qwen2" which matches plain Qwen2, not our pipeline).
- _ARCHITECTURE_MODELS: dict[str, str] = {
- "MiMoAudioForConditionalGeneration": "mimo_audio",
- "HunyuanImage3ForCausalMM": "hunyuan_image3",
- }
+ Pipelines are declared in ``vllm_omni/config/pipeline_registry.py`` and
+ loaded lazily via ``_PIPELINE_REGISTRY``; no hardcoded model-type →
+ directory mapping is maintained here. Models with generic HF
+ ``model_type`` collisions (e.g. MiMo Audio reports ``qwen2``) should
+ declare ``hf_architectures=(...)`` on their ``PipelineConfig`` so the
+ factory can disambiguate via ``hf_config.architectures``.
+ """
@classmethod
def create_from_model(
cls,
model: str,
cli_overrides: dict[str, Any] | None = None,
+ deploy_config_path: str | None = None,
+ cli_explicit_keys: set[str] | None = None,
) -> list[StageConfig] | None:
- """Load pipeline YAML, merge with CLI overrides.
+ """Load pipeline + deploy config, merge with CLI overrides.
- Args:
- model: Model name or path.
- cli_overrides: CLI overrides from VllmConfig/OmniDiffusionConfig.
+ Checks _PIPELINE_REGISTRY first (new path), falls back to legacy YAML.
- Returns:
- List of StageConfig objects with CLI overrides applied,
- or None if no pipeline definition was found for this model.
+ ``cli_explicit_keys`` is the set of CLI keys the user actually typed
+ (captured at the parser layer in ``vllm serve``). When ``None`` —
+ which is the case for programmatic ``Omni()`` callers — every kwarg
+ in ``cli_overrides`` is treated as explicit.
"""
if cli_overrides is None:
cli_overrides = {}
trust_remote_code = cli_overrides.get("trust_remote_code", True)
+
+ # --- New path: check pipeline registry by model_type first ---
+ model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code)
+ if model_type and model_type in _PIPELINE_REGISTRY:
+ return cls._create_from_registry(model_type, cli_overrides, deploy_config_path, cli_explicit_keys)
+
+ # --- HF architecture fallback: some models report a generic
+ # model_type that collides with another model. Match by the
+ # hf_architectures declared on each registered PipelineConfig.
+ if hf_config is not None:
+ hf_archs = set(getattr(hf_config, "architectures", []) or [])
+ if hf_archs:
+ for registered in _PIPELINE_REGISTRY.values():
+ if hf_archs.intersection(registered.hf_architectures):
+ return cls._create_from_registry(
+ registered.model_type, cli_overrides, deploy_config_path, cli_explicit_keys
+ )
+
+ # --- Legacy path: load from pipeline YAML ---
pipeline = cls._load_pipeline(model, trust_remote_code=trust_remote_code)
if pipeline is None:
@@ -295,6 +1052,78 @@ def create_from_model(
return result
+ @classmethod
+ def _create_from_registry(
+ cls,
+ model_type: str,
+ cli_overrides: dict[str, Any],
+ deploy_config_path: str | None = None,
+ cli_explicit_keys: set[str] | None = None,
+ ) -> list[StageConfig]:
+ """Create StageConfigs from pipeline registry + deploy YAML.
+
+ Precedence (high → low):
+ explicit CLI args > deploy YAML > parser default CLI values
+
+ ``cli_explicit_keys`` carries the set of long-option attribute names
+ the user actually typed (captured in ``OmniServeCommand.cmd``). Any
+ kwarg whose key is not in that set is treated as a parser default
+ and is only used to fill fields YAML doesn't already cover. When the
+ set is ``None`` (programmatic ``Omni()`` callers, which have no
+ argparse layer), every kwarg is treated as explicit.
+ """
+ # Resolve deploy config path
+ if deploy_config_path is None:
+ deploy_path = _DEPLOY_DIR / f"{model_type}.yaml"
+ else:
+ deploy_path = Path(deploy_config_path)
+
+ if not deploy_path.exists():
+ logger.warning(
+ "Deploy config not found: %s — using pipeline defaults only",
+ deploy_path,
+ )
+ deploy_cfg = DeployConfig()
+ else:
+ deploy_cfg = load_deploy_config(deploy_path)
+
+ cli_async_chunk = cli_overrides.get("async_chunk")
+ if cli_async_chunk is not None and (cli_explicit_keys is None or "async_chunk" in cli_explicit_keys):
+ deploy_cfg.async_chunk = bool(cli_async_chunk)
+
+ pipeline_key = deploy_cfg.pipeline or model_type
+ if pipeline_key not in _PIPELINE_REGISTRY:
+ raise KeyError(
+ f"Pipeline {pipeline_key!r} not in registry "
+ f"(resolved from {deploy_path.name!r}). Available: "
+ f"{sorted(_PIPELINE_REGISTRY.keys())}"
+ )
+ pipeline_cfg = _PIPELINE_REGISTRY[pipeline_key]
+
+ stages = merge_pipeline_deploy(pipeline_cfg, deploy_cfg, cli_overrides)
+
+ # Precedence: explicit CLI > yaml > parser-default CLI.
+ # Per-stage (``stage_N_*``) keys are always treated as explicit.
+ explicit_overrides: dict[str, Any] = {}
+ default_overrides: dict[str, Any] = {}
+ for key, value in cli_overrides.items():
+ if value is None:
+ continue
+ is_per_stage = bool(re.match(r"stage_\d+_", key))
+ is_explicit = cli_explicit_keys is None or key in cli_explicit_keys or is_per_stage
+ if is_explicit:
+ explicit_overrides[key] = value
+ else:
+ default_overrides[key] = value
+
+ for stage in stages:
+ yaml_keys = set(stage.yaml_engine_args)
+ fallback = {k: v for k, v in default_overrides.items() if k not in yaml_keys}
+ merged = {**fallback, **explicit_overrides}
+ stage.runtime_overrides = cls._merge_cli_overrides(stage, merged)
+
+ return stages
+
@classmethod
def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]]:
"""Single-stage diffusion - no YAML needed.
@@ -322,9 +1151,16 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]
continue
engine_args[key] = value
- # Serialize parallel_config as dict for OmegaConf compatibility
+ # Serialize parallel_config as dict for OmegaConf. Test helpers
+ # sometimes pass SimpleNamespace rather than a dataclass instance.
if "parallel_config" in kwargs:
- engine_args["parallel_config"] = asdict(kwargs["parallel_config"])
+ parallel_config = kwargs["parallel_config"]
+ if dataclasses.is_dataclass(parallel_config) and not isinstance(parallel_config, type):
+ engine_args["parallel_config"] = asdict(parallel_config)
+ elif hasattr(parallel_config, "__dict__"):
+ engine_args["parallel_config"] = dict(vars(parallel_config))
+ else:
+ engine_args["parallel_config"] = parallel_config
engine_args.setdefault("cache_backend", "none")
engine_args["model_stage"] = "diffusion"
@@ -351,40 +1187,49 @@ def create_default_diffusion(cls, kwargs: dict[str, Any]) -> list[dict[str, Any]
@classmethod
def _load_pipeline(cls, model: str, trust_remote_code: bool = True) -> ModelPipeline | None:
- """Load pipeline YAML for the model.
+ """Load a legacy ``pipeline.yaml`` for the model.
- Args:
- model: Model name or path.
- trust_remote_code: Whether to trust remote code for HF config loading.
+ Searches ``model_executor/models//pipeline.yaml`` by trying
+ (a) the raw ``model_type`` as the directory name, then
+ (b) ``model_type`` with hyphens replaced by underscores,
+ and finally (c) scanning every ``pipeline.yaml`` for one that
+ declares a matching ``model_type`` or ``hf_architectures``.
- Returns:
- ModelPipeline if found, None otherwise.
+ Returns None if no pipeline.yaml is found — caller handles the
+ ``resolve_model_config_path`` fallback via stage_configs/ YAMLs.
"""
model_type, hf_config = cls._auto_detect_model_type(model, trust_remote_code=trust_remote_code)
if model_type is None:
return None
- pipeline_dir = cls.PIPELINE_MODELS.get(model_type)
-
- # Fallback: check HF architectures when model_type doesn't match
- if pipeline_dir is None and hf_config is not None:
- for arch in getattr(hf_config, "architectures", []) or []:
- pipeline_dir = cls._ARCHITECTURE_MODELS.get(arch)
- if pipeline_dir is not None:
- model_type = pipeline_dir
- break
-
- if pipeline_dir is None:
- logger.debug(f"No pipeline mapping for model_type: {model_type}")
- return None
-
- pipeline_path = get_pipeline_path(pipeline_dir, "pipeline.yaml")
-
- if not pipeline_path.exists():
- logger.debug(f"Pipeline file not found: {pipeline_path}")
- return None
+ # Direct lookups by convention
+ candidates = [model_type, model_type.replace("-", "_")]
+ for dir_name in candidates:
+ pipeline_path = get_pipeline_path(dir_name, "pipeline.yaml")
+ if pipeline_path.exists():
+ return cls._parse_pipeline_yaml(pipeline_path, model_type)
+
+ # Scan fallback: read every pipeline.yaml and match on declared fields
+ hf_archs = set(getattr(hf_config, "architectures", []) or []) if hf_config else set()
+ if _MODELS_DIR.exists():
+ for subdir in sorted(_MODELS_DIR.iterdir()):
+ if not subdir.is_dir():
+ continue
+ pipeline_path = subdir / "pipeline.yaml"
+ if not pipeline_path.exists():
+ continue
+ try:
+ cfg = load_yaml_config(pipeline_path)
+ except Exception as exc:
+ logger.debug("Skip %s: %s", pipeline_path, exc)
+ continue
+ declared_type = getattr(cfg, "model_type", None)
+ declared_archs = set(getattr(cfg, "hf_architectures", None) or [])
+ if declared_type == model_type or (hf_archs and hf_archs.intersection(declared_archs)):
+ return cls._parse_pipeline_yaml(pipeline_path, declared_type or model_type)
- return cls._parse_pipeline_yaml(pipeline_path, model_type)
+ logger.debug("No pipeline.yaml found for model_type %s (archs=%s)", model_type, sorted(hf_archs))
+ return None
# Keys consumed as explicit StageConfig fields — everything else is
# passed through via yaml_extras.
@@ -542,66 +1387,17 @@ def _auto_detect_model_type(cls, model: str, trust_remote_code: bool = True) ->
return None, None
- # Keys that should never be forwarded as engine overrides (internal /
- # orchestrator-only knobs, complex objects, etc.).
- _INTERNAL_KEYS: set[str] = {
- "model",
- "stage_configs_path",
- "stage_id",
- "stage_init_timeout",
- "init_timeout",
- "shm_threshold_bytes",
- "worker_backend",
- "ray_address",
- "batch_timeout",
- "log_stats",
- "tokenizer",
- "parallel_config",
- }
-
@classmethod
def _merge_cli_overrides(
cls,
stage: StageConfig,
cli_overrides: dict[str, Any],
) -> dict[str, Any]:
- """Merge CLI overrides into stage runtime config.
+ """Merge global and per-stage (``stage_N_*``) CLI overrides.
- All CLI arguments registered by engine config classes (e.g.
- EngineArgs / OmniDiffusionConfig) are accepted as overrides
- unless they appear in ``_INTERNAL_KEYS``.
-
- Handles:
- - Global overrides (apply to all stages)
- - Per-stage overrides (--stage-N-* format, take precedence)
-
- Args:
- stage: The stage to merge overrides into.
- cli_overrides: CLI arguments from VllmConfig/OmniDiffusionConfig.
-
- Returns:
- Dict of runtime overrides for this stage.
+ Orchestrator-owned keys are filtered by ``build_stage_runtime_overrides``
+ using ``OrchestratorArgs`` as the single source of truth; unknown
+ server/uvicorn keys are dropped downstream by
+ ``filter_dataclass_kwargs(OmniEngineArgs, ...)``.
"""
- result: dict[str, Any] = {}
-
- # Apply global overrides – any key not in the internal blocklist
- # is forwarded so that engine-registered params work out of the box.
- for key, value in cli_overrides.items():
- if key in cls._INTERNAL_KEYS:
- continue
- if re.match(r"stage_\d+_", key):
- # Per-stage keys handled below
- continue
- if value is not None:
- result[key] = value
-
- # Apply per-stage overrides (--stage-N-* format, take precedence)
- stage_prefix = f"stage_{stage.stage_id}_"
- for key, value in cli_overrides.items():
- if key.startswith(stage_prefix) and value is not None:
- param_name = key[len(stage_prefix) :]
- if param_name in cls._INTERNAL_KEYS:
- continue
- result[param_name] = value
-
- return result
+ return build_stage_runtime_overrides(stage.stage_id, cli_overrides)
diff --git a/vllm_omni/core/sched/omni_ar_scheduler.py b/vllm_omni/core/sched/omni_ar_scheduler.py
index c5031ad11b0..a5579dd4640 100644
--- a/vllm_omni/core/sched/omni_ar_scheduler.py
+++ b/vllm_omni/core/sched/omni_ar_scheduler.py
@@ -15,9 +15,10 @@
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.metrics.perf import PerfStats
from vllm.v1.outputs import ModelRunnerOutput
-from vllm.v1.request import Request, RequestStatus
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.spec_decode.metrics import SpecDecodingStats
+from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
from vllm_omni.core.sched.output import OmniSchedulerOutput
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
@@ -38,7 +39,7 @@ def to_dict(self) -> dict[str, Any]:
return asdict(self)
-class OmniARScheduler(VLLMScheduler):
+class OmniARScheduler(OmniSchedulerMixin, VLLMScheduler):
"""
OmniARScheduler: Scheduler for vLLM-Omni multimodal processing.
@@ -76,6 +77,8 @@ def __init__(self, *args, **kwargs):
self.chunk_transfer_adapter = None
if getattr(model_config, "async_chunk", False):
self.chunk_transfer_adapter = OmniChunkTransferAdapter(self.vllm_config)
+ # Snapshot prompt length for each streaming input update
+ self._new_prompt_len_snapshot: dict[str, int] = {}
def _get_kv_transfer_criteria(self) -> dict | None:
# Note: vllm_config is available in Scheduler after super().__init__
@@ -338,6 +341,7 @@ def update_from_output(
)
stopped = False
+ is_segment_finished = False
new_logprobs = None
new_token_ids = generated_token_ids
pooler_output = pooler_outputs[req_index] if pooler_outputs else None
@@ -366,6 +370,7 @@ def update_from_output(
# Capture finish_reason BEFORE _handle_stopped_request, which may
# reset the status to WAITING for streaming requests that continue.
finish_reason = request.get_finished_reason()
+ is_segment_finished = request.is_finished() and request.resumable
finished = self._handle_stopped_request(request)
if finished:
kv_transfer_params = self._free_request(request)
@@ -418,6 +423,8 @@ def update_from_output(
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
+ is_segment_finished=is_segment_finished,
+ new_prompt_len_snapshot=self._new_prompt_len_snapshot.get(req_id, None),
)
)
if self.chunk_transfer_adapter is not None:
@@ -557,6 +564,21 @@ def finish_requests(self, request_ids: Any, finished_status: RequestStatus) -> l
return super().finish_requests(request_ids, finished_status)
+ def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None:
+ """
+ Override: Only extend prompt at stage 0, and replace
+ the existing session with the next streaming update at other stages.
+
+ Discards the last sampled output token from the prior input chunk at stage 0.
+ """
+ req_id = session.request_id
+ self._new_prompt_len_snapshot[req_id] = len(update.prompt_token_ids)
+ if self.vllm_config.model_config.stage_id != 0:
+ self._replace_session_with_streaming_update(session, update)
+
+ else:
+ super()._update_request_as_session(session, update)
+
def _free_request(self, request: Request, delay_free_blocks: bool = False) -> dict[str, Any] | None:
# TODO(wzliu)! for offline mode, we should not end process until all data is transferred
"""Mark a request as finished and free its resources."""
@@ -570,6 +592,7 @@ def _free_request(self, request: Request, delay_free_blocks: bool = False) -> di
self.encoder_cache_manager.free(request)
request_id = request.request_id
self.finished_req_ids.add(request_id)
+ self._new_prompt_len_snapshot.pop(request_id, None)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
diff --git a/vllm_omni/core/sched/omni_generation_scheduler.py b/vllm_omni/core/sched/omni_generation_scheduler.py
index 87f2d282aa7..81f0b7fc2b4 100644
--- a/vllm_omni/core/sched/omni_generation_scheduler.py
+++ b/vllm_omni/core/sched/omni_generation_scheduler.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import time
from collections import defaultdict
@@ -11,11 +13,16 @@
from vllm.v1.core.sched.request_queue import create_request_queue
from vllm.v1.core.sched.scheduler import Scheduler as VLLMScheduler
from vllm.v1.core.sched.utils import remove_all
-from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
+from vllm.v1.engine import (
+ EngineCoreEventType,
+ EngineCoreOutput,
+ EngineCoreOutputs,
+)
from vllm.v1.metrics.perf import PerfStats
-from vllm.v1.request import Request, RequestStatus
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
from vllm.v1.spec_decode.metrics import SpecDecodingStats
+from vllm_omni.core.sched.omni_scheduler_mixin import OmniSchedulerMixin
from vllm_omni.core.sched.output import OmniCachedRequestData, OmniNewRequestData
from vllm_omni.distributed.omni_connectors.transfer_adapter.chunk_transfer_adapter import (
OmniChunkTransferAdapter,
@@ -25,7 +32,7 @@
logger = init_logger(__name__)
-class OmniGenerationScheduler(VLLMScheduler):
+class OmniGenerationScheduler(OmniSchedulerMixin, VLLMScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
model_config = self.vllm_config.model_config
@@ -599,3 +606,11 @@ def update_from_output(
eco.scheduler_stats = stats
return engine_core_outputs
+
+ def _update_request_as_session(self, session: Request, update: StreamingUpdate) -> None:
+ """
+ Override: Just replace the existing session with the next streaming update.
+
+ Do not expend prompt id using update.
+ """
+ self._replace_session_with_streaming_update(session, update)
diff --git a/vllm_omni/core/sched/omni_scheduler_mixin.py b/vllm_omni/core/sched/omni_scheduler_mixin.py
new file mode 100644
index 00000000000..36080e63acc
--- /dev/null
+++ b/vllm_omni/core/sched/omni_scheduler_mixin.py
@@ -0,0 +1,33 @@
+from __future__ import annotations
+
+from vllm.v1.engine import EngineCoreEventType
+from vllm.v1.request import Request, RequestStatus, StreamingUpdate
+
+
+class OmniSchedulerMixin:
+ """Shared scheduler helpers for omni-specific request handling."""
+
+ def _replace_session_with_streaming_update(
+ self,
+ session: Request,
+ update: StreamingUpdate,
+ ) -> None:
+ """For streaming input: Replace an existing streaming session payload with the latest update."""
+ session._output_token_ids.clear()
+ session._all_token_ids.clear()
+ new_prompt = update.prompt_token_ids or ()
+ session._all_token_ids.extend(new_prompt)
+ session.num_computed_tokens = 0
+ session.prompt_token_ids = update.prompt_token_ids or ()
+ session.additional_information = update.additional_information or None
+ # Update block hashes for the new tokens.
+ session.update_block_hashes()
+ session.num_prompt_tokens = len(session.prompt_token_ids)
+ session.arrival_time = update.arrival_time
+ session.sampling_params = update.sampling_params
+ if session.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
+ self.num_waiting_for_streaming_input -= 1
+ session.status = RequestStatus.WAITING
+
+ if self.log_stats:
+ session.record_event(EngineCoreEventType.QUEUED)
diff --git a/vllm_omni/deploy/qwen2_5_omni.yaml b/vllm_omni/deploy/qwen2_5_omni.yaml
new file mode 100644
index 00000000000..41aef0df6f6
--- /dev/null
+++ b/vllm_omni/deploy/qwen2_5_omni.yaml
@@ -0,0 +1,92 @@
+# Qwen2.5-Omni deploy: CUDA defaults + platform overrides, verified on 2x H100.
+# Stage 2 disables flashinfer autotune because its DiT block never invokes
+# flashinfer; the autotune dummy run OOMs the shared cuda:0 device otherwise.
+#
+# Fields omitted from a stage fall back to StageDeployConfig dataclass
+# defaults (see vllm_omni/config/stage_config.py). For instance, every
+# stage here uses vLLM's default max_num_batched_tokens=32768 because
+# chat-sized prefill comfortably fits; only models with codec prefill
+# (Qwen3-Omni, Qwen3-TTS) need to bump it above 32k.
+#
+# enforce_eager policy across the three deploy YAMLs:
+# * code2wav / generation stages: always true (cudagraph incompatible with
+# the custom generation loop — set explicitly everywhere).
+# * AR stages (thinker, talker): model-dependent. Qwen2.5-Omni runs eager
+# on CUDA (thinker uses custom ops that don't trace cleanly); NPU / XPU
+# platform overrides flip back to false where cudagraph is verified.
+# Qwen3-Omni / Qwen3-TTS AR stages use the default (false = cudagraph on).
+async_chunk: false
+
+stages:
+ - stage_id: 0
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ mm_processor_cache_gb: 0
+ devices: "0"
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.1
+
+ - stage_id: 1
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.8
+ enforce_eager: true
+ devices: "1"
+ default_sampling_params:
+ temperature: 0.9
+ top_p: 0.8
+ top_k: 40
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 2
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.15
+ enforce_eager: true
+ enable_flashinfer_autotune: false
+ async_scheduling: false
+ devices: "0"
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.1
+
+platforms:
+ npu:
+ stages:
+ # NPU has cudagraph support for the thinker, unlike GPU which still
+ # only runs eager.
+ - stage_id: 0
+ enforce_eager: false
+ - stage_id: 2
+ # 3-NPU layout: stage 2 lives on its own card.
+ devices: "2"
+
+ rocm:
+ stages:
+ - stage_id: 2
+ # 3-GPU MI325 layout: stage 2 on a separate card.
+ devices: "2"
+
+ xpu:
+ stages:
+ # Verified on 2x Intel Arc Pro B60. Both AR stages use cudagraphs.
+ - stage_id: 0
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ - stage_id: 1
+ gpu_memory_utilization: 0.5
+ enforce_eager: false
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ # Stage 2 colocates with stage 1's device on XPU.
+ devices: "1"
diff --git a/vllm_omni/deploy/qwen3_omni_moe.yaml b/vllm_omni/deploy/qwen3_omni_moe.yaml
new file mode 100644
index 00000000000..fb8b6162139
--- /dev/null
+++ b/vllm_omni/deploy/qwen3_omni_moe.yaml
@@ -0,0 +1,98 @@
+# Qwen3-Omni-MoE production deploy, verified on 2x H100 (stage 0 on cuda:0,
+# stages 1+2 on cuda:1).
+#
+# Fields omitted from a stage fall back to StageDeployConfig defaults (see
+# vllm_omni/config/stage_config.py). Notable implicit defaults for this
+# model:
+# * Stages 0/1 (thinker, talker) do not set max_num_batched_tokens —
+# chat-sized prefill fits in the 32768 default.
+# * Stages 0/1 do not set enforce_eager — cudagraph runs by default
+# (false). Stage 2 (code2wav) sets true because its generation loop
+# is cudagraph-incompatible.
+# * Platform sections flip enforce_eager per-stage where platform
+# cudagraph support differs.
+async_chunk: true
+
+connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.9
+ devices: "0"
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ gpu_memory_utilization: 0.6
+ devices: "1"
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 2
+ gpu_memory_utilization: 0.1
+ max_num_seqs: 1
+ enforce_eager: true
+ async_scheduling: false
+ max_num_batched_tokens: 51200
+ devices: "1"
+ input_connectors:
+ from_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ repetition_penalty: 1.1
+
+platforms:
+ npu:
+ stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.6
+ tensor_parallel_size: 2
+ devices: "0,1"
+ - stage_id: 1
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ devices: "2"
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ devices: "2"
+
+ rocm:
+ stages:
+ - stage_id: 0
+ enforce_eager: true
+
+ xpu:
+ stages:
+ - stage_id: 0
+ tensor_parallel_size: 4
+ enforce_eager: true
+ max_cudagraph_capture_size: 0
+ devices: "0,1,2,3"
+ - stage_id: 1
+ enforce_eager: true
+ max_cudagraph_capture_size: 0
+ devices: "4"
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ max_cudagraph_capture_size: 0
+ devices: "4"
diff --git a/vllm_omni/deploy/qwen3_omni_moe_multi_replicas.yaml b/vllm_omni/deploy/qwen3_omni_moe_multi_replicas.yaml
new file mode 100644
index 00000000000..791e29a8de5
--- /dev/null
+++ b/vllm_omni/deploy/qwen3_omni_moe_multi_replicas.yaml
@@ -0,0 +1,101 @@
+# Qwen3-Omni-MoE multi-replica deploy, verified on 3x H20-96G
+# (stage 0 on cuda:0, stages 1+2 scaled out across cuda:1,2).
+#
+# Fields omitted from a stage fall back to StageDeployConfig defaults (see
+# vllm_omni/config/stage_config.py). Notable implicit defaults for this
+# layout:
+# * Stages 0/1 (thinker, talker) do not set max_num_batched_tokens:
+# chat-sized prefill fits in the 32768 default.
+# * Stage 1 (talker) does not set enforce_eager: cudagraph runs by
+# default (false).
+# * Stage 2 (code2wav) sets enforce_eager=true because its generation
+# loop is cudagraph-incompatible.
+# * Stages 1/2 use num_replicas=2 and share devices "1,2". Replica
+# routing is handled by StagePool with request-level affinity.
+async_chunk: true
+
+connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ codec_chunk_frames: 25
+ codec_left_context_frames: 25
+
+stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.9
+ devices: "0"
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ top_k: 1
+ max_tokens: 2048
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ gpu_memory_utilization: 0.6
+ devices: "1,2"
+ num_replicas: 2
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 2
+ gpu_memory_utilization: 0.1
+ max_num_seqs: 1
+ enforce_eager: true
+ async_scheduling: false
+ max_num_batched_tokens: 51200
+ devices: "1,2"
+ num_replicas: 2
+ input_connectors:
+ from_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ repetition_penalty: 1.1
+
+platforms:
+ npu:
+ stages:
+ - stage_id: 0
+ gpu_memory_utilization: 0.6
+ tensor_parallel_size: 2
+ devices: "0,1"
+ - stage_id: 1
+ gpu_memory_utilization: 0.6
+ enforce_eager: true
+ devices: "2"
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ devices: "2"
+
+ rocm:
+ stages:
+ - stage_id: 0
+ enforce_eager: true
+
+ xpu:
+ stages:
+ - stage_id: 0
+ tensor_parallel_size: 4
+ enforce_eager: true
+ max_cudagraph_capture_size: 0
+ devices: "0,1,2,3"
+ - stage_id: 1
+ enforce_eager: true
+ max_cudagraph_capture_size: 0
+ devices: "4"
+ - stage_id: 2
+ gpu_memory_utilization: 0.3
+ max_cudagraph_capture_size: 0
+ devices: "4"
diff --git a/vllm_omni/deploy/qwen3_tts.yaml b/vllm_omni/deploy/qwen3_tts.yaml
new file mode 100644
index 00000000000..32dceebd805
--- /dev/null
+++ b/vllm_omni/deploy/qwen3_tts.yaml
@@ -0,0 +1,81 @@
+# Qwen3-TTS deploy: talker → code2wav via shared-memory chunk streaming.
+# Verified on 1x H100.
+#
+# Fields omitted from a stage fall back to StageDeployConfig defaults (see
+# vllm_omni/config/stage_config.py). Notable choices for this model:
+# * Stage 0 (talker) sets max_num_batched_tokens=512 for async-chunk
+# latency tuning (not correctness) — small per-step batches keep
+# first-chunk latency low.
+# * Stage 1 (code2wav) sets max_num_batched_tokens=65536 for correctness:
+# codec prefill length (Q * num_frames) exceeds the 32k default.
+# * Stage 0 does not set enforce_eager — talker runs cudagraph by default.
+# Stage 1 sets true because its codec generation loop is not
+# cudagraph-compatible. NPU platform flips stage 0 to true where
+# cudagraph is not yet verified.
+async_chunk: true
+
+connectors:
+ connector_of_shared_memory:
+ name: SharedMemoryConnector
+ extra:
+ shm_threshold_bytes: 65536
+ codec_streaming: true
+ connector_get_sleep_s: 0.01
+ connector_get_max_wait_first_chunk: 3000
+ connector_get_max_wait: 300
+ # Must match the decoder sliding attention window.
+ codec_chunk_frames: 25
+ codec_left_context_frames: 72
+
+stages:
+ - stage_id: 0
+ max_num_seqs: 10
+ gpu_memory_utilization: 0.3
+ async_scheduling: true
+ max_num_batched_tokens: 512
+ max_model_len: 4096
+ devices: "0"
+ output_connectors:
+ to_stage_1: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.9
+ top_k: 50
+ max_tokens: 4096
+ seed: 42
+ repetition_penalty: 1.05
+
+ - stage_id: 1
+ max_num_seqs: 1
+ gpu_memory_utilization: 0.3
+ enforce_eager: true
+ async_scheduling: true
+ # Must be divisible by num_code_groups and cover (left_context + chunk).
+ # Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep
+ # headroom past 32k.
+ max_num_batched_tokens: 65536
+ # async_chunk appends windows per step; max_model_len must cover the
+ # accumulated flat codec stream.
+ max_model_len: 65536
+ devices: "0"
+ input_connectors:
+ from_stage_0: connector_of_shared_memory
+ default_sampling_params:
+ temperature: 0.0
+ top_p: 1.0
+ top_k: -1
+ max_tokens: 65536
+ seed: 42
+ repetition_penalty: 1.0
+
+platforms:
+ npu:
+ stages:
+ # NPU does not yet support async-scheduling for TTS, and the
+ # talker fits at max_num_seqs=1 only.
+ - stage_id: 0
+ max_num_seqs: 1
+ enforce_eager: true
+ async_scheduling: false
+ - stage_id: 1
+ gpu_memory_utilization: 0.2
+ async_scheduling: false
diff --git a/vllm_omni/diffusion/cache/cache_dit_backend.py b/vllm_omni/diffusion/cache/cache_dit_backend.py
index e9f79da4f3b..d5397dd1663 100644
--- a/vllm_omni/diffusion/cache/cache_dit_backend.py
+++ b/vllm_omni/diffusion/cache/cache_dit_backend.py
@@ -281,6 +281,7 @@ def enable_cache_for_longcat_image(pipeline: Any, cache_config: Any) -> Callable
],
forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_1],
params_modifiers=[modifier],
+ has_separate_cfg=True,
)
),
cache_config=db_cache_config,
@@ -632,6 +633,7 @@ def enable_cache_for_ltx2(pipeline: Any, cache_config: Any) -> Callable[[int], N
forward_pattern=ForwardPattern.Pattern_0,
# Treat audio_hidden_states as encoder_hidden_states in Pattern_0
check_forward_pattern=False,
+ has_separate_cfg=True,
),
cache_config=db_cache_config,
calibrator_config=calibrator_config,
@@ -1168,41 +1170,85 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
return refresh_cache_context
-def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
- """Enable cache-dit for GLM-Image pipeline.
+def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for Flux.2-dev pipeline.
- GLM-Image processes prompt and image by calling the transformer before the
- denoising loop. When an input image is provided (editing mode), the cache must
- be force-refreshed after the preprocessing step so stale hidden states are
- discarded. Set force_refresh_step_hint = 1 for editing, None for text-to-image.
+ Args:
+ pipeline: The Flux2 pipeline instance.
+ cache_config: DiffusionCacheConfig instance with cache configuration.
+ Returns:
+ A refresh function that can be called with a new ``num_inference_steps``
+ to update the cache context for the pipeline.
"""
+ # Build DBCacheConfig for transformer
db_cache_config = _build_db_cache_config(cache_config)
- calibrator_config = None
+ calibrator = None
if cache_config.enable_taylorseer:
- calibrator_config = TaylorSeerCalibratorConfig(taylorseer_order=cache_config.taylorseer_order)
- logger.info(f"TaylorSeer enabled with order={cache_config.taylorseer_order}")
+ taylorseer_order = cache_config.taylorseer_order
+ calibrator = TaylorSeerCalibratorConfig(taylorseer_order=taylorseer_order)
+ logger.info(f"TaylorSeer enabled with order={taylorseer_order}")
+
+ # Build ParamsModifier for transformer
+ modifier = ParamsModifier(
+ cache_config=db_cache_config,
+ calibrator_config=calibrator,
+ )
logger.info(
- f"Enabling cache-dit on GLM-Image transformer: "
+ f"Enabling cache-dit on Flux transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
- f"force_refresh_step_hint={db_cache_config.force_refresh_step_hint}, "
)
+ # Enable cache-dit using BlockAdapter for transformer
cache_dit.enable_cache(
- pipeline.transformer,
+ (
+ BlockAdapter(
+ transformer=pipeline.transformer,
+ blocks=[
+ pipeline.transformer.transformer_blocks,
+ pipeline.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
+ params_modifiers=[modifier],
+ )
+ ),
cache_config=db_cache_config,
- calibrator_config=calibrator_config,
)
+ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool = True) -> None:
+ """Refresh cache context for the transformer with new num_inference_steps.
-def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
- """Enable cache-dit for Flux.2-dev pipeline.
+ Args:
+ pipeline: The Flux2 pipeline instance.
+ num_inference_steps: New number of inference steps.
+ """
+ if cache_config.scm_steps_mask_policy is None:
+ cache_dit.refresh_context(pipeline.transformer, num_inference_steps=num_inference_steps, verbose=verbose)
+ else:
+ cache_dit.refresh_context(
+ pipeline.transformer,
+ cache_config=DBCacheConfig().reset(
+ num_inference_steps=num_inference_steps,
+ steps_computation_mask=cache_dit.steps_mask(
+ mask_policy=cache_config.scm_steps_mask_policy,
+ total_steps=num_inference_steps,
+ ),
+ steps_computation_policy=cache_config.scm_steps_policy,
+ ),
+ verbose=verbose,
+ )
+
+ return refresh_cache_context
+
+
+def enable_cache_for_glm_image(pipeline: Any, cache_config: Any) -> Callable[[int], None]:
+ """Enable cache-dit for GlmImage pipeline.
Args:
- pipeline: The Flux2 pipeline instance.
+ pipeline: The GlmImage pipeline instance.
cache_config: DiffusionCacheConfig instance with cache configuration.
Returns:
A refresh function that can be called with a new ``num_inference_steps``
@@ -1224,23 +1270,25 @@ def enable_cache_for_flux2(pipeline: Any, cache_config: Any) -> Callable[[int],
)
logger.info(
- f"Enabling cache-dit on Flux transformer with BlockAdapter: "
+ f"Enabling cache-dit on GlmImage transformer with BlockAdapter: "
f"Fn={db_cache_config.Fn_compute_blocks}, "
f"Bn={db_cache_config.Bn_compute_blocks}, "
f"W={db_cache_config.max_warmup_steps}, "
)
# Enable cache-dit using BlockAdapter for transformer
+ # Note: We don't use patch_functor here because it's designed for diffusers' GlmImage,
+ # and our vllm-omni implementation has a different forward signature.
+ # We use ForwardPattern.Pattern_0 because our block returns (hidden_states, encoder_hidden_states)
cache_dit.enable_cache(
(
BlockAdapter(
transformer=pipeline.transformer,
- blocks=[
- pipeline.transformer.transformer_blocks,
- pipeline.transformer.single_transformer_blocks,
- ],
- forward_pattern=[ForwardPattern.Pattern_1, ForwardPattern.Pattern_2],
+ blocks=pipeline.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
params_modifiers=[modifier],
+ patch_functor=None,
+ has_separate_cfg=True,
)
),
cache_config=db_cache_config,
@@ -1250,7 +1298,7 @@ def refresh_cache_context(pipeline: Any, num_inference_steps: int, verbose: bool
"""Refresh cache context for the transformer with new num_inference_steps.
Args:
- pipeline: The Flux2 pipeline instance.
+ pipeline: The GlmImage pipeline instance.
num_inference_steps: New number of inference steps.
"""
if cache_config.scm_steps_mask_policy is None:
diff --git a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
index baec21c2762..38c805c28db 100644
--- a/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
+++ b/vllm_omni/diffusion/cache/teacache/coefficient_estimator.py
@@ -1,20 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
from typing import Any
import numpy as np
import torch
from vllm.config import LoadConfig
-from vllm.utils.torch_utils import set_default_torch_dtype
+from vllm.transformers_utils.config import get_hf_file_to_dict
from vllm_omni.diffusion.cache.teacache.extractors import get_extractor
-from vllm_omni.diffusion.data import OmniDiffusionConfig
+from vllm_omni.diffusion.data import OmniDiffusionConfig, TransformerConfig
from vllm_omni.diffusion.hooks import HookRegistry, ModelHook
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
-from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline
-from vllm_omni.diffusion.models.flux2.pipeline_flux2 import Flux2Pipeline
-from vllm_omni.diffusion.models.stable_audio.pipeline_stable_audio import StableAudioPipeline
from vllm_omni.diffusion.request import OmniDiffusionRequest
from vllm_omni.inputs.data import OmniDiffusionSamplingParams
@@ -36,6 +34,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def new_forward(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
ctx = self.extractor_fn(module, *args, **kwargs)
+ # NOTE: We upcast to float32 to also handle bfloat16.
modulated_input_cpu = ctx.modulated_input.detach().float().cpu().numpy()
outputs = ctx.run_transformer_blocks()
@@ -54,23 +53,39 @@ def stop_collection(self) -> list[tuple[np.ndarray, np.ndarray]]:
return list(self.current_trajectory)
-class BagelAdapter:
- """Adapter for Bagel model."""
+class DefaultAdapter:
+ """Default adapter for standard diffusers pipelines."""
- @staticmethod
- def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> BagelPipeline:
- od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
- od_config.model_class_name = "BagelPipeline"
+ model_class_name = None
+ uses_tf_config = True
+
+ @classmethod
+ def load_pipeline(cls, model_path: str, device: str, dtype: torch.dtype) -> Any:
+ if cls.model_class_name is None:
+ raise ValueError("Adapter doesn't have a set class name.")
- pipeline = BagelPipeline(od_config=od_config)
- loader = DiffusersPipelineLoader(LoadConfig())
- loader.load_weights(pipeline)
- pipeline.to(device)
- return pipeline
+ od_config = OmniDiffusionConfig.from_kwargs(
+ model_class_name=cls.model_class_name,
+ model=model_path,
+ dtype=dtype,
+ )
+
+ if cls.uses_tf_config:
+ # TODO (Alex): Refactor to handle tf_model_config in OmniDiffusionConfig
+ # instead of OmniDiffusion and remove the manual population here
+ tf_config_dict = get_hf_file_to_dict(
+ os.path.join("transformer", "config.json"),
+ od_config.model,
+ )
+ od_config.tf_model_config = TransformerConfig.from_dict(tf_config_dict)
+
+ loader = DiffusersPipelineLoader(LoadConfig(), od_config=od_config)
+ # load_model will handle dtypes / device placement, put in .eval() mode
+ return loader.load_model(od_config=od_config, load_device=device)
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.bagel, "Bagel"
+ return pipeline.transformer, pipeline.transformer.__class__.__name__
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
@@ -78,25 +93,17 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry.register_hook(hook._HOOK_NAME, hook)
-class StableAudioAdapter:
- """Adapter for Stable Audio Open 1.0 coefficient estimation."""
-
- @staticmethod
- def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.float16) -> Any:
- od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
-
- # Strictly necessary because we bypass loader.load_model()
- with set_default_torch_dtype(dtype):
- pipeline = StableAudioPipeline(od_config=od_config)
+class BagelAdapter(DefaultAdapter):
+ """Adapter for Bagel model."""
- loader = DiffusersPipelineLoader(LoadConfig())
- loader.load_weights(pipeline)
- pipeline.to(device)
- return pipeline
+ model_class_name = "BagelPipeline"
+ # Skip the hack for loading the tf model config,
+ # because bagel doesn't use it.
+ uses_tf_config = False
@staticmethod
def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.transformer, "StableAudioDiTModel"
+ return pipeline.bagel, "Bagel"
@staticmethod
def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
@@ -104,52 +111,32 @@ def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
registry.register_hook(hook._HOOK_NAME, hook)
-class Flux2Adapter:
+class Flux2Adapter(DefaultAdapter):
"""Adapter for Flux2 model coefficient estimation."""
- @staticmethod
- def load_pipeline(model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16) -> Flux2Pipeline:
- """Load Flux2 pipeline for coefficient estimation."""
- od_config = OmniDiffusionConfig.from_kwargs(model=model_path, dtype=dtype)
- od_config.model_class_name = "Flux2Pipeline"
-
- pipeline = Flux2Pipeline(od_config=od_config)
- loader = DiffusersPipelineLoader(LoadConfig())
- loader.load_weights(pipeline)
- pipeline.to(device)
- return pipeline
+ model_class_name = "Flux2Pipeline"
- @staticmethod
- def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.transformer, pipeline.transformer.__class__.__name__
- @staticmethod
- def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
- registry = HookRegistry.get_or_create(transformer)
- registry.register_hook(hook._HOOK_NAME, hook)
+class LongCatAdapter(DefaultAdapter):
+ """Adapter for LongCat Image - NOTE: currently this model needs the vLLM
+ context to be correctly configured to actually run the estimation, since it
+ uses vLLM norm layers etc.
+ """
+ model_class_name = "LongCatImagePipeline"
-class DefaultAdapter:
- """Default adapter for standard diffusers pipelines."""
- @staticmethod
- def load_pipeline(model_path: str, device: str, dtype: torch.dtype) -> Any:
- raise NotImplementedError("DefaultAdapter.load_pipeline not implemented")
-
- @staticmethod
- def get_transformer(pipeline: Any) -> tuple[Any, str]:
- return pipeline.transformer, pipeline.transformer.__class__.__name__
+class StableAudioAdapter(DefaultAdapter):
+ """Adapter for Stable Audio Open 1.0 coefficient estimation."""
- @staticmethod
- def install_hook(transformer: Any, hook: DataCollectionHook) -> None:
- registry = HookRegistry.get_or_create(transformer)
- registry.register_hook(hook._HOOK_NAME, hook)
+ model_class_name = "StableAudioPipeline"
_MODEL_ADAPTERS: dict[str, type] = {
"Bagel": BagelAdapter,
"StableAudio": StableAudioAdapter,
"Flux2": Flux2Adapter,
+ "LongCat": LongCatAdapter,
}
_EPSILON = 1e-6
@@ -196,7 +183,6 @@ def __init__(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
- # Add validation here ⬇️
if model_type not in _MODEL_ADAPTERS:
available_types = list(_MODEL_ADAPTERS.keys())
raise ValueError(
@@ -205,7 +191,7 @@ def __init__(
f"To add support for a new model, add an entry to _MODEL_ADAPTERS."
)
- adapter = _MODEL_ADAPTERS.get(model_type, DefaultAdapter)
+ adapter = _MODEL_ADAPTERS[model_type]
self.pipeline = adapter.load_pipeline(model_path, device, dtype)
self.transformer, self.transformer_type = adapter.get_transformer(self.pipeline)
self.hook = DataCollectionHook(self.transformer_type)
diff --git a/vllm_omni/diffusion/cache/teacache/config.py b/vllm_omni/diffusion/cache/teacache/config.py
index ecf3bfc1d3d..7efdd418e12 100644
--- a/vllm_omni/diffusion/cache/teacache/config.py
+++ b/vllm_omni/diffusion/cache/teacache/config.py
@@ -73,6 +73,8 @@
3.20000000e00,
-2.00000000e-02,
],
+ # LongCat Image transformer coefficients
+ "LongCatImageTransformer2DModel": [652.5980, -424.1615, 84.5526, -4.5923, 0.1694],
}
diff --git a/vllm_omni/diffusion/cache/teacache/extractors.py b/vllm_omni/diffusion/cache/teacache/extractors.py
index 84c237b60d5..d0da0d9df3f 100644
--- a/vllm_omni/diffusion/cache/teacache/extractors.py
+++ b/vllm_omni/diffusion/cache/teacache/extractors.py
@@ -19,10 +19,13 @@
import torch
import torch.nn as nn
+from vllm.logger import init_logger
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.platforms import current_omni_platform
+logger = init_logger(__name__)
+
@dataclass
class CacheContext:
@@ -723,6 +726,105 @@ def postprocess(h):
)
+def extract_longcat_context(
+ module: nn.Module, # LongCatImageTransformer2DModel
+ hidden_states,
+ timestep,
+ guidance,
+ encoder_hidden_states,
+ txt_ids,
+ img_ids,
+ **kwargs,
+) -> CacheContext:
+ """Extract the cache context for LongCat Image.
+
+ Similar to other extractors, this is currently the only code needed
+ for TeaCache support for LongCat image, and encapsulates preprocessing,
+ modulated input extraction, transformer execution, and postprocessing
+ logic.
+
+ Args & kawrgs are identical to the inputs to LongCat Image's forward.
+
+ Returns:
+ CacheContext with all information needed for generic caching
+ """
+ # TODO (Alex) - Refactor TeaCache extractors to more tightly integrate with .forward
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
+
+ # 1. Model specific preprocessing
+ fwd_context = get_forward_context()
+ sp_size = module.parallel_config.sequence_parallel_size
+ if sp_size is not None and sp_size > 1:
+ # NOTE: For now, we set this to False on the forward context
+ # to be consistent with LongCat Image's current behavior when
+ # TeaCache is enabled. We do not need to reset it in post process
+ # since we should never split text embed in sp for this model.
+ fwd_context.split_text_embed_in_sp = False
+
+ hidden_states = module.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+
+ temb = module.time_embed(timestep, hidden_states.dtype)
+ encoder_hidden_states = module.context_embedder(encoder_hidden_states)
+
+ # Compute RoPE embeddings via rope_preparer module
+ # _sp_plan will automatically shard img_cos/img_sin (outputs 2, 3)
+ # txt_cos/txt_sin (outputs 0, 1) remain replicated for dual-stream attention
+ txt_cos, txt_sin, img_cos, img_sin = module.rope_preparer(txt_ids, img_ids)
+
+ # Reconstruct image_rotary_emb with chunked values
+ # Final shape: (txt_seq_len + img_seq_len // SP, head_dim)
+ image_rotary_emb = (
+ torch.cat([txt_cos, img_cos], dim=0),
+ torch.cat([txt_sin, img_sin], dim=0),
+ )
+
+ # 2. Extract the modulated output from the first mm-DiT block
+ first_block = module.transformer_blocks[0]
+ img_modulated = first_block.norm1(hidden_states, emb=temb)[0]
+
+ # 3. Define the transformer execution
+ def run_transformer_blocks():
+ """Execute all Longcat transformer blocks."""
+ h = hidden_states
+ e = encoder_hidden_states
+ for block in module.transformer_blocks:
+ e, h = block(
+ hidden_states=h,
+ encoder_hidden_states=e,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ for block in module.single_transformer_blocks:
+ e, h = block(
+ hidden_states=h,
+ encoder_hidden_states=e,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ # Hook expects hidden states to be first
+ return (h, e)
+
+ # 4. Postprocessing
+ def postprocess(h):
+ """Apply Longcat-specific output postprocessing."""
+ h = module.norm_out(h, temb)
+ output = module.proj_out(h)
+ return Transformer2DModelOutput(sample=output)
+
+ # 5. Return the CacheContext
+ return CacheContext(
+ modulated_input=img_modulated,
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ run_transformer_blocks=run_transformer_blocks,
+ postprocess=postprocess,
+ )
+
+
def extract_stable_audio_context(
module: nn.Module,
hidden_states: torch.Tensor,
@@ -980,6 +1082,7 @@ def postprocess(h):
"Flux2Klein": extract_flux2_klein_context,
"StableAudioDiTModel": extract_stable_audio_context,
"Flux2Transformer2DModel": extract_flux2_context,
+ "LongCatImageTransformer2DModel": extract_longcat_context,
# Future models:
# "FluxTransformer2DModel": extract_flux_context,
# "CogVideoXTransformer3DModel": extract_cogvideox_context,
diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py
index fca0a5bad05..0a19eb11974 100644
--- a/vllm_omni/diffusion/data.py
+++ b/vllm_omni/diffusion/data.py
@@ -18,6 +18,7 @@
QuantizationConfig,
)
+from vllm_omni.diffusion.model_metadata import get_diffusion_model_metadata
from vllm_omni.diffusion.utils.network_utils import is_port_available
from vllm_omni.quantization import build_quant_config
@@ -481,8 +482,10 @@ class OmniDiffusionConfig:
# Scheduler flow_shift for Wan2.2 (12.0 for 480p, 5.0 for 720p)
flow_shift: float | None = None
- # support multi images input
+ # Support multi-image inputs and expose any model-specific request limit
+ # through a generic config field so serving code stays model-agnostic.
supports_multimodal_inputs: bool = False
+ max_multimodal_image_inputs: int | None = None
log_level: str = "info"
@@ -664,7 +667,11 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None:
)
def update_multimodal_support(self) -> None:
- self.supports_multimodal_inputs = self.model_class_name in {"QwenImageEditPlusPipeline"}
+ # Resolve serving-visible multimodal behavior from shared metadata
+ # instead of importing concrete pipeline modules into the config layer.
+ metadata = get_diffusion_model_metadata(self.model_class_name)
+ self.supports_multimodal_inputs = metadata.supports_multimodal_inputs
+ self.max_multimodal_image_inputs = metadata.max_multimodal_image_inputs
def enrich_config(self) -> None:
"""Load model metadata from HuggingFace and populate config fields.
diff --git a/vllm_omni/diffusion/model_metadata.py b/vllm_omni/diffusion/model_metadata.py
new file mode 100644
index 00000000000..ec133e7380e
--- /dev/null
+++ b/vllm_omni/diffusion/model_metadata.py
@@ -0,0 +1,31 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class DiffusionModelMetadata:
+ # Keep serving-facing capability metadata in a lightweight shared module so
+ # config/model plumbing can read it without importing concrete pipelines.
+ supports_multimodal_inputs: bool = False
+ max_multimodal_image_inputs: int | None = None
+
+
+QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES = 4
+
+
+_DIFFUSION_MODEL_METADATA: dict[str, DiffusionModelMetadata] = {
+ "QwenImageEditPlusPipeline": DiffusionModelMetadata(
+ supports_multimodal_inputs=True,
+ max_multimodal_image_inputs=QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES,
+ ),
+}
+
+
+def get_diffusion_model_metadata(model_class_name: str | None) -> DiffusionModelMetadata:
+ # Unknown models fall back to "no special multimodal capabilities" so new
+ # pipelines do not accidentally inherit limits meant for other models.
+ if model_class_name is None:
+ return DiffusionModelMetadata()
+ return _DIFFUSION_MODEL_METADATA.get(model_class_name, DiffusionModelMetadata())
diff --git a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
index a3d2259e643..90baf5f6761 100644
--- a/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
+++ b/vllm_omni/diffusion/models/bagel/pipeline_bagel.py
@@ -397,11 +397,26 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
cfg_text_context["ropes"] = cfg_text_metadata["ropes"]
else:
cfg_text_context["ropes"] = [cfg_text_seq_len]
-
- if cfg_img_kv is None and cfg_text_kv is not None:
- cfg_img_kv = injected_kv
-
- if cfg_img_kv is not None:
+ else:
+ # No cfg_text companion received. For text2img this is the
+ # expected path: original BAGEL uses an empty KV cache (0
+ # tokens) as the text-unconditional branch. Keep the default
+ # empty NaiveCache in cfg_text_context and preserve the
+ # original cfg_text_scale so CFG still applies.
+ pass
+
+ if cfg_img_kv is None:
+ # text2img multi-stage: cfg_img reuses gen KV (positive prompt,
+ # no image), mirroring forward_cache_update_text on cfg_img_context
+ # in the single-stage path.
+ cfg_img_seq_len = injected_kv.key_cache[0].shape[0]
+ cfg_img_context["past_key_values"] = injected_kv
+ cfg_img_context["kv_lens"] = [cfg_img_seq_len]
+ if req.sampling_params.kv_metadata and "ropes" in req.sampling_params.kv_metadata:
+ cfg_img_context["ropes"] = req.sampling_params.kv_metadata["ropes"]
+ else:
+ cfg_img_context["ropes"] = [cfg_img_seq_len]
+ else:
cfg_img_seq_len = cfg_img_kv.key_cache[0].shape[0]
cfg_img_context["past_key_values"] = cfg_img_kv
cfg_img_context["kv_lens"] = [cfg_img_seq_len]
@@ -410,15 +425,6 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput:
else:
cfg_img_context["ropes"] = [cfg_img_seq_len]
- if not cfg_parallel_contract:
- logger.warning("CFG is disabled: only single KV cache available")
- gen_params = BagelGenParams(
- num_timesteps=gen_params.num_timesteps,
- timestep_shift=gen_params.timestep_shift,
- cfg_text_scale=1.0,
- cfg_img_scale=1.0,
- )
-
else:
image_input = (
None
diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py
index a534f5a76fa..abca4c9474f 100644
--- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py
+++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py
@@ -1,3 +1,5 @@
+import re
+
import torch
import torch.nn as nn
from vllm.logger import init_logger
@@ -15,78 +17,26 @@
logger = init_logger(__name__)
-class FusionModel(nn.Module):
- def __init__(self, video_config=None, audio_config=None):
- super().__init__()
- has_video = True
- has_audio = True
- if video_config is not None:
- self.video_model = WanModel(**video_config)
- else:
- has_video = False
- self.video_model = None
- logger.warning("No video model is provided!")
-
- if audio_config is not None:
- self.audio_model = WanModel(**audio_config)
- else:
- has_audio = False
- self.audio_model = None
- logger.warning("No audio model is provided!")
-
- if has_video and has_audio:
- assert len(self.video_model.blocks) == len(self.audio_model.blocks)
- self.num_blocks = len(self.video_model.blocks)
-
- self.inject_cross_attention_kv_projections()
- self.device = get_local_device()
-
- self.num_heads = self.video_model.num_heads
- self.head_dim = self.video_model.dim // self.video_model.num_heads
- self.attn = Attention(
- num_heads=self.num_heads,
- head_size=self.head_dim,
- num_kv_heads=self.num_heads,
- softmax_scale=1.0 / (self.head_dim**0.5),
- causal=False,
- )
-
- def inject_cross_attention_kv_projections(self):
- for vid_block in self.video_model.blocks:
- vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
- vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
- vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
- vid_block.cross_attn.norm_k_fusion = (
- WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
- )
+class FusedBlock(nn.Module):
+ """Wrapper pairing a video block and audio block for layerwise offloading.
- for audio_block in self.audio_model.blocks:
- audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
- audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
- audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
- audio_block.cross_attn.norm_k_fusion = (
- WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
- )
+ Registers both blocks as submodules so their parameters are visible to the offload hooks.
+ """
- def merge_kwargs(self, vid_kwargs, audio_kwargs):
- """
- keys in each kwarg:
- e
- seq_lens
- grid_sizes
- freqs
- context
- context_lens
- """
- merged_kwargs = {}
- for key in vid_kwargs:
- merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
- for key in audio_kwargs:
- merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
- return merged_kwargs
+ def __init__(
+ self,
+ vid_block: nn.Module,
+ audio_block: nn.Module,
+ device: torch.device,
+ ):
+ super().__init__()
+ self.vid_block = vid_block
+ self.audio_block = audio_block
+ self.device = device
- def single_fusion_cross_attention_forward(
+ def _cross_attention_forward(
self,
+ attn: Attention,
cross_attn_block,
src_seq,
src_grid_sizes,
@@ -104,21 +54,17 @@ def single_fusion_cross_attention_forward(
):
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
if hasattr(cross_attn_block, "k_img"):
- ## means is i2v block
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
else:
- ## means is t2v block
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
k_img = v_img = None
- x = self.attn(q, k, v)
+ x = attn(q, k, v)
if k_img is not None:
- img_x = self.attn(q, k_img, v_img)
+ img_x = attn(q, k_img, v_img)
x = x + img_x
- # is_vid = src_grid_sizes.shape[1] > 1
- # compute target attention
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
@@ -132,17 +78,16 @@ def single_fusion_cross_attention_forward(
freqs_scaling=target_freqs_scaling,
)
- target_x = self.attn(q, k_target, v_target)
+ target_x = attn(q, k_target, v_target)
x = x + target_x
-
- x = x.flatten(2) # [B, L/P, C]
-
+ x = x.flatten(2)
x = cross_attn_block.o(x)
return x
- def single_fusion_cross_attention_ffn_forward(
+ def _cross_attention_ffn_forward(
self,
+ attn: Attention,
attn_block,
src_seq,
src_grid_sizes,
@@ -159,7 +104,8 @@ def single_fusion_cross_attention_ffn_forward(
target_ref_lengths=None,
target_freqs_scaling=None,
):
- src_seq = src_seq + self.single_fusion_cross_attention_forward(
+ src_seq = src_seq + self._cross_attention_forward(
+ attn,
attn_block.cross_attn,
attn_block.norm3(src_seq),
src_grid_sizes=src_grid_sizes,
@@ -180,12 +126,11 @@ def single_fusion_cross_attention_ffn_forward(
src_seq = src_seq + y * src_e[5].squeeze(2)
return src_seq
- def single_fusion_block_forward(
+ def forward(
self,
- vid_block,
- audio_block,
vid,
audio,
+ attn: Attention,
vid_e,
vid_seq_lens,
vid_grid_sizes,
@@ -203,6 +148,9 @@ def single_fusion_block_forward(
audio_ref_lengths,
audio_freqs_scaling,
):
+ vid_block = self.vid_block
+ audio_block = self.audio_block
+
## audio modulation
assert audio_e.dtype == torch.bfloat16
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], (
@@ -246,7 +194,8 @@ def single_fusion_block_forward(
og_audio = audio
# audio cross-attention
- audio = self.single_fusion_cross_attention_ffn_forward(
+ audio = self._cross_attention_ffn_forward(
+ attn,
audio_block,
audio,
audio_grid_sizes,
@@ -267,7 +216,8 @@ def single_fusion_block_forward(
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
# video cross-attention
- vid = self.single_fusion_cross_attention_ffn_forward(
+ vid = self._cross_attention_ffn_forward(
+ attn,
vid_block,
vid,
vid_grid_sizes,
@@ -287,6 +237,128 @@ def single_fusion_block_forward(
return vid, audio
+
+class FusionModel(nn.Module):
+ _layerwise_offload_blocks_attrs = ["fused_blocks"]
+
+ def __init__(self, video_config=None, audio_config=None):
+ super().__init__()
+ has_video = True
+ has_audio = True
+ self.device = get_local_device()
+ if video_config is not None:
+ self.video_model = WanModel(**video_config)
+ else:
+ has_video = False
+ self.video_model = None
+ logger.warning("No video model is provided!")
+
+ if audio_config is not None:
+ self.audio_model = WanModel(**audio_config)
+ else:
+ has_audio = False
+ self.audio_model = None
+ logger.warning("No audio model is provided!")
+
+ if has_video and has_audio:
+ assert len(self.video_model.blocks) == len(self.audio_model.blocks)
+ self.num_blocks = len(self.video_model.blocks)
+
+ self.inject_cross_attention_kv_projections()
+
+ self.num_heads = self.video_model.num_heads
+ self.head_dim = self.video_model.dim // self.video_model.num_heads
+ # Make a single shared instance to pass in at forward time
+ self.attn = Attention(
+ num_heads=self.num_heads,
+ head_size=self.head_dim,
+ num_kv_heads=self.num_heads,
+ softmax_scale=1.0 / (self.head_dim**0.5),
+ causal=False,
+ )
+
+ if has_video and has_audio:
+ self.fused_blocks = nn.ModuleList(
+ [
+ FusedBlock(
+ self.video_model.blocks[i],
+ self.audio_model.blocks[i],
+ self.device,
+ )
+ for i in range(self.num_blocks)
+ ]
+ )
+
+ def load_state_dict(self, state_dict, strict=True, assign=False):
+ """Remap checkpoints where blocks are stored under
+ `video_model.blocks.N.*` / `audio_model.blocks.N.*` to the current
+ `fused_blocks.N.vid_block.*` / `fused_blocks.N.audio_block.*`.
+ """
+ needs_remap = any(re.match(r"^(video_model|audio_model)\.blocks\.\d+\.", k) for k in state_dict)
+ if needs_remap:
+ remapped = {}
+ for k, v in state_dict.items():
+ new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k)
+ new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", new_k)
+ remapped[new_k] = v
+ state_dict = remapped
+
+ self._detach_blocks_from_backbones()
+
+ return super().load_state_dict(state_dict, strict=strict, assign=assign)
+
+ def inject_cross_attention_kv_projections(self):
+ for vid_block in self.video_model.blocks:
+ vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
+ vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
+ vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
+ vid_block.cross_attn.norm_k_fusion = (
+ WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
+ )
+
+ for audio_block in self.audio_model.blocks:
+ audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
+ audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
+ audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
+ audio_block.cross_attn.norm_k_fusion = (
+ WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
+ )
+
+ def _detach_blocks_from_backbones(self) -> None:
+ """Keep offloadable blocks owned only by a single place.
+
+ NOTE: This is a special workaround to support layerwise offloading.
+ The model registers the same Wan blocks under both the video/audio
+ backbones and `fused_blocks` which is a wrapper for unified blocks
+ walking through. However, layerwise offloading will only consider
+ `fused_blocks` as offloadable components and will materialize all
+ other modules onto device, including the same blocks owned by both
+ `fused_blocks` and `video_model` and `audio_model`.
+ """
+ video_blocks = list(self.video_model.blocks)
+ audio_blocks = list(self.audio_model.blocks)
+ self.video_model._modules.pop("blocks", None)
+ self.audio_model._modules.pop("blocks", None)
+ self.video_model.blocks = tuple(video_blocks)
+ self.audio_model.blocks = tuple(audio_blocks)
+
+ def merge_kwargs(self, vid_kwargs, audio_kwargs):
+ """
+ keys in each kwarg:
+ e
+ seq_lens
+ grid_sizes
+ freqs
+ context
+ context_lens
+ """
+ merged_kwargs = {}
+ for key in vid_kwargs:
+ merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
+ for key in audio_kwargs:
+ merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
+ return merged_kwargs
+
def forward(
self,
vid,
@@ -316,17 +388,8 @@ def forward(
kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
- for i in range(self.num_blocks):
- """
- 1 fusion block refers to 1 audio block with 1 video block.
- """
-
- vid_block = self.video_model.blocks[i]
- audio_block = self.audio_model.blocks[i]
-
- vid, audio = self.single_fusion_block_forward(
- vid_block=vid_block, audio_block=audio_block, vid=vid, audio=audio, **kwargs
- )
+ for fused_block in self.fused_blocks:
+ vid, audio = fused_block(vid, audio, self.attn, **kwargs)
vid = self.video_model.post_transformer_block_out(vid, vid_kwargs["grid_sizes"], vid_e)
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs["grid_sizes"], audio_e)
diff --git a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
index c7ab4662d14..cc932f8c1f8 100644
--- a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
+++ b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py
@@ -4,6 +4,7 @@
import logging
import math
import os
+from collections.abc import Iterable
import torch
import torch.distributed
@@ -16,6 +17,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin
from vllm_omni.diffusion.distributed.utils import get_local_device
+from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput
from vllm_omni.diffusion.request import OmniDiffusionRequest
@@ -27,7 +29,6 @@
init_mmaudio_vae,
init_text_model,
init_wan_vae_2_2,
- load_fusion_checkpoint,
)
from dreamid_omni.utils.rearrange import Rearrange
from dreamid_omni.utils.resize import NaResize
@@ -122,16 +123,24 @@ def __init__(
self.text_model = init_text_model(model, rank=self.device)
self.text_encoder = self.text_model.model
- # Fusion model
- ## load audio/video model config
- Fusion_model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG)
-
- checkpoint_path = self.od_config.tf_model_config.get("fusion", None)
- assert checkpoint_path is not None, "fusion checkpoint path is None"
- load_fusion_checkpoint(Fusion_model, checkpoint_path=os.path.join(model, checkpoint_path))
- self.model = Fusion_model
+ # Fusion model — weights are loaded later via load_weights()
+ self.model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG)
self.transformer = self.model
+ fusion_path = self.od_config.tf_model_config.get("fusion", None)
+ assert fusion_path is not None, "fusion checkpoint path is None in transformer config"
+ fusion_subfolder = os.path.dirname(fusion_path) or None
+ fusion_filename = os.path.basename(fusion_path)
+ self.weights_sources = [
+ DiffusersPipelineLoader.ComponentSource(
+ model_or_path=model,
+ subfolder=fusion_subfolder,
+ revision=None,
+ prefix="model.",
+ allow_patterns_overrides=[fusion_filename],
+ )
+ ]
+
# Fixed attributes, non-configurable
self.audio_latent_channel = AUDIO_CONFIG.get("in_dim")
self.video_latent_channel = VIDEO_CONFIG.get("in_dim")
@@ -226,8 +235,11 @@ def load_image_latent_ref_ip_video(
return ref_vae_latents, ref_audio_lengths
- def load_weights(self, weights):
- pass
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ prefix = "model."
+ state_dict = {name[len(prefix) :]: tensor for name, tensor in weights if name.startswith(prefix)}
+ self.model.load_state_dict(state_dict, strict=True)
+ return {prefix + k for k in state_dict}
def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0, shift=5.0):
torch.manual_seed(4)
diff --git a/vllm_omni/diffusion/models/helios/helios_transformer.py b/vllm_omni/diffusion/models/helios/helios_transformer.py
index b3d2621ad88..5e7934c3ba6 100644
--- a/vllm_omni/diffusion/models/helios/helios_transformer.py
+++ b/vllm_omni/diffusion/models/helios/helios_transformer.py
@@ -62,10 +62,16 @@ def apply_rotary_emb_helios(
"""
x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
- out = torch.empty_like(hidden_states)
- out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
- out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
- return out.type_as(hidden_states)
+ # Use stack+flatten instead of strided slice assignment for contiguous
+ # memory layout and better performance on GPU/NPU (#2436, cf. PR #2393).
+ rotated = torch.stack(
+ (
+ x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2],
+ x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2],
+ ),
+ dim=-1,
+ )
+ return rotated.flatten(-2, -1).type_as(hidden_states)
class DistributedRMSNorm(nn.Module):
diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
index a1bf7f7809c..95ef919c24e 100644
--- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
+++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py
@@ -1264,6 +1264,7 @@ class LTX2VideoTransformer3DModel(nn.Module):
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTX2VideoTransformerBlock"]
+ _layerwise_offload_blocks_attrs = ["transformer_blocks"]
_sp_plan: dict[str, Any] | None = None
@staticmethod
diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
index a2702f3d295..2e25d0fe6b2 100644
--- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
+++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image_edit_plus.py
@@ -25,6 +25,7 @@
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
+from vllm_omni.diffusion.model_metadata import QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
from vllm_omni.diffusion.models.interface import SupportImageInput
from vllm_omni.diffusion.models.qwen_image.cfg_parallel import (
QwenImageCFGParallelMixin,
@@ -56,6 +57,12 @@
CONDITION_IMAGE_SIZE = 384 * 384
VAE_IMAGE_SIZE = 1024 * 1024
+# Keep this in sync with the practical conditioning-token budget for
+# Qwen-Image-Edit-2511. Empirically, 4 images stays within the supported range
+# while 5 images overflows the prompt/conditioning path and fails downstream.
+# Re-export the shared metadata value locally so this pipeline keeps a nearby,
+# descriptive constant for validation and tests without becoming the source of truth.
+MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES = QWEN_IMAGE_EDIT_PLUS_MAX_INPUT_IMAGES
def get_qwen_image_edit_plus_pre_process_func(
@@ -93,6 +100,11 @@ def pre_process_func(
if not isinstance(raw_image, list):
raw_image = [raw_image]
+ if len(raw_image) > MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES:
+ raise ValueError(
+ f"Received {len(raw_image)} input images. "
+ f"At most {MAX_QWEN_IMAGE_EDIT_PLUS_INPUT_IMAGES} images are supported by this model."
+ )
image = [
PIL.Image.open(im) if isinstance(im, str) else cast(PIL.Image.Image | np.ndarray | torch.Tensor, im)
for im in raw_image
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
index e1249e889c7..652425d5097 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
@@ -392,6 +392,102 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ guidance_low: float,
+ guidance_high: float,
+ boundary_timestep: float | None,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, Any],
+ latent_condition: torch.Tensor | None = None,
+ first_frame_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Select model based on timestep and boundary_ratio
+ # High noise stage (t >= boundary_timestep): use transformer
+ # Low noise stage (t < boundary_timestep): use transformer_2
+ if boundary_timestep is not None and t < boundary_timestep:
+ # Low noise stage - always use guidance_high for this stage
+ current_guidance_scale = guidance_high
+ if self.transformer_2 is not None:
+ current_model = self.transformer_2
+ elif self.transformer is not None:
+ # Fallback to transformer if transformer_2 not loaded
+ current_model = self.transformer
+ else:
+ raise RuntimeError("No transformer available for low-noise stage")
+ else:
+ # High noise stage - always use guidance_low for this stage
+ current_guidance_scale = guidance_low
+ if self.transformer is not None:
+ current_model = self.transformer
+ elif self.transformer_2 is not None:
+ # Fallback to transformer_2 if transformer not loaded
+ current_model = self.transformer_2
+ else:
+ raise RuntimeError("No transformer available for high-noise stage")
+
+ if self.expand_timesteps and latent_condition is not None:
+ # I2V mode: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps per patch - use floor division to match patch embedding
+ patch_size = self.transformer_config.patch_size
+ patch_height = latents.shape[3] // patch_size[1]
+ patch_width = latents.shape[4] // patch_size[2]
+
+ # Create mask at patch resolution (same as hidden states sequence length)
+ patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]]
+ patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions
+ temp_ts = (patch_mask[0][0] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # T2V mode: standard forward
+ latent_model_input = latents.to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ else:
+ negative_kwargs = None
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=current_guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def forward(
self,
req: OmniDiffusionRequest,
@@ -635,90 +731,19 @@ def forward(
if DEBUG_PERF:
_t_denoise_start = time.perf_counter()
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Select model based on timestep and boundary_ratio
- # High noise stage (t >= boundary_timestep): use transformer
- # Low noise stage (t < boundary_timestep): use transformer_2
- if boundary_timestep is not None and t < boundary_timestep:
- # Low noise stage - always use guidance_high for this stage
- current_guidance_scale = guidance_high
- if self.transformer_2 is not None:
- current_model = self.transformer_2
- elif self.transformer is not None:
- # Fallback to transformer if transformer_2 not loaded
- current_model = self.transformer
- else:
- raise RuntimeError("No transformer available for low-noise stage")
- else:
- # High noise stage - always use guidance_low for this stage
- current_guidance_scale = guidance_low
- if self.transformer is not None:
- current_model = self.transformer
- elif self.transformer_2 is not None:
- # Fallback to transformer_2 if transformer not loaded
- current_model = self.transformer_2
- else:
- raise RuntimeError("No transformer available for high-noise stage")
-
- if self.expand_timesteps and latent_condition is not None:
- # I2V mode: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps per patch - use floor division to match patch embedding
- patch_size = self.transformer_config.patch_size
- num_latent_frames = latents.shape[2]
- patch_height = latents.shape[3] // patch_size[1]
- patch_width = latents.shape[4] // patch_size[2]
-
- # Create mask at patch resolution (same as hidden states sequence length)
- patch_mask = first_frame_mask[:, :, :, :: patch_size[1], :: patch_size[2]]
- patch_mask = patch_mask[:, :, :, :patch_height, :patch_width] # Ensure correct dimensions
- temp_ts = (patch_mask[0][0] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # T2V mode: standard forward
- latent_model_input = latents.to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
- # Prepare kwargs for positive and negative predictions
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- else:
- negative_kwargs = None
-
- # Predict noise with automatic CFG parallel handling
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=current_guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_low=guidance_low,
+ guidance_high=guidance_high,
+ boundary_timestep=boundary_timestep,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ latent_condition=latent_condition,
+ first_frame_mask=first_frame_mask,
+ )
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
index ca042ca228b..95d1e08bbc7 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py
@@ -12,6 +12,7 @@
import numpy as np
import PIL.Image
import torch
+import torchvision.transforms.functional as TF
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
@@ -289,6 +290,82 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ image_embeds: torch.Tensor | None,
+ guidance_low: float,
+ guidance_high: float,
+ boundary_timestep: float | None,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, Any],
+ condition: torch.Tensor,
+ first_frame_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Select model and guidance scale based on timestep
+ current_model = self.transformer
+ current_guidance_scale = guidance_low
+ if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None:
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_high
+
+ # Prepare latent input
+ if self.expand_timesteps:
+ # TI2V-5B style: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps for each patch
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # Wan2.1 style: concatenate condition with latents
+ latent_model_input = torch.cat([latents, condition], dim=1).to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "encoder_hidden_states_image": image_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "encoder_hidden_states_image": image_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": current_model,
+ }
+ else:
+ negative_kwargs = None
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=current_guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def encode_image(
self,
image: PIL.Image.Image | list[PIL.Image.Image],
@@ -484,6 +561,7 @@ def forward(
video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
if isinstance(image, PIL.Image.Image):
+ image = TF.to_tensor(image).to(device)
image_tensor = video_processor.preprocess(image, height=height, width=width)
else:
image_tensor = image
@@ -492,6 +570,7 @@ def forward(
# Handle last_image if provided
if last_image is not None:
if isinstance(last_image, PIL.Image.Image):
+ image = TF.to_tensor(last_image).to(device)
last_image_tensor = video_processor.preprocess(last_image, height=height, width=width)
else:
last_image_tensor = last_image
@@ -522,68 +601,20 @@ def forward(
if DEBUG_PERF:
_t_denoise_start = time.perf_counter()
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Select model and guidance scale based on timestep
- current_model = self.transformer
- current_guidance_scale = guidance_low
- if boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None:
- current_model = self.transformer_2
- current_guidance_scale = guidance_high
-
- # Prepare latent input
- if self.expand_timesteps:
- # TI2V-5B style: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps for each patch
- temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # Wan2.1 style: concatenate condition with latents
- latent_model_input = torch.cat([latents, condition], dim=1).to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None
- # Prepare kwargs for positive and negative predictions
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "encoder_hidden_states_image": image_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "encoder_hidden_states_image": image_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": current_model,
- }
- else:
- negative_kwargs = None
-
- # Predict noise with automatic CFG parallel handling
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=current_guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ image_embeds=image_embeds,
+ guidance_low=guidance_low,
+ guidance_high=guidance_high,
+ boundary_timestep=boundary_timestep,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ condition=condition,
+ first_frame_mask=first_frame_mask,
+ )
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
@@ -844,12 +875,14 @@ def prepare_latents(
return latents, latent_condition, first_frame_mask
# Wan2.1 style: create mask and concatenate with condition
- mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+ mask_lat_size = torch.ones(
+ batch_size, 1, num_frames, latent_height, latent_width, device=latent_condition.device
+ )
if last_image is None:
- mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ mask_lat_size[:, :, 1:] = 0
else:
- mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
+ mask_lat_size[:, :, 1 : num_frames - 1] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
index 6cbd6d2d6bd..dba76ba8af8 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py
@@ -235,6 +235,77 @@ def num_timesteps(self):
def current_timestep(self):
return self._current_timestep
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ guidance_scale: float,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, Any],
+ num_latent_frames: int,
+ latent_height: int,
+ latent_width: int,
+ latent_condition: torch.Tensor | None = None,
+ first_frame_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+
+ # Prepare latent input
+ if latent_condition is not None:
+ # I2V mode: blend condition with latents using mask
+ latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(dtype)
+
+ # Expand timesteps for each patch (TI2V style)
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ # T2V mode: use latents directly
+ latent_model_input = latents.to(dtype)
+
+ # Expand timesteps for TI2V model architecture
+ mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=latents.device)
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+
+ do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": self.transformer,
+ }
+ if do_true_cfg:
+ negative_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "return_dict": False,
+ "current_model": self.transformer,
+ }
+ else:
+ negative_kwargs = None
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def forward(
self,
req: OmniDiffusionRequest,
@@ -405,64 +476,20 @@ def forward(
if attention_kwargs is None:
attention_kwargs = {}
- # Denoising loop
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
-
- # Prepare latent input
- if latent_condition is not None:
- # I2V mode: blend condition with latents using mask
- latent_model_input = (1 - first_frame_mask) * latent_condition + first_frame_mask * latents
- latent_model_input = latent_model_input.to(dtype)
-
- # Expand timesteps for each patch (TI2V style)
- temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
- else:
- # T2V mode: use latents directly
- latent_model_input = latents.to(dtype)
-
- # Expand timesteps for TI2V model architecture
- mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, device=device)
- temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
- timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
-
- do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
- # Prepare kwargs for positive and negative predictions
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": self.transformer,
- }
- if do_true_cfg:
- negative_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "return_dict": False,
- "current_model": self.transformer,
- }
- else:
- negative_kwargs = None
-
- # Predict noise with automatic CFG parallel handling
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
-
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ num_latent_frames=num_latent_frames,
+ latent_height=latent_height,
+ latent_width=latent_width,
+ latent_condition=latent_condition,
+ first_frame_mask=first_frame_mask,
+ )
# Wan2.2 is prone to out of memory errors when predicting large videos
# so we empty the cache here to avoid OOM before vae decoding.
diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
index b661108cc6f..11408e2d24b 100644
--- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
+++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py
@@ -176,6 +176,62 @@ def _create_transformer(self, config: dict) -> WanVACETransformer3DModel:
"""Build VACE transformer directly from config dict."""
return create_vace_transformer_from_config(config)
+ def diffuse(
+ self,
+ latents: torch.Tensor,
+ timesteps: torch.Tensor,
+ prompt_embeds: torch.Tensor,
+ negative_prompt_embeds: torch.Tensor | None,
+ guidance_scale: float,
+ dtype: torch.dtype,
+ attention_kwargs: dict[str, object],
+ vace_context: torch.Tensor | None,
+ vace_context_scale: float,
+ ) -> torch.Tensor:
+ with self.progress_bar(total=len(timesteps)) as pbar:
+ for t in timesteps:
+ self._current_timestep = t
+ latent_model_input = latents.to(dtype)
+ timestep = t.expand(latents.shape[0])
+
+ do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
+
+ positive_kwargs = {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "vace_context": vace_context,
+ "vace_context_scale": vace_context_scale,
+ "return_dict": False,
+ }
+ negative_kwargs = (
+ {
+ "hidden_states": latent_model_input,
+ "timestep": timestep,
+ "encoder_hidden_states": negative_prompt_embeds,
+ "attention_kwargs": attention_kwargs,
+ "vace_context": vace_context,
+ "vace_context_scale": vace_context_scale,
+ "return_dict": False,
+ }
+ if do_true_cfg
+ else None
+ )
+
+ noise_pred = self.predict_noise_maybe_with_cfg(
+ do_true_cfg=do_true_cfg,
+ true_cfg_scale=guidance_scale,
+ positive_kwargs=positive_kwargs,
+ negative_kwargs=negative_kwargs,
+ cfg_normalize=False,
+ )
+
+ latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
+ pbar.update()
+
+ return latents
+
def check_inputs(
self,
prompt,
@@ -572,48 +628,17 @@ def forward(
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
- # Denoising loop
- with self.progress_bar(total=len(timesteps)) as pbar:
- for t in timesteps:
- self._current_timestep = t
- latent_model_input = latents.to(dtype)
- timestep = t.expand(latents.shape[0])
-
- do_true_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
-
- positive_kwargs = {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "vace_context": vace_context,
- "vace_context_scale": vace_context_scale,
- "return_dict": False,
- }
- negative_kwargs = (
- {
- "hidden_states": latent_model_input,
- "timestep": timestep,
- "encoder_hidden_states": negative_prompt_embeds,
- "attention_kwargs": attention_kwargs,
- "vace_context": vace_context,
- "vace_context_scale": vace_context_scale,
- "return_dict": False,
- }
- if do_true_cfg
- else None
- )
-
- noise_pred = self.predict_noise_maybe_with_cfg(
- do_true_cfg=do_true_cfg,
- true_cfg_scale=guidance_scale,
- positive_kwargs=positive_kwargs,
- negative_kwargs=negative_kwargs,
- cfg_normalize=False,
- )
-
- latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg)
- pbar.update()
+ latents = self.diffuse(
+ latents=latents,
+ timesteps=timesteps,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ guidance_scale=guidance_scale,
+ dtype=dtype,
+ attention_kwargs=attention_kwargs,
+ vace_context=vace_context,
+ vace_context_scale=vace_context_scale,
+ )
self._current_timestep = None
diff --git a/vllm_omni/diffusion/postprocess/rife_interpolator.py b/vllm_omni/diffusion/postprocess/rife_interpolator.py
index b2b4a931914..89297d0a446 100644
--- a/vllm_omni/diffusion/postprocess/rife_interpolator.py
+++ b/vllm_omni/diffusion/postprocess/rife_interpolator.py
@@ -412,9 +412,12 @@ def interpolate_tensor(
return restore_layout(video), 1
video, restore_range = _normalize_video_tensor_range(video)
- # Prefer the decoded video's current device so CPU-offloaded requests do
- # not move the tensor back to GPU just for interpolation.
- model = self._ensure_model_loaded(preferred_device=video.device)
+ # A CPU tensor may be transport/offload state rather than an execution
+ # choice, so only trust it when it is already on an accelerator.
+ preferred_device = video.device
+ if preferred_device.type == "cpu":
+ preferred_device = _select_torch_device()
+ model = self._ensure_model_loaded(preferred_device=preferred_device)
video = video.to(model.device())
intermediates_per_pair = 2**exp // 2
diff --git a/vllm_omni/distributed/omni_connectors/utils/initialization.py b/vllm_omni/distributed/omni_connectors/utils/initialization.py
index 0497bbb3a23..f012af3c9c3 100644
--- a/vllm_omni/distributed/omni_connectors/utils/initialization.py
+++ b/vllm_omni/distributed/omni_connectors/utils/initialization.py
@@ -206,6 +206,19 @@ def load_omni_transfer_config(
if config_dict is None:
return None
+ # Normalize new-schema (top-level ``connectors`` + ``stages``) into the
+ # legacy ``runtime.connectors`` + ``stage_args`` shape the parser reads.
+ if "stages" in config_dict and "stage_args" not in config_dict:
+ normalized: dict[str, Any] = dict(config_dict)
+ runtime = dict(normalized.get("runtime") or {})
+ if "connectors" in normalized and "connectors" not in runtime:
+ runtime["connectors"] = normalized["connectors"]
+ if "edges" in normalized and "edges" not in runtime:
+ runtime["edges"] = normalized["edges"]
+ normalized["runtime"] = runtime
+ normalized["stage_args"] = normalized["stages"]
+ config_dict = normalized
+
# Parse connectors
connectors = {}
runtime_config = config_dict.get("runtime", {})
diff --git a/vllm_omni/engine/__init__.py b/vllm_omni/engine/__init__.py
index c8a96e6d25d..6c92d7952de 100644
--- a/vllm_omni/engine/__init__.py
+++ b/vllm_omni/engine/__init__.py
@@ -79,6 +79,10 @@ class OmniEngineCoreRequest(EngineCoreRequest):
class OmniEngineCoreOutput(EngineCoreOutput):
pooling_output: dict[str, torch.Tensor] | None = None
+ # Finished flag for streaming input segment
+ is_segment_finished: bool | None = False
+ # Streaming update prompt length
+ new_prompt_len_snapshot: int | None = None
class OmniEngineCoreOutputs(EngineCoreOutputs):
diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py
index 5b69d6b1f0c..d98ce7d419d 100644
--- a/vllm_omni/engine/arg_utils.py
+++ b/vllm_omni/engine/arg_utils.py
@@ -3,7 +3,7 @@
import json
import os
import tempfile
-from dataclasses import dataclass, field
+from dataclasses import dataclass, field, fields
from typing import Any
from vllm.engine.arg_utils import EngineArgs
@@ -300,3 +300,254 @@ def create_model_config(self) -> OmniModelConfig:
def output_modality(self) -> OutputModality:
"""Parse engine_output_type into a type-safe OutputModality flag."""
return OutputModality.from_string(self.engine_output_type)
+
+
+# ============================================================================
+# CLI argument routing
+# ============================================================================
+#
+# vLLM-Omni's CLI flags live in three buckets:
+#
+# ┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐
+# │ OrchestratorArgs │ │ OmniEngineArgs │ │ (upstream vllm) │
+# │ │ │ │ │ server/api │
+# │ stage_timeout │ │ max_num_seqs │ │ host, port │
+# │ worker_backend │ │ gpu_mem_util │ │ ssl_keyfile │
+# │ deploy_config │ │ dtype, quant │ │ api_key │
+# │ ... │ │ ... │ │ ... │
+# └──────────────────┘ └──────────────────┘ └──────────────────┘
+# │ │ │
+# ▼ ▼ ▼
+# orchestrator each stage uvicorn /
+# consumes engine FastAPI
+#
+# Fields in ``SHARED_FIELDS`` (e.g. ``model``, ``log_stats``) flow to BOTH
+# orchestrator and engine by design.
+#
+# Invariants enforced by ``tests/test_arg_utils.py``:
+#
+# 1. ``OrchestratorArgs`` ∩ ``OmniEngineArgs`` ⊆ ``SHARED_FIELDS``
+# 2. Every CLI flag is classifiable into one of the three buckets
+# 3. User-typed flags that match none of the above are logged as dropped
+#
+# Adding a new orchestrator-only flag → add a field to ``OrchestratorArgs``.
+# Everything else is automatic.
+
+
+@dataclass(frozen=True)
+class OrchestratorArgs:
+ """CLI flags consumed by the orchestrator.
+
+ Contract: every field here is either
+ (a) orchestrator-only (never needed by a stage engine), OR
+ (b) orchestrator-read-then-redistributed (e.g. ``async_chunk`` is read
+ from CLI, written to ``DeployConfig``, then propagated to every
+ stage via ``merge_pipeline_deploy`` — not via direct kwargs
+ forwarding).
+
+ Fields that BOTH orchestrator and engine genuinely need (e.g. ``model``,
+ ``log_stats``) should be listed in ``SHARED_FIELDS`` below; ``split_kwargs``
+ will copy them to both buckets.
+ """
+
+ # === Lifecycle ===
+ stage_init_timeout: int = 300
+ init_timeout: int = 600
+
+ # === Cross-stage Communication ===
+ shm_threshold_bytes: int = 65536
+ batch_timeout: int = 10
+
+ # === Cluster / Backend ===
+ worker_backend: str = "multi_process"
+ ray_address: str | None = None
+
+ # === Config Files ===
+ stage_configs_path: str | None = None
+ deploy_config: str | None = None
+ stage_overrides: str | None = None # raw JSON string; parsed downstream
+
+ # === Mode Switches (orchestrator reads, DeployConfig redistributes) ===
+ async_chunk: bool | None = None
+
+ # === Observability ===
+ log_stats: bool = False
+
+ # === Headless Mode (also forwarded to engine — see SHARED_FIELDS) ===
+ stage_id: int | None = None
+
+ # === Pre-built Objects ===
+ parallel_config: Any = None
+
+ # === Multi-stage guards ===
+ # --tokenizer is captured here so it does not propagate to every stage
+ # uniformly (different stages often need different tokenizers, e.g.
+ # qwen3_omni thinker vs talker). Users wanting a per-stage tokenizer
+ # should set it in the deploy YAML.
+ tokenizer: str | None = None
+
+
+# Fields that live in BOTH OrchestratorArgs and OmniEngineArgs by design.
+# Changes to this set are a review red flag — revisit the contract.
+SHARED_FIELDS: frozenset[str] = frozenset(
+ {
+ "model", # orch: detect model_type; engine: load weights
+ "stage_id", # orch: route (headless); engine: identity
+ "log_stats", # both want the flag
+ "stage_configs_path", # orch: load legacy YAML; engine: may reference for validation
+ }
+)
+
+
+def orchestrator_field_names() -> frozenset[str]:
+ """Return the names of every field on OrchestratorArgs."""
+ return frozenset(f.name for f in fields(OrchestratorArgs))
+
+
+def internal_blacklist_keys() -> frozenset[str]:
+ """Return the set of CLI keys that must never be forwarded as per-stage
+ engine overrides.
+
+ Derived from ``OrchestratorArgs`` fields minus ``SHARED_FIELDS``, so
+ adding a new orchestrator-owned flag is a one-line change to the
+ dataclass — this function updates automatically.
+ """
+ return orchestrator_field_names() - SHARED_FIELDS
+
+
+def split_kwargs(
+ kwargs: dict[str, Any],
+ *,
+ engine_cls: type | None = None,
+ user_typed: set[str] | None = None,
+ strict: bool = False,
+) -> tuple[OrchestratorArgs, dict[str, Any]]:
+ """Partition CLI kwargs into (orchestrator, engine) buckets.
+
+ Args:
+ kwargs: Raw dict, typically ``vars(args)``.
+ engine_cls: Engine dataclass used to whitelist-filter the engine
+ bucket. Defaults to ``OmniEngineArgs``. Pass a custom class
+ for testing.
+ user_typed: Keys the user actually typed on the command line. Used
+ to warn when a user-typed flag is unclassifiable.
+ strict: If True, raise ``ValueError`` on ambiguous (double-classified
+ but not in ``SHARED_FIELDS``) fields. Default False to keep the
+ rollout non-breaking; flip to True in tests and CI.
+
+ Returns:
+ ``(orchestrator_args, engine_kwargs)``. ``engine_kwargs`` has already
+ been whitelist-filtered against ``engine_cls`` — safe to pass directly
+ to ``engine_cls(**engine_kwargs)``.
+ """
+ if engine_cls is None:
+ engine_cls = OmniEngineArgs
+
+ orch_fields = orchestrator_field_names()
+ engine_fields = {f.name for f in fields(engine_cls)}
+
+ orch_kwargs: dict[str, Any] = {}
+ engine_candidate: dict[str, Any] = {}
+ shared_values: dict[str, Any] = {}
+ unclassified: dict[str, Any] = {}
+
+ for key, value in kwargs.items():
+ in_orch = key in orch_fields
+ in_engine = key in engine_fields
+ is_shared = key in SHARED_FIELDS
+
+ if is_shared:
+ shared_values[key] = value
+ elif in_orch and in_engine:
+ # Declared in both but not marked shared → ambiguous.
+ msg = (
+ f"Field {key!r} is defined on both OrchestratorArgs and "
+ f"{engine_cls.__name__} but is not in SHARED_FIELDS. "
+ f"This causes double-routing. Either remove the duplicate or "
+ f"add {key!r} to SHARED_FIELDS if the sharing is intentional."
+ )
+ if strict:
+ raise ValueError(msg)
+ logger.error(msg)
+ # Default: treat as orchestrator-only to preserve existing behavior.
+ orch_kwargs[key] = value
+ elif in_orch:
+ orch_kwargs[key] = value
+ elif in_engine:
+ engine_candidate[key] = value
+ else:
+ unclassified[key] = value
+
+ # Warn on user-typed but unclassifiable flags so we don't silently drop
+ # something the user cared about (fixes the class of bug that spawned #873).
+ if unclassified and user_typed:
+ user_typed_unknown = sorted(k for k in unclassified if k in user_typed)
+ if user_typed_unknown:
+ logger.warning(
+ "CLI flags not consumed by vllm-omni and dropped before "
+ "per-stage engine construction: %s. If these are vllm "
+ "frontend/uvicorn flags (host, port, ssl_*, api_key, …) this "
+ "is expected; otherwise check your spelling.",
+ user_typed_unknown,
+ )
+
+ # Engine bucket: shared + engine-only. We do NOT pass through unclassified
+ # fields — that's exactly the server/uvicorn noise we want to shed.
+ engine_kwargs = {**shared_values, **engine_candidate}
+
+ # Construct the orchestrator dataclass. Shared fields that OrchestratorArgs
+ # also declares get copied into its constructor.
+ orch_init: dict[str, Any] = dict(orch_kwargs)
+ for key, value in shared_values.items():
+ if key in orch_fields:
+ orch_init[key] = value
+ orch_args = OrchestratorArgs(**orch_init)
+
+ return orch_args, engine_kwargs
+
+
+def derive_server_dests_from_vllm_parser() -> frozenset[str]:
+ """Derive the set of argparse dests that belong to vllm's frontend/server.
+
+ Returns every dest registered by ``make_arg_parser`` that is NOT a field
+ of ``OmniEngineArgs`` and NOT a field of ``OrchestratorArgs``. Useful for
+ CI tests to assert all CLI flags are classifiable without maintaining
+ a hardcoded server list.
+
+ Returns empty frozenset if vllm's parser cannot be built (e.g. in a
+ minimal test environment).
+ """
+ try:
+ from vllm.entrypoints.openai.cli_args import make_arg_parser
+ from vllm.utils.argparse_utils import FlexibleArgumentParser
+ except ImportError:
+ logger.debug("Cannot import vllm parser — server-dest derivation skipped")
+ return frozenset()
+
+ try:
+ parser = make_arg_parser(FlexibleArgumentParser())
+ all_dests = {a.dest for a in parser._actions if a.dest and a.dest != "help"}
+ except Exception as exc:
+ logger.debug("Failed to build vllm parser: %s", exc)
+ return frozenset()
+
+ engine_fields = {f.name for f in fields(OmniEngineArgs)}
+ orch_fields = orchestrator_field_names()
+
+ return frozenset(all_dests - engine_fields - orch_fields - SHARED_FIELDS)
+
+
+def orchestrator_args_from_argparse(args: Any) -> OrchestratorArgs:
+ """Build an ``OrchestratorArgs`` from an ``argparse.Namespace``.
+
+ Only copies attributes that exist on the namespace — missing fields fall
+ back to the dataclass default. Useful when the full parser is already
+ built and ``vars(args)`` would include noise.
+ """
+ kwargs: dict[str, Any] = {}
+ for f in fields(OrchestratorArgs):
+ if hasattr(args, f.name):
+ value = getattr(args, f.name)
+ if value is not None or f.default is None:
+ kwargs[f.name] = value
+ return OrchestratorArgs(**kwargs)
diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py
index 1ebebba687a..d02af153c45 100644
--- a/vllm_omni/engine/async_omni_engine.py
+++ b/vllm_omni/engine/async_omni_engine.py
@@ -26,6 +26,7 @@
import janus
import torch
from omegaconf import OmegaConf
+from vllm import envs as vllm_envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.logger import init_logger
@@ -33,6 +34,7 @@
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
+from vllm_omni.config.stage_config import strip_parent_engine_args
from vllm_omni.diffusion.data import DiffusionParallelConfig
from vllm_omni.diffusion.stage_diffusion_client import StageDiffusionClient
from vllm_omni.diffusion.stage_diffusion_proc import (
@@ -82,6 +84,7 @@
terminate_alive_proc,
)
from vllm_omni.engine.stage_pool import StagePool
+from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
from vllm_omni.entrypoints.utils import (
inject_omni_kv_config,
load_and_resolve_stage_configs,
@@ -95,6 +98,27 @@
logger = init_logger(__name__)
+# ============================================================================
+# Parent-EngineArgs field-routing contracts (consumed by
+# AsyncOmniEngine._strip_parent_engine_args when ``stage_configs_path`` is set).
+# ============================================================================
+
+# Fields that must survive the "equal to default → strip" filter because
+# diffusion stages need them even when equal to vllm's default value
+# (e.g. colocate worker setup relies on worker_extension_cls being forwarded).
+_PARENT_ARGS_KEEP: frozenset[str] = frozenset({"worker_extension_cls"})
+
+# Omni orchestrator-level fields consumed by ``_resolve_stage_configs`` that
+# must never leak into per-stage EngineArgs (``stage_configs_path`` would
+# trigger the ``create_model_config`` guard).
+_PARENT_ARGS_STRIP: frozenset[str] = frozenset({"stage_configs_path"})
+
+# Fields always populated by callers (via ``from_cli_args`` / ``asdict``) so
+# their presence as an override is never a surprise — suppress the
+# "override ignored" warning for these.
+_PARENT_ARGS_NO_WARN: frozenset[str] = frozenset({"model"})
+
+
def _patch_generation_config_if_needed(model_config: Any) -> None:
"""Ensure try_get_generation_config won't crash for models whose HF
config.json lacks model_type (e.g. CosyVoice3). We probe it once;
@@ -1003,12 +1027,14 @@ async def _run_orchestrator() -> None:
self._initialize_janus_queues()
self._initialize_stages(stage_init_timeout)
+ pd_config = self._detect_pd_config()
orchestrator = Orchestrator(
request_async_queue=self.request_queue.async_q,
output_async_queue=self.output_queue.async_q,
rpc_async_queue=self.rpc_output_queue.async_q,
stage_pools=self.stage_pools,
async_chunk=self.async_chunk,
+ pd_config=pd_config,
)
if not startup_future.done():
startup_future.set_result(asyncio.get_running_loop())
@@ -1223,6 +1249,48 @@ def _normalize_cache_config(cache_backend: str | None, cache_config: Any | None)
cache_config = AsyncOmniEngine._get_default_cache_config(cache_backend)
return cache_config
+ def _detect_pd_config(self) -> dict[str, Any] | None:
+ """Detect PD (Prefill-Decode) disaggregation config from stage_configs.
+ Returns a dict with 'pd_pair' and 'bootstrap_addr', or None.
+ """
+ pd_pair = PDDisaggregationMixin.detect_pd_separation_from_stage_configs(self.stage_configs)
+ if pd_pair is None:
+ return None
+ prefill_idx, decode_idx = pd_pair
+
+ # Extract bootstrap address from prefill stage engine_args
+ bootstrap_addr: str | None = None
+ try:
+ prefill_cfg = self.stage_configs[prefill_idx]
+ ea = getattr(prefill_cfg, "engine_args", None)
+ kv_cfg = getattr(ea, "kv_transfer_config", None) if ea is not None else None
+ if kv_cfg is not None:
+ port = vllm_envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
+ kv_ip = getattr(kv_cfg, "kv_ip", None) or "127.0.0.1"
+ bootstrap_addr = f"http://{kv_ip}:{port}"
+ except Exception as exc:
+ logger.warning("[AsyncOmniEngine] Could not extract PD bootstrap address: %s", exc)
+
+ logger.info(
+ "[AsyncOmniEngine] PD disaggregation detected: prefill=stage-%d, decode=stage-%d, bootstrap=%s",
+ prefill_idx,
+ decode_idx,
+ bootstrap_addr,
+ )
+ prefill_engine_id: str | None = None
+ try:
+ prefill_client = self.stage_clients[prefill_idx]
+ kv_cfg = getattr(getattr(prefill_client, "vllm_config", None), "kv_transfer_config", None)
+ prefill_engine_id = getattr(kv_cfg, "engine_id", None)
+ except Exception as exc:
+ logger.warning("[AsyncOmniEngine] Could not extract prefill engine_id: %s", exc)
+
+ return {
+ "pd_pair": (prefill_idx, decode_idx),
+ "bootstrap_addr": bootstrap_addr,
+ "prefill_engine_id": prefill_engine_id,
+ }
+
@staticmethod
def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list:
"""Create a default single-stage diffusion config from kwargs."""
@@ -1356,45 +1424,17 @@ def _strip_single_engine_args(kwargs: dict[str, Any]) -> dict[str, Any]:
Logs a warning for any parent field whose value differs from the
dataclass default, so users know their explicit overrides are ignored.
+ See the module-level ``_PARENT_ARGS_*`` constants for the routing
+ contracts this method enforces.
"""
- # worker_extension_cls is a parent field but must pass through to
- # diffusion stages for colocate worker setup.
- _keep = {"worker_extension_cls"}
- # Orchestrator-level OmniEngineArgs fields that are consumed by
- # _resolve_stage_configs and must not leak into per-stage configs
- # (stage_configs_path would trigger the create_model_config guard).
- _strip_omni = {"stage_configs_path"}
- # Fields that are always set by callers (via from_cli_args / asdict)
- # and would always appear as overridden — suppress from the warning
- # so it only surfaces genuinely surprising overrides.
- _no_warn = {"model"}
-
parent_fields: dict[str, dataclasses.Field] = {f.name: f for f in dataclasses.fields(EngineArgs)}
- overridden: list[str] = []
- result: dict[str, Any] = {}
- for k, v in kwargs.items():
- if k in _strip_omni:
- continue
- if k not in parent_fields or k in _keep:
- result[k] = v
- continue
- # Detect explicitly-set values that differ from the default.
- # Values may have been through asdict() which converts dataclass
- # defaults to dicts, so normalise before comparing.
- field = parent_fields[k]
- if field.default is not dataclasses.MISSING:
- default = field.default
- elif field.default_factory is not dataclasses.MISSING:
- default = field.default_factory()
- else:
- default = dataclasses.MISSING
- if default is dataclasses.MISSING or v is None:
- continue
- # Normalise dataclass defaults to dicts for comparison
- if dataclasses.is_dataclass(default) and not isinstance(default, type):
- default = dataclasses.asdict(default)
- if v != default and k not in _no_warn:
- overridden.append(k)
+ result, overridden = strip_parent_engine_args(
+ kwargs,
+ parent_fields=parent_fields,
+ keep_keys=_PARENT_ARGS_KEEP,
+ strip_keys=_PARENT_ARGS_STRIP,
+ no_warn_keys=_PARENT_ARGS_NO_WARN,
+ )
if overridden:
logger.warning(
@@ -1409,6 +1449,12 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
"""Resolve stage configs and inject defaults shared by orchestrator/headless."""
stage_configs_path = kwargs.get("stage_configs_path", None)
+ deploy_config_path = kwargs.pop("deploy_config", None)
+ stage_overrides_json = kwargs.pop("stage_overrides", None)
+ # Set of CLI keys the user actually typed; ``None`` means we have no
+ # parser-level info (e.g. programmatic Omni() call) and the lower
+ # layers should treat all kwargs as explicit.
+ cli_explicit_keys = kwargs.pop("_cli_explicit_keys", None)
explicit_stage_configs = kwargs.pop("stage_configs", None)
if explicit_stage_configs is not None:
logger.warning(
@@ -1421,13 +1467,27 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
else:
base_kwargs = kwargs
- # Use the legacy config loading path (load_and_resolve_stage_configs).
- # StageConfigFactory wiring will be done in config refactor [2/N].
+ # Parse --stage-overrides JSON string if provided
+ stage_overrides = None
+ if stage_overrides_json:
+ if isinstance(stage_overrides_json, str):
+ try:
+ stage_overrides = json.loads(stage_overrides_json)
+ except json.JSONDecodeError as exc:
+ raise ValueError(
+ f"--stage-overrides is not valid JSON: {exc}. Got: {stage_overrides_json!r}"
+ ) from exc
+ else:
+ stage_overrides = stage_overrides_json
+
config_path, stage_configs = load_and_resolve_stage_configs(
model,
stage_configs_path,
base_kwargs,
default_stage_cfg_factory=lambda: self._create_default_diffusion_stage_cfg(kwargs),
+ deploy_config_path=deploy_config_path,
+ stage_overrides=stage_overrides,
+ cli_explicit_keys=cli_explicit_keys,
)
# Inject diffusion LoRA-related knobs from kwargs if not present in the stage config.
diff --git a/vllm_omni/engine/orchestrator.py b/vllm_omni/engine/orchestrator.py
index 5b60c969890..81ff960116b 100644
--- a/vllm_omni/engine/orchestrator.py
+++ b/vllm_omni/engine/orchestrator.py
@@ -35,6 +35,8 @@ def build_engine_core_request_from_tokens(
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
model_config: ModelConfig | None = None,
+ resumable: bool = False,
+ mm_features: list | None = None,
) -> OmniEngineCoreRequest:
"""Build an OmniEngineCoreRequest directly from an OmniTokensPrompt."""
if arrival_time is None:
@@ -60,7 +62,7 @@ def build_engine_core_request_from_tokens(
return OmniEngineCoreRequest(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
- mm_features=None,
+ mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
arrival_time=arrival_time,
@@ -68,6 +70,7 @@ def build_engine_core_request_from_tokens(
cache_salt=None,
data_parallel_rank=None,
prompt_embeds=prompt_embeds,
+ resumable=resumable,
additional_information=additional_info_payload,
)
@@ -83,6 +86,23 @@ class OrchestratorRequestState:
# Metrics: timestamp when request was submitted to each stage.
stage_submit_ts: dict[int, float] = field(default_factory=dict)
+ mm_processor_kwargs: dict | None = None
+ mm_features: list | None = None
+ pd_prefill_multimodal_output: dict[str, Any] | None = None
+
+ streaming: StreamingInputState = field(default_factory=lambda: StreamingInputState())
+
+
+@dataclass
+class StreamingInputState:
+ # Flag of streaming input request
+ enabled: bool = False
+ # Flag of segment of streaming input finished
+ segment_finished: bool = False
+ # Streaming update prompt length
+ new_prompt_len_snapshot: int | None = None
+ # Model/bridge-specific runtime states (e.g., thinker->talker)
+ bridge_states: dict[str, Any] = field(default_factory=dict)
class Orchestrator:
@@ -96,6 +116,7 @@ def __init__(
stage_pools: list[StagePool],
*,
async_chunk: bool = False,
+ pd_config: dict[str, Any] | None = None,
) -> None:
self.request_async_queue = request_async_queue
self.output_async_queue = output_async_queue
@@ -105,6 +126,15 @@ def __init__(
self.num_stages = len(stage_pools)
self.stage_pools: list[StagePool] = stage_pools
+ # PD disaggregation state
+ self._pd_pair: tuple[int, int] | None = None
+ self._pd_bootstrap_addr: str | None = None
+ self._pd_prefill_engine_id: str | None = None
+ self._pd_kv_params: dict[str, Any] = {}
+ if pd_config is not None:
+ self._pd_pair = pd_config.get("pd_pair")
+ self._pd_bootstrap_addr = pd_config.get("bootstrap_addr")
+ self._pd_prefill_engine_id = pd_config.get("prefill_engine_id")
self.request_states: dict[str, OrchestratorRequestState] = {}
self._cfg_tracker = CfgCompanionTracker()
@@ -201,8 +231,10 @@ async def _handle_add_request(self, msg: dict[str, Any]) -> None:
prompt=original_prompt,
sampling_params_list=sampling_params_list,
final_stage_id=final_stage_id,
+ mm_features=getattr(prompt, "mm_features", None),
)
self.request_states[request_id] = req_state
+ req_state.streaming.enabled = bool(getattr(prompt, "resumable", False))
req_state.stage_submit_ts[stage_id] = _time.time()
await self.stage_pools[stage_id].submit_initial(
request_id,
@@ -234,6 +266,7 @@ async def _handle_streaming_update(self, msg: dict[str, Any]) -> None:
if "sampling_params_list" in msg and msg["sampling_params_list"]:
req_state.sampling_params_list = msg["sampling_params_list"]
+ req_state.streaming.enabled = True
req_state.stage_submit_ts[stage_id] = _time.time()
await self.stage_pools[stage_id].submit_update(request_id, req_state, request)
@@ -376,6 +409,16 @@ async def _orchestration_loop(self) -> None:
continue
await self._handle_kv_ready_raw_outputs(stage_id, raw_outputs)
+ for eco in raw_outputs.outputs:
+ req_state = self.request_states.get(getattr(eco, "request_id", None))
+ if req_state is None:
+ continue
+ req_state.streaming.segment_finished = bool(getattr(eco, "is_segment_finished", False))
+ req_state.streaming.new_prompt_len_snapshot = getattr(
+ eco,
+ "new_prompt_len_snapshot",
+ None,
+ )
raw_output = await pool.process_llm_raw_outputs(replica_id, raw_outputs)
await self._handle_processed_outputs(stage_id, replica_id, raw_output)
idle = False
@@ -443,6 +486,7 @@ async def _cleanup_request_ids(self, request_ids: list[str], *, abort: bool = Fa
await self._abort_request_ids(request_ids)
self._release_request_bindings(request_ids)
for request_id in request_ids:
+ self._pd_kv_params.pop(request_id, None)
self.request_states.pop(request_id, None)
def _maybe_clone_diffusion_params_for_cfg(self, request_id: str, params: Any) -> Any:
@@ -502,20 +546,50 @@ async def _route_output(
}
)
+ if self._pd_pair is not None and finished and stage_id == self._pd_pair[0]:
+ kv_params = getattr(output, "kv_transfer_params", None)
+ if kv_params is not None:
+ self._pd_kv_params[req_id] = kv_params if isinstance(kv_params, dict) else dict(kv_params)
+ req_state.pd_prefill_multimodal_output = getattr(output, "multimodal_output", None)
+
if (
- finished
+ (finished or (req_state.streaming.enabled and req_state.streaming.segment_finished))
and stage_id < req_state.final_stage_id
and not self.async_chunk
- and (stage_id + 1) not in req_state.stage_submit_ts
+ and (not self._next_stage_already_submitted(stage_id, req_state) or req_state.streaming.enabled)
):
- if self._cfg_tracker.has_companions(req_id) and not self._cfg_tracker.all_companions_done(req_id):
+ if (
+ finished
+ and self._cfg_tracker.has_companions(req_id)
+ and not self._cfg_tracker.all_companions_done(req_id)
+ ):
self._cfg_tracker.defer_parent(req_id, output, stage_id)
else:
- await self._forward_to_next_stage(req_id, stage_id, output, req_state)
+ await self._forward_to_next_stage(
+ req_id,
+ stage_id,
+ output,
+ req_state,
+ is_streaming_session=req_state.streaming.enabled,
+ is_final_update=False,
+ )
+ if req_state.streaming.enabled and finished:
+ # For streaming sessions, send the terminal (resumable=False) update only on a finish
+ await self._forward_to_next_stage(
+ req_id,
+ stage_id,
+ output,
+ req_state,
+ is_streaming_session=True,
+ is_final_update=True,
+ )
if finished and stage_id == req_state.final_stage_id:
await self._cleanup_request_ids([req_id, *self._cfg_tracker.cleanup_parent(req_id)])
+ def _next_stage_already_submitted(self, stage_id: int, req_state: OrchestratorRequestState) -> bool:
+ return (stage_id + 1) in req_state.stage_submit_ts
+
async def _handle_cfg_companion_ready(self, req_id: str) -> None:
"""Mark a CFG companion as done; if all companions are done, flush deferred parent."""
parent_id = self._cfg_tracker.on_companion_completed(req_id)
@@ -572,12 +646,60 @@ async def _handle_kv_ready_raw_outputs(
else:
await self._forward_to_next_stage(req_id, stage_id, raw_output, req_state)
+ def _build_pd_decode_params(self, req_id: str, sp: Any) -> Any:
+ """Build decode-side sampling params with KV transfer params for PD routing.
+
+ Clones the sampling params and injects kv_transfer_params that tell the
+ decode engine where to pull the KV cache from (prefill engine's bootstrap addr).
+ """
+ sp = sp.clone()
+ if sp.extra_args is None:
+ sp.extra_args = {}
+
+ # Get KV params captured from the prefill output (must include remote_request_id).
+ kv_prefill_params = self._pd_kv_params.pop(req_id, None)
+ if not kv_prefill_params or "remote_request_id" not in kv_prefill_params:
+ raise RuntimeError(
+ f"[Orchestrator][PD] Missing prefill kv_transfer_params.remote_request_id for req={req_id}"
+ )
+
+ decode_kv_params: dict[str, Any] = {
+ "transfer_id": f"xfer-{req_id}",
+ }
+
+ if self._pd_bootstrap_addr:
+ decode_kv_params["remote_bootstrap_addr"] = self._pd_bootstrap_addr
+
+ if self._pd_prefill_engine_id:
+ decode_kv_params["remote_engine_id"] = self._pd_prefill_engine_id
+
+ # Overlay params from prefill side (includes remote_request_id set by monkey patch).
+ decode_kv_params.update(kv_prefill_params)
+
+ # Ensure these flags are set correctly after any overlay.
+ decode_kv_params["do_remote_prefill"] = True
+ decode_kv_params["do_remote_decode"] = False
+ if not decode_kv_params.get("transfer_id"):
+ decode_kv_params["transfer_id"] = f"xfer-{req_id}"
+
+ sp.extra_args["kv_transfer_params"] = decode_kv_params
+
+ logger.debug(
+ "[Orchestrator][PD] decode kv_transfer_params for req=%s: %s",
+ req_id,
+ decode_kv_params,
+ )
+ return sp
+
async def _forward_to_next_stage(
self,
req_id: str,
src_stage_id: int,
output: Any,
req_state: OrchestratorRequestState,
+ *,
+ is_streaming_session: bool = False,
+ is_final_update: bool = False,
) -> None:
"""Forward output from the current logical stage to the next one."""
next_logical = src_stage_id + 1
@@ -585,6 +707,8 @@ async def _forward_to_next_stage(
next_client = next_pool.stage_client
params = req_state.sampling_params_list[next_logical]
source_outputs = [output]
+ next_stage_resumable = is_streaming_session and not is_final_update
+ already_submitted = self._next_stage_already_submitted(src_stage_id, req_state)
requires_multimodal_data = getattr(next_client, "requires_multimodal_data", False)
if next_pool.stage_type == "diffusion":
@@ -597,35 +721,106 @@ async def _forward_to_next_stage(
else:
diffusion_prompt = req_state.prompt
- req_state.stage_submit_ts.setdefault(next_logical, _time.time())
- await next_pool.submit_initial(
- req_id,
- req_state,
- diffusion_prompt,
- submit_kwargs={
- "kv_sender_info": self._build_kv_sender_info(
- list(getattr(next_client, "engine_input_source", None) or [src_stage_id]),
- request_id=req_id,
+ if already_submitted:
+ await next_pool.submit_update(req_id, req_state, diffusion_prompt)
+ else:
+ await next_pool.submit_initial(
+ req_id,
+ req_state,
+ diffusion_prompt,
+ submit_kwargs={
+ "kv_sender_info": self._build_kv_sender_info(
+ list(getattr(next_client, "engine_input_source", None) or [src_stage_id]),
+ request_id=req_id,
+ )
+ },
+ params_override=self._maybe_clone_diffusion_params_for_cfg(req_id, params),
+ )
+ req_state.stage_submit_ts[next_logical] = _time.time()
+ return
+
+ # PD disaggregation: prefill → decode routing uses original prompt + KV transfer params
+ if self._pd_pair is not None and (src_stage_id, next_logical) == self._pd_pair:
+ params = self._build_pd_decode_params(req_id, params)
+
+ # Use the original user prompt for the decode stage (not processed embeddings)
+ original_prompt = req_state.prompt
+ raw_decode_inputs = [original_prompt] if not isinstance(original_prompt, list) else original_prompt
+
+ decode_inputs: list[dict[str, Any]] = []
+ for decode_input in raw_decode_inputs:
+ if isinstance(decode_input, dict):
+ decode_inputs.append(decode_input)
+ continue
+ prompt_token_ids = getattr(decode_input, "prompt_token_ids", None)
+ if prompt_token_ids is None:
+ raise TypeError(
+ "[Orchestrator][PD] decode input must be dict or have prompt_token_ids, "
+ f"got {type(decode_input).__name__} for req={req_id}"
)
- },
- params_override=self._maybe_clone_diffusion_params_for_cfg(req_id, params),
- )
- else:
- req_state.stage_submit_ts.setdefault(next_logical, _time.time())
- next_inputs = next_client.process_engine_inputs(
- source_outputs,
- req_state.prompt,
- )
- for next_input in next_inputs:
+ decode_inputs.append({"prompt_token_ids": list(prompt_token_ids)})
+
+ for decode_input in decode_inputs:
request = build_engine_core_request_from_tokens(
request_id=req_id,
- prompt=next_input,
+ prompt=decode_input,
params=params,
model_config=next_pool.stage_vllm_config.model_config,
+ mm_features=req_state.mm_features,
+ resumable=next_stage_resumable,
)
request.external_req_id = request.request_id
+ if already_submitted:
+ await next_pool.submit_update(req_id, req_state, request)
+ else:
+ await next_pool.submit_initial(req_id, req_state, request, prompt_text=None)
+
+ req_state.stage_submit_ts[next_logical] = _time.time()
+ return
+
+ if req_state.pd_prefill_multimodal_output is not None:
+ req_state.streaming.bridge_states.setdefault(
+ "pd_prefill_multimodal_output_by_req",
+ {},
+ )[req_id] = req_state.pd_prefill_multimodal_output
+
+ try:
+ next_inputs = next_client.process_engine_inputs(
+ source_outputs,
+ req_state.prompt,
+ streaming_context=req_state.streaming,
+ )
+ except Exception:
+ logger.exception(
+ "[Orchestrator] req=%s process_engine_inputs FAILED for stage-%s",
+ req_id,
+ next_logical,
+ )
+ raise
+
+ # Build and submit requests for each input
+ for next_input in next_inputs:
+ # Only AR thinker stages consume encoder mm_features; downstream
+ # (talker/code2wav/…) must not see them (avoids encoder-cache misses).
+ model_stage = getattr(next_client, "model_stage", None)
+ mm_features = req_state.mm_features if model_stage == "thinker" else None
+ request = build_engine_core_request_from_tokens(
+ request_id=req_id,
+ prompt=next_input,
+ params=params,
+ model_config=next_pool.stage_vllm_config.model_config,
+ mm_features=mm_features,
+ resumable=next_stage_resumable,
+ )
+
+ request.external_req_id = request.request_id
+ if already_submitted:
+ await next_pool.submit_update(req_id, req_state, request)
+ else:
await next_pool.submit_initial(req_id, req_state, request, prompt_text=None)
+ req_state.stage_submit_ts[next_logical] = _time.time()
+
async def _prewarm_async_chunk_stages(
self,
request_id: str,
diff --git a/vllm_omni/engine/output_processor.py b/vllm_omni/engine/output_processor.py
index badd799fc94..67b4dd16504 100644
--- a/vllm_omni/engine/output_processor.py
+++ b/vllm_omni/engine/output_processor.py
@@ -233,10 +233,9 @@ def _new_completion_output(
# Reuse base text/logprobs logic, then annotate with pooling_result.
base_output = super()._new_completion_output(token_ids, finish_reason, stop_reason, routed_experts)
try:
+ if not hasattr(base_output, "multimodal_output"):
+ setattr(base_output, "multimodal_output", {})
if self.mm_accumulated is not None:
- # Attach accumulated multimodal dict on the completion output
- if not hasattr(base_output, "multimodal_output"):
- setattr(base_output, "multimodal_output", {})
mm_out = getattr(base_output, "multimodal_output")
if isinstance(mm_out, dict):
for k, v in self.mm_accumulated.items():
diff --git a/vllm_omni/engine/stage_engine_core_client.py b/vllm_omni/engine/stage_engine_core_client.py
index cd2c0823218..cf0e4241032 100644
--- a/vllm_omni/engine/stage_engine_core_client.py
+++ b/vllm_omni/engine/stage_engine_core_client.py
@@ -6,6 +6,7 @@
from __future__ import annotations
+import inspect
import socket
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
@@ -323,6 +324,7 @@ def process_engine_inputs(
self,
source_outputs: list[Any],
prompt: Any = None,
+ streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""Process inputs from upstream stages.
@@ -330,6 +332,14 @@ def process_engine_inputs(
and the original prompt.
"""
if self.custom_process_input_func is not None:
+ signature = inspect.signature(self.custom_process_input_func)
+ if len(signature.parameters) >= 4:
+ return self.custom_process_input_func(
+ source_outputs,
+ prompt,
+ self.requires_multimodal_data,
+ streaming_context,
+ )
return self.custom_process_input_func(
source_outputs,
prompt,
diff --git a/vllm_omni/engine/stage_init_utils.py b/vllm_omni/engine/stage_init_utils.py
index 39233b719fb..456f5a9244b 100644
--- a/vllm_omni/engine/stage_init_utils.py
+++ b/vllm_omni/engine/stage_init_utils.py
@@ -533,6 +533,20 @@ def build_vllm_config(
filtered_engine_args_dict = filter_dataclass_kwargs(OmniEngineArgs, engine_args_dict)
omni_engine_args = OmniEngineArgs(**filtered_engine_args_dict)
+
+ # Multi-stage pipelines (qwen3_tts code2wav, etc.) set max_model_len
+ # larger than HF max_position_embeddings by design. vLLM's validator
+ # rejects that without the env flag.
+ if filtered_engine_args_dict.get("max_model_len") is not None and not os.environ.get(
+ "VLLM_ALLOW_LONG_MAX_MODEL_LEN"
+ ):
+ os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
+ logger.debug(
+ "Auto-set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for stage %s (max_model_len=%s).",
+ stage_config.stage_id,
+ filtered_engine_args_dict["max_model_len"],
+ )
+
vllm_config = omni_engine_args.create_engine_config(
usage_context=UsageContext.LLM_CLASS,
headless=headless,
diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py
index 129ef3c99d8..9606cc80d0d 100644
--- a/vllm_omni/entrypoints/async_omni.py
+++ b/vllm_omni/entrypoints/async_omni.py
@@ -78,7 +78,6 @@ def __init__(self, *args: Any, model: str = "", **kwargs: Any) -> None:
self.final_output_task: asyncio.Task | None = None
self.config_path = self.engine.config_path
- self.stage_configs = self.engine.stage_configs
self.tts_max_instructions_length = kwargs.get("tts_max_instructions_length", None)
self.input_processor = self.engine.input_processor
@@ -209,6 +208,13 @@ async def generate(
# Start final output dispatcher on the first call to generate()
self._final_output_handler()
+ # Expand sampling params for PD disaggregation (user may provide N-1 params)
+ if (
+ sampling_params_list is not None
+ and isinstance(sampling_params_list, Sequence)
+ and not isinstance(sampling_params_list, (str, bytes))
+ ):
+ sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list))
sampling_params_list = self.resolve_sampling_params_list(sampling_params_list)
# Track per-request metrics
@@ -228,20 +234,27 @@ async def generate(
req_state.metrics = metrics
self.request_states[request_id] = req_state
+ # PD disaggregation: modify prefill-stage sampling params per request
+ req_sp_list = list(sampling_params_list)
+ pd_pair = self._get_pd_separation_pair()
+ if pd_pair is not None:
+ p_id = pd_pair[0]
+ req_sp_list[p_id] = self._prepare_prefill_sampling_params(request_id, req_sp_list[p_id])
+
# Add request(s) to stage 0. For streaming inputs, submit
# chunks incrementally through streaming_update.
if isinstance(prompt, AsyncGenerator):
input_stream_task = await self._add_streaming_input_request(
request_id=request_id,
input_stream=prompt,
- sampling_params_list=sampling_params_list,
+ sampling_params_list=req_sp_list,
final_stage_id=final_stage_id_for_e2e,
)
else:
await self.engine.add_request_async(
request_id=request_id,
prompt=prompt,
- sampling_params_list=sampling_params_list,
+ sampling_params_list=req_sp_list,
final_stage_id=final_stage_id_for_e2e,
)
submit_ts = time.time()
@@ -296,7 +309,6 @@ async def _add_streaming_input_request(
if not stage0_params.skip_clone:
stage0_params = stage0_params.clone()
stage0_params.skip_clone = True
- stage0_params.output_kind = RequestOutputKind.DELTA
has_submitted_first_chunk = False
diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py
index 6e9adc24615..8bccfbb5916 100644
--- a/vllm_omni/entrypoints/cli/serve.py
+++ b/vllm_omni/entrypoints/cli/serve.py
@@ -9,6 +9,7 @@
import json
import os
import signal
+import sys
from types import FrameType
from typing import Any
@@ -21,6 +22,7 @@
from vllm_omni.entrypoints.cli.logo import log_logo
from vllm_omni.entrypoints.openai.api_server import omni_run_server
+from vllm_omni.entrypoints.utils import detect_explicit_cli_keys
logger = init_logger(__name__)
@@ -79,6 +81,9 @@ class OmniServeCommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI."""
name = "serve"
+ # Parser stashed at subparser_init so ``cmd`` can resolve each user-typed
+ # flag to its real ``dest`` via the parser's action table.
+ _parser: FlexibleArgumentParser | None = None
@staticmethod
def cmd(args: argparse.Namespace) -> None:
@@ -90,6 +95,10 @@ def cmd(args: argparse.Namespace) -> None:
if hasattr(args, "model_tag") and args.model_tag is not None:
args.model = args.model_tag
+ # Stash the set of long-option keys the user actually typed so the
+ # stage-config factory can give YAML precedence over argparse defaults.
+ args._cli_explicit_keys = detect_explicit_cli_keys(sys.argv[1:], OmniServeCommand._parser)
+
if args.headless:
run_headless(args)
else:
@@ -138,11 +147,33 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
help="Default task type for TTS models (CustomVoice, VoiceDesign, or Base). "
"If not specified, will be inferred from model path.",
)
+ # TODO(@lishunyang12): deprecate once all models migrate to --deploy-config
omni_config_group.add_argument(
"--stage-configs-path",
type=str,
default=None,
- help="Path to the stage configs file. If not specified, the stage configs will be loaded from the model.",
+ help="[Deprecated — will be removed in a future release] Path to a legacy "
+ "stage configs YAML (stage_args format). Prefer --deploy-config for new-format deploy YAMLs.",
+ )
+ omni_config_group.add_argument(
+ "--deploy-config",
+ type=str,
+ default=None,
+ help="Path to a deploy config YAML (new format with stages/engine_args). "
+ "Mutually exclusive with --stage-configs-path.",
+ )
+ omni_config_group.add_argument(
+ "--stage-overrides",
+ type=str,
+ default=None,
+ help="Per-stage JSON overrides. Example: "
+ '\'{"0": {"gpu_memory_utilization": 0.8}, "2": {"enforce_eager": true}}\'',
+ )
+ omni_config_group.add_argument(
+ "--async-chunk",
+ action=argparse.BooleanOptionalAction,
+ default=None,
+ help="Override the deploy YAML's ``async_chunk:`` bool. Unset leaves the YAML value in force.",
)
omni_config_group.add_argument(
"--stage-id",
@@ -406,6 +437,9 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
+ # Stash via type(self) so the docs hook (which execs this function in a
+ # sandboxed globals dict via ``DummySelf``) doesn't fail on a NameError.
+ type(self)._parser = serve_parser
return serve_parser
@@ -461,10 +495,15 @@ def run_headless(args: argparse.Namespace) -> None:
raise ValueError("headless mode requires worker_backend=multi_process")
args_dict = vars(args).copy()
+ # Preserve the explicit-keys set captured at parse time so per-stage yaml
+ # values (e.g. stage 1's ``gpu_memory_utilization: 0.5``) are not
+ # overwritten by argparse defaults for flags the user didn't type.
+ cli_explicit_keys = args_dict.pop("_cli_explicit_keys", None)
config_path, stage_configs = load_and_resolve_stage_configs(
model,
args_dict.get("stage_configs_path"),
args_dict,
+ cli_explicit_keys=cli_explicit_keys,
)
# Locate the stage config that matches stage_id.
diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py
index a3bfe98ce2c..8ef7e2ee5b7 100644
--- a/vllm_omni/entrypoints/omni.py
+++ b/vllm_omni/entrypoints/omni.py
@@ -66,6 +66,13 @@ def generate(
py_generator: bool = False,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> Generator[OmniRequestOutput, None, None] | list[OmniRequestOutput]:
+ # Expand sampling params for PD disaggregation (user may provide N-1 params)
+ if (
+ sampling_params_list is not None
+ and isinstance(sampling_params_list, Sequence)
+ and not isinstance(sampling_params_list, (str, bytes))
+ ):
+ sampling_params_list = self._maybe_expand_sampling_params(list(sampling_params_list))
sampling_params_list = self.resolve_sampling_params_list(sampling_params_list)
try:
if py_generator:
@@ -125,10 +132,17 @@ def _run_generation(
req_state.metrics = metrics
self.request_states[req_id] = req_state
+ # PD disaggregation: modify stage-0 (prefill) sampling params per request
+ req_sp_list = list(sampling_params_list)
+ pd_pair = self._get_pd_separation_pair()
+ if pd_pair is not None:
+ p_id = pd_pair[0]
+ req_sp_list[p_id] = self._prepare_prefill_sampling_params(req_id, req_sp_list[p_id])
+
self.engine.add_request(
request_id=req_id,
prompt=prompt,
- sampling_params_list=sampling_params_list,
+ sampling_params_list=req_sp_list,
final_stage_id=final_stage_id,
)
submit_ts = time.time()
diff --git a/vllm_omni/entrypoints/omni_base.py b/vllm_omni/entrypoints/omni_base.py
index 1a7ffc4a504..dca494efe72 100644
--- a/vllm_omni/entrypoints/omni_base.py
+++ b/vllm_omni/entrypoints/omni_base.py
@@ -1,6 +1,8 @@
from __future__ import annotations
+import argparse
import os
+import sys
import time
import types
import weakref
@@ -14,7 +16,8 @@
from vllm_omni.engine.async_omni_engine import AsyncOmniEngine
from vllm_omni.entrypoints.client_request_state import ClientRequestState
-from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e
+from vllm_omni.entrypoints.pd_utils import PDDisaggregationMixin
+from vllm_omni.entrypoints.utils import detect_explicit_cli_keys, get_final_stage_id_for_e2e
from vllm_omni.metrics.stats import OrchestratorAggregator as OrchestratorMetrics
from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific
from vllm_omni.outputs import OmniRequestOutput
@@ -65,9 +68,51 @@ def omni_snapshot_download(model_id: str) -> str:
OutputMessageHandleResult = tuple[Literal[True], None, None, None] | tuple[Literal[False], str, int, ClientRequestState]
-class OmniBase:
+class OmniBase(PDDisaggregationMixin):
"""Shared runtime foundation for AsyncOmni and Omni."""
+ @classmethod
+ def from_cli_args(
+ cls,
+ args: argparse.Namespace,
+ *,
+ parser: argparse.ArgumentParser | None = None,
+ **overrides: Any,
+ ) -> OmniBase:
+ """Construct an ``Omni`` / ``AsyncOmni`` from an ``argparse.Namespace``.
+
+ Mirrors the ``EngineArgs.from_cli_args`` pattern used upstream and in
+ ``OmniEngineArgs.from_cli_args``. This is the recommended entry point
+ for any argparse-based caller (offline scripts, tests, CI): it
+ expands ``vars(args)`` into kwargs and automatically captures which
+ flags the user typed on the command line so that argparse defaults
+ do not silently override deploy YAML values.
+
+ Passing ``parser`` is strongly recommended: without it, flag-to-dest
+ resolution falls back to a name-based heuristic that misidentifies
+ flags with ``dest=`` overrides, alias flags, and ``--disable-X`` /
+ ``store_false`` pairs. See :func:`detect_explicit_cli_keys`.
+
+ Args:
+ args: Parsed argparse namespace from ``parser.parse_args()``.
+ parser: The argparse parser used to produce ``args``. When
+ provided, each user-typed flag is resolved to its real
+ ``dest`` via the parser's action table.
+ **overrides: Extra keyword arguments that take precedence over
+ attributes on ``args``.
+
+ Example::
+
+ parser = FlexibleArgumentParser()
+ parser.add_argument("--model", required=True)
+ args = parser.parse_args()
+ omni = Omni.from_cli_args(args, parser=parser) # preferred
+ omni = Omni.from_cli_args(args, parser=parser, model="other")
+ """
+ kwargs: dict[str, Any] = {**vars(args), **overrides}
+ kwargs["_cli_explicit_keys"] = detect_explicit_cli_keys(sys.argv[1:], parser)
+ return cls(**kwargs)
+
def __init__(
self,
model: str,
@@ -77,16 +122,24 @@ def __init__(
stage_init_timeout = kwargs.pop("stage_init_timeout", 300)
init_timeout = kwargs.pop("init_timeout", 600)
log_stats = kwargs.pop("log_stats", False)
- async_chunk = kwargs.pop("async_chunk", False)
+ # NOTE: read-only lookup — must NOT pop. Popping here drops the key
+ # before it reaches ``StageConfigFactory._create_from_registry``, so
+ # ``--no-async-chunk`` (``async_chunk=False``) silently fails to
+ # override the deploy YAML's ``async_chunk: true`` default.
+ async_chunk = kwargs.get("async_chunk")
output_modalities = kwargs.pop("output_modalities", None)
diffusion_batch_size: int = kwargs.pop("diffusion_batch_size", 1)
if "log_requests" in kwargs:
raise TypeError("`log_requests` has been removed in Omni/AsyncOmni. Use `log_stats`.")
model = omni_snapshot_download(model)
+ self._name = self.__class__.__name__
self.model = model
self.log_stats = log_stats
- self.async_chunk = async_chunk
+ # Provisional value (mirrors the CLI/caller kwarg); the engine resolves
+ # pipeline + deploy YAML + CLI precedence below and the final value is
+ # re-assigned from ``self.engine.async_chunk`` after init.
+ self.async_chunk = bool(async_chunk) if async_chunk is not None else False
self.output_modalities = output_modalities or []
self.tts_batch_max_items: int = kwargs.pop("tts_batch_max_items", 32)
@@ -104,7 +157,11 @@ def __init__(
self._weak_finalizer = weakref.finalize(self, _weak_shutdown_engine, self.engine)
et = time.time()
logger.info("[%s] AsyncOmniEngine initialized in %.2f seconds", self.__class__.__name__, et - st)
- self.async_chunk = bool(self.async_chunk or getattr(self.engine, "async_chunk", False))
+ # Authoritative: ``AsyncOmniEngine`` resolves (pipeline + deploy YAML +
+ # CLI overrides) through ``StageConfigFactory`` and stores the final
+ # value on ``engine.async_chunk``; mirror it here so ``--no-async-chunk``
+ # (explicit ``False``) is not fallen-back-through by ``or``.
+ self.async_chunk = bool(getattr(self.engine, "async_chunk", False))
self.request_states: dict[str, ClientRequestState] = {}
@@ -125,10 +182,18 @@ def __init__(
model,
)
+ # PD disaggregation state (detects if a prefill/decode stage pair is configured)
+ self._init_pd_state()
+
@property
def num_stages(self) -> int:
return self.engine.num_stages
+ @property
+ def stage_configs(self) -> list:
+ """Expose engine stage configs for PD disaggregation detection and validation."""
+ return self.engine.stage_configs
+
@property
def is_running(self) -> bool:
return self.engine.is_alive()
diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py
index 1e45758368d..745b719d5b2 100644
--- a/vllm_omni/entrypoints/openai/api_server.py
+++ b/vllm_omni/entrypoints/openai/api_server.py
@@ -53,7 +53,6 @@
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.openai.orca_metrics import metrics_header
-from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.entrypoints.openai.server_utils import get_uvicorn_log_config
@@ -108,6 +107,7 @@
VideoListResponse,
VideoResponse,
)
+from vllm_omni.entrypoints.openai.realtime_connection import RealtimeConnection
from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat
from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech
from vllm_omni.entrypoints.openai.serving_speech_stream import OmniStreamingSpeechHandler
@@ -121,6 +121,7 @@
logger = init_logger(__name__)
router = APIRouter()
+MAX_UINT32_SEED = 2**32 - 1
profiler_router = APIRouter()
@@ -1203,6 +1204,22 @@ async def streaming_speech(websocket: WebSocket):
@router.websocket("/v1/realtime")
async def realtime_websocket(websocket: WebSocket):
"""WebSocket endpoint for OpenAI-style realtime interactions."""
+ engine_client = getattr(websocket.app.state, "engine_client", None)
+ if engine_client is not None and getattr(engine_client, "async_chunk", False):
+ await websocket.accept()
+ await websocket.send_json(
+ {
+ "type": "error",
+ "error": (
+ "The /v1/realtime API is not supported when async_chunk is enabled on the server. "
+ "Use a stage configuration with async_chunk disabled and restart the server before using "
+ "this endpoint."
+ ),
+ "code": "unsupported",
+ }
+ )
+ await websocket.close()
+ return
serving = getattr(websocket.app.state, "openai_serving_realtime", None)
if serving is None:
await websocket.accept()
@@ -1304,14 +1321,64 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# Get engine client (AsyncOmni) from app state
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
- # Validate model field (warn if mismatch, don't error)
if request.model is not None and request.model != model_name:
- logger.warning(
- f"Model mismatch: request specifies '{request.model}' but "
- f"server is running '{model_name}'. Using server model."
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=(f"Model mismatch: request specifies '{request.model}' but server is running '{model_name}'."),
)
try:
+ # Unify request construction for any multi-stage pipeline to avoid
+ # divergence between /v1/images and /v1/chat/completions.
+ if len(stage_configs) > 1:
+ chat_handler = getattr(raw_request.app.state, "openai_serving_chat", None)
+ if chat_handler is None:
+ logger.warning("openai_serving_chat is not initialized for multi-stage /v1/images/generations")
+ raise HTTPException(
+ status_code=HTTPStatus.SERVICE_UNAVAILABLE.value,
+ detail="openai_serving_chat is not initialized for multi-stage image generation.",
+ )
+
+ effective_seed = request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED)
+ extra_body: dict[str, Any] = {
+ "seed": effective_seed,
+ "num_outputs_per_prompt": request.n,
+ }
+ if request.size is not None:
+ parse_size(request.size)
+ width, height = parse_size(request.size)
+ app_state_args = getattr(raw_request.app.state, "args", None)
+ _check_max_generated_image_size(app_state_args, width, height)
+ extra_body["size"] = request.size
+ if request.negative_prompt is not None:
+ extra_body["negative_prompt"] = request.negative_prompt
+ if request.num_inference_steps is not None:
+ extra_body["num_inference_steps"] = request.num_inference_steps
+ if request.guidance_scale is not None:
+ extra_body["guidance_scale"] = request.guidance_scale
+ if request.true_cfg_scale is not None:
+ extra_body["true_cfg_scale"] = request.true_cfg_scale
+ if request.generator_device is not None:
+ extra_body["generator_device"] = request.generator_device
+ if request.lora is not None:
+ # Keep /images validation semantics: invalid LoRA should fail with 400.
+ _parse_lora_request(request.lora)
+ extra_body["lora"] = request.lora
+
+ generation_result = await chat_handler.generate_diffusion_images(
+ prompt=request.prompt,
+ extra_body=extra_body,
+ request_id=f"img_gen-{random_uuid()}",
+ )
+ if isinstance(generation_result, ErrorResponse):
+ return JSONResponse(
+ status_code=generation_result.error.code if generation_result.error else 400,
+ content=generation_result.model_dump(),
+ )
+ flat_images, _, _ = generation_result
+ image_data = [ImageData(b64_json=encode_image_base64(img), revised_prompt=None) for img in flat_images]
+ return ImageGenerationResponse(created=int(time.time()), data=image_data)
+
# Build params - pass through user values directly
prompt: OmniTextPrompt = {"prompt": request.prompt}
if request.negative_prompt is not None:
@@ -1352,7 +1419,7 @@ async def generate_images(request: ImageGenerationRequest, raw_request: Request)
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
_update_if_not_none(
- gen_params, "seed", request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
+ gen_params, "seed", request.seed if request.seed is not None else random.randint(0, MAX_UINT32_SEED)
)
_update_if_not_none(gen_params, "generator_device", request.generator_device)
_update_if_not_none(gen_params, "layers", request.layers)
@@ -1449,8 +1516,9 @@ async def edit_images(
# 1. get engine and model
engine_client, model_name, stage_configs = _get_engine_and_model(raw_request)
if model is not None and model != model_name:
- logger.warning(
- f"Model mismatch: request specifies '{model}' but server is running '{model_name}'. Using server model."
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=(f"Model mismatch: request specifies '{model}' but server is running '{model_name}'."),
)
# 2. get output format & compression
output_format = _choose_output_format(output_format, background)
@@ -1473,12 +1541,24 @@ async def edit_images(
input_images_list.extend(urls)
if not input_images_list:
raise HTTPException(status_code=422, detail="Field 'image' or 'url' is required")
- pil_images = await _load_input_images(input_images_list)
- if len(pil_images) > 1 and not _supports_multimodal_image_inputs(raw_request, engine_client):
+ # Reject oversized multi-image edit requests before fetching or decoding
+ # any inputs. This keeps over-limit URL requests from burning network,
+ # CPU, and memory on work that will be rejected anyway.
+ max_input_images = _get_max_edit_input_images(raw_request, engine_client)
+ if max_input_images is not None and len(input_images_list) > max_input_images:
+ detail = (
+ "Received multiple input images. Only a single image is supported by this model."
+ if max_input_images == 1
+ else (
+ f"Received {len(input_images_list)} input images. "
+ f"At most {max_input_images} images are supported by this model."
+ )
+ )
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
- detail="Received multiple input images. Only a single image is supported by this model.",
+ detail=detail,
)
+ pil_images = await _load_input_images(input_images_list)
prompt["multi_modal_data"] = {}
prompt["multi_modal_data"]["image"] = pil_images
@@ -1561,7 +1641,7 @@ async def edit_images(
# a proper generator is initialized in the backend.
# This fixes issues where using the default global generator
# might produce blurry images in some environments.
- _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, 2**32 - 1))
+ _update_if_not_none(gen_params, "seed", seed if seed is not None else random.randint(0, MAX_UINT32_SEED))
_update_if_not_none(gen_params, "generator_device", generator_device)
_update_if_not_none(gen_params, "layers", layers)
_update_if_not_none(gen_params, "resolution", resolution)
@@ -1651,18 +1731,25 @@ def _get_engine_and_model(raw_request: Request):
return engine_client, model_name, normalized_stage_configs
-def _supports_multimodal_image_inputs(raw_request: Request, engine_client: Any) -> bool:
+def _get_diffusion_od_config(raw_request: Request, engine_client: Any) -> Any:
diffusion_engine = getattr(raw_request.app.state, "diffusion_engine", None) or engine_client
get_diffusion_od_config = getattr(diffusion_engine, "get_diffusion_od_config", None)
- od_config = (
+ return (
get_diffusion_od_config() if callable(get_diffusion_od_config) else getattr(diffusion_engine, "od_config", None)
)
+
+def _get_max_edit_input_images(raw_request: Request, engine_client: Any) -> int | None:
+ od_config = _get_diffusion_od_config(raw_request, engine_client)
if od_config is None:
# Preserve the existing compatibility behavior when the diffusion
# config is not exposed on the serving surface.
- return True
- return bool(getattr(od_config, "supports_multimodal_inputs", False))
+ return None
+
+ if not bool(getattr(od_config, "supports_multimodal_inputs", False)):
+ return 1
+
+ return getattr(od_config, "max_multimodal_image_inputs", None)
def _get_lora_from_json_str(lora_body):
@@ -2147,10 +2234,12 @@ async def _parse_video_form(
app_model_name, app_stage_configs = _resolve_video_runtime_context(raw_request)
effective_model_name = handler.model_name or app_model_name or request.model or "unknown"
if request.model is not None and effective_model_name is not None and request.model != effective_model_name:
- logger.warning(
- "Model mismatch: request specifies '%s' but server is running '%s'. Using server model.",
- request.model,
- effective_model_name,
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST.value,
+ detail=(
+ f"Model mismatch: request specifies '{request.model}' but server is running "
+ f"'{effective_model_name}'."
+ ),
)
handler.set_stage_configs_if_missing(app_stage_configs)
except HTTPException:
diff --git a/vllm_omni/entrypoints/openai/realtime_connection.py b/vllm_omni/entrypoints/openai/realtime_connection.py
new file mode 100644
index 00000000000..1d5470f569c
--- /dev/null
+++ b/vllm_omni/entrypoints/openai/realtime_connection.py
@@ -0,0 +1,193 @@
+from __future__ import annotations
+
+import asyncio
+import base64
+import json
+from collections.abc import AsyncGenerator
+from uuid import uuid4
+
+import numpy as np
+from vllm.entrypoints.openai.engine.protocol import UsageInfo
+from vllm.entrypoints.openai.realtime.connection import RealtimeConnection as VllmRealtimeConnection
+from vllm.entrypoints.openai.realtime.protocol import TranscriptionDelta, TranscriptionDone
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class RealtimeConnection(VllmRealtimeConnection):
+ """Omni realtime connection with audio-only server events.
+
+ Reuses upstream vLLM websocket/session lifecycle and only customizes
+ generation output handling to emit audio deltas.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Last audio buffer seen for this realtime generation (cumulative or concatenation
+ # of increments); used to turn server cumulative PCM into true deltas.
+ self._realtime_audio_ref: np.ndarray | None = None
+
+ async def start_generation(self):
+ await super().start_generation()
+
+ @staticmethod
+ def _tensor_to_numpy(value) -> np.ndarray | None:
+ if value is None:
+ return None
+ if isinstance(value, np.ndarray):
+ arr = value
+ elif hasattr(value, "detach"):
+ arr = value.detach().float().cpu().numpy()
+ else:
+ try:
+ arr = np.asarray(value)
+ except Exception:
+ return None
+ if arr.ndim > 1:
+ arr = arr.reshape(-1)
+ return arr.astype(np.float32, copy=False)
+
+ @staticmethod
+ def _numpy_audio_prefix_match(prev: np.ndarray, curr: np.ndarray) -> bool:
+ n = prev.shape[0]
+ if n == 0:
+ return True
+ if curr.shape[0] < n:
+ return False
+ return bool(np.allclose(curr[:n], prev, rtol=1e-3, atol=2e-4))
+
+ def _raw_waveform_to_deltas(self, arr: np.ndarray) -> list[np.ndarray]:
+ """Convert one streaming PCM f32 chunk into incremental piece(s) for the client.
+
+ Some engine paths emit a growing cumulative waveform each step; others emit
+ true per-step deltas. We support both without duplicating audio on the client.
+ """
+ if arr.size == 0:
+ return []
+ ref = self._realtime_audio_ref
+ if ref is None:
+ self._realtime_audio_ref = arr.copy()
+ return [arr]
+ if self._numpy_audio_prefix_match(ref, arr):
+ delta = arr[ref.shape[0] :]
+ self._realtime_audio_ref = arr.copy()
+ return [delta] if delta.size > 0 else []
+ # True per-step delta (not a prefix extension of what we have seen).
+ self._realtime_audio_ref = np.concatenate([ref, arr])
+ return [arr]
+
+ def _extract_audio_chunks(self, output) -> tuple[list[np.ndarray], int]:
+ mm = getattr(output, "multimodal_output", None)
+ if not isinstance(mm, dict):
+ return [], 24000
+
+ sr = mm.get("sr") or mm.get("sample_rate") or mm.get("audio_sample_rate") or 24000
+ key = "audio" if "audio" in mm else ("model_outputs" if "model_outputs" in mm else None)
+ if key is None:
+ return [], int(sr)
+
+ raw_audio = mm.get(key)
+ chunks: list[np.ndarray] = []
+ if isinstance(raw_audio, (list, tuple)):
+ if len(raw_audio) > 0:
+ arr = self._tensor_to_numpy(raw_audio[-1])
+ if arr is not None and arr.size > 0:
+ chunks.extend(self._raw_waveform_to_deltas(arr))
+ else:
+ arr = self._tensor_to_numpy(raw_audio)
+ if arr is not None and arr.size > 0:
+ chunks.extend(self._raw_waveform_to_deltas(arr))
+ return chunks, int(sr)
+
+ @staticmethod
+ def _pcm16_b64(audio_f32: np.ndarray) -> str:
+ clipped = np.clip(audio_f32, -1.0, 1.0)
+ pcm16 = (clipped * 32767.0).astype(np.int16)
+ return base64.b64encode(pcm16.tobytes()).decode("utf-8")
+
+ async def _run_generation(
+ self,
+ streaming_input_gen: AsyncGenerator,
+ input_stream: asyncio.Queue[list[int]],
+ ):
+ request_id = f"rt-{self.connection_id}-{uuid4()}"
+ sent_audio = False
+ audio_done_sent = False
+ full_text = ""
+ sent_text_len = 0
+ prompt_token_ids_len = 0
+ completion_tokens_len = 0
+ self._realtime_audio_ref = None
+
+ try:
+ result_gen = self.serving.engine_client.generate(
+ prompt=streaming_input_gen,
+ request_id=request_id,
+ )
+
+ async for output in result_gen:
+ if output.outputs and len(output.outputs) > 0:
+ output0 = output.outputs[0]
+ token_ids = list(output0.token_ids)
+ if token_ids:
+ input_stream.put_nowait(token_ids)
+ # token_ids are cumulative per request
+ completion_tokens_len = len(token_ids)
+ if not prompt_token_ids_len and output.prompt_token_ids:
+ prompt_token_ids_len = len(output.prompt_token_ids)
+ cumulative_text = output0.text or ""
+ if cumulative_text:
+ if len(cumulative_text) >= sent_text_len:
+ delta_text = cumulative_text[sent_text_len:]
+ else:
+ delta_text = cumulative_text
+ sent_text_len = len(cumulative_text)
+ full_text = cumulative_text
+ else:
+ delta_text = ""
+
+ if delta_text:
+ await self.send(TranscriptionDelta(delta=delta_text))
+
+ audio_chunks, sample_rate = self._extract_audio_chunks(output)
+
+ for chunk in audio_chunks:
+ sent_audio = True
+ await self.send_json(
+ {
+ "type": "response.audio.delta",
+ "audio": self._pcm16_b64(chunk),
+ "format": "pcm16",
+ "sample_rate_hz": sample_rate,
+ }
+ )
+
+ if not self._is_connected:
+ break
+
+ usage = UsageInfo(
+ prompt_tokens=prompt_token_ids_len,
+ completion_tokens=completion_tokens_len,
+ total_tokens=prompt_token_ids_len + completion_tokens_len,
+ )
+ await self.send(TranscriptionDone(text=full_text, usage=usage))
+
+ if sent_audio:
+ await self.send_json({"type": "response.audio.done", "has_audio": True})
+ audio_done_sent = True
+ except Exception as e:
+ logger.exception("Error in generation: %s", e)
+ await self.send_error(str(e), "processing_error")
+ finally:
+ # Always send terminal event so clients don't hang forever.
+ if self._is_connected and not audio_done_sent:
+ try:
+ await self.send_json({"type": "response.audio.done", "has_audio": sent_audio})
+ except Exception:
+ logger.exception("Failed to send response.audio.done")
+ while not self.audio_queue.empty():
+ self.audio_queue.get_nowait()
+
+ async def send_json(self, payload: dict):
+ await self.websocket.send_text(json.dumps(payload))
diff --git a/vllm_omni/entrypoints/openai/serving_chat.py b/vllm_omni/entrypoints/openai/serving_chat.py
index 28d6ef277b3..8cddac6a6c5 100644
--- a/vllm_omni/entrypoints/openai/serving_chat.py
+++ b/vllm_omni/entrypoints/openai/serving_chat.py
@@ -733,6 +733,12 @@ def _apply_request_overrides(
return params
+ @staticmethod
+ def _set_if_supported(obj: Any, **kwargs: Any) -> None:
+ for key, value in kwargs.items():
+ if value is not None and hasattr(obj, key):
+ setattr(obj, key, value)
+
def _build_sampling_params_list_from_request(
self,
request: ChatCompletionRequest,
@@ -1579,6 +1585,7 @@ async def chat_completion_full_generator(
role,
reasoning_parser,
)
+ final_res = omni_outputs.request_output
elif omni_outputs.final_output_type == "audio":
choices_data = self._create_audio_choice(omni_outputs, role, request, stream=False)
elif omni_outputs.final_output_type == "image":
@@ -2057,6 +2064,254 @@ def _create_image_choice(
return choices
# ==================== Diffusion Mode Methods ====================
+ def _build_multistage_generation_inputs(
+ self,
+ *,
+ engine: AsyncOmni,
+ prompt: str,
+ extra_body: dict[str, Any],
+ reference_images: list[Image.Image],
+ gen_params: OmniDiffusionSamplingParams,
+ ) -> tuple[OmniTextPrompt, list[Any]]:
+ """Build the shared multistage generation prompt and stage params."""
+ stage_configs = getattr(engine, "stage_configs", None) or []
+ default_params_list = list(getattr(engine, "default_sampling_params_list", []) or [])
+
+ height = gen_params.height
+ width = gen_params.width
+ seed = gen_params.seed
+ generator_device = gen_params.generator_device
+ num_outputs_per_prompt = gen_params.num_outputs_per_prompt
+ num_inference_steps = extra_body.get("num_inference_steps")
+ guidance_scale = extra_body.get("guidance_scale")
+ true_cfg_scale = extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale")
+ negative_prompt = extra_body.get("negative_prompt")
+ num_frames = extra_body.get("num_frames")
+ guidance_scale_2 = extra_body.get("guidance_scale_2")
+ lora_body = extra_body.get("lora")
+ layers = extra_body.get("layers")
+ resolution = extra_body.get("resolution")
+
+ engine_prompt_data: dict[str, Any] | None = None
+ modalities = ["image"]
+ if reference_images:
+ if len(reference_images) == 1:
+ engine_prompt_data = {"img2img": reference_images[0]}
+ modalities = ["img2img"]
+ else:
+ engine_prompt_data = {"image": reference_images}
+
+ engine_prompt: OmniTextPrompt = {"prompt": prompt}
+ engine_prompt["modalities"] = modalities
+ if negative_prompt is not None:
+ engine_prompt["negative_prompt"] = negative_prompt
+
+ mm_processor_kwargs: dict[str, Any] = {}
+ if height is not None:
+ mm_processor_kwargs["target_h"] = height
+ if width is not None:
+ mm_processor_kwargs["target_w"] = width
+ if mm_processor_kwargs:
+ engine_prompt["mm_processor_kwargs"] = mm_processor_kwargs
+ if engine_prompt_data is not None:
+ engine_prompt["multi_modal_data"] = engine_prompt_data
+
+ comprehension_idx = None
+ for idx, stage in enumerate(stage_configs):
+ if getattr(stage, "is_comprehension", False):
+ comprehension_idx = idx
+ break
+
+ sampling_params_list: list[Any] = []
+ for idx, stage_cfg in enumerate(stage_configs):
+ stage_type = get_stage_type(stage_cfg)
+ if idx < len(default_params_list):
+ default_stage_params = default_params_list[idx]
+ if hasattr(default_stage_params, "clone"):
+ try:
+ default_stage_params = default_stage_params.clone()
+ except Exception:
+ pass
+ elif stage_type == "diffusion":
+ default_stage_params = gen_params.clone()
+ else:
+ default_stage_params = SamplingParams()
+
+ if (
+ comprehension_idx is not None
+ and idx == comprehension_idx
+ and seed is not None
+ and hasattr(default_stage_params, "seed")
+ ):
+ default_stage_params.seed = seed
+
+ if stage_type == "diffusion":
+ self._set_if_supported(
+ default_stage_params,
+ height=height,
+ width=width,
+ seed=seed,
+ generator_device=generator_device,
+ num_outputs_per_prompt=num_outputs_per_prompt,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ true_cfg_scale=true_cfg_scale,
+ num_frames=num_frames,
+ guidance_scale_2=guidance_scale_2,
+ layers=layers,
+ resolution=resolution,
+ )
+ if lora_body and isinstance(lora_body, dict):
+ try:
+ lora_req, lora_scale = parse_lora_request(lora_body)
+ if lora_req is not None:
+ default_stage_params.lora_request = lora_req
+ if lora_scale is not None:
+ default_stage_params.lora_scale = lora_scale
+ except Exception as e: # pragma: no cover - safeguard
+ logger.warning("Failed to parse LoRA request: %s", e)
+
+ sampling_params_list.append(default_stage_params)
+
+ return engine_prompt, sampling_params_list
+
+ async def generate_diffusion_images(
+ self,
+ *,
+ prompt: str,
+ extra_body: dict[str, Any] | None = None,
+ reference_images: list[str] | None = None,
+ request_id: str | None = None,
+ ) -> tuple[list[Image.Image], dict[str, Any], float] | ErrorResponse:
+ """Generate diffusion images and return raw images plus generation stats."""
+ if request_id is None:
+ request_id = f"chatcmpl-{uuid.uuid4().hex[:16]}"
+ if extra_body is None:
+ extra_body = {}
+ if reference_images is None:
+ reference_images = []
+
+ engine = self._diffusion_engine if self._diffusion_engine is not None else self.engine_client
+
+ height = extra_body.get("height")
+ width = extra_body.get("width")
+ if "size" in extra_body:
+ try:
+ size_str = extra_body["size"]
+ if isinstance(size_str, str) and "x" in size_str.lower():
+ w, h = size_str.lower().split("x")
+ width, height = int(w), int(h)
+ except ValueError:
+ logger.warning("Invalid size format: %s", extra_body.get("size"))
+
+ seed = extra_body.get("seed")
+ generator_device = extra_body.get("generator_device")
+ negative_prompt = extra_body.get("negative_prompt")
+ num_outputs_per_prompt = extra_body.get("num_outputs_per_prompt", 1)
+ lora_body = extra_body.get("lora")
+
+ pil_images: list[Image.Image] = []
+ for img_b64 in reference_images:
+ try:
+ img_bytes = base64.b64decode(img_b64)
+ pil_images.append(Image.open(BytesIO(img_bytes)))
+ except Exception as e:
+ logger.warning("Failed to decode reference image: %s", e)
+
+ gen_params = OmniDiffusionSamplingParams(
+ height=height,
+ width=width,
+ num_outputs_per_prompt=num_outputs_per_prompt,
+ seed=seed,
+ )
+ self._set_if_supported(
+ gen_params,
+ generator_device=generator_device,
+ num_inference_steps=extra_body.get("num_inference_steps"),
+ guidance_scale=extra_body.get("guidance_scale"),
+ true_cfg_scale=extra_body.get("true_cfg_scale") or extra_body.get("cfg_scale"),
+ num_frames=extra_body.get("num_frames"),
+ guidance_scale_2=extra_body.get("guidance_scale_2"),
+ layers=extra_body.get("layers"),
+ resolution=extra_body.get("resolution"),
+ )
+
+ if lora_body and isinstance(lora_body, dict):
+ try:
+ lora_req, lora_scale = parse_lora_request(lora_body)
+ if lora_req is not None:
+ gen_params.lora_request = lora_req
+ if lora_scale is not None:
+ gen_params.lora_scale = lora_scale
+ except Exception as e: # pragma: no cover - safeguard
+ logger.warning("Failed to parse LoRA request: %s", e)
+
+ gen_prompt: OmniTextPrompt = {
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ }
+ if pil_images:
+ if len(pil_images) == 1:
+ gen_prompt["multi_modal_data"] = {"image": pil_images[0]}
+ else:
+ od_config = getattr(engine, "od_config", None)
+ supports_multimodal_inputs = getattr(od_config, "supports_multimodal_inputs", False)
+ if od_config is None:
+ supports_multimodal_inputs = True
+ if supports_multimodal_inputs:
+ gen_prompt["multi_modal_data"] = {"image": pil_images}
+ else:
+ return self._create_error_response(
+ "Multiple input images are not supported by the current diffusion model. "
+ "For multi-image editing, start the server with Qwen-Image-Edit-2509 "
+ "and send multiple images in the user message content.",
+ status_code=400,
+ )
+
+ if isinstance(engine, AsyncOmni):
+ diffusion_engine = cast(AsyncOmni, engine)
+ stage_configs = getattr(diffusion_engine, "stage_configs", None) or []
+ if len(stage_configs) > 1:
+ engine_prompt, sampling_params_list = self._build_multistage_generation_inputs(
+ engine=diffusion_engine,
+ prompt=prompt,
+ extra_body=extra_body,
+ reference_images=pil_images,
+ gen_params=gen_params,
+ )
+ else:
+ engine_prompt = gen_prompt
+ sampling_params_list = [gen_params]
+
+ result = None
+ async for output in diffusion_engine.generate(
+ prompt=engine_prompt,
+ sampling_params_list=sampling_params_list,
+ request_id=request_id,
+ ):
+ result = output
+ if result is None:
+ return self._create_error_response("No output generated from AsyncOmni", status_code=500)
+ else:
+ result = await engine.generate(
+ prompt=gen_prompt,
+ sampling_params=gen_params,
+ request_id=request_id,
+ )
+
+ images = getattr(result.request_output, "images", [])
+ stage_durations = result.stage_durations
+ peak_memory_mb = result.peak_memory_mb
+
+ flat_images: list[Image.Image] = []
+ for item in images:
+ if isinstance(item, list):
+ flat_images.extend(item)
+ else:
+ flat_images.append(item)
+
+ return flat_images, stage_durations, peak_memory_mb
+
async def _create_diffusion_chat_completion(
self,
request: ChatCompletionRequest,
@@ -2234,8 +2489,8 @@ async def _create_diffusion_chat_completion(
if hasattr(default_stage_params, "clone"):
try:
default_stage_params = default_stage_params.clone()
- except Exception:
- pass
+ except Exception as e:
+ logger.warning("Failed to clone default params for stage %d: %s", idx, e)
sampling_params_list.append(default_stage_params)
if not sampling_params_list:
diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py
index 3eaf18111c0..ba8292f0c27 100644
--- a/vllm_omni/entrypoints/openai/serving_speech.py
+++ b/vllm_omni/entrypoints/openai/serving_speech.py
@@ -457,6 +457,25 @@ def _estimate_fish_prompt_len(self, text: str, ref_text: str, ref_audio: object)
logger.warning("Failed to estimate Fish Speech prompt length, using fallback 2048: %s", e)
return 2048
+ async def _build_voxcpm2_prompt(self, request: OpenAICreateSpeechRequest) -> dict[str, Any]:
+ """Build prefill prompt for VoxCPM2 TTS (`prompt_token_ids` padded to full prefill length)."""
+ from vllm_omni.model_executor.models.voxcpm2.voxcpm2_talker import build_voxcpm2_prompt
+
+ self._voxcpm2_encode("") # lazy-init tokenizer + split_map
+ ref_audio = None
+ ref_sr = None
+ if request.ref_audio is not None:
+ ref_audio, ref_sr = await self._resolve_ref_audio(request.ref_audio)
+ return build_voxcpm2_prompt(
+ hf_config=self.engine_client.model_config.hf_config,
+ tokenizer=self._voxcpm2_tokenizer,
+ split_map=self._voxcpm2_split_map,
+ text=request.input,
+ ref_audio=ref_audio,
+ ref_sr=ref_sr,
+ ref_text=request.ref_text,
+ )
+
def _get_uploaded_audio_data(self, voice_name: str) -> str | None:
"""Get base64 encoded audio data for uploaded voice."""
voice_name_lower = voice_name.lower()
@@ -1524,16 +1543,8 @@ async def _prepare_speech_generation(
if request.instructions:
prompt["instruct"] = request.instructions
elif self._tts_model_type == "voxcpm2":
+ prompt = await self._build_voxcpm2_prompt(request)
tts_params = {}
- additional: dict[str, Any] = {}
- if request.ref_audio is not None:
- wav_list, sr = await self._resolve_ref_audio(request.ref_audio)
- additional["reference_audio"] = [[wav_list, sr]]
- # Pre-split multichar Chinese tokens (VoxCPM2 was trained with single-char CJK IDs).
- token_ids = self._voxcpm2_encode(request.input)
- prompt: dict[str, Any] = {"prompt_token_ids": token_ids}
- if additional:
- prompt["additional_information"] = additional
elif self._is_tts:
validation_error = self._validate_tts_request(request)
if validation_error:
diff --git a/vllm_omni/entrypoints/pd_utils.py b/vllm_omni/entrypoints/pd_utils.py
index 0e3d65f5537..413d5d6b448 100644
--- a/vllm_omni/entrypoints/pd_utils.py
+++ b/vllm_omni/entrypoints/pd_utils.py
@@ -23,9 +23,19 @@
class PDDisaggregationMixin:
"""Mixin supplying PD disaggregation helpers to OmniBase."""
+ def _get_pd_separation_pair(self) -> tuple[int, int] | None:
+ """PD prefill/decode indices when ``_init_pd_state`` ran; else ``None``.
+
+ Partial test doubles may skip ``OmniBase.__init__``; treat missing state as
+ no PD disaggregation instead of raising ``AttributeError``.
+ """
+ return getattr(self, "_pd_separation_pair", None)
+
def _init_pd_state(self) -> None:
"""Initialise PD disaggregation state."""
- self._pd_separation_pair: tuple[int, int] | None = self._detect_pd_separation()
+ self._pd_separation_pair: tuple[int, int] | None = self.detect_pd_separation_from_stage_configs(
+ self.stage_configs
+ )
self._pd_connector_info: dict[str, Any] | None = None
self._pd_kv_params_by_req: dict[str, dict[str, Any]] = {}
self._pd_kv_params_lock = threading.Lock()
@@ -40,11 +50,19 @@ def _init_pd_state(self) -> None:
d_id,
)
- def _detect_pd_separation(self) -> tuple[int, int] | None:
- """Scan stage_list for a prefill/decode pair. Returns (p_id, d_id) or None."""
+ @staticmethod
+ def detect_pd_separation_from_stage_configs(stage_configs: list[Any]) -> tuple[int, int] | None:
+ """Scan stage configs for a prefill/decode pair.
+
+ Returns:
+ (prefill_idx, decode_idx) if one pair exists, None if not found.
+
+ Raises:
+ ValueError: if multiple candidate PD pairs are found.
+ """
prefill_by_id: dict[int, int] = {}
decode_indices: list[int] = []
- for i, stage in enumerate(self.stage_list):
+ for i, stage in enumerate(stage_configs):
if getattr(stage, "is_prefill_only", False):
prefill_by_id[i] = i
sid = getattr(stage, "stage_id", i)
@@ -55,7 +73,7 @@ def _detect_pd_separation(self) -> tuple[int, int] | None:
pd_pairs: list[tuple[int, int]] = []
for j in decode_indices:
- source_ids = getattr(self.stage_list[j], "engine_input_source", [])
+ source_ids = getattr(stage_configs[j], "engine_input_source", [])
for src in source_ids:
if src in prefill_by_id:
pd_pairs.append((prefill_by_id[src], j))
@@ -107,10 +125,11 @@ def _normalize_kv_transfer_params(self, kv_params: Any) -> dict[str, Any] | None
def _validate_pd_separation_config(self) -> None:
"""Validate PD stage configurations are consistent."""
- assert self._pd_separation_pair is not None
- p_id, d_id = self._pd_separation_pair
- p_stage = self.stage_list[p_id]
- d_stage = self.stage_list[d_id]
+ pair = self._get_pd_separation_pair()
+ assert pair is not None
+ p_id, d_id = pair
+ p_stage = self.stage_configs[p_id]
+ d_stage = self.stage_configs[d_id]
def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]:
ea = stage.engine_args
@@ -158,11 +177,12 @@ def _get_kv_cfg(stage: "OmniStage") -> dict[str, Any]:
def _get_pd_connector_info(self) -> dict[str, Any] | None:
"""Extract prefill engine KV connector info."""
- if self._pd_separation_pair is None:
+ pair = self._get_pd_separation_pair()
+ if pair is None:
return None
- p_id, _ = self._pd_separation_pair
- p_stage = self.stage_list[p_id]
+ p_id, _ = pair
+ p_stage = self.stage_configs[p_id]
ea = p_stage.engine_args
kv_cfg = getattr(ea, "kv_transfer_config", None)
@@ -241,18 +261,17 @@ def _extract_kv_transfer_params(self, engine_outputs: Any) -> dict[str, Any] | N
def _is_pd_routing(self, stage_id: int, next_stage_id: int) -> bool:
"""True when edge stage_id → next_stage_id is the prefill→decode boundary."""
- return self._pd_separation_pair is not None and self._pd_separation_pair == (
- stage_id,
- next_stage_id,
- )
+ pair = self._get_pd_separation_pair()
+ return pair is not None and pair == (stage_id, next_stage_id)
def _maybe_expand_sampling_params(self, sampling_params_list: list) -> list:
"""Auto-duplicate thinker SP for decode stage when user provides N-1 params."""
- if self._pd_separation_pair is None:
+ pair = self._get_pd_separation_pair()
+ if pair is None:
return sampling_params_list
- if len(sampling_params_list) != len(self.stage_list) - 1:
+ if len(sampling_params_list) != len(self.stage_configs) - 1:
return sampling_params_list
- p_id, d_id = self._pd_separation_pair
+ p_id, d_id = pair
sp_list = list(sampling_params_list)
sp_list.insert(d_id, sp_list[p_id])
return sp_list
diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py
index 84391c2ea8d..5757d389900 100644
--- a/vllm_omni/entrypoints/utils.py
+++ b/vllm_omni/entrypoints/utils.py
@@ -1,3 +1,4 @@
+import argparse
import os
import types
from collections import Counter
@@ -5,10 +6,12 @@
from pathlib import Path
from typing import Any, get_args, get_origin
+import yaml
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config, get_hf_file_to_dict
from vllm.transformers_utils.repo_utils import file_or_path_exists
+from vllm_omni.config.stage_config import StageConfigFactory
from vllm_omni.config.yaml_util import create_config, load_yaml_config, merge_configs
from vllm_omni.entrypoints.stage_utils import _to_dict
from vllm_omni.platforms import current_omni_platform
@@ -23,6 +26,65 @@
}
+def detect_explicit_cli_keys(
+ argv: list[str],
+ parser: argparse.ArgumentParser | None = None,
+) -> set[str]:
+ """Walk ``argv`` and return the set of ``dest`` attribute names the user
+ explicitly provided (e.g. ``--max-num-seqs 64`` → ``max_num_seqs``).
+
+ Used to distinguish user-typed CLI args from argparse default values so
+ deploy YAMLs are not silently overridden by parser defaults. Shared
+ across online (``vllm serve``) and offline (scripts, examples, tests,
+ CI) entry points — offline callers that parse CLI args via argparse
+ should invoke this on ``sys.argv[1:]`` and pass the result through to
+ ``AsyncOmni`` / ``Omni`` via the ``_cli_explicit_keys`` kwarg.
+
+ When ``parser`` is provided, each token is looked up in the parser's
+ action table to find its real ``dest``. This correctly handles flags
+ with ``dest=`` overrides, alias flags (e.g. ``--usp`` /
+ ``--ulysses-degree`` both mapping to ``ulysses_degree``), and
+ ``--disable-foo`` / ``store_false`` patterns that map to a differently
+ named dest. Callers with access to an ``argparse.ArgumentParser`` should
+ always pass it.
+
+ When ``parser`` is ``None``, a name-based heuristic is used as a
+ fallback (hyphens → underscores, plus a ``no_`` prefix strip for
+ ``argparse.BooleanOptionalAction``). This is correct for simple flags
+ but silently misidentifies ``--disable-X``-style flags and explicit
+ ``dest=`` overrides, so prefer the parser-aware form.
+ """
+ if parser is not None:
+ dest_map: dict[str, str] = {}
+ for action in parser._actions:
+ for opt in action.option_strings:
+ dest_map[opt] = action.dest
+ explicit: set[str] = set()
+ for tok in argv:
+ if not tok.startswith("--"):
+ continue
+ flag = tok.split("=", 1)[0]
+ dest = dest_map.get(flag)
+ if dest is not None:
+ explicit.add(dest)
+ return explicit
+
+ # Fallback: name-based heuristic (legacy path for callers without a parser).
+ explicit = set()
+ for tok in argv:
+ if not tok.startswith("--"):
+ continue
+ name = tok[2:].split("=", 1)[0]
+ if not name:
+ continue
+ attr = name.replace("-", "_")
+ explicit.add(attr)
+ # BooleanOptionalAction: --no-foo records as dest `foo`, not `no_foo`.
+ if attr.startswith("no_"):
+ explicit.add(attr[3:])
+ return explicit
+
+
def inject_omni_kv_config(stage: Any, omni_conn_cfg: dict[str, Any], omni_from: str, omni_to: str) -> None:
"""Inject connector configuration into stage engine arguments."""
# Prepare omni_kv_config dict
@@ -273,29 +335,59 @@ def resolve_model_config_path(model: str) -> str:
return str(stage_config_path)
-def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list:
+def load_stage_configs_from_model(
+ model: str,
+ base_engine_args: dict | None = None,
+ deploy_config_path: str | None = None,
+ stage_overrides: dict[str, dict[str, Any]] | None = None,
+ cli_explicit_keys: set[str] | None = None,
+) -> list:
"""Load stage configurations from model's default config file.
- .. deprecated::
- This is the legacy OmegaConf-based loading path. New code should use
- ``StageConfigFactory.create_from_model()`` instead. This function will
- be removed once all callers are migrated (see PR series [2/N]).
+ For models registered in the pipeline registry (new path), uses
+ ``StageConfigFactory.create_from_model()`` which merges
+ PipelineConfig + DeployConfig + CLI overrides.
- Loads stage configurations based on the model type and device type.
- First tries to load a device-specific YAML file from stage_configs/{device_type}/
- directory. If not found, falls back to the default config file.
+ For other models (legacy path), loads stage configs from YAML.
Args:
model: Model name or path (used to determine model_type)
+ base_engine_args: Base engine args to merge as CLI overrides.
+ deploy_config_path: Optional explicit deploy config path.
+ stage_overrides: Per-stage overrides from --stage-overrides.
+ cli_explicit_keys: Set of CLI keys the user actually typed. When
+ provided, only these keys override deploy YAML; argparse defaults
+ stay subordinate to YAML. ``None`` means treat every kwarg as
+ explicit (programmatic ``Omni()`` calls).
Returns:
List of stage configuration dictionaries
-
- Raises:
- FileNotFoundError: If no stage config file exists for the model type
"""
if base_engine_args is None:
base_engine_args = {}
+
+ cli_overrides = _convert_dataclasses_to_dict(dict(base_engine_args))
+ # Per-stage JSON overrides are always explicit (the user typed --stage-overrides).
+ explicit = set(cli_explicit_keys) if cli_explicit_keys is not None else None
+ if stage_overrides:
+ for stage_id_str, overrides in stage_overrides.items():
+ for key, val in overrides.items():
+ stage_key = f"stage_{stage_id_str}_{key}"
+ cli_overrides[stage_key] = val
+ if explicit is not None:
+ explicit.add(stage_key)
+
+ stages = StageConfigFactory.create_from_model(
+ model,
+ cli_overrides=cli_overrides,
+ deploy_config_path=deploy_config_path,
+ cli_explicit_keys=explicit,
+ )
+ if stages is not None:
+ # Convert StageConfig objects to OmegaConf for backward compat
+ return [stage.to_omegaconf() for stage in stages]
+
+ # Legacy fallback: load from YAML
stage_config_path = resolve_model_config_path(model)
if stage_config_path is None:
return []
@@ -312,10 +404,9 @@ def load_stage_configs_from_yaml(
base_engine_args: dict | None = None,
prefer_stage_engine_args: bool = True,
) -> list:
- """Load stage configurations from a YAML file.
+ """Load stage configurations from a YAML file (legacy OmegaConf path).
- .. deprecated::
- Legacy OmegaConf-based loader. Will be removed in PR series [2/N].
+ TODO(@lishunyang12): remove once all models use PipelineConfig + DeployConfig.
Args:
config_path: Path to the YAML configuration file
@@ -449,22 +540,75 @@ def load_and_resolve_stage_configs(
stage_configs_path: str | None,
kwargs: dict | None,
default_stage_cfg_factory: Any = None,
+ deploy_config_path: str | None = None,
+ stage_overrides: dict[str, dict[str, Any]] | None = None,
+ cli_explicit_keys: set[str] | None = None,
) -> tuple[str, list]:
"""Load stage configurations from model or YAML file with fallback to defaults.
Args:
model: Model name or path
- stage_configs_path: Optional path to YAML file containing stage configurations
+ stage_configs_path: Optional path to legacy YAML (stage_args format)
kwargs: Engine arguments to merge with stage configs
default_stage_cfg_factory: Optional callable that takes no args and returns
default stage config list when no configs are found
+ deploy_config_path: Optional path to deploy YAML (new format).
+ Mutually exclusive with ``stage_configs_path``.
+ stage_overrides: Per-stage overrides from ``--stage-overrides`` JSON.
+ Keys are stage_id strings, values are dicts of overrides.
Returns:
Tuple of (config_path, stage_configs)
"""
- if stage_configs_path is None:
+ if stage_configs_path is not None and deploy_config_path is not None:
+ raise ValueError(
+ "--stage-configs-path and --deploy-config are mutually exclusive: "
+ "they use different path resolution rules and loading paths. "
+ "Use --deploy-config for new-format YAMLs (preferred); "
+ "--stage-configs-path is kept only for the legacy `stage_args` format "
+ "and will be removed in a future release."
+ )
+ if stage_configs_path is not None and deploy_config_path is None:
+ if not os.path.exists(stage_configs_path):
+ raise FileNotFoundError(
+ f"--stage-configs-path {stage_configs_path!r} does not exist. "
+ "Legacy `stage_configs/` yamls were replaced by `vllm_omni/deploy/.yaml`; "
+ "use --deploy-config. See docs/configuration/stage_configs.md."
+ )
+ with open(stage_configs_path, encoding="utf-8") as f:
+ _peek = yaml.safe_load(f) or {}
+ if "stages" in _peek and "stage_args" not in _peek:
+ deploy_config_path = stage_configs_path
+ stage_configs_path = None
+ else:
+ logger.warning(
+ "--stage-configs-path is deprecated; migrate %r and use --deploy-config.",
+ stage_configs_path,
+ )
+
+ if deploy_config_path is not None:
+ config_path = deploy_config_path
+ stage_configs = load_stage_configs_from_model(
+ model,
+ base_engine_args=kwargs,
+ deploy_config_path=deploy_config_path,
+ stage_overrides=stage_overrides,
+ cli_explicit_keys=cli_explicit_keys,
+ )
+ if not stage_configs:
+ if default_stage_cfg_factory is not None:
+ default_stage_cfg = default_stage_cfg_factory()
+ stage_configs = create_config(default_stage_cfg)
+ else:
+ stage_configs = []
+ elif stage_configs_path is None:
config_path = resolve_model_config_path(model)
- stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs)
+ stage_configs = load_stage_configs_from_model(
+ model,
+ base_engine_args=kwargs,
+ stage_overrides=stage_overrides,
+ cli_explicit_keys=cli_explicit_keys,
+ )
if not stage_configs:
if default_stage_cfg_factory is not None:
default_stage_cfg = default_stage_cfg_factory()
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/__init__.py b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py
new file mode 100644
index 00000000000..d7fa44fd7e4
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/__init__.py
@@ -0,0 +1,18 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+
+from .ming_flash_omni import MingFlashOmniForConditionalGeneration
+from .ming_flash_omni_thinker import (
+ MingFlashOmniThinkerDummyInputsBuilder,
+ MingFlashOmniThinkerForConditionalGeneration,
+ MingFlashOmniThinkerMultiModalProcessor,
+ MingFlashOmniThinkerProcessingInfo,
+)
+
+__all__ = [
+ "MingFlashOmniForConditionalGeneration",
+ "MingFlashOmniThinkerForConditionalGeneration",
+ "MingFlashOmniThinkerProcessingInfo",
+ "MingFlashOmniThinkerMultiModalProcessor",
+ "MingFlashOmniThinkerDummyInputsBuilder",
+]
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py
new file mode 100644
index 00000000000..6ca19901141
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/audio_encoder.py
@@ -0,0 +1,246 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team.
+# Copyright (c) 2022 OpenAI
+# Adapted from Ming repository modeling_whisper_encoder.py
+# https://github.com/inclusionAI/Ming
+
+import operator
+from collections.abc import Iterable
+from itertools import accumulate
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func
+from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids
+
+logger = init_logger(__name__)
+
+
+class MultiHeadAttention(nn.Module):
+ """Multi-head attention with packed sequence support.
+ Adapted from Qwen3-TTS WhisperEncoder.
+ """
+
+ def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True):
+ super().__init__()
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
+
+ if use_flash_attn and not HAS_FLASH_ATTN:
+ logger.warning("flash-attn is not available. Fallback to manual PyTorch version")
+ self.use_flash_attn = use_flash_attn and HAS_FLASH_ATTN
+
+ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
+ """Forward pass with packed sequence support.
+
+ Args:
+ x: [total_tokens, n_state] packed sequence
+ cu_seqlens: [num_seqs + 1] cumulative sequence lengths, e.g. [0, len1, len1+len2, ...]
+
+ Returns:
+ [total_tokens, n_state] attention output
+ """
+ q = self.query(x)
+ k = self.key(x)
+ v = self.value(x)
+
+ n_ctx, n_state = q.shape
+ head_dim = n_state // self.n_head
+
+ q = q.view(n_ctx, self.n_head, head_dim)
+ k = k.view(n_ctx, self.n_head, head_dim)
+ v = v.view(n_ctx, self.n_head, head_dim)
+
+ # Try flash attention varlen
+ if self.use_flash_attn and cu_seqlens is not None and q.dtype in [torch.float16, torch.bfloat16]:
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen)
+ else:
+ attn_output = self._manual_attention(q, k, v, cu_seqlens)
+
+ # Reshape back: [T, H, D] -> [T, H*D]
+ attn_output = attn_output.contiguous().view(n_ctx, n_state)
+ return self.out(attn_output)
+
+ def _manual_attention(
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor
+ ) -> torch.Tensor:
+ """Manual attention for variable-length sequences (fallback)."""
+ _, n_head, head_dim = q.shape
+ scale = head_dim**-0.5
+
+ # Unpack sequences and pad to max length
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ batch_size = len(seqlens)
+ max_seqlen = max(seqlens)
+
+ # Create padded tensors
+ q_padded = torch.zeros(batch_size, max_seqlen, n_head, head_dim, dtype=q.dtype, device=q.device)
+ k_padded = torch.zeros_like(q_padded)
+ v_padded = torch.zeros_like(q_padded)
+
+ # Fill with actual sequences
+ for i in range(batch_size):
+ start_idx = cu_seqlens[i]
+ end_idx = cu_seqlens[i + 1]
+ seq_len = seqlens[i]
+ q_padded[i, :seq_len] = q[start_idx:end_idx]
+ k_padded[i, :seq_len] = k[start_idx:end_idx]
+ v_padded[i, :seq_len] = v[start_idx:end_idx]
+
+ # Transpose for attention: [B, H, T, D]
+ q_padded = q_padded.transpose(1, 2)
+ k_padded = k_padded.transpose(1, 2)
+ v_padded = v_padded.transpose(1, 2)
+
+ # Create attention mask for variable lengths: 0 for valid positions, -inf for padding
+ padding_mask = (
+ torch.arange(max_seqlen, device=q.device)[None, :] >= torch.tensor(seqlens, device=q.device)[:, None]
+ )
+ attn_mask = torch.zeros(batch_size, 1, 1, max_seqlen, dtype=q.dtype, device=q.device)
+ attn_mask = attn_mask.masked_fill(padding_mask.unsqueeze(1).unsqueeze(2), -torch.finfo(q.dtype).max)
+
+ # Compute attention
+ attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
+ attn_scores = attn_scores + attn_mask
+ attn_weights = F.softmax(attn_scores, dim=-1)
+ context = torch.matmul(attn_weights, v_padded)
+
+ # Transpose back: [B, H, T, D] -> [B, T, H, D]
+ context = context.transpose(1, 2).contiguous()
+ output_packed = torch.cat([context[i, : seqlens[i]] for i in range(batch_size)], dim=0)
+
+ return output_packed
+
+
+class ResidualAttentionBlock(nn.Module):
+ """Whisper-style residual attention block with packed sequence support.
+
+ Adapted from
+ https://github.com/openai/whisper/blob/v20250625/whisper/model.py
+ vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
+ """
+
+ def __init__(self, n_state: int, n_head: int, use_flash_attn: bool = True):
+ super().__init__()
+ self.attn = MultiHeadAttention(n_state, n_head, use_flash_attn=use_flash_attn)
+ self.attn_ln = nn.LayerNorm(n_state)
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(
+ Linear(n_state, n_mlp),
+ nn.GELU(),
+ Linear(n_mlp, n_state),
+ )
+ self.mlp_ln = nn.LayerNorm(n_state)
+
+ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
+ x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+
+class WhisperAudioEncoder(nn.Module):
+ """Whisper audio encoder for Ming with packed sequence support.
+
+ Adapted from
+ https://github.com/openai/whisper/blob/v20250625/whisper/model.py
+ vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
+ """
+
+ def __init__(
+ self,
+ n_mels: int = 128,
+ n_ctx: int = 15000,
+ n_state: int = 1280,
+ n_head: int = 20,
+ n_layer: int = 32,
+ use_flash_attn: bool = True,
+ ):
+ super().__init__()
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+ # self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+ self.blocks = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head, use_flash_attn=use_flash_attn) for _ in range(n_layer)]
+ )
+ self.ln_post = nn.LayerNorm(n_state)
+ self.audio_emb_dim = n_state
+
+ self.n_layer = n_layer
+ self.n_mels = n_mels
+ self.use_flash_attn = use_flash_attn
+
+ def forward(
+ self,
+ x_list: list[torch.Tensor],
+ audio_lens: list[int],
+ ) -> torch.Tensor:
+ """Forward pass with packed sequence format for variable-length inputs.
+
+ Args:
+ x_list: List of [n_mels, T_i] mel spectrogram features for each audio
+ audio_lens: List of original audio lengths in frames
+
+ Returns:
+ [total_T', n_state] packed encoded audio features, where
+ total_T' is the sum of all encoded sequence lengths
+ """
+ # Cast inputs to model dtype
+ target_dtype = self.conv1.weight.dtype
+ x_list = [x.to(target_dtype) for x in x_list]
+
+ encoded_list = []
+ encoded_lens = []
+ for mel_spec in x_list:
+ # mel_spec: [n_mels, T] - process through conv layers
+ x = mel_spec.unsqueeze(0) # [1, n_mels, T]
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.squeeze(0).transpose(0, 1) # [T', n_state]
+
+ # Add positional embedding
+ seq_len = x.shape[0]
+ positional_embedding = self.positional_embedding[:seq_len, :]
+ x = (x + positional_embedding).to(x.dtype)
+
+ encoded_list.append(x)
+ encoded_lens.append(seq_len)
+
+ x_packed = torch.cat(encoded_list, dim=0) # [total_T', n_state]
+
+ cu_seqlens = list(accumulate(encoded_lens, func=operator.add, initial=0))
+ cu_seqlens = torch.tensor(cu_seqlens, device=x_packed.device, dtype=torch.int32)
+
+ for block in self.blocks:
+ x_packed = block(x_packed, cu_seqlens=cu_seqlens)
+
+ x_packed = self.ln_post(x_packed)
+ return x_packed
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict: dict[str, torch.Tensor] = {
+ **dict(self.named_parameters(remove_duplicate=False)),
+ **dict(self.named_buffers()),
+ }
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ if name not in params_dict:
+ logger.warning("Skipping unknown audio encoder weight: %s", name)
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py
new file mode 100644
index 00000000000..87728890b67
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni.py
@@ -0,0 +1,223 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved.
+# Adapted from Ming repository modeling_bailingmm2.py
+# https://github.com/inclusionAI/Ming
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Ming-flash-omni-2.0 unified model (thinker + imagegen + talker)."""
+
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.model_executor.models.interfaces import (
+ SupportsMRoPE,
+ SupportsMultiModal,
+ SupportsPP,
+)
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.models.utils import (
+ init_vllm_registered_model,
+ maybe_prefix,
+)
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.sequence import IntermediateTensors
+
+from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+from vllm_omni.model_executor.models.utils import add_prefix_to_loaded_weights
+from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config, MingFlashOmniConfig
+
+from .ming_flash_omni_thinker import (
+ MingFlashOmniThinkerDummyInputsBuilder,
+ MingFlashOmniThinkerMultiModalProcessor,
+ MingFlashOmniThinkerProcessingInfo,
+)
+
+logger = init_logger(__name__)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ MingFlashOmniThinkerMultiModalProcessor,
+ info=MingFlashOmniThinkerProcessingInfo,
+ dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder,
+)
+class MingFlashOmniForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsMRoPE,
+ CustomProcessMixin,
+):
+ """Unified Ming-flash-omni-2.0 model combining thinker, imagegen, and talker."""
+
+ supports_multimodal = True
+ requires_raw_input_tokens: bool = True
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.have_multimodal_outputs = True
+ self.has_preprocess = False
+ self.has_postprocess = False
+
+ config = vllm_config.model_config.hf_config
+
+ self.vllm_config = vllm_config
+ self.config = config
+
+ if isinstance(config, MingFlashOmniConfig):
+ thinker_config = config.thinker_config
+ else:
+ thinker_config = config
+
+ self.thinker_config: BailingMM2Config = thinker_config
+ self.model_stage = vllm_config.model_config.model_stage
+
+ if self.model_stage == "thinker":
+ thinker_vllm_config = vllm_config.with_hf_config(
+ thinker_config, architectures=["MingFlashOmniThinkerForConditionalGeneration"]
+ )
+ self.thinker = init_vllm_registered_model(
+ vllm_config=thinker_vllm_config,
+ prefix=maybe_prefix(prefix, "thinker"),
+ architectures=["MingFlashOmniThinkerForConditionalGeneration"],
+ )
+ self.model = self.thinker
+ self.imagegen = None
+ self.talker = None
+
+ elif self.model_stage == "imagegen":
+ # TODO: Implement image generator stage
+ raise NotImplementedError(
+ "Image generation stage is not yet implemented. Please use model_stage='thinker' for now."
+ )
+
+ elif self.model_stage == "talker":
+ # TODO: Implement talker (TTS) stage
+ raise NotImplementedError(
+ "Talker (TTS) stage is not yet implemented. Please use model_stage='thinker' for now."
+ )
+
+ else:
+ raise ValueError(
+ f"Invalid model_stage: {self.model_stage}. Must be one of: 'thinker', 'imagegen', 'talker'"
+ )
+
+ # Set up intermediate tensors
+ self.make_empty_intermediate_tensors = (
+ self.thinker.make_empty_intermediate_tensors if self.model_stage == "thinker" else lambda: None
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs,
+ ) -> OmniOutput:
+ return self.model.forward(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ **kwargs,
+ )
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata=None,
+ ) -> torch.Tensor | None:
+ if hasattr(self.model, "compute_logits"):
+ return self.model.compute_logits(hidden_states, sampling_metadata)
+ return None
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata,
+ ):
+ if hasattr(self.model, "sample"):
+ return self.model.sample(logits, sampling_metadata)
+ raise NotImplementedError("sample method not available on current stage")
+
+ def get_mrope_input_positions(self, *args, **kwargs):
+ if hasattr(self.model, "get_mrope_input_positions"):
+ return self.model.get_mrope_input_positions(*args, **kwargs)
+ raise NotImplementedError("get_mrope_input_positions not available on current stage")
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loaded_weights = set()
+ thinker_weights = []
+ imagegen_weights = []
+ talker_weights = []
+
+ for name, value in weights:
+ if name.startswith("thinker."):
+ thinker_weights.append((name, value))
+ elif name.startswith("imagegen."):
+ imagegen_weights.append((name, value))
+ elif name.startswith("talker."):
+ talker_weights.append((name, value))
+ else:
+ # Weights without prefix go to thinker by default
+ thinker_weights.append((name, value))
+
+ if self.model_stage == "thinker" and thinker_weights:
+ # Remove "thinker." prefix before loading
+ thinker_weights_stripped = [
+ (name.replace("thinker.", "", 1) if name.startswith("thinker.") else name, value)
+ for name, value in thinker_weights
+ ]
+ thinker_loaded = self.thinker.load_weights(thinker_weights_stripped)
+ thinker_loaded = add_prefix_to_loaded_weights(thinker_loaded, "thinker")
+ loaded_weights.update(thinker_loaded)
+
+ # TODO: Load imagegen weights when implemented
+ # TODO: Load talker weights when implemented
+
+ return loaded_weights
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ return MultiModelKeys.from_string_field(
+ language_model="thinker.language_model",
+ connector=["thinker.linear_proj.", "thinker.linear_proj_audio."],
+ tower_model=["thinker.vision.", "thinker.audio."],
+ )
+
+ @property
+ def sampler(self):
+ if hasattr(self.model, "sampler"):
+ return self.model.sampler
+ return None
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings=None,
+ *,
+ is_multimodal=None,
+ ) -> torch.Tensor:
+ return self.model.embed_input_ids(
+ input_ids,
+ multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ )
+
+ def embed_multimodal(self, **kwargs):
+ return self.model.embed_multimodal(**kwargs)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py
new file mode 100644
index 00000000000..bde7477b945
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/ming_flash_omni_thinker.py
@@ -0,0 +1,893 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team.
+# Adapted from Ming repository modeling_bailingmm2.py and processing_bailingmm2.py
+# https://github.com/inclusionAI/Ming
+
+"""Ming-flash-omni-2.0 Thinker stage implementation (multimodal understanding)."""
+
+from collections.abc import Iterable, Iterator, Mapping, Sequence
+from typing import Annotated, Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.feature_extraction_utils import BatchFeature
+from vllm.config import VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.inputs import MultiModalDataDict
+from vllm.logger import init_logger
+from vllm.model_executor.models.interfaces import (
+ MultiModalEmbeddings,
+ SupportsMRoPE,
+ SupportsMultiModal,
+ SupportsPP,
+)
+from vllm.model_executor.models.qwen2_5_vl import (
+ Qwen2_5_VLImageInputs,
+ Qwen2_5_VLImagePixelInputs,
+ Qwen2_5_VLVideoInputs,
+ Qwen2_5_VLVideoPixelInputs,
+)
+from vllm.model_executor.models.qwen2_vl import (
+ Qwen2VLProcessingInfo,
+)
+from vllm.model_executor.models.utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ _merge_multimodal_embeddings,
+ maybe_prefix,
+)
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalFeatureSpec,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import (
+ AudioProcessorItems,
+ ImageProcessorItems,
+ MultiModalDataItems,
+ MultiModalDataParser,
+ VideoProcessorItems,
+)
+from vllm.multimodal.processing import (
+ BaseDummyInputsBuilder,
+ BaseMultiModalProcessor,
+ PromptReplacement,
+ PromptUpdate,
+ PromptUpdateDetails,
+)
+from vllm.sequence import IntermediateTensors
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
+from vllm_omni.model_executor.models.output_templates import OmniOutput
+from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMM2Config
+from vllm_omni.transformers_utils.processors.ming import (
+ PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
+ PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
+ PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
+ MingFlashOmniProcessor,
+ MingWhisperFeatureExtractor,
+)
+
+from .audio_encoder import WhisperAudioEncoder
+from .modeling_bailing_moe_v2 import BailingMoeV2ForCausalLM
+from .projectors import AudioProjector, VisionProjector
+from .vision_encoder import MingVisionEncoder
+
+logger = init_logger(__name__)
+
+
+class MingAudioInput(TensorSchema):
+ """
+ Dimensions:
+ - b: Batch size
+ - l: Total audio frames (clips concatenated along the time axis)
+ - nm: Number of mel bins
+ - N: Max number of audio clips per batch item
+ """
+
+ audio_feats: Annotated[
+ torch.Tensor,
+ TensorShape("b", "l", "nm"),
+ ]
+
+ audio_feats_lengths: Annotated[
+ torch.Tensor,
+ TensorShape("b", "N"),
+ ]
+
+
+class MingFlashOmniThinkerProcessingInfo(Qwen2VLProcessingInfo):
+ def get_hf_config(self) -> BailingMM2Config:
+ return self.ctx.get_hf_config(BailingMM2Config)
+
+ def get_hf_processor(self, **kwargs: object):
+ return self.ctx.get_hf_processor(MingFlashOmniProcessor, **kwargs)
+
+ def get_target_channels(self) -> int:
+ # See `_normalize_audio_tensor` in vllm_omni/transformers_utils/processors/ming.py
+ return 1
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"image": None, "video": None, "audio": None}
+
+ def get_mm_max_tokens_per_item(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> Mapping[str, int]:
+ mm_counts = mm_counts or {}
+ requested_modalities = {m for m, c in mm_counts.items() if c > 0}
+ mm_max_tokens: dict[str, int] = {}
+
+ if requested_modalities & {"image", "video"}:
+ vl_tokens = super().get_mm_max_tokens_per_item(
+ seq_len=seq_len,
+ mm_counts=mm_counts,
+ )
+ mm_max_tokens.update({m: vl_tokens[m] for m in ["image", "video"] if m in requested_modalities})
+
+ if "audio" in requested_modalities:
+ # TODO: consider computing from audio config
+ mm_max_tokens["audio"] = 3000
+
+ return mm_max_tokens
+
+ def get_feature_extractor(self, **kwargs: object) -> MingWhisperFeatureExtractor:
+ hf_processor = self.get_hf_processor(**kwargs)
+ feature_extractor = hf_processor.audio_processor
+ assert isinstance(feature_extractor, MingWhisperFeatureExtractor)
+ return feature_extractor
+
+ def get_data_parser(self):
+ feature_extractor = self.get_feature_extractor()
+ return MultiModalDataParser(
+ target_sr=feature_extractor.sampling_rate,
+ target_channels=self.get_target_channels(),
+ expected_hidden_size=self._get_expected_hidden_size(),
+ )
+
+
+class MingFlashOmniThinkerDummyInputsBuilder(BaseDummyInputsBuilder[MingFlashOmniThinkerProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+ num_audios = mm_counts.get("audio", 0)
+
+ hf_processor = self.info.get_hf_processor()
+
+ audio_token: str = hf_processor.audio_token
+ image_token: str = hf_processor.image_token
+ video_token: str = hf_processor.video_token
+
+ return image_token * num_images + video_token * num_videos + audio_token * num_audios
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions] | None = None,
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+ num_audios = mm_counts.get("audio", 0)
+
+ # Default dimensions for dummy data
+ image_width, image_height = 448, 448
+ video_width, video_height = 448, 448
+ num_frames = 8
+ audio_duration = 3.0 # seconds
+ sample_rate = 16000
+
+ audio_length = int(audio_duration * sample_rate)
+
+ mm_data: MultiModalDataDict = {
+ "image": self._get_dummy_images(
+ width=image_width,
+ height=image_height,
+ num_images=num_images,
+ ),
+ "video": self._get_dummy_videos(
+ width=video_width,
+ height=video_height,
+ num_frames=num_frames,
+ num_videos=num_videos,
+ ),
+ "audio": [(np.random.randn(audio_length).astype(np.float32), sample_rate) for _ in range(num_audios)],
+ }
+
+ return mm_data
+
+
+class MingFlashOmniThinkerMultiModalProcessor(BaseMultiModalProcessor[MingFlashOmniThinkerProcessingInfo]):
+ """Multimodal processor for Ming-flash-omni Thinker stage.
+
+ Handles preprocessing of 1) image, 2) video, and 3) audio inputs,
+ and expands placeholder tokens to the correct number of patch tokens.
+ """
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ tokenizer = self.info.get_tokenizer()
+ # might want to add a fallback to resolve token ids
+ # vocab = tokenizer.get_vocab()
+ thinker_config = self.info.get_hf_config()
+
+ # patch/delimiter token IDs (used in replacement sequences)
+ image_start_token_id = thinker_config.llm_config.image_start_token
+ image_patch_token_id = thinker_config.llm_config.image_patch_token
+ image_end_token_id = thinker_config.llm_config.image_end_token
+
+ video_start_token_id = thinker_config.llm_config.video_start_token
+ frame_patch_token_id = thinker_config.llm_config.video_patch_token
+ video_end_token_id = thinker_config.llm_config.video_end_token
+
+ audio_start_token_id = thinker_config.llm_config.audio_start_token
+ audio_patch_token_id = thinker_config.llm_config.audio_patch_token
+ audio_end_token_id = thinker_config.llm_config.audio_end_token
+
+ vision_config = thinker_config.vision_config
+ spatial_merge_size = vision_config.spatial_merge_size if vision_config else 2
+
+ newline_token_ids: list[int] = tokenizer.encode("\n", add_special_tokens=False)
+
+ out_mm_data = out_mm_kwargs.get_data()
+
+ def get_replacement_image(item_idx: int) -> PromptUpdateDetails:
+ """Generate token sequence for an image."""
+ grid_thw = out_mm_data.get("image_grid_thw")
+ if grid_thw is None:
+ raise ValueError(
+ "image_grid_thw missing from processor output; "
+ "cannot determine image patch count for prompt replacement."
+ )
+ if isinstance(grid_thw, torch.Tensor):
+ thw = grid_thw[item_idx]
+ num_patches = int(thw.prod().item()) // (spatial_merge_size**2)
+ else:
+ thw = grid_thw[item_idx]
+ num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2)
+
+ # Build token sequence: *N \n
+ # the newline token is added in purpose from original model processing
+ tokens: list[int] = []
+ tokens.append(image_start_token_id)
+ tokens.extend([image_patch_token_id] * num_patches)
+ tokens.append(image_end_token_id)
+ # Refer to Ming's BailingMM2Processor._expand_image_tokens
+ # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py
+ tokens.extend(newline_token_ids)
+
+ # Only tokens receive multimodal embeddings
+ return PromptUpdateDetails.select_token_id(tokens, image_patch_token_id)
+
+ def get_replacement_video(item_idx: int) -> PromptUpdateDetails:
+ """Generate token sequence for a video."""
+ grid_thw = out_mm_data.get("video_grid_thw", None)
+ if grid_thw is None:
+ raise ValueError(
+ "video_grid_thw missing from processor output; "
+ "cannot determine video patch count for prompt replacement."
+ )
+ if isinstance(grid_thw, torch.Tensor):
+ thw = grid_thw[item_idx]
+ num_patches = int(thw.prod().item()) // (spatial_merge_size**2)
+ else:
+ thw = grid_thw[item_idx]
+ num_patches = (thw[0] * thw[1] * thw[2]) // (spatial_merge_size**2)
+
+ # Build token sequence: *N \n
+ # the newline token is added in purpose from original model processing
+ tokens: list[int] = []
+ tokens.append(video_start_token_id)
+ tokens.extend([frame_patch_token_id] * num_patches)
+ tokens.append(video_end_token_id)
+ tokens.extend(newline_token_ids)
+
+ # Only tokens receive multimodal embeddings
+ return PromptUpdateDetails.select_token_id(tokens, frame_patch_token_id)
+
+ def get_replacement_audio(item_idx: int) -> PromptUpdateDetails:
+ """Generate token sequence for an audio."""
+ encoder_feats_lengths = out_mm_data.get("encoder_feats_lengths", None)
+ if encoder_feats_lengths is None:
+ raise ValueError(
+ "encoder_feats_lengths missing from processor output; "
+ "cannot determine audio patch count for prompt replacement."
+ )
+ if isinstance(encoder_feats_lengths, torch.Tensor):
+ num_patches = int(encoder_feats_lengths[item_idx].item())
+ else:
+ num_patches = encoder_feats_lengths[item_idx]
+
+ # Build token sequence: *N
+ tokens: list[int] = []
+ tokens.append(audio_start_token_id)
+ tokens.extend([audio_patch_token_id] * num_patches)
+ tokens.append(audio_end_token_id)
+
+ # Only tokens receive multimodal embeddings
+ return PromptUpdateDetails.select_token_id(tokens, audio_patch_token_id)
+
+ # Build prompt updates and process replacement
+ updates: list[PromptUpdate] = []
+
+ if "image" in mm_items and mm_items.get_items("image", ImageProcessorItems):
+ updates.append(
+ PromptReplacement(
+ modality="image",
+ target=PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
+ replacement=get_replacement_image,
+ )
+ )
+ if "video" in mm_items and mm_items.get_items("video", VideoProcessorItems):
+ updates.append(
+ PromptReplacement(
+ modality="video",
+ target=PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
+ replacement=get_replacement_video,
+ )
+ )
+ if "audio" in mm_items and mm_items.get_items("audio", AudioProcessorItems):
+ updates.append(
+ PromptReplacement(
+ modality="audio",
+ target=PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
+ replacement=get_replacement_audio,
+ )
+ )
+ return updates
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ config: dict[str, MultiModalFieldConfig] = {}
+
+ # Image fields, pixel_values is flat (concatenated patches from all images)
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ if "pixel_values" in hf_inputs:
+ image_sizes = image_grid_thw.prod(-1)
+ config["pixel_values"] = MultiModalFieldConfig.flat_from_sizes(
+ "image",
+ image_sizes,
+ )
+ if "image_grid_thw" in hf_inputs:
+ config["image_grid_thw"] = MultiModalFieldConfig.batched("image")
+
+ # Video fields, same flat layout as images
+ video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
+ if "pixel_values_videos" in hf_inputs:
+ video_sizes = video_grid_thw.prod(-1)
+ config["pixel_values_videos"] = MultiModalFieldConfig.flat_from_sizes(
+ "video",
+ video_sizes,
+ )
+ if "video_grid_thw" in hf_inputs:
+ config["video_grid_thw"] = MultiModalFieldConfig.batched("video")
+
+ # Audio fields
+ if "audio_feats" in hf_inputs:
+ config["audio_feats"] = MultiModalFieldConfig.batched("audio")
+ if "audio_feats_lengths" in hf_inputs:
+ config["audio_feats_lengths"] = MultiModalFieldConfig.batched("audio")
+ if "encoder_feats_lengths" in hf_inputs:
+ config["encoder_feats_lengths"] = MultiModalFieldConfig.batched("audio")
+ if "placeholder_audio_loc_lens" in hf_inputs:
+ config["placeholder_audio_loc_lens"] = MultiModalFieldConfig.batched("audio")
+
+ return config
+
+ def _hf_processor_applies_updates(
+ self,
+ prompt_text: str,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ tokenization_kwargs: Mapping[str, object],
+ ) -> bool:
+ return False
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ """Call sub-processors for multimodal inputs and tokenize.
+
+ We call the image/audio sub-processors directly (instead of going
+ through `MingFlashOmniProcessor.__call__`) so that the high-level
+ placeholder tokens remain **unexpanded** in the tokenized output.
+ """
+ hf_processor = self.info.get_hf_processor()
+ tokenizer = self.info.get_tokenizer()
+
+ data: dict[str, object] = {}
+
+ images = mm_data.get("images", None)
+ if images is not None:
+ image_outputs = hf_processor.image_processor(
+ images=images,
+ videos=None,
+ return_tensors="pt",
+ )
+ data.update(image_outputs)
+
+ videos = mm_data.get("videos", None)
+ if videos is not None:
+ video_outputs = hf_processor.image_processor(
+ images=None,
+ videos=videos,
+ return_tensors="pt",
+ )
+ # Rename keys to distinguish from images
+ if "pixel_values" in video_outputs:
+ video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values")
+ if "image_grid_thw" in video_outputs:
+ video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw")
+ data.update(video_outputs)
+
+ audios = mm_data.get("audios", None)
+ if audios is not None:
+ # vLLM's AudioProcessorItems provides raw numpy arrays (already resampled).
+ # MingWhisperAudioProcessor expects (waveform, sr) tuples,
+ # so wrap them with the target sample rate.
+ target_sr = hf_processor.audio_processor.sampling_rate
+ audio_tuples = [(a, target_sr) if not isinstance(a, tuple) else a for a in audios]
+
+ audio_outputs = hf_processor.audio_processor(
+ audio_tuples,
+ return_tensors="pt",
+ )
+ data.update(audio_outputs)
+
+ # Tokenize text with placeholders still intact
+ text_outputs = tokenizer(prompt, return_tensors="pt", **tok_kwargs)
+ data.update(text_outputs)
+
+ return BatchFeature(data=data)
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ MingFlashOmniThinkerMultiModalProcessor,
+ info=MingFlashOmniThinkerProcessingInfo,
+ dummy_inputs=MingFlashOmniThinkerDummyInputsBuilder,
+)
+class MingFlashOmniThinkerForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsPP,
+ SupportsMRoPE,
+ CustomProcessMixin,
+):
+ """Ming Thinker stage: multimodal understanding
+ (text + image + video + audio) -> text generation.
+ """
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={"model.": "language_model."},
+ )
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> str | None:
+ # vllm_omni/transformers_utils/processors/ming.py
+ if modality.startswith("image"):
+ return ""
+ elif modality.startswith("video"):
+ return ""
+ elif modality.startswith("audio"):
+ return ""
+
+ raise ValueError("Only image, video, or audio modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+
+ thinker_config: BailingMM2Config = config
+ if (
+ thinker_config.llm_config is None
+ or thinker_config.vision_config is None
+ or thinker_config.audio_config is None
+ ):
+ raise ValueError(
+ "MingFlashOmniThinker requires `llm_config`, `vision_config`, and `audio_config` in `thinker_config`."
+ )
+
+ llm_config = thinker_config.llm_config
+
+ self.config = llm_config
+ self.thinker_config = thinker_config
+ self.have_multimodal_outputs = True
+
+ # Initialize LLM as a component
+ with self._mark_language_model(vllm_config):
+ llm_vllm_config = vllm_config.with_hf_config(llm_config)
+ self.language_model = BailingMoeV2ForCausalLM(
+ vllm_config=llm_vllm_config, prefix=maybe_prefix(prefix, "llm")
+ )
+
+ # Ming thinker is inherently multimodal; initialize both towers eagerly.
+ with self._mark_tower_model(vllm_config, {"image", "video"}):
+ self.vision = MingVisionEncoder(
+ vision_config=thinker_config.vision_config,
+ quant_config=vllm_config.quant_config,
+ prefix=maybe_prefix(prefix, "vision"),
+ )
+ self.linear_proj = VisionProjector(
+ vision_dim=self.vision.image_emb_dim,
+ llm_dim=llm_config.hidden_size,
+ mlp_depth=getattr(thinker_config, "mlp_depth", 2),
+ )
+ logger.info("Initialized MingVisionEncoder and VisionProjector")
+
+ audio_cfg = thinker_config.audio_config
+ whisper_cfg = getattr(audio_cfg, "whisper_encoder_config", {}) or {}
+ with self._mark_tower_model(vllm_config, "audio"):
+ self.audio = WhisperAudioEncoder(
+ **whisper_cfg,
+ use_flash_attn=True,
+ )
+ self.linear_proj_audio = AudioProjector(
+ audio_dim=self.audio.audio_emb_dim,
+ llm_dim=llm_config.hidden_size,
+ ds_kernel_size=getattr(audio_cfg, "ds_kernel_size", 3),
+ ds_stride=getattr(audio_cfg, "ds_stride", 2),
+ mlp_depth=getattr(thinker_config, "mlp_depth", 1),
+ )
+ logger.info("Initialized WhisperAudioEncoder and AudioProjector")
+
+ # Expose interfaces
+ self.make_empty_intermediate_tensors = self.language_model.make_empty_intermediate_tensors
+
+ logger.info("MingFlashOmniThinker initialized with vision and audio towers")
+
+ def extract_image_feature(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
+ """Extract and project image features.
+
+ Args:
+ pixel_values: Flattened pixel values from vision processor.
+ grid_thw: [num_images, 3] tensor of (t, h, w) grid dimensions.
+
+ Returns:
+ [seq_len, hidden_size] L2-normalized image embeddings.
+ """
+ if self.vision is None:
+ raise ValueError("Vision encoder not initialized")
+
+ with torch.amp.autocast(pixel_values.device.type, dtype=torch.bfloat16):
+ image_embeds = self.vision(pixel_values, grid_thw=grid_thw)
+
+ if self.vision.use_deepstack:
+ image_embeds = image_embeds[:, : self.vision.image_emb_dim]
+
+ image_embeds = self.linear_proj(image_embeds)
+ image_embeds = F.normalize(image_embeds, dim=-1)
+ return image_embeds
+
+ def extract_audio_feature(
+ self, audio_feats: torch.Tensor, audio_feats_lengths: torch.Tensor
+ ) -> tuple[torch.Tensor, ...]:
+ """Extract and project audio features.
+
+ Args:
+ audio_feats: [B, L_total, n_mels] wrapped mel features — multiple audio
+ clips per batch item are concatenated along the time dimension
+ (time-first, as produced by MingWhisperFeatureExtractor).
+ audio_feats_lengths: [B, N] lengths of each audio clip per batch item.
+ N is the max number of clips per item; zero-padded entries are skipped.
+
+ Returns:
+ Tuple of per-clip [T'_i, hidden_size] projected audio embeddings.
+ """
+ if self.audio is None:
+ raise ValueError("Audio encoder not initialized")
+
+ # Unwrap packed [B, L_total, n_mels] into a list of [n_mels, T'_i] tensors,
+ # one per audio clip, as expected by WhisperAudioEncoder.
+ x_list: list[torch.Tensor] = []
+ audio_lens: list[int] = []
+ for i in range(audio_feats_lengths.shape[0]):
+ feat_index = 0
+ for j in range(audio_feats_lengths.shape[1]):
+ feat_len = int(audio_feats_lengths[i, j].item())
+ if feat_len == 0:
+ break
+ mel_seg = audio_feats[i, feat_index : feat_index + feat_len].transpose(0, 1)
+ x_list.append(mel_seg)
+ audio_lens.append(feat_len)
+ feat_index += feat_len
+
+ audio_packed = self.audio(x_list, audio_lens)
+
+ # Compute per-clip lengths after Whisper Conv1d (kernel=3, stride=2, pad=1)
+ encoded_lens = [(audio_len - 3 + 2) // 2 + 1 for audio_len in audio_lens]
+
+ # Project packed
+ proj_packed, proj_lens = self.linear_proj_audio.forward_packed(audio_packed, encoded_lens)
+
+ normalize = getattr(self.thinker_config.audio_config, "norm_query_embeds", False)
+ if normalize:
+ proj_packed = F.normalize(proj_packed, dim=-1)
+
+ proj_packed = proj_packed.to(audio_feats.dtype)
+
+ # Split into per-clip tensors
+ return proj_packed.split(proj_lens)
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ """Parse and validate multimodal kwargs into per-modality dicts."""
+ mm_input_by_modality: dict[str, Qwen2_5_VLImageInputs | Qwen2_5_VLVideoInputs | MingAudioInput] = {}
+
+ for key in kwargs:
+ if key == "pixel_values" and "image" not in mm_input_by_modality:
+ pixel_values = kwargs.get("pixel_values")
+ image_grid_thw = kwargs.get("image_grid_thw")
+ if pixel_values is not None and image_grid_thw is not None:
+ mm_input_by_modality["image"] = Qwen2_5_VLImagePixelInputs(
+ type="pixel_values",
+ pixel_values=pixel_values, # type: ignore[arg-type]
+ image_grid_thw=image_grid_thw, # type: ignore[arg-type]
+ )
+ elif key == "pixel_values_videos" and "video" not in mm_input_by_modality:
+ pixel_values_videos = kwargs.get("pixel_values_videos")
+ video_grid_thw = kwargs.get("video_grid_thw")
+ second_per_grid_ts = kwargs.get("second_per_grid_ts")
+ if pixel_values_videos is not None and video_grid_thw is not None:
+ mm_input_by_modality["video"] = Qwen2_5_VLVideoPixelInputs(
+ type="pixel_values_videos",
+ pixel_values_videos=pixel_values_videos, # type: ignore[arg-type]
+ video_grid_thw=video_grid_thw, # type: ignore[arg-type]
+ second_per_grid_ts=second_per_grid_ts, # type: ignore[arg-type]
+ )
+ elif key == "audio_feats" and "audio" not in mm_input_by_modality:
+ audio_feats = kwargs.get("audio_feats")
+ audio_feats_lengths = kwargs.get("audio_feats_lengths")
+ if audio_feats is not None and audio_feats_lengths is not None:
+ mm_input_by_modality["audio"] = MingAudioInput(
+ audio_feats=audio_feats, # type: ignore[arg-type]
+ audio_feats_lengths=audio_feats_lengths, # type: ignore[arg-type]
+ )
+
+ return mm_input_by_modality
+
+ def _process_image_input(self, image_input: Qwen2_5_VLImageInputs) -> list[torch.Tensor]:
+ # Splits the flat [total_tokens, D] output of extract_image_feature
+ # into one tensor per image.
+ pixel_values = image_input["pixel_values"]
+ image_grid_thw = image_input["image_grid_thw"]
+ image_embeds = self.extract_image_feature(pixel_values, image_grid_thw)
+ merge_unit = self.thinker_config.vision_config.spatial_merge_size**2
+ sizes = (image_grid_thw.prod(dim=-1) // merge_unit).tolist()
+ return list(image_embeds.split([int(s) for s in sizes], dim=0))
+
+ def _process_video_input(self, video_input: Qwen2_5_VLVideoInputs) -> list[torch.Tensor]:
+ pixel_values_videos = video_input["pixel_values_videos"]
+ video_grid_thw = video_input["video_grid_thw"]
+ video_embeds = self.extract_image_feature(pixel_values_videos, video_grid_thw)
+ merge_unit = self.thinker_config.vision_config.spatial_merge_size**2
+ sizes = (video_grid_thw.prod(dim=-1) // merge_unit).tolist()
+ return list(video_embeds.split([int(s) for s in sizes], dim=0))
+
+ def _process_audio_input(self, audio_input: MingAudioInput) -> list[torch.Tensor]:
+ return list(self.extract_audio_feature(audio_input["audio_feats"], audio_input["audio_feats_lengths"]))
+
+ def _compute_modality_masks(self, input_ids: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
+ """Compute vision and audio MoE-routing masks from input_ids.
+
+ Returns:
+ Tuple of (vision_mask, audio_mask), each shape [seq_len] bool.
+ """
+ llm_config = self.config
+
+ # vision mask
+ vision_mask = torch.zeros_like(input_ids, dtype=torch.bool)
+ image_token = llm_config.image_patch_token
+ video_token = llm_config.video_patch_token
+ vision_mask = vision_mask | (input_ids == image_token)
+ vision_mask = vision_mask | (input_ids == video_token)
+
+ # audio mask
+ audio_token = llm_config.audio_patch_token
+ audio_mask = input_ids == audio_token
+
+ return vision_mask, audio_mask
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not mm_input_by_modality:
+ return []
+
+ # preserve the order of modalities
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ for modality, mm_input in mm_input_by_modality.items():
+ if modality == "image":
+ multimodal_embeddings += tuple(self._process_image_input(mm_input)) # type: ignore[arg-type]
+ elif modality == "video":
+ multimodal_embeddings += tuple(self._process_video_input(mm_input)) # type: ignore[arg-type]
+ elif modality == "audio":
+ multimodal_embeddings += tuple(self._process_audio_input(mm_input)) # type: ignore[arg-type]
+
+ return multimodal_embeddings
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
+ *,
+ is_multimodal: torch.Tensor | None = None,
+ handle_oov_mm_token: bool = False,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.model.word_embeddings(input_ids)
+
+ if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
+ return inputs_embeds
+
+ assert is_multimodal is not None, "`is_multimodal` mask required when `multimodal_embeddings` provided"
+ return _merge_multimodal_embeddings(
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ is_multimodal=is_multimodal,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs,
+ ) -> OmniOutput:
+ # Compute MoE modality masks on every device
+ image_mask, audio_mask = self._compute_modality_masks(input_ids)
+ hidden_states = self.language_model.forward(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ image_mask=image_mask,
+ audio_mask=audio_mask,
+ )
+
+ # Capture embeddings for downstream stages
+ multimodal_outputs = {
+ "final_hidden_states": hidden_states,
+ }
+
+ return OmniOutput(
+ text_hidden_states=hidden_states,
+ multimodal_outputs=multimodal_outputs,
+ )
+
+ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) -> torch.Tensor | None:
+ return self.language_model.compute_logits(hidden_states, sampling_metadata)
+
+ def sample(self, logits: torch.Tensor, sampling_metadata):
+ return self.language_model.sample(logits, sampling_metadata)
+
+ @property
+ def sampler(self):
+ return self.language_model.sampler
+
+ def iter_mm_features(
+ self,
+ mm_features: list[MultiModalFeatureSpec],
+ ) -> Iterator[tuple[int, str, dict[str, Any]]]:
+ """Iterate over image/video features sorted by token position.
+
+ Yields: (offset, modality, feature_data) where feature_data contains:
+ - image: {"grid_t", "grid_h", "grid_w", "second_per_grid_t"}
+ - video: {"grid_t", "grid_h", "grid_w", "second_per_grid_t"}
+
+ Audio features are not yielded: Ming assigns them sequential
+ text positions (same T/H/W value) rather than 3D grid positions.
+ """
+ spatial_merge_size = self.config.spatial_merge_size
+
+ for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
+ if mm_feature.data is None:
+ continue
+
+ offset = mm_feature.mm_position.offset
+ modality = mm_feature.modality
+
+ if modality == "image":
+ t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
+ yield (
+ offset,
+ "image",
+ {
+ "grid_t": int(t),
+ "grid_h": int(h) // spatial_merge_size,
+ "grid_w": int(w) // spatial_merge_size,
+ "second_per_grid_t": 0.0,
+ },
+ )
+ elif modality == "video":
+ t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
+ second_per_grid_t = 1.0
+ spgt_field = mm_feature.data.get("second_per_grid_ts")
+ if spgt_field is not None:
+ second_per_grid_t = float(spgt_field.data.item())
+ yield (
+ offset,
+ "video",
+ {
+ "grid_t": int(t),
+ "grid_h": int(h) // spatial_merge_size,
+ "grid_w": int(w) // spatial_merge_size,
+ "second_per_grid_t": second_per_grid_t,
+ },
+ )
+
+ def get_mrope_input_positions(
+ self,
+ input_tokens: list[int],
+ mm_features: list[MultiModalFeatureSpec] | None = None,
+ **kwargs: object,
+ ) -> tuple[torch.Tensor, int]:
+ """Compute M-RoPE input positions using mm_features directly."""
+ llm_config = self.config
+ tokens_per_second: int = getattr(llm_config, "tokens_per_second", 2)
+ seq_len = len(input_tokens)
+
+ llm_pos_ids_list: list[np.ndarray] = []
+ st = 0 # index of next unprocessed token
+
+ for patch_offset, _modality, data in self.iter_mm_features(mm_features or []):
+ text_len = patch_offset - st
+ st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
+ if text_len > 0:
+ llm_pos_ids_list.append(np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx)
+ st_idx += text_len
+
+ # 3-D grid positions for patch tokens
+ grid_t: int = data["grid_t"]
+ grid_h: int = data["grid_h"]
+ grid_w: int = data["grid_w"]
+ second_per_grid_t: float = data["second_per_grid_t"]
+
+ t_raw = np.arange(grid_t)
+ if second_per_grid_t > 0:
+ t_index = (t_raw * second_per_grid_t * tokens_per_second).astype(np.int64)
+ else:
+ t_index = t_raw.astype(np.int64)
+ t_index = np.repeat(t_index, grid_h * grid_w)
+
+ h_index = np.tile(np.arange(grid_h).repeat(grid_w), grid_t)
+ w_index = np.tile(np.arange(grid_w), grid_t * grid_h)
+
+ llm_pos_ids_list.append(np.stack([t_index, h_index, w_index]) + st_idx)
+
+ num_patches = grid_t * grid_h * grid_w
+ st = patch_offset + num_patches
+
+ if st < seq_len:
+ st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
+ tail_len = seq_len - st
+ llm_pos_ids_list.append(np.broadcast_to(np.arange(tail_len), (3, tail_len)) + st_idx)
+
+ if llm_pos_ids_list:
+ position_ids = torch.from_numpy(np.concatenate(llm_pos_ids_list, axis=1).astype(np.int64)) # (3, seq_len)
+ else:
+ # text-only, simple sequential positions
+ position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0).expand(3, -1)
+
+ mrope_position_delta = int(position_ids.max().item()) + 1 - seq_len
+ return position_ids, mrope_position_delta
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py b/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py
new file mode 100644
index 00000000000..1ff362c5b9d
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/modeling_bailing_moe_v2.py
@@ -0,0 +1,896 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
+# Adapted from Ming
+# https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Iterable
+
+import torch
+from torch import nn
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.config.cache import CacheConfig
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.attention import Attention
+from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.utils import (
+ PPMissingLayer,
+ WeightsMapper,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+from vllm.sequence import IntermediateTensors
+from vllm.v1.outputs import SamplerOutput
+from vllm.v1.sample.sampler import Sampler
+
+from vllm_omni.model_executor.custom_process_mixin import CustomProcessMixin
+from vllm_omni.transformers_utils.configs.ming_flash_omni import BailingMoeV2Config
+
+logger = init_logger(__name__)
+
+
+class MingVideoRopeMRotaryEmbedding(MRotaryEmbedding):
+ """MRotaryEmbedding with Ming's video_rope cos/sin interleaving.
+
+ Unlike standard mrope which maps contiguous frequency sections to T/H/W,
+ video_rope alternates H/W frequencies element-wise in the spatial section
+ and places temporal frequencies at the end:
+ Standard mrope: [T T T ... H H H ... W W W ...]
+ Video rope: [H W H W ... H W ... T T T ...]
+
+ Refer to Ming's BailingMoeV2RotaryEmbedding3D
+ https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py#L174
+ """
+
+ def _remap_video_rope(
+ self,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Remap 3D cos/sin to video_rope interleaved layout.
+
+ Args:
+ cos, sin: [3, num_tokens, rotary_dim // 2]
+ Returns:
+ cos, sin: [num_tokens, rotary_dim // 2]
+
+ Refer to Ming's apply_3d_rotary_pos_emb
+ https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py#L226
+ """
+ assert self.mrope_section is not None
+ hw_size = self.mrope_section[1] + self.mrope_section[2]
+
+ result_cos = torch.empty_like(cos[0])
+ result_sin = torch.empty_like(sin[0])
+
+ # Spatial frequencies: even indices from H (dim 1), odd from W (dim 2)
+ result_cos[:, 0:hw_size:2] = cos[1, :, 0:hw_size:2]
+ result_cos[:, 1:hw_size:2] = cos[2, :, 1:hw_size:2]
+ result_sin[:, 0:hw_size:2] = sin[1, :, 0:hw_size:2]
+ result_sin[:, 1:hw_size:2] = sin[2, :, 1:hw_size:2]
+
+ # Temporal frequencies at the end
+ result_cos[:, hw_size:] = cos[0, :, hw_size:]
+ result_sin[:, hw_size:] = sin[0, :, hw_size:]
+
+ return result_cos, result_sin
+
+ def forward_native(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ assert positions.ndim == 1 or positions.ndim == 2
+ assert key is not None
+
+ cos_sin_cache = self._match_cos_sin_cache_dtype(query)
+ num_tokens = positions.shape[-1]
+ cos_sin = cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+
+ if positions.ndim == 2:
+ cos, sin = self._remap_video_rope(cos, sin)
+
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., : self.rotary_dim]
+ query_pass = query[..., self.rotary_dim :]
+ query_rot = self.apply_rotary_emb.forward_native(query_rot, cos, sin)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., : self.rotary_dim]
+ key_pass = key[..., self.rotary_dim :]
+ key_rot = self.apply_rotary_emb.forward_native(key_rot, cos, sin)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+
+ return query, key
+
+ def forward_cuda(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # No custom Triton kernel for video_rope; fall back to native for 3D
+ # TODO: Consider custom optimization
+ if positions.ndim == 2:
+ return self.forward_native(positions, query, key, offsets)
+ return super().forward_cuda(positions, query, key, offsets)
+
+ def forward_cpu(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor | None = None,
+ offsets: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ return self.forward_native(positions, query, key, offsets)
+
+
+class BailingMoeV2MLP(nn.Module):
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ intermediate_size: int,
+ hidden_act: str = "silu",
+ quant_config: QuantizationConfig | None = None,
+ reduce_results: bool = True,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = intermediate_size
+
+ self.gate_up_proj = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.intermediate_size] * 2,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_up_proj",
+ )
+ self.down_proj = RowParallelLinear(
+ self.intermediate_size,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ reduce_results=reduce_results,
+ prefix=f"{prefix}.down_proj",
+ )
+
+ if hidden_act != "silu":
+ raise ValueError(f"Unsupported activation: {hidden_act}")
+ self.act_fn = SiluAndMul()
+
+ def forward(self, x):
+ gate_up, _ = self.gate_up_proj(x)
+ x = self.act_fn(gate_up)
+ x, _ = self.down_proj(x)
+ return x
+
+
+class BailingMoeV2Gate(nn.Module):
+ """MoE routing gate with grouped expert selection."""
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_experts
+
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+
+ self.gating_dim = config.hidden_size
+
+ self.gate = ReplicatedLinear(
+ self.gating_dim,
+ self.num_experts,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate",
+ )
+
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ self.expert_bias = nn.Parameter(torch.zeros(self.num_experts), requires_grad=False)
+
+ def group_limited_topk(self, scores: torch.Tensor):
+ """Group-limited top-k selection for expert routing."""
+ num_tokens, _ = scores.size()
+ # Organize experts into groups
+ group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
+ group_mask = torch.zeros_like(group_scores)
+ group_mask.scatter_(1, group_idx, 1)
+
+ # Mask experts based on selected groups
+ score_mask = (
+ group_mask.unsqueeze(-1)
+ .expand(num_tokens, self.n_group, self.num_experts // self.n_group)
+ .reshape(num_tokens, -1)
+ )
+
+ masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf"))
+ probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1, sorted=False)
+
+ return probs, top_indices
+
+ def forward(self, hidden_states):
+ # compute gating score
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+ logits, _ = self.gate(hidden_states)
+
+ logits = logits.float()
+ scores = torch.sigmoid(logits)
+
+ scores_for_routing = scores + self.expert_bias
+ _, topk_idx = self.group_limited_topk(scores_for_routing)
+
+ scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
+
+ topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
+ topk_weight = topk_weight * self.routed_scaling_factor
+
+ return topk_idx, topk_weight, logits
+
+
+def _unpack_multi_routing(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Stateless routing function that unpacks pre-computed routing results.
+
+ Used as `custom_routing_function` for `FusedMoE`. The caller is expected
+ to pack (topk_weight, topk_idx) into `gating_output` before
+ calling FusedMoE.forward(), and this function unpacks them.
+
+ Args:
+ gating_output: [num_tokens, top_k * 2]
+ - [:, :top_k], topk_weight (float)
+ - [:, top_k:], topk_idx (float, cast back to int)
+ """
+ topk_weight = gating_output[:, :topk].contiguous()
+ topk_idx = gating_output[:, topk:]
+ return topk_weight.to(torch.float32), topk_idx.to(torch.int32)
+
+
+class BailingMoeV2SparseMoeBlock(nn.Module):
+ """Sparse MoE block with MultiRouter support for multimodal routing.
+
+ Keep the custom multi-router gating logic external.
+ """
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.num_experts_per_tok = config.num_experts_per_tok
+
+ if isinstance(self.config.num_shared_experts, int) and self.config.num_shared_experts > 0:
+ self.shared_experts = BailingMoeV2MLP(
+ config=self.config,
+ intermediate_size=self.config.moe_intermediate_size * self.config.num_shared_experts,
+ quant_config=quant_config,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_experts",
+ )
+ else:
+ self.shared_experts = None
+
+ self.experts = SharedFusedMoE(
+ shared_experts=self.shared_experts,
+ num_experts=config.num_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ custom_routing_function=_unpack_multi_routing,
+ renormalize=False, # we handle normalization in the gate
+ reduce_results=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts",
+ )
+
+ self.experts.expert_mapping = FusedMoE.make_expert_params_mapping(
+ self.experts,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=config.num_experts,
+ )
+
+ self.router_type = self.config.router_type
+ if self.router_type == "topN":
+ self.gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.gate")
+ elif self.router_type == "MultiRouter":
+ self.gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.gate")
+ self.image_gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.image_gate")
+ self.audio_gate = BailingMoeV2Gate(self.config, quant_config, prefix=f"{prefix}.audio_gate")
+ else:
+ raise ValueError(f"Unsupported router_type: {self.router_type}")
+
+ @staticmethod
+ def _normalize_mask(
+ mask: torch.Tensor,
+ bsz: int,
+ seq_len: int,
+ name: str,
+ ) -> torch.Tensor:
+ """Validate and reshape a modality mask to [bsz*seq_len, 1] bool."""
+ N = bsz * seq_len
+ if mask.ndim == 1:
+ # vLLM path: flat tokens [N]
+ assert mask.shape[0] == N, f"{name} length {mask.shape[0]} != N={N}"
+ elif mask.ndim == 2:
+ assert mask.shape == (bsz, seq_len), f"{name} shape {mask.shape} != ({bsz}, {seq_len})"
+ elif mask.ndim == 3:
+ assert mask.shape == (bsz, seq_len, 1), f"{name} shape {mask.shape} != ({bsz}, {seq_len}, 1)"
+ else:
+ raise ValueError(f"Unsupported {name} shape: {mask.shape}")
+
+ return mask.reshape(N, 1).bool()
+
+ def forward(self, hidden_states, image_mask: torch.Tensor, audio_mask: torch.Tensor):
+ # TODO(yuanheng-zhao): revise the shapes in the flow
+ assert 2 <= hidden_states.dim() <= 3, f"{self.__class__.__name__} only supports 2D or 3D inputs"
+ input_is_2d = hidden_states.ndim == 2
+ if input_is_2d:
+ hidden_states = hidden_states.unsqueeze(0)
+
+ bsz, seq_len, h = hidden_states.shape
+
+ if self.router_type == "MultiRouter":
+ image_mask = self._normalize_mask(image_mask, bsz, seq_len, "image_mask").to(hidden_states.device)
+ audio_mask = self._normalize_mask(audio_mask, bsz, seq_len, "audio_mask").to(hidden_states.device)
+
+ # if image_mask is not None and audio_mask is not None:
+ # assert torch.logical_and(image_mask, audio_mask).sum() == 0
+
+ image_topk_idx, image_topk_weight, _ = self.image_gate(hidden_states)
+ audio_topk_idx, audio_topk_weight, _ = self.audio_gate(hidden_states)
+ topk_idx, topk_weight, _ = self.gate(hidden_states)
+
+ topk_idx = torch.where(image_mask, image_topk_idx, topk_idx)
+ topk_weight = torch.where(image_mask, image_topk_weight, topk_weight)
+ topk_idx = torch.where(audio_mask, audio_topk_idx, topk_idx)
+ topk_weight = torch.where(audio_mask, audio_topk_weight, topk_weight)
+ else:
+ topk_idx, topk_weight, _ = self.gate(hidden_states)
+
+ # Pack pre-computed routing into a single tensor
+ packed_routing = torch.cat(
+ [
+ topk_weight.to(hidden_states.dtype),
+ topk_idx.to(hidden_states.dtype),
+ ],
+ dim=-1,
+ )
+
+ # SharedFusedMoE expects 2D hidden_states
+ hidden_states_2d = hidden_states.view(-1, h)
+ result = self.experts(hidden_states_2d, packed_routing)
+
+ if self.shared_experts is not None:
+ shared_output, fused_out = result
+ else:
+ shared_output, fused_out = None, result
+
+ final_hidden_states = fused_out + shared_output if shared_output is not None else fused_out
+
+ final_hidden_states = final_hidden_states.view(bsz, seq_len, h)
+
+ return final_hidden_states.squeeze(0) if input_is_2d else final_hidden_states
+
+
+class BailingMoeV2Attention(nn.Module):
+ """Multi-headed attention using vLLM's Attention layer with 3D RoPE support."""
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ layer_idx: int,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.head_dim = config.head_dim
+
+ tp_size = get_tensor_model_parallel_world_size()
+ assert self.num_heads % tp_size == 0
+ self.num_heads = self.num_heads // tp_size
+ self.num_kv_heads = max(1, self.num_kv_heads // tp_size)
+
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+
+ partial_rotary_factor = config.partial_rotary_factor
+ self.rope_dim = int(self.head_dim * partial_rotary_factor)
+
+ total_num_heads = config.num_attention_heads
+ total_num_kv_heads = config.num_key_value_heads
+ self.qkv_proj = QKVParallelLinear(
+ self.hidden_size,
+ self.head_dim,
+ total_num_heads,
+ total_num_kv_heads,
+ bias=config.use_qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj",
+ )
+
+ self.dense = RowParallelLinear(
+ total_num_heads * self.head_dim,
+ self.hidden_size,
+ bias=config.use_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense",
+ )
+
+ # apply vLLM RMSNorm here rather than BailingMoeV2RMSNorm, diff might exist
+ self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ # 3D Rotary embeddings for multimodal
+ if config.rope_scaling is None:
+ raise ValueError("rope_scaling must not be None")
+
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ mrope_section = config.rope_scaling.get("mrope_section", [8, 12, 12])
+
+ if rope_type == "video_rope":
+ # Ming-specific video_rope with custom H/W interleaving
+ self.rotary_emb = MingVideoRopeMRotaryEmbedding(
+ head_size=self.head_dim,
+ rotary_dim=self.rope_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ is_neox_style=True,
+ dtype=torch.get_default_dtype(),
+ mrope_section=mrope_section,
+ )
+ else:
+ # Standard m_rope (rope_type "3D", "default", or None)
+ rope_scaling = dict(config.rope_scaling)
+ rope_scaling["rope_type"] = "default" # normalize for get_rope dispatch
+ rope_scaling["mrope_section"] = mrope_section
+ self.rotary_emb = get_rope(
+ head_size=self.head_dim,
+ max_position=config.max_position_embeddings,
+ is_neox_style=True,
+ rope_parameters={
+ "rope_theta": config.rope_theta,
+ "partial_rotary_factor": config.partial_rotary_factor,
+ **rope_scaling,
+ },
+ )
+
+ self.attn = Attention(
+ self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ """Forward pass for attention with 3D RoPE.
+
+ Args:
+ positions: Position IDs, shape (3, num_tokens) for 3D rope
+ or (num_tokens,) for text-only
+ hidden_states: Input hidden states, shape (num_tokens, hidden_size)
+
+ Returns:
+ Attention output tensor, shape (num_tokens, hidden_size)
+ """
+ qkv, _ = self.qkv_proj(hidden_states)
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+ num_tokens = q.shape[0]
+ q = self.q_norm(q.view(num_tokens, self.num_heads, self.head_dim)).view(num_tokens, self.q_size)
+ k = self.k_norm(k.view(num_tokens, self.num_kv_heads, self.head_dim)).view(num_tokens, self.kv_size)
+
+ q, k = self.rotary_emb(positions, q, k)
+
+ attn_output = self.attn(q, k, v)
+
+ output, _ = self.dense(attn_output)
+ return output
+
+
+class BailingMoeV2DecoderLayer(nn.Module):
+ """Decoder layer with attention and MoE MLP."""
+
+ def __init__(
+ self,
+ config: BailingMoeV2Config,
+ layer_idx: int,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+
+ self.attention = BailingMoeV2Attention(
+ config=config,
+ layer_idx=layer_idx,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attention",
+ )
+
+ # MLP or MoE based on layer index
+ if config.num_experts is not None and layer_idx >= config.first_k_dense_replace:
+ self.mlp = BailingMoeV2SparseMoeBlock(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.is_moe = True
+ else:
+ self.mlp = BailingMoeV2MLP(
+ config=config,
+ intermediate_size=config.intermediate_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.is_moe = False
+
+ # apply vLLM RMSNorm to replace BailingMoeV2RMSNorm, diff might exist
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ image_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Forward pass for decoder layer.
+
+ Args:
+ positions: Position IDs
+ hidden_states: Input hidden states
+ residual: Residual connection from previous layer
+ image_mask: Mask for image tokens (for MultiRouter MoE)
+ audio_mask: Mask for audio tokens (for MultiRouter MoE)
+
+ Returns:
+ Tuple of (hidden_states, residual)
+ """
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ hidden_states = self.attention(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+
+ if self.is_moe:
+ hidden_states = self.mlp(hidden_states, image_mask, audio_mask)
+ else:
+ # Dense MLP only takes hidden_states (no routing masks)
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states, residual
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ "image_mask": 0,
+ "audio_mask": 0,
+ }
+)
+class BailingMoeV2Model(nn.Module):
+ """BailingMoeV2 Model adapted from:
+
+ Ming repo BailingMoeV2Model
+ https://github.com/inclusionAI/Ming/blob/2a0c02ae3130190160c215f89fce7de3005db483/modeling_bailing_moe_v2.py
+ vLLM repo BailingMoeModel
+ https://github.com/vllm-project/vllm/blob/7291d1b288558d48508e1a17c37b0aa170332264/vllm/model_executor/models/bailing_moe.py
+ """
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+
+ # BailingMoeV2Config
+ config = vllm_config.model_config.hf_text_config
+
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+
+ self.config = config
+ self.quant_config = quant_config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
+
+ if get_pp_group().is_first_rank or (self.tie_word_embeddings and get_pp_group().is_last_rank):
+ self.word_embeddings = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.word_embeddings",
+ )
+ else:
+ self.word_embeddings = PPMissingLayer()
+
+ # Decoder layers with later pipeline parallelism support
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: BailingMoeV2DecoderLayer(
+ config=config,
+ layer_idx=int(prefix.split(".")[-1]),
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix,
+ ),
+ prefix=f"{prefix}.layers",
+ )
+
+ if get_pp_group().is_last_rank:
+ # apply vLLM RMSNorm to replace BailingMoeV2RMSNorm, diff might exist
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ image_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.word_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for layer in self.layers[self.start_layer : self.end_layer]:
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ image_mask=image_mask,
+ audio_mask=audio_mask,
+ )
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class BailingMoeV2ForCausalLM(nn.Module, CustomProcessMixin):
+ """BailingMoeV2 model for causal language modeling, adapted for vLLM.
+
+ Inherits from CustomProcessMixin to support custom preprocessing and postprocessing
+ for integration with omni model pipelines.
+ """
+
+ packed_modules_mapping = {
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ # BailingMoeV2Config
+ config = vllm_config.model_config.hf_text_config
+ quant_config = vllm_config.quant_config
+
+ self.config = config
+ self.quant_config = quant_config
+
+ self.model = BailingMoeV2Model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"),
+ )
+
+ self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ if self.tie_word_embeddings:
+ self.lm_head.weight = self.model.word_embeddings.weight
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.sampler = Sampler()
+ self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ image_mask: torch.Tensor | None = None,
+ audio_mask: torch.Tensor | None = None,
+ ):
+ hidden_states = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ image_mask=image_mask,
+ audio_mask=audio_mask,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata,
+ ) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata,
+ ) -> SamplerOutput | None:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, weight_name, shard_id)
+ # BailingMoE stores fused QKV in checkpoint as query_key_value
+ ("qkv_proj", "query_key_value", None),
+ # Dense MLP and shared_experts gate/up are stored separately
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Gate router linear layers: checkpoint `{r}.weight` -> model `{r}.gate.weight`
+ gate_name_mapper = WeightsMapper(
+ orig_to_new_substr={f".{r}.weight": f".{r}.gate.weight" for r in ("gate", "image_gate", "audio_gate")}
+ )
+
+ # FusedMoE expert params mapping is identical across all MoE layers
+ expert_params_mapping: list[tuple[str, str, int, str]] = []
+ for layer in self.model.layers:
+ if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
+ expert_params_mapping = layer.mlp.experts.expert_mapping
+ break
+
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in gate_name_mapper.apply(weights):
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ if weight_name not in name or "mlp.experts" in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict.get(name)
+ if param is not None:
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ loaded_params.add(name)
+ break
+ else:
+ for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict.get(name)
+ if param is not None:
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id)
+ loaded_params.add(name)
+ break
+ else:
+ param = params_dict.get(name)
+ if param is not None:
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/projectors.py b/vllm_omni/model_executor/models/ming_flash_omni/projectors.py
new file mode 100644
index 00000000000..42e53d1c635
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/projectors.py
@@ -0,0 +1,184 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright (c) Ant Group. All rights reserved.
+# Adapted from Ming repository modeling_bailingmm2.py
+# https://github.com/inclusionAI/Ming
+
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+
+logger = init_logger(__name__)
+
+
+class Transpose(nn.Module):
+ """Used in nn.Sequential pipelines."""
+
+ def __init__(self, dim0: int, dim1: int):
+ super().__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x.transpose(self.dim0, self.dim1)
+
+
+class VisionProjector(nn.Module):
+ """MLP projector from vision encoder output to LLM hidden space.
+
+ Args:
+ vision_dim: Vision encoder output dimension (out_hidden_size).
+ llm_dim: LLM hidden dimension.
+ mlp_depth: Number of linear layers (>= 1).
+ """
+
+ def __init__(self, vision_dim: int, llm_dim: int, mlp_depth: int = 1):
+ super().__init__()
+ layers: list[nn.Module] = [nn.Linear(vision_dim, llm_dim)]
+ for _ in range(1, mlp_depth):
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(llm_dim, llm_dim))
+ self.proj = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Project vision features.
+
+ Args:
+ x: [seq_len, vision_dim] or [B, seq_len, vision_dim]
+
+ Returns:
+ Projected features with last dim = llm_dim.
+ """
+ return self.proj(x)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if not name.startswith("proj."):
+ name = f"proj.{name}"
+ if name not in params_dict:
+ logger.warning("Skipping unknown vision projector weight: %s", name)
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class AudioProjector(nn.Module):
+ """Projector for audio features.
+
+ Args:
+ audio_dim: Audio encoder output dimension (n_state).
+ llm_dim: LLM hidden dimension.
+ ds_kernel_size: Conv1d kernel size for downsampling.
+ ds_stride: Conv1d stride for downsampling.
+ mlp_depth: Total number of projection layers (>= 1).
+ """
+
+ def __init__(
+ self,
+ audio_dim: int,
+ llm_dim: int,
+ ds_kernel_size: int = 3,
+ ds_stride: int = 2,
+ mlp_depth: int = 1,
+ ):
+ super().__init__()
+ self.ds_kernel_size = ds_kernel_size
+ self.ds_stride = ds_stride
+
+ layers: list[nn.Module] = [
+ nn.Conv1d(
+ audio_dim,
+ llm_dim,
+ kernel_size=ds_kernel_size,
+ stride=ds_stride,
+ padding=ds_kernel_size // 2,
+ ),
+ Transpose(-1, -2), # [B, llm_dim, T'] -> [B, T', llm_dim]
+ ]
+ for _ in range(1, mlp_depth):
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(llm_dim, llm_dim))
+ self.proj = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Project audio features with temporal downsampling.
+
+ Args:
+ x: [B, T, audio_dim] audio encoder output (channel-last).
+
+ Returns:
+ [B, T', llm_dim] projected features (channel-last),
+ where T' = (T - ds_kernel_size + 2*(ds_kernel_size//2)) // ds_stride + 1.
+ """
+ # Conv1d expects [B, C, T], so transpose input
+ x = x.transpose(-1, -2) # [B, audio_dim, T]
+ return self.proj(x)
+
+ def forward_packed(
+ self,
+ packed: torch.Tensor,
+ encoded_lens: list[int],
+ ) -> tuple[torch.Tensor, list[int]]:
+ """Project packed audio features from the Whisper encoder.
+
+ Args:
+ packed: [total_T', audio_dim] packed encoder output.
+ encoded_lens: Length of each clip after Whisper encoding.
+
+ Returns:
+ Tuple of:
+ - [total_T'', llm_dim] packed projected features.
+ - List of projected lengths per clip.
+ """
+ conv1d = self.proj[0]
+ mlp = self.proj[2:]
+
+ # Split packed tensor per clip for Conv1d
+ segments = packed.split(encoded_lens)
+ conv_segments = []
+ proj_lens: list[int] = []
+ for seg in segments:
+ out = conv1d(seg.transpose(0, 1).unsqueeze(0)) # [1, llm_dim, T'_i]
+ out = out.squeeze(0).transpose(0, 1) # [T'_i, llm_dim]
+ conv_segments.append(out)
+ proj_lens.append(out.shape[0])
+
+ packed_proj = torch.cat(conv_segments, dim=0) # [total_T'', llm_dim]
+ packed_proj = mlp(packed_proj)
+ return packed_proj, proj_lens
+
+ def compute_output_length(self, input_length: torch.Tensor) -> torch.Tensor:
+ """Compute output sequence length after Conv1d downsampling.
+
+ Args:
+ input_length: Original mel spectrogram lengths.
+
+ Returns:
+ Output lengths after both convolutions.
+ """
+ length = (input_length - 3 + 2 * 1) // 2 + 1
+ length = (length - self.ds_kernel_size + 2 * (self.ds_kernel_size // 2)) // self.ds_stride + 1
+ return length
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if not name.startswith("proj."):
+ name = f"proj.{name}"
+ if name not in params_dict:
+ logger.warning("Skipping unknown audio projector weight: %s", name)
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py b/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py
new file mode 100644
index 00000000000..7976d76ce8d
--- /dev/null
+++ b/vllm_omni/model_executor/models/ming_flash_omni/vision_encoder.py
@@ -0,0 +1,125 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Adapted from Ming repository qwen3_moe_vit.py
+# https://github.com/inclusionAI/Ming
+
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+from vllm.logger import init_logger
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.models.qwen3_omni_moe_thinker import (
+ Qwen3Omni_VisionTransformer,
+)
+from vllm.model_executor.models.utils import WeightsMapper
+
+logger = init_logger(__name__)
+
+
+def _adapt_vision_config(vision_config):
+ # Adapt Ming's Qwen3VLMoeVisionConfig to be compatible with vLLM's
+ # Qwen3Omni_VisionTransformer expectations.
+ if not hasattr(vision_config, "image_size") or vision_config.image_size is None:
+ if hasattr(vision_config, "num_position_embeddings") and vision_config.num_position_embeddings:
+ import math
+
+ num_grid = int(math.sqrt(vision_config.num_position_embeddings))
+ vision_config.image_size = num_grid * vision_config.patch_size
+ else:
+ vision_config.image_size = vision_config.patch_size * 14 # fallback
+
+ if not hasattr(vision_config, "apply_vit_abs_pos_embed"):
+ vision_config.apply_vit_abs_pos_embed = True
+
+ return vision_config
+
+
+class MingVisionEncoder(nn.Module):
+ """**Wrapper** around vLLM's Qwen3Omni_VisionTransformer for Ming."""
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_substr={
+ "deepstack_merger_list.": "merger_list.",
+ "merger.norm.": "merger.ln_q.",
+ "merger.linear_fc1.": "merger.mlp.0.",
+ "merger.linear_fc2.": "merger.mlp.2.",
+ }
+ )
+
+ def __init__(
+ self,
+ vision_config,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ adapted_config = _adapt_vision_config(vision_config)
+ norm_eps = 1e-6
+ self.encoder = Qwen3Omni_VisionTransformer(
+ vision_config=adapted_config,
+ norm_eps=norm_eps,
+ quant_config=quant_config,
+ prefix=f"{prefix}.encoder",
+ )
+ self.image_emb_dim = vision_config.out_hidden_size
+ self.use_deepstack = (
+ hasattr(vision_config, "deepstack_visual_indexes") and vision_config.deepstack_visual_indexes is not None
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.encoder.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.encoder.device
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ ) -> torch.Tensor:
+ """forward method of the vision encoder.
+
+ Args:
+ pixel_values: Flattened pixel values.
+ grid_thw: [num_images, 3] tensor of (t, h, w) grid sizes.
+
+ Returns:
+ If deepstack is enabled, returns concatenated multi-scale features
+ along the feature dim: [seq_len, hidden_size * (1 + num_deepstack)].
+ Otherwise returns [seq_len, hidden_size].
+ """
+ return self.encoder(pixel_values, grid_thw=grid_thw)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ import re
+
+ def _remap_merger_list_inner(name: str) -> str:
+ name = re.sub(r"(merger_list\.\d+)\.norm\.", r"\1.ln_q.", name)
+ name = re.sub(r"(merger_list\.\d+)\.linear_fc1\.", r"\1.mlp.0.", name)
+ name = re.sub(r"(merger_list\.\d+)\.linear_fc2\.", r"\1.mlp.2.", name)
+
+ return name
+
+ remapped_weights = self.hf_to_vllm_mapper.apply(weights)
+ remapped_weights = ((_remap_merger_list_inner(name), tensor) for name, tensor in remapped_weights)
+ loaded_params = self.encoder.load_weights(remapped_weights)
+
+ loaded_params = {f"encoder.{loaded_param}" for loaded_param in loaded_params}
+
+ return loaded_params
diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py
new file mode 100644
index 00000000000..b44d08eb32a
--- /dev/null
+++ b/vllm_omni/model_executor/models/qwen2_5_omni/pipeline.py
@@ -0,0 +1,78 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Qwen2.5-Omni pipeline topology (frozen).
+
+Stage 0: Thinker — multimodal understanding + text generation
+Stage 1: Talker — text embeddings → speech tokens
+Stage 2: Code2Wav — speech tokens → audio waveform
+"""
+
+from vllm_omni.config.stage_config import (
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+)
+
+_PROC = "vllm_omni.model_executor.stage_input_processors.qwen2_5_omni"
+
+QWEN2_5_OMNI_PIPELINE = PipelineConfig(
+ model_type="qwen2_5_omni",
+ model_arch="Qwen2_5OmniForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="thinker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ final_output_type="text",
+ owns_tokenizer=True,
+ requires_multimodal_data=True,
+ engine_output_type="latent",
+ sampling_constraints={"detokenize": True},
+ ),
+ StagePipelineConfig(
+ stage_id=1,
+ model_stage="talker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(0,),
+ engine_output_type="latent",
+ custom_process_input_func=f"{_PROC}.thinker2talker",
+ sampling_constraints={
+ "detokenize": True,
+ "stop_token_ids": [8294],
+ },
+ ),
+ StagePipelineConfig(
+ stage_id=2,
+ model_stage="code2wav",
+ execution_type=StageExecutionType.LLM_GENERATION,
+ input_sources=(1,),
+ final_output=True,
+ final_output_type="audio",
+ engine_output_type="audio",
+ sampling_constraints={"detokenize": True},
+ ),
+ ),
+)
+
+
+# Single-stage thinker-only variant for the abort test.
+QWEN2_5_OMNI_THINKER_ONLY_PIPELINE = PipelineConfig(
+ model_type="qwen2_5_omni_thinker_only",
+ model_arch="Qwen2_5OmniForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="thinker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ final_output_type="text",
+ owns_tokenizer=True,
+ requires_multimodal_data=True,
+ engine_output_type="latent",
+ sampling_constraints={"detokenize": True},
+ ),
+ ),
+)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/pipeline.py b/vllm_omni/model_executor/models/qwen3_omni/pipeline.py
new file mode 100644
index 00000000000..1c69ec79570
--- /dev/null
+++ b/vllm_omni/model_executor/models/qwen3_omni/pipeline.py
@@ -0,0 +1,63 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Qwen3-Omni-MoE pipeline topology (frozen).
+
+Stage 0: Thinker — multimodal understanding + text generation
+Stage 1: Talker — text embeddings → RVQ codec codes
+Stage 2: Code2Wav — RVQ codes → audio waveform
+"""
+
+from vllm_omni.config.stage_config import (
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+)
+
+_PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_omni"
+
+QWEN3_OMNI_PIPELINE = PipelineConfig(
+ model_type="qwen3_omni_moe",
+ model_arch="Qwen3OmniMoeForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="thinker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ final_output=True,
+ final_output_type="text",
+ owns_tokenizer=True,
+ requires_multimodal_data=True,
+ hf_config_name="thinker_config",
+ engine_output_type="latent",
+ custom_process_next_stage_input_func=(f"{_PROC}.thinker2talker_async_chunk"),
+ sampling_constraints={"detokenize": True},
+ ),
+ StagePipelineConfig(
+ stage_id=1,
+ model_stage="talker",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(0,),
+ hf_config_name="talker_config",
+ engine_output_type="latent",
+ custom_process_input_func=f"{_PROC}.thinker2talker",
+ custom_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"),
+ sampling_constraints={
+ "detokenize": False,
+ "stop_token_ids": [2150],
+ },
+ ),
+ StagePipelineConfig(
+ stage_id=2,
+ model_stage="code2wav",
+ execution_type=StageExecutionType.LLM_GENERATION,
+ input_sources=(1,),
+ final_output=True,
+ final_output_type="audio",
+ hf_config_name="thinker_config",
+ engine_output_type="audio",
+ custom_process_input_func=f"{_PROC}.talker2code2wav",
+ sampling_constraints={"detokenize": True},
+ ),
+ ),
+)
diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
index 7df69479734..f06ecf41d22 100644
--- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
+++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py
@@ -180,6 +180,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
"trailing_text_hidden",
"tts_pad_embed_projected",
}
+ # Keys that need to be accumulated across streaming inputs
+ self.streaming_accumulated_keys: set[str] = {
+ "thinker_prefill_embeddings",
+ "thinker_hidden_states",
+ }
elif self.model_stage == "code2wav":
self.enable_update_additional_information = True
@@ -891,10 +896,11 @@ def _thinker_to_talker_prefill(
Returns:
(input_ids, input_embeds) for talker
"""
+ target_len = thinker_result_ids.shape[-1]
im_start_indexes = torch.cat(
(
torch.nonzero(input_ids[0] == self.config.im_start_token_id).squeeze(),
- torch.tensor([thinker_result_ids.shape[-1]], device=input_ids.device, dtype=input_ids.dtype),
+ torch.tensor([target_len], device=input_ids.device, dtype=input_ids.dtype),
),
dim=-1,
) # Shape [n_starts + 1]; Take batch 0 since batched inference is not supported here.
@@ -1029,8 +1035,35 @@ def talker_preprocess_decode(
return last_talker_hidden, text_step, update_dict
def _get_talker_user_parts(self, im_start_index, segment_end_index, multimodal_mask, thinker_hidden, thinker_embed):
+ clamped = min(
+ segment_end_index,
+ multimodal_mask.shape[0],
+ thinker_hidden.shape[0],
+ thinker_embed.shape[0],
+ )
+ if clamped < segment_end_index:
+ logger.warning(
+ "_get_talker_user_parts: segment_end_index %d clamped to %d "
+ "(embed=%d, hidden=%d, mask=%d). "
+ "This usually means _merge_pd_embeddings failed to merge "
+ "prefill embeddings – check PD prefill_mm keys.",
+ segment_end_index,
+ clamped,
+ thinker_embed.shape[0],
+ thinker_hidden.shape[0],
+ multimodal_mask.shape[0],
+ )
+ segment_end_index = clamped
+ seg_len = segment_end_index - im_start_index
+ if seg_len <= 0:
+ return torch.empty(
+ (0, self.config.talker_config.text_config.hidden_size),
+ device=thinker_hidden.device,
+ dtype=torch.bfloat16,
+ )
+
user_talker_part = torch.empty(
- (segment_end_index - im_start_index, self.config.talker_config.text_config.hidden_size),
+ (seg_len, self.config.talker_config.text_config.hidden_size),
device=thinker_hidden.device,
dtype=torch.bfloat16,
)
diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.py b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py
new file mode 100644
index 00000000000..5051715ceac
--- /dev/null
+++ b/vllm_omni/model_executor/models/qwen3_tts/pipeline.py
@@ -0,0 +1,48 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Qwen3-TTS pipeline: Talker (text → RVQ codec) → Code2Wav (codec → audio).
+
+Chunked vs end-to-end mode is dispatched from ``deploy.async_chunk``.
+"""
+
+from vllm_omni.config.stage_config import (
+ PipelineConfig,
+ StageExecutionType,
+ StagePipelineConfig,
+)
+
+_PROC = "vllm_omni.model_executor.stage_input_processors.qwen3_tts"
+
+QWEN3_TTS_PIPELINE = PipelineConfig(
+ model_type="qwen3_tts",
+ # Pipeline-level default; the code2wav stage overrides per-stage below.
+ model_arch="Qwen3TTSTalkerForConditionalGeneration",
+ stages=(
+ StagePipelineConfig(
+ stage_id=0,
+ model_stage="qwen3_tts",
+ execution_type=StageExecutionType.LLM_AR,
+ input_sources=(),
+ owns_tokenizer=True,
+ engine_output_type="latent",
+ async_chunk_process_next_stage_input_func=(f"{_PROC}.talker2code2wav_async_chunk"),
+ sampling_constraints={
+ "detokenize": False,
+ "stop_token_ids": [2150],
+ },
+ ),
+ StagePipelineConfig(
+ stage_id=1,
+ model_stage="code2wav",
+ execution_type=StageExecutionType.LLM_GENERATION,
+ input_sources=(0,),
+ final_output=True,
+ final_output_type="audio",
+ engine_output_type="audio",
+ model_arch="Qwen3TTSCode2Wav",
+ sync_process_input_func=f"{_PROC}.talker2code2wav",
+ sampling_constraints={"detokenize": True},
+ extras={"tts_args": {"max_instructions_length": 500}},
+ ),
+ ),
+)
diff --git a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml b/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml
deleted file mode 100644
index fd8ea3a3f4e..00000000000
--- a/vllm_omni/model_executor/models/qwen3_tts/pipeline.yaml
+++ /dev/null
@@ -1,93 +0,0 @@
-model_type: qwen3_tts
-async_chunk: true
-
-stages:
- - stage_id: 0
- model_stage: qwen3_tts
- stage_type: llm
- is_comprehension: true
- input_sources: []
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 10
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- hf_overrides:
- architectures: [Qwen3TTSTalkerForConditionalGeneration]
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.08
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- model_stage: code2wav
- stage_type: llm
- input_sources: [0]
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- final_output: true
- final_output_type: audio
- runtime:
- devices: "0"
- engine_args:
- max_num_seqs: 1
- model_arch: Qwen3TTSCode2Wav
- hf_overrides:
- architectures: [Qwen3TTSCode2Wav]
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.08
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 65536
- max_model_len: 65536
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- codec_streaming: true
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- codec_chunk_frames: 25
- # Match the decoder sliding attention window to avoid chunk-boundary noise.
- codec_left_context_frames: 72
-
-edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
index 6b7b688f15a..d9cbcf7d4ef 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_talker.py
@@ -23,6 +23,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.qwen3 import Qwen3Model
from vllm.model_executor.models.utils import AutoWeightsLoader, PPMissingLayer, WeightsMapper, maybe_prefix
+from vllm.multimodal.audio import AudioResampler
from vllm.sequence import IntermediateTensors
from vllm_omni.model_executor.models.output_templates import OmniOutput
@@ -1094,9 +1095,8 @@ def _extract_speaker_embedding(self, wav: np.ndarray, sr: int) -> torch.Tensor:
# Resample to 24kHz for speaker encoder.
target_sr = int(getattr(self.config.speaker_encoder_config, "sample_rate", 24000))
if sr != target_sr:
- from vllm.multimodal.audio import resample_audio_resampy
-
- wav = resample_audio_resampy(wav.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
+ resampler = AudioResampler(target_sr=target_sr)
+ wav = resampler.resample(wav.astype(np.float32), orig_sr=int(sr))
sr = target_sr
# Follow official implementation: mel_spectrogram expects 24kHz.
diff --git a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
index 3db5cfd1b82..14bfbc5eedf 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/qwen3_tts_tokenizer.py
@@ -22,7 +22,7 @@
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
-from vllm.multimodal.audio import resample_audio_resampy
+from vllm.multimodal.audio import AudioResampler
from vllm.multimodal.media.audio import load_audio as _load_audio_file
from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
@@ -161,7 +161,8 @@ def load_audio(
audio = np.mean(audio, axis=-1)
if sr != target_sr:
- audio = resample_audio_resampy(audio, orig_sr=sr, target_sr=target_sr)
+ resampler = AudioResampler(target_sr=target_sr)
+ audio = resampler.resample(audio, orig_sr=sr)
return audio.astype(np.float32)
@@ -209,7 +210,8 @@ def _normalize_audio_inputs(
if a.ndim > 1:
a = np.mean(a, axis=-1)
if int(sr) != target_sr:
- a = resample_audio_resampy(a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
+ resampler = AudioResampler(target_sr=target_sr)
+ a = resampler.resample(a.astype(np.float32), orig_sr=int(sr))
out.append(a.astype(np.float32))
return out
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
index 92cecbff107..f7e664c74d6 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py
@@ -23,10 +23,11 @@
import torchaudio.compliance.kaldi as kaldi
from torch import Tensor
+from vllm_omni.model_executor.models.whisper_utils import Conv1d, ConvTranspose1d
from vllm_omni.utils.audio import mel_filter_bank, peak_normalize
from .core_vq import DistributedGroupResidualVectorQuantization
-from .whisper_encoder import Conv1d, ConvTranspose1d, WhisperEncoder
+from .whisper_encoder import WhisperEncoder
def dynamic_range_compression_torch(x, c=1, clip_val=1e-5):
diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
index 8464f53c9df..7756720b2ba 100644
--- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
+++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py
@@ -23,6 +23,7 @@
from torch import Tensor, nn
from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func
+from vllm_omni.model_executor.models.whisper_utils import Conv1d, Linear, sinusoids
from vllm_omni.utils.audio import mel_filter_bank
N_FFT = 400
@@ -102,30 +103,6 @@ def get_mel_audio(audio, padding=False, audio_vq_ds_rate=1, n_mels=128):
return mel
-def sinusoids(length, channels, max_timescale=10000):
- """Returns sinusoids for positional embedding"""
- assert channels % 2 == 0
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
-
-
-class Conv1d(nn.Conv1d):
- def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor:
- return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
-
-
-class ConvTranspose1d(nn.ConvTranspose1d):
- def _conv_forward(self, x: Tensor, weight: Tensor, bias: Tensor | None) -> Tensor:
- return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
-
-
-class Linear(nn.Linear):
- def forward(self, x: Tensor) -> Tensor:
- return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
-
-
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int, use_flash_attention: bool = True):
super().__init__()
diff --git a/vllm_omni/model_executor/models/registry.py b/vllm_omni/model_executor/models/registry.py
index 3407b428695..5a466dbd62e 100644
--- a/vllm_omni/model_executor/models/registry.py
+++ b/vllm_omni/model_executor/models/registry.py
@@ -174,6 +174,23 @@
"dynin_omni",
"DyninOmniForConditionalGeneration",
),
+ ## Ming-flash-omni-2.0
+ "MingFlashOmniForConditionalGeneration": (
+ "ming_flash_omni",
+ "ming_flash_omni",
+ "MingFlashOmniForConditionalGeneration",
+ ),
+ "MingFlashOmniThinkerForConditionalGeneration": (
+ "ming_flash_omni",
+ "ming_flash_omni_thinker",
+ "MingFlashOmniThinkerForConditionalGeneration",
+ ),
+ # Alias: HF repo currently ships this architecture name in config.json
+ "BailingMM2NativeForConditionalGeneration": (
+ "ming_flash_omni",
+ "ming_flash_omni",
+ "MingFlashOmniForConditionalGeneration",
+ ),
}
diff --git a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
index b666e41ebc9..3724528898a 100644
--- a/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
+++ b/vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
@@ -13,6 +13,7 @@
import copy
import dataclasses
import logging
+import math
import os
import time
from collections.abc import Iterable
@@ -40,6 +41,11 @@
_ENABLE_PROFILING = os.environ.get("VOXCPM2_PROFILE", "0") == "1"
+# Lower bound for the _active_states leak-warn threshold. The effective
+# threshold is max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * max_batch_size) so small
+# deployments still get a usable floor instead of a tiny noisy one.
+_ACTIVE_STATE_LEAK_WARN_MIN = 512
+
def is_cjk_char(c: str) -> bool:
"""Check if a character is a CJK ideograph."""
@@ -80,6 +86,44 @@ def split_multichar_chinese(token_ids: list[int], split_map: dict[int, list[int]
return result
+def build_voxcpm2_prompt(
+ hf_config: Any,
+ tokenizer: Any,
+ split_map: dict[int, list[int]],
+ text: str,
+ ref_audio: Any | None = None,
+ ref_sr: int | None = None,
+ ref_text: str | None = None,
+) -> dict[str, Any]:
+ """Build a VoxCPM2 prefill prompt whose ``prompt_token_ids`` length matches
+ the talker-side prefill length.
+
+ Used by both online serving (``serving_speech._build_voxcpm2_prompt``) and
+ the offline example, so the talker-side length assertion never fires.
+ """
+ ids = split_multichar_chinese(tokenizer.encode(text, add_special_tokens=True), split_map)
+ bos = tokenizer.bos_token_id
+ if ids and ids[0] == bos:
+ ids = ids[1:]
+ prefill_len = len(ids) + 1 # + audio_start
+ additional: dict[str, Any] = {"text_token_ids": [ids]}
+ if ref_audio is not None:
+ vae = hf_config.audio_vae_config
+ patch_samples = hf_config.patch_size * math.prod(vae["encoder_rates"])
+ ref_len = math.ceil(math.ceil(len(ref_audio) * vae["sample_rate"] / ref_sr) / patch_samples)
+ if ref_text is not None:
+ additional["prompt_audio"] = [[ref_audio, ref_sr]]
+ additional["prompt_text"] = [ref_text]
+ ref_ids = split_multichar_chinese(tokenizer.encode(ref_text, add_special_tokens=True), split_map)
+ if ref_ids and ref_ids[0] == bos:
+ ref_ids = ref_ids[1:]
+ prefill_len += ref_len + len(ref_ids)
+ else:
+ additional["reference_audio"] = [[ref_audio, ref_sr]]
+ prefill_len += ref_len + 2 # ref_start / ref_end
+ return {"prompt_token_ids": [1] * prefill_len, "additional_information": additional}
+
+
def _encode_raw_audio(
tts: nn.Module,
samples: list[float] | torch.Tensor,
@@ -401,6 +445,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self._results_queue: list[tuple[str, torch.Tensor | None]] = []
self._audio_queue: list[tuple[str, Any]] = []
self._deferred_cleanup_ids: set[str] = set()
+ self._active_state_warn_threshold = max(_ACTIVE_STATE_LEAK_WARN_MIN, 4 * self._max_batch_size)
+ # one-shot by design: fires at most once per process to avoid log spam.
+ self._active_state_warned = False
@property
def tts(self) -> nn.Module:
@@ -410,9 +457,20 @@ def tts(self) -> nn.Module:
# -------------------- request state management --------------------
def _get_or_create_state(self, request_id: str) -> _RequestState:
- if request_id not in self._active_states:
- self._active_states[request_id] = _RequestState(request_id=request_id)
- return self._active_states[request_id]
+ state = self._active_states.get(request_id)
+ if state is None:
+ state = _RequestState(request_id=request_id)
+ self._active_states[request_id] = state
+ if len(self._active_states) > self._active_state_warn_threshold and not self._active_state_warned:
+ logger.warning(
+ "VoxCPM2: _active_states size=%d exceeds threshold %d "
+ "(max_batch_size=%d); possible cleanup path leak",
+ len(self._active_states),
+ self._active_state_warn_threshold,
+ self._max_batch_size,
+ )
+ self._active_state_warned = True
+ return state
def _switch_to_request(self, request_id: str) -> _RequestState:
if request_id != self._current_request_id:
@@ -793,19 +851,12 @@ def _prepare_residual_prefill(self, state: _RequestState, base_lm_out: torch.Ten
tts_len = text_mask.shape[1]
scaffold_len = base_lm_out.shape[0]
-
- if scaffold_len < tts_len:
- # Voice clone / continuation: scaffold only processed vllm tokens.
- # Pad to match TTS sequence length (extra positions are masked out).
- pad = torch.zeros(
- tts_len - scaffold_len,
- base_lm_out.shape[-1],
- device=base_lm_out.device,
- dtype=base_lm_out.dtype,
- )
- enc_out = torch.cat([base_lm_out, pad], dim=0).unsqueeze(0)
- else:
- enc_out = base_lm_out.unsqueeze(0)
+ assert scaffold_len == tts_len, (
+ f"voxcpm2 prefill length mismatch: scaffold_len={scaffold_len} tts_len={tts_len}; "
+ "caller must pad prompt_token_ids to the full prefill length "
+ "(see serving_speech._build_voxcpm2_prompt or the offline example)."
+ )
+ enc_out = base_lm_out.unsqueeze(0)
prefix_feat_cond = (
feat[:, -1, ...]
@@ -1055,15 +1106,12 @@ def preprocess(
is_prefill = span_len > 1
if is_prefill:
- # Evict stale states
- pending_ids = {rid for rid, *_ in self._pending_requests}
- pending_ids.add(req_id)
- if self._current_request_id:
- pending_ids.add(self._current_request_id)
- for rid in [r for r, s in self._active_states.items() if r not in pending_ids and s.prefill_completed]:
- self._cleanup_request(rid)
-
- token_ids = input_ids.tolist()
+ # Do not evict state here: _pending_requests is a per-step prefix,
+ # not the full batch. Cleanup is driven by on_requests_finished ->
+ # _flush_deferred_cleanup (fed by vLLM scheduler._free_request via
+ # gpu_ar_model_runner.py).
+ real = info_dict.get("text_token_ids")
+ token_ids = input_ids.tolist() if real is None else real[0]
# Fail-fast: unsplit multichar Chinese IDs in input_ids means the
# serving layer didn't pre-split. Silent fixup here would cause
# input_ids/embeds length mismatch (scheduler slot count is fixed).
diff --git a/vllm_omni/model_executor/models/whisper_utils.py b/vllm_omni/model_executor/models/whisper_utils.py
new file mode 100644
index 00000000000..5aa2fc8a3ad
--- /dev/null
+++ b/vllm_omni/model_executor/models/whisper_utils.py
@@ -0,0 +1,39 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright (c) 2022 OpenAI
+#
+# Shared Whisper encoder primitives used by multiple model implementations.
+# Originally from the OpenAI Whisper codebase.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding."""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class Conv1d(nn.Conv1d):
+ """Conv1d with automatic dtype casting for mixed precision inference."""
+
+ def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
+ return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
+
+
+class ConvTranspose1d(nn.ConvTranspose1d):
+ def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
+ return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
+
+
+class Linear(nn.Linear):
+ """Linear layer with automatic dtype casting for mixed precision inference."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
diff --git a/vllm_omni/model_executor/stage_configs/bagel.yaml b/vllm_omni/model_executor/stage_configs/bagel.yaml
index dfe9da1c26d..75f7c8a0637 100644
--- a/vllm_omni/model_executor/stage_configs/bagel.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel.yaml
@@ -71,10 +71,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -104,4 +100,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
index af038f59fb8..7a0d851f0fd 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_multiconnector.yaml
@@ -64,10 +64,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -104,4 +100,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml b/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
index bb24763f906..b2d4b07b13b 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_single_stage.yaml
@@ -22,6 +22,3 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_think.yaml b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
index 0d2098a2034..2575e6736dd 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_think.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_think.yaml
@@ -65,9 +65,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
shared_memory_connector:
@@ -78,4 +75,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
index 33002b9aa5c..4599f8b059c 100644
--- a/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
+++ b/vllm_omni/model_executor/stage_configs/bagel_usp2.yaml
@@ -62,9 +62,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
shared_memory_connector:
name: SharedMemoryConnector
@@ -73,4 +70,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml b/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml
new file mode 100644
index 00000000000..b7d0aeeb742
--- /dev/null
+++ b/vllm_omni/model_executor/stage_configs/bailingmm_moe_v2_lite.yaml
@@ -0,0 +1,46 @@
+# Stage config for Ming-flash-omni-2.0
+# Stage 0: Thinker (Multimodal understanding + text generation)
+# Stage 1a: Image Generator (Text embeddings -> PIL image)
+# Stage 1b: Talker (Text embeddings -> audio waveform)
+
+async_chunk: false
+stage_args:
+ - stage_id: 0
+ stage_type: llm
+ runtime:
+ devices: "0,1,2,3"
+ max_batch_size: 1
+ engine_args:
+ model_stage: thinker
+ model_arch: MingFlashOmniForConditionalGeneration
+ # tokenizer_subdir: talker/llm
+ worker_type: ar
+ scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
+ gpu_memory_utilization: 0.9
+ enforce_eager: false
+ trust_remote_code: true
+ engine_output_type: latent
+ distributed_executor_backend: "mp"
+ enable_prefix_caching: false
+ max_num_batched_tokens: 32768
+ tensor_parallel_size: 4 # Use 4 GPUs for MoE model
+ # pipeline_parallel_size: 4
+ hf_config_name: llm_config
+ compilation_config:
+ pass_config:
+ # there's a version mismatch regarding vllm and flashinfer
+ # disable fuse allreduce for now
+ fuse_allreduce_rms: false
+ final_output: true # Can output text directly
+ final_output_type: text
+ is_comprehension: true
+ default_sampling_params:
+ temperature: 0.4
+ top_p: 0.9
+ max_tokens: 2048
+ repetition_penalty: 1.05
+ seed: 42
+ detokenize: true
+
+ # Future Stage 1a: Image Generator (Optional - not yet implemented)
+ # Future Stage 1b: Talker/TTS (Optional - not yet implemented)
diff --git a/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml
index ca7e9850ae7..13419ef107f 100644
--- a/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml
+++ b/vllm_omni/model_executor/stage_configs/cosyvoice3_async_chunk.yaml
@@ -63,9 +63,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -82,4 +79,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
index 0724146aa7c..131a0d1cd70 100644
--- a/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
+++ b/vllm_omni/model_executor/stage_configs/dynin_omni.yaml
@@ -67,14 +67,9 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
index 7259daa9ea0..4a54f8188aa 100644
--- a/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
+++ b/vllm_omni/model_executor/stage_configs/dynin_omni_multiconnector.yaml
@@ -71,9 +71,6 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
####
# same as Qwen2.5_omni version
# Distributed connectors configuration (optional)
@@ -108,7 +105,5 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
- from: 1
to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
index 0b0b2785928..90f80c22d7a 100644
--- a/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
+++ b/vllm_omni/model_executor/stage_configs/fish_speech_s2_pro.yaml
@@ -71,10 +71,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 16
-
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
@@ -93,4 +89,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/glm_image.yaml b/vllm_omni/model_executor/stage_configs/glm_image.yaml
index 3cc23e1e251..05ac84a7a09 100644
--- a/vllm_omni/model_executor/stage_configs/glm_image.yaml
+++ b/vllm_omni/model_executor/stage_configs/glm_image.yaml
@@ -70,11 +70,6 @@ stage_args:
# Top-level runtime config
runtime:
enabled: true
- defaults:
- window_size: -1 # Trigger downstream only after full upstream completion
- max_inflight: 1 # Process serially within each stage
-
edges:
- from: 0 # AR → Diffusion: trigger after AR completes
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
index 719c73a9fc0..7bd66c403fc 100644
--- a/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
+++ b/vllm_omni/model_executor/stage_configs/glm_image_muilticonnector.yaml
@@ -70,14 +70,9 @@ stage_args:
# Top-level runtime config with MultiConnector support
runtime:
enabled: true
- defaults:
- window_size: -1 # Trigger downstream only after full upstream completion
- max_inflight: 1 # Process serially within each stage
-
edges:
- from: 0 # AR → Diffusion
to: 1
- window_size: -1
# OmniConnector configuration for efficient inter-stage tensor transfer
connectors:
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
index 203b54f2574..b68b184ec31 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_i2t.yaml
@@ -39,6 +39,3 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
index 9f6adece0fe..413e0f09cbe 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_it2i.yaml
@@ -69,10 +69,6 @@ stage_args:
# Top-level runtime config
runtime:
enabled: true
- defaults:
- window_size: -1 # Trigger downstream only after full upstream completion
- max_inflight: 1 # Process serially within each stage
edges:
- from: 0 # AR → Diffusion
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
index aeef27a9746..586b601bc5a 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_moe_dit_2gpu_fp8.yaml
@@ -30,6 +30,3 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
index a60fe9a5b5b..1d8c7f4812d 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i.yaml
@@ -29,6 +29,3 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
index e029c383623..41ed74ba62a 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2i_2gpu.yaml
@@ -39,6 +39,3 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
index 60da8e0bc7f..a0a1a0dc1c4 100644
--- a/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
+++ b/vllm_omni/model_executor/stage_configs/hunyuan_image3_t2t.yaml
@@ -40,6 +40,3 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
index b3c6bbbaf04..2fa1b982af8 100644
--- a/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
+++ b/vllm_omni/model_executor/stage_configs/mimo_audio_async_chunk.yaml
@@ -74,10 +74,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
@@ -93,4 +89,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 0a307b44778..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,107 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x H100-80G GPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "0" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- async_scheduling: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml
deleted file mode 100644
index 6e4f871e38d..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen2_5_omni_multiconnector.yaml
+++ /dev/null
@@ -1,141 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 1x H100-80G GPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- mm_processor_cache_gb: 0
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- # Distributed connector configuration (optional)
- output_connectors:
- to_stage_1: mooncake_connector
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- # Distributed connector configuration (optional)
- input_connectors:
- from_stage_0: mooncake_connector
- output_connectors:
- to_stage_2: mooncake_connector
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.3
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- # Distributed connector configuration (optional)
- input_connectors:
- from_stage_1: mooncake_connector
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- # Distributed connectors configuration (optional)
- # More connectors will be supported in the future.
- connectors:
- # Mooncake connector for cross-node/intra-node communication
- mooncake_connector:
- name: MooncakeStoreConnector
- extra:
- host: "127.0.0.1"
- metadata_server: "http://10.90.67.86:8080/metadata"
- master: "10.90.67.86:50051"
- segment: 512000000 # 512MB
- localbuf: 64000000 # 64MB
- proto: "tcp"
-
- # Yuanrong connector for cross-node/intra-node communication
- yuanrong_connector:
- name: YuanrongConnector
- extra:
- host: "127.0.0.1"
- port: "35000"
-
- # SharedMemory connector for intra-node communication
- # Alternative SHM connector with different threshold
- shared_memory_connector:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536 # 64KB threshold
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 0ce4f0c94fd..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 32
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
deleted file mode 100644
index 38626fc081e..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk.yaml
+++ /dev/null
@@ -1,117 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
- final_output: true
- final_output_type: text
- is_comprehension: true
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
- engine_input_source: [0]
- # final_output: true
- # final_output_type: text
- # Distributed connector configuration
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 51200 # [TODO] if max_num_batch_tokens < max_num_seqs * 800, there will be precision problem.
- hf_config_name: thinker_config
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-runtime:
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- # Align with Omni: small chunks with sufficient context overlap.
- codec_chunk_frames: 25 # code2wav decode chunk size
- codec_left_context_frames: 25 # code2wav left context size
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk_multi_replicas.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk_multi_replicas.yaml
deleted file mode 100644
index b80d19460d0..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_async_chunk_multi_replicas.yaml
+++ /dev/null
@@ -1,123 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# and multi-replica scale-out on stage 1 (talker) and stage 2 (code2wav).
-#
-# Stage 0: Thinker (multimodal understanding + text generation) — 1 replica, GPU 0
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes) — 2 replicas, GPU 1+2
-# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform) — 2 replicas, GPU 1+2 (shared)
-#
-# Hardware: 3x H20-96G GPUs (GPU 0 for thinker, GPU 1+2 shared by talker + code2wav replicas).
-# Note: stage 1 and stage 2 share GPU 1+2. Code2Wav uses gpu_memory_utilization=0.1
-# so the combined footprint (talker 0.6 + code2wav 0.1) fits within a single GPU.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
- final_output: true
- final_output_type: text
- is_comprehension: true
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1,2"
- num_replicas: 2
- engine_args:
- model_stage: talker
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
- engine_input_source: [0]
- # final_output: true
- # final_output_type: text
- # Distributed connector configuration
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1,2"
- num_replicas: 2
- engine_args:
- model_stage: code2wav
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 51200 # [TODO] if max_num_batch_tokens < max_num_seqs * 800, there will be precision problem.
- hf_config_name: thinker_config
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-runtime:
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- # Align with Omni: small chunks with sufficient context overlap.
- codec_chunk_frames: 25 # code2wav decode chunk size
- codec_left_context_frames: 25 # code2wav left context size
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml b/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml
deleted file mode 100644
index 6c2d2a7669d..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_omni_moe_multiconnector.yaml
+++ /dev/null
@@ -1,143 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings -> 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes -> audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- # Distributed connector configuration
- output_connectors:
- to_stage_1: mooncake_connector
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
- # Distributed connector configuration
- input_connectors:
- from_stage_0: mooncake_connector
- output_connectors:
- to_stage_2: mooncake_connector
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 64
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- # Distributed connector configuration
- input_connectors:
- from_stage_1: mooncake_connector
-
-# Top-level runtime config: default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
- # Distributed connectors configuration
- connectors:
- # Mooncake connector for cross-node/intra-node communication
- mooncake_connector:
- name: MooncakeStoreConnector
- extra:
- host: "127.0.0.1"
- metadata_server: "http://10.90.67.86:8080/metadata"
- master: "10.90.67.86:50051"
- segment: 512000000 # 512MB
- localbuf: 64000000 # 64MB
- proto: "tcp"
-
- # SharedMemory connector for intra-node communication
- shared_memory_connector:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536 # 64KB threshold
-
- edges:
- - from: 0
- to: 1
- window_size: -1
- - from: 1
- to: 2
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
deleted file mode 100644
index a0d38eb4b9f..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts.yaml
+++ /dev/null
@@ -1,99 +0,0 @@
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- model_stage: qwen3_tts
- max_num_seqs: 10
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- # Must be divisible by num_code_groups and cover (left_context + chunk).
- # Prefill length is Q * num_frames (e.g. 16 * 2148 = 34368); keep headroom past 32k.
- max_num_batched_tokens: 65536
- # async_chunk appends windows per step; max_model_len must cover accumulated flat codec stream.
- max_model_len: 65536
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- # Distributed connector configuration
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- # Frame-aligned codec streaming transport.
- codec_streaming: true
- # Connector polling / timeout (unit: loop count, sleep interval in seconds).
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- # Match the decoder sliding attention window to avoid chunk-boundary noise.
- codec_chunk_frames: 25
- codec_left_context_frames: 72
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
deleted file mode 100644
index 75b2bab3a27..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts_batch.yaml
+++ /dev/null
@@ -1,100 +0,0 @@
-# Same as qwen3_tts.yaml with batched talker and code2wav.
-# Stage 0: max_num_seqs 4, stage 1: max_num_seqs 4.
-# max_num_seqs must be a power of two to align with CUDA graph capture sizes
-# (stage 0) and must match --batch-size in end2end.py / benchmark scripts.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- model_stage: qwen3_tts
- max_num_seqs: 4
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 4
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: true
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- # Must be divisible by num_code_groups and cover (left_context + chunk).
- max_num_batched_tokens: 65536
- # Flat codec prompt can exceed 32k tokens (Q * frames); align with max_tokens below.
- max_model_len: 65536
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- # Distributed connector configuration
- input_connectors:
- from_stage_0: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- # Frame-aligned codec streaming transport.
- codec_streaming: true
- # Connector polling / timeout (unit: loop count, sleep interval in seconds).
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- # Match the decoder sliding attention window to avoid chunk-boundary noise.
- codec_chunk_frames: 25
- codec_left_context_frames: 72
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
deleted file mode 100644
index 3f412fc4dca..00000000000
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts_no_async_chunk.yaml
+++ /dev/null
@@ -1,64 +0,0 @@
-async_chunk: false
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- model_stage: qwen3_tts
- max_num_seqs: 1
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: false
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 65536
- max_model_len: 65536
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav
- final_output: true
- final_output_type: audio
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
diff --git a/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
index d2e920806d2..4ca8d11ad77 100644
--- a/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
+++ b/vllm_omni/model_executor/stage_configs/qwen3_tts_uniproc.yaml
@@ -72,9 +72,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -94,4 +91,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
index cf78d4e4381..c6fd177a359 100644
--- a/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
+++ b/vllm_omni/model_executor/stage_configs/voxcpm_async_chunk.yaml
@@ -77,9 +77,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
voxcpm_shm:
@@ -99,4 +96,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
index 31cccb9ccfd..b0d9a81e782 100644
--- a/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
+++ b/vllm_omni/model_executor/stage_configs/voxtral_tts.yaml
@@ -82,10 +82,6 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
connectors:
connector_of_shared_memory:
name: SharedMemoryConnector
@@ -102,4 +98,3 @@ runtime:
edges:
- from: 0 # language_model → acoustic_transformer: trigger only after receiving full input (-1)
to: 1
- window_size: -1
diff --git a/vllm_omni/model_executor/stage_input_processors/bagel.py b/vllm_omni/model_executor/stage_input_processors/bagel.py
index bfcff0ea0f3..52cc14d3aa2 100644
--- a/vllm_omni/model_executor/stage_input_processors/bagel.py
+++ b/vllm_omni/model_executor/stage_input_processors/bagel.py
@@ -82,6 +82,8 @@ def expand_cfg_prompts(
neg_prompt = _get_negative_prompt(prompt, sampling_params)
if "image" in modalities:
+ if not neg_prompt:
+ return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
@@ -166,6 +168,8 @@ def expand_cfg_prompts_think(
companion_params = {"max_tokens": 1}
if "image" in modalities:
+ if not neg_prompt:
+ return []
neg_prompt_dict = {
"prompt": neg_prompt,
"modalities": prompt.get("modalities", []),
@@ -287,9 +291,10 @@ def _get_negative_prompt(
) -> str:
"""Resolve the negative prompt for CFG from prompt or sampling params.
- An empty string is treated the same as absent (falls through to
- the Bagel default token pair), because an empty negative prompt is
- not meaningful for CFG guidance.
+ Returns the negative prompt string when one is supplied, otherwise an
+ empty string. Callers decide how to treat the empty case: text2img
+ skips the cfg_text companion entirely, while img2img substitutes it
+ into the cfg_text prompt template.
"""
neg = prompt.get("negative_prompt")
if neg:
@@ -300,4 +305,4 @@ def _get_negative_prompt(
if neg:
return neg
- return "<|im_start|><|im_end|>"
+ return ""
diff --git a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py
index 49e4bedc2fd..78cd535cb66 100644
--- a/vllm_omni/model_executor/stage_input_processors/dynin_omni.py
+++ b/vllm_omni/model_executor/stage_input_processors/dynin_omni.py
@@ -107,9 +107,7 @@ def _bridge_tokens(
if not token_ids:
token_ids = list(getattr(output, "token_ids", []) or [])
if not token_ids:
- raise RuntimeError(
- f"Stage output for request {source_output.request_id} has no token_ids"
- )
+ raise RuntimeError(f"Stage output for request {source_output.request_id} has no token_ids")
detok_id = _to_int(mm_out.get("detok_id"), default=0)
src_prompt = prompt_meta_by_reqid.get(source_output.request_id, {})
diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
index d13db3d046b..7a903be2ff8 100644
--- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
+++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py
@@ -3,6 +3,8 @@
# Copyright 2025 The Qwen team.
"""Stage input processor for Qwen3 Omni MoE: Thinker → Talker transition."""
+import logging
+from dataclasses import dataclass, field
from typing import Any
import torch
@@ -18,6 +20,12 @@
extract_speaker_from_request,
)
+logger = logging.getLogger(__name__)
+
+# Pooling output layer keys: "0" = word embedding, "24" = accept_hidden_layer
+_EMBED_LAYER_KEY = "0"
+_HIDDEN_LAYER_KEY = "24"
+
def _compute_talker_prompt_ids_length(info, device: torch.device | str = "cuda") -> int:
im_start_token_id = 151644
@@ -69,6 +77,183 @@ def _ensure_list(x):
return list(x)
+# =========================
+# PD disaggregation helpers
+# =========================
+
+
+def _get_prefill_multimodal_output(
+ request_id: str,
+ streaming_context: Any | None,
+) -> dict[str, Any] | None:
+ bridge_states = getattr(streaming_context, "bridge_states", None)
+ if not isinstance(bridge_states, dict):
+ return None
+ by_req = bridge_states.get("pd_prefill_multimodal_output_by_req")
+ if not isinstance(by_req, dict):
+ return None
+ prefill_mm = by_req.get(request_id)
+ return prefill_mm if isinstance(prefill_mm, dict) else None
+
+
+def _merge_pd_embeddings(
+ decode_emb: torch.Tensor,
+ decode_hid: torch.Tensor,
+ prefill_mm: dict[str, Any],
+ device: torch.device,
+ expected_total: int | None = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Merge prefill prompt embeddings with decode generated embeddings.
+
+ In PD mode the prefill engine processes the prompt and the decode engine
+ generates tokens starting from position 1. This function concatenates
+ them, removing the overlapping token(s):
+
+ merged = prefill[:P] + decode[overlap:]
+
+ where overlap = P + D - expected_total.
+ """
+ try:
+ p_emb = prefill_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float)
+ p_hid = prefill_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float)
+ except (KeyError, AttributeError, TypeError) as exc:
+ available_keys = list(prefill_mm.keys()) if isinstance(prefill_mm, dict) else type(prefill_mm).__name__
+ logger.error(
+ "_merge_pd_embeddings: failed to extract prefill embeddings (%s). "
+ "Expected keys %r and %r, got: %s. "
+ "Falling back to decode-only embeddings – talker user-segment will be degraded.",
+ exc,
+ _EMBED_LAYER_KEY,
+ _HIDDEN_LAYER_KEY,
+ available_keys,
+ )
+ return decode_emb, decode_hid
+
+ if p_emb.shape[0] == 0 or decode_emb.shape[0] == 0:
+ return decode_emb, decode_hid
+
+ raw_total = p_emb.shape[0] + decode_emb.shape[0]
+ overlap = max(0, raw_total - expected_total) if expected_total is not None else 0
+
+ merged_emb = torch.cat([p_emb, decode_emb[overlap:]], dim=0)
+ merged_hid = torch.cat([p_hid, decode_hid[overlap:]], dim=0)
+ return merged_emb, merged_hid
+
+
+def _resolve_tts_token_embedding(
+ key: str,
+ *,
+ thinker_mm: dict[str, Any],
+ prefill_mm: dict[str, Any] | None,
+ device: torch.device,
+) -> torch.Tensor | None:
+ """Return TTS BOS/EOS/PAD embedding tensors for the talker projection path.
+
+ Values are taken from the current thinker (decode) ``multimodal_output``; in
+ PD mode, missing keys may be filled from the paired prefill stage output.
+ """
+ val = thinker_mm.get(key)
+ if val is None and prefill_mm is not None:
+ val = prefill_mm.get(key)
+ return val.detach().to(device=device, dtype=torch.float) if val is not None else None
+
+
+# =========================
+# Streaming input helpers
+# =========================
+
+
+@dataclass
+class _Thinker2TalkerStreamingState:
+ last_prompt_len: int = 0
+ last_output_len: int = 0
+ merged_sequences: list[int] = field(default_factory=list)
+
+
+@dataclass
+class _Qwen3OmniStreamingState:
+ thinker2talker: _Thinker2TalkerStreamingState = field(default_factory=_Thinker2TalkerStreamingState)
+ talker2code2wav_last_seq_len: int = 0
+
+
+def _get_qwen3_streaming_state(
+ request_id: str,
+ streaming_context: Any | None,
+) -> _Qwen3OmniStreamingState:
+ bridge_states = getattr(streaming_context, "bridge_states", None)
+ per_model_state = bridge_states.setdefault("qwen3_omni", {})
+ state = per_model_state.get(request_id)
+ if state is None:
+ state = _Qwen3OmniStreamingState()
+ per_model_state[request_id] = state
+ return state
+
+
+def _get_streaming_talker_tokens(
+ request_id: str,
+ prompt_token_ids: list[int],
+ output_token_ids: list[int],
+ new_prompt_len_snapshot: int | None = None,
+ streaming_context: Any | None = None,
+ *,
+ clear_state: bool = False,
+) -> tuple[list[int], list[int], list[int], list[int]]:
+ """Return streaming token slices and merged token views for thinker->talker.
+ e.g. For the second streaming input request:
+ merged_sequences: [input_prompt 1, output_tokens 1[:-1], input_prompt 2, output_tokens 2]
+ thinker_input_ids: [input_prompt 1, output_tokens 1[:-1], input_prompt 2]
+ Returns:
+ inc_prompt: prompt token delta for this segment.
+ inc_output: output token delta for this segment.
+ merged_sequences: full thinker_sequences to send downstream.
+ thinker_input_ids: full thinker_input_ids paired with merged_sequences.
+ """
+ state = _get_qwen3_streaming_state(request_id, streaming_context).thinker2talker
+ if new_prompt_len_snapshot:
+ prompt_token_ids = prompt_token_ids[:-new_prompt_len_snapshot]
+ cur_prompt_len = len(prompt_token_ids)
+ cur_output_len = len(output_token_ids)
+
+ inc_prompt = prompt_token_ids[state.last_prompt_len :]
+ inc_output = output_token_ids[state.last_output_len :]
+ delta_sequences = inc_prompt + inc_output
+ cached_sequences = state.merged_sequences
+
+ merged_sequences = cached_sequences + delta_sequences
+ thinker_input_ids = cached_sequences + inc_prompt
+
+ # Persist history for next segment. Drop the latest sampled token to keep
+ # thinker_input_ids / thinker_sequences alignment with next-step append.
+ cached_sequences.extend(delta_sequences[:-1])
+
+ state.last_prompt_len = cur_prompt_len
+ state.last_output_len = cur_output_len
+
+ if clear_state:
+ state.last_prompt_len = 0
+ state.last_output_len = 0
+ state.merged_sequences.clear()
+
+ return inc_prompt, inc_output, merged_sequences, thinker_input_ids
+
+
+def _get_streaming_codec_delta_len(
+ cur_seq_len: int,
+ request_id: str,
+ talker_output: Any,
+ streaming_context: Any | None = None,
+) -> int:
+ """Return newly added seq_len for talker->code2wav in streaming mode."""
+ state = _get_qwen3_streaming_state(request_id, streaming_context)
+ prev_seq_len = state.talker2code2wav_last_seq_len
+ seq_len = cur_seq_len - prev_seq_len
+ state.talker2code2wav_last_seq_len = cur_seq_len + 1
+ if bool(getattr(talker_output, "finished", False)):
+ # Final segment: clear history to avoid cross-session carry-over.
+ state.talker2code2wav_last_seq_len = 0
+ return seq_len
+
+
# =========================
# Thinker -> Talker
# =========================
@@ -96,8 +281,8 @@ def thinker2talker_async_chunk(
all_token_ids = _ensure_list(all_token_ids)
prompt_token_ids = _ensure_list(prompt_token_ids)
talker_additional_info = {
- "thinker_prefill_embeddings": pooling_output.get("0").detach().cpu(),
- "thinker_hidden_states": pooling_output.get("24").detach().cpu(),
+ "thinker_prefill_embeddings": pooling_output.get(_EMBED_LAYER_KEY).detach().cpu(),
+ "thinker_hidden_states": pooling_output.get(_HIDDEN_LAYER_KEY).detach().cpu(),
"thinker_sequences": all_token_ids,
"thinker_input_ids": prompt_token_ids,
# Provide thinker-side TTS token embeddings for talker projection
@@ -146,7 +331,7 @@ def thinker2talker_async_chunk(
if output_token_ids:
talker_additional_info["override_keys"] = ["thinker_decode_embeddings", "thinker_output_token_ids"]
- talker_additional_info["thinker_decode_embeddings"] = pooling_output.get("0").detach().cpu()
+ talker_additional_info["thinker_decode_embeddings"] = pooling_output.get(_EMBED_LAYER_KEY).detach().cpu()
talker_additional_info["thinker_output_token_ids"] = output_token_ids
else:
# When prefilling a chunked thinker, thinker_hidden_states needs to be updated.
@@ -160,6 +345,7 @@ def thinker2talker(
source_outputs: list[Any],
prompt: OmniTokensPrompt | TextPrompt | None = None,
requires_multimodal_data: bool = False,
+ streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""
Process thinker outputs to create talker inputs.
@@ -169,6 +355,9 @@ def thinker2talker(
2. Split hidden states into: prompt embeddings + generated embeddings
3. Package for talker with additional information
+ In PD disaggregation mode, merges prefill-stage prompt embeddings with
+ decode-stage generated embeddings before handing off to the talker.
+
Args:
prompt: Original prompt data
requires_multimodal_data: Whether multimodal data is required
@@ -184,18 +373,58 @@ def thinker2talker(
# Process each thinker output
for i, thinker_output in enumerate(thinker_outputs):
output = thinker_output.outputs[0]
+ req_id = str(getattr(thinker_output, "request_id", f"idx-{i}"))
+ prompt_token_ids = _ensure_list(thinker_output.prompt_token_ids)
+ output_ids = _ensure_list(output.token_ids)
+ is_streaming_session = bool(getattr(streaming_context, "enabled", False))
+ if is_streaming_session:
+ prompt_token_ids, output_ids, thinker_sequences, thinker_input_ids = _get_streaming_talker_tokens(
+ req_id,
+ prompt_token_ids,
+ output_ids,
+ getattr(streaming_context, "new_prompt_len_snapshot", None),
+ streaming_context,
+ clear_state=bool(getattr(thinker_output, "finished", False)),
+ )
+ else:
+ thinker_sequences = prompt_token_ids + output_ids
+ thinker_input_ids = prompt_token_ids
+ # For streaming input, just send incremental prefill and hidden states tensor to talker
+ # Equally applicable to non-streaming cases.
+ new_seq_length = len(prompt_token_ids + output_ids) - 1
+ thinker_mm = output.multimodal_output
+ # Full thinker embedding sequence for the talker: single thinker engine in the
+ # non-PD path; after optional merge with prefill-side tensors in PD mode.
+ thinker_emb = thinker_mm[_EMBED_LAYER_KEY].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
+ thinker_hid = thinker_mm[_HIDDEN_LAYER_KEY].detach().to(device=device, dtype=torch.float)[-new_seq_length:]
+
+ prefill_mm: dict[str, Any] | None = None
+ prefill_mm = _get_prefill_multimodal_output(req_id, streaming_context)
+
+ if prefill_mm is not None:
+ expected_total = len(prompt_token_ids) + len(output_ids)
+ try:
+ thinker_emb, thinker_hid = _merge_pd_embeddings(
+ thinker_emb, thinker_hid, prefill_mm, device, expected_total=expected_total
+ )
+ except Exception as exc:
+ logger.warning("[PD] Could not merge prefill embeddings: %s", exc)
info = {
- "thinker_prefill_embeddings": output.multimodal_output["0"].detach().to(device=device, dtype=torch.float),
- "thinker_hidden_states": output.multimodal_output["24"].detach().to(device=device, dtype=torch.float),
- "thinker_sequences": (
- thinker_output.prompt_token_ids + output.token_ids
- ), # the thinker_sequences is the whole ids
- "thinker_input_ids": thinker_output.prompt_token_ids,
+ "thinker_prefill_embeddings": thinker_emb,
+ "thinker_hidden_states": thinker_hid,
+ "thinker_sequences": thinker_sequences, # the thinker_sequences is the whole ids
+ "thinker_input_ids": thinker_input_ids,
# Provide thinker-side TTS token embeddings for talker projection
- "tts_bos_embed": output.multimodal_output["tts_bos_embed"].detach().to(device=device, dtype=torch.float),
- "tts_eos_embed": output.multimodal_output["tts_eos_embed"].detach().to(device=device, dtype=torch.float),
- "tts_pad_embed": output.multimodal_output["tts_pad_embed"].detach().to(device=device, dtype=torch.float),
+ "tts_bos_embed": _resolve_tts_token_embedding(
+ "tts_bos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
+ ),
+ "tts_eos_embed": _resolve_tts_token_embedding(
+ "tts_eos_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
+ ),
+ "tts_pad_embed": _resolve_tts_token_embedding(
+ "tts_pad_embed", thinker_mm=thinker_mm, prefill_mm=prefill_mm, device=device
+ ),
}
speaker = extract_speaker_from_prompt(prompt, index=i)
if speaker is not None:
@@ -295,6 +524,7 @@ def talker2code2wav(
source_outputs: list[Any],
_prompt: OmniTokensPrompt | TextPrompt | None = None,
_requires_multimodal_data: bool = False,
+ streaming_context: Any | None = None,
) -> list[OmniTokensPrompt]:
"""
Process talker outputs to create code2wav inputs.
@@ -311,9 +541,14 @@ def talker2code2wav(
talker_outputs = source_outputs
code2wav_inputs: list[OmniTokensPrompt] = []
# Process each talker output
- for talker_output in talker_outputs:
+ for i, talker_output in enumerate(talker_outputs):
output = talker_output.outputs[0]
- seq_len = len(output.token_ids) - 1
+ req_id = str(getattr(talker_output, "request_id", f"idx-{i}"))
+ cur_seq_len = len(output.token_ids) - 1
+ seq_len = cur_seq_len
+ is_streaming_session = bool(getattr(streaming_context, "enabled", False))
+ if is_streaming_session:
+ seq_len = _get_streaming_codec_delta_len(cur_seq_len, req_id, talker_output, streaming_context)
# Extract codec codes from talker output
# Expected shape: [8, seq_len] (8-layer RVQ codes)
codec_codes = (
diff --git a/vllm_omni/patch.py b/vllm_omni/patch.py
index d4ab78f13a0..f6c483a92f0 100644
--- a/vllm_omni/patch.py
+++ b/vllm_omni/patch.py
@@ -12,12 +12,13 @@
from vllm.v1.engine import EngineCoreRequest as _OriginalEngineCoreRequest
from vllm.v1.request import Request as _OriginalRequest
from vllm.v1.request import RequestStatus
+from vllm.v1.request import StreamingUpdate as _OriginalStreamingUpdate
import vllm_omni.logger # noqa: F401
from vllm_omni.engine import OmniEngineCoreOutput, OmniEngineCoreOutputs, OmniEngineCoreRequest
from vllm_omni.inputs.data import OmniTokensPrompt
from vllm_omni.model_executor.layers.rotary_embedding import OmniMRotaryEmbedding
-from vllm_omni.request import OmniRequest
+from vllm_omni.request import OmniRequest, OmniStreamingUpdate
# =============================================================================
# Patch ModelConfig.is_mm_prefix_lm to support omni-specific models
@@ -115,5 +116,7 @@ def _patched_glm_image_text_config_init(self, *args, **kwargs):
module.MRotaryEmbedding = OmniMRotaryEmbedding
if hasattr(module, "Request") and module.Request == _OriginalRequest:
module.Request = OmniRequest
+ if hasattr(module, "StreamingUpdate") and module.StreamingUpdate == _OriginalStreamingUpdate:
+ module.StreamingUpdate = OmniStreamingUpdate
if hasattr(module, "EngineCoreRequest") and module.EngineCoreRequest == _OriginalEngineCoreRequest:
module.EngineCoreRequest = OmniEngineCoreRequest
diff --git a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
index 053e8a8cca0..0fd03949d11 100644
--- a/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
+++ b/vllm_omni/platforms/npu/stage_configs/hunyuan_image3_t2i.yaml
@@ -33,6 +33,3 @@ stage_args:
# Runtime defaults
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 8f7af161d65..00000000000
--- a/vllm_omni/platforms/npu/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,97 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # haven't supported talker ACL graph on NPU
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "2" # Example: use a different NPU than the previous stage; use "0" if single NPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 2638c99cd4b..00000000000
--- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,99 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 5x A2/A3-64G NPUs.
-stage_args:
- - stage_id: 0
- runtime:
- devices: "0,1"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- hf_config_name: thinker_config
- tensor_parallel_size: 2
- # profiler_config:
- # profiler: torch
- # torch_profiler_dir: ./perf
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- runtime:
- devices: "2"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true # haven't supported talker ACL graph on NPU
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- runtime:
- devices: "2"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml
deleted file mode 100644
index 9aa20baecfb..00000000000
--- a/vllm_omni/platforms/npu/stage_configs/qwen3_omni_moe_async_chunk.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 16-layer RVQ codec codes)
-# Stage 2: Code2Wav (16-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0,1"
- engine_args:
- max_num_seqs: 10
- model_stage: thinker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 2
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker_async_chunk
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "2"
- engine_args:
- max_num_seqs: 10
- model_stage: talker
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav_async_chunk
- engine_input_source: [0]
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.0
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "2"
- engine_args:
- max_num_seqs: 10
- model_stage: code2wav
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 51200 # [TODO] if max_num_batched_tokens < max_num_seqs * 800, there will be precision problem.
- hf_config_name: thinker_config
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml b/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
deleted file mode 100644
index cd82d91b715..00000000000
--- a/vllm_omni/platforms/npu/stage_configs/qwen3_tts.yaml
+++ /dev/null
@@ -1,96 +0,0 @@
-async_chunk: true
-stage_args:
- - stage_id: 0
- stage_type: llm
- is_comprehension: true
- runtime:
- devices: "0"
- engine_args:
- model_stage: qwen3_tts
- max_num_seqs: 1
- model_arch: Qwen3TTSTalkerForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: latent
- gpu_memory_utilization: 0.3
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 512
- max_model_len: 4096
- custom_process_next_stage_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_tts.talker2code2wav_async_chunk
- # Use named connector to apply runtime.connectors.extra.
- output_connectors:
- to_stage_1: connector_of_shared_memory
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: false
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 1
- stage_type: llm
- runtime:
- devices: "0"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3TTSCode2Wav
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- async_scheduling: false
- enable_prefix_caching: false
- engine_output_type: audio
- gpu_memory_utilization: 0.2
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 65536
- max_model_len: 65536
- engine_input_source: [0]
- final_output: true
- final_output_type: audio
- # Distributed connector configuration
- input_connectors:
- from_stage_0: connector_of_shared_memory
- tts_args:
- max_instructions_length: 500
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: true
- repetition_penalty: 1.0
-
-runtime:
- enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
- connectors:
- connector_of_shared_memory:
- name: SharedMemoryConnector
- extra:
- shm_threshold_bytes: 65536
- # Frame-aligned codec streaming transport.
- codec_streaming: true
- # Connector polling / timeout (unit: loop count, sleep interval in seconds).
- connector_get_sleep_s: 0.01
- connector_get_max_wait_first_chunk: 3000
- connector_get_max_wait: 300
- # Align with Omni: small chunks with sufficient context overlap.
- codec_chunk_frames: 25
- codec_left_context_frames: 72
-
- edges:
- - from: 0
- to: 1
- window_size: -1
diff --git a/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
index 0a4ed7497d5..87843634cb7 100644
--- a/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
+++ b/vllm_omni/platforms/npu/stage_configs/voxcpm_async_chunk.yaml
@@ -73,9 +73,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -90,4 +87,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 35e81935457..00000000000
--- a/vllm_omni/platforms/rocm/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config has been verified on 2x H100-80G GPU.
-stage_args:
- - stage_id: 0
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage (CUDA_VISIBLE_DEVICES/torch.cuda.set_device)
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true # Now we only support eager mode
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
- - stage_id: 1
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.8
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- runtime:
- process: true
- devices: "2" # Example: use a different GPU than the previous stage; use "0" if single GPU
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.15
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 0ca150bee6c..00000000000
--- a/vllm_omni/platforms/rocm/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,97 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config has been verified on 2x H100-80G GPUs.
-stage_args:
- - stage_id: 0
- runtime:
- devices: "0"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 1
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- runtime:
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- # tensor_parallel_size: 2
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- runtime:
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.1
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
index 0fc8a25ea5c..7b27f6a443a 100644
--- a/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/bagel.yaml
@@ -67,10 +67,6 @@ stage_args:
# Runtime edges
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
-
# Distributed connectors configuration (optional)
# More connectors will be supported in the future.
connectors:
@@ -83,4 +79,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
index 8f969ced5f4..4e0005f82a1 100644
--- a/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/hunyuan_image3_t2i.yaml
@@ -78,6 +78,3 @@ stage_args:
# Top-level runtime config (concise): default windows and stage edges
runtime:
enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml
deleted file mode 100644
index 7dbedb29a5e..00000000000
--- a/vllm_omni/platforms/xpu/stage_configs/qwen2_5_omni.yaml
+++ /dev/null
@@ -1,101 +0,0 @@
-# stage config for running qwen2.5-omni for multi-stage omni runtime.
-
-# The following config is verified with 2 * Intel Arc Pro B60 XPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true # Run this stage in a separate process
- devices: "0" # Visible devices for this stage
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9 # thinker weight is around 16.74GB for Qwen2.5-Omni-7B
- enforce_eager: false
- trust_remote_code: true
- engine_output_type: latent
- enable_prefix_caching: false
- is_comprehension: true
- final_output: true
- final_output_type: text
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.5 # talker weight is 6.03GB for Qwen2.5-Omni-7B
- enforce_eager: false
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: latent
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen2_5_omni.thinker2talker
- default_sampling_params:
- temperature: 0.9
- top_p: 0.8
- top_k: 40
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
- stop_token_ids: [8294]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- process: true
- devices: "1"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen2_5OmniForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- gpu_memory_utilization: 0.3 # code2wav weight is around 1.46GB for Qwen2.5-Omni-7B
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio
- engine_input_source: [1]
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
-
-# Top-level runtime config (concise): default windows and stage edges
-runtime:
- enabled: true
- defaults:
- window_size: -1 # Simplified: trigger downstream only after full upstream completion
- max_inflight: 1 # Simplified: process serially within each stage
-
- edges:
- - from: 0 # thinker → talker: trigger only after receiving full input (-1)
- to: 1
- window_size: -1
- - from: 1 # talker → code2wav: trigger only after receiving full input (-1)
- to: 2
- window_size: -1
diff --git a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml b/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml
deleted file mode 100644
index 49914bebc43..00000000000
--- a/vllm_omni/platforms/xpu/stage_configs/qwen3_omni_moe.yaml
+++ /dev/null
@@ -1,102 +0,0 @@
-# Stage config for running Qwen3-Omni-MoE with 3-stage architecture
-# Stage 0: Thinker (multimodal understanding + text generation)
-# Stage 1: Talker (text embeddings → 8-layer RVQ codec codes)
-# Stage 2: Code2Wav (8-layer RVQ codes → audio waveform)
-
-# The following config is verified with 8 * Intel Arc Pro B60 XPU.
-stage_args:
- - stage_id: 0
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "0,1,2,3"
- engine_args:
- model_stage: thinker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.9 # thinker weight is around 61.08GB for Qwen3-Omni-30B-A3B-Instruct
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output hidden states for talker
- distributed_executor_backend: "mp"
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- hf_config_name: thinker_config
- tensor_parallel_size: 4
- max_cudagraph_capture_size: 0
- final_output: true
- final_output_type: text
- is_comprehension: true
- default_sampling_params:
- temperature: 0.4
- top_p: 0.9
- top_k: 1
- max_tokens: 2048
- seed: 42
- detokenize: True
- repetition_penalty: 1.05
-
- - stage_id: 1
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "4"
- engine_args:
- model_stage: talker
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: ar
- scheduler_cls: vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler
- gpu_memory_utilization: 0.6 # talker weight is around 8.5GB for Qwen3-Omni-30B-A3B-Instruct
- enforce_eager: true
- trust_remote_code: true
- engine_output_type: latent # Output codec codes for code2wav
- enable_prefix_caching: false
- max_num_batched_tokens: 32768
- distributed_executor_backend: "mp"
- hf_config_name: talker_config
- max_cudagraph_capture_size: 0
- engine_input_source: [0]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.thinker2talker
- # final_output: true
- # final_output_type: text
- default_sampling_params:
- temperature: 0.9
- top_k: 50
- max_tokens: 4096
- seed: 42
- detokenize: False
- repetition_penalty: 1.05
- stop_token_ids: [2150]
-
- - stage_id: 2
- stage_type: llm # Use llm stage type for AR stages
- runtime:
- devices: "4"
- engine_args:
- model_stage: code2wav
- max_num_seqs: 1
- model_arch: Qwen3OmniMoeForConditionalGeneration
- worker_type: generation
- scheduler_cls: vllm_omni.core.sched.omni_generation_scheduler.OmniGenerationScheduler
- enforce_eager: true
- trust_remote_code: true
- enable_prefix_caching: false
- engine_output_type: audio # Final output: audio waveform
- gpu_memory_utilization: 0.3 # code2wav weight is around 0.4GB for Qwen3-Omni-30B-A3B-Instruct
- distributed_executor_backend: "mp"
- max_num_batched_tokens: 1000000
- hf_config_name: thinker_config
- max_cudagraph_capture_size: 0
- engine_input_source: [1]
- custom_process_input_func: vllm_omni.model_executor.stage_input_processors.qwen3_omni.talker2code2wav
- final_output: true
- final_output_type: audio
- default_sampling_params:
- temperature: 0.0
- top_p: 1.0
- top_k: -1
- max_tokens: 65536
- seed: 42
- detokenize: True
- repetition_penalty: 1.1
diff --git a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
index 10051c1eda4..0820ab63203 100644
--- a/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
+++ b/vllm_omni/platforms/xpu/stage_configs/voxtral_tts.yaml
@@ -88,9 +88,6 @@ stage_args:
runtime:
enabled: true
- defaults:
- window_size: -1
- max_inflight: 1
connectors:
connector_of_shared_memory:
@@ -108,4 +105,3 @@ runtime:
edges:
- from: 0
to: 1
- window_size: -1
diff --git a/vllm_omni/request.py b/vllm_omni/request.py
index 3ec325316fd..48cbf9b31d7 100644
--- a/vllm_omni/request.py
+++ b/vllm_omni/request.py
@@ -1,8 +1,11 @@
from collections.abc import Callable
+from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
import torch
+from vllm.multimodal.inputs import MultiModalFeatureSpec
+from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
if TYPE_CHECKING:
@@ -92,3 +95,34 @@ def from_engine_core_request(
resumable=request.resumable,
reasoning_ended=request.reasoning_ended,
)
+
+
+@dataclass
+class OmniStreamingUpdate:
+ """
+ Override: add additional information
+ Lightweight data for streaming session continuation.
+
+ Contains only the fields needed to update an existing streaming session
+ with new input data.
+ """
+
+ mm_features: list[MultiModalFeatureSpec] | None
+ prompt_token_ids: list[int] | None
+ max_tokens: int
+ arrival_time: float
+ sampling_params: SamplingParams | None
+ additional_information: AdditionalInformationPayload | None = None
+
+ @classmethod
+ def from_request(cls, request: "Request") -> "OmniStreamingUpdate | None":
+ if not request.resumable:
+ return None
+ return cls(
+ mm_features=request.mm_features,
+ prompt_token_ids=request.prompt_token_ids,
+ max_tokens=request.max_tokens,
+ arrival_time=request.arrival_time,
+ sampling_params=request.sampling_params,
+ additional_information=request.additional_information,
+ )
diff --git a/vllm_omni/transformers_utils/configs/__init__.py b/vllm_omni/transformers_utils/configs/__init__.py
index 0aa3624f802..598ac3a9655 100644
--- a/vllm_omni/transformers_utils/configs/__init__.py
+++ b/vllm_omni/transformers_utils/configs/__init__.py
@@ -19,6 +19,11 @@
"FishSpeechFastARConfig": "vllm_omni.transformers_utils.configs.fish_speech",
"VoxCPMConfig": "vllm_omni.transformers_utils.configs.voxcpm",
"VoxCPM2Config": "vllm_omni.transformers_utils.configs.voxcpm2",
+ "BailingMoeV2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "BailingMM2Config": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "MingFlashOmniConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "Qwen3VLMoeVisionConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
+ "WhisperEncoderConfig": "vllm_omni.transformers_utils.configs.ming_flash_omni",
}
__all__ = [
@@ -31,6 +36,11 @@
"FishSpeechFastARConfig",
"VoxCPMConfig",
"VoxCPM2Config",
+ "BailingMoeV2Config",
+ "BailingMM2Config",
+ "MingFlashOmniConfig",
+ "Qwen3VLMoeVisionConfig",
+ "WhisperEncoderConfig",
]
@@ -51,5 +61,6 @@ def __dir__():
# run as soon as `vllm_omni.transformers_utils.configs` is imported.
from vllm_omni.transformers_utils.configs import fish_speech as _fish_speech # noqa: F401, E402
from vllm_omni.transformers_utils.configs import mammoth_moda2 as _mammoth_moda2 # noqa: F401, E402
+from vllm_omni.transformers_utils.configs import ming_flash_omni as _ming_flash_omni # noqa: F401, E402
from vllm_omni.transformers_utils.configs import voxcpm as _voxcpm # noqa: F401, E402
from vllm_omni.transformers_utils.configs import voxcpm2 as _voxcpm2 # noqa: F401, E402
diff --git a/vllm_omni/transformers_utils/configs/ming_flash_omni.py b/vllm_omni/transformers_utils/configs/ming_flash_omni.py
new file mode 100644
index 00000000000..dd13b682dee
--- /dev/null
+++ b/vllm_omni/transformers_utils/configs/ming_flash_omni.py
@@ -0,0 +1,302 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration for Ming-flash-omni-2.0 model"""
+
+import os
+from typing import Any, ClassVar
+
+from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class BailingMoeV2Config(PretrainedConfig):
+ model_type = "bailing_moe_v2"
+
+ def __init__(
+ self,
+ vocab_size=30592,
+ hidden_size=1024,
+ intermediate_size=None,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_key_value_heads=0,
+ hidden_act="silu",
+ use_qkv_bias=False,
+ use_qk_norm=False,
+ use_bias=True,
+ rms_norm_eps=1e-05,
+ norm_head=False,
+ tie_word_embeddings=False,
+ embedding_dropout=0.0,
+ attention_dropout=0.0,
+ output_dropout=0.0,
+ initializer_range=0.02,
+ max_position_embeddings=16384,
+ rope_theta=10000.0,
+ use_cache=True,
+ use_sliding_window=False,
+ sliding_window=81920,
+ max_window_layers=28,
+ rope_scaling=None,
+ mrope_section=None,
+ pad_token_id=126081,
+ num_experts=16,
+ num_shared_experts=1,
+ num_experts_per_tok=2,
+ n_group=8,
+ topk_group=4,
+ routed_scaling_factor=2.5,
+ moe_intermediate_size=None,
+ first_k_dense_replace=0,
+ head_dim=None,
+ output_router_logits=False,
+ partial_rotary_factor=0.5,
+ router_type="topN",
+ _attn_implementation="flash_attention_2",
+ use_interleaved_frame_timestamp=True,
+ # Multimodal token IDs
+ image_patch_token=157157,
+ video_patch_token=157175,
+ audio_patch_token=157168,
+ image_start_token=157158,
+ video_start_token=157160,
+ audio_start_token=157169,
+ image_end_token=157159,
+ video_end_token=157161,
+ audio_end_token=157170,
+ # Position encoding parameters
+ spatial_merge_size=2,
+ tokens_per_second=2,
+ **kwargs,
+ ):
+ self.num_hidden_layers = num_hidden_layers
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.use_qkv_bias = use_qkv_bias
+ self.use_bias = use_bias
+ self.norm_head = norm_head
+ self.rms_norm_eps = rms_norm_eps
+ self.embedding_dropout = embedding_dropout
+ self.attention_dropout = attention_dropout
+ self.output_dropout = output_dropout
+ self.initializer_range = initializer_range
+ self.max_position_embeddings = max_position_embeddings
+ self.rope_theta = rope_theta
+ self.use_cache = use_cache
+ self.use_sliding_window = use_sliding_window
+ self.sliding_window = sliding_window
+ self.max_window_layers = max_window_layers
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
+ self.use_qk_norm = use_qk_norm # arg unused; QK norm is always applied
+
+ # By default, match the value of `mrope_section`
+ # to `apply_3d_rotary_pos_emb` in Ming's repo:
+ # https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/modeling_bailing_moe_v2.py
+ if mrope_section is None:
+ mrope_section = (rope_scaling or {}).get("mrope_section", [8, 12, 12])
+ # Ensure mrope_section is stored inside rope_scaling
+ if rope_scaling is not None and isinstance(rope_scaling, dict):
+ rope_scaling = dict(rope_scaling)
+ rope_scaling.setdefault("mrope_section", mrope_section)
+ self.rope_scaling = rope_scaling
+
+ # NOTE: Expose rope_parameters["mrope_section"]
+ # This refers to the pattern used for GLM-Image in vllm_omni/patch.py
+ rope_type = (rope_scaling or {}).get("type", (rope_scaling or {}).get("rope_type", ""))
+ if rope_type in ("video_rope", "3D", "mrope"):
+ self.rope_parameters = {"mrope_section": mrope_section}
+ else:
+ self.rope_parameters = None
+
+ # MoE configs
+ self.num_experts = num_experts
+ self.num_shared_experts = num_shared_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.n_group = n_group
+ self.topk_group = topk_group
+ self.moe_intermediate_size = moe_intermediate_size
+ self.first_k_dense_replace = first_k_dense_replace
+ self.output_router_logits = output_router_logits
+ self.routed_scaling_factor = routed_scaling_factor
+ self.partial_rotary_factor = partial_rotary_factor
+ self.router_type = router_type
+ self.use_interleaved_frame_timestamp = use_interleaved_frame_timestamp
+ self._attn_implementation = _attn_implementation
+
+ # Multimodal token IDs and position encoding
+ self.image_patch_token = image_patch_token
+ self.video_patch_token = video_patch_token
+ self.audio_patch_token = audio_patch_token
+ self.image_start_token = image_start_token
+ self.video_start_token = video_start_token
+ self.audio_start_token = audio_start_token
+ self.image_end_token = image_end_token
+ self.video_end_token = video_end_token
+ self.audio_end_token = audio_end_token
+ self.spatial_merge_size = spatial_merge_size
+ self.tokens_per_second = tokens_per_second
+
+ super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+class Qwen3VLMoeVisionConfig(PretrainedConfig):
+ """Configuration class for Qwen3 MoE Vision Transformer"""
+
+ model_type = "qwen3_moe_vit"
+
+ def __init__(
+ self,
+ depth=27,
+ hidden_size=1152,
+ hidden_act="gelu_pytorch_tanh",
+ intermediate_size=4304,
+ num_heads=16,
+ in_channels=3,
+ patch_size=16,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ out_hidden_size=3584,
+ num_position_embeddings=2304,
+ deepstack_visual_indexes=[8, 16, 24],
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.depth = depth
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+ self.out_hidden_size = out_hidden_size
+ self.num_position_embeddings = num_position_embeddings
+ self.initializer_range = initializer_range
+ self.deepstack_visual_indexes = deepstack_visual_indexes
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if "vision_config" in config_dict:
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class WhisperEncoderConfig(PretrainedConfig):
+ """Configuration class for Whisper audio encoder"""
+
+ model_type = "whisper_encoder"
+
+ def __init__(
+ self,
+ whisper_encoder_config: dict[str, Any] | None = None,
+ ds_kernel_size=3,
+ ds_stride=2,
+ norm_query_embeds=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.whisper_encoder_config = whisper_encoder_config or {}
+ self.ds_kernel_size = ds_kernel_size
+ self.ds_stride = ds_stride
+ self.norm_query_embeds = norm_query_embeds
+
+
+class BailingMM2Config(PretrainedConfig):
+ model_type = "bailingmm_moe_v2_lite"
+ is_composition = True
+ sub_configs: ClassVar = {"llm_config": AutoConfig}
+
+ def __init__(
+ self,
+ mlp_depth=1,
+ llm_config: BailingMoeV2Config | None = None,
+ vision_config: Qwen3VLMoeVisionConfig | None = None,
+ audio_config: WhisperEncoderConfig | None = None,
+ **kwargs,
+ ):
+ self.audio_config = WhisperEncoderConfig(**audio_config) if isinstance(audio_config, dict) else audio_config
+ self.vision_config = (
+ Qwen3VLMoeVisionConfig(**vision_config) if isinstance(vision_config, dict) else vision_config
+ )
+ self.llm_config = BailingMoeV2Config(**llm_config) if isinstance(llm_config, dict) else llm_config
+ self.mlp_depth = mlp_depth
+ super().__init__(**kwargs)
+
+ def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002
+ return self.llm_config
+
+
+class MingFlashOmniConfig(PretrainedConfig):
+ """Configuration class for unified Ming-flash-omni-2.0 model"""
+
+ model_type = "ming_flash_omni"
+ is_composition = True
+ sub_configs: ClassVar = {"thinker_config": BailingMM2Config}
+
+ def __init__(
+ self,
+ thinker_config: BailingMM2Config | None = None,
+ image_gen_config: dict[str, Any] | None = None,
+ talker_config: dict[str, Any] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if isinstance(thinker_config, dict):
+ self.thinker_config = BailingMM2Config(**thinker_config)
+ else:
+ self.thinker_config = thinker_config or BailingMM2Config()
+
+ # Image generation config (for future implementation)
+ self.image_gen_config = image_gen_config
+
+ # Talker config (for future implementation)
+ self.talker_config = talker_config
+
+ def get_text_config(self, decoder: bool = False) -> PretrainedConfig: # noqa: ARG002
+ return self.thinker_config.get_text_config()
+
+
+# Register model_type -> config class for AutoConfig
+AutoConfig.register(BailingMoeV2Config.model_type, BailingMoeV2Config)
+AutoConfig.register(BailingMM2Config.model_type, BailingMM2Config)
+AutoConfig.register(MingFlashOmniConfig.model_type, MingFlashOmniConfig)
+
+# Register tokenizer mapping for composition configs so that
+# AutoTokenizer.from_pretrained can resolve the tokenizer class
+AutoTokenizer.register(BailingMM2Config, fast_tokenizer_class=PreTrainedTokenizerFast)
+AutoTokenizer.register(MingFlashOmniConfig, fast_tokenizer_class=PreTrainedTokenizerFast)
diff --git a/vllm_omni/transformers_utils/processors/__init__.py b/vllm_omni/transformers_utils/processors/__init__.py
new file mode 100644
index 00000000000..52ca6575397
--- /dev/null
+++ b/vllm_omni/transformers_utils/processors/__init__.py
@@ -0,0 +1,12 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+
+from vllm_omni.transformers_utils.processors.ming import (
+ MingFlashOmniProcessor,
+ MingWhisperFeatureExtractor,
+)
+
+__all__ = [
+ "MingFlashOmniProcessor",
+ "MingWhisperFeatureExtractor",
+]
diff --git a/vllm_omni/transformers_utils/processors/ming.py b/vllm_omni/transformers_utils/processors/ming.py
new file mode 100644
index 00000000000..7f414b7268c
--- /dev/null
+++ b/vllm_omni/transformers_utils/processors/ming.py
@@ -0,0 +1,430 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The vLLM-Omni team.
+# Copyright 2024 ANT Group and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+
+import numpy as np
+import torch
+from transformers import AutoFeatureExtractor, AutoProcessor
+from transformers.feature_extraction_utils import BatchFeature, FeatureExtractionMixin
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = " "
+DEFAULT_VID_START_TOKEN = ""
+DEFAULT_VID_END_TOKEN = " "
+DEFAULT_FRAME_PATCH_TOKEN = ""
+
+DEFAULT_AUDIO_PATCH_TOKEN = ""
+DEFAULT_AU_START_TOKEN = ""
+DEFAULT_AU_END_TOKEN = " "
+
+# High-level placeholders used in user prompts
+PLACEHOLDER_IMAGE_TOKEN_IN_TEXT = ""
+PLACEHOLDER_VIDEO_TOKEN_IN_TEXT = ""
+PLACEHOLDER_AUDIO_TOKEN_IN_TEXT = ""
+
+# Chat template constants
+USER_PREFIX = "HUMAN "
+ASSISTANT_PREFIX = "ASSISTANT "
+SYSTEM_PROMPT_NOTHINK = "SYSTEM 你是一个友好的AI助手。\n\ndetailed thinking off"
+SYSTEM_PROMPT_THINK = "SYSTEM 你是一个友好的AI助手。\n\ndetailed thinking on"
+
+
+_NORM_FACTOR_FOR_DTYPE = {
+ torch.int8: 2**7,
+ torch.int16: 2**15,
+ torch.int32: 2**31,
+ torch.int64: 2**63,
+ torch.float32: 1,
+ torch.float64: 1,
+}
+
+
+def _normalize_audio_tensor(
+ waveform: torch.Tensor,
+ sample_rate: int,
+ target_sample_rate: int = 16000,
+) -> torch.Tensor:
+ """Normalize waveform to float32, mono, and optionally resample."""
+ norm_factor = _NORM_FACTOR_FOR_DTYPE.get(waveform.dtype, 1)
+ waveform = waveform.to(torch.float32) / norm_factor
+
+ # Remove channel dimension
+ while len(waveform.shape) > 1:
+ waveform = waveform[0]
+
+ # Resample if needed
+ if sample_rate != target_sample_rate:
+ import torchaudio
+
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
+ waveform = resampler(waveform.unsqueeze(0)).squeeze(0)
+
+ return waveform
+
+
+class MingWhisperFeatureExtractor(FeatureExtractionMixin):
+ """Whisper log-mel feature extractor for Ming-flash-omni-2.0.
+
+ Produces audio_feats in the time-first packed format.
+
+ Adapted from Ming's WhisperAudioEncoder
+ https://github.com/inclusionAI/Ming/blob/070dc3c13f95d97952ab7d22030df0c9e28a5122/modeling_whisper_encoder.py
+ and HF transformers WhisperFeatureExtractor
+ https://github.com/huggingface/transformers/blob/f842abaca95a7dbf3fc6e16122e7409109bc1431/src/transformers/models/whisper/feature_extraction_whisper.py#L33
+ """
+
+ model_input_names = ["audio_feats", "audio_feats_lengths"]
+
+ def __init__(self, feature_size: int = 128, sampling_rate: int = 16000, **kwargs):
+ # feature_size == n_mels; stored so to_dict() serialises it correctly.
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ super().__init__(**kwargs)
+
+ @property
+ def n_mels(self) -> int:
+ return self.feature_size
+
+ def __call__(
+ self,
+ audios: tuple | list,
+ return_tensors: str | None = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """Preprocess audio(s) into Whisper log-mel spectrograms"""
+ import whisper
+
+ if not isinstance(audios, list):
+ audios = [audios]
+
+ audio_feat_list = []
+ for waveform, sr in audios:
+ if isinstance(waveform, np.ndarray):
+ waveform = torch.from_numpy(waveform)
+ waveform = _normalize_audio_tensor(waveform, sr, target_sample_rate=self.sampling_rate)
+ mel = whisper.log_mel_spectrogram(waveform, n_mels=self.n_mels)
+ audio_feat_list.append(mel.transpose(0, 1)) # [T, n_mels]
+
+ audio_feats_lengths = torch.tensor([[feat.shape[0] for feat in audio_feat_list]], dtype=torch.long)
+ # Two stride-2 convolutions in series:
+ # 1. WhisperAudioEncoder conv2: kernel=3, stride=2, padding=1
+ # (conv1 has stride=1 and does not change T)
+ # 2. AudioProjector Conv1d: kernel=3, stride=2, padding=1
+ # Combined: T → ((T-1)//2 + 1 - 1)//2 + 1
+ # See also: AudioProjector.compute_output_length()
+ encoder_feats_lengths = ((audio_feats_lengths - 3 + 2 * 1) // 2 + 1 - 3 + 2 * 1) // 2 + 1
+ audio_feats = torch.cat(audio_feat_list, dim=0).unsqueeze(0) # [1, T_total, n_mels]
+
+ data = {
+ # [1, T_total, n_mels], all audio clips concatenated
+ "audio_feats": audio_feats.numpy(),
+ # [1, n_audios], actual frame count
+ "audio_feats_lengths": audio_feats_lengths.numpy(),
+ # [1, n_audios]
+ "encoder_feats_lengths": encoder_feats_lengths,
+ }
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+class MingFlashOmniProcessor(ProcessorMixin):
+ """Top-level multimodal processor for Ming-flash-omni 2.0.
+
+ Adapted from Ming's BailingMM2Processor
+ https://github.com/inclusionAI/Ming/blob/3954fcb880ff5e61ff128bcf7f1ec344d46a6fe3/processing_bailingmm2.py
+
+ Subprocessors include:
+ - Qwen2VLImageProcessor (image/video)
+ - MingWhisperFeatureExtractor (modified audio processor using Whisper's log-mel spectrogram)
+ """
+
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ audio_processor_class = "AutoFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ audio_processor=None,
+ tokenizer=None,
+ merge_size: int = 2,
+ **kwargs,
+ ):
+ # Enforce that all sub-processors exist
+ # Keep None defaults in the signature for HF ProcessorMixin compatibility
+ if image_processor is None:
+ raise ValueError("MingFlashOmniProcessor requires `image_processor`.")
+ if audio_processor is None:
+ raise ValueError("MingFlashOmniProcessor requires `audio_processor`.")
+ if tokenizer is None:
+ raise ValueError("MingFlashOmniProcessor requires `tokenizer`.")
+
+ self.spatial_merge_size = merge_size
+ self.image_token = PLACEHOLDER_IMAGE_TOKEN_IN_TEXT
+ self.video_token = PLACEHOLDER_VIDEO_TOKEN_IN_TEXT
+ self.audio_token = PLACEHOLDER_AUDIO_TOKEN_IN_TEXT
+ super().__init__(
+ image_processor=image_processor,
+ audio_processor=audio_processor,
+ tokenizer=tokenizer,
+ )
+
+ # Fall back to the tokenizer's own chat_template.
+ if self.chat_template is None:
+ self.chat_template = getattr(tokenizer, "chat_template", None)
+
+ def __call__(
+ self,
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
+ images: Any | None = None,
+ videos: Any | None = None,
+ audios: tuple[np.ndarray, int] | list[tuple[np.ndarray, int]] | None = None,
+ **kwargs,
+ ) -> BatchFeature:
+ # This should always be parallel implementations that mirror
+ # `_get_prompt_updates` logic in Ming processor, and vice versa.
+ # Ensure text is a list
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list):
+ raise ValueError("text must be a string or list of strings")
+
+ data: dict[str, Any] = {}
+
+ if images is not None:
+ image_outputs = self.image_processor(
+ images=images,
+ videos=None,
+ return_tensors="pt",
+ **kwargs.get("images_kwargs", {}),
+ )
+ data.update(image_outputs)
+ if "image_grid_thw" in image_outputs:
+ text = self._expand_image_tokens(text, image_outputs["image_grid_thw"])
+
+ if videos is not None:
+ video_outputs = self.image_processor(
+ images=None,
+ videos=videos,
+ return_tensors="pt",
+ **kwargs.get("videos_kwargs", {}),
+ )
+ if "pixel_values" in video_outputs:
+ video_outputs["pixel_values_videos"] = video_outputs.pop("pixel_values")
+ if "image_grid_thw" in video_outputs:
+ video_outputs["video_grid_thw"] = video_outputs.pop("image_grid_thw")
+ data.update(video_outputs)
+ if "video_grid_thw" in video_outputs:
+ text = self._expand_video_tokens(text, video_outputs["video_grid_thw"])
+
+ if audios is not None:
+ audio_outputs = self.audio_processor(
+ audios,
+ return_tensors="pt",
+ **kwargs.get("audio_kwargs", {}),
+ )
+ data.update(audio_outputs)
+ if "encoder_feats_lengths" in audio_outputs:
+ text = self._expand_audio_tokens(text, audio_outputs["encoder_feats_lengths"])
+
+ text_outputs = self.tokenizer(
+ text,
+ return_tensors="pt",
+ **kwargs.get("text_kwargs", {}),
+ )
+ data.update(text_outputs)
+ return BatchFeature(data=data)
+
+ def _expand_image_tokens(
+ self,
+ text: list[str],
+ image_grid_thw: torch.Tensor,
+ special_token: str = PLACEHOLDER_IMAGE_TOKEN_IN_TEXT,
+ ) -> list[str]:
+ merge_size = self.spatial_merge_size
+ num_patches_per_image = torch.prod(image_grid_thw, dim=1) // (merge_size**2)
+ prompt_strings = []
+ image_index = 0
+ for sample in text:
+ num_images = sample.count(special_token)
+ if num_images > 0:
+ for i in range(image_index, num_images + image_index):
+ num_patches = int(num_patches_per_image[i].item())
+ img_text = (
+ DEFAULT_IM_START_TOKEN + (DEFAULT_IMAGE_PATCH_TOKEN * num_patches) + DEFAULT_IM_END_TOKEN + "\n"
+ )
+ sample = sample.replace(special_token, img_text, 1)
+ image_index += num_images
+ prompt_strings.append(sample)
+ return prompt_strings
+
+ def _expand_video_tokens(
+ self,
+ text: list[str],
+ video_grid_thw: torch.Tensor,
+ special_token: str = PLACEHOLDER_VIDEO_TOKEN_IN_TEXT,
+ ) -> list[str]:
+ merge_size = self.spatial_merge_size
+ num_patches_per_video = torch.prod(video_grid_thw, dim=1) // (merge_size**2)
+ prompt_strings = []
+ video_index = 0
+ for sample in text:
+ num_videos = sample.count(special_token)
+ if num_videos > 0:
+ for i in range(video_index, num_videos + video_index):
+ num_patches = int(num_patches_per_video[i].item())
+ video_text = (
+ DEFAULT_VID_START_TOKEN
+ + (DEFAULT_FRAME_PATCH_TOKEN * num_patches)
+ + DEFAULT_VID_END_TOKEN
+ + "\n"
+ )
+ sample = sample.replace(special_token, video_text, 1)
+ video_index += num_videos
+ prompt_strings.append(sample)
+ return prompt_strings
+
+ def _expand_audio_tokens(
+ self,
+ text: list[str],
+ encoder_feats_lengths: torch.Tensor,
+ special_token: str = PLACEHOLDER_AUDIO_TOKEN_IN_TEXT,
+ ) -> list[str]:
+ prompt_strings = []
+ for sample, lengths_tensor in zip(text, encoder_feats_lengths):
+ for length in lengths_tensor:
+ num_patches = int(length.item())
+ if num_patches == 0:
+ continue
+ audio_text = DEFAULT_AU_START_TOKEN + (DEFAULT_AUDIO_PATCH_TOKEN * num_patches) + DEFAULT_AU_END_TOKEN
+ if special_token in sample:
+ sample = sample.replace(special_token, audio_text, 1)
+ else:
+ sample = sample + audio_text + "\n"
+ prompt_strings.append(sample)
+ return prompt_strings
+
+ def apply_system_template(
+ self,
+ sys_prompt_exp: str | None = None,
+ use_cot_system_prompt: bool = False,
+ ) -> str:
+ sys_prompt = SYSTEM_PROMPT_THINK if use_cot_system_prompt else SYSTEM_PROMPT_NOTHINK
+ if sys_prompt_exp is not None:
+ sys_prompt = sys_prompt.replace("你是一个友好的AI助手。", sys_prompt_exp)
+ return sys_prompt
+
+ def apply_chat_template(
+ self,
+ conversation: list[dict[str, Any]],
+ sys_prompt_exp: str | None = None,
+ use_cot_system_prompt: bool = False,
+ **kwargs,
+ ) -> str:
+ eos = self.tokenizer.eos_token
+ text = self.apply_system_template(sys_prompt_exp, use_cot_system_prompt) + eos
+
+ for idx, message in enumerate(conversation):
+ assert message["role"] in ["HUMAN", "ASSISTANT"], (
+ f"Invalid role: {message['role']}. Must be 'HUMAN' or 'ASSISTANT'"
+ )
+ if idx == len(conversation) - 1:
+ assert message["role"] == "HUMAN", "Last message must be from HUMAN"
+
+ text += USER_PREFIX if message["role"] == "HUMAN" else ASSISTANT_PREFIX
+
+ content = message["content"]
+ if isinstance(content, str):
+ # text-only
+ text += content
+ elif isinstance(content, list):
+ # structured content with multimodal elements
+ # Count existing placeholders from text items only
+ image_placeholders = 0
+ video_placeholders = 0
+ audio_placeholders = 0
+ for content_item in content:
+ if content_item.get("type", "text") == "text":
+ t = content_item.get("text", "")
+ image_placeholders += t.count(PLACEHOLDER_IMAGE_TOKEN_IN_TEXT)
+ video_placeholders += t.count(PLACEHOLDER_VIDEO_TOKEN_IN_TEXT)
+ audio_placeholders += t.count(PLACEHOLDER_AUDIO_TOKEN_IN_TEXT)
+
+ if video_placeholders > 1:
+ raise ValueError("Video count must be at most 1 per message!")
+
+ # Insert placeholders only for media items not already covered
+ for content_item in content:
+ content_type = content_item.get("type", "text")
+
+ if content_type == "image":
+ image_data = content_item.get("image")
+ if image_data is not None:
+ from PIL import Image as PILImage
+
+ num_images = 1 if isinstance(image_data, (str, PILImage.Image)) else len(image_data)
+ for _ in range(num_images):
+ if image_placeholders > 0:
+ image_placeholders -= 1
+ else:
+ text += PLACEHOLDER_IMAGE_TOKEN_IN_TEXT
+
+ elif content_type == "video":
+ if video_placeholders > 0:
+ video_placeholders -= 1
+ else:
+ text += PLACEHOLDER_VIDEO_TOKEN_IN_TEXT
+ elif content_type == "audio":
+ audio_data = content_item.get("audio")
+ if audio_data is not None:
+ num_audios = 1 if isinstance(audio_data, str) else len(audio_data)
+ for _ in range(num_audios):
+ if audio_placeholders > 0:
+ audio_placeholders -= 1
+ else:
+ text += PLACEHOLDER_AUDIO_TOKEN_IN_TEXT
+
+ elif content_type == "text":
+ text += content_item.get("text", "")
+
+ # Add EOS token after each message except the last one
+ text += eos
+
+ text += ASSISTANT_PREFIX
+ return text
+
+ def batch_decode(self, *args, **kwargs):
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ names = (
+ self.tokenizer.model_input_names
+ + self.image_processor.model_input_names
+ + self.audio_processor.model_input_names
+ )
+ return list(dict.fromkeys(names))
+
+
+AutoFeatureExtractor.register("MingWhisperFeatureExtractor", MingWhisperFeatureExtractor)
+AutoProcessor.register("MingFlashOmniProcessor", MingFlashOmniProcessor)
diff --git a/vllm_omni/version.py b/vllm_omni/version.py
index e5f0b6b661d..296bebc8e20 100644
--- a/vllm_omni/version.py
+++ b/vllm_omni/version.py
@@ -5,12 +5,12 @@
and written to _version.py during package build.
"""
+import warnings
+
try:
# Import auto-generated version from _version.py (created by setuptools_scm)
from ._version import __version__, __version_tuple__
except ImportError as e:
- import warnings
-
warnings.warn(
f"Failed to import version from _version.py: {e}\n"
"This typically happens in development mode before building.\n"
@@ -22,4 +22,37 @@
__version__ = "dev"
__version_tuple__ = (0, 0, "dev")
+
+def warn_if_misaligned_vllm_version():
+ """Warn if vLLM and vllm-omni versions don't match (major.minor)."""
+ # Import vllm lazily since import order may be sensitive with current monkeypatching,
+ # but we want to check this before potentially breaking imports run.
+ from vllm import __version__ as vllm_version
+ from vllm import __version_tuple__ as vllm_version_tuple
+
+ omni_ver: tuple[str | int, ...] = __version_tuple__[:2]
+ vllm_ver: tuple[str | int, ...] = vllm_version_tuple[:2]
+ # Skip if either version is dev (0, 0)
+ if omni_ver == (0, 0) or vllm_ver == (0, 0):
+ return
+
+ # Compare major.minor
+ if omni_ver != vllm_ver:
+ warnings.warn(
+ "vLLM and vLLM-Omni appear to have mismatched major/minor versions:\n"
+ f" --> vLLM-Omni version {__version__}\n"
+ f" --> vLLM version {vllm_version}\n"
+ "This will likely cause compatibility issues.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+
+
__all__ = ["__version__", "__version_tuple__"]
+
+# Run version check automatically when this module is imported
+try:
+ warn_if_misaligned_vllm_version()
+except ModuleNotFoundError:
+ # vLLM not installed (e.g., documentation builds)
+ pass
diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py
index de78011c75a..d1c15eac640 100644
--- a/vllm_omni/worker/gpu_model_runner.py
+++ b/vllm_omni/worker/gpu_model_runner.py
@@ -308,6 +308,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput"):
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
if req_id in self.requests:
+ self._update_streaming_input_additional_info(new_req_data, req_id)
req_state = self._update_streaming_request(req_id, new_req_data)
reqs_to_add.append(req_state)
continue
@@ -1414,3 +1415,30 @@ def _update_intermediate_buffer(self, req_id: str, upd: dict) -> None:
def _merge_additional_information_update(self, req_id, upd):
logger.warning_once("_merge_additional_information_update is deprecated, use _update_intermediate_buffer")
return self._update_intermediate_buffer(req_id, upd)
+
+ def _update_streaming_input_additional_info(self, new_req_data, req_id):
+ # For streaming input prefill case only. Update buffer from last segment input
+ cached_additional_info = self.model_intermediate_buffer.get(req_id, {})
+ if cached_additional_info:
+ payload_info = getattr(new_req_data, "additional_information", None)
+ inc_info = deserialize_additional_information(payload_info)
+ if isinstance(inc_info, dict) and inc_info:
+ merged_info = dict(cached_additional_info)
+ for key, value in inc_info.items():
+ accumulated_keys: set[str] = set()
+ if hasattr(self, "model") and hasattr(self.model, "streaming_accumulated_keys"):
+ accumulated_keys = self.model.streaming_accumulated_keys
+ if key in accumulated_keys and isinstance(value, torch.Tensor):
+ inc_tensor = value.detach().to("cpu").contiguous()
+ old_tensor = merged_info.get(key)
+ if old_tensor is None:
+ merged_info[key] = inc_tensor
+ else:
+ merged_info[key] = torch.cat((old_tensor, inc_tensor), dim=0)
+ continue
+
+ # Default for other keys: latest value.
+ merged_info[key] = value
+ merged_info["num_processed_tokens"] = 0
+ self.model_intermediate_buffer[req_id] = merged_info
+ setattr(self.requests[req_id], "additional_information_cpu", merged_info)